| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- #!/usr/bin/env python3
- """
- Fix template hyde entries in qmd_expansion_v2.jsonl
- Replaces generic "comprehensive guide covers everything" hydes with query-specific ones.
- """
- import json
- import os
- import sys
- from pathlib import Path
- from openai import OpenAI
- # Configuration
- INPUT_FILE = Path("qmd_expansion_v2.jsonl")
- OUTPUT_FILE = Path("qmd_expansion_v3.jsonl")
- CHECKPOINT_FILE = Path("fix_hyde_checkpoint.json")
- BAD_PATTERN = "comprehensive guide covers everything"
- BATCH_SIZE = 25 # Process 25 queries per API call
- def load_checkpoint():
- """Load progress checkpoint if exists."""
- if CHECKPOINT_FILE.exists():
- with open(CHECKPOINT_FILE) as f:
- return json.load(f)
- return {"processed_queries": {}, "completed_indices": []}
- def save_checkpoint(checkpoint):
- """Save progress checkpoint."""
- with open(CHECKPOINT_FILE, 'w') as f:
- json.dump(checkpoint, f)
- def load_examples():
- """Load all examples from input file."""
- examples = []
- with open(INPUT_FILE) as f:
- for line in f:
- examples.append(json.loads(line.strip()))
- return examples
- def is_bad_hyde(example):
- """Check if example has the bad template hyde."""
- for item in example.get("output", []):
- if item[0] == "hyde" and BAD_PATTERN in item[1]:
- return True
- return False
- def get_hyde_from_example(example):
- """Extract the hyde value from an example."""
- for item in example.get("output", []):
- if item[0] == "hyde":
- return item[1]
- return None
- def set_hyde_in_example(example, new_hyde):
- """Set the hyde value in an example."""
- for i, item in enumerate(example.get("output", [])):
- if item[0] == "hyde":
- example["output"][i] = ["hyde", new_hyde]
- return
- # If no hyde found, append it
- example["output"].append(["hyde", new_hyde])
- def generate_hydes_batch(client, queries):
- """Generate hydes for a batch of queries using GPT-4o-mini."""
- queries_text = "\n".join(f"{i+1}. {q}" for i, q in enumerate(queries))
-
- prompt = f"""Generate hypothetical document snippets (hyde) for each query below.
- Requirements:
- - 100-180 characters each
- - Query-specific factual information
- - Written as if from an actual document that would answer the query
- - NO generic phrases like "comprehensive guide" or "everything you need to know"
- - Include actual facts, numbers, names, or specifics
- Example:
- Query: "kubernetes pod networking"
- Hyde: "Pods communicate via cluster IP. Use CNI plugins like Calico or Flannel. Service discovery through DNS. NetworkPolicy controls traffic between namespaces."
- Queries to process:
- {queries_text}
- Output ONLY valid JSON - a single object mapping query numbers to hyde texts:
- {{"1": "hyde text for query 1", "2": "hyde text for query 2", ...}}"""
- response = client.chat.completions.create(
- model="gpt-4o-mini",
- max_tokens=4096,
- temperature=0.7,
- messages=[{"role": "user", "content": prompt}]
- )
-
- # Parse the response
- text = response.choices[0].message.content.strip()
- # Handle potential markdown code blocks
- if text.startswith("```"):
- text = text.split("```")[1]
- if text.startswith("json"):
- text = text[4:]
- text = text.strip()
-
- try:
- result = json.loads(text)
- # Convert keys to int
- return {int(k): v for k, v in result.items()}
- except json.JSONDecodeError as e:
- print(f"JSON parse error: {e}")
- print(f"Response text: {text[:500]}...")
- return {}
- def main():
- print("Loading examples...")
- examples = load_examples()
- print(f"Loaded {len(examples)} examples")
-
- # Find bad examples
- bad_indices = []
- for i, ex in enumerate(examples):
- if is_bad_hyde(ex):
- bad_indices.append(i)
-
- print(f"Found {len(bad_indices)} examples with bad hyde")
-
- # Load checkpoint
- checkpoint = load_checkpoint()
- completed = set(checkpoint.get("completed_indices", []))
- processed_queries = checkpoint.get("processed_queries", {})
-
- # Filter to only unprocessed
- to_process = [i for i in bad_indices if i not in completed]
- print(f"Already processed: {len(completed)}, remaining: {len(to_process)}")
-
- if not to_process:
- print("All examples already processed!")
- else:
- # Initialize OpenAI client
- client = OpenAI()
-
- # Process in batches
- for batch_start in range(0, len(to_process), BATCH_SIZE):
- batch_indices = to_process[batch_start:batch_start + BATCH_SIZE]
- queries = [examples[i]["query"] for i in batch_indices]
-
- print(f"\nProcessing batch {batch_start//BATCH_SIZE + 1}/{(len(to_process) + BATCH_SIZE - 1)//BATCH_SIZE}")
- print(f"Queries: {queries[:3]}...")
-
- try:
- hydes = generate_hydes_batch(client, queries)
-
- # Apply the generated hydes
- for j, idx in enumerate(batch_indices):
- query_num = j + 1
- if query_num in hydes:
- new_hyde = hydes[query_num]
- processed_queries[str(idx)] = new_hyde
- completed.add(idx)
- print(f" [{idx}] {examples[idx]['query'][:40]}... -> {new_hyde[:50]}...")
- else:
- print(f" [{idx}] MISSING hyde for: {examples[idx]['query']}")
-
- # Save checkpoint after each batch
- checkpoint = {
- "processed_queries": processed_queries,
- "completed_indices": list(completed)
- }
- save_checkpoint(checkpoint)
- print(f" Checkpoint saved: {len(completed)}/{len(bad_indices)} complete")
-
- except Exception as e:
- print(f"Error processing batch: {e}")
- import traceback
- traceback.print_exc()
- # Save checkpoint before exiting
- checkpoint = {
- "processed_queries": processed_queries,
- "completed_indices": list(completed)
- }
- save_checkpoint(checkpoint)
- raise
-
- # Apply all fixes and write output
- print(f"\nApplying {len(processed_queries)} fixes...")
- for idx_str, new_hyde in processed_queries.items():
- idx = int(idx_str)
- set_hyde_in_example(examples[idx], new_hyde)
-
- # Write output file
- print(f"Writing {OUTPUT_FILE}...")
- with open(OUTPUT_FILE, 'w') as f:
- for ex in examples:
- f.write(json.dumps(ex) + "\n")
-
- # Verify
- with open(OUTPUT_FILE) as f:
- bad_count = sum(1 for line in f if BAD_PATTERN in line)
-
- print(f"\nDone! Bad hydes remaining: {bad_count}")
- print(f"Output written to: {OUTPUT_FILE}")
- if __name__ == "__main__":
- main()
|