| | from einops.layers.torch import Rearrange
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.nn import functional as F
|
| |
|
| | from utils.base_model_util import *
|
| | import math
|
| |
|
| | class Norm(nn.Module):
|
| | """ Norm Layer """
|
| |
|
| | def __init__(self, fn, size):
|
| | super().__init__()
|
| | self.norm = nn.LayerNorm(size, eps=1e-5)
|
| | self.fn = fn
|
| |
|
| | def forward(self, x_data):
|
| | if type(x_data) is dict:
|
| | x_norm = self.fn({'x_a':x_data['x_a'], 'x_b':self.norm(x_data['x_b'])})
|
| | return x_norm
|
| | else:
|
| | x, mask_info = x_data
|
| | x_norm, _ = self.fn((self.norm(x), mask_info))
|
| | return (x_norm, mask_info)
|
| |
|
| | class Residual(nn.Module):
|
| | """ Residual Layer """
|
| |
|
| | def __init__(self, fn):
|
| | super().__init__()
|
| | self.fn = fn
|
| |
|
| | def forward(self, x_data):
|
| | if type(x_data) is dict:
|
| | x_resid = self.fn(x_data)['x_b']
|
| | return {'x_a':x_data['x_a'], 'x_b':x_resid+x_data['x_b']}
|
| | else:
|
| | x, mask_info = x_data
|
| | x_resid, _ = self.fn(x_data)
|
| | return (x_resid + x, mask_info)
|
| |
|
| |
|
| | class MLP(nn.Module):
|
| | """ MLP Layer """
|
| |
|
| | def __init__(self, in_dim, out_dim, hidden_dim):
|
| | super().__init__()
|
| | self.l1 = nn.Linear(in_dim, hidden_dim)
|
| | self.activation = get_activation("gelu")
|
| | self.l2 = nn.Linear(hidden_dim, out_dim)
|
| |
|
| | def forward(self, x_data):
|
| | if type(x_data) is dict:
|
| | out = self.l2(self.activation(self.l1(x_data['x_b'])))
|
| | return {'x_a':x_data['x_a'], 'x_b':out}
|
| | else:
|
| | x, mask_info = x_data
|
| | out = self.l2(self.activation(self.l1(x)))
|
| | return (out, mask_info)
|
| |
|
| |
|
| | class CrossModalAttention(nn.Module):
|
| | """ Cross Modal Attention Layer
|
| | Given 2 modalities (a, b), computes the K,V from modality b and Q from
|
| | modality a.
|
| | """
|
| |
|
| | def __init__(self, in_dim, dim, heads=8, in_dim2=None):
|
| | super().__init__()
|
| | self.heads = heads
|
| | self.scale = dim**-0.5
|
| |
|
| | if in_dim2 is not None:
|
| | self.to_kv = nn.Linear(in_dim2, in_dim2 * 2, bias=False)
|
| | else:
|
| | self.to_kv = nn.Linear(in_dim, dim * 2, bias=False)
|
| | self.to_q = nn.Linear(in_dim, dim, bias=False)
|
| | if in_dim2 is not None:
|
| | dim2 = int((in_dim + in_dim2*2) / 3)
|
| | else:
|
| | dim2 = dim
|
| | self.to_out = nn.Linear(dim2, dim)
|
| |
|
| | self.rearrange_qkv = Rearrange(
|
| | "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
| | self.rearrange_out = Rearrange("b h n d -> b n (h d)")
|
| |
|
| | def forward(self, x_data):
|
| | x_a = x_data['x_a']
|
| | x_b = x_data['x_b']
|
| |
|
| | kv = self.to_kv(x_b)
|
| | q = self.to_q(x_a)
|
| |
|
| | qkv = torch.cat((q, kv), dim=-1)
|
| | qkv = self.rearrange_qkv(qkv)
|
| | q = qkv[0]
|
| | k = qkv[1]
|
| | v = qkv[2]
|
| |
|
| | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
|
| | attn = F.softmax(dots, dim=-1)
|
| |
|
| | out = torch.einsum("bhij,bhjd->bhid", attn, v)
|
| | out = self.rearrange_out(out)
|
| | out = self.to_out(out)
|
| | return {'x_a':x_a, 'x_b':out}
|
| |
|
| |
|
| | class Attention(nn.Module):
|
| | """ Attention Layer """
|
| |
|
| | def __init__(self, in_dim, dim, heads=8):
|
| | super().__init__()
|
| | self.heads = heads
|
| | self.scale = dim**-0.5
|
| |
|
| | self.to_qkv = nn.Linear(in_dim, dim * 3, bias=False)
|
| | self.to_out = nn.Linear(dim, dim)
|
| |
|
| | self.rearrange_qkv = Rearrange(
|
| | "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
| | self.rearrange_out = Rearrange("b h n d -> b n (h d)")
|
| |
|
| | def forward(self, x_data):
|
| | x, mask_info = x_data
|
| | max_mask = mask_info['max_mask']
|
| | mask = mask_info['mask']
|
| |
|
| | qkv = self.to_qkv(x)
|
| | qkv = self.rearrange_qkv(qkv)
|
| | q = qkv[0]
|
| | k = qkv[1]
|
| | v = qkv[2]
|
| |
|
| | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
|
| | if max_mask is not None:
|
| | dots[:,:,:max_mask,:max_mask] = \
|
| | dots[:,:,:max_mask,:max_mask].masked_fill(mask == 0., float('-inf'))
|
| |
|
| | attn = F.softmax(dots, dim=-1)
|
| |
|
| | out = torch.einsum("bhij,bhjd->bhid", attn, v)
|
| | out = self.rearrange_out(out)
|
| | out = self.to_out(out)
|
| | return (out, mask_info)
|
| |
|
| |
|
| | class Transformer(nn.Module):
|
| | """ Transformer class
|
| | Parameters
|
| | ----------
|
| | cross_modal : bool
|
| | if true, uses cross-modal attention layers, else is the vanilla Transformer
|
| | in_dim2 : int
|
| | specifies the feature size of the second modality if using cross_modal
|
| | """
|
| |
|
| | def __init__(self,
|
| | in_size=50,
|
| | hidden_size=768,
|
| | num_hidden_layers=12,
|
| | num_attention_heads=12,
|
| | intermediate_size=3072,
|
| | cross_modal=False,
|
| | in_dim2=None):
|
| | super().__init__()
|
| | blocks = []
|
| | attn = False
|
| |
|
| | self.cross_modal = cross_modal
|
| | if cross_modal:
|
| | for i in range(num_hidden_layers):
|
| | blocks.extend([
|
| | Residual(Norm(CrossModalAttention(in_size, hidden_size,
|
| | heads=num_attention_heads,
|
| | in_dim2=in_dim2), hidden_size)),
|
| | Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size),
|
| | hidden_size))
|
| | ])
|
| | else:
|
| | for i in range(num_hidden_layers):
|
| | blocks.extend([
|
| | Residual(Norm(Attention(in_size, hidden_size,
|
| | heads=num_attention_heads), hidden_size)),
|
| | Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size),
|
| | hidden_size))
|
| | ])
|
| | self.net = torch.nn.Sequential(*blocks)
|
| |
|
| | def forward(self, x_data):
|
| | if self.cross_modal:
|
| | assert type(x_data) is dict
|
| | x_data = self.net(x_data)
|
| | x = x_data['x_b']
|
| | else:
|
| | x, mask_info = x_data
|
| | x, _ = self.net((x, mask_info))
|
| | return x
|
| |
|
| |
|
| | class LinearEmbedding(nn.Module):
|
| | """ Linear Layer """
|
| |
|
| | def __init__(self, size, dim):
|
| | super().__init__()
|
| | self.net = nn.Linear(size, dim)
|
| |
|
| | def forward(self, x):
|
| | return self.net(x)
|
| |
|
| |
|
| | class AudioEmbedding(nn.Module):
|
| | """ Audio embedding layer
|
| | Parameters
|
| | ----------
|
| | size : int
|
| | the input feature size of the audio embedding
|
| | dim : int
|
| | the desired output feature size for the audio embedding
|
| | quant_factor: int
|
| | specifies the number of max pool layers applied along the temporal dimension
|
| | version: str (default is 'v6')
|
| | specifies which version of the audio embedding to use
|
| | """
|
| |
|
| | def __init__(self, size, dim, quant_factor, version='v6'):
|
| | super().__init__()
|
| | self.proj = None
|
| | if version == 'v6':
|
| | print('MODEL V6')
|
| | self.net = nn.MaxPool1d(4)
|
| | layers = [nn.Sequential(nn.MaxPool1d(2))]
|
| | for _ in range(1, quant_factor):
|
| | layers += [nn.Sequential(
|
| | nn.MaxPool1d(2)
|
| | )]
|
| | self.squasher = nn.Sequential(*layers)
|
| | self.proj = nn.Linear(size,dim)
|
| |
|
| | def forward(self, x):
|
| | x = self.net(x)
|
| | x = self.squasher(x)
|
| | if self.proj is not None:
|
| | x = self.proj(x.permute(0,2,1)).permute(0,2,1)
|
| | return x
|
| |
|
| | class PositionEmbedding(nn.Module):
|
| | """Postion Embedding Layer"""
|
| |
|
| | def __init__(self, seq_length, dim):
|
| | super().__init__()
|
| | self.pos_embedding = nn.Parameter(torch.zeros(seq_length, dim))
|
| |
|
| | def forward(self, x):
|
| | return x + self.pos_embedding
|
| |
|
| | class PositionalEncoding(nn.Module):
|
| | def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| | super(PositionalEncoding, self).__init__()
|
| | self.dropout = nn.Dropout(p=dropout)
|
| |
|
| | pe = torch.zeros(max_len, d_model)
|
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| | pe[:, 0::2] = torch.sin(position * div_term)
|
| | pe[:, 1::2] = torch.cos(position * div_term)
|
| | pe = pe.unsqueeze(0).transpose(0, 1)
|
| | self.register_buffer('pe', pe)
|
| |
|
| | def forward(self, x):
|
| | x = x + self.pe[:x.size(0), :]
|
| | return self.dropout(x)
|
| |
|
| |
|
| | class CrossModalLayer(nn.Module):
|
| | """Cross Modal Layer inspired by FACT [Li 2021]"""
|
| |
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.config = config
|
| | model_config = self.config['transformer']
|
| | self.transformer_layer = Transformer(
|
| | in_size=model_config['hidden_size'],
|
| | hidden_size=model_config['hidden_size'],
|
| | num_hidden_layers=model_config['num_hidden_layers'],
|
| | num_attention_heads=model_config['num_attention_heads'],
|
| | intermediate_size=model_config['intermediate_size'])
|
| |
|
| | output_layer_config = self.config['output_layer']
|
| | self.cross_norm_layer = nn.LayerNorm(self.config['in_dim'])
|
| | self.cross_output_layer = nn.Linear(
|
| | self.config['in_dim'],
|
| | output_layer_config['out_dim'],
|
| | bias=False)
|
| |
|
| | self.cross_pos_embedding = PositionEmbedding(
|
| | self.config["sequence_length"], self.config['in_dim'])
|
| |
|
| |
|
| | def forward(self, modal_a_sequences, modal_b_sequences, mask_info):
|
| | """
|
| | Parameters
|
| | ----------
|
| | modal_a_sequences : tensor
|
| | the first modality (e.g. Listener motion embedding)
|
| | modal_b_sequences : tensor
|
| | the second modality (e.g. Speaker motion+audio embedding)
|
| | mask_info: dict
|
| | specifies the binary mask that is applied to the Transformer attention
|
| | """
|
| |
|
| | _, _, modal_a_width = get_shape_list(modal_a_sequences)
|
| | merged_sequences = modal_a_sequences
|
| | if modal_b_sequences is not None:
|
| | _, _, modal_b_width = get_shape_list(modal_b_sequences)
|
| | if modal_a_width != modal_b_width:
|
| | raise ValueError(
|
| | "The modal_a hidden size (%d) should be the same with the modal_b"
|
| | "hidden size (%d)" % (modal_a_width, modal_b_width))
|
| | merged_sequences = torch.cat([merged_sequences, modal_b_sequences],
|
| | axis=1)
|
| |
|
| | merged_sequences = self.cross_pos_embedding(merged_sequences)
|
| | merged_sequences = self.transformer_layer((merged_sequences, mask_info))
|
| | merged_sequences = self.cross_norm_layer(merged_sequences)
|
| | logits = self.cross_output_layer(merged_sequences)
|
| | return logits |