| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- #!/usr/bin/env python3
- """Run DSPy GEPA using reward.py as the metric."""
- from __future__ import annotations
- import argparse
- import importlib
- import json
- import sys
- from pathlib import Path
- def _import_dspy():
- script_dir = Path(__file__).parent
- repo_root = script_dir.parent
- original_sys_path = list(sys.path)
- try:
- sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
- return importlib.import_module("dspy")
- finally:
- sys.path = original_sys_path
- dspy = _import_dspy()
- repo_root = Path(__file__).parent.parent
- if str(repo_root) not in sys.path:
- sys.path.insert(0, str(repo_root))
- from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
- from reward import score_expansion_detailed
- class ExpandSignature(dspy.Signature):
- """Expand a search query into lex/vec/hyde lines."""
- query = dspy.InputField(desc="User search query")
- output = dspy.OutputField(
- desc=(
- "JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
- "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
- "Lex items are short keywords and must not echo the query. "
- "Vec items are natural language search phrases. "
- "Hyde is 50-200 chars, single line."
- )
- )
- class Expander(dspy.Module):
- def __init__(self):
- super().__init__()
- self.predict = dspy.Predict(ExpandSignature)
- def forward(self, query: str):
- return self.predict(query=query)
- def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
- expansion = output_items_to_text(_coerce_output_items(pred))
- detail = score_expansion_detailed(gold.query, expansion)
- score = detail["percentage"] / 100.0
- feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
- return dspy.Prediction(score=score, feedback=feedback)
- def load_queries(path: Path) -> list[str]:
- queries: list[str] = []
- with path.open("r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if not line:
- continue
- obj = json.loads(line)
- query = obj.get("query") or obj.get("input")
- if isinstance(query, str) and query.strip():
- queries.append(query.strip())
- return queries
- def to_examples(queries: list[str]) -> list[dspy.Example]:
- return [dspy.Example(query=q).with_inputs("query") for q in queries]
- def _coerce_output_items(pred) -> list[list[str]]:
- raw_output = getattr(pred, "output", None)
- if isinstance(raw_output, (list, tuple)):
- return normalize_output_items(raw_output)
- raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
- if not raw_text:
- return []
- if raw_text[0] in ("[", "{"):
- try:
- obj = json.loads(raw_text)
- if isinstance(obj, dict) and "output" in obj:
- obj = obj["output"]
- if isinstance(obj, (list, tuple)):
- return normalize_output_items(obj)
- except Exception:
- pass
- return parse_output_text(raw_text)
- def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
- with path.open("w", encoding="utf-8") as f:
- for query, output in zip(queries, outputs, strict=True):
- f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
- def main() -> int:
- parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
- parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
- parser.add_argument(
- "--model",
- type=str,
- default="grok-4-1-fast-reasoning",
- help="LM string in provider/model format (e.g., openai/gpt-4o)",
- )
- parser.add_argument(
- "--reflection-model",
- type=str,
- default="grok-4-1-fast-reasoning",
- help="LM string in provider/model format (e.g., openai/gpt-4o)",
- )
- parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
- parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
- parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
- parser.add_argument("--max-full-evals", type=int, default=None)
- parser.add_argument("--max-metric-calls", type=int, default=None)
- parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
- parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
- parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
- parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
- parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
- args = parser.parse_args()
- if "/" not in args.model or "/" not in args.reflection_model:
- print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
- return 1
- if args.max_full_evals is not None and args.max_metric_calls is not None:
- print("Provide only one of --max-full-evals or --max-metric-calls")
- return 1
- if args.max_full_evals is not None or args.max_metric_calls is not None:
- args.auto = None
- train_path = Path(args.input)
- queries = load_queries(train_path)
- if args.limit is not None:
- queries = queries[: args.limit]
- trainset = to_examples(queries)
- valset = None
- if args.valset:
- val_queries = load_queries(Path(args.valset))
- if args.val_limit is not None:
- val_queries = val_queries[: args.val_limit]
- valset = to_examples(val_queries)
- lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
- reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
- student = Expander()
- student.set_lm(lm)
- compiler = dspy.GEPA(
- metric=reward_metric,
- reflection_lm=reflection_lm,
- auto=None if args.auto is None else args.auto,
- max_full_evals=args.max_full_evals,
- max_metric_calls=args.max_metric_calls,
- track_stats=True,
- track_best_outputs=True,
- failure_score=0.0,
- perfect_score=1.0,
- )
- optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
- if args.save_prompt:
- prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
- Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
- print(f"Wrote {args.save_prompt}")
- if args.emit:
- outputs = []
- for q in queries:
- pred = optimized(query=q)
- items = _coerce_output_items(pred)
- outputs.append(items)
- write_jsonl(Path(args.emit), queries, outputs)
- print(f"Wrote {args.emit}")
- if hasattr(optimized, "detailed_results"):
- best = getattr(optimized.detailed_results, "best_outputs_valset", None)
- if best:
- print(f"Best outputs tracked: {len(best)}")
- return 0
- if __name__ == "__main__":
- raise SystemExit(main())
|