| | import torch
|
| | import torch.nn as nn
|
| | from models.lib.wav2vec import Wav2Vec2Model
|
| | from models.utils import init_biased_mask, enc_dec_mask, PeriodicPositionalEncoding
|
| | from base import BaseModel
|
| | import pdb
|
| | import os
|
| | import random
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| | class CodeTalker(BaseModel):
|
| | def __init__(self, args):
|
| | super(CodeTalker, self).__init__()
|
| | """
|
| | audio: (batch_size, raw_wav)
|
| | template: (batch_size, V*3)
|
| | vertice: (batch_size, seq_len, V*3)
|
| | """
|
| | self.args = args
|
| | self.dataset = args.dataset
|
| | self.audio_encoder = Wav2Vec2Model.from_pretrained(args.wav2vec2model_path)
|
| |
|
| | self.audio_encoder.feature_extractor._freeze_parameters()
|
| |
|
| |
|
| |
|
| | self.audio_feature_map = nn.Linear(1024, args.feature_dim)
|
| |
|
| |
|
| | self.vertice_map = nn.Linear(args.vertice_dim, args.feature_dim)
|
| |
|
| | self.PPE = PeriodicPositionalEncoding(args.feature_dim, period = args.period)
|
| |
|
| | self.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 600, period=args.period)
|
| | decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=args.n_head, dim_feedforward=2*args.feature_dim, batch_first=True)
|
| | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=args.num_layers)
|
| |
|
| | self.feat_map = nn.Linear(args.feature_dim, args.face_quan_num*args.zquant_dim, bias=False)
|
| |
|
| | self.learnable_style_emb = nn.Embedding(len(args.train_subjects.split()), args.feature_dim)
|
| |
|
| | self.device = args.device
|
| | nn.init.constant_(self.feat_map.weight, 0)
|
| |
|
| |
|
| | from models.stage1_vocaset import VQAutoEncoder
|
| |
|
| | self.autoencoder = VQAutoEncoder(args)
|
| | temp = torch.load(args.vqvae_pretrained_path)['state_dict']
|
| |
|
| | self.autoencoder.load_state_dict(torch.load(args.vqvae_pretrained_path)['state_dict'])
|
| | for param in self.autoencoder.parameters():
|
| | param.requires_grad = False
|
| |
|
| | def forward(self, audio_name, audio, template, vertice, one_hot, criterion):
|
| |
|
| |
|
| | template = template.unsqueeze(1)
|
| |
|
| |
|
| |
|
| | obj_embedding = self.learnable_style_emb(torch.argmax(one_hot, dim=1))
|
| | obj_embedding = obj_embedding.unsqueeze(1)
|
| |
|
| | frame_num = vertice.shape[1]
|
| |
|
| |
|
| | hidden_states = self.audio_encoder(audio, self.dataset, frame_num=frame_num).last_hidden_state
|
| | if self.dataset == "BIWI" or self.dataset=="multi":
|
| | if hidden_states.shape[1]<frame_num*2:
|
| | vertice = vertice[:, :hidden_states.shape[1]//2]
|
| | frame_num = hidden_states.shape[1]//2
|
| |
|
| | hidden_states = self.audio_feature_map(hidden_states)
|
| |
|
| |
|
| | feat_q_gt, _ = self.autoencoder.get_quant(vertice - template)
|
| | feat_q_gt = feat_q_gt.permute(0,2,1)
|
| |
|
| |
|
| | vertice_emb = obj_embedding
|
| | style_emb = vertice_emb
|
| | vertice_input = torch.cat((template,vertice[:,:-1]), 1)
|
| | vertice_input = vertice_input - template
|
| | vertice_input = self.vertice_map(vertice_input)
|
| | vertice_input = vertice_input + style_emb
|
| | vertice_input = self.PPE(vertice_input)
|
| | tgt_mask = self.biased_mask[:, :vertice_input.shape[1], :vertice_input.shape[1]].clone().detach().to(device=self.device)
|
| | memory_mask = enc_dec_mask(self.device, self.dataset, vertice_input.shape[1], hidden_states.shape[1])
|
| | feat_out = self.transformer_decoder(vertice_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| | feat_out = self.feat_map(feat_out)
|
| | feat_out = feat_out.reshape(feat_out.shape[0], feat_out.shape[1]*self.args.face_quan_num, -1)
|
| |
|
| | feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
| |
|
| | vertice_out = self.autoencoder.decode(feat_out_q)
|
| | vertice_out = vertice_out + template
|
| |
|
| |
|
| | loss_motion = criterion(vertice_out, vertice)
|
| | loss_reg = criterion(feat_out, feat_q_gt.detach())
|
| |
|
| | return self.args.motion_weight*loss_motion + self.args.reg_weight*loss_reg, [loss_motion, loss_reg]
|
| |
|
| |
|
| | def predict(self, audio, template, one_hot, one_hot2=None, weight_of_one_hot=None, gt_frame_num=None):
|
| | template = template.unsqueeze(1)
|
| |
|
| |
|
| | obj_embedding = self.learnable_style_emb(torch.argmax(one_hot, dim=1))
|
| |
|
| |
|
| | if one_hot2 is not None and weight_of_one_hot is not None:
|
| | obj_embedding2 = self.learnable_style_emb(torch.argmax(one_hot2, dim=1))
|
| | obj_embedding = obj_embedding * weight_of_one_hot + obj_embedding2 * (1-weight_of_one_hot)
|
| | obj_embedding = obj_embedding.unsqueeze(1)
|
| |
|
| |
|
| | if gt_frame_num:
|
| | hidden_states = self.audio_encoder(audio, self.dataset, frame_num=gt_frame_num).last_hidden_state
|
| | else:
|
| | hidden_states = self.audio_encoder(audio, self.dataset).last_hidden_state
|
| |
|
| | if self.dataset == "BIWI":
|
| | frame_num = hidden_states.shape[1]//2
|
| | elif self.dataset == "vocaset":
|
| | frame_num = hidden_states.shape[1]
|
| | elif self.dataset == "multi":
|
| | if not gt_frame_num:
|
| | frame_num = hidden_states.shape[1]//2
|
| | else:
|
| | frame_num=gt_frame_num
|
| |
|
| | hidden_states = self.audio_feature_map(hidden_states)
|
| |
|
| |
|
| | for i in range(frame_num):
|
| | if i==0:
|
| | vertice_emb = obj_embedding
|
| | style_emb = vertice_emb
|
| | vertice_input = self.PPE(style_emb)
|
| | else:
|
| | vertice_input = self.PPE(vertice_emb)
|
| | tgt_mask = self.biased_mask[:, :vertice_input.shape[1], :vertice_input.shape[1]].clone().detach().to(device=self.device)
|
| | memory_mask = enc_dec_mask(self.device, self.dataset, vertice_input.shape[1], hidden_states.shape[1])
|
| | feat_out = self.transformer_decoder(vertice_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| | feat_out = self.feat_map(feat_out)
|
| |
|
| | feat_out = feat_out.reshape(feat_out.shape[0], feat_out.shape[1]*self.args.face_quan_num, -1)
|
| |
|
| | feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
| |
|
| | if i == 0:
|
| | vertice_out_q = self.autoencoder.decode(torch.cat([feat_out_q, feat_out_q], dim=-1))
|
| | vertice_out_q = vertice_out_q[:,0].unsqueeze(1)
|
| | else:
|
| | vertice_out_q = self.autoencoder.decode(feat_out_q)
|
| |
|
| | if i != frame_num - 1:
|
| | new_output = self.vertice_map(vertice_out_q[:,-1,:]).unsqueeze(1)
|
| | new_output = new_output + style_emb
|
| | vertice_emb = torch.cat((vertice_emb, new_output), 1)
|
| |
|
| |
|
| | feat_out_q, _, _ = self.autoencoder.quantize(feat_out)
|
| |
|
| | vertice_out = self.autoencoder.decode(feat_out_q)
|
| |
|
| | vertice_out = vertice_out + template
|
| | return vertice_out
|
| |
|