Bladeren bron

chore: clean up finetune folder and fix training workflow

- Remove versioned files (sft_v4.yaml, prepare_v4_dataset.py, train_v2/)
- Update configs to use local data/train/ directory
- Add glob pattern support to prepare_data.py and train.py
- Update .gitignore to properly ignore outputs/ and data/train*/
- Document data preparation step in CLAUDE.md

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 3 maanden geleden
bovenliggende
commit
38073799c0

+ 5 - 5
finetune/.gitignore

@@ -1,13 +1,13 @@
-# Model checkpoints (stored on HuggingFace Hub)
-qmd-query-expansion-*/
+# Training outputs (run eval before pushing to HuggingFace)
+outputs/
+
+# Model checkpoints
 *.pt
 *.safetensors
 
 # Processed data files (regenerated by prepare_data.py)
 data/train/
-data/train_v2/train.jsonl
-data/train_v2/train_chat.jsonl
-data/train_v2/val.jsonl
+data/train_v2/
 data/qmd_expansion_cleaned.jsonl
 data/quality_report.txt
 

+ 15 - 0
finetune/CLAUDE.md

@@ -70,6 +70,21 @@ Always use **Qwen3-1.7B** as the base model unless explicitly stated otherwise.
 
 Training can run **locally** (requires CUDA GPU) or via **HuggingFace Jobs** (cloud GPU, no local hardware needed).
 
