|
|
import importlib
|
|
|
|
|
|
import torch
|
|
|
from torch import Tensor, nn
|
|
|
|
|
|
from . import vb_layers_initialize as init
|
|
|
|
|
|
|
|
|
@torch.compiler.disable
|
|
|
def kernel_triangular_mult(
|
|
|
x,
|
|
|
direction,
|
|
|
mask,
|
|
|
norm_in_weight,
|
|
|
norm_in_bias,
|
|
|
p_in_weight,
|
|
|
g_in_weight,
|
|
|
norm_out_weight,
|
|
|
norm_out_bias,
|
|
|
p_out_weight,
|
|
|
g_out_weight,
|
|
|
eps,
|
|
|
):
|
|
|
triangle_module = importlib.import_module("cuequivariance_torch.primitives.triangle")
|
|
|
triangle_multiplicative_update = triangle_module.triangle_multiplicative_update
|
|
|
return triangle_multiplicative_update(
|
|
|
x,
|
|
|
direction=direction,
|
|
|
mask=mask,
|
|
|
norm_in_weight=norm_in_weight,
|
|
|
norm_in_bias=norm_in_bias,
|
|
|
p_in_weight=p_in_weight,
|
|
|
g_in_weight=g_in_weight,
|
|
|
norm_out_weight=norm_out_weight,
|
|
|
norm_out_bias=norm_out_bias,
|
|
|
p_out_weight=p_out_weight,
|
|
|
g_out_weight=g_out_weight,
|
|
|
eps=eps,
|
|
|
)
|
|
|
|
|
|
|
|
|
class TriangleMultiplicationOutgoing(nn.Module):
|
|
|
"""TriangleMultiplicationOutgoing."""
|
|
|
|
|
|
def __init__(self, dim: int = 128) -> None:
|
|
|
"""Initialize the TriangularUpdate module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
dim: int
|
|
|
The dimension of the input, default 128
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
|
|
|
self.norm_in = nn.LayerNorm(dim, eps=1e-5)
|
|
|
self.p_in = nn.Linear(dim, 2 * dim, bias=False)
|
|
|
self.g_in = nn.Linear(dim, 2 * dim, bias=False)
|
|
|
|
|
|
self.norm_out = nn.LayerNorm(dim)
|
|
|
self.p_out = nn.Linear(dim, dim, bias=False)
|
|
|
self.g_out = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
init.bias_init_one_(self.norm_in.weight)
|
|
|
init.bias_init_zero_(self.norm_in.bias)
|
|
|
|
|
|
init.lecun_normal_init_(self.p_in.weight)
|
|
|
init.gating_init_(self.g_in.weight)
|
|
|
|
|
|
init.bias_init_one_(self.norm_out.weight)
|
|
|
init.bias_init_zero_(self.norm_out.bias)
|
|
|
|
|
|
init.final_init_(self.p_out.weight)
|
|
|
init.gating_init_(self.g_out.weight)
|
|
|
|
|
|
def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
|
|
|
"""Perform a forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
x: torch.Tensor
|
|
|
The input data of shape (B, N, N, D)
|
|
|
mask: torch.Tensor
|
|
|
The input mask of shape (B, N, N)
|
|
|
use_kernels: bool
|
|
|
Whether to use the kernel
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
x: torch.Tensor
|
|
|
The output data of shape (B, N, N, D)
|
|
|
|
|
|
"""
|
|
|
if use_kernels:
|
|
|
return kernel_triangular_mult(
|
|
|
x,
|
|
|
direction="outgoing",
|
|
|
mask=mask,
|
|
|
norm_in_weight=self.norm_in.weight,
|
|
|
norm_in_bias=self.norm_in.bias,
|
|
|
p_in_weight=self.p_in.weight,
|
|
|
g_in_weight=self.g_in.weight,
|
|
|
norm_out_weight=self.norm_out.weight,
|
|
|
norm_out_bias=self.norm_out.bias,
|
|
|
p_out_weight=self.p_out.weight,
|
|
|
g_out_weight=self.g_out.weight,
|
|
|
eps=1e-5,
|
|
|
)
|
|
|
|
|
|
|
|
|
x = self.norm_in(x)
|
|
|
x_in = x
|
|
|
x = self.p_in(x) * self.g_in(x).sigmoid()
|
|
|
|
|
|
|
|
|
x = x * mask.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
a, b = torch.chunk(x.float(), 2, dim=-1)
|
|
|
|
|
|
|
|
|
x = torch.einsum("bikd,bjkd->bijd", a, b)
|
|
|
|
|
|
|
|
|
x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TriangleMultiplicationIncoming(nn.Module):
|
|
|
"""TriangleMultiplicationIncoming."""
|
|
|
|
|
|
def __init__(self, dim: int = 128) -> None:
|
|
|
"""Initialize the TriangularUpdate module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
dim: int
|
|
|
The dimension of the input, default 128
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
|
|
|
self.norm_in = nn.LayerNorm(dim, eps=1e-5)
|
|
|
self.p_in = nn.Linear(dim, 2 * dim, bias=False)
|
|
|
self.g_in = nn.Linear(dim, 2 * dim, bias=False)
|
|
|
|
|
|
self.norm_out = nn.LayerNorm(dim)
|
|
|
self.p_out = nn.Linear(dim, dim, bias=False)
|
|
|
self.g_out = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
init.bias_init_one_(self.norm_in.weight)
|
|
|
init.bias_init_zero_(self.norm_in.bias)
|
|
|
|
|
|
init.lecun_normal_init_(self.p_in.weight)
|
|
|
init.gating_init_(self.g_in.weight)
|
|
|
|
|
|
init.bias_init_one_(self.norm_out.weight)
|
|
|
init.bias_init_zero_(self.norm_out.bias)
|
|
|
|
|
|
init.final_init_(self.p_out.weight)
|
|
|
init.gating_init_(self.g_out.weight)
|
|
|
|
|
|
def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
|
|
|
"""Perform a forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
x: torch.Tensor
|
|
|
The input data of shape (B, N, N, D)
|
|
|
mask: torch.Tensor
|
|
|
The input mask of shape (B, N, N)
|
|
|
use_kernels: bool
|
|
|
Whether to use the kernel
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
x: torch.Tensor
|
|
|
The output data of shape (B, N, N, D)
|
|
|
|
|
|
"""
|
|
|
if use_kernels:
|
|
|
return kernel_triangular_mult(
|
|
|
x,
|
|
|
direction="incoming",
|
|
|
mask=mask,
|
|
|
norm_in_weight=self.norm_in.weight,
|
|
|
norm_in_bias=self.norm_in.bias,
|
|
|
p_in_weight=self.p_in.weight,
|
|
|
g_in_weight=self.g_in.weight,
|
|
|
norm_out_weight=self.norm_out.weight,
|
|
|
norm_out_bias=self.norm_out.bias,
|
|
|
p_out_weight=self.p_out.weight,
|
|
|
g_out_weight=self.g_out.weight,
|
|
|
eps=1e-5,
|
|
|
)
|
|
|
|
|
|
|
|
|
x = self.norm_in(x)
|
|
|
x_in = x
|
|
|
x = self.p_in(x) * self.g_in(x).sigmoid()
|
|
|
|
|
|
|
|
|
x = x * mask.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
a, b = torch.chunk(x.float(), 2, dim=-1)
|
|
|
|
|
|
|
|
|
x = torch.einsum("bkid,bkjd->bijd", a, b)
|
|
|
|
|
|
|
|
|
x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
|
|
|
|
|
|
return x
|
|
|
|