dspy_gepa.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #!/usr/bin/env python3
  2. """Run DSPy GEPA using reward.py as the metric."""
  3. from __future__ import annotations
  4. import argparse
  5. import importlib
  6. import json
  7. import sys
  8. from pathlib import Path
  9. def _import_dspy():
  10. script_dir = Path(__file__).parent
  11. repo_root = script_dir.parent
  12. original_sys_path = list(sys.path)
  13. try:
  14. sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
  15. return importlib.import_module("dspy")
  16. finally:
  17. sys.path = original_sys_path
  18. dspy = _import_dspy()
  19. repo_root = Path(__file__).parent.parent
  20. if str(repo_root) not in sys.path:
  21. sys.path.insert(0, str(repo_root))
  22. from dataset.schema import parse_output_text
  23. from reward import score_expansion_detailed
  24. class ExpandSignature(dspy.Signature):
  25. """Expand a search query into lex/vec/hyde lines."""
  26. query = dspy.InputField(desc="User search query")
  27. expansion = dspy.OutputField(
  28. desc=(
  29. "Multi-line text with prefixes: 2-3 lex:, 2-3 vec:, optional 0-1 hyde:. "
  30. "Lex lines are short keywords and must not echo the query. "
  31. "Vec lines are natural language search phrases. "
  32. "Hyde is 50-200 chars, single line."
  33. )
  34. )
  35. class Expander(dspy.Module):
  36. def __init__(self):
  37. super().__init__()
  38. self.predict = dspy.Predict(ExpandSignature)
  39. def forward(self, query: str):
  40. return self.predict(query=query)
  41. def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
  42. expansion = getattr(pred, "expansion", "") or ""
  43. detail = score_expansion_detailed(gold.query, expansion)
  44. score = detail["percentage"] / 100.0
  45. feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
  46. return dspy.Prediction(score=score, feedback=feedback)
  47. def load_queries(path: Path) -> list[str]:
  48. queries: list[str] = []
  49. with path.open("r", encoding="utf-8") as f:
  50. for line in f:
  51. line = line.strip()
  52. if not line:
  53. continue
  54. obj = json.loads(line)
  55. query = obj.get("query") or obj.get("input")
  56. if isinstance(query, str) and query.strip():
  57. queries.append(query.strip())
  58. return queries
  59. def to_examples(queries: list[str]) -> list[dspy.Example]:
  60. return [dspy.Example(query=q).with_inputs("query") for q in queries]
  61. def write_jsonl(path: Path, queries: list[str], outputs: list[str]) -> None:
  62. with path.open("w", encoding="utf-8") as f:
  63. for query, output_text in zip(queries, outputs, strict=True):
  64. output = parse_output_text(output_text)
  65. f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
  66. def main() -> int:
  67. parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
  68. parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
  69. parser.add_argument(
  70. "--model",
  71. type=str,
  72. default="grok-4-1-fast-reasoning",
  73. help="LM string in provider/model format (e.g., openai/gpt-4o)",
  74. )
  75. parser.add_argument(
  76. "--reflection-model",
  77. type=str,
  78. default="grok-4-1-fast-reasoning",
  79. help="LM string in provider/model format (e.g., openai/gpt-4o)",
  80. )
  81. parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
  82. parser.add_argument("--max-full-evals", type=int, default=None)
  83. parser.add_argument("--max-metric-calls", type=int, default=None)
  84. parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
  85. parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
  86. parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
  87. parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
  88. parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
  89. args = parser.parse_args()
  90. if "/" not in args.model or "/" not in args.reflection_model:
  91. print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
  92. return 1
  93. if args.max_full_evals is not None and args.max_metric_calls is not None:
  94. print("Provide only one of --max-full-evals or --max-metric-calls")
  95. return 1
  96. if args.max_full_evals is not None or args.max_metric_calls is not None:
  97. args.auto = None
  98. train_path = Path(args.input)
  99. queries = load_queries(train_path)
  100. if args.limit is not None:
  101. queries = queries[: args.limit]
  102. trainset = to_examples(queries)
  103. valset = None
  104. if args.valset:
  105. val_queries = load_queries(Path(args.valset))
  106. if args.val_limit is not None:
  107. val_queries = val_queries[: args.val_limit]
  108. valset = to_examples(val_queries)
  109. lm = dspy.LM(model=args.model)
  110. reflection_lm = dspy.LM(model=args.reflection_model)
  111. student = Expander()
  112. student.set_lm(lm)
  113. compiler = dspy.GEPA(
  114. metric=reward_metric,
  115. reflection_lm=reflection_lm,
  116. auto=None if args.auto is None else args.auto,
  117. max_full_evals=args.max_full_evals,
  118. max_metric_calls=args.max_metric_calls,
  119. track_stats=True,
  120. track_best_outputs=True,
  121. failure_score=0.0,
  122. perfect_score=1.0,
  123. )
  124. optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
  125. if args.save_prompt:
  126. prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
  127. Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
  128. print(f"Wrote {args.save_prompt}")
  129. if args.emit:
  130. outputs = []
  131. for q in queries:
  132. pred = optimized(query=q)
  133. outputs.append(getattr(pred, "expansion", "") or "")
  134. write_jsonl(Path(args.emit), queries, outputs)
  135. print(f"Wrote {args.emit}")
  136. if hasattr(optimized, "detailed_results"):
  137. best = getattr(optimized.detailed_results, "best_outputs_valset", None)
  138. if best:
  139. print(f"Best outputs tracked: {len(best)}")
  140. return 0
  141. if __name__ == "__main__":
  142. raise SystemExit(main())