score.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env python3
  2. """Score GEPA JSONL outputs using reward.py."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. import statistics
  7. from pathlib import Path
  8. from example import example_from_json
  9. from reward import score_expansion_detailed
  10. from dataset.schema import output_items_to_text
  11. def score_file(path: Path) -> tuple[int, int, list[float], dict]:
  12. total = 0
  13. errors = 0
  14. scores: list[float] = []
  15. ratings: dict[str, int] = {}
  16. with path.open("r", encoding="utf-8") as f:
  17. for line_num, line in enumerate(f, 1):
  18. line = line.strip()
  19. if not line:
  20. continue
  21. total += 1
  22. try:
  23. obj = json.loads(line)
  24. example = example_from_json(obj)
  25. except Exception:
  26. errors += 1
  27. continue
  28. output_text = output_items_to_text([item.to_pair() for item in example.output])
  29. if not output_text:
  30. errors += 1
  31. continue
  32. detail = score_expansion_detailed(example.query, output_text)
  33. score = detail["percentage"]
  34. scores.append(score)
  35. rating = detail["rating"]
  36. ratings[rating] = ratings.get(rating, 0) + 1
  37. return total, errors, scores, ratings
  38. def main() -> int:
  39. parser = argparse.ArgumentParser(description="Score GEPA JSONL outputs")
  40. parser.add_argument("--input", type=str, required=True, help="Input JSONL file")
  41. args = parser.parse_args()
  42. path = Path(args.input)
  43. if not path.exists():
  44. print(f"Input not found: {path}")
  45. return 1
  46. total, errors, scores, ratings = score_file(path)
  47. if scores:
  48. avg = statistics.mean(scores)
  49. median = statistics.median(scores)
  50. min_score = min(scores)
  51. max_score = max(scores)
  52. above_70 = sum(1 for s in scores if s >= 70.0)
  53. pct_70 = above_70 / len(scores) * 100
  54. print(
  55. f"{path}: {len(scores)} scored, {errors} errors, "
  56. f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
  57. f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
  58. )
  59. else:
  60. print(f"{path}: 0 scored, {errors} errors")
  61. if ratings:
  62. rating_parts = [f\"{k}:{v}\" for k, v in sorted(ratings.items())]
  63. print(f\" ratings: {', '.join(rating_parts)}\")
  64. return 0
  65. if __name__ == "__main__":
  66. raise SystemExit(main())