| import os |
| |
| os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" |
|
|
| import torch |
| |
| |
| import transformers.utils.hub |
| if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"): |
| transformers.utils.hub.TRANSFORMERS_CACHE = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") |
|
|
| import pandas as pd |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Literal, Dict, Any |
| from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig, GRPOTrainer, GRPOConfig |
| from datasets import load_dataset, Dataset |
| from transformers import TrainingArguments |
| from huggingface_hub import HfApi |
| from agentic_data_gen import AgenticDataGenerator, AgenticDataConfig |
|
|
| @dataclass |
| class TrainerConfig: |
| model_name: str = "Qwen/Qwen2.5-7B" |
| dataset_name: str = "" |
| method: Literal["sft", "dpo", "grpo"] = "sft" |
| platform: Literal["kaggle", "hf_jobs", "local"] = "local" |
| max_seq_length: int = 4096 |
| load_in_4bit: bool = True |
| load_in_8bit: bool = False |
| torch_dtype: str = "bfloat16" |
| lora_r: int = 16 |
| lora_alpha: int = 16 |
| lora_dropout: float = 0 |
| learning_rate: float = 2e-4 |
| per_device_train_batch_size: int = 4 |
| gradient_accumulation_steps: int = 2 |
| num_train_epochs: int = 3 |
| output_dir: str = "outputs" |
| push_to_hub: bool = True |
| hub_model_id: Optional[str] = None |
| hf_token: Optional[str] = os.environ.get("HF_TOKEN") |
| |
| |
| use_agentic_data: bool = False |
| task_description: str = "" |
| num_synthetic_records: int = 10 |
| synthetic_data_path: str = "synthetic_data.jsonl" |
| generate_reasoning: bool = False |
| |
| |
| num_generations: int = 4 |
| max_completion_length: int = 512 |
| max_prompt_length: int = 512 |
| use_compile: bool = False |
|
|
| class QwenTrainer: |
| def __init__(self, config: TrainerConfig): |
| self.config = config |
| self.model = None |
| self.tokenizer = None |
|
|
| def setup_model(self): |
| print(f"Loading model: {self.config.model_name}") |
| |
| |
| if self.config.torch_dtype == "bfloat16": |
| dtype = torch.bfloat16 |
| elif self.config.torch_dtype == "float16": |
| dtype = torch.float16 |
| else: |
| dtype = torch.float32 |
|
|
| |
| if self.config.method == "grpo": |
| print(f"Using standard transformers + peft for GRPO stability (dtype: {self.config.torch_dtype})") |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from peft import LoraConfig, get_peft_model |
| |
| bnb_config = None |
| if self.config.load_in_4bit: |
| print("Loading in 4-bit quantization") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=dtype, |
| bnb_4bit_use_double_quant=True, |
| ) |
| elif self.config.load_in_8bit: |
| print("Loading in 8-bit quantization") |
| bnb_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| ) |
| else: |
| print(f"Loading in full {self.config.torch_dtype}") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.config.model_name, |
| quantization_config=bnb_config, |
| torch_dtype=dtype, |
| device_map="auto", |
| ) |
| |
| |
| |
| print(f"Ensuring non-quantized layers are in {self.config.torch_dtype}") |
| for name, module in self.model.named_modules(): |
| if "norm" in name.lower() or "lm_head" in name.lower() or "embed" in name.lower(): |
| module.to(dtype) |
| |
| peft_config = LoraConfig( |
| r=self.config.lora_r, |
| lora_alpha=self.config.lora_alpha, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=self.config.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| self.model = get_peft_model(self.model, peft_config) |
| |
| |
| if not hasattr(self.model, "warnings_issued"): |
| self.model.warnings_issued = {} |
| |
| else: |
| |
| from unsloth import FastLanguageModel |
| self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
| model_name=self.config.model_name, |
| max_seq_length=self.config.max_seq_length, |
| load_in_4bit=self.config.load_in_4bit, |
| dtype=dtype, |
| ) |
| |
| print("Attaching LoRA via Unsloth") |
| self.model = FastLanguageModel.get_peft_model( |
| self.model, |
| r=self.config.lora_r, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_alpha=self.config.lora_alpha, |
| lora_dropout=self.config.lora_dropout, |
| bias="none", |
| random_state=3407, |
| ) |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def generate_agentic_data(self): |
| print(f"Generating agentic synthetic data for task: {self.config.task_description}") |
| gen_config = AgenticDataConfig( |
| num_records=self.config.num_synthetic_records, |
| task_description=self.config.task_description, |
| output_path=self.config.synthetic_data_path, |
| min_quality_score=2, |
| generate_dpo=(self.config.method == "dpo"), |
| generate_reasoning=(self.config.method == "grpo" or self.config.generate_reasoning) |
| ) |
| generator = AgenticDataGenerator() |
| df = generator.generate(gen_config) |
| |
| if df.empty: |
| raise ValueError("No records passed the quality threshold during agentic data generation. Try a different task description or lower min_quality_score.") |
| |
| return df |
|
|
| def load_data(self): |
| if self.config.use_agentic_data: |
| df = self.generate_agentic_data() |
| dataset = Dataset.from_pandas(df) |
| else: |
| print(f"Loading dataset: {self.config.dataset_name}") |
| if os.path.exists(self.config.dataset_name): |
| ext = self.config.dataset_name.split(".")[-1] |
| if ext in ["jsonl", "json"]: |
| dataset = load_dataset("json", data_files=self.config.dataset_name, split="train") |
| elif ext == "csv": |
| dataset = load_dataset("csv", data_files=self.config.dataset_name, split="train") |
| elif ext == "parquet": |
| dataset = load_dataset("parquet", data_files=self.config.dataset_name, split="train") |
| else: |
| dataset = load_dataset(self.config.dataset_name, split="train") |
| else: |
| dataset = load_dataset(self.config.dataset_name, split="train") |
| |
| |
| if self.config.method == "sft": |
| def format_chatml(example): |
| return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"} |
| dataset = dataset.map(format_chatml) |
| elif self.config.method == "grpo": |
| |
| if "prompt" not in dataset.column_names: |
| print("Mapping 'instruction' to 'prompt' for GRPO") |
| def map_prompt(example): |
| return {"prompt": example["instruction"]} |
| dataset = dataset.map(map_prompt) |
| elif self.config.method == "dpo": |
| |
| if "prompt" not in dataset.column_names: |
| print("Mapping columns for DPO") |
| def map_dpo(example): |
| return { |
| "prompt": example["instruction"], |
| "chosen": example["output"], |
| "rejected": example.get("rejected", "I don't know.") |
| } |
| dataset = dataset.map(map_dpo) |
| |
| return dataset |
|
|
| def run_sft(self, dataset): |
| print("Running SFT") |
| trainer = SFTTrainer( |
| model=self.model, |
| tokenizer=self.tokenizer, |
| train_dataset=dataset, |
| dataset_text_field="text", |
| max_seq_length=self.config.max_seq_length, |
| args=SFTConfig( |
| per_device_train_batch_size=self.config.per_device_train_batch_size, |
| gradient_accumulation_steps=self.config.gradient_accumulation_steps, |
| num_train_epochs=self.config.num_train_epochs, |
| learning_rate=self.config.learning_rate, |
| fp16=not torch.cuda.is_bf16_supported(), |
| bf16=torch.cuda.is_bf16_supported(), |
| logging_steps=1, |
| optim="adamw_8bit", |
| weight_decay=0.01, |
| lr_scheduler_type="linear", |
| seed=3407, |
| output_dir=self.config.output_dir, |
| ), |
| ) |
| trainer.train() |
|
|
| def run_dpo(self, dataset): |
| print("Running DPO") |
| trainer = DPOTrainer( |
| model=self.model, |
| tokenizer=self.tokenizer, |
| train_dataset=dataset, |
| args=DPOConfig( |
| per_device_train_batch_size=self.config.per_device_train_batch_size, |
| gradient_accumulation_steps=self.config.gradient_accumulation_steps, |
| num_train_epochs=self.config.num_train_epochs, |
| learning_rate=self.config.learning_rate, |
| fp16=not torch.cuda.is_bf16_supported(), |
| bf16=torch.cuda.is_bf16_supported(), |
| logging_steps=1, |
| optim="adamw_8bit", |
| output_dir=self.config.output_dir, |
| ), |
| ) |
| trainer.train() |
|
|
| def run_grpo(self, dataset): |
| print("Running GRPO") |
| from rewards import RewardFunctions |
| |
| |
| batch_size = max(self.config.per_device_train_batch_size, self.config.num_generations) |
| |
| |
| max_comp = self.config.max_completion_length |
| if max_comp % 16 == 0: |
| max_comp += 1 |
| print(f"Adjusted max_completion_length to {max_comp} for stability") |
|
|
| trainer = GRPOTrainer( |
| model=self.model, |
| args=GRPOConfig( |
| per_device_train_batch_size=batch_size, |
| num_generations=self.config.num_generations, |
| learning_rate=self.config.learning_rate, |
| max_completion_length=max_comp, |
| |
| beta=0.01, |
| warmup_steps=10, |
| logging_steps=1, |
| output_dir=self.config.output_dir, |
| optim="adamw_8bit", |
| seed=3407, |
| ), |
| reward_funcs=[RewardFunctions.combined_reward], |
| train_dataset=dataset, |
| ) |
| trainer.train() |
|
|
| def save_and_push(self): |
| if self.config.push_to_hub: |
| print(f"Saving and pushing to Hub: {self.config.hub_model_id}") |
| if self.config.method != "grpo": |
| from unsloth import FastLanguageModel |
| |
| if hasattr(self.model, "save_pretrained_merged"): |
| self.model.save_pretrained_merged( |
| "merged_model", self.tokenizer, save_method="merged_16bit" |
| ) |
| else: |
| print("Merging and saving standard PEFT model") |
| merged_model = self.model.merge_and_unload() |
| merged_model.save_pretrained("merged_model") |
| self.tokenizer.save_pretrained("merged_model") |
| |
| api = HfApi() |
| api.create_repo(repo_id=self.config.hub_model_id, token=self.config.hf_token, exist_ok=True) |
| api.upload_folder( |
| folder_path="merged_model", |
| repo_id=self.config.hub_model_id, |
| token=self.config.hf_token, |
| ) |
|
|
| def run(self): |
| self.setup_model() |
| dataset = self.load_data() |
| |
| if self.config.method == "sft": |
| self.run_sft(dataset) |
| elif self.config.method == "dpo": |
| self.run_dpo(dataset) |
| elif self.config.method == "grpo": |
| self.run_grpo(dataset) |
| |
| self.save_and_push() |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="Qwen Unified Trainer (SFT, DPO, GRPO)") |
| |
| |
| parser.add_argument("--model", type=str, default="Qwen/Qwen3.5-2B", help="HF model ID") |
| parser.add_argument("--dataset", type=str, default="", help="HF dataset name or local path") |
| parser.add_argument("--method", type=str, choices=["sft", "dpo", "grpo"], default="sft", help="Training method") |
| |
| |
| parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") |
| parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") |
| parser.add_argument("--batch_size", type=int, default=4, help="Batch size per device") |
| parser.add_argument("--grad_acc", type=int, default=2, help="Gradient accumulation steps") |
| parser.add_argument("--max_seq_len", type=int, default=2048, help="Max sequence length") |
| |
| |
| parser.add_argument("--use_agentic", action="store_true", help="Generate synthetic data before training") |
| parser.add_argument("--task", type=str, default="", help="Task description for synthetic data") |
| parser.add_argument("--num_synthetic", type=int, default=10, help="Number of synthetic records") |
| parser.add_argument("--synthetic_path", type=str, default="synthetic_data.jsonl", help="Path to save synthetic data") |
| parser.add_argument("--reasoning", action="store_true", help="Generate reasoning format") |
| |
| |
| parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory") |
| parser.add_argument("--push", action="store_true", help="Push to HF Hub") |
| parser.add_argument("--hub_id", type=str, default=None, help="HF Hub model ID") |
| parser.add_argument("--no_compile", action="store_true", help="Disable Unsloth compilation for stability") |
| parser.add_argument("--dtype", type=str, choices=["bfloat16", "float16", "float32"], default="bfloat16", help="Torch dtype") |
| parser.add_argument("--load_8bit", action="store_true", help="Load in 8-bit") |
| parser.add_argument("--no_4bit", action="store_true", help="Disable 4-bit loading") |
| |
| args = parser.parse_args() |
| |
| config = TrainerConfig( |
| model_name=args.model, |
| dataset_name=args.dataset, |
| method=args.method, |
| learning_rate=args.lr, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_acc, |
| max_seq_length=args.max_seq_len, |
| use_agentic_data=args.use_agentic, |
| task_description=args.task, |
| num_synthetic_records=args.num_synthetic, |
| synthetic_data_path=args.synthetic_path, |
| generate_reasoning=args.reasoning, |
| output_dir=args.output_dir, |
| push_to_hub=args.push, |
| hub_model_id=args.hub_id, |
| use_compile=not args.no_compile, |
| torch_dtype=args.dtype, |
| load_in_8bit=args.load_8bit, |
| load_in_4bit=not args.no_4bit |
| ) |
| |
| trainer = QwenTrainer(config) |
| trainer.run() |
|
|