Просмотр исходного кода

perf(rerank): cap parallelism, deduplicate chunks, cache by content

- Cap rerank contexts at 4 to avoid VRAM exhaustion on high-core machines
- Deduplicate identical chunk texts before sending to reranker
- Cache rerank scores by chunk content instead of file path — same text
  from different files now shares a single reranker call
- Add truncation cache to avoid re-tokenizing duplicate documents
Tobi Lutke 2 месяцев назад
Родитель
Сommit
e3549dab1a
4 измененных файлов с 158 добавлено и 31 удалено
  1. 48 20
      src/llm.ts
  2. 16 11
      src/store.ts
  3. 59 0
      test/llm.test.ts
  4. 35 0
      test/store.test.ts

+ 48 - 20
src/llm.ts

@@ -763,7 +763,7 @@ export class LlamaCpp implements LLM {
     if (this.rerankContexts.length === 0) {
       const model = await this.ensureRerankModel();
       // ~960 MB per context with flash attention at contextSize 2048
-      const n = await this.computeParallelism(1000);
+      const n = Math.min(await this.computeParallelism(1000), 4);
       const threads = await this.threadsPerContext(n);
       for (let i = 0; i < n; i++) {
         try {
@@ -1057,6 +1057,7 @@ export class LlamaCpp implements LLM {
 
   // Qwen3 reranker chat template overhead (system prompt, tags, separators)
   private static readonly RERANK_TEMPLATE_OVERHEAD = 200;
+  private static readonly RERANK_TARGET_DOCS_PER_CONTEXT = 10;
 
   async rerank(
     query: string,
@@ -1073,34 +1074,58 @@ export class LlamaCpp implements LLM {
     // Budget = contextSize - template overhead - query tokens
     const queryTokens = model.tokenize(query).length;
     const maxDocTokens = LlamaCpp.RERANK_CONTEXT_SIZE - LlamaCpp.RERANK_TEMPLATE_OVERHEAD - queryTokens;
+    const truncationCache = new Map<string, string>();
 
     const truncatedDocs = documents.map((doc) => {
+      const cached = truncationCache.get(doc.text);
+      if (cached !== undefined) {
+        return cached === doc.text ? doc : { ...doc, text: cached };
+      }
+
       const tokens = model.tokenize(doc.text);
-      if (tokens.length <= maxDocTokens) return doc;
-      const truncatedText = model.detokenize(tokens.slice(0, maxDocTokens));
+      const truncatedText = tokens.length <= maxDocTokens
+        ? doc.text
+        : model.detokenize(tokens.slice(0, maxDocTokens));
+      truncationCache.set(doc.text, truncatedText);
+
+      if (truncatedText === doc.text) return doc;
       return { ...doc, text: truncatedText };
     });
 
-    // Build a map from document text to original indices (for lookup after sorting)
-    const textToDoc = new Map<string, { file: string; index: number }>();
+    // Deduplicate identical effective texts before scoring.
+    // This avoids redundant work for repeated chunks and fixes collisions where
+    // multiple docs map to the same chunk text.
+    const textToDocs = new Map<string, { file: string; index: number }[]>();
     truncatedDocs.forEach((doc, index) => {
-      textToDoc.set(doc.text, { file: doc.file, index });
+      const existing = textToDocs.get(doc.text);
+      if (existing) {
+        existing.push({ file: doc.file, index });
+      } else {
+        textToDocs.set(doc.text, [{ file: doc.file, index }]);
+      }
     });
 
     // Extract just the text for ranking
-    const texts = truncatedDocs.map((doc) => doc.text);
+    const texts = Array.from(textToDocs.keys());
 
     // Split documents across contexts for parallel evaluation.
     // Each context has its own sequence with a lock, so parallelism comes
     // from multiple contexts evaluating different chunks simultaneously.
-    const n = contexts.length;
-    const chunkSize = Math.ceil(texts.length / n);
-    const chunks = Array.from({ length: n }, (_, i) =>
+    const activeContextCount = Math.max(
+      1,
+      Math.min(
+        contexts.length,
+        Math.ceil(texts.length / LlamaCpp.RERANK_TARGET_DOCS_PER_CONTEXT)
+      )
+    );
+    const activeContexts = contexts.slice(0, activeContextCount);
+    const chunkSize = Math.ceil(texts.length / activeContexts.length);
+    const chunks = Array.from({ length: activeContexts.length }, (_, i) =>
       texts.slice(i * chunkSize, (i + 1) * chunkSize)
     ).filter(chunk => chunk.length > 0);
 
     const allScores = await Promise.all(
-      chunks.map((chunk, i) => contexts[i]!.rankAll(query, chunk))
+      chunks.map((chunk, i) => activeContexts[i]!.rankAll(query, chunk))
     );
 
     // Reassemble scores in original order and sort
@@ -1109,15 +1134,18 @@ export class LlamaCpp implements LLM {
       .map((text, i) => ({ document: text, score: flatScores[i]! }))
       .sort((a, b) => b.score - a.score);
 
-    // Map back to our result format using the text-to-doc map
-    const results: RerankDocumentResult[] = ranked.map((item) => {
-      const docInfo = textToDoc.get(item.document)!;
-      return {
-        file: docInfo.file,
-        score: item.score,
-        index: docInfo.index,
-      };
-    });
+    // Map back to our result format.
+    const results: RerankDocumentResult[] = [];
+    for (const item of ranked) {
+      const docInfos = textToDocs.get(item.document) ?? [];
+      for (const docInfo of docInfos) {
+        results.push({
+          file: docInfo.file,
+          score: item.score,
+          index: docInfo.index,
+        });
+      }
+    }
 
     return {
       results,

+ 16 - 11
src/store.ts

@@ -2381,38 +2381,43 @@ export async function expandQuery(query: string, model: string = DEFAULT_QUERY_M
 
 export async function rerank(query: string, documents: { file: string; text: string }[], model: string = DEFAULT_RERANK_MODEL, db: Database): Promise<{ file: string; score: number }[]> {
   const cachedResults: Map<string, number> = new Map();
-  const uncachedDocs: RerankDocument[] = [];
+  const uncachedDocsByChunk: Map<string, RerankDocument> = new Map();
 
   // Check cache for each document
   // Cache key includes chunk text — different queries can select different chunks
   // from the same file, and the reranker score depends on which chunk was sent.
+  // File path is excluded from the new cache key because the reranker score
+  // depends on the chunk content, not where it came from.
   for (const doc of documents) {
-    const cacheKey = getCacheKey("rerank", { query, file: doc.file, model, chunk: doc.text });
-    const cached = getCachedResult(db, cacheKey);
+    const cacheKey = getCacheKey("rerank", { query, model, chunk: doc.text });
+    const legacyCacheKey = getCacheKey("rerank", { query, file: doc.file, model, chunk: doc.text });
+    const cached = getCachedResult(db, cacheKey) ?? getCachedResult(db, legacyCacheKey);
     if (cached !== null) {
-      cachedResults.set(doc.file, parseFloat(cached));
+      cachedResults.set(doc.text, parseFloat(cached));
     } else {
-      uncachedDocs.push({ file: doc.file, text: doc.text });
+      uncachedDocsByChunk.set(doc.text, { file: doc.file, text: doc.text });
     }
   }
 
   // Rerank uncached documents using LlamaCpp
-  if (uncachedDocs.length > 0) {
+  if (uncachedDocsByChunk.size > 0) {
     const llm = getDefaultLlamaCpp();
+    const uncachedDocs = [...uncachedDocsByChunk.values()];
     const rerankResult = await llm.rerank(query, uncachedDocs, { model });
 
-    // Cache results — use original doc.text for cache key (result.file lacks chunk text)
-    const textByFile = new Map(documents.map(d => [d.file, d.text]));
+    // Cache results by chunk text so identical chunks across files are scored once.
+    const textByFile = new Map(uncachedDocs.map(d => [d.file, d.text]));
     for (const result of rerankResult.results) {
-      const cacheKey = getCacheKey("rerank", { query, file: result.file, model, chunk: textByFile.get(result.file) || "" });
+      const chunk = textByFile.get(result.file) || "";
+      const cacheKey = getCacheKey("rerank", { query, model, chunk });
       setCachedResult(db, cacheKey, result.score.toString());
-      cachedResults.set(result.file, result.score);
+      cachedResults.set(chunk, result.score);
     }
   }
 
   // Return all results sorted by score
   return documents
-    .map(doc => ({ file: doc.file, score: cachedResults.get(doc.file) || 0 }))
+    .map(doc => ({ file: doc.file, score: cachedResults.get(doc.text) || 0 }))
     .sort((a, b) => b.score - a.score);
 }
 

+ 59 - 0
test/llm.test.ts

@@ -117,6 +117,37 @@ describe("LlamaCpp expand context size config", () => {
   });
 });
 
+describe("LlamaCpp rerank deduping", () => {
+  test("deduplicates identical document texts before scoring", async () => {
+    const llm = new LlamaCpp({}) as any;
+    const rankAll = vi.fn(async (_query: string, docs: string[]) =>
+      docs.map((doc) => doc === "shared chunk" ? 0.9 : 0.2)
+    );
+
+    llm.touchActivity = vi.fn();
+    llm.ensureRerankContexts = vi.fn().mockResolvedValue([{ rankAll }]);
+    llm.ensureRerankModel = vi.fn().mockResolvedValue({
+      tokenize: (text: string) => Array.from(text),
+      detokenize: (tokens: string[]) => tokens.join(""),
+    });
+
+    const result = await llm.rerank("query", [
+      { file: "a.md", text: "shared chunk" },
+      { file: "b.md", text: "shared chunk" },
+      { file: "c.md", text: "different chunk" },
+    ]);
+
+    expect(rankAll).toHaveBeenCalledTimes(1);
+    expect(rankAll).toHaveBeenCalledWith("query", ["shared chunk", "different chunk"]);
+    expect(result.results).toHaveLength(3);
+
+    const scoreByFile = new Map(result.results.map((item) => [item.file, item.score]));
+    expect(scoreByFile.get("a.md")).toBe(0.9);
+    expect(scoreByFile.get("b.md")).toBe(0.9);
+    expect(scoreByFile.get("c.md")).toBe(0.2);
+  });
+});
+
 // =============================================================================
 // Integration Tests (require actual models)
 // =============================================================================
@@ -428,6 +459,34 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => {
       console.log(`Batch rerank of 10 docs took ${elapsed}ms`);
     });
 
+    test("uses fewer active rerank contexts for small batches", async () => {
+      const freshLlm = new LlamaCpp({});
+      const calls: number[] = [];
+      const fakeModel = {
+        tokenize: (text: string) => Array.from(text),
+        detokenize: (tokens: string[]) => tokens.join(""),
+      };
+      const fakeContexts = Array.from({ length: 4 }, (_, idx) => ({
+        rankAll: async (_query: string, docs: string[]) => {
+          calls.push(idx);
+          return docs.map(() => 0.5);
+        },
+      }));
+
+      (freshLlm as any).ensureRerankModel = async () => fakeModel;
+      (freshLlm as any).ensureRerankContexts = async () => fakeContexts;
+
+      const documents: RerankDocument[] = Array.from({ length: 20 }, (_, i) => ({
+        file: `doc${i}.md`,
+        text: `Document number ${i}`,
+      }));
+
+      const result = await freshLlm.rerank("topic 1", documents);
+
+      expect(result.results).toHaveLength(20);
+      expect(calls).toEqual([0, 1]);
+    });
+
     test("truncates and reranks document exceeding 2048 token context size", async () => {
       // The reranker context is created with contextSize=2048. Documents that
       // exceed the token budget (contextSize - template overhead - query tokens)

+ 35 - 0
test/store.test.ts

@@ -13,6 +13,7 @@ import { unlink, mkdtemp, rmdir, writeFile } from "node:fs/promises";
 import { tmpdir } from "node:os";
 import { join } from "node:path";
 import YAML from "yaml";
+import * as llmModule from "../src/llm.js";
 import { disposeDefaultLlamaCpp } from "../src/llm.js";
 import {
   createStore,
@@ -2427,6 +2428,40 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => {
 
     await cleanupTestDb(store);
   });
+
+  test("rerank deduplicates identical chunks across files", async () => {
+    const store = await createTestStore();
+    const rerankSpy = vi.fn(async (_query: string, docs: { file: string; text: string }[]) => ({
+      results: docs.map((doc, index) => ({
+        file: doc.file,
+        score: 1 - index * 0.1,
+        index,
+      })),
+      model: "mock-reranker",
+    }));
+
+    const llmSpy = vi.spyOn(llmModule, "getDefaultLlamaCpp").mockReturnValue({
+      rerank: rerankSpy,
+    } as any);
+
+    try {
+      const docs = [
+        { file: "doc1.md", text: "Shared chunk text" },
+        { file: "doc2.md", text: "Shared chunk text" },
+      ];
+
+      const first = await store.rerank("shared", docs);
+      const second = await store.rerank("shared", docs);
+
+      expect(first).toHaveLength(2);
+      expect(second).toHaveLength(2);
+      expect(rerankSpy).toHaveBeenCalledTimes(1);
+      expect(rerankSpy.mock.calls[0]?.[1]).toEqual([{ file: "doc2.md", text: "Shared chunk text" }]);
+    } finally {
+      llmSpy.mockRestore();
+      await cleanupTestDb(store);
+    }
+  });
 });
 
 // =============================================================================