Holmes
test
ca7299e
import math
from typing import Optional, Callable, List
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import truncnorm
def _prod(nums):
out = 1
for n in nums:
out = out * n
return out
def _calculate_fan(linear_weight_shape, fan="fan_in"):
fan_out, fan_in = linear_weight_shape
if fan == "fan_in":
f = fan_in
elif fan == "fan_out":
f = fan_out
elif fan == "fan_avg":
f = (fan_in + fan_out) / 2
else:
raise ValueError("Invalid fan option")
return f
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
f = _calculate_fan(shape, fan)
scale = scale / max(1, f)
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
with torch.no_grad():
weights.copy_(torch.tensor(samples, device=weights.device))
def lecun_normal_init_(weights):
trunc_normal_init_(weights, scale=1.0)
def he_normal_init_(weights):
trunc_normal_init_(weights, scale=2.0)
def glorot_uniform_init_(weights):
nn.init.xavier_uniform_(weights, gain=1)
def final_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def gating_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def normal_init_(weights):
nn.init.kaiming_normal_(weights, nonlinearity="linear")
class Linear(nn.Linear):
"""
A Linear layer with in-house nonstandard initializations.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
if bias:
with torch.no_grad():
self.bias.fill_(0)
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
# Simply 3-layer MLP
class NodeTransition(nn.Module):
def __init__(self, dim):
super(NodeTransition, self).__init__()
self.dim = dim
self.linear_1 = Linear(dim, dim, init="relu")
self.linear_2 = Linear(dim, dim, init="relu")
self.linear_3 = Linear(dim, dim, init="final")
self.relu = nn.ReLU()
self.ln = nn.LayerNorm(dim)
def forward(self, s):
s_initial = s
s = self.relu(self.linear_1(s))
s = self.relu(self.linear_2(s))
s = self.linear_3(s)
s = s + s_initial
s = self.ln(s)
return s
class EdgeTransition(nn.Module):
def __init__(self,
node_embed_size,
edge_embed_in,
edge_embed_out,
num_layers=2,
node_dilation=2
):
super(EdgeTransition, self).__init__()
bias_embed_size = node_embed_size // node_dilation
self.initial_embed = Linear(
node_embed_size, bias_embed_size, init="relu")
hidden_size = bias_embed_size * 2 + edge_embed_in
trunk_layers = []
for _ in range(num_layers):
trunk_layers.append(Linear(hidden_size, hidden_size, init="relu"))
trunk_layers.append(nn.ReLU())
self.trunk = nn.Sequential(*trunk_layers)
self.final_layer = Linear(hidden_size, edge_embed_out, init="final")
self.layer_norm = nn.LayerNorm(edge_embed_out)
def forward(self, node_embed, edge_embed):
node_embed = self.initial_embed(node_embed)
batch_size, num_res, _ = node_embed.shape
edge_bias = torch.cat([
torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)),
torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)),
], axis=-1)
edge_embed = torch.cat(
[edge_embed, edge_bias], axis=-1).reshape(
batch_size * num_res**2, -1)
edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed)
edge_embed = self.layer_norm(edge_embed)
edge_embed = edge_embed.reshape(
batch_size, num_res, num_res, -1
)
return edge_embed
class TorsionAngleHead(nn.Module):
def __init__(self, in_dim, n_torsion_angles, eps=1e-8):
super(TorsionAngleHead, self).__init__()
self.linear_1 = Linear(in_dim, in_dim, init="relu")
self.linear_2 = Linear(in_dim, in_dim, init="relu")
self.linear_3 = Linear(in_dim, in_dim, init="final")
self.linear_final = Linear(in_dim, n_torsion_angles * 2, init="final")
self.relu = nn.ReLU()
self.eps = eps
def forward(self, s):
s_initial = s
s = self.relu(self.linear_1(s))
s = self.linear_2(s)
s = s + s_initial
unnormalized_s = self.linear_final(s)
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True),
min=self.eps,
)
)
normalized_s = unnormalized_s / norm_denom
return normalized_s
class BackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
"""
def __init__(self, c_s):
"""
Args:
c_s:
Single representation channel dimension
"""
super(BackboneUpdate, self).__init__()
self.c_s = c_s
self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s: torch.Tensor):
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res, 6] update vector
"""
# [*, 6]
update = self.linear(s)
return update