| | |
| |
|
| | from functools import partial |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.nn import ( |
| | Linear, |
| | Module, |
| | ) |
| | from torch.types import Device |
| |
|
| | LinearNoBias = partial(Linear, bias=False) |
| |
|
| |
|
| | def exists(v): |
| | return v is not None |
| |
|
| |
|
| | def default(v, d): |
| | return v if exists(v) else d |
| |
|
| |
|
| | def log(t, eps=1e-20): |
| | return torch.log(t.clamp(min=eps)) |
| |
|
| |
|
| | class SwiGLU(Module): |
| | def forward( |
| | self, |
| | x, |
| | ): |
| | x, gates = x.chunk(2, dim=-1) |
| | return F.silu(gates) * x |
| |
|
| |
|
| | def center(atom_coords, atom_mask): |
| | atom_mean = torch.sum( |
| | atom_coords * atom_mask[:, :, None], dim=1, keepdim=True |
| | ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) |
| | atom_coords = atom_coords - atom_mean |
| | return atom_coords |
| |
|
| |
|
| | def compute_random_augmentation( |
| | multiplicity, s_trans=1.0, device=None, dtype=torch.float32 |
| | ): |
| | R = random_rotations(multiplicity, dtype=dtype, device=device) |
| | random_trans = ( |
| | torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans |
| | ) |
| | return R, random_trans |
| |
|
| |
|
| | def randomly_rotate(coords, return_second_coords=False, second_coords=None): |
| | R = random_rotations(len(coords), coords.dtype, coords.device) |
| |
|
| | if return_second_coords: |
| | return torch.einsum("bmd,bds->bms", coords, R), torch.einsum( |
| | "bmd,bds->bms", second_coords, R |
| | ) if second_coords is not None else None |
| |
|
| | return torch.einsum("bmd,bds->bms", coords, R) |
| |
|
| |
|
| | def center_random_augmentation( |
| | atom_coords, |
| | atom_mask, |
| | s_trans=1.0, |
| | augmentation=True, |
| | centering=True, |
| | return_second_coords=False, |
| | second_coords=None, |
| | ): |
| | """Algorithm 19""" |
| | if centering: |
| | atom_mean = torch.sum( |
| | atom_coords * atom_mask[:, :, None], dim=1, keepdim=True |
| | ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) |
| | atom_coords = atom_coords - atom_mean |
| |
|
| | if second_coords is not None: |
| | |
| | second_coords = second_coords - atom_mean |
| |
|
| | if augmentation: |
| | atom_coords, second_coords = randomly_rotate( |
| | atom_coords, return_second_coords=True, second_coords=second_coords |
| | ) |
| | random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans |
| | atom_coords = atom_coords + random_trans |
| |
|
| | if second_coords is not None: |
| | second_coords = second_coords + random_trans |
| |
|
| | if return_second_coords: |
| | return atom_coords, second_coords |
| |
|
| | return atom_coords |
| |
|
| |
|
| | class ExponentialMovingAverage: |
| | """from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license |
| | Maintains (exponential) moving average of a set of parameters.""" |
| |
|
| | def __init__(self, parameters, decay, use_num_updates=True): |
| | """ |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; usually the result of |
| | `model.parameters()`. |
| | decay: The exponential decay. |
| | use_num_updates: Whether to use number of updates when computing |
| | averages. |
| | """ |
| | if decay < 0.0 or decay > 1.0: |
| | raise ValueError("Decay must be between 0 and 1") |
| | self.decay = decay |
| | self.num_updates = 0 if use_num_updates else None |
| | self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] |
| | self.collected_params = [] |
| |
|
| | def update(self, parameters): |
| | """ |
| | Update currently maintained parameters. |
| | Call this every time the parameters are updated, such as the result of |
| | the `optimizer.step()` call. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; usually the same set of |
| | parameters used to initialize this object. |
| | """ |
| | decay = self.decay |
| | if self.num_updates is not None: |
| | self.num_updates += 1 |
| | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) |
| | one_minus_decay = 1.0 - decay |
| | with torch.no_grad(): |
| | parameters = [p for p in parameters if p.requires_grad] |
| | for s_param, param in zip(self.shadow_params, parameters): |
| | s_param.sub_(one_minus_decay * (s_param - param)) |
| |
|
| | def compatible(self, parameters): |
| | if len(self.shadow_params) != len(parameters): |
| | print( |
| | f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}" |
| | ) |
| | return False |
| |
|
| | for s_param, param in zip(self.shadow_params, parameters): |
| | if param.data.shape != s_param.data.shape: |
| | print( |
| | f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}" |
| | ) |
| | return False |
| | return True |
| |
|
| | def copy_to(self, parameters): |
| | """ |
| | Copy current parameters into given collection of parameters. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored moving averages. |
| | """ |
| | parameters = [p for p in parameters if p.requires_grad] |
| | for s_param, param in zip(self.shadow_params, parameters): |
| | if param.requires_grad: |
| | param.data.copy_(s_param.data) |
| |
|
| | def store(self, parameters): |
| | """ |
| | Save the current parameters for restoring later. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | temporarily stored. |
| | """ |
| | self.collected_params = [param.clone() for param in parameters] |
| |
|
| | def restore(self, parameters): |
| | """ |
| | Restore the parameters stored with the `store` method. |
| | Useful to validate the model with EMA parameters without affecting the |
| | original optimization process. Store the parameters before the |
| | `copy_to` method. After validation (or model saving), use this to |
| | restore the former parameters. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored parameters. |
| | """ |
| | for c_param, param in zip(self.collected_params, parameters): |
| | param.data.copy_(c_param.data) |
| |
|
| | def state_dict(self): |
| | return dict( |
| | decay=self.decay, |
| | num_updates=self.num_updates, |
| | shadow_params=self.shadow_params, |
| | ) |
| |
|
| | def load_state_dict(self, state_dict, device): |
| | self.decay = state_dict["decay"] |
| | self.num_updates = state_dict["num_updates"] |
| | self.shadow_params = [ |
| | tensor.to(device) for tensor in state_dict["shadow_params"] |
| | ] |
| |
|
| | def to(self, device): |
| | self.shadow_params = [tensor.to(device) for tensor in self.shadow_params] |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Return a tensor where each element has the absolute value taken from the, |
| | corresponding element of a, with sign taken from the corresponding |
| | element of b. This is like the standard copysign floating-point operation, |
| | but is not careful about negative 0 and NaN. |
| | |
| | Args: |
| | a: source tensor. |
| | b: tensor whose signs will be used, of the same shape as a. |
| | |
| | Returns: |
| | Tensor of the same shape as a with the signs of b. |
| | """ |
| | signs_differ = (a < 0) != (b < 0) |
| | return torch.where(signs_differ, -a, a) |
| |
|
| |
|
| | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert rotations given as quaternions to rotation matrices. |
| | |
| | Args: |
| | quaternions: quaternions with real part first, |
| | as tensor of shape (..., 4). |
| | |
| | Returns: |
| | Rotation matrices as tensor of shape (..., 3, 3). |
| | """ |
| | r, i, j, k = torch.unbind(quaternions, -1) |
| | |
| | two_s = 2.0 / (quaternions * quaternions).sum(-1) |
| |
|
| | o = torch.stack( |
| | ( |
| | 1 - two_s * (j * j + k * k), |
| | two_s * (i * j - k * r), |
| | two_s * (i * k + j * r), |
| | two_s * (i * j + k * r), |
| | 1 - two_s * (i * i + k * k), |
| | two_s * (j * k - i * r), |
| | two_s * (i * k - j * r), |
| | two_s * (j * k + i * r), |
| | 1 - two_s * (i * i + j * j), |
| | ), |
| | -1, |
| | ) |
| | return o.reshape(quaternions.shape[:-1] + (3, 3)) |
| |
|
| |
|
| | def random_quaternions( |
| | n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Generate random quaternions representing rotations, |
| | i.e. versors with nonnegative real part. |
| | |
| | Args: |
| | n: Number of quaternions in a batch to return. |
| | dtype: Type to return. |
| | device: Desired device of returned tensor. Default: |
| | uses the current device for the default tensor type. |
| | |
| | Returns: |
| | Quaternions as tensor of shape (N, 4). |
| | """ |
| | if isinstance(device, str): |
| | device = torch.device(device) |
| | o = torch.randn((n, 4), dtype=dtype, device=device) |
| | s = (o * o).sum(1) |
| | o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] |
| | return o |
| |
|
| |
|
| | def random_rotations( |
| | n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Generate random rotations as 3x3 rotation matrices. |
| | |
| | Args: |
| | n: Number of rotation matrices in a batch to return. |
| | dtype: Type to return. |
| | device: Device of returned tensor. Default: if None, |
| | uses the current device for the default tensor type. |
| | |
| | Returns: |
| | Rotation matrices as tensor of shape (n, 3, 3). |
| | """ |
| | quaternions = random_quaternions(n, dtype=dtype, device=device) |
| | return quaternion_to_matrix(quaternions) |
| |
|