gemma-3-1b-it-Math-GRPO / train_grpo.py
NotoriousH2's picture
Add train_grpo.py
b50d571 verified
"""C27: GRPO (Group Relative Policy Optimization) for math reasoning
Based on DeepSeekMath GRPO + Gemma-2-2B success recipe from literature.
"""
import json, re, random, torch, numpy as np, os
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
if torch.cuda.get_device_capability()[0] >= 8:
torch.set_float32_matmul_precision('high')
BASE_MODEL = "outputs/models/c20-2-5x-replay"
SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"
# === Load questions + ground truth ===
with open("data/GSM8K_full_qwen3_30b.json") as f:
teacher_data = json.load(f)
def extract_boxed(text):
m = re.findall(r'\\boxed\{([^}]+)\}', text)
return m[-1].strip() if m else None
def normalize(a):
if a is None: return None
s = str(a).replace(",","").replace(" ","").strip()
try:
n = float(s)
return str(int(n)) if n == int(n) else str(n)
except: return s
# Build ground truth
gt_by_q = {}
for t in teacher_data:
ans = extract_boxed(t["answer"])
if ans is None: continue
q = t["question"]
na = normalize(ans)
if q not in gt_by_q: gt_by_q[q] = {}
gt_by_q[q][na] = gt_by_q[q].get(na, 0) + 1
ground_truth = {q: max(counts, key=counts.get) for q, counts in gt_by_q.items()}
questions = list(ground_truth.keys())
random.shuffle(questions)
print(f"Total questions: {len(questions)}")
# Build dataset with prompt (conversational format) + answer column
dataset_items = []
for q in questions:
dataset_items.append({
"prompt": [{"role": "user", "content": SP + "\n\n" + q}],
"answer": ground_truth[q],
})
train_ds = Dataset.from_list(dataset_items)
print(f"Train dataset: {len(train_ds)}")
# === Reward function ===
def math_reward(prompts, completions, answer, **kwargs):
"""Reward: 1.0 if boxed answer matches ground truth, 0.0 otherwise.
Also +0.1 bonus for having boxed format (encourages format compliance)."""
rewards = []
for completion, gt in zip(completions, answer):
# Handle conversational format (list of dicts) or plain string
if isinstance(completion, list):
text = completion[-1]["content"] if completion else ""
elif isinstance(completion, dict):
text = completion.get("content", "")
else:
text = str(completion)
pred = extract_boxed(text)
if pred is None:
rewards.append(0.0) # No boxed = 0
elif normalize(pred) == gt:
rewards.append(1.0) # Correct = 1.0
else:
rewards.append(0.1) # Wrong but has boxed format = 0.1
return rewards
# === GRPO Config ===
# Literature-informed settings for 1B model
NUM_GEN = 8
BATCH_SIZE = 8 # per device, must be divisible by num_generations
GRAD_ACCUM = 4 # effective batch = 8 * 4 = 32, / 8 gen = 4 prompts per step
config = GRPOConfig(
output_dir="outputs/c27_grpo_ckpt",
report_to="none",
seed=SEED,
# Generation
num_generations=NUM_GEN,
max_completion_length=1024,
temperature=0.7,
# GRPO algorithm
beta=0.04, # Higher KL penalty to preserve format/quality
loss_type="grpo", # Standard GRPO
epsilon=0.2, # PPO-style clipping
scale_rewards="group", # Normalize within group
# Training
num_train_epochs=1,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=5e-6,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
max_grad_norm=0.1, # Very strict gradient clipping (from literature)
optim="paged_adamw_8bit",
bf16=True,
gradient_checkpointing=True,
# Logging & saving
logging_steps=10,
save_strategy="no",
max_steps=500,
# vLLM for fast generation
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.3,
)
print(f"Config: num_gen={NUM_GEN}, batch={BATCH_SIZE}, ga={GRAD_ACCUM}")
print(f" effective_batch={BATCH_SIZE * GRAD_ACCUM}, prompts_per_step={BATCH_SIZE * GRAD_ACCUM // NUM_GEN}")
print(f" lr={config.learning_rate}, beta={config.beta}, max_steps={config.max_steps}")
print(f" vllm={config.use_vllm}, mode={config.vllm_mode}")
# === Load model ===
print("\nLoading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
# === Train ===
trainer = GRPOTrainer(
model=model,
reward_funcs=math_reward,
args=config,
train_dataset=train_ds,
processing_class=tokenizer,
)
print("\n=== Starting GRPO Training ===")
result = trainer.train()
print(f"\nTraining loss: {result.training_loss:.4f}")
# Save
SAVE_DIR = "outputs/models/c27-grpo"
os.makedirs(SAVE_DIR, exist_ok=True)
trainer.save_model(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
print(f"Saved: {SAVE_DIR}")
# Print training metrics summary
logs = trainer.state.log_history
reward_logs = [l for l in logs if "reward" in str(l)]
for l in reward_logs[-5:]:
print(f" step={l.get('step','?')}: reward={l.get('reward', l.get('rewards/mean','?'))}, "
f"completion_length={l.get('completion_length', '?')}")
print("\n=== GRPO Training Complete ===")