| import os |
| import torch |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSequenceClassification, |
| TrainingArguments, |
| Trainer |
| ) |
| from .config import Config |
| from .dataset import DataProcessor |
| from .metrics import compute_metrics |
| from .visualization import plot_training_history |
|
|
| def main(): |
| |
| if torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| print(f"Using device: MPS (Mac Silicon Acceleration)") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| print(f"Using device: CUDA") |
| else: |
| device = torch.device("cpu") |
| print(f"Using device: CPU") |
|
|
| |
| print(f"Loading tokenizer from {Config.BASE_MODEL}...") |
| tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) |
|
|
| |
| print("Preparing datasets...") |
| processor = DataProcessor(tokenizer) |
| |
| |
| num_proc = max(1, os.cpu_count() - 1) |
| |
| |
| dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR, num_proc=num_proc) |
| |
| train_dataset = dataset['train'] |
| eval_dataset = dataset['test'] |
| |
| print(f"Training on {len(train_dataset)} samples, Validating on {len(eval_dataset)} samples.") |
|
|
| |
| print("Loading model...") |
| model = AutoModelForSequenceClassification.from_pretrained( |
| Config.BASE_MODEL, |
| num_labels=Config.NUM_LABELS, |
| id2label=Config.ID2LABEL, |
| label2id=Config.LABEL2ID |
| ) |
| model.to(device) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=Config.RESULTS_DIR, |
| num_train_epochs=Config.NUM_EPOCHS, |
| per_device_train_batch_size=Config.BATCH_SIZE, |
| per_device_eval_batch_size=Config.BATCH_SIZE, |
| learning_rate=Config.LEARNING_RATE, |
| warmup_ratio=Config.WARMUP_RATIO, |
| weight_decay=Config.WEIGHT_DECAY, |
| logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'), |
| logging_steps=Config.LOGGING_STEPS, |
| eval_strategy="steps", |
| eval_steps=Config.EVAL_STEPS, |
| save_steps=Config.SAVE_STEPS, |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| |
| |
| |
| |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| tokenizer=tokenizer, |
| compute_metrics=compute_metrics, |
| ) |
|
|
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| print(f"Saving model to {Config.CHECKPOINT_DIR}...") |
| trainer.save_model(Config.CHECKPOINT_DIR) |
| tokenizer.save_pretrained(Config.CHECKPOINT_DIR) |
| |
| |
| print("Generating training plots...") |
| plot_save_path = os.path.join(Config.RESULTS_DIR, 'training_curves.png') |
| plot_training_history(trainer.state.log_history, save_path=plot_save_path) |
| |
| print("Training completed!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|