|
|
| """Fine-tune Qwen2.5-Coder-7B with Zenith framework"""
|
|
|
| import torch
|
| import argparse
|
| from pathlib import Path
|
| import sys
|
|
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from configs.zenith_config import get_7b_config
|
| from models.zenith_model import ZenithForCausalLM
|
| from data.openthoughts_processor import OpenThoughtsProcessor, OpenThoughtsConfig
|
| from training.trainer import Trainer, TrainerConfig
|
| from data.advanced_tokenizer import AdvancedTokenizer
|
|
|
|
|
| def parse_args():
|
| """Parse command line arguments."""
|
| parser = argparse.ArgumentParser(description="Fine-tune Qwen2.5-Coder-7B")
|
|
|
|
|
| parser.add_argument(
|
| "--base_model",
|
| type=str,
|
| default="Qwen/Qwen2.5-Coder-7B",
|
| help="Base model from Hugging Face"
|
| )
|
| parser.add_argument(
|
| "--train_data",
|
| type=str,
|
| required=True,
|
| help="Path to training data (JSON or HuggingFace dataset)"
|
| )
|
| parser.add_argument(
|
| "--eval_data",
|
| type=str,
|
| default=None,
|
| help="Path to evaluation data"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--epochs",
|
| type=int,
|
| default=3,
|
| help="Number of training epochs"
|
| )
|
| parser.add_argument(
|
| "--batch_size",
|
| type=int,
|
| default=4,
|
| help="Per-device batch size"
|
| )
|
| parser.add_argument(
|
| "--gradient_accumulation_steps",
|
| type=int,
|
| default=1,
|
| help="Gradient accumulation steps"
|
| )
|
| parser.add_argument(
|
| "--learning_rate",
|
| type=float,
|
| default=2e-5,
|
| help="Learning rate"
|
| )
|
| parser.add_argument(
|
| "--weight_decay",
|
| type=float,
|
| default=0.01,
|
| help="Weight decay"
|
| )
|
| parser.add_argument(
|
| "--warmup_steps",
|
| type=int,
|
| default=100,
|
| help="Warmup steps"
|
| )
|
| parser.add_argument(
|
| "--max_grad_norm",
|
| type=float,
|
| default=1.0,
|
| help="Gradient clipping norm"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--use_lora",
|
| action="store_true",
|
| help="Use LoRA for efficient fine-tuning"
|
| )
|
| parser.add_argument(
|
| "--lora_r",
|
| type=int,
|
| default=16,
|
| help="LoRA rank"
|
| )
|
| parser.add_argument(
|
| "--lora_alpha",
|
| type=int,
|
| default=32,
|
| help="LoRA alpha"
|
| )
|
| parser.add_argument(
|
| "--lora_dropout",
|
| type=float,
|
| default=0.1,
|
| help="LoRA dropout"
|
| )
|
| parser.add_argument(
|
| "--use_qlora",
|
| action="store_true",
|
| help="Use 4-bit quantization with LoRA"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--use_eq_adapter",
|
| action="store_true",
|
| help="Enable EQ adapter during fine-tuning"
|
| )
|
| parser.add_argument(
|
| "--eq_loss_weight",
|
| type=float,
|
| default=0.1,
|
| help="Weight for EQ auxiliary loss"
|
| )
|
| parser.add_argument(
|
| "--use_moe",
|
| action="store_true",
|
| help="Enable MoE layers"
|
| )
|
| parser.add_argument(
|
| "--num_experts",
|
| type=int,
|
| default=8,
|
| help="Number of MoE experts"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--max_seq_length",
|
| type=int,
|
| default=8192,
|
| help="Maximum sequence length"
|
| )
|
| parser.add_argument(
|
| "--use_quality_filter",
|
| action="store_true",
|
| help="Apply quality filtering"
|
| )
|
| parser.add_argument(
|
| "--use_curriculum",
|
| action="store_true",
|
| help="Use curriculum learning"
|
| )
|
| parser.add_argument(
|
| "--use_augmentation",
|
| action="store_true",
|
| help="Use data augmentation"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| default="./outputs/zenith-7b-finetuned",
|
| help="Output directory for checkpoints"
|
| )
|
| parser.add_argument(
|
| "--logging_steps",
|
| type=int,
|
| default=10,
|
| help="Logging frequency"
|
| )
|
| parser.add_argument(
|
| "--eval_steps",
|
| type=int,
|
| default=500,
|
| help="Evaluation frequency"
|
| )
|
| parser.add_argument(
|
| "--save_steps",
|
| type=int,
|
| default=1000,
|
| help="Checkpoint save frequency"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--mixed_precision",
|
| type=str,
|
| default="bf16",
|
| choices=["no", "fp16", "bf16"],
|
| help="Mixed precision training"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--resume_from_checkpoint",
|
| type=str,
|
| default=None,
|
| help="Resume from checkpoint"
|
| )
|
|
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| output_dir = Path(args.output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| config = get_7b_config()
|
|
|
|
|
| config.use_lora = args.use_lora
|
| config.lora_r = args.lora_r
|
| config.lora_alpha = args.lora_alpha
|
| config.lora_dropout = args.lora_dropout
|
| config.use_qlora = args.use_qlora
|
| config.use_eq_adapter = args.use_eq_adapter
|
| config.eq_loss_weight = args.eq_loss_weight
|
| config.use_moe = args.use_moe
|
| config.num_experts = args.num_experts
|
| config.max_seq_len = args.max_seq_length
|
|
|
|
|
| print(f"Loading tokenizer from {args.base_model}...")
|
| tokenizer = AdvancedTokenizer.from_pretrained(args.base_model)
|
| config.vocab_size = tokenizer.get_vocab_size()
|
|
|
|
|
| print(f"Loading base model: {args.base_model}")
|
| model = ZenithForCausalLM.from_pretrained(
|
| args.base_model,
|
| config=config,
|
| device_map="auto"
|
| )
|
|
|
|
|
| ot_config = OpenThoughtsConfig(
|
| dataset_name="open-thoughts/OpenThoughts3-1.2M",
|
| streaming=True,
|
| max_seq_length=args.max_seq_length,
|
| quality_filtering=args.use_quality_filter,
|
| curriculum_learning=args.use_curriculum,
|
| augmentation=args.use_augmentation,
|
| tokenizer=tokenizer
|
| )
|
| data_processor = OpenThoughtsProcessor(ot_config)
|
|
|
|
|
| print("Loading and preprocessing datasets...")
|
| train_dataset = data_processor.load_dataset(args.train_data, split="train")
|
| eval_dataset = None
|
| if args.eval_data:
|
| eval_dataset = data_processor.load_dataset(args.eval_data, split="validation")
|
|
|
|
|
| train_dataloader = data_processor.create_dataloader(
|
| train_dataset,
|
| batch_size=args.batch_size,
|
| shuffle=True
|
| )
|
| eval_dataloader = None
|
| if eval_dataset:
|
| eval_dataloader = data_processor.create_dataloader(
|
| eval_dataset,
|
| batch_size=args.batch_size,
|
| shuffle=False
|
| )
|
|
|
|
|
| trainer_config = TrainerConfig(
|
| output_dir=args.output_dir,
|
| num_train_epochs=args.epochs,
|
| per_device_train_batch_size=args.batch_size,
|
| per_device_eval_batch_size=args.batch_size,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| learning_rate=args.learning_rate,
|
| weight_decay=args.weight_decay,
|
| warmup_steps=args.warmup_steps,
|
| max_grad_norm=args.max_grad_norm,
|
| mixed_precision=args.mixed_precision,
|
| logging_steps=args.logging_steps,
|
| eval_steps=args.eval_steps,
|
| save_steps=args.save_steps,
|
| save_total_limit=3,
|
| report_to="tensorboard",
|
| run_name=f"zenith-7b-finetune-{args.base_model.split('/')[-1]}"
|
| )
|
|
|
|
|
| print("Initializing trainer...")
|
| trainer = Trainer(
|
| model=model,
|
| config=trainer_config,
|
| train_dataloader=train_dataloader,
|
| eval_dataloader=eval_dataloader,
|
| tokenizer=tokenizer
|
| )
|
|
|
|
|
| if args.resume_from_checkpoint:
|
| print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
| trainer.load_checkpoint(args.resume_from_checkpoint)
|
|
|
|
|
| print("Starting fine-tuning...")
|
| trainer.train()
|
|
|
|
|
| print(f"Saving final model to {output_dir}")
|
| model.save_pretrained(output_dir)
|
| tokenizer.save_pretrained(output_dir)
|
|
|
| print("Fine-tuning complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |