| import torch |
| import config |
| import math |
| import sys |
| import os |
| from tqdm import tqdm |
| from torch.optim import Adam |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer, AutoConfig |
| from pretrained_models import load_esm2_model |
| from model import MembraneMLM, MembraneTokenizer |
| from data_loader import get_dataloaders |
|
|
| def save_hyperparams(ckpt_dir): |
| hyperparms_txt_file = os.path.join(ckpt_dir, "hyperparameters.txt") |
| with open(hyperparms_txt_file, 'w') as f: |
| for k, v in vars(config).items(): |
| if k.isupper(): |
| f.write(f"{k}: {v}\n") |
|
|
| def train_and_validate(model, optimizer, device, train_loader, val_loader, num_epochs, ckpt_dir): |
| best_val_loss = float('inf') |
|
|
| for epoch in range(num_epochs): |
| print(f"EPOCH {epoch+1}/{num_epochs}") |
| sys.stderr.flush() |
| total_train_loss = 0.0 |
| weighted_total_train_loss = 0.0 |
| total_masked_train_tokens = 0 |
| |
| model.train() |
| train_update_interval = len(train_loader) // 4 |
|
|
| with tqdm(enumerate(train_loader), desc="Training batch", total=len(train_loader), leave=True, position=0, ncols=100) as trainbar: |
| for step, inputs in trainbar: |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| optimizer.zero_grad() |
| outputs = model(**inputs) |
| train_loss = outputs.loss |
| train_loss.backward() |
| optimizer.step() |
|
|
| num_mask_tokens = (inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| total_masked_train_tokens += num_mask_tokens |
|
|
| total_train_loss += train_loss.item() |
| weighted_total_train_loss += train_loss.item() * num_mask_tokens |
|
|
| if (step+1) % train_update_interval == 0: |
| trainbar.update(train_update_interval) |
|
|
| avg_train_loss = total_train_loss / len(train_loader) |
| avg_train_neg_log_likelihood = weighted_total_train_loss / total_masked_train_tokens |
| train_perplexity = math.exp(avg_train_neg_log_likelihood) |
|
|
| |
| train_ckpt_path = os.path.join(config.CKPT_DIR, f'epoch{epoch+1}') |
| model.save_model(train_ckpt_path) |
| save_hyperparams(train_ckpt_path) |
|
|
| |
| if val_loader: |
| model.eval() |
| total_val_loss = 0.0 |
| weighted_total_val_loss = 0.0 |
| total_masked_val_tokens = 0.0 |
|
|
| with torch.no_grad(): |
| val_update_interval = len(val_loader) // 4 |
|
|
| with tqdm(enumerate(val_loader), desc='Validiation batch', total=len(val_loader), leave=True, position=0) as valbar: |
| for step, inputs in valbar: |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| val_loss = model(**inputs).loss.item() |
|
|
| num_mask_tokens = (inputs['input_ids'] == tokenizer.mask_token_id).sum().item() |
| total_masked_val_tokens += num_mask_tokens |
| |
| total_val_loss += val_loss |
| weighted_total_val_loss += val_loss * num_mask_tokens |
|
|
| if (step+1) % val_update_interval == 0: |
| valbar.update(val_update_interval) |
|
|
| avg_val_loss = total_val_loss / len(val_loader) |
| avg_val_neg_log_likelihood = weighted_total_val_loss / total_masked_val_tokens |
| val_perplexity = math.exp(avg_val_neg_log_likelihood) |
|
|
| |
| if avg_val_loss < best_val_loss: |
| best_val_loss = avg_val_loss |
| val_ckpt_path = os.path.join(config.CKPT_DIR, "best_model_epoch") |
| model.save_model(val_ckpt_path) |
| save_hyperparams(val_ckpt_path) |
|
|
|
|
| print(f"Average train loss: {avg_train_loss}") |
| print(f"Average train perplexity: {train_perplexity}\n") |
| sys.stdout.flush() |
|
|
| print(f"Average validation loss: {avg_val_loss}") |
| print(f"Average validation perplexity: {val_perplexity}\n") |
| sys.stdout.flush() |
| |
|
|
| return avg_train_loss, train_perplexity, avg_val_loss, val_perplexity |
| |
|
|
| def test(model, test_loader, device): |
| model.to(device).eval() |
| total_test_loss = 0.0 |
| weighted_total_test_loss = 0.0 |
| total_masked_test_tokens = 0.0 |
|
|
| with torch.no_grad(): |
| for step, inputs in enumerate(test_loader): |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| outputs = model(**inputs) |
| test_loss = outputs.loss.item() |
|
|
| num_mask_tokens = (inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
| total_masked_test_tokens += num_mask_tokens |
|
|
| total_test_loss += test_loss |
| weighted_total_test_loss += test_loss * num_mask_tokens |
| |
| avg_test_loss = total_test_loss / len(test_loader) |
| avg_test_neg_log_likilehood = weighted_total_test_loss / total_masked_test_tokens |
| test_perplexity = math.exp(avg_test_neg_log_likilehood) |
|
|
| return avg_test_loss, test_perplexity |
|
|
|
|
| if __name__ == "__main__": |
| device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
| print(device) |
| |
| model = MembraneMLM() |
| model.to(device) |
| model.freeze_model() |
| model.unfreeze_n_layers() |
| tokenizer = model.tokenizer |
|
|
| train_loader, val_loader, test_loader = get_dataloaders(config) |
| optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.LEARNING_RATE) |
| |
| |
| avg_train_loss, train_ppl, avg_val_loss, val_ppl = train_and_validate(model, optimizer, device, train_loader, val_loader, config.NUM_EPOCHS, config.CKPT_DIR) |
| avg_test_loss, test_ppl = test(model, test_loader, device) |
|
|
| results_dict = {"Average train loss": avg_train_loss, |
| "Average train perplexity": train_ppl, |
| "Average val loss": avg_val_loss, |
| "Average val perplexity": val_ppl, |
| "Average test loss": avg_test_loss, |
| "Average test perplexity": test_ppl, |
| } |
|
|
| print("TRAIN AND TEST RESULTS") |
| for k, v in results_dict.items(): |
| print(f"{k}: {v}\n") |
|
|
| |
| with open(config.CKPT_DIR + "/train_test_results.txt", 'w') as f: |
| for k, v in results_dict.items(): |
| f.write(f'{k}: {v}\n') |
|
|
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |