Boltz2 / vb_layers_triangular_mult.py
lhallee's picture
Upload folder using huggingface_hub
827d9ec verified
import importlib
import torch
from torch import Tensor, nn
from . import vb_layers_initialize as init
@torch.compiler.disable
def kernel_triangular_mult(
x,
direction,
mask,
norm_in_weight,
norm_in_bias,
p_in_weight,
g_in_weight,
norm_out_weight,
norm_out_bias,
p_out_weight,
g_out_weight,
eps,
):
triangle_module = importlib.import_module("cuequivariance_torch.primitives.triangle")
triangle_multiplicative_update = triangle_module.triangle_multiplicative_update
return triangle_multiplicative_update(
x,
direction=direction,
mask=mask,
norm_in_weight=norm_in_weight,
norm_in_bias=norm_in_bias,
p_in_weight=p_in_weight,
g_in_weight=g_in_weight,
norm_out_weight=norm_out_weight,
norm_out_bias=norm_out_bias,
p_out_weight=p_out_weight,
g_out_weight=g_out_weight,
eps=eps,
)
class TriangleMultiplicationOutgoing(nn.Module):
"""TriangleMultiplicationOutgoing."""
def __init__(self, dim: int = 128) -> None:
"""Initialize the TriangularUpdate module.
Parameters
----------
dim: int
The dimension of the input, default 128
"""
super().__init__()
self.norm_in = nn.LayerNorm(dim, eps=1e-5)
self.p_in = nn.Linear(dim, 2 * dim, bias=False)
self.g_in = nn.Linear(dim, 2 * dim, bias=False)
self.norm_out = nn.LayerNorm(dim)
self.p_out = nn.Linear(dim, dim, bias=False)
self.g_out = nn.Linear(dim, dim, bias=False)
init.bias_init_one_(self.norm_in.weight)
init.bias_init_zero_(self.norm_in.bias)
init.lecun_normal_init_(self.p_in.weight)
init.gating_init_(self.g_in.weight)
init.bias_init_one_(self.norm_out.weight)
init.bias_init_zero_(self.norm_out.bias)
init.final_init_(self.p_out.weight)
init.gating_init_(self.g_out.weight)
def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
"""Perform a forward pass.
Parameters
----------
x: torch.Tensor
The input data of shape (B, N, N, D)
mask: torch.Tensor
The input mask of shape (B, N, N)
use_kernels: bool
Whether to use the kernel
Returns
-------
x: torch.Tensor
The output data of shape (B, N, N, D)
"""
if use_kernels:
return kernel_triangular_mult(
x,
direction="outgoing",
mask=mask,
norm_in_weight=self.norm_in.weight,
norm_in_bias=self.norm_in.bias,
p_in_weight=self.p_in.weight,
g_in_weight=self.g_in.weight,
norm_out_weight=self.norm_out.weight,
norm_out_bias=self.norm_out.bias,
p_out_weight=self.p_out.weight,
g_out_weight=self.g_out.weight,
eps=1e-5,
)
# Input gating: D -> D
x = self.norm_in(x)
x_in = x
x = self.p_in(x) * self.g_in(x).sigmoid()
# Apply mask
x = x * mask.unsqueeze(-1)
# Split input and cast to float
a, b = torch.chunk(x.float(), 2, dim=-1)
# Triangular projection
x = torch.einsum("bikd,bjkd->bijd", a, b)
# Output gating
x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
return x
class TriangleMultiplicationIncoming(nn.Module):
"""TriangleMultiplicationIncoming."""
def __init__(self, dim: int = 128) -> None:
"""Initialize the TriangularUpdate module.
Parameters
----------
dim: int
The dimension of the input, default 128
"""
super().__init__()
self.norm_in = nn.LayerNorm(dim, eps=1e-5)
self.p_in = nn.Linear(dim, 2 * dim, bias=False)
self.g_in = nn.Linear(dim, 2 * dim, bias=False)
self.norm_out = nn.LayerNorm(dim)
self.p_out = nn.Linear(dim, dim, bias=False)
self.g_out = nn.Linear(dim, dim, bias=False)
init.bias_init_one_(self.norm_in.weight)
init.bias_init_zero_(self.norm_in.bias)
init.lecun_normal_init_(self.p_in.weight)
init.gating_init_(self.g_in.weight)
init.bias_init_one_(self.norm_out.weight)
init.bias_init_zero_(self.norm_out.bias)
init.final_init_(self.p_out.weight)
init.gating_init_(self.g_out.weight)
def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
"""Perform a forward pass.
Parameters
----------
x: torch.Tensor
The input data of shape (B, N, N, D)
mask: torch.Tensor
The input mask of shape (B, N, N)
use_kernels: bool
Whether to use the kernel
Returns
-------
x: torch.Tensor
The output data of shape (B, N, N, D)
"""
if use_kernels:
return kernel_triangular_mult(
x,
direction="incoming",
mask=mask,
norm_in_weight=self.norm_in.weight,
norm_in_bias=self.norm_in.bias,
p_in_weight=self.p_in.weight,
g_in_weight=self.g_in.weight,
norm_out_weight=self.norm_out.weight,
norm_out_bias=self.norm_out.bias,
p_out_weight=self.p_out.weight,
g_out_weight=self.g_out.weight,
eps=1e-5,
)
# Input gating: D -> D
x = self.norm_in(x)
x_in = x
x = self.p_in(x) * self.g_in(x).sigmoid()
# Apply mask
x = x * mask.unsqueeze(-1)
# Split input and cast to float
a, b = torch.chunk(x.float(), 2, dim=-1)
# Triangular projection
x = torch.einsum("bkid,bkjd->bijd", a, b)
# Output gating
x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
return x