瀏覽代碼

Fix qmd embed model selection (#494)

LJY 1 月之前
父節點
當前提交
698b44fe87
共有 4 個文件被更改,包括 47 次插入17 次删除
  1. 2 2
      src/cli/qmd.ts
  2. 7 7
      src/llm.ts
  3. 6 6
      src/store.ts
  4. 32 2
      test/store.test.ts

+ 2 - 2
src/cli/qmd.ts

@@ -1652,7 +1652,7 @@ function parseChunkStrategy(value: unknown): ChunkStrategy | undefined {
 }
 
 async function vectorIndex(
-  model: string = DEFAULT_EMBED_MODEL,
+  model: string = DEFAULT_EMBED_MODEL_URI,
   force: boolean = false,
   batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number; chunkStrategy?: ChunkStrategy },
 ): Promise<void> {
@@ -3002,7 +3002,7 @@ if (isMain) {
         const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]);
         const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]);
         const embedChunkStrategy = parseChunkStrategy(cli.values["chunk-strategy"]);
-        await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, {
+        await vectorIndex(DEFAULT_EMBED_MODEL_URI, !!cli.values.force, {
           maxDocsPerBatch,
           maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024,
           chunkStrategy: embedChunkStrategy,

+ 7 - 7
src/llm.ts

@@ -155,7 +155,7 @@ export type LLMSessionOptions = {
  */
 export interface ILLMSession {
   embed(text: string, options?: EmbedOptions): Promise<EmbeddingResult | null>;
-  embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]>;
+  embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>;
   expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise<Queryable[]>;
   rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise<RerankResult>;
   /** Whether this session is still valid (not released or aborted) */
@@ -880,7 +880,7 @@ export class LlamaCpp implements LLM {
 
       return {
         embedding: Array.from(embedding.vector),
-        model: this.embedModelUri,
+        model: options.model ?? this.embedModelUri,
       };
     } catch (error) {
       console.error("Embedding error:", error);
@@ -892,7 +892,7 @@ export class LlamaCpp implements LLM {
    * Batch embed multiple texts efficiently
    * Uses Promise.all for parallel embedding - node-llama-cpp handles batching internally
    */
-  async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
+  async embedBatch(texts: string[], options: EmbedOptions = {}): Promise<(EmbeddingResult | null)[]> {
     if (this._ciMode) throw new Error("LLM operations are disabled in CI (set CI=true)");
     // Ping activity at start to keep models alive during this operation
     this.touchActivity();
@@ -915,7 +915,7 @@ export class LlamaCpp implements LLM {
             }
             const embedding = await context.getEmbeddingFor(safeText);
             this.touchActivity();
-            embeddings.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
+            embeddings.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
           } catch (err) {
             console.error("Embedding error for text:", err);
             embeddings.push(null);
@@ -942,7 +942,7 @@ export class LlamaCpp implements LLM {
               }
               const embedding = await ctx.getEmbeddingFor(safeText);
               this.touchActivity();
-              results.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
+              results.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
             } catch (err) {
               console.error("Embedding error for text:", err);
               results.push(null);
@@ -1431,8 +1431,8 @@ class LLMSession implements ILLMSession {
     return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options));
   }
 
-  async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
-    return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts));
+  async embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> {
+    return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts, options));
   }
 
   async expandQuery(

+ 6 - 6
src/store.ts

@@ -1466,8 +1466,8 @@ export async function generateEmbeddings(
 
       if (!vectorTableInitialized) {
         const firstChunk = batchChunks[0]!;
-        const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title);
-        const firstResult = await session.embed(firstText);
+        const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, model);
+        const firstResult = await session.embed(firstText, { model });
         if (!firstResult) {
           throw new Error("Failed to get embedding dimensions from first chunk");
         }
@@ -1498,10 +1498,10 @@ export async function generateEmbeddings(
 
         const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
         const chunkBatch = batchChunks.slice(batchStart, batchEnd);
-        const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
+        const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, model));
 
         try {
-          const embeddings = await session.embedBatch(texts);
+          const embeddings = await session.embedBatch(texts, { model });
           for (let i = 0; i < chunkBatch.length; i++) {
             const chunk = chunkBatch[i]!;
             const embedding = embeddings[i];
@@ -1522,8 +1522,8 @@ export async function generateEmbeddings(
           } else {
             for (const chunk of chunkBatch) {
               try {
-                const text = formatDocForEmbedding(chunk.text, chunk.title);
-                const result = await session.embed(text);
+                const text = formatDocForEmbedding(chunk.text, chunk.title, model);
+                const result = await session.embed(text, { model });
                 if (result) {
                   insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
                   chunksEmbedded++;

+ 32 - 2
test/store.test.ts

@@ -2753,13 +2753,19 @@ describe("Embedding batching", () => {
 
   function createFakeEmbedLlm() {
     const embedBatchCalls: string[][] = [];
+    const embedCalls: { text: string; options?: { model?: string } }[] = [];
+    const embedBatchModelCalls: ({ model?: string } | undefined)[] = [];
     return {
       embedBatchCalls,
-      async embed(_text: string) {
+      embedCalls,
+      embedBatchModelCalls,
+      async embed(text: string, options?: { model?: string }) {
+        embedCalls.push({ text, options });
         return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" };
       },
-      async embedBatch(texts: string[]) {
+      async embedBatch(texts: string[], options?: { model?: string }) {
         embedBatchCalls.push([...texts]);
+        embedBatchModelCalls.push(options);
         return texts.map((_text, index) => ({
           embedding: [index + 1, index + 2, index + 3],
           model: "fake-embed",
@@ -2832,6 +2838,30 @@ describe("Embedding batching", () => {
     }
   });
 
+  test("generateEmbeddings passes the selected model through to embed calls and metadata", async () => {
+    const store = await createTestStore();
+    const db = store.db;
+    const fakeLlm = createFakeEmbedLlm();
+    const model = "hf:Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf";
+
+    setDefaultLlamaCpp(createFakeTokenizer() as any);
+    store.llm = fakeLlm as any;
+
+    try {
+      await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" });
+
+      const result = await generateEmbeddings(store, { model });
+
+      expect(result.chunksEmbedded).toBe(1);
+      expect(fakeLlm.embedCalls[0]?.options?.model).toBe(model);
+      expect(fakeLlm.embedBatchModelCalls).toEqual([{ model }]);
+      expect(db.prepare(`SELECT DISTINCT model FROM content_vectors`).all()).toEqual([{ model }]);
+    } finally {
+      setDefaultLlamaCpp(null);
+      await cleanupTestDb(store);
+    }
+  });
+
   test("generateEmbeddings rejects invalid batch limits", async () => {
     const store = await createTestStore();