# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from __future__ import annotations from math import sqrt import numpy as np import torch import torch.nn.functional as F # noqa: N812 from einops import rearrange from torch import nn from torch.nn import Module from . import vb_const as const from . import vb_layers_initialize as init from .vb_loss_diffusionv2 import ( smooth_lddt_loss, weighted_rigid_align, ) from .vb_modules_encodersv2 import ( AtomAttentionDecoder, AtomAttentionEncoder, SingleConditioning, ) from .vb_modules_transformersv2 import ( DiffusionTransformer, ) from .vb_modules_utils import ( LinearNoBias, center_random_augmentation, compute_random_augmentation, default, log, ) from .vb_potentials_potentials import get_potentials class DiffusionModule(Module): """Diffusion module""" def __init__( self, token_s: int, atom_s: int, atoms_per_window_queries: int = 32, atoms_per_window_keys: int = 128, sigma_data: int = 16, dim_fourier: int = 256, atom_encoder_depth: int = 3, atom_encoder_heads: int = 4, token_transformer_depth: int = 24, token_transformer_heads: int = 8, atom_decoder_depth: int = 3, atom_decoder_heads: int = 4, conditioning_transition_layers: int = 2, activation_checkpointing: bool = False, transformer_post_ln: bool = False, ) -> None: super().__init__() self.atoms_per_window_queries = atoms_per_window_queries self.atoms_per_window_keys = atoms_per_window_keys self.sigma_data = sigma_data self.activation_checkpointing = activation_checkpointing # conditioning self.single_conditioner = SingleConditioning( sigma_data=sigma_data, token_s=token_s, dim_fourier=dim_fourier, num_transitions=conditioning_transition_layers, ) self.atom_attention_encoder = AtomAttentionEncoder( atom_s=atom_s, token_s=token_s, atoms_per_window_queries=atoms_per_window_queries, atoms_per_window_keys=atoms_per_window_keys, atom_encoder_depth=atom_encoder_depth, atom_encoder_heads=atom_encoder_heads, structure_prediction=True, activation_checkpointing=activation_checkpointing, transformer_post_layer_norm=transformer_post_ln, ) self.s_to_a_linear = nn.Sequential( nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s) ) init.final_init_(self.s_to_a_linear[1].weight) self.token_transformer = DiffusionTransformer( dim=2 * token_s, dim_single_cond=2 * token_s, depth=token_transformer_depth, heads=token_transformer_heads, activation_checkpointing=activation_checkpointing, # post_layer_norm=transformer_post_ln, ) self.a_norm = nn.LayerNorm( 2 * token_s ) # if not transformer_post_ln else nn.Identity() self.atom_attention_decoder = AtomAttentionDecoder( atom_s=atom_s, token_s=token_s, attn_window_queries=atoms_per_window_queries, attn_window_keys=atoms_per_window_keys, atom_decoder_depth=atom_decoder_depth, atom_decoder_heads=atom_decoder_heads, activation_checkpointing=activation_checkpointing, # transformer_post_layer_norm=transformer_post_ln, ) def forward( self, s_inputs, # Float['b n ts'] s_trunk, # Float['b n ts'] r_noisy, # Float['bm m 3'] times, # Float['bm 1 1'] feats, diffusion_conditioning, multiplicity=1, ): if self.activation_checkpointing and self.training: s, normed_fourier = torch.utils.checkpoint.checkpoint( self.single_conditioner, times, s_trunk.repeat_interleave(multiplicity, 0), s_inputs.repeat_interleave(multiplicity, 0), ) else: s, normed_fourier = self.single_conditioner( times, s_trunk.repeat_interleave(multiplicity, 0), s_inputs.repeat_interleave(multiplicity, 0), ) # Sequence-local Atom Attention and aggregation to coarse-grained tokens a, q_skip, c_skip, to_keys = self.atom_attention_encoder( feats=feats, q=diffusion_conditioning["q"].float(), c=diffusion_conditioning["c"].float(), atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(), to_keys=diffusion_conditioning["to_keys"], r=r_noisy, # Float['b m 3'], multiplicity=multiplicity, ) # Full self-attention on token level a = a + self.s_to_a_linear(s) mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) a = self.token_transformer( a, mask=mask.float(), s=s, bias=diffusion_conditioning[ "token_trans_bias" ].float(), # note z is not expanded with multiplicity until after bias is computed multiplicity=multiplicity, ) a = self.a_norm(a) # Broadcast token activations to atoms and run Sequence-local Atom Attention r_update = self.atom_attention_decoder( a=a, q=q_skip, c=c_skip, atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(), feats=feats, multiplicity=multiplicity, to_keys=to_keys, ) return r_update class AtomDiffusion(Module): def __init__( self, score_model_args, num_sampling_steps: int = 5, # number of sampling steps sigma_min: float = 0.0004, # min noise level sigma_max: float = 160.0, # max noise level sigma_data: float = 16.0, # standard deviation of data distribution rho: float = 7, # controls the sampling schedule P_mean: float = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std: float = 1.5, # standard deviation of log-normal distribution from which noise is drawn for training gamma_0: float = 0.8, gamma_min: float = 1.0, noise_scale: float = 1.003, step_scale: float = 1.5, step_scale_random: list = None, coordinate_augmentation: bool = True, coordinate_augmentation_inference=None, compile_score: bool = False, alignment_reverse_diff: bool = False, synchronize_sigmas: bool = False, ): super().__init__() self.score_model = DiffusionModule( **score_model_args, ) if compile_score: self.score_model = torch.compile( self.score_model, dynamic=False, fullgraph=False ) # parameters self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data self.rho = rho self.P_mean = P_mean self.P_std = P_std self.num_sampling_steps = num_sampling_steps self.gamma_0 = gamma_0 self.gamma_min = gamma_min self.noise_scale = noise_scale self.step_scale = step_scale self.step_scale_random = step_scale_random self.coordinate_augmentation = coordinate_augmentation self.coordinate_augmentation_inference = ( coordinate_augmentation_inference if coordinate_augmentation_inference is not None else coordinate_augmentation ) self.alignment_reverse_diff = alignment_reverse_diff self.synchronize_sigmas = synchronize_sigmas self.token_s = score_model_args["token_s"] self.register_buffer("zero", torch.tensor(0.0), persistent=False) @property def device(self): return next(self.score_model.parameters()).device def c_skip(self, sigma): return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) def c_out(self, sigma): return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) def c_in(self, sigma): return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) def c_noise(self, sigma): return log(sigma / self.sigma_data) * 0.25 def preconditioned_network_forward( self, noised_atom_coords, #: Float['b m 3'], sigma, #: Float['b'] | Float[' '] | float, network_condition_kwargs: dict, ): batch, device = noised_atom_coords.shape[0], noised_atom_coords.device if isinstance(sigma, float): sigma = torch.full((batch,), sigma, device=device) padded_sigma = rearrange(sigma, "b -> b 1 1") r_update = self.score_model( r_noisy=self.c_in(padded_sigma) * noised_atom_coords, times=self.c_noise(sigma), **network_condition_kwargs, ) denoised_coords = ( self.c_skip(padded_sigma) * noised_atom_coords + self.c_out(padded_sigma) * r_update ) return denoised_coords def sample_schedule(self, num_sampling_steps=None): num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) inv_rho = 1 / self.rho steps = torch.arange( num_sampling_steps, device=self.device, dtype=torch.float32 ) sigmas = ( self.sigma_max**inv_rho + steps / (num_sampling_steps - 1) * (self.sigma_min**inv_rho - self.sigma_max**inv_rho) ) ** self.rho sigmas = sigmas * self.sigma_data sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0. return sigmas def sample( self, atom_mask, num_sampling_steps=None, multiplicity=1, max_parallel_samples=None, steering_args=None, **network_condition_kwargs, ): if steering_args is not None and ( steering_args["fk_steering"] or steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ): potentials = get_potentials(steering_args, boltz2=True) if steering_args["fk_steering"]: multiplicity = multiplicity * steering_args["num_particles"] energy_traj = torch.empty((multiplicity, 0), device=self.device) resample_weights = torch.ones(multiplicity, device=self.device).reshape( -1, steering_args["num_particles"] ) if ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ): scaled_guidance_update = torch.zeros( (multiplicity, *atom_mask.shape[1:], 3), dtype=torch.float32, device=self.device, ) if max_parallel_samples is None: max_parallel_samples = multiplicity num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) atom_mask = atom_mask.repeat_interleave(multiplicity, 0) shape = (*atom_mask.shape, 3) # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma sigmas = self.sample_schedule(num_sampling_steps) gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0) sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:])) if self.training and self.step_scale_random is not None: step_scale = np.random.choice(self.step_scale_random) else: step_scale = self.step_scale # atom position is noise at the beginning init_sigma = sigmas[0] atom_coords = init_sigma * torch.randn(shape, device=self.device) token_repr = None atom_coords_denoised = None # gradually denoise for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas): random_R, random_tr = compute_random_augmentation( multiplicity, device=atom_coords.device, dtype=atom_coords.dtype ) atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True) atom_coords = ( torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr ) if atom_coords_denoised is not None: atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True) atom_coords_denoised = ( torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R) + random_tr ) if ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ) and scaled_guidance_update is not None: scaled_guidance_update = torch.einsum( "bmd,bds->bms", scaled_guidance_update, random_R ) sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item() t_hat = sigma_tm * (1 + gamma) steering_t = 1.0 - (step_idx / num_sampling_steps) noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2) eps = sqrt(noise_var) * torch.randn(shape, device=self.device) atom_coords_noisy = atom_coords + eps with torch.no_grad(): atom_coords_denoised = torch.zeros_like(atom_coords_noisy) sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device) sample_ids_chunks = sample_ids.chunk( multiplicity % max_parallel_samples + 1 ) for sample_ids_chunk in sample_ids_chunks: atom_coords_denoised_chunk = self.preconditioned_network_forward( atom_coords_noisy[sample_ids_chunk], t_hat, network_condition_kwargs=dict( multiplicity=sample_ids_chunk.numel(), **network_condition_kwargs, ), ) atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk if steering_args["fk_steering"] and ( ( step_idx % steering_args["fk_resampling_interval"] == 0 and noise_var > 0 ) or step_idx == num_sampling_steps - 1 ): # Compute energy of x_0 prediction energy = torch.zeros(multiplicity, device=self.device) for potential in potentials: parameters = potential.compute_parameters(steering_t) if parameters["resampling_weight"] > 0: component_energy = potential.compute( atom_coords_denoised, network_condition_kwargs["feats"], parameters, ) energy += parameters["resampling_weight"] * component_energy energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1) # Compute log G values if step_idx == 0: log_G = -1 * energy else: log_G = energy_traj[:, -2] - energy_traj[:, -1] # Compute ll difference between guided and unguided transition distribution if ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ) and noise_var > 0: ll_difference = ( eps**2 - (eps + scaled_guidance_update) ** 2 ).sum(dim=(-1, -2)) / (2 * noise_var) else: ll_difference = torch.zeros_like(energy) # Compute resampling weights resample_weights = F.softmax( (ll_difference + steering_args["fk_lambda"] * log_G).reshape( -1, steering_args["num_particles"] ), dim=1, ) # Compute guidance update to x_0 prediction if ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ) and step_idx < num_sampling_steps - 1: guidance_update = torch.zeros_like(atom_coords_denoised) for guidance_step in range(steering_args["num_gd_steps"]): energy_gradient = torch.zeros_like(atom_coords_denoised) for potential in potentials: parameters = potential.compute_parameters(steering_t) if ( parameters["guidance_weight"] > 0 and (guidance_step) % parameters["guidance_interval"] == 0 ): energy_gradient += parameters[ "guidance_weight" ] * potential.compute_gradient( atom_coords_denoised + guidance_update, network_condition_kwargs["feats"], parameters, ) guidance_update -= energy_gradient atom_coords_denoised += guidance_update scaled_guidance_update = ( guidance_update * -1 * self.step_scale * (sigma_t - t_hat) / t_hat ) if steering_args["fk_steering"] and ( ( step_idx % steering_args["fk_resampling_interval"] == 0 and noise_var > 0 ) or step_idx == num_sampling_steps - 1 ): resample_indices = ( torch.multinomial( resample_weights, resample_weights.shape[1] if step_idx < num_sampling_steps - 1 else 1, replacement=True, ) + resample_weights.shape[1] * torch.arange( resample_weights.shape[0], device=resample_weights.device ).unsqueeze(-1) ).flatten() atom_coords = atom_coords[resample_indices] atom_coords_noisy = atom_coords_noisy[resample_indices] atom_mask = atom_mask[resample_indices] if atom_coords_denoised is not None: atom_coords_denoised = atom_coords_denoised[resample_indices] energy_traj = energy_traj[resample_indices] if ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ): scaled_guidance_update = scaled_guidance_update[ resample_indices ] if token_repr is not None: token_repr = token_repr[resample_indices] if self.alignment_reverse_diff: with torch.autocast("cuda", enabled=False): atom_coords_noisy = weighted_rigid_align( atom_coords_noisy.float(), atom_coords_denoised.float(), atom_mask.float(), atom_mask.float(), ) atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised) denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat atom_coords_next = ( atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma ) atom_coords = atom_coords_next return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) def loss_weight(self, sigma): return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2) def noise_distribution(self, batch_size): return ( self.sigma_data * ( self.P_mean + self.P_std * torch.randn((batch_size,), device=self.device) ).exp() ) def forward( self, s_inputs, s_trunk, feats, diffusion_conditioning, multiplicity=1, ): # training diffusion step batch_size = feats["coords"].shape[0] // multiplicity if self.synchronize_sigmas: sigmas = self.noise_distribution(batch_size).repeat_interleave( multiplicity, 0 ) else: sigmas = self.noise_distribution(batch_size * multiplicity) padded_sigmas = rearrange(sigmas, "b -> b 1 1") atom_coords = feats["coords"] atom_mask = feats["atom_pad_mask"] atom_mask = atom_mask.repeat_interleave(multiplicity, 0) atom_coords = center_random_augmentation( atom_coords, atom_mask, augmentation=self.coordinate_augmentation ) noise = torch.randn_like(atom_coords) noised_atom_coords = atom_coords + padded_sigmas * noise denoised_atom_coords = self.preconditioned_network_forward( noised_atom_coords, sigmas, network_condition_kwargs={ "s_inputs": s_inputs, "s_trunk": s_trunk, "feats": feats, "multiplicity": multiplicity, "diffusion_conditioning": diffusion_conditioning, }, ) return { "denoised_atom_coords": denoised_atom_coords, "sigmas": sigmas, "aligned_true_atom_coords": atom_coords, } def compute_loss( self, feats, out_dict, add_smooth_lddt_loss=True, nucleotide_loss_weight=5.0, ligand_loss_weight=10.0, multiplicity=1, filter_by_plddt=0.0, ): with torch.autocast("cuda", enabled=False): denoised_atom_coords = out_dict["denoised_atom_coords"].float() sigmas = out_dict["sigmas"].float() resolved_atom_mask_uni = feats["atom_resolved_mask"].float() if filter_by_plddt > 0: plddt_mask = feats["plddt"] > filter_by_plddt resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float() resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave( multiplicity, 0 ) align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2]) atom_type = ( torch.bmm( feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float(), ) .squeeze(-1) .long() ) atom_type_mult = atom_type.repeat_interleave(multiplicity, 0) align_weights = ( align_weights * ( 1 + nucleotide_loss_weight * ( torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float() + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float() ) + ligand_loss_weight * torch.eq( atom_type_mult, const.chain_type_ids["NONPOLYMER"] ).float() ).float() ) atom_coords = out_dict["aligned_true_atom_coords"].float() atom_coords_aligned_ground_truth = weighted_rigid_align( atom_coords.detach(), denoised_atom_coords.detach(), align_weights.detach(), mask=feats["atom_resolved_mask"] .float() .repeat_interleave(multiplicity, 0) .detach(), ) # Cast back atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to( denoised_atom_coords ) # weighted MSE loss of denoised atom positions mse_loss = ( (denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2 ).sum(dim=-1) mse_loss = torch.sum( mse_loss * align_weights * resolved_atom_mask, dim=-1 ) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5) # weight by sigma factor loss_weights = self.loss_weight(sigmas) mse_loss = (mse_loss * loss_weights).mean() total_loss = mse_loss # proposed auxiliary smooth lddt loss lddt_loss = self.zero if add_smooth_lddt_loss: lddt_loss = smooth_lddt_loss( denoised_atom_coords, feats["coords"], torch.eq(atom_type, const.chain_type_ids["DNA"]).float() + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(), coords_mask=resolved_atom_mask_uni, multiplicity=multiplicity, ) total_loss = total_loss + lddt_loss loss_breakdown = { "mse_loss": mse_loss, "smooth_lddt_loss": lddt_loss, } return {"loss": total_loss, "loss_breakdown": loss_breakdown}