File size: 3,790 Bytes
827d9ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from .vb_modules_encodersv2 import (
AtomEncoder,
PairwiseConditioning,
)
class DiffusionConditioning(Module):
def __init__(
self,
token_s: int,
token_z: int,
atom_s: int,
atom_z: int,
atoms_per_window_queries: int = 32,
atoms_per_window_keys: int = 128,
atom_encoder_depth: int = 3,
atom_encoder_heads: int = 4,
token_transformer_depth: int = 24,
token_transformer_heads: int = 8,
atom_decoder_depth: int = 3,
atom_decoder_heads: int = 4,
atom_feature_dim: int = 128,
conditioning_transition_layers: int = 2,
use_no_atom_char: bool = False,
use_atom_backbone_feat: bool = False,
use_residue_feats_atoms: bool = False,
) -> None:
super().__init__()
self.pairwise_conditioner = PairwiseConditioning(
token_z=token_z,
dim_token_rel_pos_feats=token_z,
num_transitions=conditioning_transition_layers,
)
self.atom_encoder = AtomEncoder(
atom_s=atom_s,
atom_z=atom_z,
token_s=token_s,
token_z=token_z,
atoms_per_window_queries=atoms_per_window_queries,
atoms_per_window_keys=atoms_per_window_keys,
atom_feature_dim=atom_feature_dim,
structure_prediction=True,
use_no_atom_char=use_no_atom_char,
use_atom_backbone_feat=use_atom_backbone_feat,
use_residue_feats_atoms=use_residue_feats_atoms,
)
self.atom_enc_proj_z = nn.ModuleList()
for _ in range(atom_encoder_depth):
self.atom_enc_proj_z.append(
nn.Sequential(
nn.LayerNorm(atom_z),
nn.Linear(atom_z, atom_encoder_heads, bias=False),
)
)
self.atom_dec_proj_z = nn.ModuleList()
for _ in range(atom_decoder_depth):
self.atom_dec_proj_z.append(
nn.Sequential(
nn.LayerNorm(atom_z),
nn.Linear(atom_z, atom_decoder_heads, bias=False),
)
)
self.token_trans_proj_z = nn.ModuleList()
for _ in range(token_transformer_depth):
self.token_trans_proj_z.append(
nn.Sequential(
nn.LayerNorm(token_z),
nn.Linear(token_z, token_transformer_heads, bias=False),
)
)
def forward(
self,
s_trunk, # Float['b n ts']
z_trunk, # Float['b n n tz']
relative_position_encoding, # Float['b n n tz']
feats,
):
z = self.pairwise_conditioner(
z_trunk,
relative_position_encoding,
)
q, c, p, to_keys = self.atom_encoder(
feats=feats,
s_trunk=s_trunk, # Float['b n ts'],
z=z, # Float['b n n tz'],
)
atom_enc_bias = []
for layer in self.atom_enc_proj_z:
atom_enc_bias.append(layer(p))
atom_enc_bias = torch.cat(atom_enc_bias, dim=-1)
atom_dec_bias = []
for layer in self.atom_dec_proj_z:
atom_dec_bias.append(layer(p))
atom_dec_bias = torch.cat(atom_dec_bias, dim=-1)
token_trans_bias = []
for layer in self.token_trans_proj_z:
token_trans_bias.append(layer(z))
token_trans_bias = torch.cat(token_trans_bias, dim=-1)
return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias
|