|
|
|
|
|
|
|
|
from functools import partial
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from torch.nn import (
|
|
|
Linear,
|
|
|
Module,
|
|
|
)
|
|
|
from torch.types import Device
|
|
|
|
|
|
LinearNoBias = partial(Linear, bias=False)
|
|
|
|
|
|
|
|
|
def exists(v):
|
|
|
return v is not None
|
|
|
|
|
|
|
|
|
def default(v, d):
|
|
|
return v if exists(v) else d
|
|
|
|
|
|
|
|
|
def log(t, eps=1e-20):
|
|
|
return torch.log(t.clamp(min=eps))
|
|
|
|
|
|
|
|
|
class SwiGLU(Module):
|
|
|
def forward(
|
|
|
self,
|
|
|
x,
|
|
|
):
|
|
|
x, gates = x.chunk(2, dim=-1)
|
|
|
return F.silu(gates) * x
|
|
|
|
|
|
|
|
|
def center(atom_coords, atom_mask):
|
|
|
atom_mean = torch.sum(
|
|
|
atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
|
|
|
) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
|
|
|
atom_coords = atom_coords - atom_mean
|
|
|
return atom_coords
|
|
|
|
|
|
|
|
|
def compute_random_augmentation(
|
|
|
multiplicity, s_trans=1.0, device=None, dtype=torch.float32
|
|
|
):
|
|
|
R = random_rotations(multiplicity, dtype=dtype, device=device)
|
|
|
random_trans = (
|
|
|
torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans
|
|
|
)
|
|
|
return R, random_trans
|
|
|
|
|
|
|
|
|
def randomly_rotate(coords, return_second_coords=False, second_coords=None):
|
|
|
R = random_rotations(len(coords), coords.dtype, coords.device)
|
|
|
|
|
|
if return_second_coords:
|
|
|
return torch.einsum("bmd,bds->bms", coords, R), torch.einsum(
|
|
|
"bmd,bds->bms", second_coords, R
|
|
|
) if second_coords is not None else None
|
|
|
|
|
|
return torch.einsum("bmd,bds->bms", coords, R)
|
|
|
|
|
|
|
|
|
def center_random_augmentation(
|
|
|
atom_coords,
|
|
|
atom_mask,
|
|
|
s_trans=1.0,
|
|
|
augmentation=True,
|
|
|
centering=True,
|
|
|
return_second_coords=False,
|
|
|
second_coords=None,
|
|
|
):
|
|
|
"""Algorithm 19"""
|
|
|
if centering:
|
|
|
atom_mean = torch.sum(
|
|
|
atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
|
|
|
) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
|
|
|
atom_coords = atom_coords - atom_mean
|
|
|
|
|
|
if second_coords is not None:
|
|
|
|
|
|
second_coords = second_coords - atom_mean
|
|
|
|
|
|
if augmentation:
|
|
|
atom_coords, second_coords = randomly_rotate(
|
|
|
atom_coords, return_second_coords=True, second_coords=second_coords
|
|
|
)
|
|
|
random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans
|
|
|
atom_coords = atom_coords + random_trans
|
|
|
|
|
|
if second_coords is not None:
|
|
|
second_coords = second_coords + random_trans
|
|
|
|
|
|
if return_second_coords:
|
|
|
return atom_coords, second_coords
|
|
|
|
|
|
return atom_coords
|
|
|
|
|
|
|
|
|
class ExponentialMovingAverage:
|
|
|
"""from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license
|
|
|
Maintains (exponential) moving average of a set of parameters."""
|
|
|
|
|
|
def __init__(self, parameters, decay, use_num_updates=True):
|
|
|
"""
|
|
|
Args:
|
|
|
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
|
|
`model.parameters()`.
|
|
|
decay: The exponential decay.
|
|
|
use_num_updates: Whether to use number of updates when computing
|
|
|
averages.
|
|
|
"""
|
|
|
if decay < 0.0 or decay > 1.0:
|
|
|
raise ValueError("Decay must be between 0 and 1")
|
|
|
self.decay = decay
|
|
|
self.num_updates = 0 if use_num_updates else None
|
|
|
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
|
|
|
self.collected_params = []
|
|
|
|
|
|
def update(self, parameters):
|
|
|
"""
|
|
|
Update currently maintained parameters.
|
|
|
Call this every time the parameters are updated, such as the result of
|
|
|
the `optimizer.step()` call.
|
|
|
Args:
|
|
|
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
|
|
parameters used to initialize this object.
|
|
|
"""
|
|
|
decay = self.decay
|
|
|
if self.num_updates is not None:
|
|
|
self.num_updates += 1
|
|
|
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
|
|
one_minus_decay = 1.0 - decay
|
|
|
with torch.no_grad():
|
|
|
parameters = [p for p in parameters if p.requires_grad]
|
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
|
s_param.sub_(one_minus_decay * (s_param - param))
|
|
|
|
|
|
def compatible(self, parameters):
|
|
|
if len(self.shadow_params) != len(parameters):
|
|
|
print(
|
|
|
f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}"
|
|
|
)
|
|
|
return False
|
|
|
|
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
|
if param.data.shape != s_param.data.shape:
|
|
|
print(
|
|
|
f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}"
|
|
|
)
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
def copy_to(self, parameters):
|
|
|
"""
|
|
|
Copy current parameters into given collection of parameters.
|
|
|
Args:
|
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
|
updated with the stored moving averages.
|
|
|
"""
|
|
|
parameters = [p for p in parameters if p.requires_grad]
|
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
|
if param.requires_grad:
|
|
|
param.data.copy_(s_param.data)
|
|
|
|
|
|
def store(self, parameters):
|
|
|
"""
|
|
|
Save the current parameters for restoring later.
|
|
|
Args:
|
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
|
temporarily stored.
|
|
|
"""
|
|
|
self.collected_params = [param.clone() for param in parameters]
|
|
|
|
|
|
def restore(self, parameters):
|
|
|
"""
|
|
|
Restore the parameters stored with the `store` method.
|
|
|
Useful to validate the model with EMA parameters without affecting the
|
|
|
original optimization process. Store the parameters before the
|
|
|
`copy_to` method. After validation (or model saving), use this to
|
|
|
restore the former parameters.
|
|
|
Args:
|
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
|
updated with the stored parameters.
|
|
|
"""
|
|
|
for c_param, param in zip(self.collected_params, parameters):
|
|
|
param.data.copy_(c_param.data)
|
|
|
|
|
|
def state_dict(self):
|
|
|
return dict(
|
|
|
decay=self.decay,
|
|
|
num_updates=self.num_updates,
|
|
|
shadow_params=self.shadow_params,
|
|
|
)
|
|
|
|
|
|
def load_state_dict(self, state_dict, device):
|
|
|
self.decay = state_dict["decay"]
|
|
|
self.num_updates = state_dict["num_updates"]
|
|
|
self.shadow_params = [
|
|
|
tensor.to(device) for tensor in state_dict["shadow_params"]
|
|
|
]
|
|
|
|
|
|
def to(self, device):
|
|
|
self.shadow_params = [tensor.to(device) for tensor in self.shadow_params]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Return a tensor where each element has the absolute value taken from the,
|
|
|
corresponding element of a, with sign taken from the corresponding
|
|
|
element of b. This is like the standard copysign floating-point operation,
|
|
|
but is not careful about negative 0 and NaN.
|
|
|
|
|
|
Args:
|
|
|
a: source tensor.
|
|
|
b: tensor whose signs will be used, of the same shape as a.
|
|
|
|
|
|
Returns:
|
|
|
Tensor of the same shape as a with the signs of b.
|
|
|
"""
|
|
|
signs_differ = (a < 0) != (b < 0)
|
|
|
return torch.where(signs_differ, -a, a)
|
|
|
|
|
|
|
|
|
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Convert rotations given as quaternions to rotation matrices.
|
|
|
|
|
|
Args:
|
|
|
quaternions: quaternions with real part first,
|
|
|
as tensor of shape (..., 4).
|
|
|
|
|
|
Returns:
|
|
|
Rotation matrices as tensor of shape (..., 3, 3).
|
|
|
"""
|
|
|
r, i, j, k = torch.unbind(quaternions, -1)
|
|
|
|
|
|
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
|
|
|
|
|
o = torch.stack(
|
|
|
(
|
|
|
1 - two_s * (j * j + k * k),
|
|
|
two_s * (i * j - k * r),
|
|
|
two_s * (i * k + j * r),
|
|
|
two_s * (i * j + k * r),
|
|
|
1 - two_s * (i * i + k * k),
|
|
|
two_s * (j * k - i * r),
|
|
|
two_s * (i * k - j * r),
|
|
|
two_s * (j * k + i * r),
|
|
|
1 - two_s * (i * i + j * j),
|
|
|
),
|
|
|
-1,
|
|
|
)
|
|
|
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
|
|
|
|
|
|
|
|
def random_quaternions(
|
|
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Generate random quaternions representing rotations,
|
|
|
i.e. versors with nonnegative real part.
|
|
|
|
|
|
Args:
|
|
|
n: Number of quaternions in a batch to return.
|
|
|
dtype: Type to return.
|
|
|
device: Desired device of returned tensor. Default:
|
|
|
uses the current device for the default tensor type.
|
|
|
|
|
|
Returns:
|
|
|
Quaternions as tensor of shape (N, 4).
|
|
|
"""
|
|
|
if isinstance(device, str):
|
|
|
device = torch.device(device)
|
|
|
o = torch.randn((n, 4), dtype=dtype, device=device)
|
|
|
s = (o * o).sum(1)
|
|
|
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
|
|
return o
|
|
|
|
|
|
|
|
|
def random_rotations(
|
|
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Generate random rotations as 3x3 rotation matrices.
|
|
|
|
|
|
Args:
|
|
|
n: Number of rotation matrices in a batch to return.
|
|
|
dtype: Type to return.
|
|
|
device: Device of returned tensor. Default: if None,
|
|
|
uses the current device for the default tensor type.
|
|
|
|
|
|
Returns:
|
|
|
Rotation matrices as tensor of shape (n, 3, 3).
|
|
|
"""
|
|
|
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
|
|
return quaternion_to_matrix(quaternions)
|
|
|
|