| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| | import logging |
| | from typing import Dict, Optional, Tuple |
| |
|
| | import ml_collections |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from src.common import residue_constants |
| | from src.common.all_atom import compute_backbone |
| | from src.common.rigid_utils import Rotation, Rigid |
| | from src.utils.tensor_utils import ( |
| | tree_map, |
| | tensor_tree_map, |
| | masked_mean, |
| | permute_final_dims, |
| | batched_gather, |
| | sum_except_batch, |
| | inflate_array_like |
| | ) |
| |
|
| |
|
| | def softmax_cross_entropy(logits, labels): |
| | loss = -1 * torch.sum( |
| | labels * torch.nn.functional.log_softmax(logits, dim=-1), |
| | dim=-1, |
| | ) |
| | return loss |
| |
|
| |
|
| | def sigmoid_cross_entropy(logits, labels): |
| | log_p = torch.log(torch.sigmoid(logits)) |
| | log_not_p = torch.log(torch.sigmoid(-logits)) |
| | loss = -labels * log_p - (1 - labels) * log_not_p |
| | return loss |
| |
|
| |
|
| | def torsion_angle_loss( |
| | a, |
| | a_gt, |
| | a_alt_gt, |
| | ): |
| | |
| | norm = torch.norm(a, dim=-1) |
| |
|
| | |
| | a = a / norm.unsqueeze(-1) |
| |
|
| | |
| | diff_norm_gt = torch.norm(a - a_gt, dim=-1) |
| | diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) |
| | min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2) |
| |
|
| | |
| | l_torsion = torch.mean(min_diff, dim=(-1, -2)) |
| | l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) |
| |
|
| | an_weight = 0.02 |
| | return l_torsion + an_weight * l_angle_norm |
| |
|
| |
|
| | def compute_fape( |
| | pred_frames: Rigid, |
| | target_frames: Rigid, |
| | frames_mask: torch.Tensor, |
| | pred_positions: torch.Tensor, |
| | target_positions: torch.Tensor, |
| | positions_mask: torch.Tensor, |
| | length_scale: float, |
| | l1_clamp_distance: Optional[float] = None, |
| | eps=1e-8, |
| | ignore_nan=True, |
| | ) -> torch.Tensor: |
| | """ |
| | Computes FAPE loss. |
| | |
| | Args: |
| | pred_frames: |
| | [*, N_frames] Rigid object of predicted frames |
| | target_frames: |
| | [*, N_frames] Rigid object of ground truth frames |
| | frames_mask: |
| | [*, N_frames] binary mask for the frames |
| | pred_positions: |
| | [*, N_pts, 3] predicted atom positions |
| | target_positions: |
| | [*, N_pts, 3] ground truth positions |
| | positions_mask: |
| | [*, N_pts] positions mask |
| | length_scale: |
| | Length scale by which the loss is divided |
| | l1_clamp_distance: |
| | Cutoff above which distance errors are disregarded |
| | eps: |
| | Small value used to regularize denominators |
| | Returns: |
| | [*] loss tensor |
| | """ |
| | |
| | local_pred_pos = pred_frames.invert()[..., None].apply( |
| | pred_positions[..., None, :, :], |
| | ) |
| | local_target_pos = target_frames.invert()[..., None].apply( |
| | target_positions[..., None, :, :], |
| | ) |
| |
|
| | error_dist = torch.sqrt( |
| | torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps |
| | ) |
| |
|
| | if l1_clamp_distance is not None: |
| | error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) |
| |
|
| | normed_error = error_dist / length_scale |
| | normed_error = normed_error * frames_mask[..., None] |
| | normed_error = normed_error * positions_mask[..., None, :] |
| | if ignore_nan: |
| | normed_error = torch.nan_to_num(normed_error) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | normed_error = torch.sum(normed_error, dim=-1) |
| | normed_error = ( |
| | normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] |
| | ) |
| | normed_error = torch.sum(normed_error, dim=-1) |
| | normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) |
| | return normed_error |
| |
|
| |
|
| | def backbone_loss( |
| | backbone_rigid_tensor: torch.Tensor, |
| | backbone_rigid_mask: torch.Tensor, |
| | traj: torch.Tensor, |
| | use_clamped_fape: Optional[torch.Tensor] = None, |
| | clamp_distance: float = 10.0, |
| | loss_unit_distance: float = 10.0, |
| | eps: float = 1e-4, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | pred_aff = Rigid.from_tensor_7(traj) |
| | pred_aff = Rigid( |
| | Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), |
| | pred_aff.get_trans(), |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) |
| |
|
| | fape_loss = compute_fape( |
| | pred_aff, |
| | gt_aff[None], |
| | backbone_rigid_mask[None], |
| | pred_aff.get_trans(), |
| | gt_aff[None].get_trans(), |
| | backbone_rigid_mask[None], |
| | l1_clamp_distance=clamp_distance, |
| | length_scale=loss_unit_distance, |
| | eps=eps, |
| | ) |
| | if use_clamped_fape is not None: |
| | unclamped_fape_loss = compute_fape( |
| | pred_aff, |
| | gt_aff[None], |
| | backbone_rigid_mask[None], |
| | pred_aff.get_trans(), |
| | gt_aff[None].get_trans(), |
| | backbone_rigid_mask[None], |
| | l1_clamp_distance=None, |
| | length_scale=loss_unit_distance, |
| | eps=eps, |
| | ) |
| |
|
| | fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * ( |
| | 1 - use_clamped_fape |
| | ) |
| |
|
| | |
| | fape_loss = torch.mean(fape_loss) |
| |
|
| | return fape_loss |
| |
|
| |
|
| | def sidechain_loss( |
| | sidechain_frames: torch.Tensor, |
| | sidechain_atom_pos: torch.Tensor, |
| | rigidgroups_gt_frames: torch.Tensor, |
| | rigidgroups_alt_gt_frames: torch.Tensor, |
| | rigidgroups_gt_exists: torch.Tensor, |
| | renamed_atom14_gt_positions: torch.Tensor, |
| | renamed_atom14_gt_exists: torch.Tensor, |
| | alt_naming_is_better: torch.Tensor, |
| | clamp_distance: float = 10.0, |
| | length_scale: float = 10.0, |
| | eps: float = 1e-4, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | renamed_gt_frames = ( |
| | 1.0 - alt_naming_is_better[..., None, None, None] |
| | ) * rigidgroups_gt_frames + alt_naming_is_better[ |
| | ..., None, None, None |
| | ] * rigidgroups_alt_gt_frames |
| |
|
| | |
| | sidechain_frames = sidechain_frames[-1] |
| | batch_dims = sidechain_frames.shape[:-4] |
| | sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) |
| | sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) |
| | renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) |
| | renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) |
| | rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) |
| | sidechain_atom_pos = sidechain_atom_pos[-1] |
| | sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) |
| | renamed_atom14_gt_positions = renamed_atom14_gt_positions.view( |
| | *batch_dims, -1, 3 |
| | ) |
| | renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) |
| |
|
| | fape = compute_fape( |
| | sidechain_frames, |
| | renamed_gt_frames, |
| | rigidgroups_gt_exists, |
| | sidechain_atom_pos, |
| | renamed_atom14_gt_positions, |
| | renamed_atom14_gt_exists, |
| | l1_clamp_distance=clamp_distance, |
| | length_scale=length_scale, |
| | eps=eps, |
| | ) |
| |
|
| | return fape |
| |
|
| |
|
| | def fape_loss( |
| | out: Dict[str, torch.Tensor], |
| | batch: Dict[str, torch.Tensor], |
| | config: ml_collections.ConfigDict, |
| | ) -> torch.Tensor: |
| | bb_loss = backbone_loss( |
| | traj=out["sm"]["frames"], |
| | **{**batch, **config.backbone}, |
| | ) |
| |
|
| | sc_loss = sidechain_loss( |
| | out["sm"]["sidechain_frames"], |
| | out["sm"]["positions"], |
| | **{**batch, **config.sidechain}, |
| | ) |
| |
|
| | loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss |
| | |
| | |
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def supervised_chi_loss( |
| | angles_sin_cos: torch.Tensor, |
| | unnormalized_angles_sin_cos: torch.Tensor, |
| | aatype: torch.Tensor, |
| | seq_mask: torch.Tensor, |
| | chi_mask: torch.Tensor, |
| | chi_angles_sin_cos: torch.Tensor, |
| | chi_weight: float, |
| | angle_norm_weight: float, |
| | eps=1e-6, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | """ |
| | Implements Algorithm 27 (torsionAngleLoss) |
| | |
| | Args: |
| | angles_sin_cos: |
| | [*, N, 7, 2] predicted angles |
| | unnormalized_angles_sin_cos: |
| | The same angles, but unnormalized |
| | aatype: |
| | [*, N] residue indices |
| | seq_mask: |
| | [*, N] sequence mask |
| | chi_mask: |
| | [*, N, 7] angle mask |
| | chi_angles_sin_cos: |
| | [*, N, 7, 2] ground truth angles |
| | chi_weight: |
| | Weight for the angle component of the loss |
| | angle_norm_weight: |
| | Weight for the normalization component of the loss |
| | Returns: |
| | [*] loss tensor |
| | """ |
| | pred_angles = angles_sin_cos[..., 3:, :] |
| | residue_type_one_hot = torch.nn.functional.one_hot( |
| | aatype, |
| | residue_constants.restype_num + 1, |
| | ) |
| | chi_pi_periodic = torch.einsum( |
| | "...ij,jk->ik", |
| | residue_type_one_hot.type(angles_sin_cos.dtype), |
| | angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), |
| | ) |
| |
|
| | true_chi = chi_angles_sin_cos[None] |
| |
|
| | shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) |
| | true_chi_shifted = shifted_mask * true_chi |
| | sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1) |
| | sq_chi_error_shifted = torch.sum( |
| | (true_chi_shifted - pred_angles) ** 2, dim=-1 |
| | ) |
| | sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) |
| | |
| | sq_chi_error = sq_chi_error.permute( |
| | *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1 |
| | ) |
| | sq_chi_loss = masked_mean( |
| | chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3) |
| | ) |
| |
|
| | loss = chi_weight * sq_chi_loss |
| |
|
| | angle_norm = torch.sqrt( |
| | torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps |
| | ) |
| | norm_error = torch.abs(angle_norm - 1.0) |
| | norm_error = norm_error.permute( |
| | *range(len(norm_error.shape))[1:-2], 0, -2, -1 |
| | ) |
| | angle_norm_loss = masked_mean( |
| | seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3) |
| | ) |
| |
|
| | loss = loss + angle_norm_weight * angle_norm_loss |
| |
|
| | |
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def compute_plddt(logits: torch.Tensor) -> torch.Tensor: |
| | num_bins = logits.shape[-1] |
| | bin_width = 1.0 / num_bins |
| | bounds = torch.arange( |
| | start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device |
| | ) |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | pred_lddt_ca = torch.sum( |
| | probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), |
| | dim=-1, |
| | ) |
| | return pred_lddt_ca * 100 |
| |
|
| |
|
| | def lddt( |
| | all_atom_pred_pos: torch.Tensor, |
| | all_atom_positions: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | cutoff: float = 15.0, |
| | eps: float = 1e-10, |
| | per_residue: bool = True, |
| | ) -> torch.Tensor: |
| | n = all_atom_mask.shape[-2] |
| | dmat_true = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | all_atom_positions[..., None, :] |
| | - all_atom_positions[..., None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | dmat_pred = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | all_atom_pred_pos[..., None, :] |
| | - all_atom_pred_pos[..., None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| | dists_to_score = ( |
| | (dmat_true < cutoff) |
| | * all_atom_mask |
| | * permute_final_dims(all_atom_mask, (1, 0)) |
| | * (1.0 - torch.eye(n, device=all_atom_mask.device)) |
| | ) |
| |
|
| | dist_l1 = torch.abs(dmat_true - dmat_pred) |
| |
|
| | score = ( |
| | (dist_l1 < 0.5).type(dist_l1.dtype) |
| | + (dist_l1 < 1.0).type(dist_l1.dtype) |
| | + (dist_l1 < 2.0).type(dist_l1.dtype) |
| | + (dist_l1 < 4.0).type(dist_l1.dtype) |
| | ) |
| | score = score * 0.25 |
| |
|
| | dims = (-1,) if per_residue else (-2, -1) |
| | norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) |
| | score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) |
| |
|
| | return score |
| |
|
| |
|
| | def lddt_ca( |
| | all_atom_pred_pos: torch.Tensor, |
| | all_atom_positions: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | cutoff: float = 15.0, |
| | eps: float = 1e-10, |
| | per_residue: bool = True, |
| | ) -> torch.Tensor: |
| | ca_pos = residue_constants.atom_order["CA"] |
| | all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] |
| | all_atom_positions = all_atom_positions[..., ca_pos, :] |
| | all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] |
| |
|
| | return lddt( |
| | all_atom_pred_pos, |
| | all_atom_positions, |
| | all_atom_mask, |
| | cutoff=cutoff, |
| | eps=eps, |
| | per_residue=per_residue, |
| | ) |
| |
|
| |
|
| | def lddt_loss( |
| | logits: torch.Tensor, |
| | all_atom_pred_pos: torch.Tensor, |
| | all_atom_positions: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | resolution: torch.Tensor, |
| | cutoff: float = 15.0, |
| | no_bins: int = 50, |
| | min_resolution: float = 0.1, |
| | max_resolution: float = 3.0, |
| | eps: float = 1e-10, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | n = all_atom_mask.shape[-2] |
| |
|
| | ca_pos = residue_constants.atom_order["CA"] |
| | all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] |
| | all_atom_positions = all_atom_positions[..., ca_pos, :] |
| | all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] |
| |
|
| | score = lddt( |
| | all_atom_pred_pos, |
| | all_atom_positions, |
| | all_atom_mask, |
| | cutoff=cutoff, |
| | eps=eps |
| | ) |
| |
|
| | score = score.detach() |
| |
|
| | bin_index = torch.floor(score * no_bins).long() |
| | bin_index = torch.clamp(bin_index, max=(no_bins - 1)) |
| | lddt_ca_one_hot = torch.nn.functional.one_hot( |
| | bin_index, num_classes=no_bins |
| | ) |
| |
|
| | errors = softmax_cross_entropy(logits, lddt_ca_one_hot) |
| | all_atom_mask = all_atom_mask.squeeze(-1) |
| | loss = torch.sum(errors * all_atom_mask, dim=-1) / ( |
| | eps + torch.sum(all_atom_mask, dim=-1) |
| | ) |
| |
|
| | loss = loss * ( |
| | (resolution >= min_resolution) & (resolution <= max_resolution) |
| | ) |
| |
|
| | |
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def distogram_loss( |
| | logits, |
| | pseudo_beta, |
| | pseudo_beta_mask, |
| | min_bin=2.3125, |
| | max_bin=21.6875, |
| | no_bins=64, |
| | eps=1e-6, |
| | **kwargs, |
| | ): |
| | boundaries = torch.linspace( |
| | min_bin, |
| | max_bin, |
| | no_bins - 1, |
| | device=logits.device, |
| | ) |
| | boundaries = boundaries ** 2 |
| | |
| | dists = torch.sum( |
| | (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, |
| | dim=-1, |
| | keepdims=True, |
| | ) |
| |
|
| | true_bins = torch.sum(dists > boundaries, dim=-1) |
| |
|
| | errors = softmax_cross_entropy( |
| | logits, |
| | torch.nn.functional.one_hot(true_bins, no_bins), |
| | ) |
| |
|
| | square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] |
| |
|
| | |
| | |
| | |
| | denom = eps + torch.sum(square_mask, dim=(-1, -2)) |
| | mean = errors * square_mask |
| | mean = torch.sum(mean, dim=-1) |
| | mean = mean / denom[..., None] |
| | mean = torch.sum(mean, dim=-1) |
| |
|
| | |
| | mean = torch.mean(mean) |
| |
|
| | return mean |
| |
|
| |
|
| | def _calculate_bin_centers(boundaries: torch.Tensor): |
| | step = boundaries[1] - boundaries[0] |
| | bin_centers = boundaries + step / 2 |
| | bin_centers = torch.cat( |
| | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 |
| | ) |
| | return bin_centers |
| |
|
| |
|
| | def _calculate_expected_aligned_error( |
| | alignment_confidence_breaks: torch.Tensor, |
| | aligned_distance_error_probs: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | bin_centers = _calculate_bin_centers(alignment_confidence_breaks) |
| | return ( |
| | torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), |
| | bin_centers[-1], |
| | ) |
| |
|
| |
|
| | def compute_predicted_aligned_error( |
| | logits: torch.Tensor, |
| | max_bin: int = 31, |
| | no_bins: int = 64, |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Computes aligned confidence metrics from logits. |
| | |
| | Args: |
| | logits: [*, num_res, num_res, num_bins] the logits output from |
| | PredictedAlignedErrorHead. |
| | max_bin: Maximum bin value |
| | no_bins: Number of bins |
| | Returns: |
| | aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted |
| | aligned error probabilities over bins for each residue pair. |
| | predicted_aligned_error: [*, num_res, num_res] the expected aligned distance |
| | error for each pair of residues. |
| | max_predicted_aligned_error: [*] the maximum predicted error possible. |
| | """ |
| | boundaries = torch.linspace( |
| | 0, max_bin, steps=(no_bins - 1), device=logits.device |
| | ) |
| |
|
| | aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) |
| | ( |
| | predicted_aligned_error, |
| | max_predicted_aligned_error, |
| | ) = _calculate_expected_aligned_error( |
| | alignment_confidence_breaks=boundaries, |
| | aligned_distance_error_probs=aligned_confidence_probs, |
| | ) |
| |
|
| | return { |
| | "aligned_confidence_probs": aligned_confidence_probs, |
| | "predicted_aligned_error": predicted_aligned_error, |
| | "max_predicted_aligned_error": max_predicted_aligned_error, |
| | } |
| |
|
| |
|
| | def compute_tm( |
| | logits: torch.Tensor, |
| | residue_weights: Optional[torch.Tensor] = None, |
| | max_bin: int = 31, |
| | no_bins: int = 64, |
| | eps: float = 1e-8, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | if residue_weights is None: |
| | residue_weights = logits.new_ones(logits.shape[-2]) |
| |
|
| | boundaries = torch.linspace( |
| | 0, max_bin, steps=(no_bins - 1), device=logits.device |
| | ) |
| |
|
| | bin_centers = _calculate_bin_centers(boundaries) |
| | torch.sum(residue_weights) |
| | n = logits.shape[-2] |
| | clipped_n = max(n, 19) |
| |
|
| | d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 |
| |
|
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| |
|
| | tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) |
| | predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) |
| |
|
| | normed_residue_mask = residue_weights / (eps + residue_weights.sum()) |
| | per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) |
| | weighted = per_alignment * residue_weights |
| | argmax = (weighted == torch.max(weighted)).nonzero()[0] |
| | return per_alignment[tuple(argmax)] |
| |
|
| |
|
| | def tm_loss( |
| | logits, |
| | final_affine_tensor, |
| | backbone_rigid_tensor, |
| | backbone_rigid_mask, |
| | resolution, |
| | max_bin=31, |
| | no_bins=64, |
| | min_resolution: float = 0.1, |
| | max_resolution: float = 3.0, |
| | eps=1e-8, |
| | **kwargs, |
| | ): |
| | pred_affine = Rigid.from_tensor_7(final_affine_tensor) |
| | backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) |
| |
|
| | def _points(affine): |
| | pts = affine.get_trans()[..., None, :, :] |
| | return affine.invert()[..., None].apply(pts) |
| |
|
| | sq_diff = torch.sum( |
| | (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1 |
| | ) |
| |
|
| | sq_diff = sq_diff.detach() |
| |
|
| | boundaries = torch.linspace( |
| | 0, max_bin, steps=(no_bins - 1), device=logits.device |
| | ) |
| | boundaries = boundaries ** 2 |
| | true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) |
| |
|
| | errors = softmax_cross_entropy( |
| | logits, torch.nn.functional.one_hot(true_bins, no_bins) |
| | ) |
| |
|
| | square_mask = ( |
| | backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :] |
| | ) |
| |
|
| | loss = torch.sum(errors * square_mask, dim=-1) |
| | scale = 0.5 |
| | denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) |
| | loss = loss / denom[..., None] |
| | loss = torch.sum(loss, dim=-1) |
| | loss = loss * scale |
| |
|
| | loss = loss * ( |
| | (resolution >= min_resolution) & (resolution <= max_resolution) |
| | ) |
| |
|
| | |
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def between_residue_bond_loss( |
| | pred_atom_positions: torch.Tensor, |
| | pred_atom_mask: torch.Tensor, |
| | residue_index: torch.Tensor, |
| | aatype: torch.Tensor, |
| | tolerance_factor_soft=12.0, |
| | tolerance_factor_hard=12.0, |
| | eps=1e-6, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Flat-bottom loss to penalize structural violations between residues. |
| | |
| | This is a loss penalizing any violation of the geometry around the peptide |
| | bond between consecutive amino acids. This loss corresponds to |
| | Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. |
| | |
| | Args: |
| | pred_atom_positions: Atom positions in atom37/14 representation |
| | pred_atom_mask: Atom mask in atom37/14 representation |
| | residue_index: Residue index for given amino acid, this is assumed to be |
| | monotonically increasing. |
| | aatype: Amino acid type of given residue |
| | tolerance_factor_soft: soft tolerance factor measured in standard deviations |
| | of pdb distributions |
| | tolerance_factor_hard: hard tolerance factor measured in standard deviations |
| | of pdb distributions |
| | |
| | Returns: |
| | Dict containing: |
| | * 'c_n_loss_mean': Loss for peptide bond length violations |
| | * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned |
| | by CA, C, N |
| | * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned |
| | by C, N, CA |
| | * 'per_residue_loss_sum': sum of all losses for each residue |
| | * 'per_residue_violation_mask': mask denoting all residues with violation |
| | present. |
| | """ |
| | |
| | this_ca_pos = pred_atom_positions[..., :-1, 1, :] |
| | this_ca_mask = pred_atom_mask[..., :-1, 1] |
| | this_c_pos = pred_atom_positions[..., :-1, 2, :] |
| | this_c_mask = pred_atom_mask[..., :-1, 2] |
| | next_n_pos = pred_atom_positions[..., 1:, 0, :] |
| | next_n_mask = pred_atom_mask[..., 1:, 0] |
| | next_ca_pos = pred_atom_positions[..., 1:, 1, :] |
| | next_ca_mask = pred_atom_mask[..., 1:, 1] |
| | has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 |
| |
|
| | |
| | c_n_bond_length = torch.sqrt( |
| | eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1) |
| | ) |
| |
|
| | |
| | next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] |
| | gt_length = ( |
| | ~next_is_proline |
| | ) * residue_constants.between_res_bond_length_c_n[ |
| | 0 |
| | ] + next_is_proline * residue_constants.between_res_bond_length_c_n[ |
| | 1 |
| | ] |
| | gt_stddev = ( |
| | ~next_is_proline |
| | ) * residue_constants.between_res_bond_length_stddev_c_n[ |
| | 0 |
| | ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[ |
| | 1 |
| | ] |
| | c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) |
| | c_n_loss_per_residue = torch.nn.functional.relu( |
| | c_n_bond_length_error - tolerance_factor_soft * gt_stddev |
| | ) |
| | mask = this_c_mask * next_n_mask * has_no_gap_mask |
| | c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / ( |
| | torch.sum(mask, dim=-1) + eps |
| | ) |
| | c_n_violation_mask = mask * ( |
| | c_n_bond_length_error > (tolerance_factor_hard * gt_stddev) |
| | ) |
| |
|
| | |
| | ca_c_bond_length = torch.sqrt( |
| | eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1) |
| | ) |
| | n_ca_bond_length = torch.sqrt( |
| | eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1) |
| | ) |
| |
|
| | c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] |
| | c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] |
| | n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] |
| |
|
| | ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) |
| | gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] |
| | gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] |
| | ca_c_n_cos_angle_error = torch.sqrt( |
| | eps + (ca_c_n_cos_angle - gt_angle) ** 2 |
| | ) |
| | ca_c_n_loss_per_residue = torch.nn.functional.relu( |
| | ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev |
| | ) |
| | mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask |
| | ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / ( |
| | torch.sum(mask, dim=-1) + eps |
| | ) |
| | ca_c_n_violation_mask = mask * ( |
| | ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev) |
| | ) |
| |
|
| | c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) |
| | gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] |
| | gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] |
| | c_n_ca_cos_angle_error = torch.sqrt( |
| | eps + torch.square(c_n_ca_cos_angle - gt_angle) |
| | ) |
| | c_n_ca_loss_per_residue = torch.nn.functional.relu( |
| | c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev |
| | ) |
| | mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask |
| | c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / ( |
| | torch.sum(mask, dim=-1) + eps |
| | ) |
| | c_n_ca_violation_mask = mask * ( |
| | c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev) |
| | ) |
| |
|
| | |
| | |
| | per_residue_loss_sum = ( |
| | c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue |
| | ) |
| | per_residue_loss_sum = 0.5 * ( |
| | torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) |
| | + torch.nn.functional.pad(per_residue_loss_sum, (1, 0)) |
| | ) |
| |
|
| | |
| | violation_mask = torch.max( |
| | torch.stack( |
| | [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], |
| | dim=-2, |
| | ), |
| | dim=-2, |
| | )[0] |
| | violation_mask = torch.maximum( |
| | torch.nn.functional.pad(violation_mask, (0, 1)), |
| | torch.nn.functional.pad(violation_mask, (1, 0)), |
| | ) |
| |
|
| | return { |
| | "c_n_loss_mean": c_n_loss, |
| | "ca_c_n_loss_mean": ca_c_n_loss, |
| | "c_n_ca_loss_mean": c_n_ca_loss, |
| | "per_residue_loss_sum": per_residue_loss_sum, |
| | "per_residue_violation_mask": violation_mask, |
| | } |
| |
|
| |
|
| | def between_residue_clash_loss( |
| | atom14_pred_positions: torch.Tensor, |
| | atom14_atom_exists: torch.Tensor, |
| | atom14_atom_radius: torch.Tensor, |
| | residue_index: torch.Tensor, |
| | overlap_tolerance_soft=1.5, |
| | overlap_tolerance_hard=1.5, |
| | eps=1e-10, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Loss to penalize steric clashes between residues. |
| | |
| | This is a loss penalizing any steric clashes due to non bonded atoms in |
| | different peptides coming too close. This loss corresponds to the part with |
| | different residues of |
| | Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. |
| | |
| | Args: |
| | atom14_pred_positions: Predicted positions of atoms in |
| | global prediction frame |
| | atom14_atom_exists: Mask denoting whether atom at positions exists for given |
| | amino acid type |
| | atom14_atom_radius: Van der Waals radius for each atom. |
| | residue_index: Residue index for given amino acid. |
| | overlap_tolerance_soft: Soft tolerance factor. |
| | overlap_tolerance_hard: Hard tolerance factor. |
| | |
| | Returns: |
| | Dict containing: |
| | * 'mean_loss': average clash loss |
| | * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) |
| | * 'per_atom_clash_mask': mask whether atom clashes with any other atom |
| | shape (N, 14) |
| | """ |
| | fp_type = atom14_pred_positions.dtype |
| |
|
| | |
| | |
| | dists = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | atom14_pred_positions[..., :, None, :, None, :] |
| | - atom14_pred_positions[..., None, :, None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | |
| | |
| | dists_mask = ( |
| | atom14_atom_exists[..., :, None, :, None] |
| | * atom14_atom_exists[..., None, :, None, :] |
| | ).type(fp_type) |
| |
|
| | |
| | |
| | |
| | dists_mask = dists_mask * ( |
| | residue_index[..., :, None, None, None] |
| | < residue_index[..., None, :, None, None] |
| | ) |
| |
|
| | |
| | c_one_hot = torch.nn.functional.one_hot( |
| | residue_index.new_tensor(2), num_classes=14 |
| | ) |
| | c_one_hot = c_one_hot.reshape( |
| | *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape |
| | ) |
| | c_one_hot = c_one_hot.type(fp_type) |
| | n_one_hot = torch.nn.functional.one_hot( |
| | residue_index.new_tensor(0), num_classes=14 |
| | ) |
| | n_one_hot = n_one_hot.reshape( |
| | *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape |
| | ) |
| | n_one_hot = n_one_hot.type(fp_type) |
| |
|
| | neighbour_mask = ( |
| | residue_index[..., :, None, None, None] + 1 |
| | ) == residue_index[..., None, :, None, None] |
| | c_n_bonds = ( |
| | neighbour_mask |
| | * c_one_hot[..., None, None, :, None] |
| | * n_one_hot[..., None, None, None, :] |
| | ) |
| | dists_mask = dists_mask * (1.0 - c_n_bonds) |
| |
|
| | |
| | cys = residue_constants.restype_name_to_atom14_names["CYS"] |
| | cys_sg_idx = cys.index("SG") |
| | cys_sg_idx = residue_index.new_tensor(cys_sg_idx) |
| | cys_sg_idx = cys_sg_idx.reshape( |
| | *((1,) * len(residue_index.shape[:-1])), 1 |
| | ).squeeze(-1) |
| | cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) |
| | disulfide_bonds = ( |
| | cys_sg_one_hot[..., None, None, :, None] |
| | * cys_sg_one_hot[..., None, None, None, :] |
| | ) |
| | dists_mask = dists_mask * (1.0 - disulfide_bonds) |
| |
|
| | |
| | |
| | dists_lower_bound = dists_mask * ( |
| | atom14_atom_radius[..., :, None, :, None] |
| | + atom14_atom_radius[..., None, :, None, :] |
| | ) |
| |
|
| | |
| | |
| | dists_to_low_error = dists_mask * torch.nn.functional.relu( |
| | dists_lower_bound - overlap_tolerance_soft - dists |
| | ) |
| |
|
| | |
| | |
| | mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) |
| |
|
| | |
| | |
| | per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum( |
| | dists_to_low_error, axis=(-3, -1) |
| | ) |
| |
|
| | |
| | |
| | clash_mask = dists_mask * ( |
| | dists < (dists_lower_bound - overlap_tolerance_hard) |
| | ) |
| |
|
| | |
| | |
| | per_atom_clash_mask = torch.maximum( |
| | torch.amax(clash_mask, axis=(-4, -2)), |
| | torch.amax(clash_mask, axis=(-3, -1)), |
| | ) |
| |
|
| | return { |
| | "mean_loss": mean_loss, |
| | "per_atom_loss_sum": per_atom_loss_sum, |
| | "per_atom_clash_mask": per_atom_clash_mask, |
| | } |
| |
|
| |
|
| | def within_residue_violations( |
| | atom14_pred_positions: torch.Tensor, |
| | atom14_atom_exists: torch.Tensor, |
| | atom14_dists_lower_bound: torch.Tensor, |
| | atom14_dists_upper_bound: torch.Tensor, |
| | tighten_bounds_for_loss=0.0, |
| | eps=1e-10, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Loss to penalize steric clashes within residues. |
| | |
| | This is a loss penalizing any steric violations or clashes of non-bonded atoms |
| | in a given peptide. This loss corresponds to the part with |
| | the same residues of |
| | Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. |
| | |
| | Args: |
| | atom14_pred_positions ([*, N, 14, 3]): |
| | Predicted positions of atoms in global prediction frame. |
| | atom14_atom_exists ([*, N, 14]): |
| | Mask denoting whether atom at positions exists for given |
| | amino acid type |
| | atom14_dists_lower_bound ([*, N, 14]): |
| | Lower bound on allowed distances. |
| | atom14_dists_upper_bound ([*, N, 14]): |
| | Upper bound on allowed distances |
| | tighten_bounds_for_loss ([*, N]): |
| | Extra factor to tighten loss |
| | |
| | Returns: |
| | Dict containing: |
| | * 'per_atom_loss_sum' ([*, N, 14]): |
| | sum of all clash losses per atom, shape |
| | * 'per_atom_clash_mask' ([*, N, 14]): |
| | mask whether atom clashes with any other atom shape |
| | """ |
| | |
| | dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] |
| | dists_masks = dists_masks.reshape( |
| | *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape |
| | ) |
| | dists_masks = ( |
| | atom14_atom_exists[..., :, :, None] |
| | * atom14_atom_exists[..., :, None, :] |
| | * dists_masks |
| | ) |
| |
|
| | |
| | dists = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | atom14_pred_positions[..., :, :, None, :] |
| | - atom14_pred_positions[..., :, None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | |
| | dists_to_low_error = torch.nn.functional.relu( |
| | atom14_dists_lower_bound + tighten_bounds_for_loss - dists |
| | ) |
| | dists_to_high_error = torch.nn.functional.relu( |
| | dists - (atom14_dists_upper_bound - tighten_bounds_for_loss) |
| | ) |
| | loss = dists_masks * (dists_to_low_error + dists_to_high_error) |
| |
|
| | |
| | per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) |
| |
|
| | |
| | violations = dists_masks * ( |
| | (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound) |
| | ) |
| |
|
| | |
| | per_atom_violations = torch.maximum( |
| | torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0] |
| | ) |
| |
|
| | return { |
| | "per_atom_loss_sum": per_atom_loss_sum, |
| | "per_atom_violations": per_atom_violations, |
| | } |
| |
|
| |
|
| | def find_structural_violations( |
| | batch: Dict[str, torch.Tensor], |
| | atom14_pred_positions: torch.Tensor, |
| | violation_tolerance_factor: float, |
| | clash_overlap_tolerance: float, |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Computes several checks for structural violations.""" |
| |
|
| | |
| | connection_violations = between_residue_bond_loss( |
| | pred_atom_positions=atom14_pred_positions, |
| | pred_atom_mask=batch["atom14_atom_exists"], |
| | residue_index=batch["residue_index"], |
| | aatype=batch["aatype"], |
| | tolerance_factor_soft=violation_tolerance_factor, |
| | tolerance_factor_hard=violation_tolerance_factor, |
| | ) |
| |
|
| | |
| | |
| | |
| | atomtype_radius = [ |
| | residue_constants.van_der_waals_radius[name[0]] |
| | for name in residue_constants.atom_types |
| | ] |
| | atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) |
| | atom14_atom_radius = ( |
| | batch["atom14_atom_exists"] |
| | * atomtype_radius[batch["residx_atom14_to_atom37"]] |
| | ) |
| |
|
| | |
| | between_residue_clashes = between_residue_clash_loss( |
| | atom14_pred_positions=atom14_pred_positions, |
| | atom14_atom_exists=batch["atom14_atom_exists"], |
| | atom14_atom_radius=atom14_atom_radius, |
| | residue_index=batch["residue_index"], |
| | overlap_tolerance_soft=clash_overlap_tolerance, |
| | overlap_tolerance_hard=clash_overlap_tolerance, |
| | ) |
| |
|
| | |
| | |
| | restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( |
| | overlap_tolerance=clash_overlap_tolerance, |
| | bond_length_tolerance_factor=violation_tolerance_factor, |
| | ) |
| | atom14_atom_exists = batch["atom14_atom_exists"] |
| | atom14_dists_lower_bound = atom14_pred_positions.new_tensor( |
| | restype_atom14_bounds["lower_bound"] |
| | )[batch["aatype"]] |
| | atom14_dists_upper_bound = atom14_pred_positions.new_tensor( |
| | restype_atom14_bounds["upper_bound"] |
| | )[batch["aatype"]] |
| | residue_violations = within_residue_violations( |
| | atom14_pred_positions=atom14_pred_positions, |
| | atom14_atom_exists=batch["atom14_atom_exists"], |
| | atom14_dists_lower_bound=atom14_dists_lower_bound, |
| | atom14_dists_upper_bound=atom14_dists_upper_bound, |
| | tighten_bounds_for_loss=0.0, |
| | ) |
| |
|
| | |
| | per_residue_violations_mask = torch.max( |
| | torch.stack( |
| | [ |
| | connection_violations["per_residue_violation_mask"], |
| | torch.max( |
| | between_residue_clashes["per_atom_clash_mask"], dim=-1 |
| | )[0], |
| | torch.max(residue_violations["per_atom_violations"], dim=-1)[0], |
| | ], |
| | dim=-1, |
| | ), |
| | dim=-1, |
| | )[0] |
| |
|
| | return { |
| | "between_residues": { |
| | "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], |
| | "angles_ca_c_n_loss_mean": connection_violations[ |
| | "ca_c_n_loss_mean" |
| | ], |
| | "angles_c_n_ca_loss_mean": connection_violations[ |
| | "c_n_ca_loss_mean" |
| | ], |
| | "connections_per_residue_loss_sum": connection_violations[ |
| | "per_residue_loss_sum" |
| | ], |
| | "connections_per_residue_violation_mask": connection_violations[ |
| | "per_residue_violation_mask" |
| | ], |
| | "clashes_mean_loss": between_residue_clashes["mean_loss"], |
| | "clashes_per_atom_loss_sum": between_residue_clashes[ |
| | "per_atom_loss_sum" |
| | ], |
| | "clashes_per_atom_clash_mask": between_residue_clashes[ |
| | "per_atom_clash_mask" |
| | ], |
| | }, |
| | "within_residues": { |
| | "per_atom_loss_sum": residue_violations[ |
| | "per_atom_loss_sum" |
| | ], |
| | "per_atom_violations": residue_violations[ |
| | "per_atom_violations" |
| | ], |
| | }, |
| | "total_per_residue_violations_mask": per_residue_violations_mask, |
| | } |
| |
|
| |
|
| | def find_structural_violations_np( |
| | batch: Dict[str, np.ndarray], |
| | atom14_pred_positions: np.ndarray, |
| | config: ml_collections.ConfigDict, |
| | ) -> Dict[str, np.ndarray]: |
| | to_tensor = lambda x: torch.tensor(x) |
| | batch = tree_map(to_tensor, batch, np.ndarray) |
| | atom14_pred_positions = to_tensor(atom14_pred_positions) |
| |
|
| | out = find_structural_violations(batch, atom14_pred_positions, **config) |
| |
|
| | to_np = lambda x: np.array(x) |
| | np_out = tensor_tree_map(to_np, out) |
| |
|
| | return np_out |
| |
|
| |
|
| | def extreme_ca_ca_distance_violations( |
| | pred_atom_positions: torch.Tensor, |
| | pred_atom_mask: torch.Tensor, |
| | residue_index: torch.Tensor, |
| | max_angstrom_tolerance=1.5, |
| | eps=1e-6, |
| | ) -> torch.Tensor: |
| | """Counts residues whose Ca is a large distance from its neighbour. |
| | |
| | Measures the fraction of CA-CA pairs between consecutive amino acids that are |
| | more than 'max_angstrom_tolerance' apart. |
| | |
| | Args: |
| | pred_atom_positions: Atom positions in atom37/14 representation |
| | pred_atom_mask: Atom mask in atom37/14 representation |
| | residue_index: Residue index for given amino acid, this is assumed to be |
| | monotonically increasing. |
| | max_angstrom_tolerance: Maximum distance allowed to not count as violation. |
| | Returns: |
| | Fraction of consecutive CA-CA pairs with violation. |
| | """ |
| | this_ca_pos = pred_atom_positions[..., :-1, 1, :] |
| | this_ca_mask = pred_atom_mask[..., :-1, 1] |
| | next_ca_pos = pred_atom_positions[..., 1:, 1, :] |
| | next_ca_mask = pred_atom_mask[..., 1:, 1] |
| | has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 |
| | ca_ca_distance = torch.sqrt( |
| | eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1) |
| | ) |
| | violations = ( |
| | ca_ca_distance - residue_constants.ca_ca |
| | ) > max_angstrom_tolerance |
| | mask = this_ca_mask * next_ca_mask * has_no_gap_mask |
| | mean = masked_mean(mask, violations, -1) |
| | return mean |
| |
|
| |
|
| | def compute_violation_metrics( |
| | batch: Dict[str, torch.Tensor], |
| | atom14_pred_positions: torch.Tensor, |
| | violations: Dict[str, torch.Tensor], |
| | ) -> Dict[str, torch.Tensor]: |
| | """Compute several metrics to assess the structural violations.""" |
| | ret = {} |
| | extreme_ca_ca_violations = extreme_ca_ca_distance_violations( |
| | pred_atom_positions=atom14_pred_positions, |
| | pred_atom_mask=batch["atom14_atom_exists"], |
| | residue_index=batch["residue_index"], |
| | ) |
| | ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations |
| | ret["violations_between_residue_bond"] = masked_mean( |
| | batch["seq_mask"], |
| | violations["between_residues"][ |
| | "connections_per_residue_violation_mask" |
| | ], |
| | dim=-1, |
| | ) |
| | ret["violations_between_residue_clash"] = masked_mean( |
| | mask=batch["seq_mask"], |
| | value=torch.max( |
| | violations["between_residues"]["clashes_per_atom_clash_mask"], |
| | dim=-1, |
| | )[0], |
| | dim=-1, |
| | ) |
| | ret["violations_within_residue"] = masked_mean( |
| | mask=batch["seq_mask"], |
| | value=torch.max( |
| | violations["within_residues"]["per_atom_violations"], dim=-1 |
| | )[0], |
| | dim=-1, |
| | ) |
| | ret["violations_per_residue"] = masked_mean( |
| | mask=batch["seq_mask"], |
| | value=violations["total_per_residue_violations_mask"], |
| | dim=-1, |
| | ) |
| | return ret |
| |
|
| |
|
| | def compute_violation_metrics_np( |
| | batch: Dict[str, np.ndarray], |
| | atom14_pred_positions: np.ndarray, |
| | violations: Dict[str, np.ndarray], |
| | ) -> Dict[str, np.ndarray]: |
| | to_tensor = lambda x: torch.tensor(x) |
| | batch = tree_map(to_tensor, batch, np.ndarray) |
| | atom14_pred_positions = to_tensor(atom14_pred_positions) |
| | violations = tree_map(to_tensor, violations, np.ndarray) |
| |
|
| | out = compute_violation_metrics(batch, atom14_pred_positions, violations) |
| |
|
| | to_np = lambda x: np.array(x) |
| | return tree_map(to_np, out, torch.Tensor) |
| |
|
| |
|
| | def violation_loss( |
| | violations: Dict[str, torch.Tensor], |
| | atom14_atom_exists: torch.Tensor, |
| | eps=1e-6, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | num_atoms = torch.sum(atom14_atom_exists) |
| | l_clash = torch.sum( |
| | violations["between_residues"]["clashes_per_atom_loss_sum"] |
| | + violations["within_residues"]["per_atom_loss_sum"] |
| | ) |
| | l_clash = l_clash / (eps + num_atoms) |
| | loss = ( |
| | violations["between_residues"]["bonds_c_n_loss_mean"] |
| | + violations["between_residues"]["angles_ca_c_n_loss_mean"] |
| | + violations["between_residues"]["angles_c_n_ca_loss_mean"] |
| | + l_clash |
| | ) |
| |
|
| | return loss |
| |
|
| |
|
| | def compute_renamed_ground_truth( |
| | batch: Dict[str, torch.Tensor], |
| | atom14_pred_positions: torch.Tensor, |
| | eps=1e-10, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Find optimal renaming of ground truth based on the predicted positions. |
| | |
| | Alg. 26 "renameSymmetricGroundTruthAtoms" |
| | |
| | This renamed ground truth is then used for all losses, |
| | such that each loss moves the atoms in the same direction. |
| | |
| | Args: |
| | batch: Dictionary containing: |
| | * atom14_gt_positions: Ground truth positions. |
| | * atom14_alt_gt_positions: Ground truth positions with renaming swaps. |
| | * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by |
| | renaming swaps. |
| | * atom14_gt_exists: Mask for which atoms exist in ground truth. |
| | * atom14_alt_gt_exists: Mask for which atoms exist in ground truth |
| | after renaming. |
| | * atom14_atom_exists: Mask for whether each atom is part of the given |
| | amino acid type. |
| | atom14_pred_positions: Array of atom positions in global frame with shape |
| | Returns: |
| | Dictionary containing: |
| | alt_naming_is_better: Array with 1.0 where alternative swap is better. |
| | renamed_atom14_gt_positions: Array of optimal ground truth positions |
| | after renaming swaps are performed. |
| | renamed_atom14_gt_exists: Mask after renaming swap is performed. |
| | """ |
| |
|
| | pred_dists = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | atom14_pred_positions[..., None, :, None, :] |
| | - atom14_pred_positions[..., None, :, None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | atom14_gt_positions = batch["atom14_gt_positions"] |
| | gt_dists = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | atom14_gt_positions[..., None, :, None, :] |
| | - atom14_gt_positions[..., None, :, None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] |
| | alt_gt_dists = torch.sqrt( |
| | eps |
| | + torch.sum( |
| | ( |
| | atom14_alt_gt_positions[..., None, :, None, :] |
| | - atom14_alt_gt_positions[..., None, :, None, :, :] |
| | ) |
| | ** 2, |
| | dim=-1, |
| | ) |
| | ) |
| |
|
| | lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) |
| | alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) |
| |
|
| | atom14_gt_exists = batch["atom14_gt_exists"] |
| | atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] |
| | mask = ( |
| | atom14_gt_exists[..., None, :, None] |
| | * atom14_atom_is_ambiguous[..., None, :, None] |
| | * atom14_gt_exists[..., None, :, None, :] |
| | * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) |
| | ) |
| |
|
| | per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) |
| | alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) |
| |
|
| | fp_type = atom14_pred_positions.dtype |
| | alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) |
| |
|
| | renamed_atom14_gt_positions = ( |
| | 1.0 - alt_naming_is_better[..., None, None] |
| | ) * atom14_gt_positions + alt_naming_is_better[ |
| | ..., None, None |
| | ] * atom14_alt_gt_positions |
| |
|
| | renamed_atom14_gt_mask = ( |
| | 1.0 - alt_naming_is_better[..., None] |
| | ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[ |
| | "atom14_alt_gt_exists" |
| | ] |
| |
|
| | return { |
| | "alt_naming_is_better": alt_naming_is_better, |
| | "renamed_atom14_gt_positions": renamed_atom14_gt_positions, |
| | "renamed_atom14_gt_exists": renamed_atom14_gt_mask, |
| | } |
| |
|
| |
|
| | def experimentally_resolved_loss( |
| | logits: torch.Tensor, |
| | atom37_atom_exists: torch.Tensor, |
| | all_atom_mask: torch.Tensor, |
| | resolution: torch.Tensor, |
| | min_resolution: float, |
| | max_resolution: float, |
| | eps: float = 1e-8, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | errors = sigmoid_cross_entropy(logits, all_atom_mask) |
| | loss = torch.sum(errors * atom37_atom_exists, dim=-1) |
| | loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) |
| | loss = torch.sum(loss, dim=-1) |
| |
|
| | loss = loss * ( |
| | (resolution >= min_resolution) & (resolution <= max_resolution) |
| | ) |
| |
|
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): |
| | """ |
| | Computes BERT-style masked MSA loss. Implements subsection 1.9.9. |
| | |
| | Args: |
| | logits: [*, N_seq, N_res, 23] predicted residue distribution |
| | true_msa: [*, N_seq, N_res] true MSA |
| | bert_mask: [*, N_seq, N_res] MSA mask |
| | Returns: |
| | Masked MSA loss |
| | """ |
| | errors = softmax_cross_entropy( |
| | logits, torch.nn.functional.one_hot(true_msa, num_classes=23) |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | loss = errors * bert_mask |
| | loss = torch.sum(loss, dim=-1) |
| | scale = 0.5 |
| | denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) |
| | loss = loss / denom[..., None] |
| | loss = torch.sum(loss, dim=-1) |
| | loss = loss * scale |
| |
|
| | loss = torch.mean(loss) |
| |
|
| | return loss |
| |
|
| |
|
| | def compute_drmsd(structure_1, structure_2, mask=None): |
| | if(mask is not None): |
| | structure_1 = structure_1 * mask[..., None] |
| | structure_2 = structure_2 * mask[..., None] |
| |
|
| | d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] |
| | d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] |
| |
|
| | d1 = d1 ** 2 |
| | d2 = d2 ** 2 |
| |
|
| | d1 = torch.sqrt(torch.sum(d1, dim=-1)) |
| | d2 = torch.sqrt(torch.sum(d2, dim=-1)) |
| |
|
| | drmsd = d1 - d2 |
| | drmsd = drmsd ** 2 |
| | drmsd = torch.sum(drmsd, dim=(-1, -2)) |
| | n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) |
| | drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) |
| | drmsd = torch.sqrt(drmsd) |
| |
|
| | return drmsd |
| |
|
| |
|
| | def compute_drmsd_np(structure_1, structure_2, mask=None): |
| | structure_1 = torch.tensor(structure_1) |
| | structure_2 = torch.tensor(structure_2) |
| | if(mask is not None): |
| | mask = torch.tensor(mask) |
| |
|
| | return compute_drmsd(structure_1, structure_2, mask) |
| |
|
| |
|
| | def backbone_atom_loss( |
| | pred_atom37: torch.Tensor, |
| | batch: Dict[str, torch.Tensor], |
| | mask: torch.Tensor = None, |
| | eps: float = 1e-4, |
| | t_threshold: Optional[float] = None, |
| | **kwargs, |
| | ): |
| | pred_backb_atoms = pred_atom37[:, :, :5] |
| | gt_rigids = batch['rigids_0'] |
| | gt_psi = batch["torsion_angles_sin_cos"][..., 2, :] |
| | |
| | gt_atom37, atom37_mask, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"]) |
| | gt_backb_atoms, backb_mask = gt_atom37[:, :, :5], atom37_mask[:, :, :5] |
| | |
| | if mask is not None: |
| | backb_mask = backb_mask * mask[..., None] |
| | |
| | backb_atom_loss = torch.sum( |
| | (pred_backb_atoms - gt_backb_atoms)**2 * backb_mask[..., None], |
| | dim=(-1, -2, -3) |
| | ) / (backb_mask.sum(dim=(-1, -2)) + eps) |
| | |
| | if t_threshold is not None: |
| | backb_atom_loss = backb_atom_loss * (batch['t'] < t_threshold) |
| | return torch.mean(backb_atom_loss) |
| |
|
| |
|
| | def pairwise_distance_loss( |
| | pred_atom37: torch.Tensor, |
| | batch: Dict[str, torch.Tensor], |
| | mask: torch.Tensor = None, |
| | eps: float = 1e-4, |
| | t_threshold: Optional[float] = None, |
| | dist_threshold: float = 6.0, |
| | **kwargs, |
| | ): |
| | batch_size, n_res = pred_atom37.shape[:2] |
| | pred_backb_atoms = pred_atom37[:, :, :5].reshape(batch_size, -1, 3) |
| | |
| | gt_rigids = batch['rigids_0'] |
| | gt_psi = batch["torsion_angles_sin_cos"][..., 2, :] |
| | gt_atom37, _, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"]) |
| | gt_backb_atoms = gt_atom37[:, :, :5].reshape(batch_size, -1, 3) |
| | |
| | |
| | residue_mask = batch['seq_mask'] |
| | if mask is not None: |
| | residue_mask = residue_mask * mask |
| | residue_mask = torch.tile(residue_mask[:, :, None], (1, 1, 5)).view(batch_size, -1) |
| | |
| | gt_pwd = torch.linalg.norm( |
| | gt_backb_atoms[:, :, None, :] - gt_backb_atoms[:, None, :, :], |
| | dim=-1 |
| | ) * residue_mask[..., None] |
| | pred_pwd = torch.linalg.norm( |
| | pred_backb_atoms[:, :, None, :] - pred_backb_atoms[:, None, :, :], |
| | dim=-1 |
| | ) * residue_mask[..., None] |
| | |
| | |
| | pair_mask = residue_mask[:, :, None] * residue_mask[:, None, :] |
| | pair_mask = pair_mask * (pred_pwd < dist_threshold) |
| | pwd_loss = torch.sum( |
| | (gt_pwd - pred_pwd)**2 * pair_mask, dim=(-1, -2) |
| | ) / (torch.sum(pair_mask, dim=(-1, -2)) - n_res + eps) |
| | |
| | if t_threshold is not None: |
| | pwd_loss = pwd_loss * (batch['t'] < t_threshold) |
| | return torch.mean(pwd_loss) |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | class ScoreMatchingLoss(nn.Module): |
| | """Aggregation of the various losses described in the supplement""" |
| | def __init__(self, config): |
| | super(ScoreMatchingLoss, self).__init__() |
| | self.config = config |
| |
|
| | def forward(self, out, batch, _return_breakdown=False): |
| | |
| | seq_mask = batch['seq_mask'] |
| | diffuse_mask = 1. - batch['fixed_mask'] |
| | loss_mask = seq_mask * diffuse_mask |
| | _denom = sum_except_batch(loss_mask) + self.config.eps |
| | |
| | |
| | |
| | |
| | pred_rot_score = out['rot_score'] * diffuse_mask[..., None] |
| | pred_trans_score = out['trans_score'] * diffuse_mask[..., None] |
| | gt_rot_score = batch['rot_score'] * diffuse_mask[..., None] |
| | gt_trans_score = batch['trans_score'] * diffuse_mask[..., None] |
| | |
| | trans_score_loss = (gt_trans_score - pred_trans_score) * loss_mask[..., None] |
| | trans_score_loss /= inflate_array_like(batch['trans_score_scaling'], trans_score_loss) |
| | trans_score_loss = torch.sum(trans_score_loss**2, dim=(-1, -2)) / _denom |
| | |
| | trans_x0_loss = (self.config.translation.coordinate_scaling * |
| | (batch['rigids_0'].get_trans() - out['rigids'].get_trans()) * |
| | loss_mask[..., None] |
| | ) |
| | trans_x0_loss = torch.sum(trans_x0_loss**2, dim=(-1, -2)) / _denom |
| | trans_loss = torch.mean( |
| | trans_score_loss * (batch['t'] > self.config.translation.x0_threshold) + |
| | trans_x0_loss * (batch['t'] <= self.config.translation.x0_threshold) |
| | ) |
| | |
| | rot_loss = (gt_rot_score - pred_rot_score) * loss_mask[..., None] |
| | rot_loss /= inflate_array_like(batch['rot_score_scaling'], rot_loss) |
| | rot_loss = torch.mean(torch.sum(rot_loss**2, dim=(-1, -2)) / _denom) |
| | |
| | loss_fns = { |
| | "translation": lambda: trans_loss, |
| | "rotation": lambda: rot_loss, |
| | } |
| | |
| | |
| | if self.config.distogram.enabled: |
| | loss_fns["distogram"] = lambda: distogram_loss( |
| | logits=out["distogram_logits"], |
| | **{**batch, **self.config.distogram}, |
| | ) |
| | if self.config.supervised_chi.enabled: |
| | loss_fns["supervised_chi"] = lambda: supervised_chi_loss( |
| | out["sm"]["angles"], |
| | out["sm"]["unnormalized_angles"], |
| | **{**batch, **self.config.supervised_chi}, |
| | ) |
| | if self.config.lddt.enabled: |
| | loss_fns["lddt"] = lambda: lddt_loss( |
| | logits=out["lddt_logits"], |
| | all_atom_pred_pos=out["final_atom_positions"], |
| | **{**batch, **self.config.lddt}, |
| | ) |
| | if self.config.fape.enabled: |
| | loss_fns["fape"] = lambda: fape_loss( |
| | out, |
| | batch, |
| | self.config.fape, |
| | ) |
| | if self.config.tm.enabled: |
| | loss_fns["tm"] = lambda: tm_loss( |
| | logits=out["tm_logits"], |
| | **{**batch, **out, **self.config.tm}, |
| | ) |
| | if self.config.backbone.enabled: |
| | loss_fns["backbone"] = lambda: backbone_atom_loss( |
| | pred_atom37=out["atom37"], |
| | batch=batch, |
| | mask=loss_mask, |
| | **self.config.backbone, |
| | ) |
| | if self.config.pwd.enabled: |
| | loss_fns["pwd"] = lambda: pairwise_distance_loss( |
| | pred_atom37=out["atom37"], |
| | batch=batch, |
| | mask=loss_mask, |
| | **self.config.pwd, |
| | ) |
| |
|
| | cum_loss = 0. |
| | losses = {} |
| | for loss_name, loss_fn in loss_fns.items(): |
| | weight = self.config[loss_name].weight |
| | loss = loss_fn() |
| | if torch.isnan(loss) or torch.isinf(loss): |
| | logging.warning(f"{loss_name} loss is NaN. Skipping...") |
| | loss = loss.new_tensor(0., requires_grad=True) |
| | cum_loss = cum_loss + weight * loss |
| | losses[loss_name] = loss.detach().clone() |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | losses["loss"] = cum_loss.detach().clone() |
| |
|
| | if not _return_breakdown: |
| | return cum_loss |
| | |
| | return cum_loss, losses |
| |
|