| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import math
|
| | from opt_einsum import contract as einsum
|
| |
|
| |
|
| | class FeedForwardLayer(nn.Module):
|
| | def __init__(self, d_model, r_ff, p_drop=0.1):
|
| | super(FeedForwardLayer, self).__init__()
|
| | self.norm = nn.LayerNorm(d_model)
|
| | self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
| | self.dropout = nn.Dropout(p_drop)
|
| | self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, src):
|
| | src = self.norm(src)
|
| | src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
| | return src
|
| |
|
| | class Attention(nn.Module):
|
| |
|
| | def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
|
| | super(Attention, self).__init__()
|
| | self.h = n_head
|
| | self.dim = d_hidden
|
| |
|
| | self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
| | self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
| | self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
| |
|
| | self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
| | self.scaling = 1/math.sqrt(d_hidden)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, query, key, value):
|
| | B, Q = query.shape[:2]
|
| | B, K = key.shape[:2]
|
| |
|
| | query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
| | key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
| | value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
| |
|
| | query = query * self.scaling
|
| | attn = einsum('bqhd,bkhd->bhqk', query, key)
|
| | attn = F.softmax(attn, dim=-1)
|
| |
|
| | out = einsum('bhqk,bkhd->bqhd', attn, value)
|
| | out = out.reshape(B, Q, self.h*self.dim)
|
| |
|
| | out = self.to_out(out)
|
| |
|
| | return out
|
| |
|
| | class AttentionWithBias(nn.Module):
|
| | def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
|
| | super(AttentionWithBias, self).__init__()
|
| | self.norm_in = nn.LayerNorm(d_in)
|
| | self.norm_bias = nn.LayerNorm(d_bias)
|
| |
|
| | self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| | self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| | self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
| | self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
| | self.to_g = nn.Linear(d_in, n_head*d_hidden)
|
| | self.to_out = nn.Linear(n_head*d_hidden, d_in)
|
| |
|
| | self.scaling = 1/math.sqrt(d_hidden)
|
| | self.h = n_head
|
| | self.dim = d_hidden
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, x, bias):
|
| | B, L = x.shape[:2]
|
| |
|
| | x = self.norm_in(x)
|
| | bias = self.norm_bias(bias)
|
| |
|
| | query = self.to_q(x).reshape(B, L, self.h, self.dim)
|
| | key = self.to_k(x).reshape(B, L, self.h, self.dim)
|
| | value = self.to_v(x).reshape(B, L, self.h, self.dim)
|
| | bias = self.to_b(bias)
|
| | gate = torch.sigmoid(self.to_g(x))
|
| |
|
| | key = key * self.scaling
|
| | attn = einsum('bqhd,bkhd->bqkh', query, key)
|
| | attn = attn + bias
|
| | attn = F.softmax(attn, dim=-2)
|
| |
|
| | out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| | out = gate * out
|
| |
|
| | out = self.to_out(out)
|
| | return out
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class RowAttentionWithBias(nn.Module):
|
| | def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
| | super().__init__()
|
| | self.norm_msa = nn.LayerNorm(d_msa)
|
| | self.norm_pair = nn.LayerNorm(d_pair)
|
| |
|
| |
|
| | self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
| | self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
| | self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
| |
|
| | self.scaling = 1/math.sqrt(d_hidden)
|
| | self.h = n_head
|
| | self.dim = d_hidden
|
| |
|
| | def forward(self, msa, pair, mask = None):
|
| | B, L = msa.shape[:2]
|
| |
|
| | msa = self.norm_msa(msa)
|
| | pair = self.norm_pair(pair)
|
| |
|
| |
|
| | query = self.to_q(msa).reshape(B, L, self.h, self.dim)
|
| | key = self.to_k(msa).reshape(B, L, self.h, self.dim)
|
| | value = self.to_v(msa).reshape(B, L, self.h, self.dim)
|
| | bias = self.to_b(pair)
|
| | gate = torch.sigmoid(self.to_g(msa))
|
| |
|
| |
|
| | key = key * self.scaling
|
| | attn = einsum('bqhd,bkhd->bqkh', query, key)
|
| | attn = attn + bias
|
| |
|
| | if mask is not None:
|
| | mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
|
| | attn = attn * mask_re - 1e9 * (1-mask_re)
|
| |
|
| | attn = F.softmax(attn, dim=-2)
|
| |
|
| | out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| | out = gate * out
|
| |
|
| | out = self.to_out(out)
|
| | return out
|
| |
|
| | class ColAttention(nn.Module):
|
| | def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
| | super().__init__()
|
| | self.norm_msa = nn.LayerNorm(d_msa)
|
| |
|
| | self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
| | self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
| | self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
| |
|
| | self.scaling = 1/math.sqrt(d_hidden)
|
| | self.h = n_head
|
| | self.dim = d_hidden
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, msa, mask = None):
|
| | '''
|
| | msa (B,L,d_node)
|
| | '''
|
| | B, L = msa.shape[:2]
|
| |
|
| | msa = self.norm_msa(msa)
|
| |
|
| | query = self.to_q(msa).reshape(B, L, self.h, self.dim)
|
| | key = self.to_k(msa).reshape(B, L, self.h, self.dim)
|
| | value = self.to_v(msa).reshape(B, L, self.h, self.dim)
|
| | gate = torch.sigmoid(self.to_g(msa))
|
| |
|
| | query = query * self.scaling
|
| | attn = einsum('bqhd,bkhd->bqkh', query, key)
|
| |
|
| | if mask is not None:
|
| | mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
|
| | attn = attn * mask_re - 1e9 * (1-mask_re)
|
| |
|
| | attn = F.softmax(attn, dim=-3)
|
| |
|
| | out = einsum('bkqh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
| | out = gate * out
|
| |
|
| | out = self.to_out(out)
|
| | return out
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class BiasedAxialAttention(nn.Module):
|
| | def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
| | super().__init__()
|
| |
|
| | self.is_row = is_row
|
| | self.norm_pair = nn.LayerNorm(d_pair)
|
| | self.norm_bias = nn.LayerNorm(d_bias)
|
| |
|
| | self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| | self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| | self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
| | self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
| | self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
| | self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
| |
|
| | self.scaling = 1/math.sqrt(d_hidden)
|
| | self.h = n_head
|
| | self.dim = d_hidden
|
| |
|
| | def forward(self, pair, bias, mask = None):
|
| | '''
|
| | pair: (B, L, L, d_pair)
|
| | mask: (B, L)
|
| | '''
|
| |
|
| | B, L = pair.shape[:2]
|
| |
|
| | if self.is_row:
|
| | pair = pair.permute(0,2,1,3)
|
| | bias = bias.permute(0,2,1,3)
|
| |
|
| | pair = self.norm_pair(pair)
|
| | bias = self.norm_bias(bias)
|
| |
|
| | query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
| | key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
| | value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
| | bias = self.to_b(bias)
|
| | gate = torch.sigmoid(self.to_g(pair))
|
| |
|
| | query = query * self.scaling
|
| | key = key / math.sqrt(L)
|
| | attn = einsum('bnihk,bnjhk->bijh', query, key)
|
| | attn = attn + bias
|
| | if mask is not None:
|
| | mask_temp = 1e-9 * (mask.type(torch.float) - 1)
|
| | attn = attn + mask_temp.unsqueeze(1).unsqueeze(-1)
|
| |
|
| | attn = F.softmax(attn, dim=-2)
|
| |
|
| | out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
|
| | out = gate * out
|
| |
|
| | out = self.to_out(out)
|
| | if self.is_row:
|
| | out = out.permute(0,2,1,3)
|
| | return out
|
| |
|
| |
|