| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset |
| import torch.utils.data |
| import json |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| class Dataset(Dataset): |
|
|
| def __init__(self): |
|
|
| self.pairs = json.load(open('pairs_encoded.json')) |
| self.dataset_size = len(self.pairs) |
|
|
| def __getitem__(self, i): |
| |
| question = torch.LongTensor(self.pairs[i][0]) |
| reply = torch.LongTensor(self.pairs[i][1]) |
| |
| return question, reply |
|
|
| def __len__(self): |
| return self.dataset_size |
| |
|
|
| def create_masks(question, reply_input, reply_target): |
| |
| def subsequent_mask(size): |
| mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) |
| return mask.unsqueeze(0) |
| |
| question_mask = (question!=0).to(device) |
| question_mask = question_mask.unsqueeze(1).unsqueeze(1) |
| |
| reply_input_mask = reply_input!=0 |
| reply_input_mask = reply_input_mask.unsqueeze(1) |
| reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) |
| reply_input_mask = reply_input_mask.unsqueeze(1) |
| reply_target_mask = reply_target!=0 |
| |
| return question_mask, reply_input_mask, reply_target_mask |
|
|
|
|
| class AdamWarmup: |
| |
| def __init__(self, model_size, warmup_steps, optimizer): |
| |
| self.model_size = model_size |
| self.warmup_steps = warmup_steps |
| self.optimizer = optimizer |
| self.current_step = 0 |
| self.lr = 0 |
| |
| def get_lr(self): |
| return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5)) |
| |
| def step(self): |
| |
| self.current_step += 1 |
| lr = self.get_lr() |
| for param_group in self.optimizer.param_groups: |
| param_group['lr'] = lr |
| |
| self.lr = lr |
| self.optimizer.step() |
|
|
| class LossWithLS(nn.Module): |
|
|
| def __init__(self, size, smooth): |
| super(LossWithLS, self).__init__() |
| self.criterion = nn.KLDivLoss(size_average=False, reduce=False) |
| self.confidence = 1.0 - smooth |
| self.smooth = smooth |
| self.size = size |
| |
| def forward(self, prediction, target, mask): |
| """ |
| prediction of shape: (batch_size, max_words, vocab_size) |
| target and mask of shape: (batch_size, max_words) |
| """ |
| prediction = prediction.view(-1, prediction.size(-1)) |
| target = target.contiguous().view(-1) |
| mask = mask.float() |
| mask = mask.view(-1) |
| labels = prediction.data.clone() |
| labels.fill_(self.smooth / (self.size - 1)) |
| labels.scatter_(1, target.data.unsqueeze(1), self.confidence) |
| loss = self.criterion(prediction, labels) |
| loss = (loss.sum(1) * mask).sum() / mask.sum() |
| return loss |
|
|