Holmes
test
ca7299e
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_sum
from torch_geometric.nn import radius_graph, knn_graph
from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES
class EnBaseLayer(nn.Module):
def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False):
super().__init__()
self.r_min = 0.
self.r_max = 10.
self.hidden_dim = hidden_dim
self.num_r_gaussian = num_r_gaussian
self.update_x = update_x
self.act_fn = act_fn
self.norm = norm
if num_r_gaussian > 1:
self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian)
self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim,
num_layer=2, norm=norm, act_fn=act_fn, act_last=True)
self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
if self.update_x:
# self.x_mlp = MLP(hidden_dim, 1, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn)
x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]]
layer = nn.Linear(hidden_dim, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
x_mlp.append(layer)
x_mlp.append(nn.Tanh())
self.x_mlp = nn.Sequential(*x_mlp)
self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn)
def forward(self, h, x, edge_index, edge_attr=None):
src, dst = edge_index
hi, hj = h[dst], h[src]
# \phi_e in Eq(3)
rel_x = x[dst] - x[src]
d_sq = torch.sum(rel_x ** 2, -1, keepdim=True)
if self.num_r_gaussian > 1:
d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8))
else:
d_feat = d_sq
if edge_attr is not None:
edge_feat = torch.cat([d_feat, edge_attr], -1)
else:
edge_feat = d_feat
mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1))
eij = self.edge_inf(mij)
mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0])
# h update in Eq(6)
h = h + self.node_mlp(torch.cat([mi, h], -1))
if self.update_x:
# x update in Eq(4)
xi, xj = x[dst], x[src]
# (xi - xj) / (\|xi - xj\| + C) to make it more stable
delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0)
x = x + delta_x
return h, x
class EGNN(nn.Module):
def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn',
update_x=True, act_fn='silu', norm=False):
super().__init__()
# Build the network
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_r_gaussian = num_r_gaussian
self.update_x = update_x
self.act_fn = act_fn
self.norm = norm
self.k = k
# self.k_seq = 16
self.cutoff = cutoff
self.cutoff_mode = cutoff_mode
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian)
self.net = self._build_network()
def _build_network(self):
# Equivariant layers
layers = []
for l_idx in range(self.num_layers):
layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian,
update_x=self.update_x, act_fn=self.act_fn, norm=self.norm)
layers.append(layer)
return nn.ModuleList(layers)
def _connect_edge(self, x, batch, orgshape):
# if self.cutoff_mode == 'radius':
# edge_index = radius_graph(x, r=self.r, batch=batch, flow='source_to_target')
if self.cutoff_mode == 'knn':
edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target')
# mask = torch.zeros((x.shape[0],x.shape[0]), dtype = bool) # (B*L,B*L)
# mask[edge_index[0], edge_index[1]] = True
# B,L = orgshape
# i,j = torch.meshgrid(torch.arange(L),torch.arange(L))
# mask_seq = (torch.abs(i-j) < self.k_seq) # (L,L)
# mask_seq = torch.block_diag(*([mask_seq]*B)) # (B*L,B*L)
# mask = torch.logical_or(mask,mask_seq)
# edge_index = torch.where(mask)
# edge_index = (edge_index[0].to(x.device),edge_index[1].to(x.device))
elif self.cutoff_mode == 'hybrid':
pass
# edge_index = batch_hybrid_edge_connection(
# x, k=self.k, mask_ligand=mask_ligand, batch=batch, add_p_index=True)
else:
raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}')
return edge_index
def forward(self, node_embed, trans_t):
B,L = node_embed.shape[:2]
x = trans_t.reshape(B*L,-1)
h = node_embed.reshape(B*L,-1)
batch = []
for idx in range(B):
batch += [idx] * L
batch = torch.tensor(batch,device = node_embed.device)
all_x = [x]
all_h = [h]
for l_idx, layer in enumerate(self.net):
edge_index = self._connect_edge(x, batch, orgshape = (B,L))
h, x = layer(h, x, edge_index)
all_x.append(x)
all_h.append(h)
outputs = {'x': x, 'h': h}
return x.reshape(B,L,-1), h.reshape(B,L,-1)