| | from __future__ import annotations |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import Module |
| |
|
| | from .vb_modules_encodersv2 import ( |
| | AtomEncoder, |
| | PairwiseConditioning, |
| | ) |
| |
|
| |
|
| | class DiffusionConditioning(Module): |
| | def __init__( |
| | self, |
| | token_s: int, |
| | token_z: int, |
| | atom_s: int, |
| | atom_z: int, |
| | atoms_per_window_queries: int = 32, |
| | atoms_per_window_keys: int = 128, |
| | atom_encoder_depth: int = 3, |
| | atom_encoder_heads: int = 4, |
| | token_transformer_depth: int = 24, |
| | token_transformer_heads: int = 8, |
| | atom_decoder_depth: int = 3, |
| | atom_decoder_heads: int = 4, |
| | atom_feature_dim: int = 128, |
| | conditioning_transition_layers: int = 2, |
| | use_no_atom_char: bool = False, |
| | use_atom_backbone_feat: bool = False, |
| | use_residue_feats_atoms: bool = False, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.pairwise_conditioner = PairwiseConditioning( |
| | token_z=token_z, |
| | dim_token_rel_pos_feats=token_z, |
| | num_transitions=conditioning_transition_layers, |
| | ) |
| |
|
| | self.atom_encoder = AtomEncoder( |
| | atom_s=atom_s, |
| | atom_z=atom_z, |
| | token_s=token_s, |
| | token_z=token_z, |
| | atoms_per_window_queries=atoms_per_window_queries, |
| | atoms_per_window_keys=atoms_per_window_keys, |
| | atom_feature_dim=atom_feature_dim, |
| | structure_prediction=True, |
| | use_no_atom_char=use_no_atom_char, |
| | use_atom_backbone_feat=use_atom_backbone_feat, |
| | use_residue_feats_atoms=use_residue_feats_atoms, |
| | ) |
| |
|
| | self.atom_enc_proj_z = nn.ModuleList() |
| | for _ in range(atom_encoder_depth): |
| | self.atom_enc_proj_z.append( |
| | nn.Sequential( |
| | nn.LayerNorm(atom_z), |
| | nn.Linear(atom_z, atom_encoder_heads, bias=False), |
| | ) |
| | ) |
| |
|
| | self.atom_dec_proj_z = nn.ModuleList() |
| | for _ in range(atom_decoder_depth): |
| | self.atom_dec_proj_z.append( |
| | nn.Sequential( |
| | nn.LayerNorm(atom_z), |
| | nn.Linear(atom_z, atom_decoder_heads, bias=False), |
| | ) |
| | ) |
| |
|
| | self.token_trans_proj_z = nn.ModuleList() |
| | for _ in range(token_transformer_depth): |
| | self.token_trans_proj_z.append( |
| | nn.Sequential( |
| | nn.LayerNorm(token_z), |
| | nn.Linear(token_z, token_transformer_heads, bias=False), |
| | ) |
| | ) |
| |
|
| | def forward( |
| | self, |
| | s_trunk, |
| | z_trunk, |
| | relative_position_encoding, |
| | feats, |
| | ): |
| | z = self.pairwise_conditioner( |
| | z_trunk, |
| | relative_position_encoding, |
| | ) |
| |
|
| | q, c, p, to_keys = self.atom_encoder( |
| | feats=feats, |
| | s_trunk=s_trunk, |
| | z=z, |
| | ) |
| |
|
| | atom_enc_bias = [] |
| | for layer in self.atom_enc_proj_z: |
| | atom_enc_bias.append(layer(p)) |
| | atom_enc_bias = torch.cat(atom_enc_bias, dim=-1) |
| |
|
| | atom_dec_bias = [] |
| | for layer in self.atom_dec_proj_z: |
| | atom_dec_bias.append(layer(p)) |
| | atom_dec_bias = torch.cat(atom_dec_bias, dim=-1) |
| |
|
| | token_trans_bias = [] |
| | for layer in self.token_trans_proj_z: |
| | token_trans_bias.append(layer(z)) |
| | token_trans_bias = torch.cat(token_trans_bias, dim=-1) |
| |
|
| | return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias |
| |
|