| | import torch.nn as nn |
| | import torch |
| | from opt_einsum import contract as einsum |
| | from models.add_module.model_utils import scatter_add, Dropout |
| | from models.add_module.Attention_module import * |
| | from torch_scatter import scatter |
| | from models.utils import get_time_embedding |
| | import e3nn |
| | import math |
| | from typing import Optional, Callable, List, Sequence |
| | from openfold.utils.rigid_utils import Rigid |
| |
|
| |
|
| | class TensorProductConvLayer(torch.nn.Module): |
| | def __init__(self, in_irreps, sh_irreps, out_irreps, d_pair, dropout=0.1, fc_factor=1, use_layer_norm=True, |
| | init='normal',use_internal_weights=False): |
| | super().__init__() |
| | self.in_irreps = in_irreps |
| | self.out_irreps = out_irreps |
| | self.sh_irreps = sh_irreps |
| | self.fc_factor = fc_factor |
| |
|
| | self.use_internal_weights = use_internal_weights |
| | if self.use_internal_weights: |
| | self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=True, |
| | internal_weights=True) |
| | else: |
| | self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False) |
| | self.fc = nn.Sequential( |
| | nn.LayerNorm(d_pair) if use_layer_norm else nn.Identity(), |
| | nn.Linear(d_pair, self.fc_factor * d_pair), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(self.fc_factor * d_pair, self.tp.weight_numel) |
| | ) |
| |
|
| | if init == 'final': |
| | with torch.no_grad(): |
| | for para in self.fc.parameters(): |
| | para.fill_(0) |
| | elif init != 'normal': |
| | raise ValueError(f'Unknown init {init}') |
| |
|
| | def forward(self, attr1, attr2, edge_attr=None): |
| | if self.use_internal_weights: |
| | out = self.tp(attr1, attr2) |
| | else: |
| | out = self.tp(attr1, attr2, self.fc(edge_attr)) |
| | return out |
| |
|
| | class E3_GNNlayer(nn.Module): |
| | def __init__(self, e3_d_node_l0=32, e3_d_node_l1=8, d_rbf=64, d_node=256, d_pair=128, |
| | dist_max=20, final = False, init='normal',dropout=0.1): |
| | super().__init__() |
| |
|
| | self.dist_max = dist_max |
| | self.d_rbf = d_rbf |
| | self.e3_d_node_l0 = e3_d_node_l0 |
| | self.e3_d_node_l1 = e3_d_node_l1 |
| | self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3) |
| |
|
| | irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") |
| | |
| |
|
| |
|
| | self.final = final |
| | if final: |
| | |
| | |
| | |
| | |
| | |
| | irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e") |
| | else: |
| | |
| | irreps_hid1 = irreps_input |
| | self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) |
| |
|
| |
|
| | |
| |
|
| | self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) |
| | self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) |
| |
|
| | self.tpc_conv = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_hid1, d_pair,init=init, dropout=dropout) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst, |
| | edge_sh, mask = None): |
| | ''' |
| | l1_feats: (B,L,3*self.e3_d_node_l1) |
| | xyz: (B,L,3,3) |
| | ''' |
| |
|
| | |
| | B, L = node.shape[:2] |
| | num_nodes = B * L |
| | |
| | |
| | l0_feats = self.proj_l0(node) |
| | node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) |
| | node_graph_total = node_graph_total.reshape(num_nodes,-1) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | b,i,j = pair_index |
| | edge_feats = pair[b,i,j] |
| | edge_feats_total = edge_feats |
| |
|
| | |
| | conv_out = self.tpc_conv(node_graph_total[edge_dst], edge_sh, edge_feats_total) |
| | |
| | out = scatter(conv_out, edge_src, dim=0, dim_size=num_nodes, reduce="mean") |
| |
|
| | |
| | |
| | |
| | |
| | out = out.reshape(B,L,-1) |
| |
|
| | if self.final: |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return torch.concat([out[:,:,0:3], out[:,:,3:6]],dim=-1) |
| | else: |
| | l0_feats = out[:,:,:self.e3_d_node_l0] |
| | |
| | |
| | l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats |
| | node = self.proj_node(l0_feats) + node |
| | return node, l1_feats |
| |
|
| | class E3_transformer(nn.Module): |
| | def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_rbf=64, d_node=256, d_pair=128, |
| | dist_max=20, final = False): |
| | super().__init__() |
| |
|
| | self.dist_max = dist_max |
| | self.d_rbf = d_rbf |
| | self.e3_d_node_l0 = e3_d_node_l0 |
| | self.e3_d_node_l1 = e3_d_node_l1 |
| | self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3) |
| |
|
| | irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") |
| | irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e") |
| | self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3) |
| |
|
| | irreps_query = irreps_hid1 |
| | irreps_key = irreps_hid1 |
| | irreps_value = irreps_hid1 |
| | self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) |
| |
|
| | self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) |
| |
|
| | |
| | self.h_q = e3nn.o3.Linear(irreps_input, irreps_query) |
| | |
| | self.tpc_k = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_key, d_pair) |
| | |
| | |
| | |
| | self.tpc_v = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_value, d_pair) |
| | |
| | |
| | |
| | self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e") |
| | |
| | self.final = final |
| | if final: |
| | |
| | irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") |
| | else: |
| | irreps_output = irreps_input |
| | self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) |
| | self.h_hidden = e3nn.o3.Linear(irreps_value, irreps_value) |
| | self.h_out = e3nn.o3.Linear(irreps_value, irreps_output) |
| | |
| | |
| |
|
| | def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst, |
| | edge_sh, mask = None): |
| | ''' |
| | l1_feats: (B,L,3*self.e3_d_node_l1) |
| | xyz: (B,L,3,3) |
| | ''' |
| |
|
| | |
| | B, L = node.shape[:2] |
| | num_nodes = B * L |
| | |
| | |
| | l0_feats = self.proj_l0(node) |
| | node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) |
| | node_graph_total = node_graph_total.reshape(num_nodes,-1) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | b,i,j = pair_index |
| | edge_feats = pair[b,i,j] |
| | edge_feats_total = edge_feats |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | q = self.h_q(node_graph_total)[edge_dst] |
| | |
| | k = self.tpc_k(node_graph_total[edge_src], edge_sh, edge_feats_total) |
| | v = self.tpc_v(node_graph_total[edge_src], edge_sh, edge_feats_total) |
| |
|
| | |
| | exp = torch.exp(self.dot(q, k)/self.d_hidden_1.sqrt()) |
| | |
| | z = scatter(exp, edge_dst, dim=0, dim_size=num_nodes, reduce="sum") |
| | |
| | alpha = exp / (z[edge_dst] + 1e-5) |
| |
|
| | |
| | out = scatter(alpha * v, edge_dst, dim=0, dim_size=num_nodes, reduce="mean") |
| |
|
| | out = self.h_hidden(out) |
| | out = self.h_out(out) |
| | |
| | |
| | out = out.reshape(B,L,-1) |
| |
|
| | if self.final: |
| | |
| | |
| | return torch.concat([out[:,:,0:3] + out[:,:,6:9], |
| | out[:,:,3:6] + out[:,:,9:12]],dim=-1) |
| | else: |
| | l0_feats = out[:,:,:self.e3_d_node_l0] |
| | l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats |
| | node = self.proj_node(l0_feats) + node |
| | return node, l1_feats |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class E3_transformer_no_adjacency(nn.Module): |
| | def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_node=256, d_pair=128, no_heads=8, final = False): |
| | super().__init__() |
| |
|
| | self.e3_d_node_l0 = e3_d_node_l0 |
| | self.e3_d_node_l1 = e3_d_node_l1 |
| | self.no_heads = no_heads |
| | self.inf = 1e5 |
| |
|
| | irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") |
| | irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+ str(e3_d_node_l1)+"x1o + "+ str(e3_d_node_l1)+"x1e") |
| | self.d_hid1 = irreps_hid1.dim |
| |
|
| | irreps_query = irreps_hid1 |
| | irreps_key = irreps_hid1 |
| | irreps_value = irreps_hid1 |
| |
|
| | self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) |
| |
|
| | self.h_q = e3nn.o3.Linear(irreps_input, irreps_query*self.no_heads) |
| | |
| | |
| | |
| | self.tpc_k = TensorProductConvLayer(irreps_input, irreps_input, irreps_key*self.no_heads, d_pair, |
| | use_internal_weights=True) |
| | self.tpc_v = TensorProductConvLayer(irreps_input, irreps_input, irreps_value*self.no_heads, d_pair, |
| | use_internal_weights=True) |
| | |
| | |
| | self.linear_pair = nn.Linear(d_pair, self.no_heads) |
| | |
| | self.final = final |
| | if final: |
| | |
| | irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") |
| | else: |
| | irreps_output = irreps_input |
| | self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) |
| |
|
| | self.h_out = e3nn.o3.Linear(irreps_value*self.no_heads, irreps_output) |
| | |
| |
|
| | def forward(self, node, pair, l1_feats, mask = None): |
| | ''' |
| | l1_feats: (B,L,3*self.e3_d_node_l1) |
| | xyz: (B,L,3,3) |
| | ''' |
| | B, L = node.shape[:2] |
| | |
| | |
| | l0_feats = self.proj_l0(node) |
| | node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) |
| |
|
| | q = self.h_q(node_graph_total) |
| | q = torch.split(q, q.shape[-1] // self.no_heads, dim=-1) |
| | q = torch.stack(q, dim=-1) |
| |
|
| | |
| | k = self.tpc_k(node_graph_total, node_graph_total) |
| | k = torch.split(k, k.shape[-1] // self.no_heads, dim=-1) |
| | k = torch.stack(k, dim=-1) |
| |
|
| | |
| | v = self.tpc_v(node_graph_total, node_graph_total) |
| | v = torch.split(v, v.shape[-1] // self.no_heads, dim=-1) |
| | v = torch.stack(v, dim=-1) |
| |
|
| | |
| | a = q.unsqueeze(-3) - k.unsqueeze(-4) |
| | a = torch.sum(a**2, dim=-2)/math.sqrt(3*self.d_hid1) |
| | a = a + self.linear_pair(pair) |
| |
|
| | if mask is not None: |
| | square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) |
| | square_mask = self.inf * (square_mask - 1) |
| | a = a + square_mask.unsqueeze(-1) |
| |
|
| | a = torch.softmax(a, dim=-2) |
| | out = torch.matmul( |
| | a.permute(0,3,1,2), |
| | v.permute(0,3,1,2), |
| | ) |
| | out = out.permute(0,2,1,3) |
| | out = out.reshape(B,L,-1) |
| |
|
| | out = self.h_out(out) |
| |
|
| | if self.final: |
| | |
| | |
| | return torch.concat([out[:,:,0:3] + out[:,:,6:9], |
| | out[:,:,3:6] + out[:,:,9:12]],dim=-1) |
| | else: |
| | l0_feats = out[:,:,:self.e3_d_node_l0] |
| | l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats |
| | node = self.proj_node(l0_feats) + node |
| | return node, l1_feats |
| |
|
| | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): |
| | zero_index = -1 * len(inds) |
| | first_inds = list(range(len(tensor.shape[:zero_index]))) |
| | return tensor.permute(first_inds + [zero_index + i for i in inds]) |
| |
|
| | def flatten_final_dims(t: torch.Tensor, no_dims: int): |
| | return t.reshape(t.shape[:-no_dims] + (-1,)) |
| | |
| | class E3_transformer_test(nn.Module): |
| |
|
| | def __init__(self,ipa_conf,inf: float = 1e5, eps: float = 1e-8, |
| | e3_d_node_l1=4): |
| |
|
| | super().__init__() |
| |
|
| | self._ipa_conf = ipa_conf |
| | self.c_s = ipa_conf.c_s |
| | self.c_z = ipa_conf.c_z |
| | self.c_hidden = ipa_conf.c_hidden |
| | self.no_heads = ipa_conf.no_heads |
| | self.no_qk_points = ipa_conf.no_qk_points |
| | self.no_v_points = ipa_conf.no_v_points |
| | self.inf = inf |
| | self.eps = eps |
| |
|
| | |
| | |
| | |
| | |
| | hc = self.c_hidden * self.no_heads |
| | self.linear_q = nn.Linear(self.c_s, hc) |
| | self.linear_kv = nn.Linear(self.c_s, 2 * hc) |
| |
|
| | hpq = self.no_heads * self.no_qk_points * 3 |
| | self.hpq = hpq//3 |
| | self.linear_q_points = nn.Linear(self.c_s, hpq) |
| |
|
| | hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 |
| | self.hpkv = hpkv//3 |
| | self.linear_kv_points = nn.Linear(self.c_s, hpkv) |
| |
|
| | self.linear_b = nn.Linear(self.c_z, self.no_heads) |
| | self.down_z = nn.Linear(self.c_z, self.c_z // 4) |
| |
|
| | |
| | |
| |
|
| | concat_out_dim = ( |
| | self.c_z // 4 + self.c_hidden + self.no_v_points * 4 |
| | ) |
| | self.linear_out = nn.Linear(self.no_heads * concat_out_dim, self.c_s) |
| |
|
| | self.softmax = nn.Softmax(dim=-1) |
| | self.softplus = nn.Softplus() |
| |
|
| |
|
| |
|
| | self.use_e3_qkvo = False |
| |
|
| | irreps_3 = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o") |
| | self.l1_feats_proj = e3nn.o3.Linear(e3nn.o3.Irreps(str(self.no_heads*self.no_v_points)+"x1o"), |
| | irreps_3) |
| | if self.use_e3_qkvo: |
| | irreps_1 = e3nn.o3.Irreps("3x0e") |
| | irreps_2 = e3nn.o3.Irreps("1x1o") |
| | self.tp_q_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4) |
| | |
| | |
| |
|
| | self.tp_kv_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4) |
| | |
| | |
| |
|
| | self.tp_o_pt_invert = TensorProductConvLayer(irreps_2, irreps_3, irreps_1, 1,fc_factor=4) |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | s: torch.Tensor, |
| | z: Optional[torch.Tensor], |
| | l1_feats, |
| | r: Rigid, |
| | mask: torch.Tensor, |
| | _offload_inference: bool = False, |
| | _z_reference_list: Optional[Sequence[torch.Tensor]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | s: |
| | [*, N_res, C_s] single representation |
| | z: |
| | [*, N_res, N_res, C_z] pair representation |
| | r: |
| | [*, N_res] transformation object |
| | mask: |
| | [*, N_res] mask |
| | Returns: |
| | [*, N_res, C_s] single representation update |
| | """ |
| | if _offload_inference: |
| | z = _z_reference_list |
| | else: |
| | z = [z] |
| |
|
| | |
| | |
| | |
| | |
| | s_org = s |
| |
|
| | q = self.linear_q(s) |
| | kv = self.linear_kv(s) |
| |
|
| | |
| | q = q.view(q.shape[:-1] + (self.no_heads, -1)) |
| |
|
| | |
| | kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) |
| |
|
| | |
| | k, v = torch.split(kv, self.c_hidden, dim=-1) |
| |
|
| | |
| | q_pts = self.linear_q_points(s) |
| |
|
| | |
| | |
| | q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) |
| | q_pts = torch.stack(q_pts, dim=-1) |
| | if self.use_e3_qkvo: |
| | q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1),q_pts) |
| | |
| | else: |
| | q_pts = r[..., None].apply(q_pts) |
| |
|
| | |
| | q_pts = q_pts.view( |
| | q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) |
| | ) |
| |
|
| | |
| | kv_pts = self.linear_kv_points(s) |
| |
|
| | |
| | kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) |
| | kv_pts = torch.stack(kv_pts, dim=-1) |
| | if self.use_e3_qkvo: |
| | kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1),kv_pts) |
| | |
| | else: |
| | kv_pts = r[..., None].apply(kv_pts) |
| |
|
| | |
| | kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) |
| |
|
| | |
| | k_pts, v_pts = torch.split( |
| | kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | b = self.linear_b(z[0]) |
| | |
| | if(_offload_inference): |
| | z[0] = z[0].cpu() |
| |
|
| | |
| | a = torch.matmul( |
| | math.sqrt(1.0 / (3 * self.c_hidden)) * |
| | permute_final_dims(q, (1, 0, 2)), |
| | permute_final_dims(k, (1, 2, 0)), |
| | ) |
| | |
| | a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) |
| |
|
| | |
| | pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) |
| | pt_att = pt_displacement ** 2 |
| |
|
| | |
| | pt_att = sum(torch.unbind(pt_att, dim=-1)) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | pt_att = torch.sum(pt_att, dim=-1) * (-0.5) |
| | |
| | square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) |
| | square_mask = self.inf * (square_mask - 1) |
| |
|
| | |
| | pt_att = permute_final_dims(pt_att, (2, 0, 1)) |
| | |
| | a = a + pt_att |
| | a = a + square_mask.unsqueeze(-3) |
| | a = self.softmax(a) |
| |
|
| | |
| | |
| | |
| | |
| | o = torch.matmul( |
| | a, v.transpose(-2, -3) |
| | ).transpose(-2, -3) |
| |
|
| | |
| | o = flatten_final_dims(o, 2) |
| |
|
| | |
| | o_pt = torch.sum( |
| | ( |
| | a[..., None, :, :, None] |
| | * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] |
| | ), |
| | dim=-2, |
| | ) |
| |
|
| | |
| | o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) |
| |
|
| | o_pt_l1 = o_pt.reshape(*o_pt.shape[:-3], -1) |
| | l1_feats = l1_feats + self.l1_feats_proj(o_pt_l1) |
| |
|
| | if self.use_e3_qkvo: |
| | o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat( |
| | 1,1,self.no_heads, self.no_v_points,1), |
| | torch.sum(o_pt ** 2, dim=-1,keepdim=True)) |
| | |
| | |
| | else: |
| | o_pt = r[..., None, None].invert_apply(o_pt) |
| | |
| | |
| | o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps) |
| | o_pt_norm_feats = flatten_final_dims( |
| | o_pt_dists, 2) |
| |
|
| | |
| | o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) |
| |
|
| | if(_offload_inference): |
| | z[0] = z[0].to(o_pt.device) |
| |
|
| | |
| | pair_z = self.down_z(z[0]) |
| | o_pair = torch.matmul(a.transpose(-2, -3), pair_z) |
| |
|
| | |
| | o_pair = flatten_final_dims(o_pair, 2) |
| |
|
| | o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] |
| |
|
| | |
| | s = self.linear_out( |
| | torch.cat( |
| | o_feats, dim=-1 |
| | ) |
| | ) |
| | |
| | s = s + s_org |
| |
|
| | return s, l1_feats |
| |
|
| | class Energy_Adapter_Node(nn.Module): |
| | def __init__(self, d_node=256, n_head=8, p_drop=0.1, ff_factor=1): |
| | super().__init__() |
| |
|
| | |
| | self.d_node = d_node |
| | self.tfmr_layer = torch.nn.TransformerDecoderLayer( |
| | d_model=d_node, |
| | nhead=n_head, |
| | dim_feedforward=ff_factor*d_node, |
| | dropout=p_drop, |
| | batch_first=True, |
| | norm_first=False |
| | ) |
| |
|
| | def forward(self, node, energy, mask = None): |
| | ''' |
| | energy: (B,) |
| | ''' |
| | |
| | energy_emb = get_time_embedding( |
| | energy, |
| | self.d_node, |
| | max_positions=2056 |
| | )[:,None,:] |
| |
|
| | if mask is not None: |
| | mask = (1 - mask).bool() |
| | node = self.tfmr_layer(node, energy_emb, tgt_key_padding_mask = mask) |
| |
|
| | return node |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class Node_update(nn.Module): |
| | def __init__(self, d_msa=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1): |
| | super().__init__() |
| | self.d_hidden = d_msa // n_head |
| | self.ff_factor = ff_factor |
| | self.norm_node = nn.LayerNorm(d_msa) |
| | self.norm_pair = nn.LayerNorm(d_pair) |
| |
|
| | self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) |
| | self.row_attn = RowAttentionWithBias(d_msa=d_msa, d_pair=d_pair, |
| | n_head=n_head, d_hidden=self.d_hidden) |
| |
|
| | self.col_attn = ColAttention(d_msa=d_msa, n_head=n_head, d_hidden=self.d_hidden) |
| | self.ff = FeedForwardLayer(d_msa, self.ff_factor, p_drop=p_drop) |
| |
|
| | def forward(self, msa, pair, mask = None): |
| | ''' |
| | Inputs: |
| | - msa: MSA feature (B, L, d_msa) |
| | - pair: Pair feature (B, L, L, d_pair) |
| | - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36) |
| | - xyz: xyz coordinates (B, L, n_atom, 3) |
| | Output: |
| | - msa: Updated MSA feature (B, L, d_msa) |
| | ''' |
| | B, L = msa.shape[:2] |
| |
|
| | |
| | msa = self.norm_node(msa) |
| | pair = self.norm_pair(pair) |
| |
|
| | msa = msa + self.drop_row(self.row_attn(msa, pair, mask=mask)) |
| | msa = msa + self.col_attn(msa, mask=mask) |
| | msa = msa + self.ff(msa) |
| |
|
| | return msa |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | class Node2Pair_bias(nn.Module): |
| | def __init__(self, d_node=256, d_hidden=32, d_pair=128, d_head = 16): |
| | super().__init__() |
| | self.d_hidden = d_hidden |
| | self.d_head = d_head |
| | self.norm = nn.LayerNorm(d_node) |
| | self.proj_left = nn.Linear(d_node, d_hidden) |
| | self.proj_right = nn.Linear(d_node, d_hidden) |
| | self.proj_out = nn.Linear(d_head, d_pair) |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def forward(self, node, mask = None): |
| | B, L = node.shape[:2] |
| | node = self.norm(node) |
| | if mask is not None: |
| | mask_re = mask[...,None].type(torch.float32) |
| | node = node * mask_re |
| |
|
| | left = self.proj_left(node).reshape(B, L, self.d_head, -1) |
| | right = self.proj_right(node) |
| | right = (right / np.sqrt(self.d_hidden)).reshape(B, L, self.d_head, -1) |
| | out = einsum('bihk,bjhk->bijh', left, right) |
| | out = self.proj_out(out) |
| | return out |
| |
|
| | class Pair_update(nn.Module): |
| | def __init__(self, d_node=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1): |
| | super().__init__() |
| | self.d_hidden = d_pair//n_head |
| | self.d_pair_bias = d_pair |
| | self.d_node = d_node |
| | self.ff_factor = ff_factor |
| |
|
| | self.proj_bias = nn.Linear(2*self.d_node, self.d_pair_bias) |
| |
|
| | self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) |
| | self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop) |
| |
|
| | self.row_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=True) |
| | self.col_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=False) |
| |
|
| | self.ff = FeedForwardLayer(d_pair, self.ff_factor, p_drop=p_drop) |
| |
|
| | def forward(self, node, pair, mask = None): |
| |
|
| | B, L = pair.shape[:2] |
| | pair_bias_total = torch.cat([ |
| | torch.tile(node[:, :, None, :], (1, 1, L, 1)), |
| | torch.tile(node[:, None, :, :], (1, L, 1, 1)), |
| | ], axis=-1) |
| | pair_bias_total = self.proj_bias(pair_bias_total) |
| |
|
| | if mask is not None: |
| | mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None] |
| | pair = pair * mask_re |
| | pair_bias_total = pair_bias_total * mask_re |
| |
|
| | pair = pair + self.drop_row(self.row_attn(pair, pair_bias_total, mask)) |
| | pair = pair + self.drop_col(self.col_attn(pair, pair_bias_total, mask)) |
| | pair = pair + self.ff(pair) |
| |
|
| | return pair |
| |
|
| | class FeatBlock(nn.Module): |
| | def __init__(self, d_node=256, d_pair=128, dist_max = 40, T=500, d_time=64, L_max=64, |
| | d_cent=36, d_rbf=36, p_drop=0.1): |
| | super().__init__() |
| | self.dist_max = dist_max |
| | self.d_rbf = d_rbf |
| | self.d_time = d_time |
| | |
| | |
| | |
| | |
| | self.seq_emb_layer = nn.Embedding(22,d_node) |
| | self.node_in_proj = nn.Sequential( |
| | nn.Linear(d_node+d_node, d_node), |
| | nn.ReLU(), |
| | nn.Dropout(p_drop), |
| | nn.Linear(d_node, d_node) |
| | ) |
| |
|
| | self.pair_in_proj = nn.Sequential( |
| | nn.Linear(d_pair+d_rbf, d_pair), |
| | nn.ReLU(), |
| | nn.Dropout(p_drop), |
| | nn.Linear(d_pair, d_pair) |
| | ) |
| |
|
| | def forward(self, node, pair, seq, xyz, t_emb, mask = None): |
| | ''' |
| | input: |
| | t_emb: timestep (B, d_time) |
| | mask: (B,L) |
| | ''' |
| | B,L = xyz.shape[:2] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | seq_emb = self.seq_emb_layer(seq) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | node = torch.concat([node, seq_emb], dim = -1) |
| | node = self.node_in_proj(node) |
| |
|
| | dist_mat = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) |
| | rbf_dist = rbf(dist_mat, D_count = self.d_rbf) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pair = torch.concat([pair, rbf_dist], dim = -1) |
| | pair = self.pair_in_proj(pair) |
| |
|
| | if mask is not None: |
| | mask = mask.type(torch.float32) |
| | node = node * mask[..., None] |
| | mask_cross = torch.matmul(mask.unsqueeze(2), mask.unsqueeze(1)) |
| | pair = pair * mask_cross[...,None] |
| |
|
| | return node, pair, rbf_dist |
| |
|