| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from openfold.utils.feats import ( |
| | pseudo_beta_fn, |
| | build_extra_msa_feat, |
| | build_template_angle_feat, |
| | build_template_pair_feat, |
| | atom14_to_atom37, |
| | ) |
| | from openfold.model.embedders import ( |
| | InputEmbedder, |
| | RecyclingEmbedder, |
| | TemplateAngleEmbedder, |
| | TemplatePairEmbedder, |
| | ExtraMSAEmbedder, |
| | ) |
| | from openfold.model.evoformer import EvoformerStack, ExtraMSAStack |
| | from openfold.model.heads import AuxiliaryHeads |
| | import openfold.np.residue_constants as residue_constants |
| | from openfold.model.structure_module import StructureModule |
| | from openfold.model.template import ( |
| | TemplatePairStack, |
| | TemplatePointwiseAttention, |
| | ) |
| | from openfold.utils.loss import ( |
| | compute_plddt, |
| | ) |
| | from openfold.utils.tensor_utils import ( |
| | dict_multimap, |
| | tensor_tree_map, |
| | ) |
| |
|
| |
|
| | class AlphaFold(nn.Module): |
| | """ |
| | Alphafold 2. |
| | |
| | Implements Algorithm 2 (but with training). |
| | """ |
| |
|
| | def __init__(self, config): |
| | """ |
| | Args: |
| | config: |
| | A dict-like config object (like the one in config.py) |
| | """ |
| | super(AlphaFold, self).__init__() |
| |
|
| | self.globals = config.globals |
| | config = config.model |
| | template_config = config.template |
| | extra_msa_config = config.extra_msa |
| |
|
| | |
| | self.input_embedder = InputEmbedder( |
| | **config["input_embedder"], |
| | ) |
| | self.recycling_embedder = RecyclingEmbedder( |
| | **config["recycling_embedder"], |
| | ) |
| | self.template_angle_embedder = TemplateAngleEmbedder( |
| | **template_config["template_angle_embedder"], |
| | ) |
| | self.template_pair_embedder = TemplatePairEmbedder( |
| | **template_config["template_pair_embedder"], |
| | ) |
| | self.template_pair_stack = TemplatePairStack( |
| | **template_config["template_pair_stack"], |
| | ) |
| | self.template_pointwise_att = TemplatePointwiseAttention( |
| | **template_config["template_pointwise_attention"], |
| | ) |
| | self.extra_msa_embedder = ExtraMSAEmbedder( |
| | **extra_msa_config["extra_msa_embedder"], |
| | ) |
| | self.extra_msa_stack = ExtraMSAStack( |
| | **extra_msa_config["extra_msa_stack"], |
| | ) |
| | self.evoformer = EvoformerStack( |
| | **config["evoformer_stack"], |
| | ) |
| | self.structure_module = StructureModule( |
| | **config["structure_module"], |
| | ) |
| |
|
| | self.aux_heads = AuxiliaryHeads( |
| | config["heads"], |
| | ) |
| |
|
| | self.config = config |
| |
|
| | def embed_templates(self, batch, z, pair_mask, templ_dim): |
| | |
| | template_embeds = [] |
| | n_templ = batch["template_aatype"].shape[templ_dim] |
| | for i in range(n_templ): |
| | idx = batch["template_aatype"].new_tensor(i) |
| | single_template_feats = tensor_tree_map( |
| | lambda t: torch.index_select(t, templ_dim, idx), |
| | batch, |
| | ) |
| |
|
| | single_template_embeds = {} |
| | if self.config.template.embed_angles: |
| | template_angle_feat = build_template_angle_feat( |
| | single_template_feats, |
| | ) |
| |
|
| | |
| | a = self.template_angle_embedder(template_angle_feat) |
| |
|
| | single_template_embeds["angle"] = a |
| |
|
| | |
| | t = build_template_pair_feat( |
| | single_template_feats, |
| | inf=self.config.template.inf, |
| | eps=self.config.template.eps, |
| | **self.config.template.distogram, |
| | ).to(z.dtype) |
| | t = self.template_pair_embedder(t) |
| |
|
| | single_template_embeds.update({"pair": t}) |
| |
|
| | template_embeds.append(single_template_embeds) |
| |
|
| | template_embeds = dict_multimap( |
| | partial(torch.cat, dim=templ_dim), |
| | template_embeds, |
| | ) |
| |
|
| | |
| | t = self.template_pair_stack( |
| | template_embeds["pair"], |
| | pair_mask.unsqueeze(-3).to(dtype=z.dtype), |
| | chunk_size=self.globals.chunk_size, |
| | _mask_trans=self.config._mask_trans, |
| | ) |
| |
|
| | |
| | t = self.template_pointwise_att( |
| | t, |
| | z, |
| | template_mask=batch["template_mask"].to(dtype=z.dtype), |
| | chunk_size=self.globals.chunk_size, |
| | ) |
| | t = t * (torch.sum(batch["template_mask"]) > 0) |
| |
|
| | ret = {} |
| | if self.config.template.embed_angles: |
| | ret["template_angle_embedding"] = template_embeds["angle"] |
| |
|
| | ret.update({"template_pair_embedding": t}) |
| |
|
| | return ret |
| |
|
| | def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): |
| | |
| | outputs = {} |
| |
|
| | |
| | dtype = next(self.parameters()).dtype |
| | for k in feats: |
| | if(feats[k].dtype == torch.float32): |
| | feats[k] = feats[k].to(dtype=dtype) |
| |
|
| | |
| | batch_dims = feats["target_feat"].shape[:-2] |
| | no_batch_dims = len(batch_dims) |
| | n = feats["target_feat"].shape[-2] |
| | n_seq = feats["msa_feat"].shape[-3] |
| | device = feats["target_feat"].device |
| |
|
| | |
| | seq_mask = feats["seq_mask"] |
| | pair_mask = seq_mask[..., None] * seq_mask[..., None, :] |
| | msa_mask = feats["msa_mask"] |
| |
|
| | |
| |
|
| | |
| | |
| | m, z = self.input_embedder( |
| | feats["target_feat"], |
| | feats["residue_index"], |
| | feats["msa_feat"], |
| | ) |
| |
|
| | |
| | if None in [m_1_prev, z_prev, x_prev]: |
| | |
| | m_1_prev = m.new_zeros( |
| | (*batch_dims, n, self.config.input_embedder.c_m), |
| | requires_grad=False, |
| | ) |
| |
|
| | |
| | z_prev = z.new_zeros( |
| | (*batch_dims, n, n, self.config.input_embedder.c_z), |
| | requires_grad=False, |
| | ) |
| |
|
| | |
| | x_prev = z.new_zeros( |
| | (*batch_dims, n, residue_constants.atom_type_num, 3), |
| | requires_grad=False, |
| | ) |
| |
|
| | x_prev = pseudo_beta_fn( |
| | feats["aatype"], x_prev, None |
| | ).to(dtype=z.dtype) |
| |
|
| | |
| | |
| | m_1_prev_emb, z_prev_emb = self.recycling_embedder( |
| | m_1_prev, |
| | z_prev, |
| | x_prev, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | if(not _recycle): |
| | m_1_prev_emb *= 0 |
| | z_prev_emb *= 0 |
| |
|
| | |
| | m[..., 0, :, :] += m_1_prev_emb |
| |
|
| | |
| | z += z_prev_emb |
| |
|
| | |
| | del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb |
| |
|
| | |
| | if self.config.template.enabled: |
| | template_feats = { |
| | k: v for k, v in feats.items() if k.startswith("template_") |
| | } |
| | template_embeds = self.embed_templates( |
| | template_feats, |
| | z, |
| | pair_mask.to(dtype=z.dtype), |
| | no_batch_dims, |
| | ) |
| |
|
| | |
| | z = z + template_embeds["template_pair_embedding"] |
| |
|
| | if self.config.template.embed_angles: |
| | |
| | m = torch.cat( |
| | [m, template_embeds["template_angle_embedding"]], |
| | dim=-3 |
| | ) |
| |
|
| | |
| | torsion_angles_mask = feats["template_torsion_angles_mask"] |
| | msa_mask = torch.cat( |
| | [feats["msa_mask"], torsion_angles_mask[..., 2]], |
| | dim=-2 |
| | ) |
| |
|
| | |
| | if self.config.extra_msa.enabled: |
| | |
| | a = self.extra_msa_embedder(build_extra_msa_feat(feats)) |
| |
|
| | |
| | z = self.extra_msa_stack( |
| | a, |
| | z, |
| | msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), |
| | chunk_size=self.globals.chunk_size, |
| | pair_mask=pair_mask.to(dtype=z.dtype), |
| | _mask_trans=self.config._mask_trans, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | m, z, s = self.evoformer( |
| | m, |
| | z, |
| | msa_mask=msa_mask.to(dtype=m.dtype), |
| | pair_mask=pair_mask.to(dtype=z.dtype), |
| | chunk_size=self.globals.chunk_size, |
| | _mask_trans=self.config._mask_trans, |
| | ) |
| |
|
| | outputs["msa"] = m[..., :n_seq, :, :] |
| | outputs["pair"] = z |
| | outputs["single"] = s |
| |
|
| | |
| | outputs["sm"] = self.structure_module( |
| | s, |
| | z, |
| | feats["aatype"], |
| | mask=feats["seq_mask"].to(dtype=s.dtype), |
| | ) |
| | outputs["final_atom_positions"] = atom14_to_atom37( |
| | outputs["sm"]["positions"][-1], feats |
| | ) |
| | outputs["final_atom_mask"] = feats["atom37_atom_exists"] |
| | outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] |
| |
|
| | |
| |
|
| | |
| | m_1_prev = m[..., 0, :, :] |
| |
|
| | |
| | z_prev = z |
| |
|
| | |
| | x_prev = outputs["final_atom_positions"] |
| |
|
| | return outputs, m_1_prev, z_prev, x_prev |
| |
|
| | def _disable_activation_checkpointing(self): |
| | self.template_pair_stack.blocks_per_ckpt = None |
| | self.evoformer.blocks_per_ckpt = None |
| |
|
| | for b in self.extra_msa_stack.blocks: |
| | b.ckpt = False |
| |
|
| | def _enable_activation_checkpointing(self): |
| | self.template_pair_stack.blocks_per_ckpt = ( |
| | self.config.template.template_pair_stack.blocks_per_ckpt |
| | ) |
| | self.evoformer.blocks_per_ckpt = ( |
| | self.config.evoformer_stack.blocks_per_ckpt |
| | ) |
| |
|
| | for b in self.extra_msa_stack.blocks: |
| | b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt |
| |
|
| | def forward(self, batch): |
| | """ |
| | Args: |
| | batch: |
| | Dictionary of arguments outlined in Algorithm 2. Keys must |
| | include the official names of the features in the |
| | supplement subsection 1.2.9. |
| | |
| | The final dimension of each input must have length equal to |
| | the number of recycling iterations. |
| | |
| | Features (without the recycling dimension): |
| | |
| | "aatype" ([*, N_res]): |
| | Contrary to the supplement, this tensor of residue |
| | indices is not one-hot. |
| | "target_feat" ([*, N_res, C_tf]) |
| | One-hot encoding of the target sequence. C_tf is |
| | config.model.input_embedder.tf_dim. |
| | "residue_index" ([*, N_res]) |
| | Tensor whose final dimension consists of |
| | consecutive indices from 0 to N_res. |
| | "msa_feat" ([*, N_seq, N_res, C_msa]) |
| | MSA features, constructed as in the supplement. |
| | C_msa is config.model.input_embedder.msa_dim. |
| | "seq_mask" ([*, N_res]) |
| | 1-D sequence mask |
| | "msa_mask" ([*, N_seq, N_res]) |
| | MSA mask |
| | "pair_mask" ([*, N_res, N_res]) |
| | 2-D pair mask |
| | "extra_msa_mask" ([*, N_extra, N_res]) |
| | Extra MSA mask |
| | "template_mask" ([*, N_templ]) |
| | Template mask (on the level of templates, not |
| | residues) |
| | "template_aatype" ([*, N_templ, N_res]) |
| | Tensor of template residue indices (indices greater |
| | than 19 are clamped to 20 (Unknown)) |
| | "template_all_atom_positions" |
| | ([*, N_templ, N_res, 37, 3]) |
| | Template atom coordinates in atom37 format |
| | "template_all_atom_mask" ([*, N_templ, N_res, 37]) |
| | Template atom coordinate mask |
| | "template_pseudo_beta" ([*, N_templ, N_res, 3]) |
| | Positions of template carbon "pseudo-beta" atoms |
| | (i.e. C_beta for all residues but glycine, for |
| | for which C_alpha is used instead) |
| | "template_pseudo_beta_mask" ([*, N_templ, N_res]) |
| | Pseudo-beta mask |
| | """ |
| | |
| | m_1_prev, z_prev, x_prev = None, None, None |
| |
|
| | |
| | is_grad_enabled = torch.is_grad_enabled() |
| | self._disable_activation_checkpointing() |
| |
|
| | |
| | num_iters = batch["aatype"].shape[-1] |
| | for cycle_no in range(num_iters): |
| | |
| | fetch_cur_batch = lambda t: t[..., cycle_no] |
| | feats = tensor_tree_map(fetch_cur_batch, batch) |
| |
|
| | |
| | is_final_iter = cycle_no == (num_iters - 1) |
| | with torch.set_grad_enabled(is_grad_enabled and is_final_iter): |
| | if is_final_iter: |
| | self._enable_activation_checkpointing() |
| | |
| | if torch.is_autocast_enabled(): |
| | torch.clear_autocast_cache() |
| |
|
| | |
| | outputs, m_1_prev, z_prev, x_prev = self.iteration( |
| | feats, |
| | m_1_prev, |
| | z_prev, |
| | x_prev, |
| | _recycle=(num_iters > 1) |
| | ) |
| |
|
| | |
| | outputs.update(self.aux_heads(outputs)) |
| |
|
| | return outputs |
| |
|