|
@@ -290,7 +290,11 @@ def score_expansion(query: str, expansion: str) -> float:
|
|
|
|
|
|
|
|
def extract_query_from_prompt(prompt: str) -> str:
|
|
def extract_query_from_prompt(prompt: str) -> str:
|
|
|
if "Expand this search query:" in prompt:
|
|
if "Expand this search query:" in prompt:
|
|
|
- return prompt.split("Expand this search query:")[-1].strip()
|
|
|
|
|
|
|
+ query = prompt.split("Expand this search query:")[-1].strip()
|
|
|
|
|
+ # Remove chat template artifacts if present
|
|
|
|
|
+ if "<|im_end|>" in query:
|
|
|
|
|
+ query = query.split("<|im_end|>")[0].strip()
|
|
|
|
|
+ return query
|
|
|
return prompt.strip()
|
|
return prompt.strip()
|
|
|
|
|
|
|
|
|
|
|
|
@@ -323,23 +327,27 @@ def main():
|
|
|
print("Logging in to HuggingFace Hub...")
|
|
print("Logging in to HuggingFace Hub...")
|
|
|
login(token=hf_token)
|
|
login(token=hf_token)
|
|
|
|
|
|
|
|
|
|
+ # Load tokenizer first (needed for chat template)
|
|
|
|
|
+ print(f"Loading tokenizer from {BASE_MODEL}...")
|
|
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
|
|
|
+ if tokenizer.pad_token is None:
|
|
|
|
|
+ tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
+
|
|
|
# Load dataset
|
|
# Load dataset
|
|
|
print("Loading dataset...")
|
|
print("Loading dataset...")
|
|
|
dataset = load_dataset(DATASET, split="train")
|
|
dataset = load_dataset(DATASET, split="train")
|
|
|
|
|
|
|
|
def extract_prompt(example):
|
|
def extract_prompt(example):
|
|
|
- return {"prompt": example["messages"][0]["content"]}
|
|
|
|
|
|
|
+ # Apply chat template so model sees the same format as SFT training
|
|
|
|
|
+ content = example["messages"][0]["content"]
|
|
|
|
|
+ messages = [{"role": "user", "content": content}]
|
|
|
|
|
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
+ return {"prompt": formatted}
|
|
|
|
|
|
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
|
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
|
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
|
|
|
|
|
|
- # Load tokenizer and model
|
|
|
|
|
- print(f"Loading tokenizer from {BASE_MODEL}...")
|
|
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
|
|
|
- if tokenizer.pad_token is None:
|
|
|
|
|
- tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
-
|
|
|
|
|
print(f"Loading SFT model from {SFT_MODEL}...")
|
|
print(f"Loading SFT model from {SFT_MODEL}...")
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
BASE_MODEL,
|
|
BASE_MODEL,
|