| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from functools import partial |
| | import math |
| | from typing import Optional, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from openfold.model.primitives import Linear, LayerNorm, Attention |
| | from openfold.model.dropout import ( |
| | DropoutRowwise, |
| | DropoutColumnwise, |
| | ) |
| | from openfold.model.pair_transition import PairTransition |
| | from openfold.model.triangular_attention import ( |
| | TriangleAttentionStartingNode, |
| | TriangleAttentionEndingNode, |
| | ) |
| | from openfold.model.triangular_multiplicative_update import ( |
| | TriangleMultiplicationOutgoing, |
| | TriangleMultiplicationIncoming, |
| | ) |
| | from openfold.utils.checkpointing import checkpoint_blocks |
| | from openfold.utils.tensor_utils import ( |
| | chunk_layer, |
| | permute_final_dims, |
| | flatten_final_dims, |
| | ) |
| |
|
| |
|
| | class TemplatePointwiseAttention(nn.Module): |
| | """ |
| | Implements Algorithm 17. |
| | """ |
| | def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): |
| | """ |
| | Args: |
| | c_t: |
| | Template embedding channel dimension |
| | c_z: |
| | Pair embedding channel dimension |
| | c_hidden: |
| | Hidden channel dimension |
| | """ |
| | super(TemplatePointwiseAttention, self).__init__() |
| |
|
| | self.c_t = c_t |
| | self.c_z = c_z |
| | self.c_hidden = c_hidden |
| | self.no_heads = no_heads |
| | self.inf = inf |
| |
|
| | self.mha = Attention( |
| | self.c_z, |
| | self.c_t, |
| | self.c_t, |
| | self.c_hidden, |
| | self.no_heads, |
| | gating=False, |
| | ) |
| |
|
| | def _chunk(self, |
| | z: torch.Tensor, |
| | t: torch.Tensor, |
| | biases: List[torch.Tensor], |
| | chunk_size: int, |
| | ) -> torch.Tensor: |
| | mha_inputs = { |
| | "q_x": z, |
| | "kv_x": t, |
| | "biases": biases, |
| | } |
| | return chunk_layer( |
| | self.mha, |
| | mha_inputs, |
| | chunk_size=chunk_size, |
| | no_batch_dims=len(z.shape[:-2]), |
| | ) |
| |
|
| |
|
| | def forward(self, |
| | t: torch.Tensor, |
| | z: torch.Tensor, |
| | template_mask: Optional[torch.Tensor] = None, |
| | chunk_size: Optional[int] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | t: |
| | [*, N_templ, N_res, N_res, C_t] template embedding |
| | z: |
| | [*, N_res, N_res, C_t] pair embedding |
| | template_mask: |
| | [*, N_templ] template mask |
| | Returns: |
| | [*, N_res, N_res, C_z] pair embedding update |
| | """ |
| | if template_mask is None: |
| | template_mask = t.new_ones(t.shape[:-3]) |
| |
|
| | bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) |
| |
|
| | |
| | z = z.unsqueeze(-2) |
| |
|
| | |
| | t = permute_final_dims(t, (1, 2, 0, 3)) |
| |
|
| | |
| | biases = [bias] |
| | if chunk_size is not None: |
| | z = self._chunk(z, t, biases, chunk_size) |
| | else: |
| | z = self.mha(q_x=z, kv_x=t, biases=biases) |
| |
|
| | |
| | z = z.squeeze(-2) |
| |
|
| | return z |
| |
|
| |
|
| | class TemplatePairStackBlock(nn.Module): |
| | def __init__( |
| | self, |
| | c_t: int, |
| | c_hidden_tri_att: int, |
| | c_hidden_tri_mul: int, |
| | no_heads: int, |
| | pair_transition_n: int, |
| | dropout_rate: float, |
| | inf: float, |
| | **kwargs, |
| | ): |
| | super(TemplatePairStackBlock, self).__init__() |
| |
|
| | self.c_t = c_t |
| | self.c_hidden_tri_att = c_hidden_tri_att |
| | self.c_hidden_tri_mul = c_hidden_tri_mul |
| | self.no_heads = no_heads |
| | self.pair_transition_n = pair_transition_n |
| | self.dropout_rate = dropout_rate |
| | self.inf = inf |
| |
|
| | self.dropout_row = DropoutRowwise(self.dropout_rate) |
| | self.dropout_col = DropoutColumnwise(self.dropout_rate) |
| |
|
| | self.tri_att_start = TriangleAttentionStartingNode( |
| | self.c_t, |
| | self.c_hidden_tri_att, |
| | self.no_heads, |
| | inf=inf, |
| | ) |
| | self.tri_att_end = TriangleAttentionEndingNode( |
| | self.c_t, |
| | self.c_hidden_tri_att, |
| | self.no_heads, |
| | inf=inf, |
| | ) |
| |
|
| | self.tri_mul_out = TriangleMultiplicationOutgoing( |
| | self.c_t, |
| | self.c_hidden_tri_mul, |
| | ) |
| | self.tri_mul_in = TriangleMultiplicationIncoming( |
| | self.c_t, |
| | self.c_hidden_tri_mul, |
| | ) |
| |
|
| | self.pair_transition = PairTransition( |
| | self.c_t, |
| | self.pair_transition_n, |
| | ) |
| |
|
| | def forward(self, |
| | z: torch.Tensor, |
| | mask: torch.Tensor, |
| | chunk_size: Optional[int] = None, |
| | _mask_trans: bool = True |
| | ): |
| | single_templates = [ |
| | t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) |
| | ] |
| | single_templates_masks = [ |
| | m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) |
| | ] |
| | for i in range(len(single_templates)): |
| | single = single_templates[i] |
| | single_mask = single_templates_masks[i] |
| | |
| | single = single + self.dropout_row( |
| | self.tri_att_start( |
| | single, |
| | chunk_size=chunk_size, |
| | mask=single_mask |
| | ) |
| | ) |
| | single = single + self.dropout_col( |
| | self.tri_att_end( |
| | single, |
| | chunk_size=chunk_size, |
| | mask=single_mask |
| | ) |
| | ) |
| | single = single + self.dropout_row( |
| | self.tri_mul_out( |
| | single, |
| | mask=single_mask |
| | ) |
| | ) |
| | single = single + self.dropout_row( |
| | self.tri_mul_in( |
| | single, |
| | mask=single_mask |
| | ) |
| | ) |
| | single = single + self.pair_transition( |
| | single, |
| | mask=single_mask if _mask_trans else None, |
| | chunk_size=chunk_size, |
| | ) |
| |
|
| | single_templates[i] = single |
| |
|
| | z = torch.cat(single_templates, dim=-4) |
| |
|
| | return z |
| |
|
| |
|
| | class TemplatePairStack(nn.Module): |
| | """ |
| | Implements Algorithm 16. |
| | """ |
| | def __init__( |
| | self, |
| | c_t, |
| | c_hidden_tri_att, |
| | c_hidden_tri_mul, |
| | no_blocks, |
| | no_heads, |
| | pair_transition_n, |
| | dropout_rate, |
| | blocks_per_ckpt, |
| | inf=1e9, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_t: |
| | Template embedding channel dimension |
| | c_hidden_tri_att: |
| | Per-head hidden dimension for triangular attention |
| | c_hidden_tri_att: |
| | Hidden dimension for triangular multiplication |
| | no_blocks: |
| | Number of blocks in the stack |
| | pair_transition_n: |
| | Scale of pair transition (Alg. 15) hidden dimension |
| | dropout_rate: |
| | Dropout rate used throughout the stack |
| | blocks_per_ckpt: |
| | Number of blocks per activation checkpoint. None disables |
| | activation checkpointing |
| | """ |
| | super(TemplatePairStack, self).__init__() |
| |
|
| | self.blocks_per_ckpt = blocks_per_ckpt |
| |
|
| | self.blocks = nn.ModuleList() |
| | for _ in range(no_blocks): |
| | block = TemplatePairStackBlock( |
| | c_t=c_t, |
| | c_hidden_tri_att=c_hidden_tri_att, |
| | c_hidden_tri_mul=c_hidden_tri_mul, |
| | no_heads=no_heads, |
| | pair_transition_n=pair_transition_n, |
| | dropout_rate=dropout_rate, |
| | inf=inf, |
| | ) |
| | self.blocks.append(block) |
| |
|
| | self.layer_norm = LayerNorm(c_t) |
| |
|
| | def forward( |
| | self, |
| | t: torch.tensor, |
| | mask: torch.tensor, |
| | chunk_size: int, |
| | _mask_trans: bool = True, |
| | ): |
| | """ |
| | Args: |
| | t: |
| | [*, N_templ, N_res, N_res, C_t] template embedding |
| | mask: |
| | [*, N_templ, N_res, N_res] mask |
| | Returns: |
| | [*, N_templ, N_res, N_res, C_t] template embedding update |
| | """ |
| | if(mask.shape[-3] == 1): |
| | expand_idx = list(mask.shape) |
| | expand_idx[-3] = t.shape[-4] |
| | mask = mask.expand(*expand_idx) |
| |
|
| | t, = checkpoint_blocks( |
| | blocks=[ |
| | partial( |
| | b, |
| | mask=mask, |
| | chunk_size=chunk_size, |
| | _mask_trans=_mask_trans, |
| | ) |
| | for b in self.blocks |
| | ], |
| | args=(t,), |
| | blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, |
| | ) |
| |
|
| | t = self.layer_norm(t) |
| |
|
| | return t |
| |
|