from abc import ABC, abstractmethod from typing import Optional, Dict, Any, Set, List, Union import torch import numpy as np from . import vb_const as const from .vb_potentials_schedules import ( ParameterSchedule, ExponentialInterpolation, PiecewiseStepFunction, ) from .vb_loss_diffusionv2 import weighted_rigid_align class Potential(ABC): def __init__( self, parameters: Optional[ Dict[str, Union[ParameterSchedule, float, int, bool]] ] = None, ): self.parameters = parameters def compute(self, coords, feats, parameters): index, args, com_args, ref_args, operator_args = self.compute_args( feats, parameters ) if index.shape[1] == 0: return torch.zeros(coords.shape[:-2], device=coords.device) if com_args is not None: com_index, atom_pad_mask = com_args unpad_com_index = com_index[atom_pad_mask] unpad_coords = coords[..., atom_pad_mask, :] coords = torch.zeros( (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), device=coords.device, ).scatter_reduce( -2, unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), unpad_coords, "mean", ) else: com_index, atom_pad_mask = None, None if ref_args is not None: ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args coords = coords[..., ref_atom_index, :] else: ref_coords, ref_mask, ref_atom_index, ref_token_index = ( None, None, None, None, ) if operator_args is not None: negation_mask, union_index = operator_args else: negation_mask, union_index = None, None value = self.compute_variable( coords, index, ref_coords=ref_coords, ref_mask=ref_mask, compute_gradient=False, ) energy = self.compute_function( value, *args, negation_mask=negation_mask, compute_derivative=False ) if union_index is not None: neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) Z = torch.zeros( (*energy.shape[:-1], union_index.max() + 1), device=union_index.device ).scatter_reduce( -1, union_index.expand_as(neg_exp_energy), neg_exp_energy, "sum", ) softmax_energy = neg_exp_energy / Z[..., union_index] softmax_energy[Z[..., union_index] == 0] = 0 return (energy * softmax_energy).sum(dim=-1) return energy.sum(dim=tuple(range(1, energy.dim()))) def compute_gradient(self, coords, feats, parameters): index, args, com_args, ref_args, operator_args = self.compute_args( feats, parameters ) if index.shape[1] == 0: return torch.zeros_like(coords) if com_args is not None: com_index, atom_pad_mask = com_args unpad_coords = coords[..., atom_pad_mask, :] unpad_com_index = com_index[atom_pad_mask] coords = torch.zeros( (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), device=coords.device, ).scatter_reduce( -2, unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), unpad_coords, "mean", ) com_counts = torch.bincount(com_index[atom_pad_mask]) else: com_index, atom_pad_mask = None, None if ref_args is not None: ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args coords = coords[..., ref_atom_index, :] else: ref_coords, ref_mask, ref_atom_index, ref_token_index = ( None, None, None, None, ) if operator_args is not None: negation_mask, union_index = operator_args else: negation_mask, union_index = None, None value, grad_value = self.compute_variable( coords, index, ref_coords=ref_coords, ref_mask=ref_mask, compute_gradient=True, ) energy, dEnergy = self.compute_function( value, *args, negation_mask=negation_mask, compute_derivative=True ) if union_index is not None: neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) Z = torch.zeros( (*energy.shape[:-1], union_index.max() + 1), device=union_index.device ).scatter_reduce( -1, union_index.expand_as(energy), neg_exp_energy, "sum", ) softmax_energy = neg_exp_energy / Z[..., union_index] softmax_energy[Z[..., union_index] == 0] = 0 f = torch.zeros( (*energy.shape[:-1], union_index.max() + 1), device=union_index.device ).scatter_reduce( -1, union_index.expand_as(energy), energy * softmax_energy, "sum", ) dSoftmax = ( dEnergy * softmax_energy * (1 + parameters["union_lambda"] * (energy - f[..., union_index])) ) prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze( -1 ) * grad_value.flatten(start_dim=-3, end_dim=-2) if prod.dim() > 3: prod = prod.sum(dim=list(range(1, prod.dim() - 2))) grad_atom = torch.zeros_like(coords).scatter_reduce( -2, index.flatten(start_dim=0, end_dim=1) .unsqueeze(-1) .expand((*coords.shape[:-2], -1, 3)), prod, "sum", ) else: prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze( -1 ) * grad_value.flatten(start_dim=-3, end_dim=-2) if prod.dim() > 3: prod = prod.sum(dim=list(range(1, prod.dim() - 2))) grad_atom = torch.zeros_like(coords).scatter_reduce( -2, index.flatten(start_dim=0, end_dim=1) .unsqueeze(-1) .expand((*coords.shape[:-2], -1, 3)), # 9 x 516 x 3 prod, "sum", ) if com_index is not None: grad_atom = grad_atom[..., com_index, :] elif ref_token_index is not None: grad_atom = grad_atom[..., ref_token_index, :] return grad_atom def compute_parameters(self, t): if self.parameters is None: return None parameters = { name: parameter if not isinstance(parameter, ParameterSchedule) else parameter.compute(t) for name, parameter in self.parameters.items() } return parameters @abstractmethod def compute_function( self, value, *args, negation_mask=None, compute_derivative=False ): raise NotImplementedError @abstractmethod def compute_variable(self, coords, index, compute_gradient=False): raise NotImplementedError @abstractmethod def compute_args(self, t, feats, **parameters): raise NotImplementedError def get_reference_coords(self, feats, parameters): return None, None class FlatBottomPotential(Potential): def compute_function( self, value, k, lower_bounds, upper_bounds, negation_mask=None, compute_derivative=False, ): if lower_bounds is None: lower_bounds = torch.full_like(value, float("-inf")) if upper_bounds is None: upper_bounds = torch.full_like(value, float("inf")) lower_bounds = lower_bounds.expand_as(value).clone() upper_bounds = upper_bounds.expand_as(value).clone() if negation_mask is not None: unbounded_below_mask = torch.isneginf(lower_bounds) unbounded_above_mask = torch.isposinf(upper_bounds) unbounded_mask = unbounded_below_mask + unbounded_above_mask assert torch.all(unbounded_mask + negation_mask) lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[ ~unbounded_above_mask * ~negation_mask ] upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf") upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[ ~unbounded_below_mask * ~negation_mask ] lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf") neg_overflow_mask = value < lower_bounds pos_overflow_mask = value > upper_bounds energy = torch.zeros_like(value) energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask] energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask] if not compute_derivative: return energy dEnergy = torch.zeros_like(value) dEnergy[neg_overflow_mask] = ( -1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask] ) dEnergy[pos_overflow_mask] = ( 1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask] ) return energy, dEnergy class ReferencePotential(Potential): def compute_variable( self, coords, index, ref_coords, ref_mask, compute_gradient=False ): aligned_ref_coords = weighted_rigid_align( ref_coords.float(), coords[:, index].float(), ref_mask, ref_mask, ) r = coords[:, index] - aligned_ref_coords r_norm = torch.linalg.norm(r, dim=-1) if not compute_gradient: return r_norm r_hat = r / r_norm.unsqueeze(-1) grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1) return r_norm, grad class DistancePotential(Potential): def compute_variable( self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False ): r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) r_ij_norm = torch.linalg.norm(r_ij, dim=-1) r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1) if not compute_gradient: return r_ij_norm grad_i = r_hat_ij grad_j = -1 * r_hat_ij grad = torch.stack((grad_i, grad_j), dim=1) return r_ij_norm, grad class DihedralPotential(Potential): def compute_variable( self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False ): r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1]) r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3]) n_ijk = torch.cross(r_ij, r_kj, dim=-1) n_jkl = torch.cross(r_kj, r_kl, dim=-1) r_kj_norm = torch.linalg.norm(r_kj, dim=-1) n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) sign_phi = torch.sign( r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1) ).squeeze(-1, -2) phi = sign_phi * torch.arccos( torch.clamp( (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2) / (n_ijk_norm * n_jkl_norm), -1 + 1e-8, 1 - 1e-8, ) ) if not compute_gradient: return phi a = ( (r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) ).unsqueeze(-1) b = ( (r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) ).unsqueeze(-1) grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1) grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1) grad_j = (a - 1) * grad_i - b * grad_l grad_k = (b - 1) * grad_l - a * grad_i grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1) return phi, grad class AbsDihedralPotential(DihedralPotential): def compute_variable( self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False ): if not compute_gradient: phi = super().compute_variable( coords, index, compute_gradient=compute_gradient ) phi = torch.abs(phi) return phi phi, grad = super().compute_variable( coords, index, compute_gradient=compute_gradient ) grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1 phi = torch.abs(phi) return phi, grad class PoseBustersPotential(FlatBottomPotential, DistancePotential): def compute_args(self, feats, parameters): pair_index = feats["rdkit_bounds_index"][0] lower_bounds = feats["rdkit_lower_bounds"][0].clone() upper_bounds = feats["rdkit_upper_bounds"][0].clone() bond_mask = feats["rdkit_bounds_bond_mask"][0] angle_mask = feats["rdkit_bounds_angle_mask"][0] lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"] upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"] lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"] upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"] lower_bounds[bond_mask * angle_mask] *= 1.0 - min( parameters["bond_buffer"], parameters["angle_buffer"] ) upper_bounds[bond_mask * angle_mask] *= 1.0 + min( parameters["bond_buffer"], parameters["angle_buffer"] ) lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"] upper_bounds[~bond_mask * ~angle_mask] = float("inf") vdw_radii = torch.zeros( const.num_elements, dtype=torch.float32, device=pair_index.device ) vdw_radii[1:119] = torch.tensor( const.vdw_radii, dtype=torch.float32, device=pair_index.device ) atom_vdw_radii = ( feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) ).squeeze(-1)[0] bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0) lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask]) upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask]) k = torch.ones_like(lower_bounds) return pair_index, (k, lower_bounds, upper_bounds), None, None, None class ConnectionsPotential(FlatBottomPotential, DistancePotential): def compute_args(self, feats, parameters): pair_index = feats["connected_atom_index"][0] lower_bounds = None upper_bounds = torch.full( (pair_index.shape[1],), parameters["buffer"], device=pair_index.device ) k = torch.ones_like(upper_bounds) return pair_index, (k, lower_bounds, upper_bounds), None, None, None class VDWOverlapPotential(FlatBottomPotential, DistancePotential): def compute_args(self, feats, parameters): atom_chain_id = ( torch.bmm( feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() ) .squeeze(-1) .long() )[0] atom_pad_mask = feats["atom_pad_mask"][0].bool() chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) single_ion_mask = (chain_sizes > 1)[atom_chain_id] vdw_radii = torch.zeros( const.num_elements, dtype=torch.float32, device=atom_chain_id.device ) vdw_radii[1:119] = torch.tensor( const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device ) atom_vdw_radii = ( feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) ).squeeze(-1)[0] pair_index = torch.triu_indices( atom_chain_id.shape[0], atom_chain_id.shape[0], 1, device=atom_chain_id.device, ) pair_pad_mask = atom_pad_mask[pair_index].all(dim=0) pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] num_chains = atom_chain_id.max() + 1 connected_chain_index = feats["connected_chain_index"][0] connected_chain_matrix = torch.eye( num_chains, device=atom_chain_id.device, dtype=torch.bool ) connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = ( True ) connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = ( True ) connected_chain_mask = connected_chain_matrix[ atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]] ] pair_index = pair_index[ :, pair_pad_mask * pair_ion_mask * ~connected_chain_mask ] lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * ( 1.0 - parameters["buffer"] ) upper_bounds = None k = torch.ones_like(lower_bounds) return pair_index, (k, lower_bounds, upper_bounds), None, None, None class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential): def compute_args(self, feats, parameters): atom_chain_id = ( torch.bmm( feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() ) .squeeze(-1) .long() )[0] atom_pad_mask = feats["atom_pad_mask"][0].bool() chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) single_ion_mask = chain_sizes > 1 pair_index = feats["symmetric_chain_index"][0] pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] pair_index = pair_index[:, pair_ion_mask] lower_bounds = torch.full( (pair_index.shape[1],), parameters["buffer"], dtype=torch.float32, device=pair_index.device, ) upper_bounds = None k = torch.ones_like(lower_bounds) return ( pair_index, (k, lower_bounds, upper_bounds), (atom_chain_id, atom_pad_mask), None, None, ) class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential): def compute_args(self, feats, parameters): stereo_bond_index = feats["stereo_bond_index"][0] stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool() lower_bounds = torch.zeros( stereo_bond_orientations.shape, device=stereo_bond_orientations.device ) upper_bounds = torch.zeros( stereo_bond_orientations.shape, device=stereo_bond_orientations.device ) lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"] upper_bounds[stereo_bond_orientations] = float("inf") lower_bounds[~stereo_bond_orientations] = float("-inf") upper_bounds[~stereo_bond_orientations] = parameters["buffer"] k = torch.ones_like(lower_bounds) return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None class ChiralAtomPotential(FlatBottomPotential, DihedralPotential): def compute_args(self, feats, parameters): chiral_atom_index = feats["chiral_atom_index"][0] chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool() lower_bounds = torch.zeros( chiral_atom_orientations.shape, device=chiral_atom_orientations.device ) upper_bounds = torch.zeros( chiral_atom_orientations.shape, device=chiral_atom_orientations.device ) lower_bounds[chiral_atom_orientations] = parameters["buffer"] upper_bounds[chiral_atom_orientations] = float("inf") upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"] lower_bounds[~chiral_atom_orientations] = float("-inf") k = torch.ones_like(lower_bounds) return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential): def compute_args(self, feats, parameters): double_bond_index = feats["planar_bond_index"][0].T double_bond_improper_index = torch.tensor( [ [1, 2, 3, 0], [4, 5, 0, 3], ], device=double_bond_index.device, ).T improper_index = ( double_bond_index[:, double_bond_improper_index] .swapaxes(0, 1) .flatten(start_dim=1) ) lower_bounds = None upper_bounds = torch.full( (improper_index.shape[1],), parameters["buffer"], device=improper_index.device, ) k = torch.ones_like(upper_bounds) return improper_index, (k, lower_bounds, upper_bounds), None, None, None class TemplateReferencePotential(FlatBottomPotential, ReferencePotential): def compute_args(self, feats, parameters): if "template_mask_cb" not in feats or "template_force" not in feats: return torch.empty([1, 0]), None, None, None, None template_mask = feats["template_mask_cb"][feats["template_force"]] if template_mask.shape[0] == 0: return torch.empty([1, 0]), None, None, None, None ref_coords = feats["template_cb"][feats["template_force"]].clone() ref_mask = feats["template_mask_cb"][feats["template_force"]].clone() ref_atom_index = ( torch.bmm( feats["token_to_rep_atom"].float(), torch.arange( feats["atom_pad_mask"].shape[1], device=feats["atom_pad_mask"].device, dtype=torch.float32, )[None, :, None], ) .squeeze(-1) .long() )[0] ref_token_index = ( torch.bmm( feats["atom_to_token"].float(), feats["token_index"].unsqueeze(-1).float(), ) .squeeze(-1) .long() )[0] index = torch.arange( template_mask.shape[-1], dtype=torch.long, device=template_mask.device )[None] upper_bounds = torch.full( template_mask.shape, float("inf"), device=index.device, dtype=torch.float32 ) ref_idxs = torch.argwhere(template_mask).T upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][ feats["template_force"] ][ref_idxs[0]] lower_bounds = None k = torch.ones_like(upper_bounds) return ( index, (k, lower_bounds, upper_bounds), None, (ref_coords, ref_mask, ref_atom_index, ref_token_index), None, ) class ContactPotentital(FlatBottomPotential, DistancePotential): def compute_args(self, feats, parameters): index = feats["contact_pair_index"][0] union_index = feats["contact_union_index"][0] negation_mask = feats["contact_negation_mask"][0] lower_bounds = None upper_bounds = feats["contact_thresholds"][0].clone() k = torch.ones_like(upper_bounds) return ( index, (k, lower_bounds, upper_bounds), None, None, (negation_mask, union_index), ) def get_potentials(steering_args, boltz2=False): potentials = [] if steering_args["fk_steering"] or steering_args["physical_guidance_update"]: potentials.extend( [ SymmetricChainCOMPotential( parameters={ "guidance_interval": 4, "guidance_weight": 0.5 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 0.5, "buffer": ExponentialInterpolation( start=1.0, end=5.0, alpha=-2.0 ), } ), VDWOverlapPotential( parameters={ "guidance_interval": 5, "guidance_weight": ( PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0]) if steering_args["physical_guidance_update"] else 0.0 ), "resampling_weight": PiecewiseStepFunction( thresholds=[0.6], values=[0.01, 0.0] ), "buffer": 0.225, } ), ConnectionsPotential( parameters={ "guidance_interval": 1, "guidance_weight": 0.15 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 1.0, "buffer": 2.0, } ), PoseBustersPotential( parameters={ "guidance_interval": 1, "guidance_weight": 0.01 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 0.1, "bond_buffer": 0.125, "angle_buffer": 0.125, "clash_buffer": 0.10, } ), ChiralAtomPotential( parameters={ "guidance_interval": 1, "guidance_weight": 0.1 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 1.0, "buffer": 0.52360, } ), StereoBondPotential( parameters={ "guidance_interval": 1, "guidance_weight": 0.05 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 1.0, "buffer": 0.52360, } ), PlanarBondPotential( parameters={ "guidance_interval": 1, "guidance_weight": 0.05 if steering_args["physical_guidance_update"] else 0.0, "resampling_weight": 1.0, "buffer": 0.26180, } ), ] ) if boltz2 and ( steering_args["fk_steering"] or steering_args["contact_guidance_update"] ): potentials.extend( [ ContactPotentital( parameters={ "guidance_interval": 4, "guidance_weight": ( PiecewiseStepFunction( thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0] ) if steering_args["contact_guidance_update"] else 0.0 ), "resampling_weight": 1.0, "union_lambda": ExponentialInterpolation( start=8.0, end=0.0, alpha=-2.0 ), } ), TemplateReferencePotential( parameters={ "guidance_interval": 2, "guidance_weight": 0.1 if steering_args["contact_guidance_update"] else 0.0, "resampling_weight": 1.0, } ), ] ) return potentials