P2DFlow / models /add_module /structure_module.py
Holmes
test
ca7299e
import torch.nn as nn
import torch
from opt_einsum import contract as einsum
from models.add_module.model_utils import scatter_add, Dropout
from models.add_module.Attention_module import *
from torch_scatter import scatter
from models.utils import get_time_embedding
import e3nn
import math
from typing import Optional, Callable, List, Sequence
from openfold.utils.rigid_utils import Rigid
class TensorProductConvLayer(torch.nn.Module):
def __init__(self, in_irreps, sh_irreps, out_irreps, d_pair, dropout=0.1, fc_factor=1, use_layer_norm=True,
init='normal',use_internal_weights=False):
super().__init__()
self.in_irreps = in_irreps
self.out_irreps = out_irreps
self.sh_irreps = sh_irreps
self.fc_factor = fc_factor
self.use_internal_weights = use_internal_weights
if self.use_internal_weights:
self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=True,
internal_weights=True)
else:
self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
self.fc = nn.Sequential(
nn.LayerNorm(d_pair) if use_layer_norm else nn.Identity(),
nn.Linear(d_pair, self.fc_factor * d_pair),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(self.fc_factor * d_pair, self.tp.weight_numel)
)
if init == 'final':
with torch.no_grad():
for para in self.fc.parameters():
para.fill_(0)
elif init != 'normal':
raise ValueError(f'Unknown init {init}')
def forward(self, attr1, attr2, edge_attr=None):
if self.use_internal_weights:
out = self.tp(attr1, attr2)
else:
out = self.tp(attr1, attr2, self.fc(edge_attr))
return out
class E3_GNNlayer(nn.Module):
def __init__(self, e3_d_node_l0=32, e3_d_node_l1=8, d_rbf=64, d_node=256, d_pair=128,
dist_max=20, final = False, init='normal',dropout=0.1):
super().__init__()
self.dist_max = dist_max
self.d_rbf = d_rbf
self.e3_d_node_l0 = e3_d_node_l0
self.e3_d_node_l1 = e3_d_node_l1
self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3)
irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
# irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e")
self.final = final
if final:
# irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
# irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
# irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
# irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
# irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o")
irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e")
else:
# irreps_output = irreps_input
irreps_hid1 = irreps_input
self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
# self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3)
self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2)
self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
self.tpc_conv = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_hid1, d_pair,init=init, dropout=dropout)
# self.h_hidden = e3nn.o3.Linear(irreps_hid1, irreps_hid1)
# self.h_out = e3nn.o3.Linear(irreps_hid1, irreps_output)
# self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair)
# self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair)
def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst,
edge_sh, mask = None):
'''
l1_feats: (B,L,3*self.e3_d_node_l1)
xyz: (B,L,3,3)
'''
# create graph
B, L = node.shape[:2]
num_nodes = B * L
# (num_nodes, irreps_input)
l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1)
node_graph_total = node_graph_total.reshape(num_nodes,-1) # (B*L, e3_d_node_l0 + 3* e3_d_node_l1)
# xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA
# edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max)
# edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
# edge_length = edge_vec.norm(dim=1)
# rbf_dist = rbf(edge_length, D_count = self.d_rbf)
# edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1)
b,i,j = pair_index
edge_feats = pair[b,i,j]
edge_feats_total = edge_feats
# compute the queries (per node), keys (per edge) and values (per edge)
conv_out = self.tpc_conv(node_graph_total[edge_dst], edge_sh, edge_feats_total)
# scatter_out = scatter_add(conv_out, edge_dst, dim_index=0, num_nodes=num_nodes) # (num_nodes, irreps_output)
out = scatter(conv_out, edge_src, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output)
# out = self.h_hidden(out)
# out = self.h_out(out)
# out = self.h_hidden(value_out, edge_sh, edge_feats_total)
# out = self.h_out(out, edge_sh, edge_feats_total)
out = out.reshape(B,L,-1)
if self.final:
# return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
# out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
# return torch.concat([out[:,:,0:3] + out[:,:,6:9],
# out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
return torch.concat([out[:,:,0:3], out[:,:,3:6]],dim=-1) # (B,L,6)
else:
l0_feats = out[:,:,:self.e3_d_node_l0]
# l1_feats = out[:,:,self.e3_d_node_l0:]
# node = self.proj_node(l0_feats)
l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
node = self.proj_node(l0_feats) + node
return node, l1_feats
class E3_transformer(nn.Module):
def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_rbf=64, d_node=256, d_pair=128,
dist_max=20, final = False):
super().__init__()
self.dist_max = dist_max
self.d_rbf = d_rbf
self.e3_d_node_l0 = e3_d_node_l0
self.e3_d_node_l1 = e3_d_node_l1
self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3)
irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e")
self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3)
irreps_query = irreps_hid1
irreps_key = irreps_hid1
irreps_value = irreps_hid1
self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2)
self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
# self.tpc_q = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_query, d_pair)
self.h_q = e3nn.o3.Linear(irreps_input, irreps_query) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False)
self.tpc_k = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_key, d_pair)
# self.tp_k = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_key, shared_weights=False)
# self.fc_k = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_k.weight_numel], act=torch.nn.functional.relu)
self.tpc_v = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_value, d_pair)
# self.tp_v = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_value, shared_weights=False)
# self.fc_v = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_v.weight_numel], act=torch.nn.functional.relu)
self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e")
self.final = final
if final:
# irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
else:
irreps_output = irreps_input
self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
self.h_hidden = e3nn.o3.Linear(irreps_value, irreps_value)
self.h_out = e3nn.o3.Linear(irreps_value, irreps_output)
# self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair)
# self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair)
def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst,
edge_sh, mask = None):
'''
l1_feats: (B,L,3*self.e3_d_node_l1)
xyz: (B,L,3,3)
'''
# create graph
B, L = node.shape[:2]
num_nodes = B * L
# (num_nodes, irreps_input)
l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1)
node_graph_total = node_graph_total.reshape(num_nodes,-1)
# xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA
# edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max)
# edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
# edge_length = edge_vec.norm(dim=1)
# rbf_dist = rbf(edge_length, D_count = self.d_rbf)
# edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1)
b,i,j = pair_index
edge_feats = pair[b,i,j]
edge_feats_total = edge_feats
# compute the queries (per node), keys (per edge) and values (per edge)
# q = self.h_q(node_graph_total)[edge_dst]
# k = self.tp_k(node_graph_total[edge_src], edge_sh, self.fc_k(edge_feats_total))
# v = self.tp_v(node_graph_total[edge_src], edge_sh, self.fc_v(edge_feats_total))
q = self.h_q(node_graph_total)[edge_dst]
# q = self.tpc_q(node_graph_total[edge_dst], edge_sh, edge_feats_total)
k = self.tpc_k(node_graph_total[edge_src], edge_sh, edge_feats_total)
v = self.tpc_v(node_graph_total[edge_src], edge_sh, edge_feats_total)
# compute the softmax (per edge)
exp = torch.exp(self.dot(q, k)/self.d_hidden_1.sqrt()) # compute the numerator (num_edges, 1)
# z = scatter_add(exp, edge_dst, dim_index=0, num_nodes=num_nodes) # compute the denominator (nodes, 1)
z = scatter(exp, edge_dst, dim=0, dim_size=num_nodes, reduce="sum") # compute the denominator (nodes, 1)
# z[z == 0] = 1 # to avoid 0/0 when all the neighbors are exactly at the cutoff
alpha = exp / (z[edge_dst] + 1e-5) # (num_edges, 1)
# value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output)
out = scatter(alpha * v, edge_dst, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output)
out = self.h_hidden(out)
out = self.h_out(out)
# out = self.h_hidden(value_out, edge_sh, edge_feats_total)
# out = self.h_out(out, edge_sh, edge_feats_total)
out = out.reshape(B,L,-1)
if self.final:
# return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
# out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
return torch.concat([out[:,:,0:3] + out[:,:,6:9],
out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
else:
l0_feats = out[:,:,:self.e3_d_node_l0]
l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
node = self.proj_node(l0_feats) + node
return node, l1_feats
# # /self.d_hidden_1.sqrt()
# v = self.tp_hidden(v, edge_sh, self.fc_hidden(edge_feats_total))
# v = self.tp_out(v, edge_sh, self.fc_out(edge_feats_total))
# value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output)
# # print("before_out_max=",torch.max(value_out))
# out = self.h_hidden(value_out + self.res_connect(node_graph_total))
# out = self.h_out(out)
# out = out.reshape(B,L,-1)
#return out[:,:,6:9] + out[:,:,9:], out[:,:,:3] + out[:,:,3:6]
class E3_transformer_no_adjacency(nn.Module):
def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_node=256, d_pair=128, no_heads=8, final = False):
super().__init__()
self.e3_d_node_l0 = e3_d_node_l0
self.e3_d_node_l1 = e3_d_node_l1
self.no_heads = no_heads
self.inf = 1e5
irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+ str(e3_d_node_l1)+"x1o + "+ str(e3_d_node_l1)+"x1e")
self.d_hid1 = irreps_hid1.dim
irreps_query = irreps_hid1
irreps_key = irreps_hid1
irreps_value = irreps_hid1
self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
self.h_q = e3nn.o3.Linear(irreps_input, irreps_query*self.no_heads) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False)
# self.h_k = e3nn.o3.Linear(irreps_input, irreps_key*self.no_heads)
# self.h_v = e3nn.o3.Linear(irreps_input, irreps_value*self.no_heads)
self.tpc_k = TensorProductConvLayer(irreps_input, irreps_input, irreps_key*self.no_heads, d_pair,
use_internal_weights=True)
self.tpc_v = TensorProductConvLayer(irreps_input, irreps_input, irreps_value*self.no_heads, d_pair,
use_internal_weights=True)
# self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e")
self.linear_pair = nn.Linear(d_pair, self.no_heads)
self.final = final
if final:
# irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
else:
irreps_output = irreps_input
self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
self.h_out = e3nn.o3.Linear(irreps_value*self.no_heads, irreps_output)
# self.h_out = TensorProductConvLayer(irreps_value*self.no_heads, self.irreps_sh, irreps_output, d_pair)
def forward(self, node, pair, l1_feats, mask = None):
'''
l1_feats: (B,L,3*self.e3_d_node_l1)
xyz: (B,L,3,3)
'''
B, L = node.shape[:2]
# (num_nodes, irreps_input)
l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) # (B,L,d_hid1)
q = self.h_q(node_graph_total) # (B,L,d_hid1*no_heads)
q = torch.split(q, q.shape[-1] // self.no_heads, dim=-1)
q = torch.stack(q, dim=-1) # (B,L,d_hid1,no_heads)
# k = self.h_k(node_graph_total)
k = self.tpc_k(node_graph_total, node_graph_total)
k = torch.split(k, k.shape[-1] // self.no_heads, dim=-1)
k = torch.stack(k, dim=-1) # (B,L,d_hid1,no_heads)
# v = self.h_v(node_graph_total)
v = self.tpc_v(node_graph_total, node_graph_total)
v = torch.split(v, v.shape[-1] // self.no_heads, dim=-1)
v = torch.stack(v, dim=-1) # (B,L,d_hid1,no_heads)
# compute the softmax (per edge)
a = q.unsqueeze(-3) - k.unsqueeze(-4) # (B,L,L,d_hid1,no_heads)
a = torch.sum(a**2, dim=-2)/math.sqrt(3*self.d_hid1) # (B,L,L,no_heads) invariable
a = a + self.linear_pair(pair)
if mask is not None:
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) # (B,L,L)
square_mask = self.inf * (square_mask - 1)
a = a + square_mask.unsqueeze(-1)
a = torch.softmax(a, dim=-2) # (B,L,L,no_heads)
out = torch.matmul(
a.permute(0,3,1,2), # (B,no_heads,L,L)
v.permute(0,3,1,2), # (B,no_heads,L,d_hid1)
) # (B,no_heads,L,d_hid1)
out = out.permute(0,2,1,3) # (B,L,no_heads,d_hid1)
out = out.reshape(B,L,-1) # (B,L,d_hid1*no_heads)
out = self.h_out(out)
if self.final:
# return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
# out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
return torch.concat([out[:,:,0:3] + out[:,:,6:9],
out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
else:
l0_feats = out[:,:,:self.e3_d_node_l0]
l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
node = self.proj_node(l0_feats) + node
return node, l1_feats
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
class E3_transformer_test(nn.Module):
def __init__(self,ipa_conf,inf: float = 1e5, eps: float = 1e-8,
e3_d_node_l1=4):
super().__init__()
self._ipa_conf = ipa_conf
self.c_s = ipa_conf.c_s
self.c_z = ipa_conf.c_z
self.c_hidden = ipa_conf.c_hidden
self.no_heads = ipa_conf.no_heads
self.no_qk_points = ipa_conf.no_qk_points
self.no_v_points = ipa_conf.no_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = nn.Linear(self.c_s, hc)
self.linear_kv = nn.Linear(self.c_s, 2 * hc)
hpq = self.no_heads * self.no_qk_points * 3
self.hpq = hpq//3
self.linear_q_points = nn.Linear(self.c_s, hpq)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.hpkv = hpkv//3
self.linear_kv_points = nn.Linear(self.c_s, hpkv)
self.linear_b = nn.Linear(self.c_z, self.no_heads)
self.down_z = nn.Linear(self.c_z, self.c_z // 4)
# self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads)))
# ipa_point_weights_init_(self.head_weights)
concat_out_dim = (
self.c_z // 4 + self.c_hidden + self.no_v_points * 4
)
self.linear_out = nn.Linear(self.no_heads * concat_out_dim, self.c_s)
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
self.use_e3_qkvo = False
irreps_3 = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o")
self.l1_feats_proj = e3nn.o3.Linear(e3nn.o3.Irreps(str(self.no_heads*self.no_v_points)+"x1o"),
irreps_3)
if self.use_e3_qkvo:
irreps_1 = e3nn.o3.Irreps("3x0e")
irreps_2 = e3nn.o3.Irreps("1x1o")
self.tp_q_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4)
# self.tp_q_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True,
# internal_weights=True)
self.tp_kv_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4)
# self.tp_kv_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True,
# internal_weights=True)
self.tp_o_pt_invert = TensorProductConvLayer(irreps_2, irreps_3, irreps_1, 1,fc_factor=4)
# self.tp_o_pt_invert = e3nn.o3.FullyConnectedTensorProduct(irreps_2, irreps_3, irreps_1, shared_weights=True,
# internal_weights=True)
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
l1_feats,
r: Rigid,
mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if _offload_inference:
z = _z_reference_list
else:
z = [z]
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
s_org = s
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
if self.use_e3_qkvo:
q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1),q_pts) # [*, N_res, H * P_q, 3]
# q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1)) # [*, N_res, H * P_q, 3]
else:
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
if self.use_e3_qkvo:
kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1),kv_pts) # [*, N_res, H * (P_q + P_v), 3]
# kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1)) # [*, N_res, H * (P_q + P_v), 3]
else:
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if(_offload_inference):
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
a = torch.matmul(
math.sqrt(1.0 / (3 * self.c_hidden)) *
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
# a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_displacement ** 2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
# head_weights = self.softplus(self.head_weights).view(
# *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
# )
# head_weights = head_weights * math.sqrt(
# 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
# )
# pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(
a, v.transpose(-2, -3)
).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt_l1 = o_pt.reshape(*o_pt.shape[:-3], -1) # [*, N_res, H * P_v * 3]
l1_feats = l1_feats + self.l1_feats_proj(o_pt_l1)
if self.use_e3_qkvo:
o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat(
1,1,self.no_heads, self.no_v_points,1),
torch.sum(o_pt ** 2, dim=-1,keepdim=True))
# o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat(
# 1,1,self.no_heads, self.no_v_points,1))
else:
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
o_pt_norm_feats = flatten_final_dims(
o_pt_dists, 2)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z // 4]
pair_z = self.down_z(z[0])
o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
# [*, N_res, H * C_z // 4]
o_pair = flatten_final_dims(o_pair, 2)
o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
o_feats, dim=-1
)
)
s = s + s_org
return s, l1_feats
class Energy_Adapter_Node(nn.Module):
def __init__(self, d_node=256, n_head=8, p_drop=0.1, ff_factor=1):
super().__init__()
# self.energy_proj = nn.Linear(1, d_node)
self.d_node = d_node
self.tfmr_layer = torch.nn.TransformerDecoderLayer(
d_model=d_node,
nhead=n_head,
dim_feedforward=ff_factor*d_node,
dropout=p_drop,
batch_first=True,
norm_first=False
)
def forward(self, node, energy, mask = None):
'''
energy: (B,)
'''
# energy_emb = self.energy_proj(energy = energy.unsqueeze(-1).unsqueeze(-1)) # (B,1,d_node)
energy_emb = get_time_embedding(
energy,
self.d_node,
max_positions=2056
)[:,None,:] # (B,1,d_node)
if mask is not None:
mask = (1 - mask).bool()
node = self.tfmr_layer(node, energy_emb, tgt_key_padding_mask = mask)
return node
# class EMA():
# def __init__(self, model, decay):
# super().__init__()
# self.model = model
# self.decay = decay
# self.device = None
# self.shadow = {}
# for name, param in self.model.named_parameters():
# if param.requires_grad:
# self.shadow[name] = param.data.clone().detach()
# def to(self, device):
# self.shadow = {name: param.to(device) for name, param in self.shadow.items()}
# self.device = device
# def update(self):
# for name, param in self.model.named_parameters():
# if param.requires_grad:
# assert name in self.shadow
# new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
# self.shadow[name] = new_average.clone()
# def apply_shadow(self):
# for name, param in self.model.named_parameters():
# if param.requires_grad:
# assert name in self.shadow
# param.data.copy_(self.shadow[name])
# self.model.load_state_dict(self.shadow)
class Node_update(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1):
super().__init__()
self.d_hidden = d_msa // n_head
self.ff_factor = ff_factor
self.norm_node = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.row_attn = RowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
n_head=n_head, d_hidden=self.d_hidden)
self.col_attn = ColAttention(d_msa=d_msa, n_head=n_head, d_hidden=self.d_hidden)
self.ff = FeedForwardLayer(d_msa, self.ff_factor, p_drop=p_drop)
def forward(self, msa, pair, mask = None):
'''
Inputs:
- msa: MSA feature (B, L, d_msa)
- pair: Pair feature (B, L, L, d_pair)
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
- xyz: xyz coordinates (B, L, n_atom, 3)
Output:
- msa: Updated MSA feature (B, L, d_msa)
'''
B, L = msa.shape[:2]
# prepare input bias feature by combining pair & coordinate info
msa = self.norm_node(msa)
pair = self.norm_pair(pair)
msa = msa + self.drop_row(self.row_attn(msa, pair, mask=mask))
msa = msa + self.col_attn(msa, mask=mask)
msa = msa + self.ff(msa)
return msa
# class Node2Node(nn.Module):
# def __init__(self, d_node=256, d_pair=128, n_head = 8, dist_max=40,
# d_rbf=64, SE3_param=None, p_drop=0.1):
# super().__init__()
# # initial node & pair feature process
# d_node_l0 = d_node // 2
# self.dist_max = dist_max
# self.norm_node = nn.LayerNorm(d_node_l0)
# self.proj_node = nn.Linear(d_node, d_node_l0)
# self.norm_pair = nn.LayerNorm(d_pair)
# self.proj_pair = nn.Linear(d_pair, d_pair)
# self.d_node_l0 = d_node_l0
# self.d_node = d_node
# self.proj_node_t2 = nn.Linear(d_node_l0, d_node)
# self.node2node_temp = Node2Node_temp(d_msa=d_node, d_pair=d_pair,
# n_head=n_head, d_hidden=d_node//n_head, p_drop=p_drop)
# self.E3 = E3_transformer(e3_d_node_l0=d_node_l0, e3_d_node_l1=48, d_rbf=d_rbf,
# d_pair=d_pair, dist_max = dist_max, keep_dim = True)
# self.alpha_pred = AlphaPred(d_node=d_node,p_drop=p_drop)
# # @torch.cuda.amp.autocast(enabled=False)
# def forward(self, node, pair, l1_feats, xyz, mask = None):
# """
# input:
# msa: node embeddings (B, L, d_node)
# pair: residue pair embeddings (B, L, L, d_pair)
# xyz: initial BB coordinates (B, L, 3, 3)
# R_in: (B,L,3,3)
# """
# B, L = node.shape[:2]
# node = self.node2node_temp(node, pair, mask = mask)
# node = self.norm_node(self.proj_node(node))
# pair = self.norm_pair(self.proj_pair(pair))
# out = self.E3(node, pair, l1_feats, xyz, mask = mask)
# node = self.proj_node_t2(out[...,:self.d_node_l0])
# # l1_feats = l1_feats + out[...,self.d_node_l0:].reshape(B*L,48,3)
# l1_feats = out[...,self.d_node_l0:].reshape(B*L,48,3)
# alpha = self.alpha_pred(node, mask=mask)
# return node, l1_feats, alpha
class Node2Pair_bias(nn.Module):
def __init__(self, d_node=256, d_hidden=32, d_pair=128, d_head = 16):
super().__init__()
self.d_hidden = d_hidden
self.d_head = d_head
self.norm = nn.LayerNorm(d_node)
self.proj_left = nn.Linear(d_node, d_hidden)
self.proj_right = nn.Linear(d_node, d_hidden)
self.proj_out = nn.Linear(d_head, d_pair)
# self.reset_parameter()
# def reset_parameter(self):
# # normal initialization
# self.proj_left = init_lecun_normal(self.proj_left)
# self.proj_right = init_lecun_normal(self.proj_right)
# nn.init.zeros_(self.proj_left.bias)
# nn.init.zeros_(self.proj_right.bias)
# # zero initialize output
# nn.init.zeros_(self.proj_out.weight)
# nn.init.zeros_(self.proj_out.bias)
def forward(self, node, mask = None):
B, L = node.shape[:2]
node = self.norm(node)
if mask is not None:
mask_re = mask[...,None].type(torch.float32)
node = node * mask_re
left = self.proj_left(node).reshape(B, L, self.d_head, -1)
right = self.proj_right(node)
right = (right / np.sqrt(self.d_hidden)).reshape(B, L, self.d_head, -1)
out = einsum('bihk,bjhk->bijh', left, right) # (B, L, L, d_head)
out = self.proj_out(out)
return out
class Pair_update(nn.Module):
def __init__(self, d_node=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1):
super().__init__()
self.d_hidden = d_pair//n_head
self.d_pair_bias = d_pair
self.d_node = d_node
self.ff_factor = ff_factor
self.proj_bias = nn.Linear(2*self.d_node, self.d_pair_bias)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) # # Dropout entire row or column for broadcast_dim
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
self.row_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=True)
self.col_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=False)
self.ff = FeedForwardLayer(d_pair, self.ff_factor, p_drop=p_drop)
def forward(self, node, pair, mask = None):
B, L = pair.shape[:2]
pair_bias_total = torch.cat([
torch.tile(node[:, :, None, :], (1, 1, L, 1)),
torch.tile(node[:, None, :, :], (1, L, 1, 1)),
], axis=-1) # (B, L, L, 2*d_node)
pair_bias_total = self.proj_bias(pair_bias_total) # (B, L, L, d_pair_bias)
if mask is not None:
mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
pair = pair * mask_re
pair_bias_total = pair_bias_total * mask_re
pair = pair + self.drop_row(self.row_attn(pair, pair_bias_total, mask))
pair = pair + self.drop_col(self.col_attn(pair, pair_bias_total, mask))
pair = pair + self.ff(pair)
return pair
class FeatBlock(nn.Module):
def __init__(self, d_node=256, d_pair=128, dist_max = 40, T=500, d_time=64, L_max=64,
d_cent=36, d_rbf=36, p_drop=0.1):
super().__init__()
self.dist_max = dist_max
self.d_rbf = d_rbf
self.d_time = d_time
# self.time_emb_layer = nn.Embedding(T + 1, d_time)
# self.cent_emb_layer = nn.Embedding(L_max + 1, d_cent)
# self.cent_norm = nn.LayerNorm(d_cent)
# self.link_emb_layer = nn.Embedding(2, d_cent)
self.seq_emb_layer = nn.Embedding(22,d_node)
self.node_in_proj = nn.Sequential(
nn.Linear(d_node+d_node, d_node),
nn.ReLU(),
nn.Dropout(p_drop),
nn.Linear(d_node, d_node)
)
self.pair_in_proj = nn.Sequential(
nn.Linear(d_pair+d_rbf, d_pair),
nn.ReLU(),
nn.Dropout(p_drop),
nn.Linear(d_pair, d_pair)
)
def forward(self, node, pair, seq, xyz, t_emb, mask = None):
'''
input:
t_emb: timestep (B, d_time)
mask: (B,L)
'''
B,L = xyz.shape[:2]
# N, Ca, C = xyz[:, :, 0, :], xyz[:, :, 1, :], xyz[:, :, 2, :] # [batch, L, 3]
# R_true, __ = rigid_from_3_points(N, Ca, C)
# q_rotate = torch.tensor(scipy_R.from_matrix(R_true.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_true.shape[:3]).to(R_true.device).type(torch.float32) # [batch, L, 3]
# q_norm = torch.linalg.norm(q_rotate, dim=-1) # [batch, L]
# q_node_emb = []
# for i_d in range(1, self.d_rbf//2 + 1):
# q_node_emb = q_node_emb + [torch.sin(i_d * q_norm)[...,None],torch.cos(i_d * q_norm)[...,None]]
# q_node_emb = torch.concat(q_node_emb,dim = -1) # [batch, L, d_rbf]
# xyz_norm = torch.linalg.norm(xyz[:,:,1], dim=-1) # [batch, L]
# xyz_node_emb = []
# for i_d in range(1, self.d_rbf//2 + 1):
# xyz_node_emb = xyz_node_emb + [torch.sin(i_d * xyz_norm)[...,None],torch.cos(i_d * xyz_norm)[...,None]]
# xyz_node_emb = torch.concat(xyz_node_emb,dim = -1) # [batch, L, d_rbf]
seq_emb = self.seq_emb_layer(seq) # (B, L, d_seq)
# centrality, link_mat = get_centrality(dist_mat, self.dist_max) # (B, L)
# cent_emb = self.cent_norm(self.cent_emb_layer(centrality)) # (B, L, d_cent)
# link_mat_emb = self.link_emb_layer(link_mat) # (B, L, L, d_cent)
# t_emb1 = t_emb[:, None, :].repeat(1, node.shape[1], 1) # (B, L, d_time)
# print(node.shape,seq_emb.shape,t_emb1.shape,q_node_emb.shape,xyz_node_emb.shape)
node = torch.concat([node, seq_emb], dim = -1)
node = self.node_in_proj(node) # (B, L, d_node)
dist_mat = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) # (B, L, L)
rbf_dist = rbf(dist_mat, D_count = self.d_rbf) # (B, L, L, d_rbf)
# R_pair = torch.matmul(R_true[:,:,None,...], R_true[:,None,...].transpose(-1,-2)) # (B,L,L,3,3)
# q_pair = torch.tensor(scipy_R.from_matrix(R_pair.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_pair.shape[:4]).to(R_pair.device).type(torch.float32) # [batch, L, L, 3]
# q_pair_emb = []
# for i_d in range(1, self.d_rbf//2 + 1):
# q_pair_emb = q_pair_emb + [torch.sin(i_d * q_pair),torch.cos(i_d * q_pair)]
# q_pair_emb = torch.concat(q_pair_emb,dim = -1) # [batch, L, L, 3 * d_rbf]
# t_emb2 = t_emb[:, None, None, :].repeat(1, pair.shape[1], pair.shape[2], 1) # (B, L, L, d_time)
pair = torch.concat([pair, rbf_dist], dim = -1)
pair = self.pair_in_proj(pair) # (B, L, L, d_pair)
if mask is not None:
mask = mask.type(torch.float32) # (B, L)
node = node * mask[..., None] # (B, L, d_node)
mask_cross = torch.matmul(mask.unsqueeze(2), mask.unsqueeze(1)) # (B, L, L)
pair = pair * mask_cross[...,None] # (B, L, L, d_pair)
return node, pair, rbf_dist # node: (B, L, d_node) pair: (B, L, L, d_pair)