| import os |
| import logging |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, List, Dict, Any |
| import time |
| import json |
| import pathlib |
| from tqdm import tqdm |
| import pandas as pd |
| import numpy as np |
| import argparse |
| import torch |
| from torch import nn |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import ( |
| get_linear_schedule_with_warmup, |
| BertForSequenceClassification, |
| AutoTokenizer, |
| AdamW |
| ) |
| from sklearn.metrics import roc_auc_score |
|
|
| import traceback |
|
|
|
|
| logging.basicConfig( |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| level=logging.INFO, |
| handlers=[ |
| logging.FileHandler('training.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| max_seq_len: int = 50 |
| epochs: int = 3 |
| batch_size: int = 32 |
| learning_rate: float = 2e-5 |
| patience: int = 1 |
| max_grad_norm: float = 10.0 |
| warmup_ratio: float = 0.1 |
| model_path: str = '/cpfs01/shared/MA4Tool/hug_ckpts/BERT_ckpt' |
| num_labels: int = 2 |
| if_save_model: bool = True |
| out_dir: str = './run_0' |
|
|
| def validate(self) -> None: |
| if self.max_seq_len <= 0: |
| raise ValueError("max_seq_len must be positive") |
| if self.epochs <= 0: |
| raise ValueError("epochs must be positive") |
| if self.batch_size <= 0: |
| raise ValueError("batch_size must be positive") |
| if not (0.0 < self.learning_rate): |
| raise ValueError("learning_rate must be between 0 and 1") |
|
|
|
|
| class DataPrecessForSentence(Dataset): |
| def __init__(self, bert_tokenizer: AutoTokenizer, df: pd.DataFrame, max_seq_len: int = 50): |
| self.bert_tokenizer = bert_tokenizer |
| self.max_seq_len = max_seq_len |
| self.input_ids, self.attention_mask, self.token_type_ids, self.labels = self._get_input(df) |
|
|
| def __len__(self) -> int: |
| return len(self.labels) |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| return ( |
| self.input_ids[idx], |
| self.attention_mask[idx], |
| self.token_type_ids[idx], |
| self.labels[idx] |
| ) |
|
|
| def _get_input(self, df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| sentences = df['s1'].values |
| labels = df['similarity'].values |
|
|
| tokens_seq = list(map(self.bert_tokenizer.tokenize, sentences)) |
| result = list(map(self._truncate_and_pad, tokens_seq)) |
|
|
| input_ids = torch.tensor([i[0] for i in result], dtype=torch.long) |
| attention_mask = torch.tensor([i[1] for i in result], dtype=torch.long) |
| token_type_ids = torch.tensor([i[2] for i in result], dtype=torch.long) |
| labels = torch.tensor(labels, dtype=torch.long) |
|
|
| return input_ids, attention_mask, token_type_ids, labels |
|
|
| def _truncate_and_pad(self, tokens_seq: List[str]) -> Tuple[List[int], List[int], List[int]]: |
| tokens_seq = ['[CLS]'] + tokens_seq[:self.max_seq_len - 1] |
| padding_length = self.max_seq_len - len(tokens_seq) |
|
|
| input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens_seq) |
| input_ids += [0] * padding_length |
| attention_mask = [1] * len(tokens_seq) + [0] * padding_length |
| token_type_ids = [0] * self.max_seq_len |
|
|
| return input_ids, attention_mask, token_type_ids |
|
|
|
|
| class BertClassifier(nn.Module): |
| def __init__(self, model_path: str, num_labels: int, requires_grad: bool = True): |
| super().__init__() |
| try: |
| self.bert = BertForSequenceClassification.from_pretrained( |
| model_path, |
| num_labels=num_labels |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| except Exception as e: |
| logger.error(f"Failed to load BERT model: {e}") |
| raise |
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| for param in self.bert.parameters(): |
| param.requires_grad = requires_grad |
|
|
| def forward( |
| self, |
| batch_seqs: torch.Tensor, |
| batch_seq_masks: torch.Tensor, |
| batch_seq_segments: torch.Tensor, |
| labels: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| loss, logits = self.bert( |
| input_ids=batch_seqs, |
| attention_mask=batch_seq_masks, |
| token_type_ids=batch_seq_segments, |
| labels=labels |
| )[:2] |
| probabilities = nn.functional.softmax(logits, dim=-1) |
| return loss, logits, probabilities |
|
|
|
|
| class BertTrainer: |
| def __init__(self, config: TrainingConfig): |
| self.config = config |
| self.config.validate() |
| self.model = BertClassifier(config.model_path, config.num_labels) |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model.to(self.device) |
|
|
| def _prepare_data( |
| self, |
| train_df: pd.DataFrame, |
| dev_df: pd.DataFrame, |
| test_df: pd.DataFrame |
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| train_data = DataPrecessForSentence( |
| self.model.tokenizer, |
| train_df, |
| max_seq_len=self.config.max_seq_len |
| ) |
| train_loader = DataLoader( |
| train_data, |
| shuffle=True, |
| batch_size=self.config.batch_size |
| ) |
|
|
| dev_data = DataPrecessForSentence( |
| self.model.tokenizer, |
| dev_df, |
| max_seq_len=self.config.max_seq_len |
| ) |
| dev_loader = DataLoader( |
| dev_data, |
| shuffle=False, |
| batch_size=self.config.batch_size |
| ) |
|
|
| test_data = DataPrecessForSentence( |
| self.model.tokenizer, |
| test_df, |
| max_seq_len=self.config.max_seq_len |
| ) |
| test_loader = DataLoader( |
| test_data, |
| shuffle=False, |
| batch_size=self.config.batch_size |
| ) |
|
|
| return train_loader, dev_loader, test_loader |
|
|
| def _prepare_optimizer(self, num_training_steps: int) -> Tuple[AdamW, Any]: |
| param_optimizer = list(self.model.named_parameters()) |
| no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
| optimizer_grouped_parameters = [ |
| { |
| 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], |
| 'weight_decay': 0.01 |
| }, |
| { |
| 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], |
| 'weight_decay': 0.0 |
| } |
| ] |
|
|
| optimizer = AdamW( |
| optimizer_grouped_parameters, |
| lr=self.config.learning_rate |
| ) |
|
|
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=int(num_training_steps * self.config.warmup_ratio), |
| num_training_steps=num_training_steps |
| ) |
|
|
| return optimizer, scheduler |
|
|
| def _initialize_training_stats(self) -> Dict[str, List]: |
| return { |
| 'epochs_count': [], |
| 'train_losses': [], |
| 'train_accuracies': [], |
| 'valid_losses': [], |
| 'valid_accuracies': [], |
| 'valid_aucs': [] |
| } |
|
|
| def _update_training_stats( |
| self, |
| training_stats: Dict[str, List], |
| epoch: int, |
| train_metrics: Dict[str, float], |
| val_metrics: Dict[str, float] |
| ) -> None: |
| training_stats['epochs_count'].append(epoch) |
| training_stats['train_losses'].append(train_metrics['loss']) |
| training_stats['train_accuracies'].append(train_metrics['accuracy']) |
| training_stats['valid_losses'].append(val_metrics['loss']) |
| training_stats['valid_accuracies'].append(val_metrics['accuracy']) |
| training_stats['valid_aucs'].append(val_metrics['auc']) |
|
|
| logger.info( |
| f"Training - Loss: {train_metrics['loss']:.4f}, " |
| f"Accuracy: {train_metrics['accuracy'] * 100:.2f}%" |
| ) |
| logger.info( |
| f"Validation - Loss: {val_metrics['loss']:.4f}, " |
| f"Accuracy: {val_metrics['accuracy'] * 100:.2f}%, " |
| f"AUC: {val_metrics['auc']:.4f}" |
| ) |
|
|
| def _save_checkpoint( |
| self, |
| target_dir: str, |
| epoch: int, |
| optimizer: AdamW, |
| best_score: float, |
| training_stats: Dict[str, List] |
| ) -> None: |
| checkpoint = { |
| "epoch": epoch, |
| "model": self.model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "best_score": best_score, |
| **training_stats |
| } |
| torch.save( |
| checkpoint, |
| os.path.join(target_dir, "best.pth.tar") |
| ) |
| logger.info("Model saved successfully") |
|
|
| def _load_checkpoint( |
| self, |
| checkpoint_path: str, |
| optimizer: AdamW, |
| training_stats: Dict[str, List] |
| ) -> float: |
| checkpoint = torch.load(checkpoint_path) |
| self.model.load_state_dict(checkpoint["model"]) |
| optimizer.load_state_dict(checkpoint["optimizer"]) |
| for key in training_stats: |
| training_stats[key] = checkpoint[key] |
| logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}") |
| return checkpoint["best_score"] |
|
|
| def _train_epoch( |
| self, |
| train_loader: DataLoader, |
| optimizer: AdamW, |
| scheduler: Any |
| ) -> Dict[str, float]: |
| self.model.train() |
| total_loss = 0 |
| correct_preds = 0 |
|
|
| for batch in tqdm(train_loader, desc="Training"): |
| batch = tuple(t.to(self.device) for t in batch) |
| input_ids, attention_mask, token_type_ids, labels = batch |
|
|
| optimizer.zero_grad() |
| loss, _, probabilities = self.model(input_ids, attention_mask, token_type_ids, labels) |
|
|
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
|
|
| optimizer.step() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
| correct_preds += (probabilities.argmax(dim=1) == labels).sum().item() |
|
|
| return { |
| 'loss': total_loss / len(train_loader), |
| 'accuracy': correct_preds / len(train_loader.dataset) |
| } |
|
|
| def _validate_epoch(self, dev_loader: DataLoader) -> Tuple[Dict[str, float], List[float]]: |
| self.model.eval() |
| total_loss = 0 |
| correct_preds = 0 |
| all_probs = [] |
| all_labels = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(dev_loader, desc="Validating"): |
| batch = tuple(t.to(self.device) for t in batch) |
| input_ids, attention_mask, token_type_ids, labels = batch |
|
|
| loss, _, probabilities = self.model(input_ids, attention_mask, token_type_ids, labels) |
|
|
| total_loss += loss.item() |
| correct_preds += (probabilities.argmax(dim=1) == labels).sum().item() |
| all_probs.extend(probabilities[:, 1].cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
|
|
| metrics = { |
| 'loss': total_loss / len(dev_loader), |
| 'accuracy': correct_preds / len(dev_loader.dataset), |
| 'auc': roc_auc_score(all_labels, all_probs) |
| } |
|
|
| return metrics, all_probs |
|
|
| def _evaluate_test_set( |
| self, |
| test_loader: DataLoader, |
| target_dir: str, |
| epoch: int |
| ) -> None: |
| test_metrics, all_probs = self._validate_epoch(test_loader) |
| logger.info(f"Test accuracy: {test_metrics['accuracy'] * 100:.2f}%") |
|
|
| test_prediction = pd.DataFrame({'prob_1': all_probs}) |
| test_prediction['prob_0'] = 1 - test_prediction['prob_1'] |
| test_prediction['prediction'] = test_prediction.apply( |
| lambda x: 0 if (x['prob_0'] > x['prob_1']) else 1, |
| axis=1 |
| ) |
|
|
| output_path = os.path.join(target_dir, f"test_prediction_epoch_{epoch}.csv") |
| test_prediction.to_csv(output_path, index=False) |
| logger.info(f"Test predictions saved to {output_path}") |
|
|
| def train_and_evaluate( |
| self, |
| train_df: pd.DataFrame, |
| dev_df: pd.DataFrame, |
| test_df: pd.DataFrame, |
| target_dir: str, |
| checkpoint: Optional[str] = None |
| ) -> None: |
| try: |
| os.makedirs(target_dir, exist_ok=True) |
|
|
| train_loader, dev_loader, test_loader = self._prepare_data( |
| train_df, dev_df, test_df |
| ) |
|
|
| optimizer, scheduler = self._prepare_optimizer( |
| len(train_loader) * self.config.epochs |
| ) |
|
|
| training_stats = self._initialize_training_stats() |
| best_score = 0.0 |
| patience_counter = 0 |
|
|
| if checkpoint: |
| best_score = self._load_checkpoint(checkpoint, optimizer, training_stats) |
|
|
| for epoch in range(1, self.config.epochs + 1): |
| logger.info(f"Training epoch {epoch}") |
|
|
| |
| train_metrics = self._train_epoch(train_loader, optimizer, scheduler) |
|
|
| |
| val_metrics, _ = self._validate_epoch(dev_loader) |
|
|
| self._update_training_stats(training_stats, epoch, train_metrics, val_metrics) |
|
|
| |
| if val_metrics['accuracy'] > best_score: |
| best_score = val_metrics['accuracy'] |
| patience_counter = 0 |
| if self.config.if_save_model: |
| self._save_checkpoint( |
| target_dir, |
| epoch, |
| optimizer, |
| best_score, |
| training_stats |
| ) |
| self._evaluate_test_set(test_loader, target_dir, epoch) |
| else: |
| patience_counter += 1 |
| if patience_counter >= self.config.patience: |
| logger.info("Early stopping triggered") |
| break |
|
|
| final_infos = { |
| "sentiment": { |
| "means": { |
| "best_acc": best_score |
| } |
| } |
| } |
|
|
| with open(os.path.join(self.config.out_dir, "final_info.json"), "w") as f: |
| json.dump(final_infos, f) |
|
|
| except Exception as e: |
| logger.error(f"Training failed: {e}") |
| raise |
|
|
|
|
| def set_seed(seed: int = 42) -> None: |
| import random |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
|
|
| def main(out_dir): |
| try: |
| config = TrainingConfig(out_dir=out_dir) |
| pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
| data_path = "/cpfs01/shared/MA4Tool/datasets/SST-2/" |
| train_df = pd.read_csv( |
| os.path.join(data_path, "train.tsv"), |
| sep='\t', |
| header=None, |
| names=['similarity', 's1'] |
| ) |
| dev_df = pd.read_csv( |
| os.path.join(data_path, "dev.tsv"), |
| sep='\t', |
| header=None, |
| names=['similarity', 's1'] |
| ) |
| test_df = pd.read_csv( |
| os.path.join(data_path, "test.tsv"), |
| sep='\t', |
| header=None, |
| names=['similarity', 's1'] |
| ) |
|
|
| set_seed(2024) |
|
|
| trainer = BertTrainer(config) |
| trainer.train_and_evaluate(train_df, dev_df, test_df, "./output/Bert/") |
|
|
| except Exception as e: |
| logger.error(f"Program failed: {e}") |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--out_dir", type=str, default="run_0") |
| args = parser.parse_args() |
| try: |
| main(args.out_dir) |
| except Exception as e: |
| print("Original error in subprocess:", flush=True) |
| traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
| raise |
|
|