llm.test.ts 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. /**
  2. * llm.test.ts - Unit tests for the LLM abstraction layer (node-llama-cpp)
  3. *
  4. * Run with: bun test src/llm.test.ts
  5. *
  6. * These tests require the actual models to be downloaded. Run the embed or
  7. * rerank functions first to trigger model downloads.
  8. */
  9. import { describe, test, expect, beforeAll, afterAll } from "bun:test";
  10. import {
  11. LlamaCpp,
  12. getDefaultLlamaCpp,
  13. disposeDefaultLlamaCpp,
  14. withLLMSession,
  15. canUnloadLLM,
  16. SessionReleasedError,
  17. type RerankDocument,
  18. type ILLMSession,
  19. } from "./llm.js";
  20. // =============================================================================
  21. // Singleton Tests (no model loading required)
  22. // =============================================================================
  23. describe("Default LlamaCpp Singleton", () => {
  24. // Test singleton behavior without resetting to avoid orphan instances
  25. test("getDefaultLlamaCpp returns same instance on subsequent calls", () => {
  26. const llm1 = getDefaultLlamaCpp();
  27. const llm2 = getDefaultLlamaCpp();
  28. expect(llm1).toBe(llm2);
  29. expect(llm1).toBeInstanceOf(LlamaCpp);
  30. });
  31. });
  32. // =============================================================================
  33. // Model Existence Tests
  34. // =============================================================================
  35. describe("LlamaCpp.modelExists", () => {
  36. test("returns exists:true for HuggingFace model URIs", async () => {
  37. const llm = getDefaultLlamaCpp();
  38. const result = await llm.modelExists("hf:org/repo/model.gguf");
  39. expect(result.exists).toBe(true);
  40. expect(result.name).toBe("hf:org/repo/model.gguf");
  41. });
  42. test("returns exists:false for non-existent local paths", async () => {
  43. const llm = getDefaultLlamaCpp();
  44. const result = await llm.modelExists("/nonexistent/path/model.gguf");
  45. expect(result.exists).toBe(false);
  46. expect(result.name).toBe("/nonexistent/path/model.gguf");
  47. });
  48. });
  49. // =============================================================================
  50. // Integration Tests (require actual models)
  51. // =============================================================================
  52. describe("LlamaCpp Integration", () => {
  53. // Use the singleton to avoid multiple Metal contexts
  54. const llm = getDefaultLlamaCpp();
  55. afterAll(async () => {
  56. // Ensure native resources are released to avoid ggml-metal asserts on process exit.
  57. await disposeDefaultLlamaCpp();
  58. });
  59. describe("embed", () => {
  60. test("returns embedding with correct dimensions", async () => {
  61. const result = await llm.embed("Hello world");
  62. expect(result).not.toBeNull();
  63. expect(result!.embedding).toBeInstanceOf(Array);
  64. expect(result!.embedding.length).toBeGreaterThan(0);
  65. // embeddinggemma outputs 768 dimensions
  66. expect(result!.embedding.length).toBe(768);
  67. });
  68. test("returns consistent embeddings for same input", async () => {
  69. const result1 = await llm.embed("test text");
  70. const result2 = await llm.embed("test text");
  71. expect(result1).not.toBeNull();
  72. expect(result2).not.toBeNull();
  73. // Embeddings should be identical for the same input
  74. for (let i = 0; i < result1!.embedding.length; i++) {
  75. expect(result1!.embedding[i]).toBeCloseTo(result2!.embedding[i]!, 5);
  76. }
  77. });
  78. test("returns different embeddings for different inputs", async () => {
  79. const result1 = await llm.embed("cats are great");
  80. const result2 = await llm.embed("database optimization");
  81. expect(result1).not.toBeNull();
  82. expect(result2).not.toBeNull();
  83. // Calculate cosine similarity - should be less than 1.0 (not identical)
  84. let dotProduct = 0;
  85. let norm1 = 0;
  86. let norm2 = 0;
  87. for (let i = 0; i < result1!.embedding.length; i++) {
  88. const v1 = result1!.embedding[i]!;
  89. const v2 = result2!.embedding[i]!;
  90. dotProduct += v1 * v2;
  91. norm1 += v1 ** 2;
  92. norm2 += v2 ** 2;
  93. }
  94. const similarity = dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
  95. expect(similarity).toBeLessThan(0.95); // Should be meaningfully different
  96. });
  97. });
  98. describe("embedBatch", () => {
  99. test("returns embeddings for multiple texts", async () => {
  100. const texts = ["Hello world", "Test text", "Another document"];
  101. const results = await llm.embedBatch(texts);
  102. expect(results).toHaveLength(3);
  103. for (const result of results) {
  104. expect(result).not.toBeNull();
  105. expect(result!.embedding.length).toBe(768);
  106. }
  107. });
  108. test("returns same results as individual embed calls", async () => {
  109. const texts = ["cats are great", "dogs are awesome"];
  110. // Get batch embeddings
  111. const batchResults = await llm.embedBatch(texts);
  112. // Get individual embeddings
  113. const individualResults = await Promise.all(texts.map(t => llm.embed(t)));
  114. // Compare - should be identical
  115. for (let i = 0; i < texts.length; i++) {
  116. expect(batchResults[i]).not.toBeNull();
  117. expect(individualResults[i]).not.toBeNull();
  118. for (let j = 0; j < batchResults[i]!.embedding.length; j++) {
  119. expect(batchResults[i]!.embedding[j]).toBeCloseTo(individualResults[i]!.embedding[j]!, 5);
  120. }
  121. }
  122. });
  123. test("handles empty array", async () => {
  124. const results = await llm.embedBatch([]);
  125. expect(results).toHaveLength(0);
  126. });
  127. test("batch is faster than sequential", async () => {
  128. const texts = Array(10).fill(null).map((_, i) => `Document number ${i} with content`);
  129. // Time batch
  130. const batchStart = Date.now();
  131. await llm.embedBatch(texts);
  132. const batchTime = Date.now() - batchStart;
  133. // Time sequential
  134. const seqStart = Date.now();
  135. for (const text of texts) {
  136. await llm.embed(text);
  137. }
  138. const seqTime = Date.now() - seqStart;
  139. console.log(`Batch: ${batchTime}ms, Sequential: ${seqTime}ms`);
  140. // Performance is machine/load dependent. We only assert batch isn't drastically worse.
  141. expect(batchTime).toBeLessThanOrEqual(seqTime * 3);
  142. });
  143. test("handles concurrent embedBatch calls on fresh instance without race condition", async () => {
  144. // This test verifies the fix for a race condition where concurrent calls to
  145. // ensureEmbedContext() could create multiple contexts. Without the promise guard,
  146. // each concurrent embedBatch call sees embedContext === null and creates its own
  147. // context, causing resource leaks and potential "Context is disposed" errors.
  148. //
  149. // See: https://github.com/tobi/qmd/pull/54
  150. //
  151. // The fix uses a promise guard to ensure only one context creation runs at a time.
  152. // We verify this by instrumenting createEmbeddingContext to count invocations.
  153. const freshLlm = new LlamaCpp({});
  154. let contextCreateCount = 0;
  155. // Instrument the model's createEmbeddingContext to count calls
  156. const originalEnsureEmbedModel = (freshLlm as any).ensureEmbedModel.bind(freshLlm);
  157. let modelInstrumented = false;
  158. (freshLlm as any).ensureEmbedModel = async function() {
  159. const model = await originalEnsureEmbedModel();
  160. if (!modelInstrumented) {
  161. modelInstrumented = true;
  162. const originalCreate = model.createEmbeddingContext.bind(model);
  163. model.createEmbeddingContext = async function(...args: any[]) {
  164. contextCreateCount++;
  165. return originalCreate(...args);
  166. };
  167. }
  168. return model;
  169. };
  170. const texts = Array(10).fill(null).map((_, i) => `Document ${i}`);
  171. // Call embedBatch 5 TIMES in parallel on fresh instance.
  172. // Without the promise guard fix, this would create 5 contexts (one per call).
  173. // With the fix, only 1 context should be created.
  174. const batches = await Promise.all([
  175. freshLlm.embedBatch(texts.slice(0, 2)),
  176. freshLlm.embedBatch(texts.slice(2, 4)),
  177. freshLlm.embedBatch(texts.slice(4, 6)),
  178. freshLlm.embedBatch(texts.slice(6, 8)),
  179. freshLlm.embedBatch(texts.slice(8, 10)),
  180. ]);
  181. const allResults = batches.flat();
  182. expect(allResults).toHaveLength(10);
  183. const successCount = allResults.filter(r => r !== null).length;
  184. expect(successCount).toBe(10);
  185. // THE KEY ASSERTION: Only 1 context should be created, not 5
  186. // Without the fix, contextCreateCount would be 5 (one per concurrent embedBatch call)
  187. console.log(`Context creation count: ${contextCreateCount} (expected: 1)`);
  188. expect(contextCreateCount).toBe(1);
  189. await freshLlm.dispose();
  190. }, 60000);
  191. });
  192. describe("rerank", () => {
  193. test("scores capital of France question correctly", async () => {
  194. const query = "What is the capital of France?";
  195. const documents: RerankDocument[] = [
  196. { file: "butterflies.txt", text: "Butterflies indeed fly through the garden." },
  197. { file: "france.txt", text: "The capital of France is Paris." },
  198. { file: "canada.txt", text: "The capital of Canada is Ottawa." },
  199. ];
  200. const result = await llm.rerank(query, documents);
  201. expect(result.results).toHaveLength(3);
  202. // The France document should score highest
  203. expect(result.results[0]!.file).toBe("france.txt");
  204. expect(result.results[0]!.score).toBeGreaterThan(0.7);
  205. // Canada should be somewhat relevant (also about capitals)
  206. expect(result.results[1]!.file).toBe("canada.txt");
  207. // Butterflies should score lowest
  208. expect(result.results[2]!.file).toBe("butterflies.txt");
  209. expect(result.results[2]!.score).toBeLessThan(0.6);
  210. });
  211. test("scores authentication query correctly", async () => {
  212. const query = "How do I configure authentication?";
  213. const documents: RerankDocument[] = [
  214. { file: "weather.md", text: "The weather today is sunny with mild temperatures." },
  215. { file: "auth.md", text: "Authentication can be configured by setting the AUTH_SECRET environment variable." },
  216. { file: "pizza.md", text: "Our restaurant serves the best pizza in town." },
  217. { file: "jwt.md", text: "JWT authentication requires a secret key and expiration time." },
  218. ];
  219. const result = await llm.rerank(query, documents);
  220. expect(result.results).toHaveLength(4);
  221. // Auth documents should score highest
  222. const topTwo = result.results.slice(0, 2).map((r) => r.file);
  223. expect(topTwo).toContain("auth.md");
  224. expect(topTwo).toContain("jwt.md");
  225. // Irrelevant documents should score lowest
  226. const bottomTwo = result.results.slice(2).map((r) => r.file);
  227. expect(bottomTwo).toContain("weather.md");
  228. expect(bottomTwo).toContain("pizza.md");
  229. });
  230. test("handles programming queries correctly", async () => {
  231. const query = "How do I handle errors in JavaScript?";
  232. const documents: RerankDocument[] = [
  233. { file: "cooking.md", text: "To make a good pasta, boil water and add salt." },
  234. { file: "errors.md", text: "Use try-catch blocks to handle JavaScript errors gracefully." },
  235. { file: "python.md", text: "Python uses try-except for exception handling." },
  236. ];
  237. const result = await llm.rerank(query, documents);
  238. // JavaScript errors doc should score highest
  239. expect(result.results[0]!.file).toBe("errors.md");
  240. expect(result.results[0]!.score).toBeGreaterThan(0.7);
  241. // Python doc might be somewhat relevant (same concept, different language)
  242. // Cooking should be least relevant
  243. expect(result.results[2]!.file).toBe("cooking.md");
  244. });
  245. test("handles empty document list", async () => {
  246. const result = await llm.rerank("test query", []);
  247. expect(result.results).toHaveLength(0);
  248. });
  249. test("handles single document", async () => {
  250. const result = await llm.rerank("test", [{ file: "doc.md", text: "content" }]);
  251. expect(result.results).toHaveLength(1);
  252. expect(result.results[0]!.file).toBe("doc.md");
  253. });
  254. test("preserves original file paths", async () => {
  255. const documents: RerankDocument[] = [
  256. { file: "path/to/doc1.md", text: "content one" },
  257. { file: "another/path/doc2.md", text: "content two" },
  258. ];
  259. const result = await llm.rerank("query", documents);
  260. const files = result.results.map((r) => r.file).sort();
  261. expect(files).toEqual(["another/path/doc2.md", "path/to/doc1.md"]);
  262. });
  263. test("returns scores between 0 and 1", async () => {
  264. const documents: RerankDocument[] = [
  265. { file: "a.md", text: "The quick brown fox jumps over the lazy dog." },
  266. { file: "b.md", text: "Machine learning algorithms process data efficiently." },
  267. { file: "c.md", text: "React components use JSX syntax for rendering." },
  268. ];
  269. const result = await llm.rerank("Tell me about animals", documents);
  270. for (const doc of result.results) {
  271. expect(doc.score).toBeGreaterThanOrEqual(0);
  272. expect(doc.score).toBeLessThanOrEqual(1);
  273. }
  274. });
  275. test("batch reranks multiple documents efficiently", async () => {
  276. // Create 10 documents to verify batch processing works
  277. const documents: RerankDocument[] = Array(10)
  278. .fill(null)
  279. .map((_, i) => ({
  280. file: `doc${i}.md`,
  281. text: `Document number ${i} with some content about topic ${i % 3}`,
  282. }));
  283. const start = Date.now();
  284. const result = await llm.rerank("topic 1", documents);
  285. const elapsed = Date.now() - start;
  286. expect(result.results).toHaveLength(10);
  287. // Verify all documents are returned with valid scores
  288. for (const doc of result.results) {
  289. expect(doc.score).toBeGreaterThanOrEqual(0);
  290. expect(doc.score).toBeLessThanOrEqual(1);
  291. }
  292. // Log timing for monitoring batch performance
  293. console.log(`Batch rerank of 10 docs took ${elapsed}ms`);
  294. });
  295. });
  296. describe("expandQuery", () => {
  297. test("returns query expansions with correct types", async () => {
  298. const result = await llm.expandQuery("test query");
  299. // Result is Queryable[] containing lex, vec, and/or hyde entries
  300. expect(result.length).toBeGreaterThanOrEqual(1);
  301. // Each result should have a valid type
  302. for (const q of result) {
  303. expect(["lex", "vec", "hyde"]).toContain(q.type);
  304. expect(q.text.length).toBeGreaterThan(0);
  305. }
  306. }, 30000); // 30s timeout for model loading
  307. test("can exclude lexical queries", async () => {
  308. const result = await llm.expandQuery("authentication setup", { includeLexical: false });
  309. // Should not contain any 'lex' type entries
  310. const lexEntries = result.filter(q => q.type === "lex");
  311. expect(lexEntries).toHaveLength(0);
  312. });
  313. });
  314. });
  315. // =============================================================================
  316. // Session Management Tests
  317. // =============================================================================
  318. describe("LLM Session Management", () => {
  319. describe("withLLMSession", () => {
  320. test("session provides access to LLM operations", async () => {
  321. const result = await withLLMSession(async (session) => {
  322. expect(session.isValid).toBe(true);
  323. const embedding = await session.embed("test text");
  324. expect(embedding).not.toBeNull();
  325. expect(embedding!.embedding.length).toBe(768);
  326. return "success";
  327. });
  328. expect(result).toBe("success");
  329. });
  330. test("session is invalid after release", async () => {
  331. let capturedSession: ILLMSession | null = null;
  332. await withLLMSession(async (session) => {
  333. capturedSession = session;
  334. expect(session.isValid).toBe(true);
  335. });
  336. // Session should be invalid after withLLMSession returns
  337. expect(capturedSession).not.toBeNull();
  338. expect(capturedSession!.isValid).toBe(false);
  339. });
  340. test("session prevents idle unload during operations", async () => {
  341. await withLLMSession(async (session) => {
  342. // While inside a session, canUnloadLLM should return false
  343. expect(canUnloadLLM()).toBe(false);
  344. // Perform an operation
  345. await session.embed("test");
  346. // Still should not be able to unload
  347. expect(canUnloadLLM()).toBe(false);
  348. });
  349. // After session ends, should be able to unload
  350. expect(canUnloadLLM()).toBe(true);
  351. });
  352. test("nested sessions increment ref count", async () => {
  353. await withLLMSession(async (outerSession) => {
  354. expect(canUnloadLLM()).toBe(false);
  355. await withLLMSession(async (innerSession) => {
  356. expect(canUnloadLLM()).toBe(false);
  357. expect(innerSession.isValid).toBe(true);
  358. expect(outerSession.isValid).toBe(true);
  359. });
  360. // Inner session released, but outer still active
  361. expect(canUnloadLLM()).toBe(false);
  362. expect(outerSession.isValid).toBe(true);
  363. });
  364. // All sessions released
  365. expect(canUnloadLLM()).toBe(true);
  366. });
  367. test("session embedBatch works correctly", async () => {
  368. await withLLMSession(async (session) => {
  369. const texts = ["Hello world", "Test text", "Another document"];
  370. const results = await session.embedBatch(texts);
  371. expect(results).toHaveLength(3);
  372. for (const result of results) {
  373. expect(result).not.toBeNull();
  374. expect(result!.embedding.length).toBe(768);
  375. }
  376. });
  377. });
  378. test("session rerank works correctly", async () => {
  379. await withLLMSession(async (session) => {
  380. const documents: RerankDocument[] = [
  381. { file: "a.txt", text: "The capital of France is Paris." },
  382. { file: "b.txt", text: "Dogs are great pets." },
  383. ];
  384. const result = await session.rerank("What is the capital of France?", documents);
  385. expect(result.results).toHaveLength(2);
  386. expect(result.results[0]!.file).toBe("a.txt");
  387. expect(result.results[0]!.score).toBeGreaterThan(result.results[1]!.score);
  388. });
  389. });
  390. test("max duration aborts session after timeout", async () => {
  391. let aborted = false;
  392. try {
  393. await withLLMSession(async (session) => {
  394. // Wait longer than max duration
  395. await new Promise(resolve => setTimeout(resolve, 150));
  396. // This operation should throw because session was aborted
  397. await session.embed("test");
  398. }, { maxDuration: 50 }); // 50ms max
  399. } catch (err) {
  400. if (err instanceof SessionReleasedError) {
  401. aborted = true;
  402. } else {
  403. throw err;
  404. }
  405. }
  406. expect(aborted).toBe(true);
  407. }, 5000);
  408. test("external abort signal propagates to session", async () => {
  409. const abortController = new AbortController();
  410. let sessionAborted = false;
  411. const promise = withLLMSession(async (session) => {
  412. // Wait a bit then check if aborted
  413. await new Promise(resolve => setTimeout(resolve, 100));
  414. if (!session.isValid) {
  415. sessionAborted = true;
  416. throw new SessionReleasedError("Session aborted");
  417. }
  418. return "should not reach";
  419. }, { signal: abortController.signal });
  420. // Abort after 20ms
  421. setTimeout(() => abortController.abort(), 20);
  422. try {
  423. await promise;
  424. } catch (err) {
  425. // Expected
  426. }
  427. expect(sessionAborted).toBe(true);
  428. }, 5000);
  429. test("session provides abort signal for monitoring", async () => {
  430. await withLLMSession(async (session) => {
  431. expect(session.signal).toBeInstanceOf(AbortSignal);
  432. expect(session.signal.aborted).toBe(false);
  433. });
  434. });
  435. test("returns value from callback", async () => {
  436. const result = await withLLMSession(async (session) => {
  437. await session.embed("test");
  438. return { status: "complete", count: 42 };
  439. });
  440. expect(result).toEqual({ status: "complete", count: 42 });
  441. });
  442. test("propagates errors from callback", async () => {
  443. const customError = new Error("Custom test error");
  444. await expect(
  445. withLLMSession(async () => {
  446. throw customError;
  447. })
  448. ).rejects.toThrow("Custom test error");
  449. });
  450. });
  451. });