| | """Neural network architecture for the flow model.""" |
| | import torch |
| | from torch import nn |
| | from models.node_embedder import NodeEmbedder |
| | from models.edge_embedder import EdgeEmbedder |
| | from models.add_module.structure_module import Node_update, Pair_update, E3_transformer, TensorProductConvLayer, E3_GNNlayer, E3_transformer_no_adjacency, E3_transformer_test, Energy_Adapter_Node |
| | from models.add_module.egnn import EGNN |
| | from models.add_module.model_utils import get_sub_graph |
| | from models import ipa_pytorch |
| | from data import utils as du |
| | from data import all_atom |
| | import e3nn |
| |
|
| | class FlowModel(nn.Module): |
| |
|
| | def __init__(self, model_conf): |
| | super(FlowModel, self).__init__() |
| |
|
| | self._model_conf = model_conf |
| | self._ipa_conf = model_conf.ipa |
| | self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE) |
| | self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE) |
| | self.node_embedder = NodeEmbedder(model_conf.node_features) |
| | self.edge_embedder = EdgeEmbedder(model_conf.edge_features) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | self.use_torsions = model_conf.use_torsions |
| | self.use_adapter_node = model_conf.use_adapter_node |
| | self.use_mid_bb_update = model_conf.use_mid_bb_update |
| | self.use_mid_bb_update_e3 = model_conf.use_mid_bb_update_e3 |
| | self.use_e3_transformer = model_conf.use_e3_transformer |
| |
|
| |
|
| | if self.use_adapter_node: |
| | self.energy_adapter = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) |
| |
|
| | if self.use_torsions: |
| | self.num_torsions = 7 |
| | self.torsions_pred_layer1 = nn.Sequential( |
| | nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), |
| | nn.ReLU(), |
| | nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), |
| | ) |
| | self.torsions_pred_layer2 = nn.Linear(self._ipa_conf.c_s, self.num_torsions * 2) |
| |
|
| | if self.use_e3_transformer or self.use_mid_bb_update_e3: |
| | self.max_dist = self._model_conf.edge_features.max_dist |
| | self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) |
| | input_d_l1 = 3 |
| | e3_d_node_l1 = 32 |
| | irreps_l1_in = e3nn.o3.Irreps(str(input_d_l1)+"x1o") |
| | irreps_l1_out = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o") |
| |
|
| | self.proj_l1_init = e3nn.o3.Linear(irreps_l1_in, irreps_l1_out) |
| |
|
| | |
| | self.trunk = nn.ModuleDict() |
| | for b in range(self._ipa_conf.num_blocks): |
| | if self.use_e3_transformer: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.trunk[f'ipa_{b}'] = E3_transformer_test(self._ipa_conf,e3_d_node_l1 = e3_d_node_l1) |
| | else: |
| | self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf) |
| | self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s) |
| |
|
| | if self.use_adapter_node: |
| | self.trunk[f'energy_adapter_{b}'] = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) |
| |
|
| |
|
| | self.trunk[f'egnn_{b}'] = EGNN(hidden_dim = model_conf.node_embed_size) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.use_mid_bb_update: |
| | if self.use_mid_bb_update_e3: |
| | self.trunk[f'bb_update_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, |
| | e3_d_node_l1 = e3_d_node_l1, final = True, |
| | init='final',dropout=0.0) |
| | else: |
| | self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate( |
| | self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| | if self.use_mid_bb_update_e3: |
| | |
| | |
| | |
| | self.bb_update_layer = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, |
| | e3_d_node_l1 = e3_d_node_l1, final = True, |
| | init='final',dropout=0.0) |
| | elif self.use_e3_transformer: |
| |
|
| | irreps_1 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") |
| | irreps_2 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") |
| | |
| | irreps_3 = e3nn.o3.Irreps("6x0e") |
| | self.bb_tpc = TensorProductConvLayer(irreps_1, irreps_2, irreps_3, self._ipa_conf.c_s, fc_factor=1, |
| | init='final', dropout=0.1) |
| |
|
| | |
| | |
| |
|
| | else: |
| | self.bb_update_layer = ipa_pytorch.BackboneUpdate( |
| | self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) |
| |
|
| | def forward(self, input_feats, use_mask_aatype = False): |
| | ''' |
| | note: B and L are changing during training |
| | input_feats.keys(): |
| | 'aatype' (B,L) |
| | 'res_mask' (B,L) |
| | 't' (B,1) |
| | 'trans_1' (B,L,3) |
| | 'rotmats_1' (B,L,3,3) |
| | 'trans_t' (B,L,3) |
| | 'rotmats_t' (B,L,3,3) |
| | ''' |
| |
|
| | node_mask = input_feats['res_mask'] |
| | edge_mask = node_mask[:, None] * node_mask[:, :, None] |
| | continuous_t = input_feats['t'] |
| | trans_t = input_feats['trans_t'] |
| | rotmats_t = input_feats['rotmats_t'] |
| | aatype = input_feats['aatype'] |
| |
|
| | if self.use_adapter_node: |
| | energy = input_feats['energy'] |
| |
|
| |
|
| | if self.use_e3_transformer or self.use_mid_bb_update_e3: |
| | xyz_t = all_atom.to_atom37(trans_t, rotmats_t)[:, :, :3, :] |
| | edge_src, edge_dst, pair_index = get_sub_graph(xyz_t[:,:,1,:], mask = node_mask, |
| | dist_max=self.max_dist, kmin=32) |
| | xyz_graph = xyz_t[:,:,1,:].reshape(-1,3) |
| | edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst] |
| | edge_sh = e3nn.o3.spherical_harmonics(self.irreps_sh, edge_vec, |
| | normalize=True, normalization='component') |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | node_repr_pre = input_feats['node_repr_pre'] |
| | pair_repr_pre = input_feats['pair_repr_pre'] |
| |
|
| | if 'trans_sc' not in input_feats: |
| | trans_sc = torch.zeros_like(trans_t) |
| | else: |
| | trans_sc = input_feats['trans_sc'] |
| |
|
| | |
| | init_node_embed = self.node_embedder(continuous_t, aatype, node_repr_pre, node_mask) |
| | init_node_embed = init_node_embed * node_mask[..., None] |
| |
|
| | if self.use_adapter_node: |
| | init_node_embed = self.energy_adapter(init_node_embed, energy, mask = node_mask) |
| | init_node_embed = init_node_embed * node_mask[..., None] |
| |
|
| | init_edge_embed = self.edge_embedder( |
| | init_node_embed, trans_t, trans_sc, pair_repr_pre, edge_mask) |
| | init_edge_embed = init_edge_embed * edge_mask[..., None] |
| | |
| | if self.use_e3_transformer or self.use_mid_bb_update_e3: |
| | l1_feats = torch.concat([xyz_t[:,:,0,:]-xyz_t[:,:,1,:], |
| | xyz_t[:,:,1,:], |
| | xyz_t[:,:,2,:]-xyz_t[:,:,1,:]], dim=-1) |
| | |
| | l1_feats = self.proj_l1_init(l1_feats) |
| |
|
| | l1_feats = l1_feats * node_mask[..., None] |
| |
|
| | |
| | curr_rigids = du.create_rigid(rotmats_t, trans_t,) |
| |
|
| | |
| | curr_rigids = self.rigids_ang_to_nm(curr_rigids) |
| | |
| | node_embed = init_node_embed |
| | edge_embed = init_edge_embed |
| | for b in range(self._ipa_conf.num_blocks): |
| | |
| | if self.use_e3_transformer: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, |
| | edge_embed, |
| | l1_feats, |
| | curr_rigids, |
| | node_mask) |
| | l1_feats = l1_feats * node_mask[..., None] |
| | else: |
| | ipa_embed = self.trunk[f'ipa_{b}'](node_embed, |
| | edge_embed, |
| | curr_rigids, |
| | node_mask) |
| | ipa_embed = node_embed + ipa_embed |
| | |
| | ipa_embed = ipa_embed * node_mask[..., None] |
| | node_embed = self.trunk[f'ipa_ln_{b}'](ipa_embed) |
| |
|
| | if self.use_adapter_node: |
| | node_embed = self.trunk[f'energy_adapter_{b}'](node_embed, energy, mask = node_mask) |
| | node_embed = node_embed * node_mask[..., None] |
| | |
| | |
| | __, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t) |
| | node_embed = node_embed * node_mask[..., None] |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.use_mid_bb_update: |
| | if self.use_mid_bb_update_e3: |
| | rigid_update = self.trunk[f'bb_update_{b}'](node_embed, edge_embed, l1_feats, pair_index, |
| | edge_src, edge_dst, edge_sh, mask = node_mask) |
| | elif self.use_e3_transformer: |
| | rigid_update = self.trunk[f'bb_update_{b}'](torch.concat([node_embed,l1_feats],dim=-1)) |
| | else: |
| | rigid_update = self.trunk[f'bb_update_{b}'](node_embed) |
| | curr_rigids = curr_rigids.compose_q_update_vec( |
| | rigid_update, node_mask[..., None]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.use_mid_bb_update_e3: |
| | rigid_update = self.bb_update_layer(node_embed, edge_embed, l1_feats, pair_index, |
| | edge_src, edge_dst, edge_sh, mask = node_mask) |
| | |
| | |
| | elif self.use_e3_transformer: |
| | node_feats_total = torch.concat([node_embed,l1_feats],dim=-1) |
| | rigid_update = self.bb_tpc(node_feats_total, node_feats_total, node_embed) |
| |
|
| | else: |
| | rigid_update = self.bb_update_layer(node_embed) |
| |
|
| | curr_rigids = curr_rigids.compose_q_update_vec( |
| | rigid_update, node_mask[..., None]) |
| |
|
| | curr_rigids = self.rigids_nm_to_ang(curr_rigids) |
| | pred_trans = curr_rigids.get_trans() |
| | |
| | pred_rotmats = curr_rigids.get_rots().get_rot_mats() |
| | |
| |
|
| |
|
| | if self.use_torsions: |
| | pred_torsions = node_embed + self.torsions_pred_layer1(node_embed) |
| | pred_torsions = self.torsions_pred_layer2(pred_torsions).reshape(input_feats['aatype'].shape+(self.num_torsions,2)) |
| |
|
| | norm_torsions = torch.sqrt(torch.sum(pred_torsions ** 2, dim=-1, keepdim=True)) |
| | pred_torsions = pred_torsions / norm_torsions |
| |
|
| | add_rot = pred_torsions.new_zeros((1,) * len(pred_torsions.shape[:-2])+(8-self.num_torsions,2)) |
| | add_rot[..., 1] = 1 |
| | add_rot = add_rot.expand(*pred_torsions.shape[:-2], -1, -1) |
| | pred_torsions_with_CB = torch.concat([add_rot, pred_torsions],dim=-2) |
| |
|
| | |
| |
|
| | return { |
| | 'pred_trans': pred_trans, |
| | 'pred_rotmats': pred_rotmats, |
| | |
| | 'pred_torsions_with_CB': pred_torsions_with_CB, |
| | } |
| |
|