dspy_gepa.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #!/usr/bin/env python3
  2. """Run DSPy GEPA using reward.py as the metric."""
  3. from __future__ import annotations
  4. import argparse
  5. import importlib
  6. import json
  7. import sys
  8. from pathlib import Path
  9. def _import_dspy():
  10. script_dir = Path(__file__).parent
  11. repo_root = script_dir.parent
  12. original_sys_path = list(sys.path)
  13. try:
  14. sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
  15. return importlib.import_module("dspy")
  16. finally:
  17. sys.path = original_sys_path
  18. dspy = _import_dspy()
  19. repo_root = Path(__file__).parent.parent
  20. if str(repo_root) not in sys.path:
  21. sys.path.insert(0, str(repo_root))
  22. from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
  23. from reward import score_expansion_detailed
  24. class ExpandSignature(dspy.Signature):
  25. """Expand a search query into lex/vec/hyde lines."""
  26. query = dspy.InputField(desc="User search query")
  27. output = dspy.OutputField(
  28. desc=(
  29. "JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
  30. "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
  31. "Lex items are short keywords and must not echo the query. "
  32. "Vec items are natural language search phrases. "
  33. "Hyde is 50-200 chars, single line."
  34. )
  35. )
  36. class Expander(dspy.Module):
  37. def __init__(self):
  38. super().__init__()
  39. self.predict = dspy.Predict(ExpandSignature)
  40. def forward(self, query: str):
  41. return self.predict(query=query)
  42. def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
  43. expansion = output_items_to_text(_coerce_output_items(pred))
  44. detail = score_expansion_detailed(gold.query, expansion)
  45. score = detail["percentage"] / 100.0
  46. feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
  47. return dspy.Prediction(score=score, feedback=feedback)
  48. def load_queries(path: Path) -> list[str]:
  49. queries: list[str] = []
  50. with path.open("r", encoding="utf-8") as f:
  51. for line in f:
  52. line = line.strip()
  53. if not line:
  54. continue
  55. obj = json.loads(line)
  56. query = obj.get("query") or obj.get("input")
  57. if isinstance(query, str) and query.strip():
  58. queries.append(query.strip())
  59. return queries
  60. def to_examples(queries: list[str]) -> list[dspy.Example]:
  61. return [dspy.Example(query=q).with_inputs("query") for q in queries]
  62. def _coerce_output_items(pred) -> list[list[str]]:
  63. raw_output = getattr(pred, "output", None)
  64. if isinstance(raw_output, (list, tuple)):
  65. return normalize_output_items(raw_output)
  66. raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
  67. if not raw_text:
  68. return []
  69. if raw_text[0] in ("[", "{"):
  70. try:
  71. obj = json.loads(raw_text)
  72. if isinstance(obj, dict) and "output" in obj:
  73. obj = obj["output"]
  74. if isinstance(obj, (list, tuple)):
  75. return normalize_output_items(obj)
  76. except Exception:
  77. pass
  78. return parse_output_text(raw_text)
  79. def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
  80. with path.open("w", encoding="utf-8") as f:
  81. for query, output in zip(queries, outputs, strict=True):
  82. f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
  83. def main() -> int:
  84. parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
  85. parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
  86. parser.add_argument(
  87. "--model",
  88. type=str,
  89. default="grok-4-1-fast-reasoning",
  90. help="LM string in provider/model format (e.g., openai/gpt-4o)",
  91. )
  92. parser.add_argument(
  93. "--reflection-model",
  94. type=str,
  95. default="grok-4-1-fast-reasoning",
  96. help="LM string in provider/model format (e.g., openai/gpt-4o)",
  97. )
  98. parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
  99. parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
  100. parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
  101. parser.add_argument("--max-full-evals", type=int, default=None)
  102. parser.add_argument("--max-metric-calls", type=int, default=None)
  103. parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
  104. parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
  105. parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
  106. parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
  107. parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
  108. args = parser.parse_args()
  109. if "/" not in args.model or "/" not in args.reflection_model:
  110. print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
  111. return 1
  112. if args.max_full_evals is not None and args.max_metric_calls is not None:
  113. print("Provide only one of --max-full-evals or --max-metric-calls")
  114. return 1
  115. if args.max_full_evals is not None or args.max_metric_calls is not None:
  116. args.auto = None
  117. train_path = Path(args.input)
  118. queries = load_queries(train_path)
  119. if args.limit is not None:
  120. queries = queries[: args.limit]
  121. trainset = to_examples(queries)
  122. valset = None
  123. if args.valset:
  124. val_queries = load_queries(Path(args.valset))
  125. if args.val_limit is not None:
  126. val_queries = val_queries[: args.val_limit]
  127. valset = to_examples(val_queries)
  128. lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
  129. reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
  130. student = Expander()
  131. student.set_lm(lm)
  132. compiler = dspy.GEPA(
  133. metric=reward_metric,
  134. reflection_lm=reflection_lm,
  135. auto=None if args.auto is None else args.auto,
  136. max_full_evals=args.max_full_evals,
  137. max_metric_calls=args.max_metric_calls,
  138. track_stats=True,
  139. track_best_outputs=True,
  140. failure_score=0.0,
  141. perfect_score=1.0,
  142. )
  143. optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
  144. if args.save_prompt:
  145. prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
  146. Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
  147. print(f"Wrote {args.save_prompt}")
  148. if args.emit:
  149. outputs = []
  150. for q in queries:
  151. pred = optimized(query=q)
  152. items = _coerce_output_items(pred)
  153. outputs.append(items)
  154. write_jsonl(Path(args.emit), queries, outputs)
  155. print(f"Wrote {args.emit}")
  156. if hasattr(optimized, "detailed_results"):
  157. best = getattr(optimized.detailed_results, "best_outputs_valset", None)
  158. if best:
  159. print(f"Best outputs tracked: {len(best)}")
  160. return 0
  161. if __name__ == "__main__":
  162. raise SystemExit(main())