generate.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #!/usr/bin/env python3
  2. """Generate expansions using a saved GEPA prompt."""
  3. from __future__ import annotations
  4. import argparse
  5. import importlib
  6. import json
  7. import sys
  8. from pathlib import Path
  9. def _import_dspy():
  10. script_dir = Path(__file__).parent
  11. original_sys_path = list(sys.path)
  12. try:
  13. sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
  14. return importlib.import_module("dspy")
  15. finally:
  16. sys.path = original_sys_path
  17. dspy = _import_dspy()
  18. repo_root = Path(__file__).parent.parent
  19. if str(repo_root) not in sys.path:
  20. sys.path.insert(0, str(repo_root))
  21. from dataset.schema import parse_output_text
  22. def load_topics(path: Path) -> list[str]:
  23. topics: list[str] = []
  24. with path.open("r", encoding="utf-8") as f:
  25. for line in f:
  26. line = line.strip()
  27. if not line:
  28. continue
  29. # Allow JSONL {"topic": "..."} or plain lines.
  30. if line.startswith("{") and line.endswith("}"):
  31. try:
  32. obj = json.loads(line)
  33. topic = obj.get("topic") or obj.get("query") or obj.get("input")
  34. if isinstance(topic, str) and topic.strip():
  35. topics.append(topic.strip())
  36. continue
  37. except json.JSONDecodeError:
  38. pass
  39. topics.append(line)
  40. return topics
  41. def write_jsonl_line(handle, query: str, output_text: str) -> None:
  42. output = parse_output_text(output_text)
  43. handle.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
  44. def parse_queries(text: str) -> list[str]:
  45. lines = []
  46. for raw in text.splitlines():
  47. line = raw.strip().lstrip("-").strip()
  48. if not line:
  49. continue
  50. lines.append(line)
  51. return lines
  52. def main() -> int:
  53. parser = argparse.ArgumentParser(description="Generate with saved GEPA prompt")
  54. parser.add_argument("--prompt", type=str, required=True, help="Path to saved prompt text")
  55. parser.add_argument("--topics", type=str, required=True, help="Topics file (one per line or JSONL)")
  56. parser.add_argument("--output", type=str, required=True, help="Output JSONL path")
  57. parser.add_argument("--model", type=str, required=True, help="LM string in provider/model format")
  58. parser.add_argument("--per-topic", type=int, default=3, help="Queries to generate per topic")
  59. args = parser.parse_args()
  60. prompt_text = Path(args.prompt).read_text(encoding="utf-8").strip()
  61. expansion_sig = dspy.Signature("query -> expansion", prompt_text)
  62. query_sig = dspy.Signature(
  63. "topic, count -> queries",
  64. (
  65. "Generate distinct user search queries for the given topic. "
  66. "Return exactly `count` queries, one per line, no numbering or extra text."
  67. ),
  68. )
  69. class Generator(dspy.Module):
  70. def __init__(self):
  71. super().__init__()
  72. self.predict = dspy.Predict(expansion_sig)
  73. def forward(self, query: str):
  74. return self.predict(query=query)
  75. class QueryGenerator(dspy.Module):
  76. def __init__(self):
  77. super().__init__()
  78. self.predict = dspy.Predict(query_sig)
  79. def forward(self, topic: str, count: int):
  80. return self.predict(topic=topic, count=str(count))
  81. lm = dspy.LM(model=args.model)
  82. gen = Generator()
  83. gen.set_lm(lm)
  84. qgen = QueryGenerator()
  85. qgen.set_lm(lm)
  86. topics = load_topics(Path(args.topics))
  87. with Path(args.output).open("w", encoding="utf-8") as f_out:
  88. for topic in topics:
  89. qpred = qgen(topic=topic, count=args.per_topic)
  90. qtext = getattr(qpred, "queries", "") or ""
  91. generated = parse_queries(qtext)
  92. if not generated:
  93. generated = [topic]
  94. generated = generated[: args.per_topic]
  95. for query in generated:
  96. pred = gen(query=query)
  97. output_text = getattr(pred, "expansion", "") or ""
  98. write_jsonl_line(f_out, query, output_text)
  99. print(json.dumps({"query": query, "output": parse_output_text(output_text)}, ensure_ascii=False))
  100. print(f"Wrote {args.output}")
  101. return 0
  102. if __name__ == "__main__":
  103. raise SystemExit(main())