Selaa lähdekoodia

fix: truncate oversized text before embedding to prevent GGML crash

When a chunk exceeds the embedding model's context window (trainContextSize),
node-llama-cpp's getEmbeddingFor() triggers a native SIGABRT in GGML/Metal,
crashing the entire process.

Fix: Add truncateToContextSize() guard in embed() and embedBatch() that uses
the model's own tokenizer to check token count before calling getEmbeddingFor().
Oversized text is truncated to (trainContextSize - 4) tokens with a warning,
preserving partial embedding coverage instead of crashing.

Fixes #303
edy 2 kuukautta sitten
vanhempi
commit
9718d3767c
1 muutettua tiedostoa jossa 40 lisäystä ja 3 poistoa
  1. 40 3
      src/llm.ts

+ 40 - 3
src/llm.ts

@@ -800,13 +800,42 @@ export class LlamaCpp implements LLM {
   // Core API methods
   // ==========================================================================
 
+  /**
+   * Truncate text to fit within the embedding model's context window.
+   * Uses the model's own tokenizer for accurate token counting, then
+   * detokenizes back to text if truncation is needed.
+   * Returns the (possibly truncated) text and whether truncation occurred.
+   */
+  private async truncateToContextSize(text: string): Promise<{ text: string; truncated: boolean }> {
+    if (!this.embedModel) return { text, truncated: false };
+
+    const maxTokens = this.embedModel.trainContextSize;
+    if (maxTokens <= 0) return { text, truncated: false };
+
+    const tokens = this.embedModel.tokenize(text);
+    if (tokens.length <= maxTokens) return { text, truncated: false };
+
+    // Leave a small margin (4 tokens) for BOS/EOS overhead
+    const safeLimit = Math.max(1, maxTokens - 4);
+    const truncatedTokens = tokens.slice(0, safeLimit);
+    const truncatedText = this.embedModel.detokenize(truncatedTokens);
+    return { text: truncatedText, truncated: true };
+  }
+
   async embed(text: string, options: EmbedOptions = {}): Promise<EmbeddingResult | null> {
     // Ping activity at start to keep models alive during this operation
     this.touchActivity();
 
     try {
       const context = await this.ensureEmbedContext();
-      const embedding = await context.getEmbeddingFor(text);
+
+      // Guard: truncate text that exceeds model context window to prevent GGML crash
+      const { text: safeText, truncated } = await this.truncateToContextSize(text);
+      if (truncated) {
+        console.warn(`⚠ Text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+      }
+
+      const embedding = await context.getEmbeddingFor(safeText);
 
       return {
         embedding: Array.from(embedding.vector),
@@ -838,7 +867,11 @@ export class LlamaCpp implements LLM {
         const embeddings: ({ embedding: number[]; model: string } | null)[] = [];
         for (const text of texts) {
           try {
-            const embedding = await context.getEmbeddingFor(text);
+            const { text: safeText, truncated } = await this.truncateToContextSize(text);
+            if (truncated) {
+              console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+            }
+            const embedding = await context.getEmbeddingFor(safeText);
             this.touchActivity();
             embeddings.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
           } catch (err) {
@@ -861,7 +894,11 @@ export class LlamaCpp implements LLM {
           const results: (EmbeddingResult | null)[] = [];
           for (const text of chunk) {
             try {
-              const embedding = await ctx.getEmbeddingFor(text);
+              const { text: safeText, truncated } = await this.truncateToContextSize(text);
+              if (truncated) {
+                console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
+              }
+              const embedding = await ctx.getEmbeddingFor(safeText);
               this.touchActivity();
               results.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
             } catch (err) {