Sfoglia il codice sorgente

fix: use sequential embedding on CPU-only systems to avoid race condition (#54)

* fix: add promise guard to ensureEmbedContext to prevent race condition

Root cause: ensureEmbedContext() was not thread-safe. When multiple parallel
embedding requests called ensureEmbedContext() simultaneously, all would see
embedContext === null and start creating new contexts. This race condition
caused 'Context is disposed' errors as contexts were overwritten/orphaned.

The fix adds a promise guard (embedContextCreatePromise) to ensure only one
context creation runs at a time - identical to the pattern already used in
ensureGenerateModel().

Changes:
- Add embedContextCreatePromise field to track in-progress context creation
- Modify ensureEmbedContext() to wait for existing creation if in progress
- Update test comment and timeout for CPU-only systems

Testing:
- Fresh model download + qmd embed: 28/28 chunks succeeded (was 14/27)
- All embedBatch tests pass
- No warmup hack needed - full parallel performance from the start

Environment tested:
- Ubuntu 24.04 LTS (x64), Bun 1.3.6, node-llama-cpp 3.14.5, no GPU

* test: improve race condition test to verify single context creation

The previous test only verified embeddings succeeded but didn't prove the fix
actually prevents multiple context creation. This improved test:

- Instruments createEmbeddingContext to count invocations
- Runs 5 concurrent embedBatch calls on a fresh LlamaCpp instance
- Asserts exactly 1 context is created (fails with 5 without the fix)

Verified locally:
- With fix: 1 context created (PASS)
- Without fix: 5 contexts created (FAIL)

* chore: clear embedContextCreatePromise in dispose() for consistency
Freeman Jiang 3 mesi fa
parent
commit
bfb0eebc3e
2 ha cambiato i file con 80 aggiunte e 4 eliminazioni
  1. 57 0
      src/llm.test.ts
  2. 23 4
      src/llm.ts

+ 57 - 0
src/llm.test.ts

@@ -167,6 +167,63 @@ describe("LlamaCpp Integration", () => {
       // Performance is machine/load dependent. We only assert batch isn't drastically worse.
       expect(batchTime).toBeLessThanOrEqual(seqTime * 3);
     });
+
+    test("handles concurrent embedBatch calls on fresh instance without race condition", async () => {
+      // This test verifies the fix for a race condition where concurrent calls to
+      // ensureEmbedContext() could create multiple contexts. Without the promise guard,
+      // each concurrent embedBatch call sees embedContext === null and creates its own
+      // context, causing resource leaks and potential "Context is disposed" errors.
+      //
+      // See: https://github.com/tobi/qmd/pull/54
+      //
+      // The fix uses a promise guard to ensure only one context creation runs at a time.
+      // We verify this by instrumenting createEmbeddingContext to count invocations.
+      
+      const freshLlm = new LlamaCpp({});
+      let contextCreateCount = 0;
+      
+      // Instrument the model's createEmbeddingContext to count calls
+      const originalEnsureEmbedModel = (freshLlm as any).ensureEmbedModel.bind(freshLlm);
+      let modelInstrumented = false;
+      (freshLlm as any).ensureEmbedModel = async function() {
+        const model = await originalEnsureEmbedModel();
+        if (!modelInstrumented) {
+          modelInstrumented = true;
+          const originalCreate = model.createEmbeddingContext.bind(model);
+          model.createEmbeddingContext = async function(...args: any[]) {
+            contextCreateCount++;
+            return originalCreate(...args);
+          };
+        }
+        return model;
+      };
+      
+      const texts = Array(10).fill(null).map((_, i) => `Document ${i}`);
+
+      // Call embedBatch 5 TIMES in parallel on fresh instance.
+      // Without the promise guard fix, this would create 5 contexts (one per call).
+      // With the fix, only 1 context should be created.
+      const batches = await Promise.all([
+        freshLlm.embedBatch(texts.slice(0, 2)),
+        freshLlm.embedBatch(texts.slice(2, 4)),
+        freshLlm.embedBatch(texts.slice(4, 6)),
+        freshLlm.embedBatch(texts.slice(6, 8)),
+        freshLlm.embedBatch(texts.slice(8, 10)),
+      ]);
+
+      const allResults = batches.flat();
+      expect(allResults).toHaveLength(10);
+      
+      const successCount = allResults.filter(r => r !== null).length;
+      expect(successCount).toBe(10);
+
+      // THE KEY ASSERTION: Only 1 context should be created, not 5
+      // Without the fix, contextCreateCount would be 5 (one per concurrent embedBatch call)
+      console.log(`Context creation count: ${contextCreateCount} (expected: 1)`);
+      expect(contextCreateCount).toBe(1);
+      
+      await freshLlm.dispose();
+    }, 60000);
   });
 
   describe("rerank", () => {

+ 23 - 4
src/llm.ts

@@ -241,8 +241,9 @@ export class LlamaCpp implements LLM {
   private rerankModelUri: string;
   private modelCacheDir: string;
 
-  // Ensure we don't load the same model concurrently (which can allocate duplicate VRAM).
+  // Ensure we don't load the same model/context concurrently (which can allocate duplicate VRAM).
   private embedModelLoadPromise: Promise<LlamaModel> | null = null;
+  private embedContextCreatePromise: Promise<LlamaEmbeddingContext> | null = null;
   private generateModelLoadPromise: Promise<LlamaModel> | null = null;
   private rerankModelLoadPromise: Promise<LlamaModel> | null = null;
 
@@ -402,11 +403,28 @@ export class LlamaCpp implements LLM {
 
   /**
    * Load embedding context (lazy). Context can be disposed and recreated without reloading the model.
+   * Uses promise guard to prevent concurrent context creation race condition.
    */
   private async ensureEmbedContext(): Promise<LlamaEmbeddingContext> {
     if (!this.embedContext) {
-      const model = await this.ensureEmbedModel();
-      this.embedContext = await model.createEmbeddingContext();
+      // If context creation is already in progress, wait for it
+      if (this.embedContextCreatePromise) {
+        return await this.embedContextCreatePromise;
+      }
+
+      // Start context creation and store promise so concurrent calls wait
+      this.embedContextCreatePromise = (async () => {
+        const model = await this.ensureEmbedModel();
+        const context = await model.createEmbeddingContext();
+        this.embedContext = context;
+        return context;
+      })();
+
+      try {
+        await this.embedContextCreatePromise;
+      } finally {
+        this.embedContextCreatePromise = null;
+      }
     }
     this.touchActivity();
     return this.embedContext;
@@ -781,8 +799,9 @@ Final Output:`;
     this.rerankModel = null;
     this.llama = null;
 
-    // Clear any in-flight load promises
+    // Clear any in-flight load/create promises
     this.embedModelLoadPromise = null;
+    this.embedContextCreatePromise = null;
     this.generateModelLoadPromise = null;
     this.rerankModelLoadPromise = null;
   }