Parcourir la source

Harden embedding overflow handling

Bek il y a 1 mois
Parent
commit
e4990e470e
4 fichiers modifiés avec 130 ajouts et 35 suppressions
  1. 22 12
      src/llm.ts
  2. 57 23
      src/store.ts
  3. 26 0
      test/llm.test.ts
  4. 25 0
      test/store.test.ts

+ 22 - 12
src/llm.ts

@@ -889,20 +889,30 @@ export class LlamaCpp implements LLM {
    * detokenizes back to text if truncation is needed.
    * Returns the (possibly truncated) text and whether truncation occurred.
    */
-  private async truncateToContextSize(text: string): Promise<{ text: string; truncated: boolean }> {
-    if (!this.embedModel) return { text, truncated: false };
+  private resolveEmbedTokenLimit(): number {
+    const trainedContextSize = this.embedModel?.trainContextSize;
+    if (typeof trainedContextSize === "number" && Number.isFinite(trainedContextSize) && trainedContextSize > 0) {
+      return Math.max(1, Math.min(LlamaCpp.EMBED_CONTEXT_SIZE, trainedContextSize));
+    }
+    return LlamaCpp.EMBED_CONTEXT_SIZE;
+  }
+
+  private async truncateToContextSize(
+    text: string
+  ): Promise<{ text: string; truncated: boolean; limit: number }> {
+    if (!this.embedModel) return { text, truncated: false, limit: LlamaCpp.EMBED_CONTEXT_SIZE };
 
-    const maxTokens = this.embedModel.trainContextSize;
-    if (maxTokens <= 0) return { text, truncated: false };
+    const maxTokens = this.resolveEmbedTokenLimit();
+    if (maxTokens <= 0) return { text, truncated: false, limit: maxTokens };
 
     const tokens = this.embedModel.tokenize(text);
-    if (tokens.length <= maxTokens) return { text, truncated: false };
+    if (tokens.length <= maxTokens) return { text, truncated: false, limit: maxTokens };
 
     // Leave a small margin (4 tokens) for BOS/EOS overhead
     const safeLimit = Math.max(1, maxTokens - 4);
     const truncatedTokens = tokens.slice(0, safeLimit);
     const truncatedText = this.embedModel.detokenize(truncatedTokens);
-    return { text: truncatedText, truncated: true };
+    return { text: truncatedText, truncated: true, limit: maxTokens };
   }
 
   async embed(text: string, options: EmbedOptions = {}): Promise<EmbeddingResult | null> {
@@ -913,9 +923,9 @@ export class LlamaCpp implements LLM {
       const context = await this.ensureEmbedContext();
 
       // Guard: truncate text that exceeds model context window to prevent GGML crash
-      const { text: safeText, truncated } = await this.truncateToContextSize(text);
+      const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
       if (truncated) {
-        console.warn(`⚠ Text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+        console.warn(`⚠ Text truncated to fit embedding context (${limit} tokens)`);
       }
 
       const embedding = await context.getEmbeddingFor(safeText);
@@ -951,9 +961,9 @@ export class LlamaCpp implements LLM {
         const embeddings: ({ embedding: number[]; model: string } | null)[] = [];
         for (const text of texts) {
           try {
-            const { text: safeText, truncated } = await this.truncateToContextSize(text);
+            const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
             if (truncated) {
-              console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+              console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
             }
             const embedding = await context.getEmbeddingFor(safeText);
             this.touchActivity();
@@ -978,9 +988,9 @@ export class LlamaCpp implements LLM {
           const results: (EmbeddingResult | null)[] = [];
           for (const text of chunk) {
             try {
-              const { text: safeText, truncated } = await this.truncateToContextSize(text);
+              const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
               if (truncated) {
-                console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+                console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
               }
               const embedding = await ctx.getEmbeddingFor(safeText);
               this.touchActivity();

+ 57 - 23
src/store.ts

@@ -2228,33 +2228,67 @@ export async function chunkDocumentByTokens(
 
   // Tokenize and split any chunks that still exceed limit
   const results: { text: string; pos: number; tokens: number }[] = [];
+  const clampOverlapChars = (value: number, maxChars: number): number => {
+    if (maxChars <= 1) return 0;
+    return Math.max(0, Math.min(maxChars - 1, Math.floor(value)));
+  };
 
-  for (const chunk of charChunks) {
-    // Respect abort signal to avoid runaway tokenization
-    if (signal?.aborted) break;
+  const pushChunkWithinTokenLimit = async (text: string, pos: number): Promise<void> => {
+    if (signal?.aborted) return;
 
-    const tokens = await llm.tokenize(chunk.text);
+    const tokens = await llm.tokenize(text);
+    if (tokens.length <= maxTokens || text.length <= 1) {
+      results.push({ text, pos, tokens: tokens.length });
+      return;
+    }
 
-    if (tokens.length <= maxTokens) {
-      results.push({ text: chunk.text, pos: chunk.pos, tokens: tokens.length });
-    } else {
-      // Chunk is still too large - split it further
-      // Use actual token count to estimate better char limit
-      const actualCharsPerToken = chunk.text.length / tokens.length;
-      const safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95); // 5% safety margin
-
-      const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
-
-      for (const subChunk of subChunks) {
-        if (signal?.aborted) break;
-        const subTokens = await llm.tokenize(subChunk.text);
-        results.push({
-          text: subChunk.text,
-          pos: chunk.pos + subChunk.pos,
-          tokens: subTokens.length,
-        });
-      }
+    const actualCharsPerToken = text.length / tokens.length;
+    let safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95);
+    if (!Number.isFinite(safeMaxChars) || safeMaxChars < 1) {
+      safeMaxChars = Math.floor(text.length / 2);
+    }
+    safeMaxChars = Math.max(1, Math.min(text.length - 1, safeMaxChars));
+
+    let nextOverlapChars = clampOverlapChars(
+      overlapChars * actualCharsPerToken / 2,
+      safeMaxChars,
+    );
+    let nextWindowChars = Math.max(0, Math.floor(windowChars * actualCharsPerToken / 2));
+    let subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
+
+    // Pathological single-line blobs can produce no meaningful breakpoint progress.
+    // Fall back to a simple half split so every recursion step strictly shrinks.
+    if (
+      subChunks.length <= 1
+      || subChunks[0]?.text.length === text.length
+    ) {
+      safeMaxChars = Math.max(1, Math.floor(text.length / 2));
+      nextOverlapChars = 0;
+      nextWindowChars = 0;
+      subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
     }
+
+    if (
+      subChunks.length <= 1
+      || subChunks[0]?.text.length === text.length
+    ) {
+      const fallbackTokens = tokens.slice(0, Math.max(1, maxTokens));
+      const truncatedText = await llm.detokenize(fallbackTokens);
+      results.push({
+        text: truncatedText,
+        pos,
+        tokens: fallbackTokens.length,
+      });
+      return;
+    }
+
+    for (const subChunk of subChunks) {
+      await pushChunkWithinTokenLimit(text.slice(subChunk.pos, subChunk.pos + subChunk.text.length), pos + subChunk.pos);
+    }
+  };
+
+  for (const chunk of charChunks) {
+    await pushChunkWithinTokenLimit(chunk.text, chunk.pos);
   }
 
   return results;

+ 26 - 0
test/llm.test.ts

@@ -194,6 +194,32 @@ describe("LlamaCpp model resolution (config > env > default)", () => {
   });
 });
 
+describe("LlamaCpp embedding truncation", () => {
+  test("truncates against the active embedding context limit, not the model train context", async () => {
+    const llm = new LlamaCpp({}) as any;
+    const getEmbeddingFor = vi.fn(async (text: string) => ({
+      vector: new Float32Array([0.25, 0.5]),
+      text,
+    }));
+
+    llm.touchActivity = vi.fn();
+    llm.embedModel = {
+      trainContextSize: 8192,
+      tokenize: (text: string) => Array.from({ length: text.length }, () => 1),
+      detokenize: (tokens: readonly number[]) => "x".repeat(tokens.length),
+    };
+    llm.ensureEmbedContext = vi.fn().mockResolvedValue({ getEmbeddingFor });
+
+    const result = await llm.embed("x".repeat(3000));
+
+    expect(getEmbeddingFor).toHaveBeenCalledWith("x".repeat(2044));
+    expect(result).toEqual({
+      embedding: [0.25, 0.5],
+      model: llm.embedModelUri,
+    });
+  });
+});
+
 describe("LlamaCpp rerank deduping", () => {
   test("deduplicates identical document texts before scoring", async () => {
     const llm = new LlamaCpp({}) as any;

+ 25 - 0
test/store.test.ts

@@ -2805,6 +2805,31 @@ describe("Embedding batching", () => {
   });
 });
 
+describe("Token chunking guardrails", () => {
+  test("chunkDocumentByTokens keeps pathological single-line blobs under the token limit", async () => {
+    setDefaultLlamaCpp({
+      async tokenize(text: string) {
+        return Array.from({ length: text.length }, () => 1);
+      },
+      async detokenize(tokens: readonly number[]) {
+        return "x".repeat(tokens.length);
+      },
+    } as any);
+
+    try {
+      const chunks = await chunkDocumentByTokens("x".repeat(1200), 100, 15, 20);
+
+      expect(chunks.length).toBeGreaterThan(1);
+      expect(chunks.every((chunk) => chunk.tokens <= 100)).toBe(true);
+      for (let i = 1; i < chunks.length; i++) {
+        expect(chunks[i]!.pos).toBeGreaterThan(chunks[i - 1]!.pos);
+      }
+    } finally {
+      setDefaultLlamaCpp(null);
+    }
+  });
+});
+
 // =============================================================================
 // Content-Addressable Storage Tests
 // =============================================================================