Ver código fonte

fix: use Qwen3 recommended sampling params to prevent repetition loops

- Changed temperature from 0/0.1 to 0.7 (Qwen3 non-thinking mode default)
- Added topK=20, topP=0.8 per Qwen3 docs
- Added repeatPenalty with presencePenalty=0.5 for query expansion
- Fixes infinite loop on acronyms like DHH, BFCM

Qwen3 docs explicitly warn: 'DO NOT use greedy decoding, as it can
lead to performance degradation and endless repetitions'
Tobi Lütke 3 meses atrás
pai
commit
102ff861d3

+ 12 - 0
finetune/Justfile

@@ -29,3 +29,15 @@ train-local:
 grpo-local:
     CUDA_VISIBLE_DEVICES=1,2,3 HF_TOKEN=${HF_TOKEN} uv run torchrun --standalone --nproc_per_node 3 \
       train.py grpo --config configs/grpo.yaml |& tee /tmp/qmd-grpo-train.log
+
+gepa-local:
+    UV_CACHE_DIR=/tmp/uv-cache LITELLM_CACHE_DIR=/tmp/litellm-cache OLLAMA_API_BASE=http://localhost:11434 \
+      uv run python gepa/dspy_gepa.py \
+      --input data/qmd_expansion_v2.jsonl \
+      --model ollama/glm-4.7-flash:Q8_0 \
+      --reflection-model ollama/glm-4.7-flash:Q8_0 \
+      --max-metric-calls 100 --limit 20 \
+      --valset data/qmd_expansion_handcrafted.jsonl --val-limit 20 \
+      --max-tokens 512 --reflection-max-tokens 512 \
+      --emit gepa/gepa_outputs_glm.jsonl \
+      --save-prompt gepa/best_prompt_glm.txt

+ 1 - 0
finetune/gepa/best_prompt_glm.txt

@@ -0,0 +1 @@
+Expand a search query into lex/vec/hyde lines.

+ 37 - 12
finetune/gepa/dspy_gepa.py

@@ -27,7 +27,7 @@ repo_root = Path(__file__).parent.parent
 if str(repo_root) not in sys.path:
     sys.path.insert(0, str(repo_root))
 
-from dataset.schema import parse_output_text
+from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
 from reward import score_expansion_detailed
 
 
@@ -35,11 +35,12 @@ class ExpandSignature(dspy.Signature):
     """Expand a search query into lex/vec/hyde lines."""
 
     query = dspy.InputField(desc="User search query")
-    expansion = dspy.OutputField(
+    output = dspy.OutputField(
         desc=(
-            "Multi-line text with prefixes: 2-3 lex:, 2-3 vec:, optional 0-1 hyde:. "
-            "Lex lines are short keywords and must not echo the query. "
-            "Vec lines are natural language search phrases. "
+            "JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
+            "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
+            "Lex items are short keywords and must not echo the query. "
+            "Vec items are natural language search phrases. "
             "Hyde is 50-200 chars, single line."
         )
     )
@@ -55,7 +56,7 @@ class Expander(dspy.Module):
 
 
 def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
-    expansion = getattr(pred, "expansion", "") or ""
+    expansion = output_items_to_text(_coerce_output_items(pred))
     detail = score_expansion_detailed(gold.query, expansion)
     score = detail["percentage"] / 100.0
     feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
@@ -80,10 +81,31 @@ def to_examples(queries: list[str]) -> list[dspy.Example]:
     return [dspy.Example(query=q).with_inputs("query") for q in queries]
 
 
-def write_jsonl(path: Path, queries: list[str], outputs: list[str]) -> None:
+def _coerce_output_items(pred) -> list[list[str]]:
+    raw_output = getattr(pred, "output", None)
+    if isinstance(raw_output, (list, tuple)):
+        return normalize_output_items(raw_output)
+
+    raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
+    if not raw_text:
+        return []
+
+    if raw_text[0] in ("[", "{"):
+        try:
+            obj = json.loads(raw_text)
+            if isinstance(obj, dict) and "output" in obj:
+                obj = obj["output"]
+            if isinstance(obj, (list, tuple)):
+                return normalize_output_items(obj)
+        except Exception:
+            pass
+
+    return parse_output_text(raw_text)
+
+
+def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
     with path.open("w", encoding="utf-8") as f:
-        for query, output_text in zip(queries, outputs, strict=True):
-            output = parse_output_text(output_text)
+        for query, output in zip(queries, outputs, strict=True):
             f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
 
 
@@ -102,6 +124,8 @@ def main() -> int:
         default="grok-4-1-fast-reasoning",
         help="LM string in provider/model format (e.g., openai/gpt-4o)",
     )
