|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
from math import sqrt
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
from torch import nn
|
|
|
from torch.nn import Module
|
|
|
|
|
|
from . import vb_const as const
|
|
|
from . import vb_layers_initialize as init
|
|
|
from .vb_loss_diffusionv2 import (
|
|
|
smooth_lddt_loss,
|
|
|
weighted_rigid_align,
|
|
|
)
|
|
|
from .vb_modules_encodersv2 import (
|
|
|
AtomAttentionDecoder,
|
|
|
AtomAttentionEncoder,
|
|
|
SingleConditioning,
|
|
|
)
|
|
|
from .vb_modules_transformersv2 import (
|
|
|
DiffusionTransformer,
|
|
|
)
|
|
|
from .vb_modules_utils import (
|
|
|
LinearNoBias,
|
|
|
center_random_augmentation,
|
|
|
compute_random_augmentation,
|
|
|
default,
|
|
|
log,
|
|
|
)
|
|
|
from .vb_potentials_potentials import get_potentials
|
|
|
|
|
|
|
|
|
class DiffusionModule(Module):
|
|
|
"""Diffusion module"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
token_s: int,
|
|
|
atom_s: int,
|
|
|
atoms_per_window_queries: int = 32,
|
|
|
atoms_per_window_keys: int = 128,
|
|
|
sigma_data: int = 16,
|
|
|
dim_fourier: int = 256,
|
|
|
atom_encoder_depth: int = 3,
|
|
|
atom_encoder_heads: int = 4,
|
|
|
token_transformer_depth: int = 24,
|
|
|
token_transformer_heads: int = 8,
|
|
|
atom_decoder_depth: int = 3,
|
|
|
atom_decoder_heads: int = 4,
|
|
|
conditioning_transition_layers: int = 2,
|
|
|
activation_checkpointing: bool = False,
|
|
|
transformer_post_ln: bool = False,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.atoms_per_window_queries = atoms_per_window_queries
|
|
|
self.atoms_per_window_keys = atoms_per_window_keys
|
|
|
self.sigma_data = sigma_data
|
|
|
self.activation_checkpointing = activation_checkpointing
|
|
|
|
|
|
|
|
|
self.single_conditioner = SingleConditioning(
|
|
|
sigma_data=sigma_data,
|
|
|
token_s=token_s,
|
|
|
dim_fourier=dim_fourier,
|
|
|
num_transitions=conditioning_transition_layers,
|
|
|
)
|
|
|
|
|
|
self.atom_attention_encoder = AtomAttentionEncoder(
|
|
|
atom_s=atom_s,
|
|
|
token_s=token_s,
|
|
|
atoms_per_window_queries=atoms_per_window_queries,
|
|
|
atoms_per_window_keys=atoms_per_window_keys,
|
|
|
atom_encoder_depth=atom_encoder_depth,
|
|
|
atom_encoder_heads=atom_encoder_heads,
|
|
|
structure_prediction=True,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
transformer_post_layer_norm=transformer_post_ln,
|
|
|
)
|
|
|
|
|
|
self.s_to_a_linear = nn.Sequential(
|
|
|
nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
|
|
|
)
|
|
|
init.final_init_(self.s_to_a_linear[1].weight)
|
|
|
|
|
|
self.token_transformer = DiffusionTransformer(
|
|
|
dim=2 * token_s,
|
|
|
dim_single_cond=2 * token_s,
|
|
|
depth=token_transformer_depth,
|
|
|
heads=token_transformer_heads,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
|
|
|
)
|
|
|
|
|
|
self.a_norm = nn.LayerNorm(
|
|
|
2 * token_s
|
|
|
)
|
|
|
|
|
|
self.atom_attention_decoder = AtomAttentionDecoder(
|
|
|
atom_s=atom_s,
|
|
|
token_s=token_s,
|
|
|
attn_window_queries=atoms_per_window_queries,
|
|
|
attn_window_keys=atoms_per_window_keys,
|
|
|
atom_decoder_depth=atom_decoder_depth,
|
|
|
atom_decoder_heads=atom_decoder_heads,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
s_inputs,
|
|
|
s_trunk,
|
|
|
r_noisy,
|
|
|
times,
|
|
|
feats,
|
|
|
diffusion_conditioning,
|
|
|
multiplicity=1,
|
|
|
):
|
|
|
if self.activation_checkpointing and self.training:
|
|
|
s, normed_fourier = torch.utils.checkpoint.checkpoint(
|
|
|
self.single_conditioner,
|
|
|
times,
|
|
|
s_trunk.repeat_interleave(multiplicity, 0),
|
|
|
s_inputs.repeat_interleave(multiplicity, 0),
|
|
|
)
|
|
|
else:
|
|
|
s, normed_fourier = self.single_conditioner(
|
|
|
times,
|
|
|
s_trunk.repeat_interleave(multiplicity, 0),
|
|
|
s_inputs.repeat_interleave(multiplicity, 0),
|
|
|
)
|
|
|
|
|
|
|
|
|
a, q_skip, c_skip, to_keys = self.atom_attention_encoder(
|
|
|
feats=feats,
|
|
|
q=diffusion_conditioning["q"].float(),
|
|
|
c=diffusion_conditioning["c"].float(),
|
|
|
atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(),
|
|
|
to_keys=diffusion_conditioning["to_keys"],
|
|
|
r=r_noisy,
|
|
|
multiplicity=multiplicity,
|
|
|
)
|
|
|
|
|
|
|
|
|
a = a + self.s_to_a_linear(s)
|
|
|
|
|
|
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
|
|
a = self.token_transformer(
|
|
|
a,
|
|
|
mask=mask.float(),
|
|
|
s=s,
|
|
|
bias=diffusion_conditioning[
|
|
|
"token_trans_bias"
|
|
|
].float(),
|
|
|
multiplicity=multiplicity,
|
|
|
)
|
|
|
a = self.a_norm(a)
|
|
|
|
|
|
|
|
|
r_update = self.atom_attention_decoder(
|
|
|
a=a,
|
|
|
q=q_skip,
|
|
|
c=c_skip,
|
|
|
atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(),
|
|
|
feats=feats,
|
|
|
multiplicity=multiplicity,
|
|
|
to_keys=to_keys,
|
|
|
)
|
|
|
|
|
|
return r_update
|
|
|
|
|
|
|
|
|
class AtomDiffusion(Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
score_model_args,
|
|
|
num_sampling_steps: int = 5,
|
|
|
sigma_min: float = 0.0004,
|
|
|
sigma_max: float = 160.0,
|
|
|
sigma_data: float = 16.0,
|
|
|
rho: float = 7,
|
|
|
P_mean: float = -1.2,
|
|
|
P_std: float = 1.5,
|
|
|
gamma_0: float = 0.8,
|
|
|
gamma_min: float = 1.0,
|
|
|
noise_scale: float = 1.003,
|
|
|
step_scale: float = 1.5,
|
|
|
step_scale_random: list = None,
|
|
|
coordinate_augmentation: bool = True,
|
|
|
coordinate_augmentation_inference=None,
|
|
|
compile_score: bool = False,
|
|
|
alignment_reverse_diff: bool = False,
|
|
|
synchronize_sigmas: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.score_model = DiffusionModule(
|
|
|
**score_model_args,
|
|
|
)
|
|
|
if compile_score:
|
|
|
self.score_model = torch.compile(
|
|
|
self.score_model, dynamic=False, fullgraph=False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.sigma_min = sigma_min
|
|
|
self.sigma_max = sigma_max
|
|
|
self.sigma_data = sigma_data
|
|
|
self.rho = rho
|
|
|
self.P_mean = P_mean
|
|
|
self.P_std = P_std
|
|
|
self.num_sampling_steps = num_sampling_steps
|
|
|
self.gamma_0 = gamma_0
|
|
|
self.gamma_min = gamma_min
|
|
|
self.noise_scale = noise_scale
|
|
|
self.step_scale = step_scale
|
|
|
self.step_scale_random = step_scale_random
|
|
|
self.coordinate_augmentation = coordinate_augmentation
|
|
|
self.coordinate_augmentation_inference = (
|
|
|
coordinate_augmentation_inference
|
|
|
if coordinate_augmentation_inference is not None
|
|
|
else coordinate_augmentation
|
|
|
)
|
|
|
self.alignment_reverse_diff = alignment_reverse_diff
|
|
|
self.synchronize_sigmas = synchronize_sigmas
|
|
|
|
|
|
self.token_s = score_model_args["token_s"]
|
|
|
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
|
|
|
|
|
@property
|
|
|
def device(self):
|
|
|
return next(self.score_model.parameters()).device
|
|
|
|
|
|
def c_skip(self, sigma):
|
|
|
return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
|
|
|
|
|
|
def c_out(self, sigma):
|
|
|
return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
|
|
|
|
|
|
def c_in(self, sigma):
|
|
|
return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
|
|
|
|
|
|
def c_noise(self, sigma):
|
|
|
return log(sigma / self.sigma_data) * 0.25
|
|
|
|
|
|
def preconditioned_network_forward(
|
|
|
self,
|
|
|
noised_atom_coords,
|
|
|
sigma,
|
|
|
network_condition_kwargs: dict,
|
|
|
):
|
|
|
batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
|
|
|
|
|
|
if isinstance(sigma, float):
|
|
|
sigma = torch.full((batch,), sigma, device=device)
|
|
|
|
|
|
padded_sigma = rearrange(sigma, "b -> b 1 1")
|
|
|
|
|
|
r_update = self.score_model(
|
|
|
r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
|
|
|
times=self.c_noise(sigma),
|
|
|
**network_condition_kwargs,
|
|
|
)
|
|
|
|
|
|
denoised_coords = (
|
|
|
self.c_skip(padded_sigma) * noised_atom_coords
|
|
|
+ self.c_out(padded_sigma) * r_update
|
|
|
)
|
|
|
return denoised_coords
|
|
|
|
|
|
def sample_schedule(self, num_sampling_steps=None):
|
|
|
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
|
|
inv_rho = 1 / self.rho
|
|
|
|
|
|
steps = torch.arange(
|
|
|
num_sampling_steps, device=self.device, dtype=torch.float32
|
|
|
)
|
|
|
sigmas = (
|
|
|
self.sigma_max**inv_rho
|
|
|
+ steps
|
|
|
/ (num_sampling_steps - 1)
|
|
|
* (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
|
|
|
) ** self.rho
|
|
|
|
|
|
sigmas = sigmas * self.sigma_data
|
|
|
|
|
|
sigmas = F.pad(sigmas, (0, 1), value=0.0)
|
|
|
return sigmas
|
|
|
|
|
|
def sample(
|
|
|
self,
|
|
|
atom_mask,
|
|
|
num_sampling_steps=None,
|
|
|
multiplicity=1,
|
|
|
max_parallel_samples=None,
|
|
|
steering_args=None,
|
|
|
**network_condition_kwargs,
|
|
|
):
|
|
|
if steering_args is not None and (
|
|
|
steering_args["fk_steering"]
|
|
|
or steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
):
|
|
|
potentials = get_potentials(steering_args, boltz2=True)
|
|
|
|
|
|
if steering_args["fk_steering"]:
|
|
|
multiplicity = multiplicity * steering_args["num_particles"]
|
|
|
energy_traj = torch.empty((multiplicity, 0), device=self.device)
|
|
|
resample_weights = torch.ones(multiplicity, device=self.device).reshape(
|
|
|
-1, steering_args["num_particles"]
|
|
|
)
|
|
|
if (
|
|
|
steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
):
|
|
|
scaled_guidance_update = torch.zeros(
|
|
|
(multiplicity, *atom_mask.shape[1:], 3),
|
|
|
dtype=torch.float32,
|
|
|
device=self.device,
|
|
|
)
|
|
|
if max_parallel_samples is None:
|
|
|
max_parallel_samples = multiplicity
|
|
|
|
|
|
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
|
|
|
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
|
|
|
|
|
shape = (*atom_mask.shape, 3)
|
|
|
|
|
|
|
|
|
sigmas = self.sample_schedule(num_sampling_steps)
|
|
|
gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
|
|
|
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
|
|
|
if self.training and self.step_scale_random is not None:
|
|
|
step_scale = np.random.choice(self.step_scale_random)
|
|
|
else:
|
|
|
step_scale = self.step_scale
|
|
|
|
|
|
|
|
|
init_sigma = sigmas[0]
|
|
|
atom_coords = init_sigma * torch.randn(shape, device=self.device)
|
|
|
token_repr = None
|
|
|
atom_coords_denoised = None
|
|
|
|
|
|
|
|
|
for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
|
|
|
random_R, random_tr = compute_random_augmentation(
|
|
|
multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
|
|
|
)
|
|
|
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
|
|
|
atom_coords = (
|
|
|
torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
|
|
|
)
|
|
|
if atom_coords_denoised is not None:
|
|
|
atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
|
|
|
atom_coords_denoised = (
|
|
|
torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
|
|
|
+ random_tr
|
|
|
)
|
|
|
if (
|
|
|
steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
) and scaled_guidance_update is not None:
|
|
|
scaled_guidance_update = torch.einsum(
|
|
|
"bmd,bds->bms", scaled_guidance_update, random_R
|
|
|
)
|
|
|
|
|
|
sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
|
|
|
|
|
|
t_hat = sigma_tm * (1 + gamma)
|
|
|
steering_t = 1.0 - (step_idx / num_sampling_steps)
|
|
|
noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
|
|
|
eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
|
|
|
atom_coords_noisy = atom_coords + eps
|
|
|
|
|
|
with torch.no_grad():
|
|
|
atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
|
|
|
sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
|
|
|
sample_ids_chunks = sample_ids.chunk(
|
|
|
multiplicity % max_parallel_samples + 1
|
|
|
)
|
|
|
|
|
|
for sample_ids_chunk in sample_ids_chunks:
|
|
|
atom_coords_denoised_chunk = self.preconditioned_network_forward(
|
|
|
atom_coords_noisy[sample_ids_chunk],
|
|
|
t_hat,
|
|
|
network_condition_kwargs=dict(
|
|
|
multiplicity=sample_ids_chunk.numel(),
|
|
|
**network_condition_kwargs,
|
|
|
),
|
|
|
)
|
|
|
atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
|
|
|
|
|
|
if steering_args["fk_steering"] and (
|
|
|
(
|
|
|
step_idx % steering_args["fk_resampling_interval"] == 0
|
|
|
and noise_var > 0
|
|
|
)
|
|
|
or step_idx == num_sampling_steps - 1
|
|
|
):
|
|
|
|
|
|
energy = torch.zeros(multiplicity, device=self.device)
|
|
|
for potential in potentials:
|
|
|
parameters = potential.compute_parameters(steering_t)
|
|
|
if parameters["resampling_weight"] > 0:
|
|
|
component_energy = potential.compute(
|
|
|
atom_coords_denoised,
|
|
|
network_condition_kwargs["feats"],
|
|
|
parameters,
|
|
|
)
|
|
|
energy += parameters["resampling_weight"] * component_energy
|
|
|
energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
|
|
|
|
|
|
|
|
|
if step_idx == 0:
|
|
|
log_G = -1 * energy
|
|
|
else:
|
|
|
log_G = energy_traj[:, -2] - energy_traj[:, -1]
|
|
|
|
|
|
|
|
|
if (
|
|
|
steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
) and noise_var > 0:
|
|
|
ll_difference = (
|
|
|
eps**2 - (eps + scaled_guidance_update) ** 2
|
|
|
).sum(dim=(-1, -2)) / (2 * noise_var)
|
|
|
else:
|
|
|
ll_difference = torch.zeros_like(energy)
|
|
|
|
|
|
|
|
|
resample_weights = F.softmax(
|
|
|
(ll_difference + steering_args["fk_lambda"] * log_G).reshape(
|
|
|
-1, steering_args["num_particles"]
|
|
|
),
|
|
|
dim=1,
|
|
|
)
|
|
|
|
|
|
|
|
|
if (
|
|
|
steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
) and step_idx < num_sampling_steps - 1:
|
|
|
guidance_update = torch.zeros_like(atom_coords_denoised)
|
|
|
for guidance_step in range(steering_args["num_gd_steps"]):
|
|
|
energy_gradient = torch.zeros_like(atom_coords_denoised)
|
|
|
for potential in potentials:
|
|
|
parameters = potential.compute_parameters(steering_t)
|
|
|
if (
|
|
|
parameters["guidance_weight"] > 0
|
|
|
and (guidance_step) % parameters["guidance_interval"]
|
|
|
== 0
|
|
|
):
|
|
|
energy_gradient += parameters[
|
|
|
"guidance_weight"
|
|
|
] * potential.compute_gradient(
|
|
|
atom_coords_denoised + guidance_update,
|
|
|
network_condition_kwargs["feats"],
|
|
|
parameters,
|
|
|
)
|
|
|
guidance_update -= energy_gradient
|
|
|
atom_coords_denoised += guidance_update
|
|
|
scaled_guidance_update = (
|
|
|
guidance_update
|
|
|
* -1
|
|
|
* self.step_scale
|
|
|
* (sigma_t - t_hat)
|
|
|
/ t_hat
|
|
|
)
|
|
|
|
|
|
if steering_args["fk_steering"] and (
|
|
|
(
|
|
|
step_idx % steering_args["fk_resampling_interval"] == 0
|
|
|
and noise_var > 0
|
|
|
)
|
|
|
or step_idx == num_sampling_steps - 1
|
|
|
):
|
|
|
resample_indices = (
|
|
|
torch.multinomial(
|
|
|
resample_weights,
|
|
|
resample_weights.shape[1]
|
|
|
if step_idx < num_sampling_steps - 1
|
|
|
else 1,
|
|
|
replacement=True,
|
|
|
)
|
|
|
+ resample_weights.shape[1]
|
|
|
* torch.arange(
|
|
|
resample_weights.shape[0], device=resample_weights.device
|
|
|
).unsqueeze(-1)
|
|
|
).flatten()
|
|
|
|
|
|
atom_coords = atom_coords[resample_indices]
|
|
|
atom_coords_noisy = atom_coords_noisy[resample_indices]
|
|
|
atom_mask = atom_mask[resample_indices]
|
|
|
if atom_coords_denoised is not None:
|
|
|
atom_coords_denoised = atom_coords_denoised[resample_indices]
|
|
|
energy_traj = energy_traj[resample_indices]
|
|
|
if (
|
|
|
steering_args["physical_guidance_update"]
|
|
|
or steering_args["contact_guidance_update"]
|
|
|
):
|
|
|
scaled_guidance_update = scaled_guidance_update[
|
|
|
resample_indices
|
|
|
]
|
|
|
if token_repr is not None:
|
|
|
token_repr = token_repr[resample_indices]
|
|
|
|
|
|
if self.alignment_reverse_diff:
|
|
|
with torch.autocast("cuda", enabled=False):
|
|
|
atom_coords_noisy = weighted_rigid_align(
|
|
|
atom_coords_noisy.float(),
|
|
|
atom_coords_denoised.float(),
|
|
|
atom_mask.float(),
|
|
|
atom_mask.float(),
|
|
|
)
|
|
|
|
|
|
atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
|
|
|
|
|
|
denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
|
|
|
atom_coords_next = (
|
|
|
atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma
|
|
|
)
|
|
|
|
|
|
atom_coords = atom_coords_next
|
|
|
|
|
|
return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
|
|
|
|
|
|
def loss_weight(self, sigma):
|
|
|
return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
|
|
|
|
|
|
def noise_distribution(self, batch_size):
|
|
|
return (
|
|
|
self.sigma_data
|
|
|
* (
|
|
|
self.P_mean
|
|
|
+ self.P_std * torch.randn((batch_size,), device=self.device)
|
|
|
).exp()
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
s_inputs,
|
|
|
s_trunk,
|
|
|
feats,
|
|
|
diffusion_conditioning,
|
|
|
multiplicity=1,
|
|
|
):
|
|
|
|
|
|
batch_size = feats["coords"].shape[0] // multiplicity
|
|
|
|
|
|
if self.synchronize_sigmas:
|
|
|
sigmas = self.noise_distribution(batch_size).repeat_interleave(
|
|
|
multiplicity, 0
|
|
|
)
|
|
|
else:
|
|
|
sigmas = self.noise_distribution(batch_size * multiplicity)
|
|
|
padded_sigmas = rearrange(sigmas, "b -> b 1 1")
|
|
|
|
|
|
atom_coords = feats["coords"]
|
|
|
|
|
|
atom_mask = feats["atom_pad_mask"]
|
|
|
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
|
|
|
|
|
atom_coords = center_random_augmentation(
|
|
|
atom_coords, atom_mask, augmentation=self.coordinate_augmentation
|
|
|
)
|
|
|
|
|
|
noise = torch.randn_like(atom_coords)
|
|
|
noised_atom_coords = atom_coords + padded_sigmas * noise
|
|
|
|
|
|
denoised_atom_coords = self.preconditioned_network_forward(
|
|
|
noised_atom_coords,
|
|
|
sigmas,
|
|
|
network_condition_kwargs={
|
|
|
"s_inputs": s_inputs,
|
|
|
"s_trunk": s_trunk,
|
|
|
"feats": feats,
|
|
|
"multiplicity": multiplicity,
|
|
|
"diffusion_conditioning": diffusion_conditioning,
|
|
|
},
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"denoised_atom_coords": denoised_atom_coords,
|
|
|
"sigmas": sigmas,
|
|
|
"aligned_true_atom_coords": atom_coords,
|
|
|
}
|
|
|
|
|
|
def compute_loss(
|
|
|
self,
|
|
|
feats,
|
|
|
out_dict,
|
|
|
add_smooth_lddt_loss=True,
|
|
|
nucleotide_loss_weight=5.0,
|
|
|
ligand_loss_weight=10.0,
|
|
|
multiplicity=1,
|
|
|
filter_by_plddt=0.0,
|
|
|
):
|
|
|
with torch.autocast("cuda", enabled=False):
|
|
|
denoised_atom_coords = out_dict["denoised_atom_coords"].float()
|
|
|
sigmas = out_dict["sigmas"].float()
|
|
|
|
|
|
resolved_atom_mask_uni = feats["atom_resolved_mask"].float()
|
|
|
|
|
|
if filter_by_plddt > 0:
|
|
|
plddt_mask = feats["plddt"] > filter_by_plddt
|
|
|
resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float()
|
|
|
|
|
|
resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave(
|
|
|
multiplicity, 0
|
|
|
)
|
|
|
|
|
|
align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2])
|
|
|
atom_type = (
|
|
|
torch.bmm(
|
|
|
feats["atom_to_token"].float(),
|
|
|
feats["mol_type"].unsqueeze(-1).float(),
|
|
|
)
|
|
|
.squeeze(-1)
|
|
|
.long()
|
|
|
)
|
|
|
atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
|
|
|
|
|
|
align_weights = (
|
|
|
align_weights
|
|
|
* (
|
|
|
1
|
|
|
+ nucleotide_loss_weight
|
|
|
* (
|
|
|
torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
|
|
|
+ torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
|
|
|
)
|
|
|
+ ligand_loss_weight
|
|
|
* torch.eq(
|
|
|
atom_type_mult, const.chain_type_ids["NONPOLYMER"]
|
|
|
).float()
|
|
|
).float()
|
|
|
)
|
|
|
|
|
|
atom_coords = out_dict["aligned_true_atom_coords"].float()
|
|
|
atom_coords_aligned_ground_truth = weighted_rigid_align(
|
|
|
atom_coords.detach(),
|
|
|
denoised_atom_coords.detach(),
|
|
|
align_weights.detach(),
|
|
|
mask=feats["atom_resolved_mask"]
|
|
|
.float()
|
|
|
.repeat_interleave(multiplicity, 0)
|
|
|
.detach(),
|
|
|
)
|
|
|
|
|
|
|
|
|
atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
|
|
|
denoised_atom_coords
|
|
|
)
|
|
|
|
|
|
|
|
|
mse_loss = (
|
|
|
(denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2
|
|
|
).sum(dim=-1)
|
|
|
mse_loss = torch.sum(
|
|
|
mse_loss * align_weights * resolved_atom_mask, dim=-1
|
|
|
) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5)
|
|
|
|
|
|
|
|
|
loss_weights = self.loss_weight(sigmas)
|
|
|
mse_loss = (mse_loss * loss_weights).mean()
|
|
|
|
|
|
total_loss = mse_loss
|
|
|
|
|
|
|
|
|
lddt_loss = self.zero
|
|
|
if add_smooth_lddt_loss:
|
|
|
lddt_loss = smooth_lddt_loss(
|
|
|
denoised_atom_coords,
|
|
|
feats["coords"],
|
|
|
torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
|
|
|
+ torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
|
|
|
coords_mask=resolved_atom_mask_uni,
|
|
|
multiplicity=multiplicity,
|
|
|
)
|
|
|
|
|
|
total_loss = total_loss + lddt_loss
|
|
|
|
|
|
loss_breakdown = {
|
|
|
"mse_loss": mse_loss,
|
|
|
"smooth_lddt_loss": lddt_loss,
|
|
|
}
|
|
|
|
|
|
return {"loss": total_loss, "loss_breakdown": loss_breakdown}
|
|
|
|