import torch import torch.nn as nn import torch.nn.functional as F from torch_scatter import scatter_sum from torch_geometric.nn import radius_graph, knn_graph from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES class EnBaseLayer(nn.Module): def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False): super().__init__() self.r_min = 0. self.r_max = 10. self.hidden_dim = hidden_dim self.num_r_gaussian = num_r_gaussian self.update_x = update_x self.act_fn = act_fn self.norm = norm if num_r_gaussian > 1: self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian) self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn, act_last=True) self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) if self.update_x: # self.x_mlp = MLP(hidden_dim, 1, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]] layer = nn.Linear(hidden_dim, 1, bias=False) torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) x_mlp.append(layer) x_mlp.append(nn.Tanh()) self.x_mlp = nn.Sequential(*x_mlp) self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) def forward(self, h, x, edge_index, edge_attr=None): src, dst = edge_index hi, hj = h[dst], h[src] # \phi_e in Eq(3) rel_x = x[dst] - x[src] d_sq = torch.sum(rel_x ** 2, -1, keepdim=True) if self.num_r_gaussian > 1: d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8)) else: d_feat = d_sq if edge_attr is not None: edge_feat = torch.cat([d_feat, edge_attr], -1) else: edge_feat = d_feat mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1)) eij = self.edge_inf(mij) mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0]) # h update in Eq(6) h = h + self.node_mlp(torch.cat([mi, h], -1)) if self.update_x: # x update in Eq(4) xi, xj = x[dst], x[src] # (xi - xj) / (\|xi - xj\| + C) to make it more stable delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0) x = x + delta_x return h, x class EGNN(nn.Module): def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn', update_x=True, act_fn='silu', norm=False): super().__init__() # Build the network self.num_layers = num_layers self.hidden_dim = hidden_dim self.num_r_gaussian = num_r_gaussian self.update_x = update_x self.act_fn = act_fn self.norm = norm self.k = k # self.k_seq = 16 self.cutoff = cutoff self.cutoff_mode = cutoff_mode self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian) self.net = self._build_network() def _build_network(self): # Equivariant layers layers = [] for l_idx in range(self.num_layers): layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian, update_x=self.update_x, act_fn=self.act_fn, norm=self.norm) layers.append(layer) return nn.ModuleList(layers) def _connect_edge(self, x, batch, orgshape): # if self.cutoff_mode == 'radius': # edge_index = radius_graph(x, r=self.r, batch=batch, flow='source_to_target') if self.cutoff_mode == 'knn': edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target') # mask = torch.zeros((x.shape[0],x.shape[0]), dtype = bool) # (B*L,B*L) # mask[edge_index[0], edge_index[1]] = True # B,L = orgshape # i,j = torch.meshgrid(torch.arange(L),torch.arange(L)) # mask_seq = (torch.abs(i-j) < self.k_seq) # (L,L) # mask_seq = torch.block_diag(*([mask_seq]*B)) # (B*L,B*L) # mask = torch.logical_or(mask,mask_seq) # edge_index = torch.where(mask) # edge_index = (edge_index[0].to(x.device),edge_index[1].to(x.device)) elif self.cutoff_mode == 'hybrid': pass # edge_index = batch_hybrid_edge_connection( # x, k=self.k, mask_ligand=mask_ligand, batch=batch, add_p_index=True) else: raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}') return edge_index def forward(self, node_embed, trans_t): B,L = node_embed.shape[:2] x = trans_t.reshape(B*L,-1) h = node_embed.reshape(B*L,-1) batch = [] for idx in range(B): batch += [idx] * L batch = torch.tensor(batch,device = node_embed.device) all_x = [x] all_h = [h] for l_idx, layer in enumerate(self.net): edge_index = self._connect_edge(x, batch, orgshape = (B,L)) h, x = layer(h, x, edge_index) all_x.append(x) all_h.append(h) outputs = {'x': x, 'h': h} return x.reshape(B,L,-1), h.reshape(B,L,-1)