generate_data_offline.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "datasets",
  6. # ]
  7. # ///
  8. """
  9. Generate QMD training data by transforming s-emanuilov/query-expansion dataset
  10. and adding synthetic hyde passages. No API calls needed.
  11. """
  12. import json
  13. import random
  14. from pathlib import Path
  15. from dataset.schema import normalize_output_items, parse_output_text
  16. # HyDE passage templates for different query types
  17. HYDE_TEMPLATES = {
  18. "how_to": [
  19. "To {action}, you need to {steps}. This can be done by {method}.",
  20. "The recommended way to {action} is to first {step1}, then {step2}.",
  21. "{Topic} can be achieved by {method}. Make sure to {consideration}.",
  22. ],
  23. "what_is": [
  24. "{Topic} is a {category} that {description}. It is commonly used for {use_case}.",
  25. "{Topic} refers to {definition}. Key features include {features}.",
  26. ],
  27. "config": [
  28. "To configure {topic}, set the {setting} option to {value}. You can also customize {other}.",
  29. "Configuration for {topic} is done in the {file} file. Key settings include {settings}.",
  30. ],
  31. "error": [
  32. "The {error} error occurs when {cause}. To fix this, {solution}.",
  33. "If you encounter {error}, check that {check}. Common solutions include {solutions}.",
  34. ],
  35. "general": [
  36. "{Topic} provides {benefit} for {use_case}. It works by {mechanism}.",
  37. "When working with {topic}, consider {considerations}. Best practices include {practices}.",
  38. ],
  39. }
  40. def classify_query(query: str) -> str:
  41. """Classify query type for hyde template selection."""
  42. q = query.lower()
  43. if any(
  44. w in q for w in ["how to", "how do", "setup", "install", "configure", "create"]
  45. ):
  46. return "how_to"
  47. if any(w in q for w in ["what is", "what are", "definition", "meaning"]):
  48. return "what_is"
  49. if any(w in q for w in ["config", "setting", "option"]):
  50. return "config"
  51. if any(w in q for w in ["error", "issue", "problem", "fix", "debug"]):
  52. return "error"
  53. return "general"
  54. def extract_topic(query: str) -> str:
  55. """Extract main topic from query."""
  56. # Remove common prefixes
  57. for prefix in [
  58. "how to ",
  59. "how do i ",
  60. "what is ",
  61. "what are ",
  62. "configure ",
  63. "setup ",
  64. ]:
  65. if query.lower().startswith(prefix):
  66. return query[len(prefix) :].strip()
  67. return query
  68. def generate_hyde(query: str, expansions: list[str]) -> str:
  69. """Generate a hypothetical document passage by combining expansions naturally."""
  70. topic = extract_topic(query)
  71. query_type = classify_query(query)
  72. # Use the longest, most descriptive expansion as the base
  73. sorted_exp = sorted(expansions, key=len, reverse=True)
  74. main_exp = sorted_exp[0] if sorted_exp else topic
  75. # Build a natural passage based on query type
  76. if query_type == "how_to":
  77. templates = [
  78. f"To {topic}, start by reviewing the requirements and dependencies. {main_exp.capitalize()} is the recommended approach. Make sure all prerequisites are met before proceeding.",
  79. f"The process of {topic} involves several steps. First, {main_exp}. Follow the official documentation for detailed instructions.",
  80. f"When you need to {topic}, the most effective method is to {main_exp}. This ensures compatibility and follows best practices.",
  81. ]
  82. elif query_type == "what_is":
  83. templates = [
  84. f"{topic.capitalize()} refers to {main_exp}. It is widely used in various applications and provides significant benefits.",
  85. f"The concept of {topic} encompasses {main_exp}. Understanding this is essential for effective implementation.",
  86. f"{topic.capitalize()} is defined as {main_exp}. This plays a crucial role in modern development practices.",
  87. ]
  88. elif query_type == "config":
  89. templates = [
  90. f"Configuration for {topic} requires setting the appropriate parameters. {main_exp.capitalize()} should be adjusted based on your specific requirements.",
  91. f"To configure {topic}, modify the settings in your configuration file. Key options include those related to {main_exp}.",
  92. f"The {topic} configuration can be customized by {main_exp}. Default values work for most use cases.",
  93. ]
  94. elif query_type == "error":
  95. templates = [
  96. f"The {topic} issue typically occurs when dependencies are misconfigured. To resolve this, {main_exp}. Check your environment settings.",
  97. f"If you encounter problems with {topic}, verify that {main_exp}. Common solutions include updating dependencies and checking permissions.",
  98. f"Debugging {topic} requires understanding the root cause. Often, {main_exp} resolves the issue. Review logs for details.",
  99. ]
  100. else:
  101. templates = [
  102. f"{topic.capitalize()} is an important concept that relates to {main_exp}. It provides functionality for various use cases in software development.",
  103. f"Understanding {topic} is essential for modern development. Key aspects include {main_exp}. This knowledge helps in building robust applications.",
  104. f"The topic of {topic} covers {main_exp}. Proper implementation follows established patterns and best practices.",
  105. ]
  106. return random.choice(templates)
  107. def transform_to_qmd_format(query: str, expansions: list[str]) -> str:
  108. """Transform s-emanuilov format to QMD lex/vec/hyde format."""
  109. lines = []
  110. # Generate hyde line first
  111. hyde = generate_hyde(query, expansions)
  112. lines.append(f"hyde: {hyde}")
  113. # Generate lex lines (keyword-focused, shorter)
  114. lex_candidates = []
  115. for exp in expansions:
  116. # Shorter versions for lex
  117. words = exp.split()
  118. if len(words) <= 4:
  119. lex_candidates.append(exp)
  120. else:
  121. # Take key phrases
  122. lex_candidates.append(" ".join(words[:3]))
  123. # Add 1-2 lex lines
  124. for lex in lex_candidates[:2]:
  125. if lex.lower() != query.lower():
  126. lines.append(f"lex: {lex}")
  127. # Generate vec lines (semantic, complete phrases)
  128. vec_candidates = [exp for exp in expansions if len(exp.split()) >= 3]
  129. if not vec_candidates:
  130. vec_candidates = expansions
  131. # Add 1-2 vec lines
  132. for vec in vec_candidates[:2]:
  133. if vec.lower() != query.lower():
  134. lines.append(f"vec: {vec}")
  135. return "\n".join(lines)
  136. def main():
  137. try:
  138. from datasets import load_dataset
  139. except ImportError:
  140. print("Installing datasets...")
  141. import subprocess
  142. subprocess.run(["uv", "pip", "install", "datasets"], check=True)
  143. from datasets import load_dataset
  144. print("Loading s-emanuilov/query-expansion dataset...")
  145. dataset = load_dataset("s-emanuilov/query-expansion", split="train")
  146. print(f"Loaded {len(dataset)} examples")
  147. # Transform each example
  148. output_path = Path("data/qmd_expansion.jsonl")
  149. output_path.parent.mkdir(parents=True, exist_ok=True)
  150. examples = []
  151. for item in dataset:
  152. query = item["query"]
  153. expansions = item["expansions"]
  154. output = transform_to_qmd_format(query, expansions)
  155. output_items = normalize_output_items(parse_output_text(output))
  156. examples.append({"query": query, "output": output_items})
  157. # Shuffle
  158. random.seed(42)
  159. random.shuffle(examples)
  160. # Write output
  161. with open(output_path, "w") as f:
  162. for ex in examples:
  163. f.write(json.dumps(ex) + "\n")
  164. print(f"Generated {len(examples)} examples to {output_path}")
  165. # Show sample
  166. print("\nSample output:")
  167. print("-" * 50)
  168. sample = examples[0]
  169. print(f"Input: {sample['query']}")
  170. print(f"Output: {sample['output']}")
  171. if __name__ == "__main__":
  172. main()