Spaces:
Runtime error
Runtime error
| # coding: UTF-8 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from sklearn import metrics | |
| import time | |
| from utils import get_time_dif | |
| from tensorboardX import SummaryWriter | |
| # 权重初始化,默认xavier | |
| def init_network(model, method='xavier', exclude='embedding', seed=123): | |
| for name, w in model.named_parameters(): | |
| if exclude not in name: | |
| if 'weight' in name: | |
| if method == 'xavier': | |
| nn.init.xavier_normal_(w) | |
| elif method == 'kaiming': | |
| nn.init.kaiming_normal_(w) | |
| else: | |
| nn.init.normal_(w) | |
| elif 'bias' in name: | |
| nn.init.constant_(w, 0) | |
| else: | |
| pass | |
| def train(config, model, train_iter, dev_iter, test_iter): | |
| start_time = time.time() | |
| model.train() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) | |
| # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率 | |
| # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) | |
| total_batch = 0 # 记录进行到多少batch | |
| dev_best_loss = float('inf') | |
| last_improve = 0 # 记录上次验证集loss下降的batch数 | |
| flag = False # 记录是否很久没有效果提升 | |
| writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())) | |
| for epoch in range(config.num_epochs): | |
| print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) | |
| # scheduler.step() # 学习率衰减 | |
| for i, (trains, labels) in enumerate(train_iter): | |
| outputs = model(trains) | |
| model.zero_grad() | |
| loss = F.cross_entropy(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| if total_batch % 100 == 0: | |
| # 每多少轮输出在训练集和验证集上的效果 | |
| true = labels.data.cpu() | |
| predic = torch.max(outputs.data, 1)[1].cpu() | |
| train_acc = metrics.accuracy_score(true, predic) | |
| dev_acc, dev_loss = evaluate(config, model, dev_iter) | |
| if dev_loss < dev_best_loss: | |
| dev_best_loss = dev_loss | |
| torch.save(model.state_dict(), config.save_path) | |
| improve = '*' | |
| last_improve = total_batch | |
| else: | |
| improve = '' | |
| time_dif = get_time_dif(start_time) | |
| msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' | |
| print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) | |
| writer.add_scalar("loss/train", loss.item(), total_batch) | |
| writer.add_scalar("loss/dev", dev_loss, total_batch) | |
| writer.add_scalar("acc/train", train_acc, total_batch) | |
| writer.add_scalar("acc/dev", dev_acc, total_batch) | |
| model.train() | |
| total_batch += 1 | |
| if total_batch - last_improve > config.require_improvement: | |
| # 验证集loss超过1000batch没下降,结束训练 | |
| print("No optimization for a long time, auto-stopping...") | |
| flag = True | |
| break | |
| if flag: | |
| break | |
| writer.close() | |
| test(config, model, test_iter) | |
| def test(config, model, test_iter): | |
| # test | |
| model.load_state_dict(torch.load(config.save_path)) | |
| model.eval() | |
| start_time = time.time() | |
| test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) | |
| msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' | |
| print(msg.format(test_loss, test_acc)) | |
| print("Precision, Recall and F1-Score...") | |
| print(test_report) | |
| print("Confusion Matrix...") | |
| print(test_confusion) | |
| time_dif = get_time_dif(start_time) | |
| print("Time usage:", time_dif) | |
| def evaluate(config, model, data_iter, test=False): | |
| model.eval() | |
| loss_total = 0 | |
| predict_all = np.array([], dtype=int) | |
| labels_all = np.array([], dtype=int) | |
| with torch.no_grad(): | |
| for texts, labels in data_iter: | |
| outputs = model(texts) | |
| loss = F.cross_entropy(outputs, labels) | |
| loss_total += loss | |
| labels = labels.data.cpu().numpy() | |
| predic = torch.max(outputs.data, 1)[1].cpu().numpy() | |
| labels_all = np.append(labels_all, labels) | |
| predict_all = np.append(predict_all, predic) | |
| acc = metrics.accuracy_score(labels_all, predict_all) | |
| if test: | |
| report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) | |
| confusion = metrics.confusion_matrix(labels_all, predict_all) | |
| return acc, loss_total / len(data_iter), report, confusion | |
| return acc, loss_total / len(data_iter) |