| | from __future__ import absolute_import
|
| | from __future__ import division
|
| | from __future__ import print_function
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
| | import numpy as np
|
| | from KMVE_RG.modules.Caption import MyCaption
|
| |
|
| |
|
| | def sort_pack_padded_sequence(input, lengths):
|
| | sorted_lengths, indices = torch.sort(lengths, descending=True)
|
| | tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
|
| | inv_ix = indices.clone()
|
| | inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
|
| | return tmp, inv_ix
|
| |
|
| |
|
| | def pad_unsort_packed_sequence(input, inv_ix):
|
| | tmp, _ = pad_packed_sequence(input, batch_first=True)
|
| | tmp = tmp[inv_ix]
|
| | return tmp
|
| |
|
| |
|
| | def pack_wrapper(module, att_feats, att_masks):
|
| | if att_masks is not None:
|
| | packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
|
| | return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
| | else:
|
| | return module(att_feats)
|
| |
|
| |
|
| | class GenModel(MyCaption):
|
| | def __init__(self, args, tokenizer):
|
| | super(GenModel, self).__init__()
|
| | self.args = args
|
| | self.tokenizer = tokenizer
|
| | self.vocab_size = len(tokenizer.idx2token)
|
| | self.input_encoding_size = args.d_model
|
| | self.rnn_size = args.d_ff
|
| | self.num_layers = args.num_layers
|
| | self.drop_prob_lm = args.drop_prob_lm
|
| | self.max_seq_length = args.max_seq_length
|
| | self.att_feat_size = args.d_vf
|
| | self.att_hid_size = args.d_model
|
| |
|
| | self.bos_idx = args.bos_idx
|
| | self.eos_idx = args.eos_idx
|
| | self.pad_idx = args.pad_idx
|
| |
|
| | self.use_bn = args.use_bn
|
| |
|
| | self.embed = lambda x: x
|
| | self.fc_embed = lambda x: x
|
| | self.att_embed = nn.Sequential(*(
|
| | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
|
| | (nn.Linear(self.att_feat_size, self.input_encoding_size),
|
| | nn.ReLU(),
|
| | nn.Dropout(self.drop_prob_lm)) +
|
| | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))
|
| |
|
| | def clip_att(self, att_feats, att_masks):
|
| | if att_masks is not None:
|
| | max_len = att_masks.data.long().sum(1).max()
|
| | att_feats = att_feats[:, :max_len].contiguous()
|
| | att_masks = att_masks[:, :max_len].contiguous()
|
| | return att_feats, att_masks
|
| |
|
| | def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
| | att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
| |
|
| |
|
| | fc_feats = self.fc_embed(fc_feats)
|
| | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
| | p_att_feats = self.ctx2att(att_feats)
|
| |
|
| | return fc_feats, att_feats, p_att_feats, att_masks
|
| |
|
| | def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
|
| | xt = self.embed(it)
|
| | output, state = self.core(xt, p_att_feats, state, att_masks)
|
| | if output_logsoftmax:
|
| | logprobs = F.log_softmax(self.logit(output), dim=1)
|
| | else:
|
| | logprobs = self.logit(output)
|
| |
|
| | output_weight = output.unsqueeze(-1)
|
| | attn_map = torch.matmul(p_att_feats, output_weight)
|
| | return logprobs, state, attn_map
|
| |
|
| | def _sample(self, fc_feats, att_feats, att_masks=None):
|
| | opt = self.args.__dict__
|
| | sample_n = int(opt.get('sample_n', 1))
|
| | output_logsoftmax = opt.get('output_logsoftmax', 1)
|
| | decoding_constraint = opt.get('decoding_constraint', 0)
|
| |
|
| |
|
| |
|
| | batch_size = fc_feats.size(0)
|
| |
|
| | state = []
|
| | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
| |
|
| | seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
|
| | seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
|
| |
|
| | for t in range(self.max_seq_length + 1):
|
| | if t == 0:
|
| | it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)
|
| |
|
| | logprobs, state, attn_map = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
|
| | output_logsoftmax=output_logsoftmax)
|
| | if decoding_constraint and t > 0:
|
| | tmp = logprobs.new_zeros(logprobs.size())
|
| | tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
|
| | logprobs = logprobs + tmp
|
| |
|
| | if t == self.max_seq_length:
|
| | break
|
| | it, sampleLogprobs = self.sample_next_word(logprobs)
|
| |
|
| | if t == 0:
|
| | unfinished = it != self.eos_idx
|
| | else:
|
| | it[~unfinished] = self.pad_idx
|
| | logprobs = logprobs * unfinished.unsqueeze(1).float()
|
| | unfinished = unfinished * (it != self.eos_idx)
|
| | seq[:, t] = it
|
| | seqLogprobs[:, t] = logprobs
|
| |
|
| | return seq, seqLogprobs
|
| |
|
| | def _evaluate(self, fc_feats, att_feats, att_masks=None):
|
| | opt = self.args.__dict__
|
| | sample_n = int(opt.get('sample_n', 1))
|
| | output_logsoftmax = opt.get('output_logsoftmax', 1)
|
| | decoding_constraint = opt.get('decoding_constraint', 0)
|
| |
|
| | batch_size = fc_feats.size(0)
|
| | state = []
|
| | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
| |
|
| | seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
|
| | seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
|
| |
|
| | first_sentence = []
|
| | first_attmap = []
|
| | first_sentence_probs = []
|
| | for t in range(self.max_seq_length + 1):
|
| | if t == 0:
|
| | it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)
|
| |
|
| | logprobs, state, attn_map = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks,
|
| | state,
|
| | output_logsoftmax=output_logsoftmax)
|
| | if decoding_constraint and t > 0:
|
| | tmp = logprobs.new_zeros(logprobs.size())
|
| | tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
|
| | logprobs = logprobs + tmp
|
| |
|
| | if t == self.max_seq_length:
|
| | break
|
| | it, sampleLogprobs = self.sample_next_word(logprobs)
|
| |
|
| | if t == 0:
|
| | unfinished = it != self.eos_idx
|
| | else:
|
| | it[~unfinished] = self.pad_idx
|
| | logprobs = logprobs * unfinished.unsqueeze(1).float()
|
| | unfinished = unfinished * (it != self.eos_idx)
|
| | seq[:, t] = it
|
| | seqLogprobs[:, t] = logprobs
|
| |
|
| | log_probs = logprobs[0].cpu()
|
| | probabilities = np.exp(log_probs)
|
| | index = int(it[0].cpu())
|
| |
|
| | prob = probabilities[index]
|
| | first_attmap.append(attn_map[0])
|
| | first_sentence.append(index)
|
| | first_sentence_probs.append(prob)
|
| |
|
| | if unfinished.sum() == 0:
|
| | break
|
| | return seq, first_sentence, first_attmap, first_sentence_probs
|
| |
|