score_data.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = ["pydantic>=2.0"]
  5. # ///
  6. """Score JSONL datasets with the reward function."""
  7. from __future__ import annotations
  8. import argparse
  9. import statistics
  10. import sys
  11. from pathlib import Path
  12. sys.path.insert(0, str(Path(__file__).parent.parent))
  13. from dataset.schema import load_examples, output_items_to_text
  14. from reward import score_expansion_detailed
  15. def score_file(path: Path) -> tuple[int, int, list[float], dict]:
  16. total = 0
  17. errors = 0
  18. scores: list[float] = []
  19. ratings: dict[str, int] = {}
  20. try:
  21. examples = load_examples(path)
  22. except ValueError as e:
  23. print(f" Error loading {path}: {e}")
  24. return 0, 1, [], {}
  25. for ex in examples:
  26. total += 1
  27. output_text = output_items_to_text(ex.output)
  28. if not output_text:
  29. errors += 1
  30. continue
  31. detail = score_expansion_detailed(ex.query, output_text)
  32. score = detail["percentage"]
  33. scores.append(score)
  34. rating = detail["rating"]
  35. ratings[rating] = ratings.get(rating, 0) + 1
  36. return total, errors, scores, ratings
  37. def main() -> int:
  38. parser = argparse.ArgumentParser(description="Score QMD datasets")
  39. parser.add_argument(
  40. "paths",
  41. nargs="*",
  42. default=["finetune/data/*.jsonl"],
  43. help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
  44. )
  45. args = parser.parse_args()
  46. repo_root = Path(__file__).parent.parent.parent
  47. files: list[Path] = []
  48. for pattern in args.paths:
  49. if "*" in pattern:
  50. files.extend(repo_root.glob(pattern))
  51. else:
  52. files.append(repo_root / pattern)
  53. files = [p for p in files if p.exists()]
  54. if not files:
  55. print("No files found to score.")
  56. return 1
  57. for path in sorted(files):
  58. total, errors, scores, ratings = score_file(path)
  59. if scores:
  60. avg = statistics.mean(scores)
  61. median = statistics.median(scores)
  62. min_score = min(scores)
  63. max_score = max(scores)
  64. above_70 = sum(1 for s in scores if s >= 70.0)
  65. pct_70 = above_70 / len(scores) * 100
  66. print(
  67. f"{path}: {len(scores)} scored, {errors} errors, "
  68. f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
  69. f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
  70. )
  71. else:
  72. print(f"{path}: 0 scored, {errors} errors")
  73. if ratings:
  74. rating_parts = [f"{k}:{v}" for k, v in sorted(ratings.items())]
  75. print(f" ratings: {', '.join(rating_parts)}")
  76. return 0
  77. if __name__ == "__main__":
  78. raise SystemExit(main())