import os # Disable Unsloth compilation for GRPO stability - must be set before imports os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" import torch # from unsloth import FastLanguageModel # Moved to lazy import # Monkeypatch for TRANSFORMERS_CACHE which is needed by older llm_blender import transformers.utils.hub if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"): transformers.utils.hub.TRANSFORMERS_CACHE = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") import pandas as pd from dataclasses import dataclass, field from typing import Optional, List, Literal, Dict, Any from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig, GRPOTrainer, GRPOConfig from datasets import load_dataset, Dataset from transformers import TrainingArguments from huggingface_hub import HfApi from agentic_data_gen import AgenticDataGenerator, AgenticDataConfig @dataclass class TrainerConfig: model_name: str = "Qwen/Qwen2.5-7B" dataset_name: str = "" method: Literal["sft", "dpo", "grpo"] = "sft" platform: Literal["kaggle", "hf_jobs", "local"] = "local" max_seq_length: int = 4096 load_in_4bit: bool = True load_in_8bit: bool = False torch_dtype: str = "bfloat16" # "bfloat16", "float16", "float32" lora_r: int = 16 lora_alpha: int = 16 lora_dropout: float = 0 learning_rate: float = 2e-4 per_device_train_batch_size: int = 4 gradient_accumulation_steps: int = 2 num_train_epochs: int = 3 output_dir: str = "outputs" push_to_hub: bool = True hub_model_id: Optional[str] = None hf_token: Optional[str] = os.environ.get("HF_TOKEN") # Agentic Data Generation use_agentic_data: bool = False task_description: str = "" num_synthetic_records: int = 10 synthetic_data_path: str = "synthetic_data.jsonl" generate_reasoning: bool = False # Whether to generate ... format # GRPO-specific num_generations: int = 4 max_completion_length: int = 512 max_prompt_length: int = 512 use_compile: bool = False # Disable by default for GRPO stability class QwenTrainer: def __init__(self, config: TrainerConfig): self.config = config self.model = None self.tokenizer = None def setup_model(self): print(f"Loading model: {self.config.model_name}") # Determine torch_dtype if self.config.torch_dtype == "bfloat16": dtype = torch.bfloat16 elif self.config.torch_dtype == "float16": dtype = torch.float16 else: dtype = torch.float32 # GRPO Stability Fix: Use standard transformers for GRPO due to Unsloth bugs if self.config.method == "grpo": print(f"Using standard transformers + peft for GRPO stability (dtype: {self.config.torch_dtype})") from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, get_peft_model bnb_config = None if self.config.load_in_4bit: print("Loading in 4-bit quantization") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) elif self.config.load_in_8bit: print("Loading in 8-bit quantization") bnb_config = BitsAndBytesConfig( load_in_8bit=True, ) else: print(f"Loading in full {self.config.torch_dtype}") self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.config.model_name, quantization_config=bnb_config, torch_dtype=dtype, device_map="auto", ) # GRPO Stability Fix: Ensure all non-quantized parts are in the target dtype # This is critical for preventing scalar type mismatches during KL div calculation print(f"Ensuring non-quantized layers are in {self.config.torch_dtype}") for name, module in self.model.named_modules(): if "norm" in name.lower() or "lm_head" in name.lower() or "embed" in name.lower(): module.to(dtype) peft_config = LoraConfig( r=self.config.lora_r, lora_alpha=self.config.lora_alpha, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=self.config.lora_dropout, bias="none", task_type="CAUSAL_LM", ) self.model = get_peft_model(self.model, peft_config) # GRPO Stability Fix: Fix for TRL GRPOTrainer trying to access warnings_issued if not hasattr(self.model, "warnings_issued"): self.model.warnings_issued = {} else: # SFT and DPO still use Unsloth for performance from unsloth import FastLanguageModel self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=self.config.model_name, max_seq_length=self.config.max_seq_length, load_in_4bit=self.config.load_in_4bit, dtype=dtype, ) print("Attaching LoRA via Unsloth") self.model = FastLanguageModel.get_peft_model( self.model, r=self.config.lora_r, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=self.config.lora_alpha, lora_dropout=self.config.lora_dropout, bias="none", random_state=3407, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def generate_agentic_data(self): print(f"Generating agentic synthetic data for task: {self.config.task_description}") gen_config = AgenticDataConfig( num_records=self.config.num_synthetic_records, task_description=self.config.task_description, output_path=self.config.synthetic_data_path, min_quality_score=2, # Allow partially correct/minor issues to pass generate_dpo=(self.config.method == "dpo"), generate_reasoning=(self.config.method == "grpo" or self.config.generate_reasoning) ) generator = AgenticDataGenerator() df = generator.generate(gen_config) if df.empty: raise ValueError("No records passed the quality threshold during agentic data generation. Try a different task description or lower min_quality_score.") return df def load_data(self): if self.config.use_agentic_data: df = self.generate_agentic_data() dataset = Dataset.from_pandas(df) else: print(f"Loading dataset: {self.config.dataset_name}") if os.path.exists(self.config.dataset_name): ext = self.config.dataset_name.split(".")[-1] if ext in ["jsonl", "json"]: dataset = load_dataset("json", data_files=self.config.dataset_name, split="train") elif ext == "csv": dataset = load_dataset("csv", data_files=self.config.dataset_name, split="train") elif ext == "parquet": dataset = load_dataset("parquet", data_files=self.config.dataset_name, split="train") else: dataset = load_dataset(self.config.dataset_name, split="train") else: dataset = load_dataset(self.config.dataset_name, split="train") # Standard ChatML formatting if self.config.method == "sft": def format_chatml(example): return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"} dataset = dataset.map(format_chatml) elif self.config.method == "grpo": # For GRPO, we need 'prompt' column if "prompt" not in dataset.column_names: print("Mapping 'instruction' to 'prompt' for GRPO") def map_prompt(example): return {"prompt": example["instruction"]} dataset = dataset.map(map_prompt) elif self.config.method == "dpo": # For DPO, we need 'prompt', 'chosen', 'rejected' if "prompt" not in dataset.column_names: print("Mapping columns for DPO") def map_dpo(example): return { "prompt": example["instruction"], "chosen": example["output"], "rejected": example.get("rejected", "I don't know.") } dataset = dataset.map(map_dpo) return dataset def run_sft(self, dataset): print("Running SFT") trainer = SFTTrainer( model=self.model, tokenizer=self.tokenizer, train_dataset=dataset, dataset_text_field="text", max_seq_length=self.config.max_seq_length, args=SFTConfig( per_device_train_batch_size=self.config.per_device_train_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, num_train_epochs=self.config.num_train_epochs, learning_rate=self.config.learning_rate, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=3407, output_dir=self.config.output_dir, ), ) trainer.train() def run_dpo(self, dataset): print("Running DPO") trainer = DPOTrainer( model=self.model, tokenizer=self.tokenizer, train_dataset=dataset, args=DPOConfig( per_device_train_batch_size=self.config.per_device_train_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, num_train_epochs=self.config.num_train_epochs, learning_rate=self.config.learning_rate, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), logging_steps=1, optim="adamw_8bit", output_dir=self.config.output_dir, ), ) trainer.train() def run_grpo(self, dataset): print("Running GRPO") from rewards import RewardFunctions # For GRPO, batch size must be a multiple of num_generations # Unsloth prefers per_device_train_batch_size == num_generations batch_size = max(self.config.per_device_train_batch_size, self.config.num_generations) # Stability: adjust max_completion_length by 1 if it's a power of 2 or common boundary max_comp = self.config.max_completion_length if max_comp % 16 == 0: max_comp += 1 print(f"Adjusted max_completion_length to {max_comp} for stability") trainer = GRPOTrainer( model=self.model, args=GRPOConfig( per_device_train_batch_size=batch_size, num_generations=self.config.num_generations, learning_rate=self.config.learning_rate, max_completion_length=max_comp, # max_prompt_length=self.config.max_prompt_length, # Not supported in this version beta=0.01, warmup_steps=10, logging_steps=1, output_dir=self.config.output_dir, optim="adamw_8bit", seed=3407, ), reward_funcs=[RewardFunctions.combined_reward], train_dataset=dataset, ) trainer.train() def save_and_push(self): if self.config.push_to_hub: print(f"Saving and pushing to Hub: {self.config.hub_model_id}") if self.config.method != "grpo": from unsloth import FastLanguageModel if hasattr(self.model, "save_pretrained_merged"): self.model.save_pretrained_merged( "merged_model", self.tokenizer, save_method="merged_16bit" ) else: print("Merging and saving standard PEFT model") merged_model = self.model.merge_and_unload() merged_model.save_pretrained("merged_model") self.tokenizer.save_pretrained("merged_model") api = HfApi() api.create_repo(repo_id=self.config.hub_model_id, token=self.config.hf_token, exist_ok=True) api.upload_folder( folder_path="merged_model", repo_id=self.config.hub_model_id, token=self.config.hf_token, ) def run(self): self.setup_model() dataset = self.load_data() if self.config.method == "sft": self.run_sft(dataset) elif self.config.method == "dpo": self.run_dpo(dataset) elif self.config.method == "grpo": self.run_grpo(dataset) self.save_and_push() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Qwen Unified Trainer (SFT, DPO, GRPO)") # Model/Dataset parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-2B", help="HF model ID") parser.add_argument("--dataset", type=str, default="", help="HF dataset name or local path") parser.add_argument("--method", type=str, choices=["sft", "dpo", "grpo"], default="sft", help="Training method") # Training Hyperparameters parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=4, help="Batch size per device") parser.add_argument("--grad_acc", type=int, default=2, help="Gradient accumulation steps") parser.add_argument("--max_seq_len", type=int, default=2048, help="Max sequence length") # Agentic Data parser.add_argument("--use_agentic", action="store_true", help="Generate synthetic data before training") parser.add_argument("--task", type=str, default="", help="Task description for synthetic data") parser.add_argument("--num_synthetic", type=int, default=10, help="Number of synthetic records") parser.add_argument("--synthetic_path", type=str, default="synthetic_data.jsonl", help="Path to save synthetic data") parser.add_argument("--reasoning", action="store_true", help="Generate reasoning format") # Output/Hub parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory") parser.add_argument("--push", action="store_true", help="Push to HF Hub") parser.add_argument("--hub_id", type=str, default=None, help="HF Hub model ID") parser.add_argument("--no_compile", action="store_true", help="Disable Unsloth compilation for stability") parser.add_argument("--dtype", type=str, choices=["bfloat16", "float16", "float32"], default="bfloat16", help="Torch dtype") parser.add_argument("--load_8bit", action="store_true", help="Load in 8-bit") parser.add_argument("--no_4bit", action="store_true", help="Disable 4-bit loading") args = parser.parse_args() config = TrainerConfig( model_name=args.model, dataset_name=args.dataset, method=args.method, learning_rate=args.lr, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_acc, max_seq_length=args.max_seq_len, use_agentic_data=args.use_agentic, task_description=args.task, num_synthetic_records=args.num_synthetic, synthetic_data_path=args.synthetic_path, generate_reasoning=args.reasoning, output_dir=args.output_dir, push_to_hub=args.push, hub_model_id=args.hub_id, use_compile=not args.no_compile, torch_dtype=args.dtype, load_in_8bit=args.load_8bit, load_in_4bit=not args.no_4bit ) trainer = QwenTrainer(config) trainer.run()