#!/usr/bin/env python3 """Run DSPy GEPA using reward.py as the metric.""" from __future__ import annotations import argparse import importlib import json import sys from pathlib import Path def _import_dspy(): script_dir = Path(__file__).parent repo_root = script_dir.parent original_sys_path = list(sys.path) try: sys.path = [p for p in sys.path if p and str(p) != str(script_dir)] return importlib.import_module("dspy") finally: sys.path = original_sys_path dspy = _import_dspy() repo_root = Path(__file__).parent.parent if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text from reward import score_expansion_detailed class ExpandSignature(dspy.Signature): """Expand a search query into lex/vec/hyde lines.""" query = dspy.InputField(desc="User search query") output = dspy.OutputField( desc=( "JSON array of [kind, text] pairs. kind is lex|vec|hyde. " "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. " "Lex items are short keywords and must not echo the query. " "Vec items are natural language search phrases. " "Hyde is 50-200 chars, single line." ) ) class Expander(dspy.Module): def __init__(self): super().__init__() self.predict = dspy.Predict(ExpandSignature) def forward(self, query: str): return self.predict(query=query) def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None): expansion = output_items_to_text(_coerce_output_items(pred)) detail = score_expansion_detailed(gold.query, expansion) score = detail["percentage"] / 100.0 feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}" return dspy.Prediction(score=score, feedback=feedback) def load_queries(path: Path) -> list[str]: queries: list[str] = [] with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue obj = json.loads(line) query = obj.get("query") or obj.get("input") if isinstance(query, str) and query.strip(): queries.append(query.strip()) return queries def to_examples(queries: list[str]) -> list[dspy.Example]: return [dspy.Example(query=q).with_inputs("query") for q in queries] def _coerce_output_items(pred) -> list[list[str]]: raw_output = getattr(pred, "output", None) if isinstance(raw_output, (list, tuple)): return normalize_output_items(raw_output) raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip() if not raw_text: return [] if raw_text[0] in ("[", "{"): try: obj = json.loads(raw_text) if isinstance(obj, dict) and "output" in obj: obj = obj["output"] if isinstance(obj, (list, tuple)): return normalize_output_items(obj) except Exception: pass return parse_output_text(raw_text) def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None: with path.open("w", encoding="utf-8") as f: for query, output in zip(queries, outputs, strict=True): f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n") def main() -> int: parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py") parser.add_argument("--input", type=str, required=True, help="Training JSONL path") parser.add_argument( "--model", type=str, default="grok-4-1-fast-reasoning", help="LM string in provider/model format (e.g., openai/gpt-4o)", ) parser.add_argument( "--reflection-model", type=str, default="grok-4-1-fast-reasoning", help="LM string in provider/model format (e.g., openai/gpt-4o)", ) parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM") parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM") parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"]) parser.add_argument("--max-full-evals", type=int, default=None) parser.add_argument("--max-metric-calls", type=int, default=None) parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path") parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries") parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries") parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile") parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file") args = parser.parse_args() if "/" not in args.model or "/" not in args.reflection_model: print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).") return 1 if args.max_full_evals is not None and args.max_metric_calls is not None: print("Provide only one of --max-full-evals or --max-metric-calls") return 1 if args.max_full_evals is not None or args.max_metric_calls is not None: args.auto = None train_path = Path(args.input) queries = load_queries(train_path) if args.limit is not None: queries = queries[: args.limit] trainset = to_examples(queries) valset = None if args.valset: val_queries = load_queries(Path(args.valset)) if args.val_limit is not None: val_queries = val_queries[: args.val_limit] valset = to_examples(val_queries) lm = dspy.LM(model=args.model, max_tokens=args.max_tokens) reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens) student = Expander() student.set_lm(lm) compiler = dspy.GEPA( metric=reward_metric, reflection_lm=reflection_lm, auto=None if args.auto is None else args.auto, max_full_evals=args.max_full_evals, max_metric_calls=args.max_metric_calls, track_stats=True, track_best_outputs=True, failure_score=0.0, perfect_score=1.0, ) optimized = compiler.compile(student=student, trainset=trainset, valset=valset) if args.save_prompt: prompt_text = getattr(optimized.predict.signature, "__doc__", "") or "" Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8") print(f"Wrote {args.save_prompt}") if args.emit: outputs = [] for q in queries: pred = optimized(query=q) items = _coerce_output_items(pred) outputs.append(items) write_jsonl(Path(args.emit), queries, outputs) print(f"Wrote {args.emit}") if hasattr(optimized, "detailed_results"): best = getattr(optimized.detailed_results, "best_outputs_valset", None) if best: print(f"Best outputs tracked: {len(best)}") return 0 if __name__ == "__main__": raise SystemExit(main())