generate_data_offline.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "datasets",
  6. # ]
  7. # ///
  8. """
  9. Generate QMD training data by transforming s-emanuilov/query-expansion dataset
  10. and adding synthetic hyde passages. No API calls needed.
  11. """
  12. import json
  13. import random
  14. from pathlib import Path
  15. # HyDE passage templates for different query types
  16. HYDE_TEMPLATES = {
  17. "how_to": [
  18. "To {action}, you need to {steps}. This can be done by {method}.",
  19. "The recommended way to {action} is to first {step1}, then {step2}.",
  20. "{Topic} can be achieved by {method}. Make sure to {consideration}.",
  21. ],
  22. "what_is": [
  23. "{Topic} is a {category} that {description}. It is commonly used for {use_case}.",
  24. "{Topic} refers to {definition}. Key features include {features}.",
  25. ],
  26. "config": [
  27. "To configure {topic}, set the {setting} option to {value}. You can also customize {other}.",
  28. "Configuration for {topic} is done in the {file} file. Key settings include {settings}.",
  29. ],
  30. "error": [
  31. "The {error} error occurs when {cause}. To fix this, {solution}.",
  32. "If you encounter {error}, check that {check}. Common solutions include {solutions}.",
  33. ],
  34. "general": [
  35. "{Topic} provides {benefit} for {use_case}. It works by {mechanism}.",
  36. "When working with {topic}, consider {considerations}. Best practices include {practices}.",
  37. ],
  38. }
  39. def classify_query(query: str) -> str:
  40. """Classify query type for hyde template selection."""
  41. q = query.lower()
  42. if any(w in q for w in ["how to", "how do", "setup", "install", "configure", "create"]):
  43. return "how_to"
  44. if any(w in q for w in ["what is", "what are", "definition", "meaning"]):
  45. return "what_is"
  46. if any(w in q for w in ["config", "setting", "option"]):
  47. return "config"
  48. if any(w in q for w in ["error", "issue", "problem", "fix", "debug"]):
  49. return "error"
  50. return "general"
  51. def extract_topic(query: str) -> str:
  52. """Extract main topic from query."""
  53. # Remove common prefixes
  54. for prefix in ["how to ", "how do i ", "what is ", "what are ", "configure ", "setup "]:
  55. if query.lower().startswith(prefix):
  56. return query[len(prefix):].strip()
  57. return query
  58. def generate_hyde(query: str, expansions: list[str]) -> str:
  59. """Generate a hypothetical document passage by combining expansions naturally."""
  60. topic = extract_topic(query)
  61. query_type = classify_query(query)
  62. # Use the longest, most descriptive expansion as the base
  63. sorted_exp = sorted(expansions, key=len, reverse=True)
  64. main_exp = sorted_exp[0] if sorted_exp else topic
  65. # Build a natural passage based on query type
  66. if query_type == "how_to":
  67. templates = [
  68. f"To {topic}, start by reviewing the requirements and dependencies. {main_exp.capitalize()} is the recommended approach. Make sure all prerequisites are met before proceeding.",
  69. f"The process of {topic} involves several steps. First, {main_exp}. Follow the official documentation for detailed instructions.",
  70. f"When you need to {topic}, the most effective method is to {main_exp}. This ensures compatibility and follows best practices.",
  71. ]
  72. elif query_type == "what_is":
  73. templates = [
  74. f"{topic.capitalize()} refers to {main_exp}. It is widely used in various applications and provides significant benefits.",
  75. f"The concept of {topic} encompasses {main_exp}. Understanding this is essential for effective implementation.",
  76. f"{topic.capitalize()} is defined as {main_exp}. This plays a crucial role in modern development practices.",
  77. ]
  78. elif query_type == "config":
  79. templates = [
  80. f"Configuration for {topic} requires setting the appropriate parameters. {main_exp.capitalize()} should be adjusted based on your specific requirements.",
  81. f"To configure {topic}, modify the settings in your configuration file. Key options include those related to {main_exp}.",
  82. f"The {topic} configuration can be customized by {main_exp}. Default values work for most use cases.",
  83. ]
  84. elif query_type == "error":
  85. templates = [
  86. f"The {topic} issue typically occurs when dependencies are misconfigured. To resolve this, {main_exp}. Check your environment settings.",
  87. f"If you encounter problems with {topic}, verify that {main_exp}. Common solutions include updating dependencies and checking permissions.",
  88. f"Debugging {topic} requires understanding the root cause. Often, {main_exp} resolves the issue. Review logs for details.",
  89. ]
  90. else:
  91. templates = [
  92. f"{topic.capitalize()} is an important concept that relates to {main_exp}. It provides functionality for various use cases in software development.",
  93. f"Understanding {topic} is essential for modern development. Key aspects include {main_exp}. This knowledge helps in building robust applications.",
  94. f"The topic of {topic} covers {main_exp}. Proper implementation follows established patterns and best practices.",
  95. ]
  96. return random.choice(templates)
  97. def transform_to_qmd_format(query: str, expansions: list[str]) -> str:
  98. """Transform s-emanuilov format to QMD lex/vec/hyde format."""
  99. lines = []
  100. # Generate hyde line first
  101. hyde = generate_hyde(query, expansions)
  102. lines.append(f"hyde: {hyde}")
  103. # Generate lex lines (keyword-focused, shorter)
  104. lex_candidates = []
  105. for exp in expansions:
  106. # Shorter versions for lex
  107. words = exp.split()
  108. if len(words) <= 4:
  109. lex_candidates.append(exp)
  110. else:
  111. # Take key phrases
  112. lex_candidates.append(" ".join(words[:3]))
  113. # Add 1-2 lex lines
  114. for lex in lex_candidates[:2]:
  115. if lex.lower() != query.lower():
  116. lines.append(f"lex: {lex}")
  117. # Generate vec lines (semantic, complete phrases)
  118. vec_candidates = [exp for exp in expansions if len(exp.split()) >= 3]
  119. if not vec_candidates:
  120. vec_candidates = expansions
  121. # Add 1-2 vec lines
  122. for vec in vec_candidates[:2]:
  123. if vec.lower() != query.lower():
  124. lines.append(f"vec: {vec}")
  125. return "\n".join(lines)
  126. def main():
  127. try:
  128. from datasets import load_dataset
  129. except ImportError:
  130. print("Installing datasets...")
  131. import subprocess
  132. subprocess.run(["uv", "pip", "install", "datasets"], check=True)
  133. from datasets import load_dataset
  134. print("Loading s-emanuilov/query-expansion dataset...")
  135. dataset = load_dataset("s-emanuilov/query-expansion", split="train")
  136. print(f"Loaded {len(dataset)} examples")
  137. # Transform each example
  138. output_path = Path("data/qmd_expansion.jsonl")
  139. output_path.parent.mkdir(parents=True, exist_ok=True)
  140. examples = []
  141. for item in dataset:
  142. query = item["query"]
  143. expansions = item["expansions"]
  144. output = transform_to_qmd_format(query, expansions)
  145. examples.append({"input": query, "output": output})
  146. # Shuffle
  147. random.seed(42)
  148. random.shuffle(examples)
  149. # Write output
  150. with open(output_path, "w") as f:
  151. for ex in examples:
  152. f.write(json.dumps(ex) + "\n")
  153. print(f"Generated {len(examples)} examples to {output_path}")
  154. # Show sample
  155. print("\nSample output:")
  156. print("-" * 50)
  157. sample = examples[0]
  158. print(f"Input: {sample['input']}")
  159. print(f"Output:\n{sample['output']}")
  160. if __name__ == "__main__":
  161. main()