example.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/env python3
  2. """GEPA example schema for QMD training JSONL lines."""
  3. from __future__ import annotations
  4. import json
  5. from dataclasses import dataclass, field
  6. from enum import Enum
  7. from pathlib import Path
  8. from typing import Iterable
  9. class SearchType(str, Enum):
  10. LexSearch = "LexSearch"
  11. VecSearch = "VecSearch"
  12. HydeSearch = "HydeSearch"
  13. SEARCH_TYPE_TO_PREFIX = {
  14. SearchType.LexSearch: "lex",
  15. SearchType.VecSearch: "vec",
  16. SearchType.HydeSearch: "hyde",
  17. }
  18. @dataclass
  19. class OutputItem:
  20. """Single expansion line with validation hints."""
  21. kind: SearchType
  22. text: str
  23. # Validation hints (not strict rules).
  24. min_chars: int = 3
  25. max_chars: int | None = None
  26. def __post_init__(self) -> None:
  27. self.text = str(self.text).strip()
  28. if not self.text:
  29. raise ValueError("OutputItem.text must be non-empty")
  30. if "\n" in self.text:
  31. raise ValueError("OutputItem.text must be single-line")
  32. if len(self.text) < self.min_chars:
  33. raise ValueError("OutputItem.text is too short")
  34. if self.max_chars is not None and len(self.text) > self.max_chars:
  35. raise ValueError("OutputItem.text is too long")
  36. def to_pair(self) -> list[str]:
  37. return [SEARCH_TYPE_TO_PREFIX[self.kind], self.text]
  38. @dataclass
  39. class Example:
  40. """JSONL line schema for QMD training data."""
  41. query: str
  42. output: list[OutputItem] = field(default_factory=list)
  43. def __post_init__(self) -> None:
  44. self.query = str(self.query).strip()
  45. if not self.query:
  46. raise ValueError("Example.query must be non-empty")
  47. if not self.output:
  48. raise ValueError("Example.output must not be empty")
  49. def to_json(self) -> dict:
  50. return {
  51. "query": self.query,
  52. "output": [item.to_pair() for item in self.output],
  53. }
  54. def to_jsonl(self) -> str:
  55. return json.dumps(self.to_json(), ensure_ascii=False)
  56. def parse_output_items(raw_output: Iterable[Iterable[str]]) -> list[OutputItem]:
  57. items: list[OutputItem] = []
  58. for item in raw_output:
  59. if not item or len(item) < 2:
  60. continue
  61. kind_raw, text = item[0], item[1]
  62. kind_map = {
  63. "lex": SearchType.LexSearch,
  64. "vec": SearchType.VecSearch,
  65. "hyde": SearchType.HydeSearch,
  66. }
  67. kind = kind_map.get(str(kind_raw).strip().lower())
  68. if kind is None:
  69. continue
  70. max_chars = 200 if kind is SearchType.HydeSearch else None
  71. items.append(OutputItem(kind=kind, text=str(text), max_chars=max_chars))
  72. return items
  73. def example_from_json(obj: dict) -> Example:
  74. query = obj.get("query") or obj.get("input") or ""
  75. output = obj.get("output") or []
  76. if isinstance(output, str):
  77. raise ValueError("String outputs are not supported in GEPA example schema")
  78. items = parse_output_items(output)
  79. return Example(query=query, output=items)
  80. def load_jsonl(path: str | Path) -> list[Example]:
  81. examples: list[Example] = []
  82. with Path(path).open("r", encoding="utf-8") as f:
  83. for line_num, line in enumerate(f, 1):
  84. line = line.strip()
  85. if not line:
  86. continue
  87. try:
  88. obj = json.loads(line)
  89. examples.append(example_from_json(obj))
  90. except Exception as exc:
  91. raise ValueError(f"Invalid line {line_num}: {exc}") from exc
  92. return examples