| import torch |
| from tqdm import tqdm |
| import torch.optim as optim |
| from utils.dataset import GraphData |
|
|
|
|
| class Trainer: |
| def __init__(self, args, net, G_data): |
| self.args = args |
| self.net = net |
| self.feat_dim = G_data.feat_dim |
| self.fold_idx = G_data.fold_idx |
| self.init(args, G_data.train_gs, G_data.test_gs) |
| if torch.cuda.is_available(): |
| self.net.cuda() |
|
|
| def init(self, args, train_gs, test_gs): |
| print('#train: %d, #test: %d' % (len(train_gs), len(test_gs))) |
| train_data = GraphData(train_gs, self.feat_dim) |
| test_data = GraphData(test_gs, self.feat_dim) |
| self.train_d = train_data.loader(self.args.batch, True) |
| self.test_d = test_data.loader(self.args.batch, False) |
| self.optimizer = optim.Adam( |
| self.net.parameters(), lr=self.args.lr, amsgrad=True, |
| weight_decay=0.0008) |
|
|
| def to_cuda(self, gs): |
| if torch.cuda.is_available(): |
| if type(gs) == list: |
| return [g.cuda() for g in gs] |
| return gs.cuda() |
| return gs |
|
|
| def run_epoch(self, epoch, data, model, optimizer): |
| losses, accs, n_samples = [], [], 0 |
| for batch in tqdm(data, desc=str(epoch), unit='b'): |
| cur_len, gs, hs, ys = batch |
| gs, hs, ys = map(self.to_cuda, [gs, hs, ys]) |
| loss, acc = model(gs, hs, ys) |
| losses.append(loss*cur_len) |
| accs.append(acc*cur_len) |
| n_samples += cur_len |
| if optimizer is not None: |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samples |
| return avg_loss.item(), avg_acc.item() |
|
|
| def train(self): |
| max_acc = 0.0 |
| train_str = 'Train epoch %d: loss %.5f acc %.5f' |
| test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f' |
| line_str = '%d:\t%.5f\n' |
| for e_id in range(self.args.num_epochs): |
| self.net.train() |
| loss, acc = self.run_epoch( |
| e_id, self.train_d, self.net, self.optimizer) |
| print(train_str % (e_id, loss, acc)) |
|
|
| with torch.no_grad(): |
| self.net.eval() |
| loss, acc = self.run_epoch(e_id, self.test_d, self.net, None) |
| max_acc = max(max_acc, acc) |
| print(test_str % (e_id, loss, acc, max_acc)) |
|
|
| with open(self.args.acc_file, 'a+') as f: |
| f.write(line_str % (self.fold_idx, max_acc)) |
|
|