# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from functools import partial from typing import Optional import torch import torch.nn.functional as F from torch.nn import ( Linear, Module, ) from torch.types import Device LinearNoBias = partial(Linear, bias=False) def exists(v): return v is not None def default(v, d): return v if exists(v) else d def log(t, eps=1e-20): return torch.log(t.clamp(min=eps)) class SwiGLU(Module): def forward( self, x, #: Float['... d'] ): # -> Float[' ... (d//2)']: x, gates = x.chunk(2, dim=-1) return F.silu(gates) * x def center(atom_coords, atom_mask): atom_mean = torch.sum( atom_coords * atom_mask[:, :, None], dim=1, keepdim=True ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) atom_coords = atom_coords - atom_mean return atom_coords def compute_random_augmentation( multiplicity, s_trans=1.0, device=None, dtype=torch.float32 ): R = random_rotations(multiplicity, dtype=dtype, device=device) random_trans = ( torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans ) return R, random_trans def randomly_rotate(coords, return_second_coords=False, second_coords=None): R = random_rotations(len(coords), coords.dtype, coords.device) if return_second_coords: return torch.einsum("bmd,bds->bms", coords, R), torch.einsum( "bmd,bds->bms", second_coords, R ) if second_coords is not None else None return torch.einsum("bmd,bds->bms", coords, R) def center_random_augmentation( atom_coords, atom_mask, s_trans=1.0, augmentation=True, centering=True, return_second_coords=False, second_coords=None, ): """Algorithm 19""" if centering: atom_mean = torch.sum( atom_coords * atom_mask[:, :, None], dim=1, keepdim=True ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) atom_coords = atom_coords - atom_mean if second_coords is not None: # apply same transformation also to this input second_coords = second_coords - atom_mean if augmentation: atom_coords, second_coords = randomly_rotate( atom_coords, return_second_coords=True, second_coords=second_coords ) random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans atom_coords = atom_coords + random_trans if second_coords is not None: second_coords = second_coords + random_trans if return_second_coords: return atom_coords, second_coords return atom_coords class ExponentialMovingAverage: """from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license Maintains (exponential) moving average of a set of parameters.""" def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] def update(self, parameters): """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) def compatible(self, parameters): if len(self.shadow_params) != len(parameters): print( f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}" ) return False for s_param, param in zip(self.shadow_params, parameters): if param.data.shape != s_param.data.shape: print( f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}" ) return False return True def copy_to(self, parameters): """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def state_dict(self): return dict( decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params, ) def load_state_dict(self, state_dict, device): self.decay = state_dict["decay"] self.num_updates = state_dict["num_updates"] self.shadow_params = [ tensor.to(device) for tensor in state_dict["shadow_params"] ] def to(self, device): self.shadow_params = [tensor.to(device) for tensor in self.shadow_params] # the following is copied from Torch3D, BSD License, Copyright (c) Meta Platforms, Inc. and affiliates. def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Return a tensor where each element has the absolute value taken from the, corresponding element of a, with sign taken from the corresponding element of b. This is like the standard copysign floating-point operation, but is not careful about negative 0 and NaN. Args: a: source tensor. b: tensor whose signs will be used, of the same shape as a. Returns: Tensor of the same shape as a with the signs of b. """ signs_differ = (a < 0) != (b < 0) return torch.where(signs_differ, -a, a) def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def random_quaternions( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None ) -> torch.Tensor: """ Generate random quaternions representing rotations, i.e. versors with nonnegative real part. Args: n: Number of quaternions in a batch to return. dtype: Type to return. device: Desired device of returned tensor. Default: uses the current device for the default tensor type. Returns: Quaternions as tensor of shape (N, 4). """ if isinstance(device, str): device = torch.device(device) o = torch.randn((n, 4), dtype=dtype, device=device) s = (o * o).sum(1) o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] return o def random_rotations( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None ) -> torch.Tensor: """ Generate random rotations as 3x3 rotation matrices. Args: n: Number of rotation matrices in a batch to return. dtype: Type to return. device: Device of returned tensor. Default: if None, uses the current device for the default tensor type. Returns: Rotation matrices as tensor of shape (n, 3, 3). """ quaternions = random_quaternions(n, dtype=dtype, device=device) return quaternion_to_matrix(quaternions)