Boltz2 / vb_modules_diffusionv2.py
lhallee's picture
Upload folder using huggingface_hub
827d9ec verified
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
from __future__ import annotations
from math import sqrt
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from einops import rearrange
from torch import nn
from torch.nn import Module
from . import vb_const as const
from . import vb_layers_initialize as init
from .vb_loss_diffusionv2 import (
smooth_lddt_loss,
weighted_rigid_align,
)
from .vb_modules_encodersv2 import (
AtomAttentionDecoder,
AtomAttentionEncoder,
SingleConditioning,
)
from .vb_modules_transformersv2 import (
DiffusionTransformer,
)
from .vb_modules_utils import (
LinearNoBias,
center_random_augmentation,
compute_random_augmentation,
default,
log,
)
from .vb_potentials_potentials import get_potentials
class DiffusionModule(Module):
"""Diffusion module"""
def __init__(
self,
token_s: int,
atom_s: int,
atoms_per_window_queries: int = 32,
atoms_per_window_keys: int = 128,
sigma_data: int = 16,
dim_fourier: int = 256,
atom_encoder_depth: int = 3,
atom_encoder_heads: int = 4,
token_transformer_depth: int = 24,
token_transformer_heads: int = 8,
atom_decoder_depth: int = 3,
atom_decoder_heads: int = 4,
conditioning_transition_layers: int = 2,
activation_checkpointing: bool = False,
transformer_post_ln: bool = False,
) -> None:
super().__init__()
self.atoms_per_window_queries = atoms_per_window_queries
self.atoms_per_window_keys = atoms_per_window_keys
self.sigma_data = sigma_data
self.activation_checkpointing = activation_checkpointing
# conditioning
self.single_conditioner = SingleConditioning(
sigma_data=sigma_data,
token_s=token_s,
dim_fourier=dim_fourier,
num_transitions=conditioning_transition_layers,
)
self.atom_attention_encoder = AtomAttentionEncoder(
atom_s=atom_s,
token_s=token_s,
atoms_per_window_queries=atoms_per_window_queries,
atoms_per_window_keys=atoms_per_window_keys,
atom_encoder_depth=atom_encoder_depth,
atom_encoder_heads=atom_encoder_heads,
structure_prediction=True,
activation_checkpointing=activation_checkpointing,
transformer_post_layer_norm=transformer_post_ln,
)
self.s_to_a_linear = nn.Sequential(
nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
)
init.final_init_(self.s_to_a_linear[1].weight)
self.token_transformer = DiffusionTransformer(
dim=2 * token_s,
dim_single_cond=2 * token_s,
depth=token_transformer_depth,
heads=token_transformer_heads,
activation_checkpointing=activation_checkpointing,
# post_layer_norm=transformer_post_ln,
)
self.a_norm = nn.LayerNorm(
2 * token_s
) # if not transformer_post_ln else nn.Identity()
self.atom_attention_decoder = AtomAttentionDecoder(
atom_s=atom_s,
token_s=token_s,
attn_window_queries=atoms_per_window_queries,
attn_window_keys=atoms_per_window_keys,
atom_decoder_depth=atom_decoder_depth,
atom_decoder_heads=atom_decoder_heads,
activation_checkpointing=activation_checkpointing,
# transformer_post_layer_norm=transformer_post_ln,
)
def forward(
self,
s_inputs, # Float['b n ts']
s_trunk, # Float['b n ts']
r_noisy, # Float['bm m 3']
times, # Float['bm 1 1']
feats,
diffusion_conditioning,
multiplicity=1,
):
if self.activation_checkpointing and self.training:
s, normed_fourier = torch.utils.checkpoint.checkpoint(
self.single_conditioner,
times,
s_trunk.repeat_interleave(multiplicity, 0),
s_inputs.repeat_interleave(multiplicity, 0),
)
else:
s, normed_fourier = self.single_conditioner(
times,
s_trunk.repeat_interleave(multiplicity, 0),
s_inputs.repeat_interleave(multiplicity, 0),
)
# Sequence-local Atom Attention and aggregation to coarse-grained tokens
a, q_skip, c_skip, to_keys = self.atom_attention_encoder(
feats=feats,
q=diffusion_conditioning["q"].float(),
c=diffusion_conditioning["c"].float(),
atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(),
to_keys=diffusion_conditioning["to_keys"],
r=r_noisy, # Float['b m 3'],
multiplicity=multiplicity,
)
# Full self-attention on token level
a = a + self.s_to_a_linear(s)
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
a = self.token_transformer(
a,
mask=mask.float(),
s=s,
bias=diffusion_conditioning[
"token_trans_bias"
].float(), # note z is not expanded with multiplicity until after bias is computed
multiplicity=multiplicity,
)
a = self.a_norm(a)
# Broadcast token activations to atoms and run Sequence-local Atom Attention
r_update = self.atom_attention_decoder(
a=a,
q=q_skip,
c=c_skip,
atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(),
feats=feats,
multiplicity=multiplicity,
to_keys=to_keys,
)
return r_update
class AtomDiffusion(Module):
def __init__(
self,
score_model_args,
num_sampling_steps: int = 5, # number of sampling steps
sigma_min: float = 0.0004, # min noise level
sigma_max: float = 160.0, # max noise level
sigma_data: float = 16.0, # standard deviation of data distribution
rho: float = 7, # controls the sampling schedule
P_mean: float = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std: float = 1.5, # standard deviation of log-normal distribution from which noise is drawn for training
gamma_0: float = 0.8,
gamma_min: float = 1.0,
noise_scale: float = 1.003,
step_scale: float = 1.5,
step_scale_random: list = None,
coordinate_augmentation: bool = True,
coordinate_augmentation_inference=None,
compile_score: bool = False,
alignment_reverse_diff: bool = False,
synchronize_sigmas: bool = False,
):
super().__init__()
self.score_model = DiffusionModule(
**score_model_args,
)
if compile_score:
self.score_model = torch.compile(
self.score_model, dynamic=False, fullgraph=False
)
# parameters
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.rho = rho
self.P_mean = P_mean
self.P_std = P_std
self.num_sampling_steps = num_sampling_steps
self.gamma_0 = gamma_0
self.gamma_min = gamma_min
self.noise_scale = noise_scale
self.step_scale = step_scale
self.step_scale_random = step_scale_random
self.coordinate_augmentation = coordinate_augmentation
self.coordinate_augmentation_inference = (
coordinate_augmentation_inference
if coordinate_augmentation_inference is not None
else coordinate_augmentation
)
self.alignment_reverse_diff = alignment_reverse_diff
self.synchronize_sigmas = synchronize_sigmas
self.token_s = score_model_args["token_s"]
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
@property
def device(self):
return next(self.score_model.parameters()).device
def c_skip(self, sigma):
return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
def c_out(self, sigma):
return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
def c_in(self, sigma):
return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
def c_noise(self, sigma):
return log(sigma / self.sigma_data) * 0.25
def preconditioned_network_forward(
self,
noised_atom_coords, #: Float['b m 3'],
sigma, #: Float['b'] | Float[' '] | float,
network_condition_kwargs: dict,
):
batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
if isinstance(sigma, float):
sigma = torch.full((batch,), sigma, device=device)
padded_sigma = rearrange(sigma, "b -> b 1 1")
r_update = self.score_model(
r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
times=self.c_noise(sigma),
**network_condition_kwargs,
)
denoised_coords = (
self.c_skip(padded_sigma) * noised_atom_coords
+ self.c_out(padded_sigma) * r_update
)
return denoised_coords
def sample_schedule(self, num_sampling_steps=None):
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
inv_rho = 1 / self.rho
steps = torch.arange(
num_sampling_steps, device=self.device, dtype=torch.float32
)
sigmas = (
self.sigma_max**inv_rho
+ steps
/ (num_sampling_steps - 1)
* (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
) ** self.rho
sigmas = sigmas * self.sigma_data
sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
return sigmas
def sample(
self,
atom_mask,
num_sampling_steps=None,
multiplicity=1,
max_parallel_samples=None,
steering_args=None,
**network_condition_kwargs,
):
if steering_args is not None and (
steering_args["fk_steering"]
or steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
):
potentials = get_potentials(steering_args, boltz2=True)
if steering_args["fk_steering"]:
multiplicity = multiplicity * steering_args["num_particles"]
energy_traj = torch.empty((multiplicity, 0), device=self.device)
resample_weights = torch.ones(multiplicity, device=self.device).reshape(
-1, steering_args["num_particles"]
)
if (
steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
):
scaled_guidance_update = torch.zeros(
(multiplicity, *atom_mask.shape[1:], 3),
dtype=torch.float32,
device=self.device,
)
if max_parallel_samples is None:
max_parallel_samples = multiplicity
num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
shape = (*atom_mask.shape, 3)
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
sigmas = self.sample_schedule(num_sampling_steps)
gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
if self.training and self.step_scale_random is not None:
step_scale = np.random.choice(self.step_scale_random)
else:
step_scale = self.step_scale
# atom position is noise at the beginning
init_sigma = sigmas[0]
atom_coords = init_sigma * torch.randn(shape, device=self.device)
token_repr = None
atom_coords_denoised = None
# gradually denoise
for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas):
random_R, random_tr = compute_random_augmentation(
multiplicity, device=atom_coords.device, dtype=atom_coords.dtype
)
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
atom_coords = (
torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr
)
if atom_coords_denoised is not None:
atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True)
atom_coords_denoised = (
torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R)
+ random_tr
)
if (
steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
) and scaled_guidance_update is not None:
scaled_guidance_update = torch.einsum(
"bmd,bds->bms", scaled_guidance_update, random_R
)
sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
t_hat = sigma_tm * (1 + gamma)
steering_t = 1.0 - (step_idx / num_sampling_steps)
noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2)
eps = sqrt(noise_var) * torch.randn(shape, device=self.device)
atom_coords_noisy = atom_coords + eps
with torch.no_grad():
atom_coords_denoised = torch.zeros_like(atom_coords_noisy)
sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device)
sample_ids_chunks = sample_ids.chunk(
multiplicity % max_parallel_samples + 1
)
for sample_ids_chunk in sample_ids_chunks:
atom_coords_denoised_chunk = self.preconditioned_network_forward(
atom_coords_noisy[sample_ids_chunk],
t_hat,
network_condition_kwargs=dict(
multiplicity=sample_ids_chunk.numel(),
**network_condition_kwargs,
),
)
atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk
if steering_args["fk_steering"] and (
(
step_idx % steering_args["fk_resampling_interval"] == 0
and noise_var > 0
)
or step_idx == num_sampling_steps - 1
):
# Compute energy of x_0 prediction
energy = torch.zeros(multiplicity, device=self.device)
for potential in potentials:
parameters = potential.compute_parameters(steering_t)
if parameters["resampling_weight"] > 0:
component_energy = potential.compute(
atom_coords_denoised,
network_condition_kwargs["feats"],
parameters,
)
energy += parameters["resampling_weight"] * component_energy
energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1)
# Compute log G values
if step_idx == 0:
log_G = -1 * energy
else:
log_G = energy_traj[:, -2] - energy_traj[:, -1]
# Compute ll difference between guided and unguided transition distribution
if (
steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
) and noise_var > 0:
ll_difference = (
eps**2 - (eps + scaled_guidance_update) ** 2
).sum(dim=(-1, -2)) / (2 * noise_var)
else:
ll_difference = torch.zeros_like(energy)
# Compute resampling weights
resample_weights = F.softmax(
(ll_difference + steering_args["fk_lambda"] * log_G).reshape(
-1, steering_args["num_particles"]
),
dim=1,
)
# Compute guidance update to x_0 prediction
if (
steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
) and step_idx < num_sampling_steps - 1:
guidance_update = torch.zeros_like(atom_coords_denoised)
for guidance_step in range(steering_args["num_gd_steps"]):
energy_gradient = torch.zeros_like(atom_coords_denoised)
for potential in potentials:
parameters = potential.compute_parameters(steering_t)
if (
parameters["guidance_weight"] > 0
and (guidance_step) % parameters["guidance_interval"]
== 0
):
energy_gradient += parameters[
"guidance_weight"
] * potential.compute_gradient(
atom_coords_denoised + guidance_update,
network_condition_kwargs["feats"],
parameters,
)
guidance_update -= energy_gradient
atom_coords_denoised += guidance_update
scaled_guidance_update = (
guidance_update
* -1
* self.step_scale
* (sigma_t - t_hat)
/ t_hat
)
if steering_args["fk_steering"] and (
(
step_idx % steering_args["fk_resampling_interval"] == 0
and noise_var > 0
)
or step_idx == num_sampling_steps - 1
):
resample_indices = (
torch.multinomial(
resample_weights,
resample_weights.shape[1]
if step_idx < num_sampling_steps - 1
else 1,
replacement=True,
)
+ resample_weights.shape[1]
* torch.arange(
resample_weights.shape[0], device=resample_weights.device
).unsqueeze(-1)
).flatten()
atom_coords = atom_coords[resample_indices]
atom_coords_noisy = atom_coords_noisy[resample_indices]
atom_mask = atom_mask[resample_indices]
if atom_coords_denoised is not None:
atom_coords_denoised = atom_coords_denoised[resample_indices]
energy_traj = energy_traj[resample_indices]
if (
steering_args["physical_guidance_update"]
or steering_args["contact_guidance_update"]
):
scaled_guidance_update = scaled_guidance_update[
resample_indices
]
if token_repr is not None:
token_repr = token_repr[resample_indices]
if self.alignment_reverse_diff:
with torch.autocast("cuda", enabled=False):
atom_coords_noisy = weighted_rigid_align(
atom_coords_noisy.float(),
atom_coords_denoised.float(),
atom_mask.float(),
atom_mask.float(),
)
atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
atom_coords_next = (
atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma
)
atom_coords = atom_coords_next
return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
def loss_weight(self, sigma):
return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
def noise_distribution(self, batch_size):
return (
self.sigma_data
* (
self.P_mean
+ self.P_std * torch.randn((batch_size,), device=self.device)
).exp()
)
def forward(
self,
s_inputs,
s_trunk,
feats,
diffusion_conditioning,
multiplicity=1,
):
# training diffusion step
batch_size = feats["coords"].shape[0] // multiplicity
if self.synchronize_sigmas:
sigmas = self.noise_distribution(batch_size).repeat_interleave(
multiplicity, 0
)
else:
sigmas = self.noise_distribution(batch_size * multiplicity)
padded_sigmas = rearrange(sigmas, "b -> b 1 1")
atom_coords = feats["coords"]
atom_mask = feats["atom_pad_mask"]
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
atom_coords = center_random_augmentation(
atom_coords, atom_mask, augmentation=self.coordinate_augmentation
)
noise = torch.randn_like(atom_coords)
noised_atom_coords = atom_coords + padded_sigmas * noise
denoised_atom_coords = self.preconditioned_network_forward(
noised_atom_coords,
sigmas,
network_condition_kwargs={
"s_inputs": s_inputs,
"s_trunk": s_trunk,
"feats": feats,
"multiplicity": multiplicity,
"diffusion_conditioning": diffusion_conditioning,
},
)
return {
"denoised_atom_coords": denoised_atom_coords,
"sigmas": sigmas,
"aligned_true_atom_coords": atom_coords,
}
def compute_loss(
self,
feats,
out_dict,
add_smooth_lddt_loss=True,
nucleotide_loss_weight=5.0,
ligand_loss_weight=10.0,
multiplicity=1,
filter_by_plddt=0.0,
):
with torch.autocast("cuda", enabled=False):
denoised_atom_coords = out_dict["denoised_atom_coords"].float()
sigmas = out_dict["sigmas"].float()
resolved_atom_mask_uni = feats["atom_resolved_mask"].float()
if filter_by_plddt > 0:
plddt_mask = feats["plddt"] > filter_by_plddt
resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float()
resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave(
multiplicity, 0
)
align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2])
atom_type = (
torch.bmm(
feats["atom_to_token"].float(),
feats["mol_type"].unsqueeze(-1).float(),
)
.squeeze(-1)
.long()
)
atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
align_weights = (
align_weights
* (
1
+ nucleotide_loss_weight
* (
torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
+ torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
)
+ ligand_loss_weight
* torch.eq(
atom_type_mult, const.chain_type_ids["NONPOLYMER"]
).float()
).float()
)
atom_coords = out_dict["aligned_true_atom_coords"].float()
atom_coords_aligned_ground_truth = weighted_rigid_align(
atom_coords.detach(),
denoised_atom_coords.detach(),
align_weights.detach(),
mask=feats["atom_resolved_mask"]
.float()
.repeat_interleave(multiplicity, 0)
.detach(),
)
# Cast back
atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
denoised_atom_coords
)
# weighted MSE loss of denoised atom positions
mse_loss = (
(denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2
).sum(dim=-1)
mse_loss = torch.sum(
mse_loss * align_weights * resolved_atom_mask, dim=-1
) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5)
# weight by sigma factor
loss_weights = self.loss_weight(sigmas)
mse_loss = (mse_loss * loss_weights).mean()
total_loss = mse_loss
# proposed auxiliary smooth lddt loss
lddt_loss = self.zero
if add_smooth_lddt_loss:
lddt_loss = smooth_lddt_loss(
denoised_atom_coords,
feats["coords"],
torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
+ torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
coords_mask=resolved_atom_mask_uni,
multiplicity=multiplicity,
)
total_loss = total_loss + lddt_loss
loss_breakdown = {
"mse_loss": mse_loss,
"smooth_lddt_loss": lddt_loss,
}
return {"loss": total_loss, "loss_breakdown": loss_breakdown}