| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch_scatter import scatter_sum |
| | from torch_geometric.nn import radius_graph, knn_graph |
| | from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES |
| |
|
| |
|
| | class EnBaseLayer(nn.Module): |
| | def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False): |
| | super().__init__() |
| | self.r_min = 0. |
| | self.r_max = 10. |
| | self.hidden_dim = hidden_dim |
| | self.num_r_gaussian = num_r_gaussian |
| | self.update_x = update_x |
| | self.act_fn = act_fn |
| | self.norm = norm |
| | if num_r_gaussian > 1: |
| | self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian) |
| | self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim, |
| | num_layer=2, norm=norm, act_fn=act_fn, act_last=True) |
| | self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) |
| | if self.update_x: |
| | |
| | x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]] |
| | layer = nn.Linear(hidden_dim, 1, bias=False) |
| | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) |
| | x_mlp.append(layer) |
| | x_mlp.append(nn.Tanh()) |
| | self.x_mlp = nn.Sequential(*x_mlp) |
| |
|
| | self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) |
| |
|
| | def forward(self, h, x, edge_index, edge_attr=None): |
| | src, dst = edge_index |
| | hi, hj = h[dst], h[src] |
| | |
| | rel_x = x[dst] - x[src] |
| | d_sq = torch.sum(rel_x ** 2, -1, keepdim=True) |
| | if self.num_r_gaussian > 1: |
| | d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8)) |
| | else: |
| | d_feat = d_sq |
| |
|
| | if edge_attr is not None: |
| | edge_feat = torch.cat([d_feat, edge_attr], -1) |
| | else: |
| | edge_feat = d_feat |
| |
|
| | mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1)) |
| | eij = self.edge_inf(mij) |
| | mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0]) |
| |
|
| | |
| | h = h + self.node_mlp(torch.cat([mi, h], -1)) |
| | if self.update_x: |
| | |
| | xi, xj = x[dst], x[src] |
| | |
| | delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0) |
| | x = x + delta_x |
| |
|
| | return h, x |
| |
|
| |
|
| | class EGNN(nn.Module): |
| | def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn', |
| | update_x=True, act_fn='silu', norm=False): |
| | super().__init__() |
| | |
| | self.num_layers = num_layers |
| | self.hidden_dim = hidden_dim |
| | self.num_r_gaussian = num_r_gaussian |
| | self.update_x = update_x |
| | self.act_fn = act_fn |
| | self.norm = norm |
| | self.k = k |
| | |
| | self.cutoff = cutoff |
| | self.cutoff_mode = cutoff_mode |
| | self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian) |
| | self.net = self._build_network() |
| |
|
| | def _build_network(self): |
| | |
| | layers = [] |
| | for l_idx in range(self.num_layers): |
| | layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian, |
| | update_x=self.update_x, act_fn=self.act_fn, norm=self.norm) |
| | layers.append(layer) |
| | return nn.ModuleList(layers) |
| |
|
| | def _connect_edge(self, x, batch, orgshape): |
| | |
| | |
| | if self.cutoff_mode == 'knn': |
| | edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target') |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | elif self.cutoff_mode == 'hybrid': |
| | pass |
| | |
| | |
| | else: |
| | raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}') |
| | return edge_index |
| |
|
| | def forward(self, node_embed, trans_t): |
| |
|
| | B,L = node_embed.shape[:2] |
| |
|
| | x = trans_t.reshape(B*L,-1) |
| | h = node_embed.reshape(B*L,-1) |
| |
|
| | batch = [] |
| | for idx in range(B): |
| | batch += [idx] * L |
| | batch = torch.tensor(batch,device = node_embed.device) |
| |
|
| | all_x = [x] |
| | all_h = [h] |
| | for l_idx, layer in enumerate(self.net): |
| | edge_index = self._connect_edge(x, batch, orgshape = (B,L)) |
| | h, x = layer(h, x, edge_index) |
| | all_x.append(x) |
| | all_h.append(h) |
| | outputs = {'x': x, 'h': h} |
| | return x.reshape(B,L,-1), h.reshape(B,L,-1) |
| |
|