| | import importlib |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | from . import vb_layers_initialize as init |
| |
|
| |
|
| | @torch.compiler.disable |
| | def kernel_triangular_mult( |
| | x, |
| | direction, |
| | mask, |
| | norm_in_weight, |
| | norm_in_bias, |
| | p_in_weight, |
| | g_in_weight, |
| | norm_out_weight, |
| | norm_out_bias, |
| | p_out_weight, |
| | g_out_weight, |
| | eps, |
| | ): |
| | triangle_module = importlib.import_module("cuequivariance_torch.primitives.triangle") |
| | triangle_multiplicative_update = triangle_module.triangle_multiplicative_update |
| | return triangle_multiplicative_update( |
| | x, |
| | direction=direction, |
| | mask=mask, |
| | norm_in_weight=norm_in_weight, |
| | norm_in_bias=norm_in_bias, |
| | p_in_weight=p_in_weight, |
| | g_in_weight=g_in_weight, |
| | norm_out_weight=norm_out_weight, |
| | norm_out_bias=norm_out_bias, |
| | p_out_weight=p_out_weight, |
| | g_out_weight=g_out_weight, |
| | eps=eps, |
| | ) |
| |
|
| |
|
| | class TriangleMultiplicationOutgoing(nn.Module): |
| | """TriangleMultiplicationOutgoing.""" |
| |
|
| | def __init__(self, dim: int = 128) -> None: |
| | """Initialize the TriangularUpdate module. |
| | |
| | Parameters |
| | ---------- |
| | dim: int |
| | The dimension of the input, default 128 |
| | |
| | """ |
| | super().__init__() |
| |
|
| | self.norm_in = nn.LayerNorm(dim, eps=1e-5) |
| | self.p_in = nn.Linear(dim, 2 * dim, bias=False) |
| | self.g_in = nn.Linear(dim, 2 * dim, bias=False) |
| |
|
| | self.norm_out = nn.LayerNorm(dim) |
| | self.p_out = nn.Linear(dim, dim, bias=False) |
| | self.g_out = nn.Linear(dim, dim, bias=False) |
| |
|
| | init.bias_init_one_(self.norm_in.weight) |
| | init.bias_init_zero_(self.norm_in.bias) |
| |
|
| | init.lecun_normal_init_(self.p_in.weight) |
| | init.gating_init_(self.g_in.weight) |
| |
|
| | init.bias_init_one_(self.norm_out.weight) |
| | init.bias_init_zero_(self.norm_out.bias) |
| |
|
| | init.final_init_(self.p_out.weight) |
| | init.gating_init_(self.g_out.weight) |
| |
|
| | def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: |
| | """Perform a forward pass. |
| | |
| | Parameters |
| | ---------- |
| | x: torch.Tensor |
| | The input data of shape (B, N, N, D) |
| | mask: torch.Tensor |
| | The input mask of shape (B, N, N) |
| | use_kernels: bool |
| | Whether to use the kernel |
| | |
| | Returns |
| | ------- |
| | x: torch.Tensor |
| | The output data of shape (B, N, N, D) |
| | |
| | """ |
| | if use_kernels: |
| | return kernel_triangular_mult( |
| | x, |
| | direction="outgoing", |
| | mask=mask, |
| | norm_in_weight=self.norm_in.weight, |
| | norm_in_bias=self.norm_in.bias, |
| | p_in_weight=self.p_in.weight, |
| | g_in_weight=self.g_in.weight, |
| | norm_out_weight=self.norm_out.weight, |
| | norm_out_bias=self.norm_out.bias, |
| | p_out_weight=self.p_out.weight, |
| | g_out_weight=self.g_out.weight, |
| | eps=1e-5, |
| | ) |
| |
|
| | |
| | x = self.norm_in(x) |
| | x_in = x |
| | x = self.p_in(x) * self.g_in(x).sigmoid() |
| |
|
| | |
| | x = x * mask.unsqueeze(-1) |
| |
|
| | |
| | a, b = torch.chunk(x.float(), 2, dim=-1) |
| |
|
| | |
| | x = torch.einsum("bikd,bjkd->bijd", a, b) |
| |
|
| | |
| | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() |
| |
|
| | return x |
| |
|
| |
|
| | class TriangleMultiplicationIncoming(nn.Module): |
| | """TriangleMultiplicationIncoming.""" |
| |
|
| | def __init__(self, dim: int = 128) -> None: |
| | """Initialize the TriangularUpdate module. |
| | |
| | Parameters |
| | ---------- |
| | dim: int |
| | The dimension of the input, default 128 |
| | |
| | """ |
| | super().__init__() |
| |
|
| | self.norm_in = nn.LayerNorm(dim, eps=1e-5) |
| | self.p_in = nn.Linear(dim, 2 * dim, bias=False) |
| | self.g_in = nn.Linear(dim, 2 * dim, bias=False) |
| |
|
| | self.norm_out = nn.LayerNorm(dim) |
| | self.p_out = nn.Linear(dim, dim, bias=False) |
| | self.g_out = nn.Linear(dim, dim, bias=False) |
| |
|
| | init.bias_init_one_(self.norm_in.weight) |
| | init.bias_init_zero_(self.norm_in.bias) |
| |
|
| | init.lecun_normal_init_(self.p_in.weight) |
| | init.gating_init_(self.g_in.weight) |
| |
|
| | init.bias_init_one_(self.norm_out.weight) |
| | init.bias_init_zero_(self.norm_out.bias) |
| |
|
| | init.final_init_(self.p_out.weight) |
| | init.gating_init_(self.g_out.weight) |
| |
|
| | def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: |
| | """Perform a forward pass. |
| | |
| | Parameters |
| | ---------- |
| | x: torch.Tensor |
| | The input data of shape (B, N, N, D) |
| | mask: torch.Tensor |
| | The input mask of shape (B, N, N) |
| | use_kernels: bool |
| | Whether to use the kernel |
| | |
| | Returns |
| | ------- |
| | x: torch.Tensor |
| | The output data of shape (B, N, N, D) |
| | |
| | """ |
| | if use_kernels: |
| | return kernel_triangular_mult( |
| | x, |
| | direction="incoming", |
| | mask=mask, |
| | norm_in_weight=self.norm_in.weight, |
| | norm_in_bias=self.norm_in.bias, |
| | p_in_weight=self.p_in.weight, |
| | g_in_weight=self.g_in.weight, |
| | norm_out_weight=self.norm_out.weight, |
| | norm_out_bias=self.norm_out.bias, |
| | p_out_weight=self.p_out.weight, |
| | g_out_weight=self.g_out.weight, |
| | eps=1e-5, |
| | ) |
| |
|
| | |
| | x = self.norm_in(x) |
| | x_in = x |
| | x = self.p_in(x) * self.g_in(x).sigmoid() |
| |
|
| | |
| | x = x * mask.unsqueeze(-1) |
| |
|
| | |
| | a, b = torch.chunk(x.float(), 2, dim=-1) |
| |
|
| | |
| | x = torch.einsum("bkid,bkjd->bijd", a, b) |
| |
|
| | |
| | x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() |
| |
|
| | return x |
| |
|