|
|
@@ -27,7 +27,7 @@ repo_root = Path(__file__).parent.parent
|
|
|
if str(repo_root) not in sys.path:
|
|
|
sys.path.insert(0, str(repo_root))
|
|
|
|
|
|
-from dataset.schema import parse_output_text
|
|
|
+from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
|
|
|
from reward import score_expansion_detailed
|
|
|
|
|
|
|
|
|
@@ -35,11 +35,12 @@ class ExpandSignature(dspy.Signature):
|
|
|
"""Expand a search query into lex/vec/hyde lines."""
|
|
|
|
|
|
query = dspy.InputField(desc="User search query")
|
|
|
- expansion = dspy.OutputField(
|
|
|
+ output = dspy.OutputField(
|
|
|
desc=(
|
|
|
- "Multi-line text with prefixes: 2-3 lex:, 2-3 vec:, optional 0-1 hyde:. "
|
|
|
- "Lex lines are short keywords and must not echo the query. "
|
|
|
- "Vec lines are natural language search phrases. "
|
|
|
+ "JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
|
|
|
+ "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
|
|
|
+ "Lex items are short keywords and must not echo the query. "
|
|
|
+ "Vec items are natural language search phrases. "
|
|
|
"Hyde is 50-200 chars, single line."
|
|
|
)
|
|
|
)
|
|
|
@@ -55,7 +56,7 @@ class Expander(dspy.Module):
|
|
|
|
|
|
|
|
|
def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
|
|
|
- expansion = getattr(pred, "expansion", "") or ""
|
|
|
+ expansion = output_items_to_text(_coerce_output_items(pred))
|
|
|
detail = score_expansion_detailed(gold.query, expansion)
|
|
|
score = detail["percentage"] / 100.0
|
|
|
feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
|
|
|
@@ -80,10 +81,31 @@ def to_examples(queries: list[str]) -> list[dspy.Example]:
|
|
|
return [dspy.Example(query=q).with_inputs("query") for q in queries]
|
|
|
|
|
|
|
|
|
-def write_jsonl(path: Path, queries: list[str], outputs: list[str]) -> None:
|
|
|
+def _coerce_output_items(pred) -> list[list[str]]:
|
|
|
+ raw_output = getattr(pred, "output", None)
|
|
|
+ if isinstance(raw_output, (list, tuple)):
|
|
|
+ return normalize_output_items(raw_output)
|
|
|
+
|
|
|
+ raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
|
|
|
+ if not raw_text:
|
|
|
+ return []
|
|
|
+
|
|
|
+ if raw_text[0] in ("[", "{"):
|
|
|
+ try:
|
|
|
+ obj = json.loads(raw_text)
|
|
|
+ if isinstance(obj, dict) and "output" in obj:
|
|
|
+ obj = obj["output"]
|
|
|
+ if isinstance(obj, (list, tuple)):
|
|
|
+ return normalize_output_items(obj)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return parse_output_text(raw_text)
|
|
|
+
|
|
|
+
|
|
|
+def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
|
|
|
with path.open("w", encoding="utf-8") as f:
|
|
|
- for query, output_text in zip(queries, outputs, strict=True):
|
|
|
- output = parse_output_text(output_text)
|
|
|
+ for query, output in zip(queries, outputs, strict=True):
|
|
|
f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
|
|
|
|
|
|
|
|
|
@@ -102,6 +124,8 @@ def main() -> int:
|
|
|
default="grok-4-1-fast-reasoning",
|
|
|
help="LM string in provider/model format (e.g., openai/gpt-4o)",
|
|
|
)
|
|
|
+ parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
|
|
|
+ parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
|
|
|
parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
|
|
|
parser.add_argument("--max-full-evals", type=int, default=None)
|
|
|
parser.add_argument("--max-metric-calls", type=int, default=None)
|
|
|
@@ -134,8 +158,8 @@ def main() -> int:
|
|
|
val_queries = val_queries[: args.val_limit]
|
|
|
valset = to_examples(val_queries)
|
|
|
|
|
|
- lm = dspy.LM(model=args.model)
|
|
|
- reflection_lm = dspy.LM(model=args.reflection_model)
|
|
|
+ lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
|
|
|
+ reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
|
|
|
|
|
|
student = Expander()
|
|
|
student.set_lm(lm)
|
|
|
@@ -163,7 +187,8 @@ def main() -> int:
|
|
|
outputs = []
|
|
|
for q in queries:
|
|
|
pred = optimized(query=q)
|
|
|
- outputs.append(getattr(pred, "expansion", "") or "")
|
|
|
+ items = _coerce_output_items(pred)
|
|
|
+ outputs.append(items)
|
|
|
write_jsonl(Path(args.emit), queries, outputs)
|
|
|
print(f"Wrote {args.emit}")
|
|
|
|