bench-rerank.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #!/usr/bin/env bun
  2. /**
  3. * QMD Reranker Benchmark
  4. *
  5. * Measures reranking performance across different configurations.
  6. * Reports device, parallelism, memory, VRAM, and throughput.
  7. *
  8. * Usage:
  9. * bun src/bench-rerank.ts # full benchmark
  10. * bun src/bench-rerank.ts --quick # quick smoke test (10 docs, 1 iteration)
  11. * bun src/bench-rerank.ts --docs 100 # custom doc count
  12. */
  13. import {
  14. getLlama,
  15. resolveModelFile,
  16. LlamaLogLevel,
  17. type Llama,
  18. type LlamaModel,
  19. } from "node-llama-cpp";
  20. import { homedir } from "os";
  21. import { join } from "path";
  22. import { cpus } from "os";
  23. // ============================================================================
  24. // Config
  25. // ============================================================================
  26. const RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
  27. const MODEL_CACHE = join(homedir(), ".cache", "qmd", "models");
  28. const CONTEXT_SIZE = 2048;
  29. const args = process.argv.slice(2);
  30. const quick = args.includes("--quick");
  31. const docsIdx = args.indexOf("--docs");
  32. const DOC_COUNT = docsIdx >= 0 ? parseInt(args[docsIdx + 1]!) : (quick ? 10 : 40);
  33. const ITERATIONS = quick ? 1 : 3;
  34. const PARALLEL_CONFIGS = quick ? [1, 4] : [1, 2, 4, 8];
  35. // ============================================================================
  36. // Test data — realistic-ish chunks of varying length
  37. // ============================================================================
  38. const QUERY = "How do AI agents work and what are their limitations?";
  39. function generateDocs(n: number): string[] {
  40. const templates = [
  41. "Artificial intelligence agents are software systems that perceive their environment and take actions to achieve goals. They use techniques like reinforcement learning, planning, and natural language processing to operate autonomously.",
  42. "The transformer architecture, introduced in 2017, revolutionized natural language processing. Self-attention mechanisms allow models to weigh the importance of different parts of input sequences when generating outputs.",
  43. "Machine learning models require careful evaluation to avoid overfitting. Cross-validation, holdout sets, and metrics like precision, recall, and F1 score help assess generalization performance.",
  44. "Retrieval-augmented generation combines information retrieval with language models. Documents are embedded into vector spaces, retrieved based on query similarity, and used as context for generation.",
  45. "Neural network training involves forward propagation, loss computation, and backpropagation. Optimizers like Adam and SGD adjust weights to minimize the loss function over training iterations.",
  46. "Large language models exhibit emergent capabilities at scale, including few-shot learning, chain-of-thought reasoning, and instruction following. These properties were not explicitly trained for.",
  47. "Embedding models convert text into dense vector representations that capture semantic meaning. Similar texts produce similar vectors, enabling efficient similarity search and clustering.",
  48. "Autonomous agents face challenges including hallucination, lack of grounding, limited planning horizons, and difficulty with multi-step reasoning. Safety and alignment remain open research problems.",
  49. "The attention mechanism computes query-key-value interactions to determine which parts of the input are most relevant. Multi-head attention allows the model to attend to different representation subspaces.",
  50. "Fine-tuning adapts a pre-trained model to specific tasks using domain-specific data. Techniques like LoRA reduce the number of trainable parameters while maintaining performance.",
  51. ];
  52. return Array.from({ length: n }, (_, i) => templates[i % templates.length]!);
  53. }
  54. // ============================================================================
  55. // Helpers
  56. // ============================================================================
  57. function formatBytes(bytes: number): string {
  58. if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
  59. if (bytes < 1024 * 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
  60. return `${(bytes / (1024 * 1024 * 1024)).toFixed(2)} GB`;
  61. }
  62. function getMemUsage(): { rss: number; heapUsed: number } {
  63. const m = process.memoryUsage();
  64. return { rss: m.rss, heapUsed: m.heapUsed };
  65. }
  66. function median(arr: number[]): number {
  67. const sorted = [...arr].sort((a, b) => a - b);
  68. const mid = Math.floor(sorted.length / 2);
  69. return sorted.length % 2 !== 0 ? sorted[mid]! : (sorted[mid - 1]! + sorted[mid]!) / 2;
  70. }
  71. // ============================================================================
  72. // Benchmark runner
  73. // ============================================================================
  74. interface BenchResult {
  75. parallelism: number;
  76. contextSize: number;
  77. flashAttention: boolean;
  78. times: number[]; // ms per run
  79. medianMs: number;
  80. docsPerSec: number;
  81. vramPerContext: number; // bytes
  82. totalVram: number; // bytes
  83. peakRss: number; // bytes
  84. }
  85. async function benchmarkConfig(
  86. model: LlamaModel,
  87. llama: Llama,
  88. docs: string[],
  89. parallelism: number,
  90. flash: boolean,
  91. ): Promise<BenchResult> {
  92. // Measure VRAM before
  93. const vramBefore = llama.gpu ? await llama.getVramState() : null;
  94. const rssBefore = getMemUsage().rss;
  95. // Create contexts. On CPU, split threads evenly across contexts.
  96. const cpuThreads = !llama.gpu ? Math.floor(llama.cpuMathCores / parallelism) : 0;
  97. const contexts = [];
  98. for (let i = 0; i < parallelism; i++) {
  99. try {
  100. contexts.push(await model.createRankingContext({
  101. contextSize: CONTEXT_SIZE,
  102. flashAttention: flash,
  103. ...(cpuThreads > 0 ? { threads: cpuThreads } : {}),
  104. }));
  105. } catch {
  106. if (contexts.length === 0) {
  107. // Try without flash
  108. contexts.push(await model.createRankingContext({
  109. contextSize: CONTEXT_SIZE,
  110. ...(cpuThreads > 0 ? { threads: cpuThreads } : {}),
  111. }));
  112. }
  113. break;
  114. }
  115. }
  116. const actualParallelism = contexts.length;
  117. // Measure VRAM after context creation
  118. const vramAfter = llama.gpu ? await llama.getVramState() : null;
  119. const vramUsed = vramBefore && vramAfter ? vramAfter.used - vramBefore.used : 0;
  120. const vramPerCtx = actualParallelism > 0 ? vramUsed / actualParallelism : 0;
  121. // Warm up
  122. await contexts[0]!.rankAll(QUERY, docs.slice(0, 2));
  123. // Benchmark iterations
  124. const times: number[] = [];
  125. let peakRss = getMemUsage().rss;
  126. for (let iter = 0; iter < ITERATIONS; iter++) {
  127. const chunkSize = Math.ceil(docs.length / actualParallelism);
  128. const t0 = performance.now();
  129. const allScores = await Promise.all(
  130. Array.from({ length: actualParallelism }, (_, i) => {
  131. const chunk = docs.slice(i * chunkSize, (i + 1) * chunkSize);
  132. return chunk.length > 0 ? contexts[i]!.rankAll(QUERY, chunk) : Promise.resolve([]);
  133. })
  134. );
  135. const elapsed = performance.now() - t0;
  136. times.push(elapsed);
  137. // Verify scores are valid
  138. const flat = allScores.flat();
  139. if (flat.some(s => s < 0 || s > 1 || isNaN(s))) {
  140. throw new Error("Invalid scores detected");
  141. }
  142. const currentRss = getMemUsage().rss;
  143. if (currentRss > peakRss) peakRss = currentRss;
  144. }
  145. // Cleanup
  146. for (const ctx of contexts) await ctx.dispose();
  147. const med = median(times);
  148. return {
  149. parallelism: actualParallelism,
  150. contextSize: CONTEXT_SIZE,
  151. flashAttention: flash,
  152. times,
  153. medianMs: med,
  154. docsPerSec: (docs.length / med) * 1000,
  155. vramPerContext: vramPerCtx,
  156. totalVram: vramUsed,
  157. peakRss,
  158. };
  159. }
  160. // ============================================================================
  161. // Main
  162. // ============================================================================
  163. async function main() {
  164. console.log("═══════════════════════════════════════════════════════════════");
  165. console.log(" QMD Reranker Benchmark");
  166. console.log("═══════════════════════════════════════════════════════════════\n");
  167. const llama = await getLlama({
  168. // attempt to build
  169. build: "autoAttempt",
  170. logLevel: LlamaLogLevel.error
  171. });
  172. let gpuLabel: string = llama.gpu === false
  173. ? "cpu"
  174. : llama.gpu;
  175. // System info
  176. const cpuInfo = cpus();
  177. const cpuModel = cpuInfo[0]?.model || "unknown";
  178. const cpuCount = cpuInfo.length;
  179. console.log("System");
  180. console.log(` CPU: ${cpuModel}`);
  181. console.log(` Cores: ${cpuCount} (${llama.cpuMathCores} math)`);
  182. console.log(` Device: ${gpuLabel}`);
  183. if (llama.gpu) {
  184. const gpuNames = await llama.getGpuDeviceNames();
  185. const counts = new Map<string, number>();
  186. for (const name of gpuNames) counts.set(name, (counts.get(name) || 0) + 1);
  187. const devStr = Array.from(counts.entries())
  188. .map(([name, n]) => n > 1 ? `${n}× ${name}` : name).join(", ");
  189. console.log(` GPU: ${devStr}`);
  190. const vram = await llama.getVramState();
  191. console.log(` VRAM: ${formatBytes(vram.total)} total, ${formatBytes(vram.free)} free`);
  192. }
  193. console.log(` RAM: ${formatBytes(getMemUsage().rss)} RSS at start`);
  194. // Load model
  195. console.log(`\nModel`);
  196. console.log(` URI: ${RERANK_MODEL}`);
  197. const modelPath = await resolveModelFile(RERANK_MODEL, MODEL_CACHE);
  198. const vramPreModel = llama.gpu ? await llama.getVramState() : null;
  199. const model = await llama.loadModel({ modelPath });
  200. const vramPostModel = llama.gpu ? await llama.getVramState() : null;
  201. const modelVram = vramPreModel && vramPostModel ? vramPostModel.used - vramPreModel.used : 0;
  202. console.log(` Params: ${model.trainContextSize} train ctx`);
  203. if (modelVram > 0) console.log(` VRAM: ${formatBytes(modelVram)} (model weights)`);
  204. // Generate test docs
  205. const docs = generateDocs(DOC_COUNT);
  206. console.log(`\nBenchmark`);
  207. console.log(` Documents: ${DOC_COUNT}`);
  208. console.log(` Ctx size: ${CONTEXT_SIZE}`);
  209. console.log(` Iterations:${ITERATIONS}`);
  210. console.log(` Query: "${QUERY.slice(0, 50)}..."`);
  211. // Run benchmarks
  212. const results: BenchResult[] = [];
  213. for (const p of PARALLEL_CONFIGS) {
  214. if (!llama.gpu && p > 1) {
  215. // CPU: only test if we have enough cores (at least 4 per context)
  216. if (llama.cpuMathCores < p * 4) {
  217. console.log(`\n [${p} ctx] skipped (need ${p * 4} cores, have ${llama.cpuMathCores})`);
  218. continue;
  219. }
  220. }
  221. // Test with flash attention
  222. process.stdout.write(`\n [${p} ctx, flash] running...`);
  223. try {
  224. const r = await benchmarkConfig(model, llama, docs, p, true);
  225. results.push(r);
  226. process.stdout.write(` ${r.medianMs.toFixed(0)}ms (${r.docsPerSec.toFixed(1)} docs/s)\n`);
  227. } catch (e: any) {
  228. process.stdout.write(` failed: ${e.message}\n`);
  229. // Try without flash
  230. process.stdout.write(` [${p} ctx, no flash] running...`);
  231. try {
  232. const r = await benchmarkConfig(model, llama, docs, p, false);
  233. results.push(r);
  234. process.stdout.write(` ${r.medianMs.toFixed(0)}ms (${r.docsPerSec.toFixed(1)} docs/s)\n`);
  235. } catch (e2: any) {
  236. process.stdout.write(` failed: ${e2.message}\n`);
  237. }
  238. }
  239. }
  240. // Summary table
  241. console.log("\n═══════════════════════════════════════════════════════════════");
  242. console.log(" Results");
  243. console.log("═══════════════════════════════════════════════════════════════\n");
  244. const header = " Ctx Flash Median Docs/s VRAM/ctx Total VRAM Peak RSS";
  245. const sep = " ─── ───── ────── ────── ──────── ────────── ────────";
  246. console.log(header);
  247. console.log(sep);
  248. const baseline = results[0]?.medianMs ?? 1;
  249. for (const r of results) {
  250. const speedup = baseline / r.medianMs;
  251. const speedupStr = r === results[0] ? " " : `(${speedup.toFixed(1)}×)`;
  252. console.log(
  253. ` ${String(r.parallelism).padStart(3)} ` +
  254. `${r.flashAttention ? " yes " : " no "} ` +
  255. `${r.medianMs.toFixed(0).padStart(5)}ms ` +
  256. `${r.docsPerSec.toFixed(1).padStart(6)} ` +
  257. `${formatBytes(r.vramPerContext).padStart(8)} ` +
  258. `${formatBytes(r.totalVram).padStart(10)} ` +
  259. `${formatBytes(r.peakRss).padStart(8)} ` +
  260. speedupStr
  261. );
  262. }
  263. // Best config
  264. if (results.length > 0) {
  265. const best = results.reduce((a, b) => a.docsPerSec > b.docsPerSec ? a : b);
  266. console.log(`\n Best: ${best.parallelism} contexts, flash=${best.flashAttention}`);
  267. console.log(` ${best.medianMs.toFixed(0)}ms for ${DOC_COUNT} docs (${best.docsPerSec.toFixed(1)} docs/s)`);
  268. if (best.totalVram > 0) console.log(` ${formatBytes(best.totalVram)} VRAM`);
  269. }
  270. console.log("");
  271. await model.dispose();
  272. await llama.dispose();
  273. }
  274. main().catch(console.error);