Boltz2 / vb_potentials_potentials.py
lhallee's picture
Upload folder using huggingface_hub
827d9ec verified
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Set, List, Union
import torch
import numpy as np
from . import vb_const as const
from .vb_potentials_schedules import (
ParameterSchedule,
ExponentialInterpolation,
PiecewiseStepFunction,
)
from .vb_loss_diffusionv2 import weighted_rigid_align
class Potential(ABC):
def __init__(
self,
parameters: Optional[
Dict[str, Union[ParameterSchedule, float, int, bool]]
] = None,
):
self.parameters = parameters
def compute(self, coords, feats, parameters):
index, args, com_args, ref_args, operator_args = self.compute_args(
feats, parameters
)
if index.shape[1] == 0:
return torch.zeros(coords.shape[:-2], device=coords.device)
if com_args is not None:
com_index, atom_pad_mask = com_args
unpad_com_index = com_index[atom_pad_mask]
unpad_coords = coords[..., atom_pad_mask, :]
coords = torch.zeros(
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
device=coords.device,
).scatter_reduce(
-2,
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
unpad_coords,
"mean",
)
else:
com_index, atom_pad_mask = None, None
if ref_args is not None:
ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args
coords = coords[..., ref_atom_index, :]
else:
ref_coords, ref_mask, ref_atom_index, ref_token_index = (
None,
None,
None,
None,
)
if operator_args is not None:
negation_mask, union_index = operator_args
else:
negation_mask, union_index = None, None
value = self.compute_variable(
coords,
index,
ref_coords=ref_coords,
ref_mask=ref_mask,
compute_gradient=False,
)
energy = self.compute_function(
value, *args, negation_mask=negation_mask, compute_derivative=False
)
if union_index is not None:
neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy)
Z = torch.zeros(
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
).scatter_reduce(
-1,
union_index.expand_as(neg_exp_energy),
neg_exp_energy,
"sum",
)
softmax_energy = neg_exp_energy / Z[..., union_index]
softmax_energy[Z[..., union_index] == 0] = 0
return (energy * softmax_energy).sum(dim=-1)
return energy.sum(dim=tuple(range(1, energy.dim())))
def compute_gradient(self, coords, feats, parameters):
index, args, com_args, ref_args, operator_args = self.compute_args(
feats, parameters
)
if index.shape[1] == 0:
return torch.zeros_like(coords)
if com_args is not None:
com_index, atom_pad_mask = com_args
unpad_coords = coords[..., atom_pad_mask, :]
unpad_com_index = com_index[atom_pad_mask]
coords = torch.zeros(
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
device=coords.device,
).scatter_reduce(
-2,
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
unpad_coords,
"mean",
)
com_counts = torch.bincount(com_index[atom_pad_mask])
else:
com_index, atom_pad_mask = None, None
if ref_args is not None:
ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args
coords = coords[..., ref_atom_index, :]
else:
ref_coords, ref_mask, ref_atom_index, ref_token_index = (
None,
None,
None,
None,
)
if operator_args is not None:
negation_mask, union_index = operator_args
else:
negation_mask, union_index = None, None
value, grad_value = self.compute_variable(
coords,
index,
ref_coords=ref_coords,
ref_mask=ref_mask,
compute_gradient=True,
)
energy, dEnergy = self.compute_function(
value,
*args, negation_mask=negation_mask, compute_derivative=True
)
if union_index is not None:
neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy)
Z = torch.zeros(
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
).scatter_reduce(
-1,
union_index.expand_as(energy),
neg_exp_energy,
"sum",
)
softmax_energy = neg_exp_energy / Z[..., union_index]
softmax_energy[Z[..., union_index] == 0] = 0
f = torch.zeros(
(*energy.shape[:-1], union_index.max() + 1), device=union_index.device
).scatter_reduce(
-1,
union_index.expand_as(energy),
energy * softmax_energy,
"sum",
)
dSoftmax = (
dEnergy
* softmax_energy
* (1 + parameters["union_lambda"] * (energy - f[..., union_index]))
)
prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze(
-1
) * grad_value.flatten(start_dim=-3, end_dim=-2)
if prod.dim() > 3:
prod = prod.sum(dim=list(range(1, prod.dim() - 2)))
grad_atom = torch.zeros_like(coords).scatter_reduce(
-2,
index.flatten(start_dim=0, end_dim=1)
.unsqueeze(-1)
.expand((*coords.shape[:-2], -1, 3)),
prod,
"sum",
)
else:
prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze(
-1
) * grad_value.flatten(start_dim=-3, end_dim=-2)
if prod.dim() > 3:
prod = prod.sum(dim=list(range(1, prod.dim() - 2)))
grad_atom = torch.zeros_like(coords).scatter_reduce(
-2,
index.flatten(start_dim=0, end_dim=1)
.unsqueeze(-1)
.expand((*coords.shape[:-2], -1, 3)), # 9 x 516 x 3
prod,
"sum",
)
if com_index is not None:
grad_atom = grad_atom[..., com_index, :]
elif ref_token_index is not None:
grad_atom = grad_atom[..., ref_token_index, :]
return grad_atom
def compute_parameters(self, t):
if self.parameters is None:
return None
parameters = {
name: parameter
if not isinstance(parameter, ParameterSchedule)
else parameter.compute(t)
for name, parameter in self.parameters.items()
}
return parameters
@abstractmethod
def compute_function(
self, value, *args, negation_mask=None, compute_derivative=False
):
raise NotImplementedError
@abstractmethod
def compute_variable(self, coords, index, compute_gradient=False):
raise NotImplementedError
@abstractmethod
def compute_args(self, t, feats, **parameters):
raise NotImplementedError
def get_reference_coords(self, feats, parameters):
return None, None
class FlatBottomPotential(Potential):
def compute_function(
self,
value,
k,
lower_bounds,
upper_bounds,
negation_mask=None,
compute_derivative=False,
):
if lower_bounds is None:
lower_bounds = torch.full_like(value, float("-inf"))
if upper_bounds is None:
upper_bounds = torch.full_like(value, float("inf"))
lower_bounds = lower_bounds.expand_as(value).clone()
upper_bounds = upper_bounds.expand_as(value).clone()
if negation_mask is not None:
unbounded_below_mask = torch.isneginf(lower_bounds)
unbounded_above_mask = torch.isposinf(upper_bounds)
unbounded_mask = unbounded_below_mask + unbounded_above_mask
assert torch.all(unbounded_mask + negation_mask)
lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[
~unbounded_above_mask * ~negation_mask
]
upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf")
upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[
~unbounded_below_mask * ~negation_mask
]
lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf")
neg_overflow_mask = value < lower_bounds
pos_overflow_mask = value > upper_bounds
energy = torch.zeros_like(value)
energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask]
energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask]
if not compute_derivative:
return energy
dEnergy = torch.zeros_like(value)
dEnergy[neg_overflow_mask] = (
-1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask]
)
dEnergy[pos_overflow_mask] = (
1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask]
)
return energy, dEnergy
class ReferencePotential(Potential):
def compute_variable(
self, coords, index, ref_coords, ref_mask, compute_gradient=False
):
aligned_ref_coords = weighted_rigid_align(
ref_coords.float(),
coords[:, index].float(),
ref_mask,
ref_mask,
)
r = coords[:, index] - aligned_ref_coords
r_norm = torch.linalg.norm(r, dim=-1)
if not compute_gradient:
return r_norm
r_hat = r / r_norm.unsqueeze(-1)
grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1)
return r_norm, grad
class DistancePotential(Potential):
def compute_variable(
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
):
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
r_ij_norm = torch.linalg.norm(r_ij, dim=-1)
r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1)
if not compute_gradient:
return r_ij_norm
grad_i = r_hat_ij
grad_j = -1 * r_hat_ij
grad = torch.stack((grad_i, grad_j), dim=1)
return r_ij_norm, grad
class DihedralPotential(Potential):
def compute_variable(
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
):
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1])
r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3])
n_ijk = torch.cross(r_ij, r_kj, dim=-1)
n_jkl = torch.cross(r_kj, r_kl, dim=-1)
r_kj_norm = torch.linalg.norm(r_kj, dim=-1)
n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1)
n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1)
sign_phi = torch.sign(
r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1)
).squeeze(-1, -2)
phi = sign_phi * torch.arccos(
torch.clamp(
(n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2)
/ (n_ijk_norm * n_jkl_norm),
-1 + 1e-8,
1 - 1e-8,
)
)
if not compute_gradient:
return phi
a = (
(r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
).unsqueeze(-1)
b = (
(r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
).unsqueeze(-1)
grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1)
grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1)
grad_j = (a - 1) * grad_i - b * grad_l
grad_k = (b - 1) * grad_l - a * grad_i
grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1)
return phi, grad
class AbsDihedralPotential(DihedralPotential):
def compute_variable(
self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False
):
if not compute_gradient:
phi = super().compute_variable(
coords, index, compute_gradient=compute_gradient
)
phi = torch.abs(phi)
return phi
phi, grad = super().compute_variable(
coords, index, compute_gradient=compute_gradient
)
grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1
phi = torch.abs(phi)
return phi, grad
class PoseBustersPotential(FlatBottomPotential, DistancePotential):
def compute_args(self, feats, parameters):
pair_index = feats["rdkit_bounds_index"][0]
lower_bounds = feats["rdkit_lower_bounds"][0].clone()
upper_bounds = feats["rdkit_upper_bounds"][0].clone()
bond_mask = feats["rdkit_bounds_bond_mask"][0]
angle_mask = feats["rdkit_bounds_angle_mask"][0]
lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"]
upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"]
lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"]
upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"]
lower_bounds[bond_mask * angle_mask] *= 1.0 - min(
parameters["bond_buffer"], parameters["angle_buffer"]
)
upper_bounds[bond_mask * angle_mask] *= 1.0 + min(
parameters["bond_buffer"], parameters["angle_buffer"]
)
lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"]
upper_bounds[~bond_mask * ~angle_mask] = float("inf")
vdw_radii = torch.zeros(
const.num_elements, dtype=torch.float32, device=pair_index.device
)
vdw_radii[1:119] = torch.tensor(
const.vdw_radii, dtype=torch.float32, device=pair_index.device
)
atom_vdw_radii = (
feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
).squeeze(-1)[0]
bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0)
lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask])
upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask])
k = torch.ones_like(lower_bounds)
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
class ConnectionsPotential(FlatBottomPotential, DistancePotential):
def compute_args(self, feats, parameters):
pair_index = feats["connected_atom_index"][0]
lower_bounds = None
upper_bounds = torch.full(
(pair_index.shape[1],), parameters["buffer"], device=pair_index.device
)
k = torch.ones_like(upper_bounds)
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
class VDWOverlapPotential(FlatBottomPotential, DistancePotential):
def compute_args(self, feats, parameters):
atom_chain_id = (
torch.bmm(
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
)
.squeeze(-1)
.long()
)[0]
atom_pad_mask = feats["atom_pad_mask"][0].bool()
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
single_ion_mask = (chain_sizes > 1)[atom_chain_id]
vdw_radii = torch.zeros(
const.num_elements, dtype=torch.float32, device=atom_chain_id.device
)
vdw_radii[1:119] = torch.tensor(
const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device
)
atom_vdw_radii = (
feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
).squeeze(-1)[0]
pair_index = torch.triu_indices(
atom_chain_id.shape[0],
atom_chain_id.shape[0],
1,
device=atom_chain_id.device,
)
pair_pad_mask = atom_pad_mask[pair_index].all(dim=0)
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
num_chains = atom_chain_id.max() + 1
connected_chain_index = feats["connected_chain_index"][0]
connected_chain_matrix = torch.eye(
num_chains, device=atom_chain_id.device, dtype=torch.bool
)
connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = (
True
)
connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = (
True
)
connected_chain_mask = connected_chain_matrix[
atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]]
]
pair_index = pair_index[
:, pair_pad_mask * pair_ion_mask * ~connected_chain_mask
]
lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * (
1.0 - parameters["buffer"]
)
upper_bounds = None
k = torch.ones_like(lower_bounds)
return pair_index, (k, lower_bounds, upper_bounds), None, None, None
class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential):
def compute_args(self, feats, parameters):
atom_chain_id = (
torch.bmm(
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
)
.squeeze(-1)
.long()
)[0]
atom_pad_mask = feats["atom_pad_mask"][0].bool()
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
single_ion_mask = chain_sizes > 1
pair_index = feats["symmetric_chain_index"][0]
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
pair_index = pair_index[:, pair_ion_mask]
lower_bounds = torch.full(
(pair_index.shape[1],),
parameters["buffer"],
dtype=torch.float32,
device=pair_index.device,
)
upper_bounds = None
k = torch.ones_like(lower_bounds)
return (
pair_index,
(k, lower_bounds, upper_bounds),
(atom_chain_id, atom_pad_mask),
None,
None,
)
class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential):
def compute_args(self, feats, parameters):
stereo_bond_index = feats["stereo_bond_index"][0]
stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool()
lower_bounds = torch.zeros(
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
)
upper_bounds = torch.zeros(
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
)
lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"]
upper_bounds[stereo_bond_orientations] = float("inf")
lower_bounds[~stereo_bond_orientations] = float("-inf")
upper_bounds[~stereo_bond_orientations] = parameters["buffer"]
k = torch.ones_like(lower_bounds)
return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None
class ChiralAtomPotential(FlatBottomPotential, DihedralPotential):
def compute_args(self, feats, parameters):
chiral_atom_index = feats["chiral_atom_index"][0]
chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool()
lower_bounds = torch.zeros(
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
)
upper_bounds = torch.zeros(
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
)
lower_bounds[chiral_atom_orientations] = parameters["buffer"]
upper_bounds[chiral_atom_orientations] = float("inf")
upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"]
lower_bounds[~chiral_atom_orientations] = float("-inf")
k = torch.ones_like(lower_bounds)
return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None
class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential):
def compute_args(self, feats, parameters):
double_bond_index = feats["planar_bond_index"][0].T
double_bond_improper_index = torch.tensor(
[
[1, 2, 3, 0],
[4, 5, 0, 3],
],
device=double_bond_index.device,
).T
improper_index = (
double_bond_index[:, double_bond_improper_index]
.swapaxes(0, 1)
.flatten(start_dim=1)
)
lower_bounds = None
upper_bounds = torch.full(
(improper_index.shape[1],),
parameters["buffer"],
device=improper_index.device,
)
k = torch.ones_like(upper_bounds)
return improper_index, (k, lower_bounds, upper_bounds), None, None, None
class TemplateReferencePotential(FlatBottomPotential, ReferencePotential):
def compute_args(self, feats, parameters):
if "template_mask_cb" not in feats or "template_force" not in feats:
return torch.empty([1, 0]), None, None, None, None
template_mask = feats["template_mask_cb"][feats["template_force"]]
if template_mask.shape[0] == 0:
return torch.empty([1, 0]), None, None, None, None
ref_coords = feats["template_cb"][feats["template_force"]].clone()
ref_mask = feats["template_mask_cb"][feats["template_force"]].clone()
ref_atom_index = (
torch.bmm(
feats["token_to_rep_atom"].float(),
torch.arange(
feats["atom_pad_mask"].shape[1],
device=feats["atom_pad_mask"].device,
dtype=torch.float32,
)[None, :, None],
)
.squeeze(-1)
.long()
)[0]
ref_token_index = (
torch.bmm(
feats["atom_to_token"].float(),
feats["token_index"].unsqueeze(-1).float(),
)
.squeeze(-1)
.long()
)[0]
index = torch.arange(
template_mask.shape[-1], dtype=torch.long, device=template_mask.device
)[None]
upper_bounds = torch.full(
template_mask.shape, float("inf"), device=index.device, dtype=torch.float32
)
ref_idxs = torch.argwhere(template_mask).T
upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][
feats["template_force"]
][ref_idxs[0]]
lower_bounds = None
k = torch.ones_like(upper_bounds)
return (
index,
(k, lower_bounds, upper_bounds),
None,
(ref_coords, ref_mask, ref_atom_index, ref_token_index),
None,
)
class ContactPotentital(FlatBottomPotential, DistancePotential):
def compute_args(self, feats, parameters):
index = feats["contact_pair_index"][0]
union_index = feats["contact_union_index"][0]
negation_mask = feats["contact_negation_mask"][0]
lower_bounds = None
upper_bounds = feats["contact_thresholds"][0].clone()
k = torch.ones_like(upper_bounds)
return (
index,
(k, lower_bounds, upper_bounds),
None,
None,
(negation_mask, union_index),
)
def get_potentials(steering_args, boltz2=False):
potentials = []
if steering_args["fk_steering"] or steering_args["physical_guidance_update"]:
potentials.extend(
[
SymmetricChainCOMPotential(
parameters={
"guidance_interval": 4,
"guidance_weight": 0.5
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 0.5,
"buffer": ExponentialInterpolation(
start=1.0, end=5.0, alpha=-2.0
),
}
),
VDWOverlapPotential(
parameters={
"guidance_interval": 5,
"guidance_weight": (
PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0])
if steering_args["physical_guidance_update"]
else 0.0
),
"resampling_weight": PiecewiseStepFunction(
thresholds=[0.6], values=[0.01, 0.0]
),
"buffer": 0.225,
}
),
ConnectionsPotential(
parameters={
"guidance_interval": 1,
"guidance_weight": 0.15
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 1.0,
"buffer": 2.0,
}
),
PoseBustersPotential(
parameters={
"guidance_interval": 1,
"guidance_weight": 0.01
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 0.1,
"bond_buffer": 0.125,
"angle_buffer": 0.125,
"clash_buffer": 0.10,
}
),
ChiralAtomPotential(
parameters={
"guidance_interval": 1,
"guidance_weight": 0.1
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 1.0,
"buffer": 0.52360,
}
),
StereoBondPotential(
parameters={
"guidance_interval": 1,
"guidance_weight": 0.05
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 1.0,
"buffer": 0.52360,
}
),
PlanarBondPotential(
parameters={
"guidance_interval": 1,
"guidance_weight": 0.05
if steering_args["physical_guidance_update"]
else 0.0,
"resampling_weight": 1.0,
"buffer": 0.26180,
}
),
]
)
if boltz2 and (
steering_args["fk_steering"] or steering_args["contact_guidance_update"]
):
potentials.extend(
[
ContactPotentital(
parameters={
"guidance_interval": 4,
"guidance_weight": (
PiecewiseStepFunction(
thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0]
)
if steering_args["contact_guidance_update"]
else 0.0
),
"resampling_weight": 1.0,
"union_lambda": ExponentialInterpolation(
start=8.0, end=0.0, alpha=-2.0
),
}
),
TemplateReferencePotential(
parameters={
"guidance_interval": 2,
"guidance_weight": 0.1
if steering_args["contact_guidance_update"]
else 0.0,
"resampling_weight": 1.0,
}
),
]
)
return potentials