Explorar el Código

add qmd model pull and refresh logic

Tobi Lütke hace 3 meses
padre
commit
479b68bbf1
Se han modificado 2 ficheros con 142 adiciones y 51 borrados
  1. 120 50
      src/llm.ts
  2. 22 1
      src/qmd.ts

+ 120 - 50
src/llm.ts

@@ -16,7 +16,7 @@ import {
 } from "node-llama-cpp";
 import { homedir } from "os";
 import { join } from "path";
-import { existsSync, mkdirSync } from "fs";
+import { existsSync, mkdirSync, statSync, unlinkSync, readdirSync, readFileSync, writeFileSync } from "fs";
 
 // =============================================================================
 // Embedding Formatting Functions
@@ -178,8 +178,105 @@ const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-re
 // const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf";
 const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
 
+export const DEFAULT_EMBED_MODEL_URI = DEFAULT_EMBED_MODEL;
+export const DEFAULT_RERANK_MODEL_URI = DEFAULT_RERANK_MODEL;
+export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL;
+
 // Local model cache directory
 const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models");
+export const DEFAULT_MODEL_CACHE_DIR = MODEL_CACHE_DIR;
+
+export type PullResult = {
+  model: string;
+  path: string;
+  sizeBytes: number;
+  refreshed: boolean;
+};
+
+type HfRef = {
+  repo: string;
+  file: string;
+};
+
+function parseHfUri(model: string): HfRef | null {
+  if (!model.startsWith("hf:")) return null;
+  const without = model.slice(3);
+  const parts = without.split("/");
+  if (parts.length < 3) return null;
+  const repo = parts.slice(0, 2).join("/");
+  const file = parts.slice(2).join("/");
+  return { repo, file };
+}
+
+async function getRemoteEtag(ref: HfRef): Promise<string | null> {
+  const url = `https://huggingface.co/${ref.repo}/resolve/main/${ref.file}`;
+  try {
+    const resp = await fetch(url, { method: "HEAD" });
+    if (!resp.ok) return null;
+    const etag = resp.headers.get("etag");
+    return etag || null;
+  } catch {
+    return null;
+  }
+}
+
+export async function pullModels(
+  models: string[],
+  options: { refresh?: boolean; cacheDir?: string } = {}
+): Promise<PullResult[]> {
+  const cacheDir = options.cacheDir || MODEL_CACHE_DIR;
+  if (!existsSync(cacheDir)) {
+    mkdirSync(cacheDir, { recursive: true });
+  }
+
+  const results: PullResult[] = [];
+  for (const model of models) {
+    let refreshed = false;
+    const hfRef = parseHfUri(model);
+    const filename = model.split("/").pop();
+    const entries = readdirSync(cacheDir, { withFileTypes: true });
+    const cached = filename
+      ? entries
+          .filter((entry) => entry.isFile() && entry.name.includes(filename))
+          .map((entry) => join(cacheDir, entry.name))
+      : [];
+
+    if (hfRef && filename) {
+      const etagPath = join(cacheDir, `${filename}.etag`);
+      const remoteEtag = await getRemoteEtag(hfRef);
+      const localEtag = existsSync(etagPath)
+        ? readFileSync(etagPath, "utf-8").trim()
+        : null;
+      const shouldRefresh =
+        options.refresh || !remoteEtag || remoteEtag !== localEtag || cached.length === 0;
+
+      if (shouldRefresh) {
+        for (const candidate of cached) {
+          if (existsSync(candidate)) unlinkSync(candidate);
+        }
+        if (existsSync(etagPath)) unlinkSync(etagPath);
+        refreshed = cached.length > 0;
+      }
+    } else if (options.refresh && filename) {
+      for (const candidate of cached) {
+        if (existsSync(candidate)) unlinkSync(candidate);
+        refreshed = true;
+      }
+    }
+
+    const path = await resolveModelFile(model, cacheDir);
+    const sizeBytes = existsSync(path) ? statSync(path).size : 0;
+    if (hfRef && filename) {
+      const remoteEtag = await getRemoteEtag(hfRef);
+      if (remoteEtag) {
+        const etagPath = join(cacheDir, `${filename}.etag`);
+        writeFileSync(etagPath, remoteEtag + "\n", "utf-8");
+      }
+    }
+    results.push({ model, path, sizeBytes, refreshed });
+  }
+  return results;
+}
 
 // =============================================================================
 // LLM Interface
