|
|
from abc import ABC, abstractmethod
|
|
|
from typing import Optional, Dict, Any, Set, List, Union
|
|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from . import vb_const as const
|
|
|
from .vb_potentials_schedules import (
|
|
|
ParameterSchedule,
|
|
|
ExponentialInterpolation,
|
|
|
PiecewiseStepFunction,
|
|
|
)
|
|
|
from .vb_loss_diffusionv2 import weighted_rigid_align
|
|
|
|
|
|
|
|
|
class Potential(ABC):
|
|
|
def __init__(
|
|
|
self,
|
|
|
parameters: Optional[
|
|
|
Dict[str, Union[ParameterSchedule, float, int, bool]]
|
|
|
] = None,
|
|
|
):
|
|
|
self.parameters = parameters
|
|
|
|
|
|
def compute(self, coords, feats, parameters):
|
|
|
index, args, com_args, ref_args, operator_args = self.compute_args(
|
|
|
feats, parameters
|
|
|
)
|
|
|
|
|
|
if index.shape[1] == 0:
|
|
|
return torch.zeros(coords.shape[:-2], device=coords.device)
|
|
|
|
|
|
if com_args is not None:
|
|
|
com_index, atom_pad_mask = com_args
|
|
|
unpad_com_index = com_index[atom_pad_mask]
|
|
|
unpad_coords = coords[..., atom_pad_mask, :]
|
|
|
coords = torch.zeros(
|
|
|
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
|
|
|
device=coords.device,
|
|
|
).scatter_reduce(
|
|
|
-2,
|
|
|
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
|
|
|
unpad_coords,
|
|
|
"mean",
|
|
|
)
|
|
|
else:
|
|
|
com_index, atom_pad_mask = None, None
|
|
|
|
|
|
if ref_args is not None:
|
|
|
ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args
|
|
|
coords = coords[..., ref_atom_index, :]
|
|
|
else:
|
|
|
ref_coords, ref_mask, ref_atom_index, ref_token_index = (
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
if operator_args is not None:
|
|
|
negation_mask, union_index = operator_args
|
|
|
else:
|
|
|
negation_mask, union_index = None, None
|
|
|
|
|
|
value = self.compute_variable(
|
|
|
coords,
|
|
|
index,
|
|
|
ref_coords=ref_coords,
|
|
|
ref_mask=ref_mask,
|
|
|
compute_gradient=False,
|
|
|
)
|
|
|
energy = self.compute_function(
|
|
|
value, *args, negation_mask=negation_mask, compute_derivative=False
|
|
|
)
|
|
|
|
|
|
if union_index is not None:
|
|
|
neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy)
|
|
|
Z = torch.zeros(
|
|
|
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
|
|
|
).scatter_reduce(
|
|
|
-1,
|
|
|
union_index.expand_as(neg_exp_energy),
|
|
|
neg_exp_energy,
|
|
|
"sum",
|
|
|
)
|
|
|
softmax_energy = neg_exp_energy / Z[..., union_index]
|
|
|
softmax_energy[Z[..., union_index] == 0] = 0
|
|
|
return (energy * softmax_energy).sum(dim=-1)
|
|
|
|
|
|
return energy.sum(dim=tuple(range(1, energy.dim())))
|
|
|
|
|
|
def compute_gradient(self, coords, feats, parameters):
|
|
|
index, args, com_args, ref_args, operator_args = self.compute_args(
|
|
|
feats, parameters
|
|
|
)
|
|
|
if index.shape[1] == 0:
|
|
|
return torch.zeros_like(coords)
|
|
|
|
|
|
if com_args is not None:
|
|
|
com_index, atom_pad_mask = com_args
|
|
|
unpad_coords = coords[..., atom_pad_mask, :]
|
|
|
unpad_com_index = com_index[atom_pad_mask]
|
|
|
coords = torch.zeros(
|
|
|
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
|
|
|
device=coords.device,
|
|
|
).scatter_reduce(
|
|
|
-2,
|
|
|
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
|
|
|
unpad_coords,
|
|
|
"mean",
|
|
|
)
|
|
|
com_counts = torch.bincount(com_index[atom_pad_mask])
|
|
|
else:
|
|
|
com_index, atom_pad_mask = None, None
|
|
|
|
|
|
if ref_args is not None:
|
|
|
ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args
|
|
|
coords = coords[..., ref_atom_index, :]
|
|
|
else:
|
|
|
ref_coords, ref_mask, ref_atom_index, ref_token_index = (
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
if operator_args is not None:
|
|
|
negation_mask, union_index = operator_args
|
|
|
else:
|
|
|
negation_mask, union_index = None, None
|
|
|
|
|
|
value, grad_value = self.compute_variable(
|
|
|
coords,
|
|
|
index,
|
|
|
ref_coords=ref_coords,
|
|
|
ref_mask=ref_mask,
|
|
|
compute_gradient=True,
|
|
|
)
|
|
|
energy, dEnergy = self.compute_function(
|
|
|
value,
|
|
|
*args, negation_mask=negation_mask, compute_derivative=True
|
|
|
)
|
|
|
if union_index is not None:
|
|
|
neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy)
|
|
|
Z = torch.zeros(
|
|
|
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
|
|
|
).scatter_reduce(
|
|
|
-1,
|
|
|
union_index.expand_as(energy),
|
|
|
neg_exp_energy,
|
|
|
"sum",
|
|
|
)
|
|
|
softmax_energy = neg_exp_energy / Z[..., union_index]
|
|
|
softmax_energy[Z[..., union_index] == 0] = 0
|
|
|
f = torch.zeros(
|
|
|
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
|
|
|
).scatter_reduce(
|
|
|
-1,
|
|
|
union_index.expand_as(energy),
|
|
|
energy * softmax_energy,
|
|
|
"sum",
|
|
|
)
|
|
|
dSoftmax = (
|
|
|
dEnergy
|
|
|
* softmax_energy
|
|
|
* (1 + parameters["union_lambda"] * (energy - f[..., union_index]))
|
|
|
)
|
|
|
prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze(
|
|
|
-1
|
|
|
) * grad_value.flatten(start_dim=-3, end_dim=-2)
|
|
|
if prod.dim() > 3:
|
|
|
prod = prod.sum(dim=list(range(1, prod.dim() - 2)))
|
|
|
grad_atom = torch.zeros_like(coords).scatter_reduce(
|
|
|
-2,
|
|
|
index.flatten(start_dim=0, end_dim=1)
|
|
|
.unsqueeze(-1)
|
|
|
.expand((*coords.shape[:-2], -1, 3)),
|
|
|
prod,
|
|
|
"sum",
|
|
|
)
|
|
|
else:
|
|
|
prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze(
|
|
|
-1
|
|
|
) * grad_value.flatten(start_dim=-3, end_dim=-2)
|
|
|
if prod.dim() > 3:
|
|
|
prod = prod.sum(dim=list(range(1, prod.dim() - 2)))
|
|
|
grad_atom = torch.zeros_like(coords).scatter_reduce(
|
|
|
-2,
|
|
|
index.flatten(start_dim=0, end_dim=1)
|
|
|
.unsqueeze(-1)
|
|
|
.expand((*coords.shape[:-2], -1, 3)),
|
|
|
prod,
|
|
|
"sum",
|
|
|
)
|
|
|
|
|
|
if com_index is not None:
|
|
|
grad_atom = grad_atom[..., com_index, :]
|
|
|
elif ref_token_index is not None:
|
|
|
grad_atom = grad_atom[..., ref_token_index, :]
|
|
|
|
|
|
return grad_atom
|
|
|
|
|
|
def compute_parameters(self, t):
|
|
|
if self.parameters is None:
|
|
|
return None
|
|
|
parameters = {
|
|
|
name: parameter
|
|
|
if not isinstance(parameter, ParameterSchedule)
|
|
|
else parameter.compute(t)
|
|
|
for name, parameter in self.parameters.items()
|
|
|
}
|
|
|
return parameters
|
|
|
|
|
|
@abstractmethod
|
|
|
def compute_function(
|
|
|
self, value, *args, negation_mask=None, compute_derivative=False
|
|
|
):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abstractmethod
|
|
|
def compute_variable(self, coords, index, compute_gradient=False):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abstractmethod
|
|
|
def compute_args(self, t, feats, **parameters):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_reference_coords(self, feats, parameters):
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
class FlatBottomPotential(Potential):
|
|
|
def compute_function(
|
|
|
self,
|
|
|
value,
|
|
|
k,
|
|
|
lower_bounds,
|
|
|
upper_bounds,
|
|
|
negation_mask=None,
|
|
|
compute_derivative=False,
|
|
|
):
|
|
|
if lower_bounds is None:
|
|
|
lower_bounds = torch.full_like(value, float("-inf"))
|
|
|
if upper_bounds is None:
|
|
|
upper_bounds = torch.full_like(value, float("inf"))
|
|
|
lower_bounds = lower_bounds.expand_as(value).clone()
|
|
|
upper_bounds = upper_bounds.expand_as(value).clone()
|
|
|
|
|
|
if negation_mask is not None:
|
|
|
unbounded_below_mask = torch.isneginf(lower_bounds)
|
|
|
unbounded_above_mask = torch.isposinf(upper_bounds)
|
|
|
unbounded_mask = unbounded_below_mask + unbounded_above_mask
|
|
|
assert torch.all(unbounded_mask + negation_mask)
|
|
|
lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[
|
|
|
~unbounded_above_mask * ~negation_mask
|
|
|
]
|
|
|
upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf")
|
|
|
upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[
|
|
|
~unbounded_below_mask * ~negation_mask
|
|
|
]
|
|
|
lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf")
|
|
|
|
|
|
neg_overflow_mask = value < lower_bounds
|
|
|
pos_overflow_mask = value > upper_bounds
|
|
|
|
|
|
energy = torch.zeros_like(value)
|
|
|
energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask]
|
|
|
energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask]
|
|
|
if not compute_derivative:
|
|
|
return energy
|
|
|
|
|
|
dEnergy = torch.zeros_like(value)
|
|
|
dEnergy[neg_overflow_mask] = (
|
|
|
-1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask]
|
|
|
)
|
|
|
dEnergy[pos_overflow_mask] = (
|
|
|
1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask]
|
|
|
)
|
|
|
|
|
|
return energy, dEnergy
|
|
|
|
|
|
|
|
|
class ReferencePotential(Potential):
|
|
|
def compute_variable(
|
|
|
self, coords, index, ref_coords, ref_mask, compute_gradient=False
|
|
|
):
|
|
|
aligned_ref_coords = weighted_rigid_align(
|
|
|
ref_coords.float(),
|
|
|
coords[:, index].float(),
|
|
|
ref_mask,
|
|
|
ref_mask,
|
|
|
)
|
|
|
|
|
|
r = coords[:, index] - aligned_ref_coords
|
|
|
r_norm = torch.linalg.norm(r, dim=-1)
|
|
|
|
|
|
if not compute_gradient:
|
|
|
return r_norm
|
|
|
|
|
|
r_hat = r / r_norm.unsqueeze(-1)
|
|
|
grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1)
|
|
|
return r_norm, grad
|
|
|
|
|
|
|
|
|
class DistancePotential(Potential):
|
|
|
def compute_variable(
|
|
|
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
|
|
|
):
|
|
|
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
|
|
|
r_ij_norm = torch.linalg.norm(r_ij, dim=-1)
|
|
|
r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1)
|
|
|
|
|
|
if not compute_gradient:
|
|
|
return r_ij_norm
|
|
|
|
|
|
grad_i = r_hat_ij
|
|
|
grad_j = -1 * r_hat_ij
|
|
|
grad = torch.stack((grad_i, grad_j), dim=1)
|
|
|
return r_ij_norm, grad
|
|
|
|
|
|
|
|
|
class DihedralPotential(Potential):
|
|
|
def compute_variable(
|
|
|
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
|
|
|
):
|
|
|
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
|
|
|
r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1])
|
|
|
r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3])
|
|
|
|
|
|
n_ijk = torch.cross(r_ij, r_kj, dim=-1)
|
|
|
n_jkl = torch.cross(r_kj, r_kl, dim=-1)
|
|
|
|
|
|
r_kj_norm = torch.linalg.norm(r_kj, dim=-1)
|
|
|
n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1)
|
|
|
n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1)
|
|
|
|
|
|
sign_phi = torch.sign(
|
|
|
r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1)
|
|
|
).squeeze(-1, -2)
|
|
|
phi = sign_phi * torch.arccos(
|
|
|
torch.clamp(
|
|
|
(n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2)
|
|
|
/ (n_ijk_norm * n_jkl_norm),
|
|
|
-1 + 1e-8,
|
|
|
1 - 1e-8,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
if not compute_gradient:
|
|
|
return phi
|
|
|
|
|
|
a = (
|
|
|
(r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
|
|
|
).unsqueeze(-1)
|
|
|
b = (
|
|
|
(r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
|
|
|
).unsqueeze(-1)
|
|
|
|
|
|
grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1)
|
|
|
grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1)
|
|
|
grad_j = (a - 1) * grad_i - b * grad_l
|
|
|
grad_k = (b - 1) * grad_l - a * grad_i
|
|
|
grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1)
|
|
|
return phi, grad
|
|
|
|
|
|
|
|
|
class AbsDihedralPotential(DihedralPotential):
|
|
|
def compute_variable(
|
|
|
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
|
|
|
):
|
|
|
if not compute_gradient:
|
|
|
phi = super().compute_variable(
|
|
|
coords, index, compute_gradient=compute_gradient
|
|
|
)
|
|
|
phi = torch.abs(phi)
|
|
|
return phi
|
|
|
|
|
|
phi, grad = super().compute_variable(
|
|
|
coords, index, compute_gradient=compute_gradient
|
|
|
)
|
|
|
grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1
|
|
|
phi = torch.abs(phi)
|
|
|
|
|
|
return phi, grad
|
|
|
|
|
|
|
|
|
class PoseBustersPotential(FlatBottomPotential, DistancePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
pair_index = feats["rdkit_bounds_index"][0]
|
|
|
lower_bounds = feats["rdkit_lower_bounds"][0].clone()
|
|
|
upper_bounds = feats["rdkit_upper_bounds"][0].clone()
|
|
|
bond_mask = feats["rdkit_bounds_bond_mask"][0]
|
|
|
angle_mask = feats["rdkit_bounds_angle_mask"][0]
|
|
|
|
|
|
lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"]
|
|
|
upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"]
|
|
|
lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"]
|
|
|
upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"]
|
|
|
lower_bounds[bond_mask * angle_mask] *= 1.0 - min(
|
|
|
parameters["bond_buffer"], parameters["angle_buffer"]
|
|
|
)
|
|
|
upper_bounds[bond_mask * angle_mask] *= 1.0 + min(
|
|
|
parameters["bond_buffer"], parameters["angle_buffer"]
|
|
|
)
|
|
|
lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"]
|
|
|
upper_bounds[~bond_mask * ~angle_mask] = float("inf")
|
|
|
|
|
|
vdw_radii = torch.zeros(
|
|
|
const.num_elements, dtype=torch.float32, device=pair_index.device
|
|
|
)
|
|
|
vdw_radii[1:119] = torch.tensor(
|
|
|
const.vdw_radii, dtype=torch.float32, device=pair_index.device
|
|
|
)
|
|
|
atom_vdw_radii = (
|
|
|
feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
|
|
|
).squeeze(-1)[0]
|
|
|
bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0)
|
|
|
lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask])
|
|
|
upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask])
|
|
|
|
|
|
k = torch.ones_like(lower_bounds)
|
|
|
|
|
|
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class ConnectionsPotential(FlatBottomPotential, DistancePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
pair_index = feats["connected_atom_index"][0]
|
|
|
lower_bounds = None
|
|
|
upper_bounds = torch.full(
|
|
|
(pair_index.shape[1],), parameters["buffer"], device=pair_index.device
|
|
|
)
|
|
|
k = torch.ones_like(upper_bounds)
|
|
|
|
|
|
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class VDWOverlapPotential(FlatBottomPotential, DistancePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
atom_chain_id = (
|
|
|
torch.bmm(
|
|
|
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
|
|
|
)
|
|
|
.squeeze(-1)
|
|
|
.long()
|
|
|
)[0]
|
|
|
atom_pad_mask = feats["atom_pad_mask"][0].bool()
|
|
|
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
|
|
|
single_ion_mask = (chain_sizes > 1)[atom_chain_id]
|
|
|
|
|
|
vdw_radii = torch.zeros(
|
|
|
const.num_elements, dtype=torch.float32, device=atom_chain_id.device
|
|
|
)
|
|
|
vdw_radii[1:119] = torch.tensor(
|
|
|
const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device
|
|
|
)
|
|
|
atom_vdw_radii = (
|
|
|
feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
|
|
|
).squeeze(-1)[0]
|
|
|
|
|
|
pair_index = torch.triu_indices(
|
|
|
atom_chain_id.shape[0],
|
|
|
atom_chain_id.shape[0],
|
|
|
1,
|
|
|
device=atom_chain_id.device,
|
|
|
)
|
|
|
|
|
|
pair_pad_mask = atom_pad_mask[pair_index].all(dim=0)
|
|
|
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
|
|
|
|
|
|
num_chains = atom_chain_id.max() + 1
|
|
|
connected_chain_index = feats["connected_chain_index"][0]
|
|
|
connected_chain_matrix = torch.eye(
|
|
|
num_chains, device=atom_chain_id.device, dtype=torch.bool
|
|
|
)
|
|
|
connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = (
|
|
|
True
|
|
|
)
|
|
|
connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = (
|
|
|
True
|
|
|
)
|
|
|
connected_chain_mask = connected_chain_matrix[
|
|
|
atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]]
|
|
|
]
|
|
|
|
|
|
pair_index = pair_index[
|
|
|
:, pair_pad_mask * pair_ion_mask * ~connected_chain_mask
|
|
|
]
|
|
|
|
|
|
lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * (
|
|
|
1.0 - parameters["buffer"]
|
|
|
)
|
|
|
upper_bounds = None
|
|
|
k = torch.ones_like(lower_bounds)
|
|
|
|
|
|
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
atom_chain_id = (
|
|
|
torch.bmm(
|
|
|
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
|
|
|
)
|
|
|
.squeeze(-1)
|
|
|
.long()
|
|
|
)[0]
|
|
|
atom_pad_mask = feats["atom_pad_mask"][0].bool()
|
|
|
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
|
|
|
single_ion_mask = chain_sizes > 1
|
|
|
|
|
|
pair_index = feats["symmetric_chain_index"][0]
|
|
|
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
|
|
|
pair_index = pair_index[:, pair_ion_mask]
|
|
|
lower_bounds = torch.full(
|
|
|
(pair_index.shape[1],),
|
|
|
parameters["buffer"],
|
|
|
dtype=torch.float32,
|
|
|
device=pair_index.device,
|
|
|
)
|
|
|
upper_bounds = None
|
|
|
k = torch.ones_like(lower_bounds)
|
|
|
|
|
|
return (
|
|
|
pair_index,
|
|
|
(k, lower_bounds, upper_bounds),
|
|
|
(atom_chain_id, atom_pad_mask),
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
|
|
|
class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
stereo_bond_index = feats["stereo_bond_index"][0]
|
|
|
stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool()
|
|
|
|
|
|
lower_bounds = torch.zeros(
|
|
|
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
|
|
|
)
|
|
|
upper_bounds = torch.zeros(
|
|
|
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
|
|
|
)
|
|
|
lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"]
|
|
|
upper_bounds[stereo_bond_orientations] = float("inf")
|
|
|
lower_bounds[~stereo_bond_orientations] = float("-inf")
|
|
|
upper_bounds[~stereo_bond_orientations] = parameters["buffer"]
|
|
|
|
|
|
k = torch.ones_like(lower_bounds)
|
|
|
|
|
|
return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class ChiralAtomPotential(FlatBottomPotential, DihedralPotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
chiral_atom_index = feats["chiral_atom_index"][0]
|
|
|
chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool()
|
|
|
|
|
|
lower_bounds = torch.zeros(
|
|
|
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
|
|
|
)
|
|
|
upper_bounds = torch.zeros(
|
|
|
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
|
|
|
)
|
|
|
lower_bounds[chiral_atom_orientations] = parameters["buffer"]
|
|
|
upper_bounds[chiral_atom_orientations] = float("inf")
|
|
|
upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"]
|
|
|
lower_bounds[~chiral_atom_orientations] = float("-inf")
|
|
|
|
|
|
k = torch.ones_like(lower_bounds)
|
|
|
return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
double_bond_index = feats["planar_bond_index"][0].T
|
|
|
double_bond_improper_index = torch.tensor(
|
|
|
[
|
|
|
[1, 2, 3, 0],
|
|
|
[4, 5, 0, 3],
|
|
|
],
|
|
|
device=double_bond_index.device,
|
|
|
).T
|
|
|
improper_index = (
|
|
|
double_bond_index[:, double_bond_improper_index]
|
|
|
.swapaxes(0, 1)
|
|
|
.flatten(start_dim=1)
|
|
|
)
|
|
|
lower_bounds = None
|
|
|
upper_bounds = torch.full(
|
|
|
(improper_index.shape[1],),
|
|
|
parameters["buffer"],
|
|
|
device=improper_index.device,
|
|
|
)
|
|
|
k = torch.ones_like(upper_bounds)
|
|
|
|
|
|
return improper_index, (k, lower_bounds, upper_bounds), None, None, None
|
|
|
|
|
|
|
|
|
class TemplateReferencePotential(FlatBottomPotential, ReferencePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
if "template_mask_cb" not in feats or "template_force" not in feats:
|
|
|
return torch.empty([1, 0]), None, None, None, None
|
|
|
|
|
|
template_mask = feats["template_mask_cb"][feats["template_force"]]
|
|
|
if template_mask.shape[0] == 0:
|
|
|
return torch.empty([1, 0]), None, None, None, None
|
|
|
|
|
|
ref_coords = feats["template_cb"][feats["template_force"]].clone()
|
|
|
ref_mask = feats["template_mask_cb"][feats["template_force"]].clone()
|
|
|
ref_atom_index = (
|
|
|
torch.bmm(
|
|
|
feats["token_to_rep_atom"].float(),
|
|
|
torch.arange(
|
|
|
feats["atom_pad_mask"].shape[1],
|
|
|
device=feats["atom_pad_mask"].device,
|
|
|
dtype=torch.float32,
|
|
|
)[None, :, None],
|
|
|
)
|
|
|
.squeeze(-1)
|
|
|
.long()
|
|
|
)[0]
|
|
|
ref_token_index = (
|
|
|
torch.bmm(
|
|
|
feats["atom_to_token"].float(),
|
|
|
feats["token_index"].unsqueeze(-1).float(),
|
|
|
)
|
|
|
.squeeze(-1)
|
|
|
.long()
|
|
|
)[0]
|
|
|
|
|
|
index = torch.arange(
|
|
|
template_mask.shape[-1], dtype=torch.long, device=template_mask.device
|
|
|
)[None]
|
|
|
upper_bounds = torch.full(
|
|
|
template_mask.shape, float("inf"), device=index.device, dtype=torch.float32
|
|
|
)
|
|
|
ref_idxs = torch.argwhere(template_mask).T
|
|
|
upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][
|
|
|
feats["template_force"]
|
|
|
][ref_idxs[0]]
|
|
|
|
|
|
lower_bounds = None
|
|
|
k = torch.ones_like(upper_bounds)
|
|
|
return (
|
|
|
index,
|
|
|
(k, lower_bounds, upper_bounds),
|
|
|
None,
|
|
|
(ref_coords, ref_mask, ref_atom_index, ref_token_index),
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
|
|
|
class ContactPotentital(FlatBottomPotential, DistancePotential):
|
|
|
def compute_args(self, feats, parameters):
|
|
|
index = feats["contact_pair_index"][0]
|
|
|
union_index = feats["contact_union_index"][0]
|
|
|
negation_mask = feats["contact_negation_mask"][0]
|
|
|
lower_bounds = None
|
|
|
upper_bounds = feats["contact_thresholds"][0].clone()
|
|
|
k = torch.ones_like(upper_bounds)
|
|
|
return (
|
|
|
index,
|
|
|
(k, lower_bounds, upper_bounds),
|
|
|
None,
|
|
|
None,
|
|
|
(negation_mask, union_index),
|
|
|
)
|
|
|
|
|
|
|
|
|
def get_potentials(steering_args, boltz2=False):
|
|
|
potentials = []
|
|
|
if steering_args["fk_steering"] or steering_args["physical_guidance_update"]:
|
|
|
potentials.extend(
|
|
|
[
|
|
|
SymmetricChainCOMPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 4,
|
|
|
"guidance_weight": 0.5
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 0.5,
|
|
|
"buffer": ExponentialInterpolation(
|
|
|
start=1.0, end=5.0, alpha=-2.0
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
VDWOverlapPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 5,
|
|
|
"guidance_weight": (
|
|
|
PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0])
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0
|
|
|
),
|
|
|
"resampling_weight": PiecewiseStepFunction(
|
|
|
thresholds=[0.6], values=[0.01, 0.0]
|
|
|
),
|
|
|
"buffer": 0.225,
|
|
|
}
|
|
|
),
|
|
|
ConnectionsPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 1,
|
|
|
"guidance_weight": 0.15
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 1.0,
|
|
|
"buffer": 2.0,
|
|
|
}
|
|
|
),
|
|
|
PoseBustersPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 1,
|
|
|
"guidance_weight": 0.01
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 0.1,
|
|
|
"bond_buffer": 0.125,
|
|
|
"angle_buffer": 0.125,
|
|
|
"clash_buffer": 0.10,
|
|
|
}
|
|
|
),
|
|
|
ChiralAtomPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 1,
|
|
|
"guidance_weight": 0.1
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 1.0,
|
|
|
"buffer": 0.52360,
|
|
|
}
|
|
|
),
|
|
|
StereoBondPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 1,
|
|
|
"guidance_weight": 0.05
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 1.0,
|
|
|
"buffer": 0.52360,
|
|
|
}
|
|
|
),
|
|
|
PlanarBondPotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 1,
|
|
|
"guidance_weight": 0.05
|
|
|
if steering_args["physical_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 1.0,
|
|
|
"buffer": 0.26180,
|
|
|
}
|
|
|
),
|
|
|
]
|
|
|
)
|
|
|
if boltz2 and (
|
|
|
steering_args["fk_steering"] or steering_args["contact_guidance_update"]
|
|
|
):
|
|
|
potentials.extend(
|
|
|
[
|
|
|
ContactPotentital(
|
|
|
parameters={
|
|
|
"guidance_interval": 4,
|
|
|
"guidance_weight": (
|
|
|
PiecewiseStepFunction(
|
|
|
thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0]
|
|
|
)
|
|
|
if steering_args["contact_guidance_update"]
|
|
|
else 0.0
|
|
|
),
|
|
|
"resampling_weight": 1.0,
|
|
|
"union_lambda": ExponentialInterpolation(
|
|
|
start=8.0, end=0.0, alpha=-2.0
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
TemplateReferencePotential(
|
|
|
parameters={
|
|
|
"guidance_interval": 2,
|
|
|
"guidance_weight": 0.1
|
|
|
if steering_args["contact_guidance_update"]
|
|
|
else 0.0,
|
|
|
"resampling_weight": 1.0,
|
|
|
}
|
|
|
),
|
|
|
]
|
|
|
)
|
|
|
return potentials
|
|
|
|