+    parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
+    parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
     parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
     parser.add_argument("--max-full-evals", type=int, default=None)
     parser.add_argument("--max-metric-calls", type=int, default=None)
@@ -134,8 +158,8 @@ def main() -> int:
             val_queries = val_queries[: args.val_limit]
         valset = to_examples(val_queries)
 
-    lm = dspy.LM(model=args.model)
-    reflection_lm = dspy.LM(model=args.reflection_model)
+    lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
+    reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
 
     student = Expander()
     student.set_lm(lm)
@@ -163,7 +187,8 @@ def main() -> int:
         outputs = []
         for q in queries:
             pred = optimized(query=q)
-            outputs.append(getattr(pred, "expansion", "") or "")
+            items = _coerce_output_items(pred)
+            outputs.append(items)
         write_jsonl(Path(args.emit), queries, outputs)
         print(f"Wrote {args.emit}")
 

+ 20 - 0
finetune/gepa/gepa_outputs_glm.jsonl

@@ -0,0 +1,20 @@
+{"query": "how tourism affects local cultures", "output": []}
+{"query": "how to ferment foods at home", "output": []}
+{"query": "how to mix modern and vintage decor", "output": []}
+{"query": "how to perform a scientific experiment", "output": []}
+{"query": "web mail", "output": []}
+{"query": "what does the quran cover", "output": []}
+{"query": "web config", "output": []}
+{"query": "how to choose farm equipment", "output": []}
+{"query": "how do thought experiments aid philosophical reasoning", "output": []}
+{"query": "what is the significance of logic in philosophy", "output": []}
+{"query": "how to train for a 5k run", "output": []}
+{"query": "how to engage with political dialogues", "output": []}
+{"query": "what is competitive analysis", "output": []}
+{"query": "how does the united nations operate", "output": []}
+{"query": "what are the crusades?", "output": []}
+{"query": "what is a literary theme?", "output": []}
+{"query": "what is the ethical significance of consent", "output": []}
+{"query": "paint mix", "output": []}
+{"query": "how to conserve energy in the office?", "output": []}
+{"query": "how to test soil ph?", "output": []}

+ 15 - 2
src/llm.ts

@@ -742,13 +742,17 @@ export class LlamaCpp implements LLM {
     const session = new LlamaChatSession({ contextSequence: sequence });
 
     const maxTokens = options.maxTokens ?? 150;
-    const temperature = options.temperature ?? 0;
+    // Qwen3 recommends temp=0.7, topP=0.8, topK=20 for non-thinking mode
+    // DO NOT use greedy decoding (temp=0) - causes repetition loops
+    const temperature = options.temperature ?? 0.7;
 
     let result = "";
     try {
       await session.prompt(prompt, {
         maxTokens,
         temperature,
+        topK: 20,
+        topP: 0.8,
         onTextChunk: (text) => {
           result += text;
         },
@@ -811,10 +815,19 @@ export class LlamaCpp implements LLM {
     const session = new LlamaChatSession({ contextSequence: sequence });
 
     try {
+      // Qwen3 recommended settings for non-thinking mode:
+      // temp=0.7, topP=0.8, topK=20, presence_penalty for repetition
+      // DO NOT use greedy decoding (temp=0) - causes infinite loops
       const result = await session.prompt(prompt, {
         grammar,
         maxTokens: 600,
-        temperature: 0.1,
+        temperature: 0.7,
+        topK: 20,
+        topP: 0.8,
+        repeatPenalty: {
+          lastTokens: 64,
+          presencePenalty: 0.5,
+        },
       });
 
       const lines = result.trim().split("\n");