validate_schema.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = ["pydantic>=2.0"]
  5. # ///
  6. """Validate JSONL files against the strict QMD training schema."""
  7. from __future__ import annotations
  8. import argparse
  9. import json
  10. import sys
  11. from pathlib import Path
  12. sys.path.insert(0, str(Path(__file__).parent.parent))
  13. from dataset.schema import TrainingExample
  14. def validate_file(path: Path) -> tuple[int, int]:
  15. """Return (total_lines, error_count)."""
  16. total = 0
  17. errors = 0
  18. with path.open("r", encoding="utf-8") as f:
  19. for line_num, line in enumerate(f, 1):
  20. line = line.strip()
  21. if not line:
  22. continue
  23. total += 1
  24. try:
  25. obj = json.loads(line)
  26. except json.JSONDecodeError as e:
  27. print(f"{path}:{line_num}: invalid JSON ({e})")
  28. errors += 1
  29. continue
  30. try:
  31. TrainingExample.model_validate(obj)
  32. except Exception as e:
  33. print(f"{path}:{line_num}: {e}")
  34. errors += 1
  35. return total, errors
  36. def main() -> int:
  37. parser = argparse.ArgumentParser(description="Validate QMD JSONL schema")
  38. parser.add_argument(
  39. "paths",
  40. nargs="*",
  41. default=["finetune/data/*.jsonl"],
  42. help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
  43. )
  44. args = parser.parse_args()
  45. repo_root = Path(__file__).parent.parent.parent
  46. files: list[Path] = []
  47. for pattern in args.paths:
  48. if "*" in pattern:
  49. files.extend(repo_root.glob(pattern))
  50. else:
  51. files.append(repo_root / pattern)
  52. files = [p for p in files if p.exists()]
  53. if not files:
  54. print("No files found to validate.")
  55. return 1
  56. total_lines = 0
  57. total_errors = 0
  58. for path in sorted(files):
  59. lines, errors = validate_file(path)
  60. total_lines += lines
  61. total_errors += errors
  62. status = "OK" if errors == 0 else f"{errors} error(s)"
  63. print(f"{path}: {lines} lines, {status}")
  64. if total_errors:
  65. print(
  66. f"\nValidation failed: {total_errors} error(s) across {total_lines} lines"
  67. )
  68. return 1
  69. print(f"\nValidation passed: {total_lines} lines checked")
  70. return 0
  71. if __name__ == "__main__":
  72. raise SystemExit(main())