Zenith-7b-V1 / finetune_qwen.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
#!/usr/bin/env python3
"""Fine-tune Qwen2.5-Coder-7B with Zenith framework"""
import torch
import argparse
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_7b_config
from models.zenith_model import ZenithForCausalLM
from data.openthoughts_processor import OpenThoughtsProcessor, OpenThoughtsConfig
from training.trainer import Trainer, TrainerConfig
from data.advanced_tokenizer import AdvancedTokenizer
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Fine-tune Qwen2.5-Coder-7B")
# Model and data
parser.add_argument(
"--base_model",
type=str,
default="Qwen/Qwen2.5-Coder-7B",
help="Base model from Hugging Face"
)
parser.add_argument(
"--train_data",
type=str,
required=True,
help="Path to training data (JSON or HuggingFace dataset)"
)
parser.add_argument(
"--eval_data",
type=str,
default=None,
help="Path to evaluation data"
)
# Training hyperparameters
parser.add_argument(
"--epochs",
type=int,
default=3,
help="Number of training epochs"
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Per-device batch size"
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Gradient accumulation steps"
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-5,
help="Learning rate"
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay"
)
parser.add_argument(
"--warmup_steps",
type=int,
default=100,
help="Warmup steps"
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Gradient clipping norm"
)
# LoRA/QLoRA settings
parser.add_argument(
"--use_lora",
action="store_true",
help="Use LoRA for efficient fine-tuning"
)
parser.add_argument(
"--lora_r",
type=int,
default=16,
help="LoRA rank"
)
parser.add_argument(
"--lora_alpha",
type=int,
default=32,
help="LoRA alpha"
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.1,
help="LoRA dropout"
)
parser.add_argument(
"--use_qlora",
action="store_true",
help="Use 4-bit quantization with LoRA"
)
# Zenith features
parser.add_argument(
"--use_eq_adapter",
action="store_true",
help="Enable EQ adapter during fine-tuning"
)
parser.add_argument(
"--eq_loss_weight",
type=float,
default=0.1,
help="Weight for EQ auxiliary loss"
)
parser.add_argument(
"--use_moe",
action="store_true",
help="Enable MoE layers"
)
parser.add_argument(
"--num_experts",
type=int,
default=8,
help="Number of MoE experts"
)
# Data processing
parser.add_argument(
"--max_seq_length",
type=int,
default=8192,
help="Maximum sequence length"
)
parser.add_argument(
"--use_quality_filter",
action="store_true",
help="Apply quality filtering"
)
parser.add_argument(
"--use_curriculum",
action="store_true",
help="Use curriculum learning"
)
parser.add_argument(
"--use_augmentation",
action="store_true",
help="Use data augmentation"
)
# Output and logging
parser.add_argument(
"--output_dir",
type=str,
default="./outputs/zenith-7b-finetuned",
help="Output directory for checkpoints"
)
parser.add_argument(
"--logging_steps",
type=int,
default=10,
help="Logging frequency"
)
parser.add_argument(
"--eval_steps",
type=int,
default=500,
help="Evaluation frequency"
)
parser.add_argument(
"--save_steps",
type=int,
default=1000,
help="Checkpoint save frequency"
)
# Mixed precision
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["no", "fp16", "bf16"],
help="Mixed precision training"
)
# Resume training
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Resume from checkpoint"
)
return parser.parse_args()
def main():
args = parse_args()
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load base configuration
config = get_7b_config()
# Apply fine-tuning specific settings
config.use_lora = args.use_lora
config.lora_r = args.lora_r
config.lora_alpha = args.lora_alpha
config.lora_dropout = args.lora_dropout
config.use_qlora = args.use_qlora
config.use_eq_adapter = args.use_eq_adapter
config.eq_loss_weight = args.eq_loss_weight
config.use_moe = args.use_moe
config.num_experts = args.num_experts
config.max_seq_len = args.max_seq_length
# Initialize tokenizer from base model
print(f"Loading tokenizer from {args.base_model}...")
tokenizer = AdvancedTokenizer.from_pretrained(args.base_model)
config.vocab_size = tokenizer.get_vocab_size()
# Load base model
print(f"Loading base model: {args.base_model}")
model = ZenithForCausalLM.from_pretrained(
args.base_model,
config=config,
device_map="auto"
)
# Prepare OpenThoughts data processor
ot_config = OpenThoughtsConfig(
dataset_name="open-thoughts/OpenThoughts3-1.2M",
streaming=True,
max_seq_length=args.max_seq_length,
quality_filtering=args.use_quality_filter,
curriculum_learning=args.use_curriculum,
augmentation=args.use_augmentation,
tokenizer=tokenizer
)
data_processor = OpenThoughtsProcessor(ot_config)
# Load datasets
print("Loading and preprocessing datasets...")
train_dataset = data_processor.load_dataset(args.train_data, split="train")
eval_dataset = None
if args.eval_data:
eval_dataset = data_processor.load_dataset(args.eval_data, split="validation")
# Create dataloaders
train_dataloader = data_processor.create_dataloader(
train_dataset,
batch_size=args.batch_size,
shuffle=True
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = data_processor.create_dataloader(
eval_dataset,
batch_size=args.batch_size,
shuffle=False
)
# Trainer configuration
trainer_config = TrainerConfig(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps,
max_grad_norm=args.max_grad_norm,
mixed_precision=args.mixed_precision,
logging_steps=args.logging_steps,
eval_steps=args.eval_steps,
save_steps=args.save_steps,
save_total_limit=3,
report_to="tensorboard",
run_name=f"zenith-7b-finetune-{args.base_model.split('/')[-1]}"
)
# Initialize trainer
print("Initializing trainer...")
trainer = Trainer(
model=model,
config=trainer_config,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
tokenizer=tokenizer
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
trainer.load_checkpoint(args.resume_from_checkpoint)
# Start training
print("Starting fine-tuning...")
trainer.train()
# Save final model
print(f"Saving final model to {output_dir}")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("Fine-tuning complete!")
if __name__ == "__main__":
main()