| | import math |
| | from typing import Optional, Callable, List |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from scipy.stats import truncnorm |
| |
|
| |
|
| | def _prod(nums): |
| | out = 1 |
| | for n in nums: |
| | out = out * n |
| | return out |
| |
|
| | def _calculate_fan(linear_weight_shape, fan="fan_in"): |
| | fan_out, fan_in = linear_weight_shape |
| |
|
| | if fan == "fan_in": |
| | f = fan_in |
| | elif fan == "fan_out": |
| | f = fan_out |
| | elif fan == "fan_avg": |
| | f = (fan_in + fan_out) / 2 |
| | else: |
| | raise ValueError("Invalid fan option") |
| |
|
| | return f |
| |
|
| | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): |
| | shape = weights.shape |
| | f = _calculate_fan(shape, fan) |
| | scale = scale / max(1, f) |
| | a = -2 |
| | b = 2 |
| | std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) |
| | size = _prod(shape) |
| | samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) |
| | samples = np.reshape(samples, shape) |
| | with torch.no_grad(): |
| | weights.copy_(torch.tensor(samples, device=weights.device)) |
| |
|
| | def lecun_normal_init_(weights): |
| | trunc_normal_init_(weights, scale=1.0) |
| |
|
| | def he_normal_init_(weights): |
| | trunc_normal_init_(weights, scale=2.0) |
| |
|
| | def glorot_uniform_init_(weights): |
| | nn.init.xavier_uniform_(weights, gain=1) |
| |
|
| | def final_init_(weights): |
| | with torch.no_grad(): |
| | weights.fill_(0.0) |
| |
|
| | def gating_init_(weights): |
| | with torch.no_grad(): |
| | weights.fill_(0.0) |
| |
|
| | def normal_init_(weights): |
| | nn.init.kaiming_normal_(weights, nonlinearity="linear") |
| |
|
| |
|
| | class Linear(nn.Linear): |
| | """ |
| | A Linear layer with in-house nonstandard initializations. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_dim: int, |
| | out_dim: int, |
| | bias: bool = True, |
| | init: str = "default", |
| | init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, |
| | ): |
| | """ |
| | Args: |
| | in_dim: |
| | The final dimension of inputs to the layer |
| | out_dim: |
| | The final dimension of layer outputs |
| | bias: |
| | Whether to learn an additive bias. True by default |
| | init: |
| | The initializer to use. Choose from: |
| | |
| | "default": LeCun fan-in truncated normal initialization |
| | "relu": He initialization w/ truncated normal distribution |
| | "glorot": Fan-average Glorot uniform initialization |
| | "gating": Weights=0, Bias=1 |
| | "normal": Normal initialization with std=1/sqrt(fan_in) |
| | "final": Weights=0, Bias=0 |
| | |
| | Overridden by init_fn if the latter is not None. |
| | init_fn: |
| | A custom initializer taking weight and bias as inputs. |
| | Overrides init if not None. |
| | """ |
| | super(Linear, self).__init__(in_dim, out_dim, bias=bias) |
| |
|
| | if bias: |
| | with torch.no_grad(): |
| | self.bias.fill_(0) |
| | if init_fn is not None: |
| | init_fn(self.weight, self.bias) |
| | else: |
| | if init == "default": |
| | lecun_normal_init_(self.weight) |
| | elif init == "relu": |
| | he_normal_init_(self.weight) |
| | elif init == "glorot": |
| | glorot_uniform_init_(self.weight) |
| | elif init == "gating": |
| | gating_init_(self.weight) |
| | if bias: |
| | with torch.no_grad(): |
| | self.bias.fill_(1.0) |
| | elif init == "normal": |
| | normal_init_(self.weight) |
| | elif init == "final": |
| | final_init_(self.weight) |
| | else: |
| | raise ValueError("Invalid init string.") |
| |
|
| |
|
| | |
| | class NodeTransition(nn.Module): |
| | def __init__(self, dim): |
| | super(NodeTransition, self).__init__() |
| | self.dim = dim |
| | self.linear_1 = Linear(dim, dim, init="relu") |
| | self.linear_2 = Linear(dim, dim, init="relu") |
| | self.linear_3 = Linear(dim, dim, init="final") |
| | self.relu = nn.ReLU() |
| | self.ln = nn.LayerNorm(dim) |
| |
|
| | def forward(self, s): |
| | s_initial = s |
| | s = self.relu(self.linear_1(s)) |
| | s = self.relu(self.linear_2(s)) |
| | s = self.linear_3(s) |
| | s = s + s_initial |
| | s = self.ln(s) |
| | return s |
| |
|
| |
|
| | class EdgeTransition(nn.Module): |
| | def __init__(self, |
| | node_embed_size, |
| | edge_embed_in, |
| | edge_embed_out, |
| | num_layers=2, |
| | node_dilation=2 |
| | ): |
| | super(EdgeTransition, self).__init__() |
| |
|
| | bias_embed_size = node_embed_size // node_dilation |
| | self.initial_embed = Linear( |
| | node_embed_size, bias_embed_size, init="relu") |
| | hidden_size = bias_embed_size * 2 + edge_embed_in |
| | trunk_layers = [] |
| | for _ in range(num_layers): |
| | trunk_layers.append(Linear(hidden_size, hidden_size, init="relu")) |
| | trunk_layers.append(nn.ReLU()) |
| | self.trunk = nn.Sequential(*trunk_layers) |
| | self.final_layer = Linear(hidden_size, edge_embed_out, init="final") |
| | self.layer_norm = nn.LayerNorm(edge_embed_out) |
| |
|
| | def forward(self, node_embed, edge_embed): |
| | node_embed = self.initial_embed(node_embed) |
| | batch_size, num_res, _ = node_embed.shape |
| | edge_bias = torch.cat([ |
| | torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)), |
| | torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)), |
| | ], axis=-1) |
| | edge_embed = torch.cat( |
| | [edge_embed, edge_bias], axis=-1).reshape( |
| | batch_size * num_res**2, -1) |
| | edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed) |
| | edge_embed = self.layer_norm(edge_embed) |
| | edge_embed = edge_embed.reshape( |
| | batch_size, num_res, num_res, -1 |
| | ) |
| | return edge_embed |
| | |
| | |
| | class TorsionAngleHead(nn.Module): |
| | def __init__(self, in_dim, n_torsion_angles, eps=1e-8): |
| | super(TorsionAngleHead, self).__init__() |
| | |
| | self.linear_1 = Linear(in_dim, in_dim, init="relu") |
| | self.linear_2 = Linear(in_dim, in_dim, init="relu") |
| | self.linear_3 = Linear(in_dim, in_dim, init="final") |
| | self.linear_final = Linear(in_dim, n_torsion_angles * 2, init="final") |
| | self.relu = nn.ReLU() |
| | self.eps = eps |
| |
|
| | def forward(self, s): |
| | s_initial = s |
| | s = self.relu(self.linear_1(s)) |
| | s = self.linear_2(s) |
| | s = s + s_initial |
| | |
| | unnormalized_s = self.linear_final(s) |
| | norm_denom = torch.sqrt( |
| | torch.clamp( |
| | torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True), |
| | min=self.eps, |
| | ) |
| | ) |
| | normalized_s = unnormalized_s / norm_denom |
| | return normalized_s |
| | |
| | |
| | class BackboneUpdate(nn.Module): |
| | """ |
| | Implements part of Algorithm 23. |
| | """ |
| |
|
| | def __init__(self, c_s): |
| | """ |
| | Args: |
| | c_s: |
| | Single representation channel dimension |
| | """ |
| | super(BackboneUpdate, self).__init__() |
| |
|
| | self.c_s = c_s |
| | self.linear = Linear(self.c_s, 6, init="final") |
| |
|
| | def forward(self, s: torch.Tensor): |
| | """ |
| | Args: |
| | [*, N_res, C_s] single representation |
| | Returns: |
| | [*, N_res, 6] update vector |
| | """ |
| | |
| | update = self.linear(s) |
| | return update |