Просмотр исходного кода

perf: batch vector embeddings + collection-aware FTS filtering

Three improvements to hybridQuery:

1. Collection filter pushed into SQL: searchFTS and searchVec now
   accept collectionName directly instead of filtering post-hoc.
   Reduces noise in FTS probe and all expanded-query FTS calls.
   Also fixes MCP server's FTS search to use SQL-level filtering.

2. Batch embed for vector searches: instead of embedding each
   vec/hyde query sequentially (one embed call per query), we now
   collect all texts that need vector search and embed them in a
   single embedBatch() call. The sqlite-vec lookups still run
   sequentially (they're fast), but the expensive LLM embed step
   is batched.

3. FTS-first ordering: all lex expansions run immediately (sync,
   no LLM needed) before the vector embedding batch. This means
   FTS results are ready while embeddings compute.

Also cleans up legacy collectionId parameter naming (was number,
now properly string collectionName throughout).
Tobi Lütke 3 месяцев назад
Родитель
Сommit
b69fae7aa3
4 измененных файлов с 55 добавлено и 49 удалено
  1. 1 3
      src/mcp.ts
  2. 1 2
      src/qmd.ts
  3. 2 2
      src/store.test.ts
  4. 51 42
      src/store.ts

+ 1 - 3
src/mcp.ts

@@ -237,9 +237,7 @@ function createMcpServer(store: Store): McpServer {
       },
     },
     async ({ query, limit, minScore, collection }) => {
-      // Note: Collection filtering is now done post-search since collections are managed in YAML
-      const results = store.searchFTS(query, limit || 10)
-        .filter(r => !collection || r.collectionName === collection);
+      const results = store.searchFTS(query, limit || 10, collection);
       const filtered: SearchResultItem[] = results
         .filter(r => r.score >= (minScore || 0))
         .map(r => {

+ 1 - 2
src/qmd.ts

@@ -1871,8 +1871,7 @@ function search(query: string, opts: OutputOptions): void {
 
   // Use large limit for --all, otherwise fetch more than needed and let outputResults filter
   const fetchLimit = opts.all ? 100000 : Math.max(50, opts.limit * 2);
-  // searchFTS accepts collection name as number parameter for legacy reasons (will be fixed in store.ts)
-  const results = searchFTS(db, query, fetchLimit, collectionName as any);
+  const results = searchFTS(db, query, fetchLimit, collectionName);
 
   // Add context to results
   const resultsWithContext = results.map(r => ({

+ 2 - 2
src/store.test.ts

@@ -911,8 +911,8 @@ describe("FTS Search", () => {
     const allResults = store.searchFTS("searchable", 10);
     expect(allResults).toHaveLength(2);
 
-    // Filter by collection name (collectionId is now treated as collection name string)
-    const filtered = store.searchFTS("searchable", 10, collection1 as unknown as number);
+    // Filter by collection name
+    const filtered = store.searchFTS("searchable", 10, collection1);
     expect(filtered).toHaveLength(1);
     expect(filtered[0]!.displayPath).toBe(`${collection1}/doc1.md`);
 

+ 51 - 42
src/store.ts

@@ -669,8 +669,8 @@ export type Store = {
   toVirtualPath: (absolutePath: string) => string | null;
 
   // Search
-  searchFTS: (query: string, limit?: number, collectionId?: number) => SearchResult[];
-  searchVec: (query: string, model: string, limit?: number, collectionName?: string) => Promise<SearchResult[]>;
+  searchFTS: (query: string, limit?: number, collectionName?: string) => SearchResult[];
+  searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => Promise<SearchResult[]>;
 
   // Query expansion & reranking
   expandQuery: (query: string, model?: string) => Promise<ExpandedQuery[]>;
@@ -752,8 +752,8 @@ export function createStore(dbPath?: string): Store {
     toVirtualPath: (absolutePath: string) => toVirtualPath(db, absolutePath),
 
     // Search
-    searchFTS: (query: string, limit?: number, collectionId?: number) => searchFTS(db, query, limit, collectionId),
-    searchVec: (query: string, model: string, limit?: number, collectionName?: string) => searchVec(db, query, model, limit, collectionName),
+    searchFTS: (query: string, limit?: number, collectionName?: string) => searchFTS(db, query, limit, collectionName),
+    searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding),
 
     // Query expansion & reranking
     expandQuery: (query: string, model?: string) => expandQuery(query, model, db),
@@ -1897,7 +1897,7 @@ function buildFTS5Query(query: string): string | null {
   return terms.map(t => `"${t}"*`).join(' AND ');
 }
 
-export function searchFTS(db: Database, query: string, limit: number = 20, collectionId?: number): SearchResult[] {
+export function searchFTS(db: Database, query: string, limit: number = 20, collectionName?: string): SearchResult[] {
   const ftsQuery = buildFTS5Query(query);
   if (!ftsQuery) return [];
 
@@ -1916,12 +1916,9 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
   `;
   const params: (string | number)[] = [ftsQuery];
 
-  if (collectionId) {
-    // Note: collectionId is a legacy parameter that should be phased out
-    // Collections are now managed in YAML. For now, we interpret it as a collection name filter.
-    // This code path is likely unused as collection filtering should be done at CLI level.
+  if (collectionName) {
     sql += ` AND d.collection = ?`;
-    params.push(String(collectionId));
+    params.push(String(collectionName));
   }
 
   // bm25 lower is better; sort ascending.
@@ -1957,11 +1954,11 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
 // Vector Search
 // =============================================================================
 
-export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string, session?: ILLMSession): Promise<SearchResult[]> {
+export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]): Promise<SearchResult[]> {
   const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get();
   if (!tableExists) return [];
 
-  const embedding = await getEmbedding(query, model, true, session);
+  const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session);
   if (!embedding) return [];
 
   // IMPORTANT: We use a two-step query approach here because sqlite-vec virtual tables
@@ -2722,8 +2719,8 @@ export async function hybridQuery(
   ).get();
 
   // Step 1: BM25 probe — strong signal skips expensive LLM expansion
-  const initialFts = store.searchFTS(query, 20)
-    .filter(r => !collection || r.collectionName === collection);
+  // Pass collection directly into FTS query (filter at SQL level, not post-hoc)
+  const initialFts = store.searchFTS(query, 20, collection);
   const topScore = initialFts[0]?.score ?? 0;
   const secondScore = initialFts[1]?.score ?? 0;
   const hasStrongSignal = initialFts.length > 0
@@ -2749,26 +2746,15 @@ export async function hybridQuery(
   }
 
   // Step 3: Route searches by query type
-  // Original query → vector search (FTS already covered by probe in step 1).
-  // Vector searches run sequentially — node-llama-cpp's embed context
-  // hangs on concurrent embed() calls (known limitation).
-  if (hasVectors) {
-    const vecResults = await store.searchVec(query, DEFAULT_EMBED_MODEL, 20, collection);
-    if (vecResults.length > 0) {
-      for (const r of vecResults) docidMap.set(r.filepath, r.docid);
-      rankedLists.push(vecResults.map(r => ({
-        file: r.filepath, displayPath: r.displayPath,
-        title: r.title, body: r.body || "", score: r.score,
-      })));
-    }
-  }
+  //
+  // Strategy: run all FTS queries immediately (they're sync/instant), then
+  // batch-embed all vector queries in one embedBatch() call, then run
+  // sqlite-vec lookups with pre-computed embeddings.
 
-  // Expanded queries → route by type: lex→FTS only, vec/hyde→vector only.
-  // This restores the CLI's query-type-aware routing that was lost in the initial refactor.
+  // 3a: Run FTS for all lex expansions right away (no LLM needed)
   for (const q of expanded) {
     if (q.type === 'lex') {
-      const ftsResults = store.searchFTS(q.text, 20)
-        .filter(r => !collection || r.collectionName === collection);
+      const ftsResults = store.searchFTS(q.text, 20, collection);
       if (ftsResults.length > 0) {
         for (const r of ftsResults) docidMap.set(r.filepath, r.docid);
         rankedLists.push(ftsResults.map(r => ({
@@ -2776,17 +2762,40 @@ export async function hybridQuery(
           title: r.title, body: r.body || "", score: r.score,
         })));
       }
-    } else {
-      // vec or hyde → vector search only
-      if (hasVectors) {
-        const vecResults = await store.searchVec(q.text, DEFAULT_EMBED_MODEL, 20, collection);
-        if (vecResults.length > 0) {
-          for (const r of vecResults) docidMap.set(r.filepath, r.docid);
-          rankedLists.push(vecResults.map(r => ({
-            file: r.filepath, displayPath: r.displayPath,
-            title: r.title, body: r.body || "", score: r.score,
-          })));
-        }
+    }
+  }
+
+  // 3b: Collect all texts that need vector search (original query + vec/hyde expansions)
+  if (hasVectors) {
+    const vecQueries: { text: string; isOriginal: boolean }[] = [
+      { text: query, isOriginal: true },
+    ];
+    for (const q of expanded) {
+      if (q.type === 'vec' || q.type === 'hyde') {
+        vecQueries.push({ text: q.text, isOriginal: false });
+      }
+    }
+
+    // Batch embed all vector queries in a single call
+    const llm = getDefaultLlamaCpp();
+    const textsToEmbed = vecQueries.map(q => formatQueryForEmbedding(q.text));
+    const embeddings = await llm.embedBatch(textsToEmbed);
+
+    // Run sqlite-vec lookups with pre-computed embeddings
+    for (let i = 0; i < vecQueries.length; i++) {
+      const embedding = embeddings[i]?.embedding;
+      if (!embedding) continue;
+
+      const vecResults = await store.searchVec(
+        vecQueries[i]!.text, DEFAULT_EMBED_MODEL, 20, collection,
+        undefined, embedding
+      );
+      if (vecResults.length > 0) {
+        for (const r of vecResults) docidMap.set(r.filepath, r.docid);
+        rankedLists.push(vecResults.map(r => ({
+          file: r.filepath, displayPath: r.displayPath,
+          title: r.title, body: r.body || "", score: r.score,
+        })));
       }
     }
   }