import torch from tqdm import tqdm import torch.optim as optim from utils.dataset import GraphData class Trainer: def __init__(self, args, net, G_data): self.args = args self.net = net self.feat_dim = G_data.feat_dim self.fold_idx = G_data.fold_idx self.init(args, G_data.train_gs, G_data.test_gs) if torch.cuda.is_available(): self.net.cuda() def init(self, args, train_gs, test_gs): print('#train: %d, #test: %d' % (len(train_gs), len(test_gs))) train_data = GraphData(train_gs, self.feat_dim) test_data = GraphData(test_gs, self.feat_dim) self.train_d = train_data.loader(self.args.batch, True) self.test_d = test_data.loader(self.args.batch, False) self.optimizer = optim.Adam( self.net.parameters(), lr=self.args.lr, amsgrad=True, weight_decay=0.0008) def to_cuda(self, gs): if torch.cuda.is_available(): if type(gs) == list: return [g.cuda() for g in gs] return gs.cuda() return gs def run_epoch(self, epoch, data, model, optimizer): losses, accs, n_samples = [], [], 0 for batch in tqdm(data, desc=str(epoch), unit='b'): cur_len, gs, hs, ys = batch gs, hs, ys = map(self.to_cuda, [gs, hs, ys]) loss, acc = model(gs, hs, ys) losses.append(loss*cur_len) accs.append(acc*cur_len) n_samples += cur_len if optimizer is not None: optimizer.zero_grad() loss.backward() optimizer.step() avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samples return avg_loss.item(), avg_acc.item() def train(self): max_acc = 0.0 train_str = 'Train epoch %d: loss %.5f acc %.5f' test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f' line_str = '%d:\t%.5f\n' for e_id in range(self.args.num_epochs): self.net.train() loss, acc = self.run_epoch( e_id, self.train_d, self.net, self.optimizer) print(train_str % (e_id, loss, acc)) with torch.no_grad(): self.net.eval() loss, acc = self.run_epoch(e_id, self.test_d, self.net, None) max_acc = max(max_acc, acc) print(test_str % (e_id, loss, acc, max_acc)) with open(self.args.acc_file, 'a+') as f: f.write(line_str % (self.fold_idx, max_acc))