| import torch
|
| import torch.nn as nn
|
| from models.lib.wav2vec import Wav2Vec2Model
|
| from models.utils import init_biased_mask, enc_dec_mask, PeriodicPositionalEncoding
|
| from base import BaseModel
|
| import pdb
|
| import os
|
| import random
|
| import torch.nn.functional as F
|
|
|
|
|
| class CodeTalker(BaseModel):
|
| def __init__(self, args):
|
| super(CodeTalker, self).__init__()
|
| """
|
| audio: (batch_size, raw_wav)
|
| template: (batch_size, V*3)
|
| vertice: (batch_size, seq_len, V*3)
|
| """
|
| self.args = args
|
| self.dataset = args.dataset
|
| self.audio_encoder = Wav2Vec2Model.from_pretrained(args.wav2vec2model_path)
|
|
|
| self.audio_encoder.feature_extractor._freeze_parameters()
|
|
|
|
|
|
|
| self.audio_feature_map = nn.Linear(1024, args.feature_dim)
|
|
|
|
|
| self.vertice_map = nn.Linear(args.vertice_dim, args.feature_dim)
|
|
|
| self.PPE = PeriodicPositionalEncoding(args.feature_dim, period = args.period)
|
|
|
| self.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 600, period=args.period)
|
| decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=args.n_head, dim_feedforward=2*args.feature_dim, batch_first=True)
|
| self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=args.num_layers)
|
|
|
| self.feat_map = nn.Linear(args.feature_dim, args.face_quan_num*args.zquant_dim, bias=False)
|
|
|
| self.learnable_style_emb = nn.Embedding(len(args.train_subjects.split()), args.feature_dim)
|
|
|
| self.device = args.device
|
| nn.init.constant_(self.feat_map.weight, 0)
|
|
|
|
|
| from models.stage1_vocaset import VQAutoEncoder
|
|
|
| self.autoencoder = VQAutoEncoder(args)
|
| temp = torch.load(args.vqvae_pretrained_path)['state_dict']
|
|
|
| self.autoencoder.load_state_dict(torch.load(args.vqvae_pretrained_path)['state_dict'])
|
| for param in self.autoencoder.parameters():
|
| param.requires_grad = False
|
|
|
| def forward(self, audio_name, audio, template, vertice, one_hot, criterion):
|
|
|
|
|
| template = template.unsqueeze(1)
|
|
|
|
|
|
|
| obj_embedding = self.learnable_style_emb(torch.argmax(one_hot, dim=1))
|
| obj_embedding = obj_embedding.unsqueeze(1)
|
|
|
| frame_num = vertice.shape[1]
|
|
|
|
|
| hidden_states = self.audio_encoder(audio, self.dataset, frame_num=frame_num).last_hidden_state
|
| if self.dataset == "BIWI" or self.dataset=="multi":
|
| if hidden_states.shape[1]<frame_num*2:
|
| vertice = vertice[:, :hidden_states.shape[1]//2]
|
| frame_num = hidden_states.shape[1]//2
|
|
|
| hidden_states = self.audio_feature_map(hidden_states)
|
|
|
|
|
| feat_q_gt, _ = self.autoencoder.get_quant(vertice - template)
|
| feat_q_gt = feat_q_gt.permute(0,2,1)
|
|
|
|
|
| vertice_emb = obj_embedding
|
| style_emb = vertice_emb
|
| vertice_input = torch.cat((template,vertice[:,:-1]), 1)
|
| vertice_input = vertice_input - template
|
| vertice_input = self.vertice_map(vertice_input)
|
| vertice_input = vertice_input + style_emb
|
| vertice_input = self.PPE(vertice_input)
|
| tgt_mask = self.biased_mask[:, :vertice_input.shape[1], :vertice_input.shape[1]].clone().detach().to(device=self.device)
|
| memory_mask = enc_dec_mask(self.device, self.dataset, vertice_input.shape[1], hidden_states.shape[1])
|
| feat_out = self.transformer_decoder(vertice_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| feat_out = self.feat_map(feat_out)
|
| feat_out = feat_out.reshape(feat_out.shape[0], feat_out.shape[1]*self.args.face_quan_num, -1)
|
|
|
| feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
|
|
| vertice_out = self.autoencoder.decode(feat_out_q)
|
| vertice_out = vertice_out + template
|
|
|
|
|
| loss_motion = criterion(vertice_out, vertice)
|
| loss_reg = criterion(feat_out, feat_q_gt.detach())
|
|
|
| return self.args.motion_weight*loss_motion + self.args.reg_weight*loss_reg, [loss_motion, loss_reg]
|
|
|
|
|
| def predict(self, audio, template, one_hot, one_hot2=None, weight_of_one_hot=None, gt_frame_num=None):
|
| template = template.unsqueeze(1)
|
|
|
|
|
| obj_embedding = self.learnable_style_emb(torch.argmax(one_hot, dim=1))
|
|
|
|
|
| if one_hot2 is not None and weight_of_one_hot is not None:
|
| obj_embedding2 = self.learnable_style_emb(torch.argmax(one_hot2, dim=1))
|
| obj_embedding = obj_embedding * weight_of_one_hot + obj_embedding2 * (1-weight_of_one_hot)
|
| obj_embedding = obj_embedding.unsqueeze(1)
|
|
|
|
|
| if gt_frame_num:
|
| hidden_states = self.audio_encoder(audio, self.dataset, frame_num=gt_frame_num).last_hidden_state
|
| else:
|
| hidden_states = self.audio_encoder(audio, self.dataset).last_hidden_state
|
|
|
| if self.dataset == "BIWI":
|
| frame_num = hidden_states.shape[1]//2
|
| elif self.dataset == "vocaset":
|
| frame_num = hidden_states.shape[1]
|
| elif self.dataset == "multi":
|
| if not gt_frame_num:
|
| frame_num = hidden_states.shape[1]//2
|
| else:
|
| frame_num=gt_frame_num
|
|
|
| hidden_states = self.audio_feature_map(hidden_states)
|
|
|
|
|
| for i in range(frame_num):
|
| if i==0:
|
| vertice_emb = obj_embedding
|
| style_emb = vertice_emb
|
| vertice_input = self.PPE(style_emb)
|
| else:
|
| vertice_input = self.PPE(vertice_emb)
|
| tgt_mask = self.biased_mask[:, :vertice_input.shape[1], :vertice_input.shape[1]].clone().detach().to(device=self.device)
|
| memory_mask = enc_dec_mask(self.device, self.dataset, vertice_input.shape[1], hidden_states.shape[1])
|
| feat_out = self.transformer_decoder(vertice_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| feat_out = self.feat_map(feat_out)
|
|
|
| feat_out = feat_out.reshape(feat_out.shape[0], feat_out.shape[1]*self.args.face_quan_num, -1)
|
|
|
| feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
|
|
| if i == 0:
|
| vertice_out_q = self.autoencoder.decode(torch.cat([feat_out_q, feat_out_q], dim=-1))
|
| vertice_out_q = vertice_out_q[:,0].unsqueeze(1)
|
| else:
|
| vertice_out_q = self.autoencoder.decode(feat_out_q)
|
|
|
| if i != frame_num - 1:
|
| new_output = self.vertice_map(vertice_out_q[:,-1,:]).unsqueeze(1)
|
| new_output = new_output + style_emb
|
| vertice_emb = torch.cat((vertice_emb, new_output), 1)
|
|
|
|
|
| feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
|
|
| vertice_out = self.autoencoder.decode(feat_out_q)
|
|
|
| vertice_out = vertice_out + template
|
| return vertice_out
|
|
|