fix_lex_filler.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python3
  2. """
  3. Fix lex entries in QMD training data by removing filler words that were
  4. inserted as padding rather than being genuine search intent.
  5. Filler words to remove (case insensitive):
  6. - overview
  7. - tutorial
  8. - guide
  9. - examples
  10. - documentation
  11. - best practices
  12. Keep these words when they're genuinely part of the query intent.
  13. """
  14. import json
  15. import re
  16. from pathlib import Path
  17. INPUT_FILE = Path(__file__).parent / "qmd_expansion_v2.jsonl"
  18. OUTPUT_FILE = Path(__file__).parent / "qmd_expansion_v3_lex_fixed.jsonl"
  19. # Filler words/phrases to remove when they're just padding
  20. FILLER_WORDS = [
  21. 'overview',
  22. 'tutorial',
  23. 'guide',
  24. 'examples',
  25. 'documentation',
  26. 'best practices',
  27. ]
  28. def count_word(text: str, word: str) -> int:
  29. """Count occurrences of a word in text (case insensitive, whole word)."""
  30. if word == 'best practices':
  31. return len(re.findall(r'\bbest practices\b', text, re.IGNORECASE))
  32. return len(re.findall(r'\b' + re.escape(word) + r'\b', text, re.IGNORECASE))
  33. def clean_lex_entry(lex: str, query: str) -> str:
  34. """
  35. Remove filler words from a lex entry.
  36. Logic:
  37. - If filler word appears in lex but NOT in query: remove all occurrences
  38. - If filler word appears in both: remove excess occurrences (keep same count as in query)
  39. """
  40. query_lower = query.lower()
  41. result = lex
  42. for filler in FILLER_WORDS:
  43. query_count = count_word(query_lower, filler)
  44. lex_count = count_word(result, filler)
  45. # Remove excess occurrences
  46. if lex_count > query_count:
  47. # Remove (lex_count - query_count) occurrences
  48. for _ in range(lex_count - query_count):
  49. if filler == 'best practices':
  50. result = re.sub(r'\bbest practices\b', '', result, count=1, flags=re.IGNORECASE)
  51. else:
  52. result = re.sub(r'\b' + re.escape(filler) + r'\b', '', result, count=1, flags=re.IGNORECASE)
  53. # Clean up extra whitespace
  54. result = ' '.join(result.split())
  55. return result.strip()
  56. def has_filler_to_clean(lex: str, query: str) -> bool:
  57. """Check if lex entry has filler words that need cleaning."""
  58. query_lower = query.lower()
  59. for filler in FILLER_WORDS:
  60. query_count = count_word(query_lower, filler)
  61. lex_count = count_word(lex, filler)
  62. if lex_count > query_count:
  63. return True
  64. return False
  65. def process_entry(entry: dict) -> tuple[dict, bool]:
  66. """
  67. Process a single entry, cleaning lex entries if needed.
  68. Returns (processed_entry, was_modified)
  69. """
  70. query = entry.get("query", "")
  71. output = entry.get("output", [])
  72. modified = False
  73. new_output = []
  74. for item in output:
  75. if item[0] == "lex":
  76. original_lex = item[1]
  77. if has_filler_to_clean(original_lex, query):
  78. cleaned_lex = clean_lex_entry(original_lex, query)
  79. if cleaned_lex != original_lex:
  80. new_output.append(["lex", cleaned_lex])
  81. modified = True
  82. else:
  83. new_output.append(item)
  84. else:
  85. new_output.append(item)
  86. else:
  87. new_output.append(item)
  88. new_entry = entry.copy()
  89. new_entry["output"] = new_output
  90. return new_entry, modified
  91. def main():
  92. entries = []
  93. modified_count = 0
  94. total_lex_modified = 0
  95. print(f"Reading {INPUT_FILE}...")
  96. with open(INPUT_FILE, 'r') as f:
  97. for line_num, line in enumerate(f, 1):
  98. if not line.strip():
  99. continue
  100. entry = json.loads(line)
  101. processed, modified = process_entry(entry)
  102. entries.append(processed)
  103. if modified:
  104. modified_count += 1
  105. # Count how many lex entries were modified
  106. orig_output = entry.get("output", [])
  107. new_output = processed.get("output", [])
  108. for i, item in enumerate(orig_output):
  109. if item[0] == "lex" and item[1] != new_output[i][1]:
  110. total_lex_modified += 1
  111. print(f"Total entries: {len(entries)}")
  112. print(f"Entries modified: {modified_count}")
  113. print(f"Total lex entries cleaned: {total_lex_modified}")
  114. print(f"\nWriting to {OUTPUT_FILE}...")
  115. with open(OUTPUT_FILE, 'w') as f:
  116. for entry in entries:
  117. f.write(json.dumps(entry) + '\n')
  118. print("Done!")
  119. # Show some examples of modifications
  120. print("\n--- Sample modifications ---")
  121. sample_count = 0
  122. with open(INPUT_FILE, 'r') as f:
  123. for line in f:
  124. if not line.strip():
  125. continue
  126. entry = json.loads(line)
  127. processed, modified = process_entry(entry)
  128. if modified and sample_count < 15:
  129. query = entry.get("query", "")
  130. print(f"\nQuery: {query}")
  131. orig_lex = [item[1] for item in entry.get("output", []) if item[0] == "lex"]
  132. new_lex = [item[1] for item in processed.get("output", []) if item[0] == "lex"]
  133. for orig, new in zip(orig_lex, new_lex):
  134. if orig != new:
  135. print(f" - \"{orig}\"")
  136. print(f" + \"{new}\"")
  137. sample_count += 1
  138. if __name__ == "__main__":
  139. main()