Sfoglia il codice sorgente

perf: parallel reranking with multiple contexts (2.7x speedup)

node-llama-cpp's LlamaRankingContext uses a single sequence with a
withLock() guard, making rankAll() effectively sequential despite
using Promise.all(). Each document evaluation erases the context,
evaluates tokens, and extracts the logit — all serialized.

Fix: create 4 parallel ranking contexts from the same model (model
weights are shared, only KV cache is duplicated). Split documents
across contexts and evaluate in parallel via Promise.all().

Benchmarks (40 chunks, CUDA, 4x A6000):
- 1 context:  898ms (baseline)
- 2 contexts: 460ms (2.0x)
- 4 contexts: 338ms (2.7x)  ← sweet spot
- 8 contexts: 458ms (VRAM contention)

End-to-end 'qmd query' time: 7.5s → 3.7s

Gracefully handles VRAM limits — if creating the Nth context fails,
falls back to however many were successfully created.
Tobi Lütke 3 mesi fa
parent
commit
0a0e1e6f29
1 ha cambiato i file con 46 aggiunte e 14 eliminazioni
  1. 46 14
      src/llm.ts

+ 46 - 14
src/llm.ts

@@ -357,7 +357,7 @@ export class LlamaCpp implements LLM {
   private embedContext: LlamaEmbeddingContext | null = null;
   private generateModel: LlamaModel | null = null;
   private rerankModel: LlamaModel | null = null;
-  private rerankContext: Awaited<ReturnType<LlamaModel["createRankingContext"]>> | null = null;
+  private rerankContexts: Awaited<ReturnType<LlamaModel["createRankingContext"]>>[] = [];
 
   private embedModelUri: string;
   private generateModelUri: string;
@@ -423,7 +423,7 @@ export class LlamaCpp implements LLM {
    * Check if any contexts are currently loaded (and therefore worth unloading on inactivity).
    */
   private hasLoadedContexts(): boolean {
-    return !!(this.embedContext || this.rerankContext);
+    return !!(this.embedContext || this.rerankContexts.length > 0);
   }
 
   /**
@@ -449,10 +449,10 @@ export class LlamaCpp implements LLM {
       await this.embedContext.dispose();
       this.embedContext = null;
     }
-    if (this.rerankContext) {
-      await this.rerankContext.dispose();
-      this.rerankContext = null;
+    for (const ctx of this.rerankContexts) {
+      await ctx.dispose();
     }
+    this.rerankContexts = [];
 
     // Optionally dispose models too (opt-in)
     if (this.disposeModelsOnInactivity) {
@@ -646,15 +646,31 @@ export class LlamaCpp implements LLM {
   }
 
   /**
-   * Load rerank context (lazy). Context can be disposed and recreated without reloading the model.
+   * Load rerank contexts (lazy). Creates multiple contexts for parallel ranking.
+   * Each context has its own sequence, so they can evaluate independently.
    */
-  private async ensureRerankContext(): Promise<Awaited<ReturnType<LlamaModel["createRankingContext"]>>> {
-    if (!this.rerankContext) {
+  private static readonly RERANK_PARALLEL_CONTEXTS = 4;
+
+  private async ensureRerankContexts(): Promise<Awaited<ReturnType<LlamaModel["createRankingContext"]>>[]> {
+    if (this.rerankContexts.length === 0) {
       const model = await this.ensureRerankModel();
-      this.rerankContext = await model.createRankingContext();
+      const n = LlamaCpp.RERANK_PARALLEL_CONTEXTS;
+      // Create contexts sequentially to avoid VRAM allocation races
+      for (let i = 0; i < n; i++) {
+        try {
+          this.rerankContexts.push(await model.createRankingContext());
+        } catch {
+          // VRAM exhausted — use however many we got
+          if (this.rerankContexts.length === 0) {
+            // Must have at least one
+            throw new Error("Failed to create any rerank context");
+          }
+          break;
+        }
+      }
     }
     this.touchActivity();
-    return this.rerankContext;
+    return this.rerankContexts;
   }
 
   // ==========================================================================
@@ -901,7 +917,7 @@ export class LlamaCpp implements LLM {
     // Ping activity at start to keep models alive during this operation
     this.touchActivity();
 
-    const context = await this.ensureRerankContext();
+    const contexts = await this.ensureRerankContexts();
 
     // Build a map from document text to original indices (for lookup after sorting)
     const textToDoc = new Map<string, { file: string; index: number }>();
@@ -912,8 +928,24 @@ export class LlamaCpp implements LLM {
     // Extract just the text for ranking
     const texts = documents.map((doc) => doc.text);
 
-    // Use the proper ranking API - returns [{document: string, score: number}] sorted by score
-    const ranked = await context.rankAndSort(query, texts);
+    // Split documents across contexts for parallel evaluation.
+    // Each context has its own sequence with a lock, so parallelism comes
+    // from multiple contexts evaluating different chunks simultaneously.
+    const n = contexts.length;
+    const chunkSize = Math.ceil(texts.length / n);
+    const chunks = Array.from({ length: n }, (_, i) =>
+      texts.slice(i * chunkSize, (i + 1) * chunkSize)
+    ).filter(chunk => chunk.length > 0);
+
+    const allScores = await Promise.all(
+      chunks.map((chunk, i) => contexts[i]!.rankAll(query, chunk))
+    );
+
+    // Reassemble scores in original order and sort
+    const flatScores = allScores.flat();
+    const ranked = texts
+      .map((text, i) => ({ document: text, score: flatScores[i]! }))
+      .sort((a, b) => b.score - a.score);
 
     // Map back to our result format using the text-to-doc map
     const results: RerankDocumentResult[] = ranked.map((item) => {
@@ -984,7 +1016,7 @@ export class LlamaCpp implements LLM {
 
     // Clear references
     this.embedContext = null;
-    this.rerankContext = null;
+    this.rerankContexts = [];
     this.embedModel = null;
     this.generateModel = null;
     this.rerankModel = null;