| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import time |
| import copy |
| import sys |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from sklearn.metrics import roc_auc_score |
| from torch.autograd import Variable |
| from scripts.multiAUC import Metric |
| import numpy |
| from tqdm import tqdm |
| from random import sample |
| from scripts.plot import bootstrap_auc,result_csv,plotimage |
| import pynvml |
| pynvml.nvmlInit() |
| from prettytable import PrettyTable |
|
|
|
|
| def train_model(model, dataloaders, criterion, optimizer,num_epochs, modelname, device): |
| global VAL_auc,TEST_auc |
| since = time.time() |
| train_loss_history, valid_loss_history, test_loss_history= [], [], [] |
| test_maj_history, test_min_history = [], [] |
| train_auc_history, val_auc_history, test_auc_history = [], [], [] |
| best_model_wts = copy.deepcopy(model.state_dict()) |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) |
| for epoch in range(num_epochs): |
| start = time.time() |
|
|
| print('{} Epoch {}/{} {}'.format('-' * 30, epoch, num_epochs - 1, '-' * 30)) |
| for phase in ['train','valid', 'test']: |
| if phase == 'train' and epoch != 0: |
| model.train() |
| else: |
| model.eval() |
| running_loss,running_corrects,prob_all, label_all = [], [], [], [] |
| with tqdm(range(len(dataloaders[phase])),desc='%s' % phase, ncols=100) as t: |
| if epoch == 0 : |
| t.set_postfix(L = 0.000, usedMemory = 0) |
|
|
| for data in dataloaders[phase]: |
| inputs, labels, sub = data |
| print(labels) |
| inputs = inputs.to(device) |
| labels = labels.to(device) |
| optimizer.zero_grad(set_to_none=True) |
| with torch.set_grad_enabled(phase == 'train'): |
| outputs = model(inputs) |
| loss = criterion(outputs, labels) |
| _, preds = torch.max(outputs, 1) |
| if phase == 'train' and epoch != 0: |
| loss.backward() |
| optimizer.step() |
| running_loss.append(loss.item()) |
| running_corrects.append((preds.cpu().detach() == labels.cpu().detach()).numpy()) |
|
|
| prob_all.extend(outputs[:, 1].cpu().detach().numpy()) |
| label_all.extend(labels.cpu().detach().numpy()) |
|
|
| """ |
| B:batch |
| L:Loss |
| maj: Maj group AUC |
| min: Min group AUC |
| n: NVIDIA Memory used |
| """ |
| gpu_device = pynvml.nvmlDeviceGetHandleByIndex(0) |
| meminfo = pynvml.nvmlDeviceGetMemoryInfo(gpu_device).total |
| usedMemory = pynvml.nvmlDeviceGetMemoryInfo(gpu_device).used |
| usedMemory = usedMemory/meminfo |
| t.set_postfix(loss = loss.data.item(), usedMemory = usedMemory) |
| t.update() |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if modelname =="Thyroid_PF": |
| try: |
| data_auc = roc_auc_score(Label,Output) |
| Data_auc_maj = roc_auc_score(Label_maj, Output_maj) |
| Data_auc_min = roc_auc_score(Label_min, Output_min) |
| except: |
| data_auc = roc_auc_score(Label,Output) |
| Data_auc_maj = 0 |
| Data_auc_min = 0 |
| epoch_loss = running_loss / Batch |
| statistics = bootstrap_auc(Label, Output, [0,1,2,3,4]) |
| max_auc = np.max(statistics, axis=1).max() |
| min_auc = np.min(statistics, axis=1).max() |
| if G == [] and phase == "train": |
| G1.append(0) |
| elif phase == "train": |
| G1.append(sum(G)/len(G)) |
| print('{} --> Num: {} Loss: {:.4f} Gamma: {:.4f} AUROC: {:.4f} ({:.2f} ~ {:.2f}) (Maj {:.4f}, Min {:.4f})'.format( |
| phase, len(outputs_out), epoch_loss, G1[-1], data_auc, min_auc, max_auc, Data_auc_maj, Data_auc_min)) |
|
|
| else: |
| myMetic = Metric(Output,Label) |
| data_auc,auc = myMetic.auROC() |
| epoch_loss = running_loss / Batch |
| statistics = bootstrap_auc(Label, Output, [0,1,2,3,4]) |
| max_auc = np.max(statistics, axis=1).max() |
| min_auc = np.min(statistics, axis=1).max() |
| if G == [] and phase == "train": |
| G1.append(0) |
| elif phase == "train": |
| G1.append(sum(G)/len(G)) |
| print('{} --> Num: {} Loss: {:.4f} AUROC: {:.4f} ({:.2f} ~ {:.2f}) (Maj {:.4f}, Min {:.4f})'.format( |
| phase, len(outputs_out), epoch_loss, data_auc, min_auc, max_auc, data_auc_maj,data_auc_min)) |
|
|
| if phase == 'train': |
| train_loss_history.append(epoch_loss) |
| train_auc_history.append(auc) |
|
|
| if phase == 'valid': |
| valid_loss_history.append(epoch_loss) |
| val_auc_history.append(auc) |
|
|
| if phase == 'test': |
| test_loss_history.append(epoch_loss) |
| test_auc_history.append(auc) |
| |
| if phase == 'valid' and train_auc_history[-1] >= 0.9: |
| if val_auc_history[-1] >= max(val_auc_history) or test_auc_history[-1] >= max(test_auc_history): |
| print("In epoch %d, better AUC(%.3f) and save model. " % (epoch, float(val_auc_history[-1]))) |
| PATH = '/export/home/daifang/Diffusion/Resnet/modelsaved/%s/e%d_%s_V%.3fT%.3f.pth' % (modelname,epoch,modelname,val_auc_history[-1],test_auc_history[-1]) |
| torch.save(model.state_dict(),PATH) |
| |
| print("learning rate = %.6f time: %.1f sec" % (optimizer.param_groups[-1]['lr'], time.time() - start)) |
| if epoch != 0: |
| scheduler.step() |
| print() |
|
|
| plotimage(train_auc_history, val_auc_history, test_auc_history,"AUC", modelname) |
| plotimage(train_loss_history, valid_loss_history, test_loss_history,"Loss", modelname) |
| result_csv( train_auc_history, val_auc_history, test_auc_history, modelname) |
| |
| time_elapsed = time.time() - since |
| print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) |
| model.load_state_dict(best_model_wts) |