+### Stage 0: Prepare Data
+
+Raw data in `data/*.jsonl` must be converted to Qwen3 chat format before training:
+
+```bash
+# Process all JSONL files in data/
+uv run dataset/prepare_data.py
+# Creates: data/train/train.jsonl, data/train/val.jsonl
+
+# Or process a specific file
+uv run dataset/prepare_data.py --input data/qmd_expansion_v2.jsonl
+```
+
+This applies the Qwen3 chat template, deduplicates, and splits into train/val sets.
+
 ### Stage 1: SFT
 
 ```bash

+ 3 - 1
finetune/configs/grpo.yaml

@@ -13,7 +13,9 @@ model:
   output: "outputs/grpo"  # Local training output (push to HF manually after eval)
 
 dataset:
-  name: "tobil/qmd-query-expansion-train-v2"
+  # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
+  # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
+  name: "data/train/"
   prompt_field: "messages"
   max_samples: 1000
 

+ 3 - 1
finetune/configs/sft.yaml

@@ -8,7 +8,9 @@ model:
   output: "outputs/sft"  # Local training output (push to HF manually after eval)
 
 dataset:
-  name: "tobil/qmd-query-expansion-train-v2"
+  # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
+  # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
+  name: "data/train/"
   text_field: "text"
   split: "train"
   eval_split: 0.1

+ 0 - 38
finetune/configs/sft_v4.yaml

@@ -1,38 +0,0 @@
-# SFT Training Config - v4 with /only: support
-# Usage: accelerate launch --config_file configs/accelerate_multi_gpu.yaml train.py sft --config configs/sft_v4.yaml
-
-model:
-  base: "Qwen/Qwen3-1.7B"
-  output: "qmd-sft-v4"
-
-dataset:
-  name: "data/train_v4"
-  text_field: "text"
-  split: "train"
-  eval_split: 0.1
-
-training:
-  epochs: 3
-  batch_size: 2
-  gradient_accumulation_steps: 4
-  learning_rate: 0.0002
-  max_length: 512
-  warmup_ratio: 0.03
-  lr_scheduler: "cosine"
-
-lora:
-  rank: 16
-  alpha: 32
-  dropout: 0.0
-  target_modules:
-    - "q_proj"
-    - "k_proj"
-    - "v_proj"
-    - "o_proj"
-    - "gate_proj"
-    - "up_proj"
-    - "down_proj"
-
-tracking:
-  project: "qmd-query-expansion"
-  run_name: "sft-1.7B-v4-only-modes"

+ 0 - 12
finetune/data/train/dataset_info.json

@@ -1,12 +0,0 @@
-{
-  "dataset_name": "qmd-query-expansion",
-  "train_samples": 1891,
-  "val_samples": 211,
-  "short_query_pct": 33.6,
-  "columns": [
-    "prompt",
-    "completion",
-    "text",
-    "messages"
-  ]
-}

+ 0 - 12
finetune/data/train_v2/dataset_info.json

@@ -1,12 +0,0 @@
-{
-  "dataset_name": "qmd-query-expansion",
-  "train_samples": 1145,
-  "val_samples": 128,
-  "short_query_pct": 29.3,
-  "columns": [
-    "prompt",
-    "completion",
-    "text",
-    "messages"
-  ]
-}

+ 47 - 32
finetune/dataset/prepare_data.py

@@ -157,50 +157,65 @@ def format_for_training(input_text: str, output_text: str) -> dict:
 
 def main():
     parser = argparse.ArgumentParser(description="Prepare data for training")
-    parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
+    parser.add_argument("--input", type=str, default="data/*.jsonl", help="Input JSONL file(s) - supports glob patterns")
     parser.add_argument("--output", type=str, default="data/train", help="Output directory")
     parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
     parser.add_argument("--add-short", type=int, default=3, help="Variations per short query to add")
     args = parser.parse_args()
 
-    input_path = Path(args.input)
     output_dir = Path(args.output)
     output_dir.mkdir(parents=True, exist_ok=True)
 
-    if not input_path.exists():
-        print(f"Error: Input file not found: {input_path}")
-        exit(1)
-
-    # Load and clean existing examples
+    # Support glob patterns for input
+    import glob
+    if "*" in args.input:
+        input_files = sorted(glob.glob(args.input))
+        if not input_files:
+            print(f"Error: No files found matching: {args.input}")
+            exit(1)
+        print(f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}")
+    else:
+        input_path = Path(args.input)
+        if not input_path.exists():
+            print(f"Error: Input file not found: {input_path}")
+            exit(1)
+        input_files = [str(input_path)]
+
+    # Load and clean existing examples from all input files
     examples = []
     seen_queries = set()
     long_hyde_count = 0
 
-    with open(input_path) as f:
-        for line in f:
-            if line.strip():
-                ex = json.loads(line)
-
-                # Clean output (truncate hyde, remove invalid lines)
-                original_output = ex["output"]
-                ex["output"] = clean_output(ex["output"])
-
-                # Track hyde truncation
-                if "hyde:" in original_output:
-                    for orig_line in original_output.split("\n"):
-                        if orig_line.strip().startswith("hyde:"):
-                            if len(orig_line) > 160:
-                                long_hyde_count += 1
-
-                # Validate cleaned output
-                has_lex = "lex:" in ex["output"]
-                has_vec = "vec:" in ex["output"]
-
-                if has_lex and has_vec:
-                    examples.append(ex)
-                    seen_queries.add(ex["input"].lower())
-
-    print(f"Loaded and cleaned {len(examples)} examples")
+    for input_file in input_files:
+        file_count = 0
+        with open(input_file) as f:
+            for line in f:
+                if line.strip():
+                    ex = json.loads(line)
+
+                    # Clean output (truncate hyde, remove invalid lines)
+                    original_output = ex["output"]
+                    ex["output"] = clean_output(ex["output"])
+
+                    # Track hyde truncation
+                    if "hyde:" in original_output:
+                        for orig_line in original_output.split("\n"):
+                            if orig_line.strip().startswith("hyde:"):
+                                if len(orig_line) > 160:
+                                    long_hyde_count += 1
+
+                    # Validate cleaned output
+                    has_lex = "lex:" in ex["output"]
+                    has_vec = "vec:" in ex["output"]
+
+                    # Skip duplicates
+                    if has_lex and has_vec and ex["input"].lower() not in seen_queries:
+                        examples.append(ex)
+                        seen_queries.add(ex["input"].lower())
+                        file_count += 1
+        print(f"  {Path(input_file).name}: {file_count} examples")
+
+    print(f"Loaded and cleaned {len(examples)} examples total")
     print(f"Truncated {long_hyde_count} long hyde sections")
 
     # Count existing short queries

+ 0 - 107
finetune/prepare_v4_dataset.py

@@ -1,107 +0,0 @@
-# /// script
-# requires-python = ">=3.10"
-# dependencies = []
-# ///
-"""Prepare v4 dataset: high-quality expansions + /only: variants."""
-
-import json
-import random
-from pathlib import Path
-
-def to_chat_format(query: str, output: str) -> dict:
-    """Convert input/output to chat format with /no_think."""
-    # For /only: queries, keep the suffix in the prompt
-    prompt = f"/no_think Expand this search query: {query}"
-    
-    text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n{output}<|im_end|>\n"
-    
-    messages = [
-        {"role": "user", "content": prompt},
-        {"role": "assistant", "content": output}
-    ]
-    
-    return {"text": text, "messages": messages}
-
-
-def load_jsonl(path: Path) -> list[dict]:
-    """Load JSONL file."""
-    data = []
-    with open(path) as f:
-        for line in f:
-            line = line.strip()
-            if line:
-                data.append(json.loads(line))
-    return data
-
-
-def main():
-    data_dir = Path("data")
-    
-    # High-quality sources
-    sources = [
-        ("qmd_expansion_v2.jsonl", "v2"),
-        ("qmd_expansion_handcrafted.jsonl", "handcrafted"),
-        ("qmd_only_variants.jsonl", "only"),
-    ]
-    
-    all_examples = []
-    stats = {}
-    
-    for filename, label in sources:
-        path = data_dir / filename
-        if not path.exists():
-            print(f"  Skipping {filename} (not found)")
-            continue
-        
-        raw = load_jsonl(path)
-        converted = []
-        
-        for item in raw:
-            query = item.get("input", "")
-            output = item.get("output", "")
-            if query and output:
-                converted.append(to_chat_format(query, output))
-        
-        all_examples.extend(converted)
-        stats[label] = len(converted)
-        print(f"  {label}: {len(converted)} examples")
-    
-    # Shuffle
-    random.seed(42)
-    random.shuffle(all_examples)
-    
-    # Split 90/10
-    split_idx = int(len(all_examples) * 0.9)
-    train = all_examples[:split_idx]
-    val = all_examples[split_idx:]
-    
-    # Write output
-    out_dir = data_dir / "train_v4"
-    out_dir.mkdir(exist_ok=True)
-    
-    with open(out_dir / "train.jsonl", "w") as f:
-        for ex in train:
-            f.write(json.dumps(ex) + "\n")
-    
-    with open(out_dir / "val.jsonl", "w") as f:
-        for ex in val:
-            f.write(json.dumps(ex) + "\n")
-    
-    # Dataset info
-    info = {
-        "dataset_name": "qmd-query-expansion-v4",
-        "train_samples": len(train),
-        "val_samples": len(val),
-        "sources": stats,
-    }
-    with open(out_dir / "dataset_info.json", "w") as f:
-        json.dump(info, f, indent=2)
-    
-    print(f"\n✓ Dataset prepared in {out_dir}/")
-    print(f"  Train: {len(train)}")
-    print(f"  Val: {len(val)}")
-    print(f"  Total: {len(all_examples)}")
-
-
-if __name__ == "__main__":
-    main()

+ 16 - 6
finetune/train.py

@@ -50,15 +50,25 @@ def cmd_sft(args):
     dataset_name = cfg["dataset"]["name"]
     print(f"Loading dataset: {dataset_name}...")
 
-    # Support local JSONL files
+    # Support local JSONL files and glob patterns
     if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
         from pathlib import Path
-        data_path = Path(dataset_name)
-        if data_path.is_dir():
-            train_file = data_path / "train.jsonl"
-            dataset = load_dataset("json", data_files=str(train_file), split="train")
+        import glob
+
+        # Handle glob patterns like "data/*.jsonl"
+        if "*" in dataset_name:
+            jsonl_files = sorted(glob.glob(dataset_name))
+            if not jsonl_files:
+                raise ValueError(f"No files found matching: {dataset_name}")
+            print(f"  Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}")
+            dataset = load_dataset("json", data_files=jsonl_files, split="train")
         else:
-            dataset = load_dataset("json", data_files=dataset_name, split="train")
+            data_path = Path(dataset_name)
+            if data_path.is_dir():
+                train_file = data_path / "train.jsonl"
+                dataset = load_dataset("json", data_files=str(train_file), split="train")
+            else:
+                dataset = load_dataset("json", data_files=dataset_name, split="train")
     else:
         dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
     print(f"Dataset loaded: {len(dataset)} examples")