Преглед на файлове

Fix QMD_LLAMA_GPU backend override handling

Jeff Gardner преди 1 месец
родител
ревизия
1ecb5c9f96
променени са 3 файла, в които са добавени 54 реда и са изтрити 7 реда
  1. 4 0
      CHANGELOG.md
  2. 17 7
      src/llm.ts
  3. 33 0
      test/llm.test.ts

+ 4 - 0
CHANGELOG.md

@@ -2,6 +2,10 @@
 
 ## [Unreleased]
 
+### Fixes
+
+- GPU: respect explicit `QMD_LLAMA_GPU=metal|vulkan|cuda` backend overrides instead of always using auto GPU selection. #529
+
 ## [2.1.0] - 2026-04-05
 
 Code files now chunk at function and class boundaries via tree-sitter,

+ 17 - 7
src/llm.ts

@@ -385,6 +385,18 @@ export type LlamaCppConfig = {
 const DEFAULT_INACTIVITY_TIMEOUT_MS = 5 * 60 * 1000;
 const DEFAULT_EXPAND_CONTEXT_SIZE = 2048;
 
+type LlamaGpuMode = "auto" | "metal" | "vulkan" | "cuda" | false;
+
+export function resolveLlamaGpuMode(envValue = process.env.QMD_LLAMA_GPU): LlamaGpuMode {
+  const normalized = envValue?.trim().toLowerCase() ?? "";
+  if (!normalized) return "auto";
+  if (["false", "off", "none", "disable", "disabled", "0"].includes(normalized)) return false;
+  if (normalized === "metal" || normalized === "vulkan" || normalized === "cuda") return normalized;
+
+  process.stderr.write(`QMD Warning: invalid QMD_LLAMA_GPU="${envValue}", using auto GPU selection.\n`);
+  return "auto";
+}
+
 function resolveExpandContextSize(configValue?: number): number {
   if (configValue !== undefined) {
     if (!Number.isInteger(configValue) || configValue <= 0) {
@@ -552,11 +564,9 @@ export class LlamaCpp implements LLM {
    */
   private async ensureLlama(): Promise<Llama> {
     if (!this.llama) {
-      // Allow override via QMD_LLAMA_GPU: "false" | "off" | "none" forces CPU
-      const gpuOverride = (process.env.QMD_LLAMA_GPU ?? "").toLowerCase();
-      const forceCpu = ["false", "off", "none", "disable", "disabled", "0"].includes(gpuOverride);
+      const gpuMode = resolveLlamaGpuMode();
 
-      const loadLlama = async (gpu: "auto" | false) =>
+      const loadLlama = async (gpu: LlamaGpuMode) =>
         await getLlama({
           build: "autoAttempt",
           logLevel: LlamaLogLevel.error,
@@ -564,16 +574,16 @@ export class LlamaCpp implements LLM {
         });
 
       let llama: Llama;
-      if (forceCpu) {
+      if (gpuMode === false) {
         llama = await loadLlama(false);
       } else {
         try {
-          llama = await loadLlama("auto");
+          llama = await loadLlama(gpuMode);
         } catch (err) {
           // GPU backend (e.g. Vulkan on headless/driverless machines) can throw at init.
           // Fall back to CPU so qmd still works.
           process.stderr.write(
-            `QMD Warning: GPU init failed (${err instanceof Error ? err.message : String(err)}), falling back to CPU.\n`
+            `QMD Warning: GPU init failed${gpuMode === "auto" ? "" : ` for QMD_LLAMA_GPU=${gpuMode}`} (${err instanceof Error ? err.message : String(err)}), falling back to CPU.\n`
           );
           llama = await loadLlama(false);
         }

+ 33 - 0
test/llm.test.ts

@@ -12,6 +12,7 @@ import {
   LlamaCpp,
   getDefaultLlamaCpp,
   disposeDefaultLlamaCpp,
+  resolveLlamaGpuMode,
   withLLMSession,
   canUnloadLLM,
   SessionReleasedError,
@@ -55,6 +56,38 @@ describe("LlamaCpp.modelExists", () => {
   });
 });
 
+describe("QMD_LLAMA_GPU resolution", () => {
+  test("uses auto when unset or blank", () => {
+    expect(resolveLlamaGpuMode(undefined)).toBe("auto");
+    expect(resolveLlamaGpuMode("   ")).toBe("auto");
+  });
+
+  test("maps CPU disable values to false", () => {
+    expect(resolveLlamaGpuMode("false")).toBe(false);
+    expect(resolveLlamaGpuMode("OFF")).toBe(false);
+    expect(resolveLlamaGpuMode(" none ")).toBe(false);
+    expect(resolveLlamaGpuMode("disabled")).toBe(false);
+    expect(resolveLlamaGpuMode("0")).toBe(false);
+  });
+
+  test("passes through supported GPU backends", () => {
+    expect(resolveLlamaGpuMode("metal")).toBe("metal");
+    expect(resolveLlamaGpuMode("VULKAN")).toBe("vulkan");
+    expect(resolveLlamaGpuMode(" cuda ")).toBe("cuda");
+  });
+
+  test("warns and falls back to auto for unsupported values", () => {
+    const stderrSpy = vi.spyOn(process.stderr, "write").mockReturnValue(true);
+    try {
+      expect(resolveLlamaGpuMode("rocm")).toBe("auto");
+      expect(stderrSpy).toHaveBeenCalled();
+      expect(String(stderrSpy.mock.calls[0]?.[0] || "")).toContain("QMD_LLAMA_GPU");
+    } finally {
+      stderrSpy.mockRestore();
+    }
+  });
+});
+
 describe("LlamaCpp expand context size config", () => {
   const defaultExpandContextSize = 2048;