|
@@ -157,50 +157,65 @@ def format_for_training(input_text: str, output_text: str) -> dict:
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Prepare data for training")
|
|
parser = argparse.ArgumentParser(description="Prepare data for training")
|
|
|
- parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
|
|
|
|
|
|
|
+ parser.add_argument("--input", type=str, default="data/*.jsonl", help="Input JSONL file(s) - supports glob patterns")
|
|
|
parser.add_argument("--output", type=str, default="data/train", help="Output directory")
|
|
parser.add_argument("--output", type=str, default="data/train", help="Output directory")
|
|
|
parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
|
|
parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
|
|
|
parser.add_argument("--add-short", type=int, default=3, help="Variations per short query to add")
|
|
parser.add_argument("--add-short", type=int, default=3, help="Variations per short query to add")
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
- input_path = Path(args.input)
|
|
|
|
|
output_dir = Path(args.output)
|
|
output_dir = Path(args.output)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
- if not input_path.exists():
|
|
|
|
|
- print(f"Error: Input file not found: {input_path}")
|
|
|
|
|
- exit(1)
|
|
|
|
|
-
|
|
|
|
|
- # Load and clean existing examples
|
|
|
|
|
|
|
+ # Support glob patterns for input
|
|
|
|
|
+ import glob
|
|
|
|
|
+ if "*" in args.input:
|
|
|
|
|
+ input_files = sorted(glob.glob(args.input))
|
|
|
|
|
+ if not input_files:
|
|
|
|
|
+ print(f"Error: No files found matching: {args.input}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+ print(f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ input_path = Path(args.input)
|
|
|
|
|
+ if not input_path.exists():
|
|
|
|
|
+ print(f"Error: Input file not found: {input_path}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+ input_files = [str(input_path)]
|
|
|
|
|
+
|
|
|
|
|
+ # Load and clean existing examples from all input files
|
|
|
examples = []
|
|
examples = []
|
|
|
seen_queries = set()
|
|
seen_queries = set()
|
|
|
long_hyde_count = 0
|
|
long_hyde_count = 0
|
|
|
|
|
|
|
|
- with open(input_path) as f:
|
|
|
|
|
- for line in f:
|
|
|
|
|
- if line.strip():
|
|
|
|
|
- ex = json.loads(line)
|
|
|
|
|
-
|
|
|
|
|
- # Clean output (truncate hyde, remove invalid lines)
|
|
|
|
|
- original_output = ex["output"]
|
|
|
|
|
- ex["output"] = clean_output(ex["output"])
|
|
|
|
|
-
|
|
|
|
|
- # Track hyde truncation
|
|
|
|
|
- if "hyde:" in original_output:
|
|
|
|
|
- for orig_line in original_output.split("\n"):
|
|
|
|
|
- if orig_line.strip().startswith("hyde:"):
|
|
|
|
|
- if len(orig_line) > 160:
|
|
|
|
|
- long_hyde_count += 1
|
|
|
|
|
-
|
|
|
|
|
- # Validate cleaned output
|
|
|
|
|
- has_lex = "lex:" in ex["output"]
|
|
|
|
|
- has_vec = "vec:" in ex["output"]
|
|
|
|
|
-
|
|
|
|
|
- if has_lex and has_vec:
|
|
|
|
|
- examples.append(ex)
|
|
|
|
|
- seen_queries.add(ex["input"].lower())
|
|
|
|
|
-
|
|
|
|
|
- print(f"Loaded and cleaned {len(examples)} examples")
|
|
|
|
|
|
|
+ for input_file in input_files:
|
|
|
|
|
+ file_count = 0
|
|
|
|
|
+ with open(input_file) as f:
|
|
|
|
|
+ for line in f:
|
|
|
|
|
+ if line.strip():
|
|
|
|
|
+ ex = json.loads(line)
|
|
|
|
|
+
|
|
|
|
|
+ # Clean output (truncate hyde, remove invalid lines)
|
|
|
|
|
+ original_output = ex["output"]
|
|
|
|
|
+ ex["output"] = clean_output(ex["output"])
|
|
|
|
|
+
|
|
|
|
|
+ # Track hyde truncation
|
|
|
|
|
+ if "hyde:" in original_output:
|
|
|
|
|
+ for orig_line in original_output.split("\n"):
|
|
|
|
|
+ if orig_line.strip().startswith("hyde:"):
|
|
|
|
|
+ if len(orig_line) > 160:
|
|
|
|
|
+ long_hyde_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ # Validate cleaned output
|
|
|
|
|
+ has_lex = "lex:" in ex["output"]
|
|
|
|
|
+ has_vec = "vec:" in ex["output"]
|
|
|
|
|
+
|
|
|
|
|
+ # Skip duplicates
|
|
|
|
|
+ if has_lex and has_vec and ex["input"].lower() not in seen_queries:
|
|
|
|
|
+ examples.append(ex)
|
|
|
|
|
+ seen_queries.add(ex["input"].lower())
|
|
|
|
|
+ file_count += 1
|
|
|
|
|
+ print(f" {Path(input_file).name}: {file_count} examples")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"Loaded and cleaned {len(examples)} examples total")
|
|
|
print(f"Truncated {long_hyde_count} long hyde sections")
|
|
print(f"Truncated {long_hyde_count} long hyde sections")
|
|
|
|
|
|
|
|
# Count existing short queries
|
|
# Count existing short queries
|