Parcourir la source

Merge pull request #502 from JohnRichardEnders/feat/yaml-model-config

feat: support model configuration in index.yml
Tobias Lütke il y a 1 mois
Parent
commit
54fc7b01a9
8 fichiers modifiés avec 100 ajouts et 16 suppressions
  1. 4 0
      CHANGELOG.md
  2. 8 1
      src/cli/qmd.ts
  3. 10 0
      src/collections.ts
  4. 7 2
      src/index.ts
  5. 9 6
      src/llm.ts
  6. 12 2
      src/mcp/server.ts
  7. 6 5
      src/store.ts
  8. 44 0
      test/llm.test.ts

+ 4 - 0
CHANGELOG.md

@@ -18,6 +18,10 @@
   Measures precision@k, recall, MRR, and F1 across BM25, vector, hybrid,
   and full pipeline backends. Ships with an example fixture against
   the eval-docs test collection.
+- `models:` section in `index.yml` lets you configure `embed`, `rerank`,
+  and `generate` model URIs per config. Resolution order is
+  config > env var (`QMD_EMBED_MODEL`, `QMD_RERANK_MODEL`,
+  `QMD_GENERATE_MODEL`) > built-in default.
 - CLI search output now emits clickable OSC 8 terminal hyperlinks when
   stdout is a TTY. Links resolve `qmd://` paths to absolute filesystem
   paths and open in editors via URI templates (default:

+ 8 - 1
src/cli/qmd.ts

@@ -77,7 +77,7 @@ import {
   type ReindexResult,
   type ChunkStrategy,
 } from "../store.js";
-import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js";
+import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, setDefaultLlamaCpp, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js";
 import {
   formatSearchResults,
   formatDocuments,
@@ -118,6 +118,13 @@ function getStore(): ReturnType<typeof createStore> {
     try {
       const config = loadConfig();
       syncConfigToDb(store.db, config);
+      if (config.models) {
+        setDefaultLlamaCpp(new LlamaCpp({
+          embedModel: config.models.embed,
+          generateModel: config.models.generate,
+          rerankModel: config.models.rerank,
+        }));
+      }
     } catch {
       // Config may not exist yet — that's fine, DB works without it
     }

+ 10 - 0
src/collections.ts

@@ -33,6 +33,15 @@ export interface Collection {
   includeByDefault?: boolean; // Include in queries by default (default: true)
 }
 
+/**
+ * Model configuration for embedding, reranking, and generation
+ */
+export interface ModelsConfig {
+  embed?: string;
+  rerank?: string;
+  generate?: string;
+}
+
 /**
  * The complete configuration file structure
  */
@@ -41,6 +50,7 @@ export interface CollectionConfig {
   editor_uri?: string;                        // Editor URI template for terminal hyperlinks
   editor_uri_template?: string;               // Alias for editor_uri
   collections: Record<string, Collection>;    // Collection name -> config
+  models?: ModelsConfig;
 }
 
 /**

+ 7 - 2
src/index.ts

@@ -351,21 +351,26 @@ export async function createStore(options: StoreOptions): Promise<QMDStore> {
   const hasYamlConfig = !!options.configPath;
 
   // Sync config into SQLite store_collections
+  let config: CollectionConfig | undefined;
   if (options.configPath) {
     // YAML mode: inject config source for write-through, sync to DB
     setConfigSource({ configPath: options.configPath });
-    const config = loadConfig();
+    config = loadConfig();
     syncConfigToDb(db, config);
   } else if (options.config) {
     // Inline config mode: inject config source for mutations, sync to DB
     setConfigSource({ config: options.config });
-    syncConfigToDb(db, options.config);
+    config = options.config;
+    syncConfigToDb(db, config);
   }
   // else: DB-only mode — no external config, use existing store_collections
 
   // Create a per-store LlamaCpp instance — lazy-loads models on first use,
   // auto-unloads after 5 min inactivity to free VRAM.
   const llm = new LlamaCpp({
+    embedModel: config?.models?.embed,
+    generateModel: config?.models?.generate,
+    rerankModel: config?.models?.rerank,
     inactivityTimeoutMs: 5 * 60 * 1000,
     disposeModelsOnInactivity: true,
   });

+ 9 - 6
src/llm.ts

@@ -193,7 +193,7 @@ export type RerankDocument = {
 // HuggingFace model URIs for node-llama-cpp
 // Format: hf:<user>/<repo>/<file>
 // Override via QMD_EMBED_MODEL env var (e.g. hf:Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf)
-const DEFAULT_EMBED_MODEL = process.env.QMD_EMBED_MODEL ?? "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
+const DEFAULT_EMBED_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
 const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
 // const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf";
 const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
@@ -436,15 +436,19 @@ export class LlamaCpp implements LLM {
 
 
   constructor(config: LlamaCppConfig = {}) {
-    this.embedModelUri = config.embedModel || DEFAULT_EMBED_MODEL;
-    this.generateModelUri = config.generateModel || DEFAULT_GENERATE_MODEL;
-    this.rerankModelUri = config.rerankModel || DEFAULT_RERANK_MODEL;
+    this.embedModelUri = config.embedModel || process.env.QMD_EMBED_MODEL || DEFAULT_EMBED_MODEL;
+    this.generateModelUri = config.generateModel || process.env.QMD_GENERATE_MODEL || DEFAULT_GENERATE_MODEL;
+    this.rerankModelUri = config.rerankModel || process.env.QMD_RERANK_MODEL || DEFAULT_RERANK_MODEL;
     this.modelCacheDir = config.modelCacheDir || MODEL_CACHE_DIR;
     this.expandContextSize = resolveExpandContextSize(config.expandContextSize);
     this.inactivityTimeoutMs = config.inactivityTimeoutMs ?? DEFAULT_INACTIVITY_TIMEOUT_MS;
     this.disposeModelsOnInactivity = config.disposeModelsOnInactivity ?? false;
   }
 
+  get embedModelName(): string {
+    return this.embedModelUri;
+  }
+
   /**
    * Reset the inactivity timer. Called after each model operation.
    * When timer fires, models are unloaded to free memory (if no active sessions).
@@ -1559,8 +1563,7 @@ let defaultLlamaCpp: LlamaCpp | null = null;
  */
 export function getDefaultLlamaCpp(): LlamaCpp {
   if (!defaultLlamaCpp) {
-    const embedModel = process.env.QMD_EMBED_MODEL;
-    defaultLlamaCpp = new LlamaCpp(embedModel ? { embedModel } : {});
+    defaultLlamaCpp = new LlamaCpp();
   }
   return defaultLlamaCpp;
 }

+ 12 - 2
src/mcp/server.ts

@@ -18,6 +18,7 @@ import { WebStandardStreamableHTTPServerTransport }
   from "@modelcontextprotocol/sdk/server/webStandardStreamableHttp.js";
 import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
 import { z } from "zod";
+import { existsSync } from "fs";
 import {
   createStore,
   extractSnippet,
@@ -28,6 +29,7 @@ import {
   type ExpandedQuery,
   type IndexStatus,
 } from "../index.js";
+import { getConfigPath } from "../collections.js";
 
 // =============================================================================
 // Types for structured content
@@ -536,7 +538,11 @@ Intent-aware lex (C++ performance, not sports):
 // =============================================================================
 
 export async function startMcpServer(): Promise<void> {
-  const store = await createStore({ dbPath: getDefaultDbPath() });
+  const configPath = getConfigPath();
+  const store = await createStore({
+    dbPath: getDefaultDbPath(),
+    ...(existsSync(configPath) ? { configPath } : {}),
+  });
   const server = await createMcpServer(store);
   const transport = new StdioServerTransport();
   await server.connect(transport);
@@ -557,7 +563,11 @@ export type HttpServerHandle = {
  * Binds to localhost only. Returns a handle for shutdown and port discovery.
  */
 export async function startMcpHttpServer(port: number, options?: { quiet?: boolean }): Promise<HttpServerHandle> {
-  const store = await createStore({ dbPath: getDefaultDbPath() });
+  const configPath = getConfigPath();
+  const store = await createStore({
+    dbPath: getDefaultDbPath(),
+    ...(existsSync(configPath) ? { configPath } : {}),
+  });
 
   // Pre-fetch default collection names for REST endpoint
   const defaultCollectionNames = await store.getDefaultCollectionNames();

+ 6 - 5
src/store.ts

@@ -1414,6 +1414,7 @@ export async function generateEmbeddings(
 
   // Use store's LlamaCpp or global singleton, wrapped in a session
   const llm = getLlm(store);
+  const embedModelUri = llm.embedModelName;
 
   // Create a session manager for this llm instance
   const result = await withLLMSessionForLlm(llm, async (session) => {
@@ -1471,7 +1472,7 @@ export async function generateEmbeddings(
 
       if (!vectorTableInitialized) {
         const firstChunk = batchChunks[0]!;
-        const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, model);
+        const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, embedModelUri);
         const firstResult = await session.embed(firstText, { model });
         if (!firstResult) {
           throw new Error("Failed to get embedding dimensions from first chunk");
@@ -1503,7 +1504,7 @@ export async function generateEmbeddings(
 
         const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
         const chunkBatch = batchChunks.slice(batchStart, batchEnd);
-        const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, model));
+        const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, embedModelUri));
 
         try {
           const embeddings = await session.embedBatch(texts, { model });
@@ -1527,7 +1528,7 @@ export async function generateEmbeddings(
           } else {
             for (const chunk of chunkBatch) {
               try {
-                const text = formatDocForEmbedding(chunk.text, chunk.title, model);
+                const text = formatDocForEmbedding(chunk.text, chunk.title, embedModelUri);
                 const result = await session.embed(text, { model });
                 if (result) {
                   insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
@@ -3985,7 +3986,7 @@ export async function hybridQuery(
 
     // Batch embed all vector queries in a single call
     const llm = getLlm(store);
-    const textsToEmbed = vecQueries.map(q => formatQueryForEmbedding(q.text));
+    const textsToEmbed = vecQueries.map(q => formatQueryForEmbedding(q.text, llm.embedModelName));
     hooks?.onEmbedStart?.(textsToEmbed.length);
     const embedStart = Date.now();
     const embeddings = await llm.embedBatch(textsToEmbed);
@@ -4368,7 +4369,7 @@ export async function structuredSearch(
     );
     if (vecSearches.length > 0) {
       const llm = getLlm(store);
-      const textsToEmbed = vecSearches.map(s => formatQueryForEmbedding(s.query));
+      const textsToEmbed = vecSearches.map(s => formatQueryForEmbedding(s.query, llm.embedModelName));
       hooks?.onEmbedStart?.(textsToEmbed.length);
       const embedStart = Date.now();
       const embeddings = await llm.embedBatch(textsToEmbed);

+ 44 - 0
test/llm.test.ts

@@ -117,6 +117,50 @@ describe("LlamaCpp expand context size config", () => {
   });
 });
 
+describe("LlamaCpp model resolution (config > env > default)", () => {
+  const HARDCODED_EMBED = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
+  const HARDCODED_RERANK = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
+  const HARDCODED_GENERATE = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
+
+  test("uses hardcoded default when no config or env is set", () => {
+    const prev = process.env.QMD_EMBED_MODEL;
+    delete process.env.QMD_EMBED_MODEL;
+    try {
+      const llm = new LlamaCpp({}) as any;
+      expect(llm.embedModelUri).toBe(HARDCODED_EMBED);
+      expect(llm.rerankModelUri).toBe(HARDCODED_RERANK);
+      expect(llm.generateModelUri).toBe(HARDCODED_GENERATE);
+    } finally {
+      if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
+      else process.env.QMD_EMBED_MODEL = prev;
+    }
+  });
+
+  test("env var overrides hardcoded default", () => {
+    const prev = process.env.QMD_EMBED_MODEL;
+    process.env.QMD_EMBED_MODEL = "hf:custom/embed-model.gguf";
+    try {
+      const llm = new LlamaCpp({}) as any;
+      expect(llm.embedModelUri).toBe("hf:custom/embed-model.gguf");
+    } finally {
+      if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
+      else process.env.QMD_EMBED_MODEL = prev;
+    }
+  });
+
+  test("config overrides env var", () => {
+    const prev = process.env.QMD_EMBED_MODEL;
+    process.env.QMD_EMBED_MODEL = "hf:env/model.gguf";
+    try {
+      const llm = new LlamaCpp({ embedModel: "hf:config/model.gguf" }) as any;
+      expect(llm.embedModelUri).toBe("hf:config/model.gguf");
+    } finally {
+      if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
+      else process.env.QMD_EMBED_MODEL = prev;
+    }
+  });
+});
+
 describe("LlamaCpp rerank deduping", () => {
   test("deduplicates identical document texts before scoring", async () => {
     const llm = new LlamaCpp({}) as any;