@@ -706,48 +803,7 @@ export class LlamaCpp implements LLM {
       `
     });
 
-    const prompt = `You are a search query optimization expert. Your task is to improve retrieval by rewriting queries and generating hypothetical documents.
-
-Original Query: ${query}
-
-${context ? `Additional Context, ONLY USE IF RELEVANT:\n\n<context>${context}</context>` : ""}
-
-## Step 1: Query Analysis
-Identify entities, search intent, and missing context.
-
-## Step 2: Generate Hypothetical Document
-Write a focused sentence passage that would answer the query. Include specific terminology and domain vocabulary.
-
-## Step 3: Query Rewrites
-Generate 2-3 alternative search queries that resolve ambiguities. Use terminology from the hypothetical document.
-
-## Step 4: Final Retrieval Text
-Output MAX ONE 'hyde' line FIRST, then 1-3 'lex' lines, then 1-3 'vec' lines.
-
-<format>
-hyde: {complete hypothetical document passage from Step 2 on a SINGLE LINE}
-lex: {single search term}
-vec: {single vector query}
-</format>
-
-<example>
-Example (FOR FORMAT ONLY - DO NOT COPY THIS CONTENT):
-hyde: This is an example of a hypothetical document passage that would answer the example query. It contains multiple sentences and relevant vocabulary.
-lex: example keyword 1
-lex: example keyword 2
-vec: example semantic query
-</example>
-
-<rules>
-- DO NOT repeat the same line.
-- Each 'lex:' line MUST be a different keyword variation based on the ORIGINAL QUERY.
-- Each 'vec:' line MUST be a different semantic variation based on the ORIGINAL QUERY.
-- The 'hyde:' line MUST be the full sentence passage from Step 2, but all on one line.
-- DO NOT use the example content above.
-${!includeLexical ? "- Do NOT output any 'lex:' lines" : ""}
-</rules>
-
-Final Output:`;
+    const prompt = `/no_think Expand this search query: ${query}`;
 
     // Create fresh context for each call
     const genContext = await this.generateModel!.createContext();
@@ -757,25 +813,40 @@ Final Output:`;
     try {
       const result = await session.prompt(prompt, {
         grammar,
-        maxTokens: 1000,
-        temperature: 1,
+        maxTokens: 600,
+        temperature: 0.1,
       });
 
       const lines = result.trim().split("\n");
+      const queryLower = query.toLowerCase();
+      const queryTerms = queryLower.replace(/[^a-z0-9\s]/g, " ").split(/\s+/).filter(Boolean);
+
+      const hasQueryTerm = (text: string): boolean => {
+        const lower = text.toLowerCase();
+        if (queryTerms.length === 0) return true;
+        return queryTerms.some(term => lower.includes(term));
+      };
+
       const queryables: Queryable[] = lines.map(line => {
         const colonIdx = line.indexOf(":");
         if (colonIdx === -1) return null;
         const type = line.slice(0, colonIdx).trim();
         if (type !== 'lex' && type !== 'vec' && type !== 'hyde') return null;
         const text = line.slice(colonIdx + 1).trim();
+        if (!hasQueryTerm(text)) return null;
         return { type: type as QueryType, text };
       }).filter((q): q is Queryable => q !== null);
 
       // Filter out lex entries if not requested
-      if (!includeLexical) {
-        return queryables.filter(q => q.type !== 'lex');
-      }
-      return queryables;
+      const filtered = includeLexical ? queryables : queryables.filter(q => q.type !== 'lex');
+      if (filtered.length > 0) return filtered;
+
+      const fallback: Queryable[] = [
+        { type: 'hyde', text: `Information about ${query}` },
+        { type: 'lex', text: query },
+        { type: 'vec', text: query },
+      ];
+      return includeLexical ? fallback : fallback.filter(q => q.type !== 'lex');
     } catch (error) {
       console.error("Structured query expansion failed:", error);
       // Fallback to original query
@@ -1122,4 +1193,3 @@ export async function disposeDefaultLlamaCpp(): Promise<void> {
     defaultLlamaCpp = null;
   }
 }
-

+ 22 - 1
src/qmd.ts

@@ -65,7 +65,7 @@ import {
   createStore,
   getDefaultDbPath,
 } from "./store.js";
-import { getDefaultLlamaCpp, disposeDefaultLlamaCpp, withLLMSession, type ILLMSession, type RerankDocument, type Queryable, type QueryType } from "./llm.js";
+import { getDefaultLlamaCpp, disposeDefaultLlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR, type ILLMSession, type RerankDocument, type Queryable, type QueryType } from "./llm.js";
 import type { SearchResult, RankedResult } from "./store.js";
 import {
   formatSearchResults,
@@ -2316,6 +2316,7 @@ function parseCLI() {
       force: { type: "boolean", short: "f" },
       // Update options
       pull: { type: "boolean" },  // git pull before update
+      refresh: { type: "boolean" },
       // Get options
       l: { type: "string" },  // max lines
       from: { type: "string" },  // start line
@@ -2595,6 +2596,26 @@ if (import.meta.main) {
       await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force);
       break;
 
+    case "pull": {
+      const refresh = cli.values.refresh === undefined ? false : Boolean(cli.values.refresh);
+      const models = [
+        DEFAULT_EMBED_MODEL_URI,
+        DEFAULT_GENERATE_MODEL_URI,
+        DEFAULT_RERANK_MODEL_URI,
+      ];
+      console.log(`${c.bold}Pulling models${c.reset}`);
+      const results = await pullModels(models, {
+        refresh,
+        cacheDir: DEFAULT_MODEL_CACHE_DIR,
+      });
+      for (const result of results) {
+        const size = formatBytes(result.sizeBytes);
+        const note = result.refreshed ? "refreshed" : "cached/checked";
+        console.log(`- ${result.model} -> ${result.path} (${size}, ${note})`);
+      }
+      break;
+    }
+
     case "search":
       if (!cli.query) {
         console.error("Usage: qmd search [options] <query>");