schema.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #!/usr/bin/env python3
  2. """
  3. Strict schema for QMD training data.
  4. Every JSONL file in data/ MUST conform to this format:
  5. {"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
  6. - query: non-empty string
  7. - output: list of [type, text] pairs where type is "lex", "vec", or "hyde"
  8. - Extra fields (category, intent, is_short, etc.) are allowed but ignored
  9. There is exactly ONE format. No alternatives, no legacy fallbacks.
  10. """
  11. from __future__ import annotations
  12. import json
  13. from enum import Enum
  14. from pathlib import Path
  15. from typing import Annotated, Iterable
  16. from pydantic import (
  17. BaseModel,
  18. BeforeValidator,
  19. ConfigDict,
  20. field_validator,
  21. )
  22. # ---------------------------------------------------------------------------
  23. # Types
  24. # ---------------------------------------------------------------------------
  25. class OutputType(str, Enum):
  26. lex = "lex"
  27. vec = "vec"
  28. hyde = "hyde"
  29. VALID_OUTPUT_TYPES = {t.value for t in OutputType}
  30. class OutputPair(BaseModel):
  31. """A single expansion line: [type, text]."""
  32. type: OutputType
  33. text: str
  34. model_config = ConfigDict(frozen=True)
  35. @field_validator("text")
  36. @classmethod
  37. def text_not_empty(cls, v: str) -> str:
  38. if not v or not v.strip():
  39. raise ValueError("text must not be empty")
  40. return v
  41. def to_list(self) -> list[str]:
  42. return [self.type.value, self.text]
  43. def _coerce_output_pairs(v: list) -> list[OutputPair]:
  44. """Accept [["lex", "..."], ...] from JSON and coerce to OutputPair list."""
  45. pairs = []
  46. for i, item in enumerate(v):
  47. if isinstance(item, OutputPair):
  48. pairs.append(item)
  49. elif isinstance(item, (list, tuple)) and len(item) == 2:
  50. pairs.append(OutputPair(type=item[0], text=item[1]))
  51. else:
  52. raise ValueError(
  53. f"output[{i}] must be [type, text], got {item!r}"
  54. )
  55. return pairs
  56. # ---------------------------------------------------------------------------
  57. # Pydantic model — single source of truth for the JSONL schema
  58. # ---------------------------------------------------------------------------
  59. class TrainingExample(BaseModel):
  60. """One training example in the canonical JSONL format."""
  61. query: str
  62. output: Annotated[list[OutputPair], BeforeValidator(_coerce_output_pairs)]
  63. # Optional metadata — present in some files, ignored during training.
  64. category: str | None = None
  65. intent: str | None = None
  66. is_short: bool | None = None
  67. model_config = ConfigDict(extra="ignore")
  68. @field_validator("query")
  69. @classmethod
  70. def query_not_empty(cls, v: str) -> str:
  71. if not v or not v.strip():
  72. raise ValueError("query must not be empty")
  73. return v
  74. @field_validator("output")
  75. @classmethod
  76. def output_not_empty(cls, v: list[OutputPair]) -> list[OutputPair]:
  77. if not v:
  78. raise ValueError("output must not be empty")
  79. return v
  80. def output_as_lists(self) -> list[list[str]]:
  81. """Return output as list-of-lists for JSON serialization."""
  82. return [p.to_list() for p in self.output]
  83. # ---------------------------------------------------------------------------
  84. # Loading
  85. # ---------------------------------------------------------------------------
  86. def load_examples(path: str | Path) -> list[TrainingExample]:
  87. """Load and validate a JSONL file. Fails loudly on any bad line."""
  88. path = Path(path)
  89. examples: list[TrainingExample] = []
  90. with path.open("r", encoding="utf-8") as f:
  91. for line_num, line in enumerate(f, 1):
  92. line = line.strip()
  93. if not line:
  94. continue
  95. try:
  96. obj = json.loads(line)
  97. except json.JSONDecodeError as e:
  98. raise ValueError(f"{path}:{line_num}: invalid JSON: {e}") from e
  99. try:
  100. examples.append(TrainingExample.model_validate(obj))
  101. except Exception as e:
  102. raise ValueError(f"{path}:{line_num}: {e}") from e
  103. return examples
  104. # ---------------------------------------------------------------------------
  105. # Helpers (used by prepare_data.py, reward.py, and other tools)
  106. # ---------------------------------------------------------------------------
  107. def parse_output_text(text: str) -> list[list[str]]:
  108. """Parse prefixed output text into list pairs.
  109. >>> parse_output_text("lex: foo\\nvec: bar")
  110. [["lex", "foo"], ["vec", "bar"]]
  111. """
  112. items: list[list[str]] = []
  113. for raw_line in text.strip().split("\n"):
  114. line = raw_line.strip()
  115. if not line:
  116. continue
  117. if line.startswith("lex:"):
  118. items.append(["lex", line[4:].strip()])
  119. elif line.startswith("vec:"):
  120. items.append(["vec", line[4:].strip()])
  121. elif line.startswith("hyde:"):
  122. items.append(["hyde", line[5:].strip()])
  123. return items
  124. def reorder_hyde_first(items: list[list[str]]) -> list[list[str]]:
  125. """Reorder items to put hyde first, then lex, then vec."""
  126. hyde_items = [item for item in items if item and item[0] == "hyde"]
  127. lex_items = [item for item in items if item and item[0] == "lex"]
  128. vec_items = [item for item in items if item and item[0] == "vec"]
  129. return hyde_items + lex_items + vec_items
  130. def output_items_to_text(
  131. items: Iterable, hyde_first: bool = True
  132. ) -> str:
  133. """Render output pairs to prefixed text lines.
  134. Accepts list[OutputPair] or list[list[str]].
  135. """
  136. normalized = []
  137. for item in items:
  138. if isinstance(item, OutputPair):
  139. normalized.append([item.type.value, item.text.strip()])
  140. continue
  141. if not item:
  142. continue
  143. try:
  144. kind, text = item[0], item[1]
  145. except Exception:
  146. continue
  147. if kind not in VALID_OUTPUT_TYPES:
  148. continue
  149. if text is None:
  150. continue
  151. text = str(text).strip()
  152. if not text:
  153. continue
  154. normalized.append([kind, text])
  155. if hyde_first:
  156. normalized = reorder_hyde_first(normalized)
  157. lines = [f"{kind}: {text}" for kind, text in normalized]
  158. return "\n".join(lines)
  159. def normalize_output_items(
  160. items: Iterable, hyde_first: bool = True
  161. ) -> list[list[str]]:
  162. """Normalize output pairs (filter invalid, trim whitespace, reorder).
  163. Accepts list[OutputPair] or list[list[str]].
  164. """
  165. normalized: list[list[str]] = []
  166. for item in items:
  167. if isinstance(item, OutputPair):
  168. normalized.append([item.type.value, item.text.strip()])
  169. continue
  170. if not item:
  171. continue
  172. try:
  173. kind, text = item[0], item[1]
  174. except Exception:
  175. continue
  176. if kind not in VALID_OUTPUT_TYPES:
  177. continue
  178. if text is None:
  179. continue
  180. text = str(text).strip()
  181. if not text:
  182. continue
  183. normalized.append([kind, text])
  184. if hyde_first:
  185. normalized = reorder_hyde_first(normalized)
  186. return normalized
  187. def has_type(items: Iterable, kind: str) -> bool:
  188. for item in items:
  189. if isinstance(item, OutputPair):
  190. if item.type.value == kind:
  191. return True
  192. elif item and item[0] == kind:
  193. return True
  194. return False