score_data.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/env python3
  2. """Score JSONL datasets with the reward function."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. import statistics
  7. import sys
  8. from pathlib import Path
  9. sys.path.insert(0, str(Path(__file__).parent.parent))
  10. from dataset.schema import (
  11. normalize_output_items,
  12. output_items_to_text,
  13. parse_output_text,
  14. )
  15. from reward import score_expansion_detailed
  16. def score_file(path: Path) -> tuple[int, int, list[float], dict]:
  17. total = 0
  18. errors = 0
  19. scores: list[float] = []
  20. ratings: dict[str, int] = {}
  21. with path.open("r", encoding="utf-8") as f:
  22. for line_num, line in enumerate(f, 1):
  23. line = line.strip()
  24. if not line:
  25. continue
  26. total += 1
  27. try:
  28. obj = json.loads(line)
  29. except json.JSONDecodeError:
  30. errors += 1
  31. continue
  32. query = obj.get("query") or obj.get("input")
  33. output = obj.get("output")
  34. if not isinstance(query, str) or not query.strip():
  35. errors += 1
  36. continue
  37. if output is None:
  38. errors += 1
  39. continue
  40. if isinstance(output, str):
  41. output_items = normalize_output_items(parse_output_text(output))
  42. else:
  43. output_items = normalize_output_items(output)
  44. output_text = output_items_to_text(output_items)
  45. if not output_text:
  46. errors += 1
  47. continue
  48. detail = score_expansion_detailed(query, output_text)
  49. score = detail["percentage"]
  50. scores.append(score)
  51. rating = detail["rating"]
  52. ratings[rating] = ratings.get(rating, 0) + 1
  53. return total, errors, scores, ratings
  54. def main() -> int:
  55. parser = argparse.ArgumentParser(description="Score QMD datasets")
  56. parser.add_argument(
  57. "paths",
  58. nargs="*",
  59. default=["finetune/data/*.jsonl"],
  60. help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
  61. )
  62. args = parser.parse_args()
  63. repo_root = Path(__file__).parent.parent.parent
  64. files: list[Path] = []
  65. for pattern in args.paths:
  66. if "*" in pattern:
  67. files.extend(repo_root.glob(pattern))
  68. else:
  69. files.append(repo_root / pattern)
  70. files = [p for p in files if p.exists()]
  71. if not files:
  72. print("No files found to score.")
  73. return 1
  74. for path in sorted(files):
  75. total, errors, scores, ratings = score_file(path)
  76. if scores:
  77. avg = statistics.mean(scores)
  78. median = statistics.median(scores)
  79. min_score = min(scores)
  80. max_score = max(scores)
  81. above_70 = sum(1 for s in scores if s >= 70.0)
  82. pct_70 = above_70 / len(scores) * 100
  83. print(
  84. f"{path}: {len(scores)} scored, {errors} errors, "
  85. f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
  86. f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
  87. )
  88. else:
  89. print(f"{path}: 0 scored, {errors} errors")
  90. if ratings:
  91. rating_parts = [f"{k}:{v}" for k, v in sorted(ratings.items())]
  92. print(f" ratings: {', '.join(rating_parts)}")
  93. return 0
  94. if __name__ == "__main__":
  95. raise SystemExit(main())