fix_hyde.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. #!/usr/bin/env python3
  2. """
  3. Fix template hyde entries in qmd_expansion_v2.jsonl
  4. Replaces generic "comprehensive guide covers everything" hydes with query-specific ones.
  5. """
  6. import json
  7. import os
  8. import sys
  9. from pathlib import Path
  10. from openai import OpenAI
  11. # Configuration
  12. INPUT_FILE = Path("qmd_expansion_v2.jsonl")
  13. OUTPUT_FILE = Path("qmd_expansion_v3.jsonl")
  14. CHECKPOINT_FILE = Path("fix_hyde_checkpoint.json")
  15. BAD_PATTERN = "comprehensive guide covers everything"
  16. BATCH_SIZE = 25 # Process 25 queries per API call
  17. def load_checkpoint():
  18. """Load progress checkpoint if exists."""
  19. if CHECKPOINT_FILE.exists():
  20. with open(CHECKPOINT_FILE) as f:
  21. return json.load(f)
  22. return {"processed_queries": {}, "completed_indices": []}
  23. def save_checkpoint(checkpoint):
  24. """Save progress checkpoint."""
  25. with open(CHECKPOINT_FILE, 'w') as f:
  26. json.dump(checkpoint, f)
  27. def load_examples():
  28. """Load all examples from input file."""
  29. examples = []
  30. with open(INPUT_FILE) as f:
  31. for line in f:
  32. examples.append(json.loads(line.strip()))
  33. return examples
  34. def is_bad_hyde(example):
  35. """Check if example has the bad template hyde."""
  36. for item in example.get("output", []):
  37. if item[0] == "hyde" and BAD_PATTERN in item[1]:
  38. return True
  39. return False
  40. def get_hyde_from_example(example):
  41. """Extract the hyde value from an example."""
  42. for item in example.get("output", []):
  43. if item[0] == "hyde":
  44. return item[1]
  45. return None
  46. def set_hyde_in_example(example, new_hyde):
  47. """Set the hyde value in an example."""
  48. for i, item in enumerate(example.get("output", [])):
  49. if item[0] == "hyde":
  50. example["output"][i] = ["hyde", new_hyde]
  51. return
  52. # If no hyde found, append it
  53. example["output"].append(["hyde", new_hyde])
  54. def generate_hydes_batch(client, queries):
  55. """Generate hydes for a batch of queries using GPT-4o-mini."""
  56. queries_text = "\n".join(f"{i+1}. {q}" for i, q in enumerate(queries))
  57. prompt = f"""Generate hypothetical document snippets (hyde) for each query below.
  58. Requirements:
  59. - 100-180 characters each
  60. - Query-specific factual information
  61. - Written as if from an actual document that would answer the query
  62. - NO generic phrases like "comprehensive guide" or "everything you need to know"
  63. - Include actual facts, numbers, names, or specifics
  64. Example:
  65. Query: "kubernetes pod networking"
  66. Hyde: "Pods communicate via cluster IP. Use CNI plugins like Calico or Flannel. Service discovery through DNS. NetworkPolicy controls traffic between namespaces."
  67. Queries to process:
  68. {queries_text}
  69. Output ONLY valid JSON - a single object mapping query numbers to hyde texts:
  70. {{"1": "hyde text for query 1", "2": "hyde text for query 2", ...}}"""
  71. response = client.chat.completions.create(
  72. model="gpt-4o-mini",
  73. max_tokens=4096,
  74. temperature=0.7,
  75. messages=[{"role": "user", "content": prompt}]
  76. )
  77. # Parse the response
  78. text = response.choices[0].message.content.strip()
  79. # Handle potential markdown code blocks
  80. if text.startswith("```"):
  81. text = text.split("```")[1]
  82. if text.startswith("json"):
  83. text = text[4:]
  84. text = text.strip()
  85. try:
  86. result = json.loads(text)
  87. # Convert keys to int
  88. return {int(k): v for k, v in result.items()}
  89. except json.JSONDecodeError as e:
  90. print(f"JSON parse error: {e}")
  91. print(f"Response text: {text[:500]}...")
  92. return {}
  93. def main():
  94. print("Loading examples...")
  95. examples = load_examples()
  96. print(f"Loaded {len(examples)} examples")
  97. # Find bad examples
  98. bad_indices = []
  99. for i, ex in enumerate(examples):
  100. if is_bad_hyde(ex):
  101. bad_indices.append(i)
  102. print(f"Found {len(bad_indices)} examples with bad hyde")
  103. # Load checkpoint
  104. checkpoint = load_checkpoint()
  105. completed = set(checkpoint.get("completed_indices", []))
  106. processed_queries = checkpoint.get("processed_queries", {})
  107. # Filter to only unprocessed
  108. to_process = [i for i in bad_indices if i not in completed]
  109. print(f"Already processed: {len(completed)}, remaining: {len(to_process)}")
  110. if not to_process:
  111. print("All examples already processed!")
  112. else:
  113. # Initialize OpenAI client
  114. client = OpenAI()
  115. # Process in batches
  116. for batch_start in range(0, len(to_process), BATCH_SIZE):
  117. batch_indices = to_process[batch_start:batch_start + BATCH_SIZE]
  118. queries = [examples[i]["query"] for i in batch_indices]
  119. print(f"\nProcessing batch {batch_start//BATCH_SIZE + 1}/{(len(to_process) + BATCH_SIZE - 1)//BATCH_SIZE}")
  120. print(f"Queries: {queries[:3]}...")
  121. try:
  122. hydes = generate_hydes_batch(client, queries)
  123. # Apply the generated hydes
  124. for j, idx in enumerate(batch_indices):
  125. query_num = j + 1
  126. if query_num in hydes:
  127. new_hyde = hydes[query_num]
  128. processed_queries[str(idx)] = new_hyde
  129. completed.add(idx)
  130. print(f" [{idx}] {examples[idx]['query'][:40]}... -> {new_hyde[:50]}...")
  131. else:
  132. print(f" [{idx}] MISSING hyde for: {examples[idx]['query']}")
  133. # Save checkpoint after each batch
  134. checkpoint = {
  135. "processed_queries": processed_queries,
  136. "completed_indices": list(completed)
  137. }
  138. save_checkpoint(checkpoint)
  139. print(f" Checkpoint saved: {len(completed)}/{len(bad_indices)} complete")
  140. except Exception as e:
  141. print(f"Error processing batch: {e}")
  142. import traceback
  143. traceback.print_exc()
  144. # Save checkpoint before exiting
  145. checkpoint = {
  146. "processed_queries": processed_queries,
  147. "completed_indices": list(completed)
  148. }
  149. save_checkpoint(checkpoint)
  150. raise
  151. # Apply all fixes and write output
  152. print(f"\nApplying {len(processed_queries)} fixes...")
  153. for idx_str, new_hyde in processed_queries.items():
  154. idx = int(idx_str)
  155. set_hyde_in_example(examples[idx], new_hyde)
  156. # Write output file
  157. print(f"Writing {OUTPUT_FILE}...")
  158. with open(OUTPUT_FILE, 'w') as f:
  159. for ex in examples:
  160. f.write(json.dumps(ex) + "\n")
  161. # Verify
  162. with open(OUTPUT_FILE) as f:
  163. bad_count = sum(1 for line in f if BAD_PATTERN in line)
  164. print(f"\nDone! Bad hydes remaining: {bad_count}")
  165. print(f"Output written to: {OUTPUT_FILE}")
  166. if __name__ == "__main__":
  167. main()