mindchain's picture
Upload folder using huggingface_hub
78a0ca9 verified
import os
# Disable Unsloth compilation for GRPO stability - must be set before imports
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
import torch
# from unsloth import FastLanguageModel # Moved to lazy import
# Monkeypatch for TRANSFORMERS_CACHE which is needed by older llm_blender
import transformers.utils.hub
if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
transformers.utils.hub.TRANSFORMERS_CACHE = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
import pandas as pd
from dataclasses import dataclass, field
from typing import Optional, List, Literal, Dict, Any
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig, GRPOTrainer, GRPOConfig
from datasets import load_dataset, Dataset
from transformers import TrainingArguments
from huggingface_hub import HfApi
from agentic_data_gen import AgenticDataGenerator, AgenticDataConfig
@dataclass
class TrainerConfig:
model_name: str = "Qwen/Qwen2.5-7B"
dataset_name: str = ""
method: Literal["sft", "dpo", "grpo"] = "sft"
platform: Literal["kaggle", "hf_jobs", "local"] = "local"
max_seq_length: int = 4096
load_in_4bit: bool = True
load_in_8bit: bool = False
torch_dtype: str = "bfloat16" # "bfloat16", "float16", "float32"
lora_r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0
learning_rate: float = 2e-4
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 2
num_train_epochs: int = 3
output_dir: str = "outputs"
push_to_hub: bool = True
hub_model_id: Optional[str] = None
hf_token: Optional[str] = os.environ.get("HF_TOKEN")
# Agentic Data Generation
use_agentic_data: bool = False
task_description: str = ""
num_synthetic_records: int = 10
synthetic_data_path: str = "synthetic_data.jsonl"
generate_reasoning: bool = False # Whether to generate <reasoning>...<answer> format
# GRPO-specific
num_generations: int = 4
max_completion_length: int = 512
max_prompt_length: int = 512
use_compile: bool = False # Disable by default for GRPO stability
class QwenTrainer:
def __init__(self, config: TrainerConfig):
self.config = config
self.model = None
self.tokenizer = None
def setup_model(self):
print(f"Loading model: {self.config.model_name}")
# Determine torch_dtype
if self.config.torch_dtype == "bfloat16":
dtype = torch.bfloat16
elif self.config.torch_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
# GRPO Stability Fix: Use standard transformers for GRPO due to Unsloth bugs
if self.config.method == "grpo":
print(f"Using standard transformers + peft for GRPO stability (dtype: {self.config.torch_dtype})")
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
bnb_config = None
if self.config.load_in_4bit:
print("Loading in 4-bit quantization")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
elif self.config.load_in_8bit:
print("Loading in 8-bit quantization")
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
print(f"Loading in full {self.config.torch_dtype}")
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
quantization_config=bnb_config,
torch_dtype=dtype,
device_map="auto",
)
# GRPO Stability Fix: Ensure all non-quantized parts are in the target dtype
# This is critical for preventing scalar type mismatches during KL div calculation
print(f"Ensuring non-quantized layers are in {self.config.torch_dtype}")
for name, module in self.model.named_modules():
if "norm" in name.lower() or "lm_head" in name.lower() or "embed" in name.lower():
module.to(dtype)
peft_config = LoraConfig(
r=self.config.lora_r,
lora_alpha=self.config.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=self.config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
self.model = get_peft_model(self.model, peft_config)
# GRPO Stability Fix: Fix for TRL GRPOTrainer trying to access warnings_issued
if not hasattr(self.model, "warnings_issued"):
self.model.warnings_issued = {}
else:
# SFT and DPO still use Unsloth for performance
from unsloth import FastLanguageModel
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config.model_name,
max_seq_length=self.config.max_seq_length,
load_in_4bit=self.config.load_in_4bit,
dtype=dtype,
)
print("Attaching LoRA via Unsloth")
self.model = FastLanguageModel.get_peft_model(
self.model,
r=self.config.lora_r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=self.config.lora_alpha,
lora_dropout=self.config.lora_dropout,
bias="none",
random_state=3407,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def generate_agentic_data(self):
print(f"Generating agentic synthetic data for task: {self.config.task_description}")
gen_config = AgenticDataConfig(
num_records=self.config.num_synthetic_records,
task_description=self.config.task_description,
output_path=self.config.synthetic_data_path,
min_quality_score=2, # Allow partially correct/minor issues to pass
generate_dpo=(self.config.method == "dpo"),
generate_reasoning=(self.config.method == "grpo" or self.config.generate_reasoning)
)
generator = AgenticDataGenerator()
df = generator.generate(gen_config)
if df.empty:
raise ValueError("No records passed the quality threshold during agentic data generation. Try a different task description or lower min_quality_score.")
return df
def load_data(self):
if self.config.use_agentic_data:
df = self.generate_agentic_data()
dataset = Dataset.from_pandas(df)
else:
print(f"Loading dataset: {self.config.dataset_name}")
if os.path.exists(self.config.dataset_name):
ext = self.config.dataset_name.split(".")[-1]
if ext in ["jsonl", "json"]:
dataset = load_dataset("json", data_files=self.config.dataset_name, split="train")
elif ext == "csv":
dataset = load_dataset("csv", data_files=self.config.dataset_name, split="train")
elif ext == "parquet":
dataset = load_dataset("parquet", data_files=self.config.dataset_name, split="train")
else:
dataset = load_dataset(self.config.dataset_name, split="train")
else:
dataset = load_dataset(self.config.dataset_name, split="train")
# Standard ChatML formatting
if self.config.method == "sft":
def format_chatml(example):
return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"}
dataset = dataset.map(format_chatml)
elif self.config.method == "grpo":
# For GRPO, we need 'prompt' column
if "prompt" not in dataset.column_names:
print("Mapping 'instruction' to 'prompt' for GRPO")
def map_prompt(example):
return {"prompt": example["instruction"]}
dataset = dataset.map(map_prompt)
elif self.config.method == "dpo":
# For DPO, we need 'prompt', 'chosen', 'rejected'
if "prompt" not in dataset.column_names:
print("Mapping columns for DPO")
def map_dpo(example):
return {
"prompt": example["instruction"],
"chosen": example["output"],
"rejected": example.get("rejected", "I don't know.")
}
dataset = dataset.map(map_dpo)
return dataset
def run_sft(self, dataset):
print("Running SFT")
trainer = SFTTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=self.config.max_seq_length,
args=SFTConfig(
per_device_train_batch_size=self.config.per_device_train_batch_size,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
num_train_epochs=self.config.num_train_epochs,
learning_rate=self.config.learning_rate,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir=self.config.output_dir,
),
)
trainer.train()
def run_dpo(self, dataset):
print("Running DPO")
trainer = DPOTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=dataset,
args=DPOConfig(
per_device_train_batch_size=self.config.per_device_train_batch_size,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
num_train_epochs=self.config.num_train_epochs,
learning_rate=self.config.learning_rate,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
output_dir=self.config.output_dir,
),
)
trainer.train()
def run_grpo(self, dataset):
print("Running GRPO")
from rewards import RewardFunctions
# For GRPO, batch size must be a multiple of num_generations
# Unsloth prefers per_device_train_batch_size == num_generations
batch_size = max(self.config.per_device_train_batch_size, self.config.num_generations)
# Stability: adjust max_completion_length by 1 if it's a power of 2 or common boundary
max_comp = self.config.max_completion_length
if max_comp % 16 == 0:
max_comp += 1
print(f"Adjusted max_completion_length to {max_comp} for stability")
trainer = GRPOTrainer(
model=self.model,
args=GRPOConfig(
per_device_train_batch_size=batch_size,
num_generations=self.config.num_generations,
learning_rate=self.config.learning_rate,
max_completion_length=max_comp,
# max_prompt_length=self.config.max_prompt_length, # Not supported in this version
beta=0.01,
warmup_steps=10,
logging_steps=1,
output_dir=self.config.output_dir,
optim="adamw_8bit",
seed=3407,
),
reward_funcs=[RewardFunctions.combined_reward],
train_dataset=dataset,
)
trainer.train()
def save_and_push(self):
if self.config.push_to_hub:
print(f"Saving and pushing to Hub: {self.config.hub_model_id}")
if self.config.method != "grpo":
from unsloth import FastLanguageModel
if hasattr(self.model, "save_pretrained_merged"):
self.model.save_pretrained_merged(
"merged_model", self.tokenizer, save_method="merged_16bit"
)
else:
print("Merging and saving standard PEFT model")
merged_model = self.model.merge_and_unload()
merged_model.save_pretrained("merged_model")
self.tokenizer.save_pretrained("merged_model")
api = HfApi()
api.create_repo(repo_id=self.config.hub_model_id, token=self.config.hf_token, exist_ok=True)
api.upload_folder(
folder_path="merged_model",
repo_id=self.config.hub_model_id,
token=self.config.hf_token,
)
def run(self):
self.setup_model()
dataset = self.load_data()
if self.config.method == "sft":
self.run_sft(dataset)
elif self.config.method == "dpo":
self.run_dpo(dataset)
elif self.config.method == "grpo":
self.run_grpo(dataset)
self.save_and_push()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Qwen Unified Trainer (SFT, DPO, GRPO)")
# Model/Dataset
parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-2B", help="HF model ID")
parser.add_argument("--dataset", type=str, default="", help="HF dataset name or local path")
parser.add_argument("--method", type=str, choices=["sft", "dpo", "grpo"], default="sft", help="Training method")
# Training Hyperparameters
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size per device")
parser.add_argument("--grad_acc", type=int, default=2, help="Gradient accumulation steps")
parser.add_argument("--max_seq_len", type=int, default=2048, help="Max sequence length")
# Agentic Data
parser.add_argument("--use_agentic", action="store_true", help="Generate synthetic data before training")
parser.add_argument("--task", type=str, default="", help="Task description for synthetic data")
parser.add_argument("--num_synthetic", type=int, default=10, help="Number of synthetic records")
parser.add_argument("--synthetic_path", type=str, default="synthetic_data.jsonl", help="Path to save synthetic data")
parser.add_argument("--reasoning", action="store_true", help="Generate reasoning format")
# Output/Hub
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
parser.add_argument("--push", action="store_true", help="Push to HF Hub")
parser.add_argument("--hub_id", type=str, default=None, help="HF Hub model ID")
parser.add_argument("--no_compile", action="store_true", help="Disable Unsloth compilation for stability")
parser.add_argument("--dtype", type=str, choices=["bfloat16", "float16", "float32"], default="bfloat16", help="Torch dtype")
parser.add_argument("--load_8bit", action="store_true", help="Load in 8-bit")
parser.add_argument("--no_4bit", action="store_true", help="Disable 4-bit loading")
args = parser.parse_args()
config = TrainerConfig(
model_name=args.model,
dataset_name=args.dataset,
method=args.method,
learning_rate=args.lr,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_acc,
max_seq_length=args.max_seq_len,
use_agentic_data=args.use_agentic,
task_description=args.task,
num_synthetic_records=args.num_synthetic,
synthetic_data_path=args.synthetic_path,
generate_reasoning=args.reasoning,
output_dir=args.output_dir,
push_to_hub=args.push,
hub_model_id=args.hub_id,
use_compile=not args.no_compile,
torch_dtype=args.dtype,
load_in_8bit=args.load_8bit,
load_in_4bit=not args.no_4bit
)
trainer = QwenTrainer(config)
trainer.run()