| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partialmethod, partial |
| | import math |
| | from typing import Optional, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from openfold.model.primitives import Linear, LayerNorm, Attention |
| | from openfold.utils.tensor_utils import ( |
| | chunk_layer, |
| | permute_final_dims, |
| | flatten_final_dims, |
| | ) |
| |
|
| |
|
| | class TriangleAttention(nn.Module): |
| | def __init__( |
| | self, c_in, c_hidden, no_heads, starting, inf=1e9 |
| | ): |
| | """ |
| | Args: |
| | c_in: |
| | Input channel dimension |
| | c_hidden: |
| | Overall hidden channel dimension (not per-head) |
| | no_heads: |
| | Number of attention heads |
| | """ |
| | super(TriangleAttention, self).__init__() |
| |
|
| | self.c_in = c_in |
| | self.c_hidden = c_hidden |
| | self.no_heads = no_heads |
| | self.starting = starting |
| | self.inf = inf |
| |
|
| | self.layer_norm = LayerNorm(self.c_in) |
| |
|
| | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") |
| |
|
| | self.mha = Attention( |
| | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads |
| | ) |
| |
|
| | @torch.jit.ignore |
| | def _chunk(self, |
| | x: torch.Tensor, |
| | biases: List[torch.Tensor], |
| | chunk_size: int, |
| | ) -> torch.Tensor: |
| | mha_inputs = { |
| | "q_x": x, |
| | "kv_x": x, |
| | "biases": biases, |
| | } |
| | return chunk_layer( |
| | partial(self.mha), |
| | mha_inputs, |
| | chunk_size=chunk_size, |
| | no_batch_dims=len(x.shape[:-2]), |
| | ) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | chunk_size: Optional[int] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | [*, I, J, C_in] input tensor (e.g. the pair representation) |
| | Returns: |
| | [*, I, J, C_in] output tensor |
| | """ |
| | if mask is None: |
| | |
| | mask = x.new_ones( |
| | x.shape[:-1], |
| | ) |
| |
|
| | |
| | if not self.starting: |
| | x = x.transpose(-2, -3) |
| | mask = mask.transpose(-1, -2) |
| |
|
| | |
| | x = self.layer_norm(x) |
| |
|
| | |
| | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] |
| |
|
| | |
| | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) |
| |
|
| | |
| | triangle_bias = triangle_bias.unsqueeze(-4) |
| |
|
| | biases = [mask_bias, triangle_bias] |
| |
|
| | if chunk_size is not None: |
| | x = self._chunk(x, biases, chunk_size) |
| | else: |
| | x = self.mha(q_x=x, kv_x=x, biases=biases) |
| |
|
| | if not self.starting: |
| | x = x.transpose(-2, -3) |
| |
|
| | return x |
| |
|
| |
|
| | class TriangleAttentionStartingNode(TriangleAttention): |
| | """ |
| | Implements Algorithm 13. |
| | """ |
| |
|
| | __init__ = partialmethod(TriangleAttention.__init__, starting=True) |
| |
|
| |
|
| | class TriangleAttentionEndingNode(TriangleAttention): |
| | """ |
| | Implements Algorithm 14. |
| | """ |
| |
|
| | __init__ = partialmethod(TriangleAttention.__init__, starting=False) |
| |
|