gemma-3-1b-it-Math-GRPO / train_rs_sft.py
NotoriousH2's picture
Add train_rs_sft.py
12dd0e7 verified
"""C20: Variants of C18-2 (the 48.5% recipe) with different replay ratios"""
import json, re, random, torch, numpy as np, os
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision('high')
SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"
BASE = "outputs/models/c17d-gemma-3-1b-it-Math"
# Load RS1+RS2 (the winning combo)
with open("outputs/c17d_rs/sft_dataset.json") as f:
rs1 = json.load(f)
with open("outputs/c17d_rs2/sft_dataset.json") as f:
rs2 = json.load(f)
seen = set()
rs_combined = []
for d in rs1 + rs2:
if d["n_correct"] < 4: continue
key = (d["question"], d["answer"])
if key not in seen:
seen.add(key)
rs_combined.append({"question": d["question"], "answer": d["answer"], "source": "gsm8k"})
print(f"RS1+RS2 combined: {len(rs_combined)}")
with open("data/GSM8K_full_qwen3_30b.json") as f:
orig_data = json.load(f)
orig_filtered = [d for d in orig_data if len(d["answer"]) <= 1500]
def to_sft(ex):
return {"prompt": [{"role":"user","content":SP+"\n\n"+ex["question"]}],
"completion": [{"role":"assistant","content":ex["answer"]}]}
# === Condition 1: RS1+RS2 + 2x replay (more aggressive RS) ===
print("\n=== C20-1: RS1+RS2 + 2x replay ===")
random.seed(SEED)
rs_qs = set(d["question"] for d in rs_combined)
replay = [d for d in orig_filtered if d["question"] not in rs_qs]
random.shuffle(replay)
replay1 = replay[:int(len(rs_combined) * 2)]
mixed1 = rs_combined + replay1
random.shuffle(mixed1)
print(f" RS: {len(rs_combined)} + replay: {len(replay1)} = {len(mixed1)}")
ds1 = Dataset.from_list(mixed1)
cols = [c for c in ds1.column_names if c not in ["prompt","completion"]]
ds1 = ds1.map(to_sft, remove_columns=cols)
tokenizer = AutoTokenizer.from_pretrained(BASE)
model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
tokenizer.pad_token = tokenizer.eos_token
model.gradient_checkpointing_enable(); model.config.use_cache = False
cfg1 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit",
output_dir="outputs/c20_1_ckpt", logging_steps=25, save_strategy="no")
trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds1, args=cfg1)
r = trainer.train()
print(f" Loss: {r.training_loss:.4f}")
SAVE1 = "outputs/models/c20-1-2x-replay"
os.makedirs(SAVE1, exist_ok=True)
model.eval(); model.save_pretrained(SAVE1, safe_serialization=False)
tokenizer.save_pretrained(SAVE1)
del model, trainer; torch.cuda.empty_cache()
# === Condition 2: RS1+RS2 + 5x replay (more teacher data) ===
print("\n=== C20-2: RS1+RS2 + 5x replay ===")
random.seed(SEED)
replay = [d for d in orig_filtered if d["question"] not in rs_qs]
random.shuffle(replay)
replay2 = replay[:int(len(rs_combined) * 5)]
mixed2 = rs_combined + replay2
random.shuffle(mixed2)
print(f" RS: {len(rs_combined)} + replay: {len(replay2)} = {len(mixed2)}")
ds2 = Dataset.from_list(mixed2)
cols = [c for c in ds2.column_names if c not in ["prompt","completion"]]
ds2 = ds2.map(to_sft, remove_columns=cols)
tokenizer = AutoTokenizer.from_pretrained(BASE)
model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
tokenizer.pad_token = tokenizer.eos_token
model.gradient_checkpointing_enable(); model.config.use_cache = False
cfg2 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit",
output_dir="outputs/c20_2_ckpt", logging_steps=25, save_strategy="no")
trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds2, args=cfg2)
r = trainer.train()
print(f" Loss: {r.training_loss:.4f}")
SAVE2 = "outputs/models/c20-2-5x-replay"
os.makedirs(SAVE2, exist_ok=True)
model.eval(); model.save_pretrained(SAVE2, safe_serialization=False)
tokenizer.save_pretrained(SAVE2)
del model, trainer; torch.cuda.empty_cache()
# === Condition 3: RS1+RS2 + 3x replay + lr=3e-6 (higher lr) ===
print("\n=== C20-3: RS1+RS2 + 3x replay + lr=3e-6 ===")
random.seed(SEED)
replay = [d for d in orig_filtered if d["question"] not in rs_qs]
random.shuffle(replay)
replay3 = replay[:int(len(rs_combined) * 3)]
mixed3 = rs_combined + replay3
random.shuffle(mixed3)
print(f" RS: {len(rs_combined)} + replay: {len(replay3)} = {len(mixed3)}")
ds3 = Dataset.from_list(mixed3)
cols = [c for c in ds3.column_names if c not in ["prompt","completion"]]
ds3 = ds3.map(to_sft, remove_columns=cols)
tokenizer = AutoTokenizer.from_pretrained(BASE)
model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
tokenizer.pad_token = tokenizer.eos_token
model.gradient_checkpointing_enable(); model.config.use_cache = False
cfg3 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
learning_rate=3e-6, bf16=True, optim="paged_adamw_8bit",
output_dir="outputs/c20_3_ckpt", logging_steps=25, save_strategy="no")
trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds3, args=cfg3)
r = trainer.train()
print(f" Loss: {r.training_loss:.4f}")
SAVE3 = "outputs/models/c20-3-lr3e-6"
os.makedirs(SAVE3, exist_ok=True)
model.eval(); model.save_pretrained(SAVE3, safe_serialization=False)
tokenizer.save_pretrained(SAVE3)
del model, trainer; torch.cuda.empty_cache()
print("\n=== All conditions complete ===")