| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import datetime |
| | from multiprocessing import cpu_count |
| | from typing import Mapping, Optional, Sequence, Any |
| |
|
| | import numpy as np |
| |
|
| | from openfold.data import templates, parsers, mmcif_parsing |
| | from openfold.data.tools import jackhmmer, hhblits, hhsearch |
| | from openfold.data.tools.utils import to_date |
| | from openfold.np import residue_constants, protein |
| |
|
| |
|
| | FeatureDict = Mapping[str, np.ndarray] |
| |
|
| | def empty_template_feats(n_res) -> FeatureDict: |
| | return { |
| | "template_aatype": np.zeros((0, n_res)).astype(np.int64), |
| | "template_all_atom_positions": |
| | np.zeros((0, n_res, 37, 3)).astype(np.float32), |
| | "template_sum_probs": np.zeros((0, 1)).astype(np.float32), |
| | "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32), |
| | } |
| |
|
| |
|
| | def make_template_features( |
| | input_sequence: str, |
| | hits: Sequence[Any], |
| | template_featurizer: Any, |
| | query_pdb_code: Optional[str] = None, |
| | query_release_date: Optional[str] = None, |
| | ) -> FeatureDict: |
| | hits_cat = sum(hits.values(), []) |
| | if(len(hits_cat) == 0 or template_featurizer is None): |
| | template_features = empty_template_feats(len(input_sequence)) |
| | else: |
| | templates_result = template_featurizer.get_templates( |
| | query_sequence=input_sequence, |
| | query_pdb_code=query_pdb_code, |
| | query_release_date=query_release_date, |
| | hits=hits_cat, |
| | ) |
| | template_features = templates_result.features |
| |
|
| | |
| | |
| | if(template_features["template_aatype"].shape[0] == 0): |
| | template_features = empty_template_feats(len(input_sequence)) |
| |
|
| | return template_features |
| |
|
| |
|
| | def make_sequence_features( |
| | sequence: str, description: str, num_res: int |
| | ) -> FeatureDict: |
| | """Construct a feature dict of sequence features.""" |
| | features = {} |
| | features["aatype"] = residue_constants.sequence_to_onehot( |
| | sequence=sequence, |
| | mapping=residue_constants.restype_order_with_x, |
| | map_unknown_to_x=True, |
| | ) |
| | features["between_segment_residues"] = np.zeros((num_res,), dtype=int) |
| | features["domain_name"] = np.array( |
| | [description.encode("utf-8")], dtype=np.object_ |
| | ) |
| | features["residue_index"] = np.array(range(num_res), dtype=int) |
| | features["seq_length"] = np.array([num_res] * num_res, dtype=int) |
| | features["sequence"] = np.array( |
| | [sequence.encode("utf-8")], dtype=np.object_ |
| | ) |
| | return features |
| |
|
| |
|
| | def make_mmcif_features( |
| | mmcif_object: mmcif_parsing.MmcifObject, chain_id: str |
| | ) -> FeatureDict: |
| | input_sequence = mmcif_object.chain_to_seqres[chain_id] |
| | description = "_".join([mmcif_object.file_id, chain_id]) |
| | num_res = len(input_sequence) |
| |
|
| | mmcif_feats = {} |
| |
|
| | mmcif_feats.update( |
| | make_sequence_features( |
| | sequence=input_sequence, |
| | description=description, |
| | num_res=num_res, |
| | ) |
| | ) |
| |
|
| | all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords( |
| | mmcif_object=mmcif_object, chain_id=chain_id |
| | ) |
| | mmcif_feats["all_atom_positions"] = all_atom_positions |
| | mmcif_feats["all_atom_mask"] = all_atom_mask |
| |
|
| | mmcif_feats["resolution"] = np.array( |
| | [mmcif_object.header["resolution"]], dtype=np.float32 |
| | ) |
| |
|
| | mmcif_feats["release_date"] = np.array( |
| | [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ |
| | ) |
| |
|
| | mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) |
| |
|
| | return mmcif_feats |
| |
|
| |
|
| | def _aatype_to_str_sequence(aatype): |
| | return ''.join([ |
| | residue_constants.restypes_with_x[aatype[i]] |
| | for i in range(len(aatype)) |
| | ]) |
| |
|
| |
|
| | def make_protein_features( |
| | protein_object: protein.Protein, |
| | description: str, |
| | _is_distillation: bool = False, |
| | ) -> FeatureDict: |
| | pdb_feats = {} |
| | aatype = protein_object.aatype |
| | sequence = _aatype_to_str_sequence(aatype) |
| | pdb_feats.update( |
| | make_sequence_features( |
| | sequence=sequence, |
| | description=description, |
| | num_res=len(protein_object.aatype), |
| | ) |
| | ) |
| |
|
| | all_atom_positions = protein_object.atom_positions |
| | all_atom_mask = protein_object.atom_mask |
| |
|
| | pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32) |
| | pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32) |
| |
|
| | pdb_feats["resolution"] = np.array([0.]).astype(np.float32) |
| | pdb_feats["is_distillation"] = np.array( |
| | 1. if _is_distillation else 0. |
| | ).astype(np.float32) |
| |
|
| | return pdb_feats |
| |
|
| |
|
| | def make_pdb_features( |
| | protein_object: protein.Protein, |
| | description: str, |
| | confidence_threshold: float = 0.5, |
| | is_distillation: bool = True, |
| | ) -> FeatureDict: |
| | pdb_feats = make_protein_features( |
| | protein_object, description, _is_distillation=True |
| | ) |
| |
|
| | if(is_distillation): |
| | high_confidence = protein_object.b_factors > confidence_threshold |
| | high_confidence = np.any(high_confidence, axis=-1) |
| | for i, confident in enumerate(high_confidence): |
| | if(not confident): |
| | pdb_feats["all_atom_mask"][i] = 0 |
| |
|
| | return pdb_feats |
| |
|
| |
|
| | def make_msa_features( |
| | msas: Sequence[Sequence[str]], |
| | deletion_matrices: Sequence[parsers.DeletionMatrix], |
| | ) -> FeatureDict: |
| | """Constructs a feature dict of MSA features.""" |
| | if not msas: |
| | raise ValueError("At least one MSA must be provided.") |
| |
|
| | int_msa = [] |
| | deletion_matrix = [] |
| | seen_sequences = set() |
| | for msa_index, msa in enumerate(msas): |
| | if not msa: |
| | raise ValueError( |
| | f"MSA {msa_index} must contain at least one sequence." |
| | ) |
| | for sequence_index, sequence in enumerate(msa): |
| | if sequence in seen_sequences: |
| | continue |
| | seen_sequences.add(sequence) |
| | int_msa.append( |
| | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence] |
| | ) |
| | deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) |
| |
|
| | num_res = len(msas[0][0]) |
| | num_alignments = len(int_msa) |
| | features = {} |
| | features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=int) |
| | features["msa"] = np.array(int_msa, dtype=int) |
| | features["num_alignments"] = np.array( |
| | [num_alignments] * num_res, dtype=int |
| | ) |
| | return features |
| |
|
| |
|
| | class AlignmentRunner: |
| | """Runs alignment tools and saves the results""" |
| | def __init__( |
| | self, |
| | jackhmmer_binary_path: Optional[str] = None, |
| | hhblits_binary_path: Optional[str] = None, |
| | hhsearch_binary_path: Optional[str] = None, |
| | uniref90_database_path: Optional[str] = None, |
| | mgnify_database_path: Optional[str] = None, |
| | bfd_database_path: Optional[str] = None, |
| | uniclust30_database_path: Optional[str] = None, |
| | pdb70_database_path: Optional[str] = None, |
| | use_small_bfd: Optional[bool] = None, |
| | no_cpus: Optional[int] = None, |
| | uniref_max_hits: int = 10000, |
| | mgnify_max_hits: int = 5000, |
| | ): |
| | """ |
| | Args: |
| | jackhmmer_binary_path: |
| | Path to jackhmmer binary |
| | hhblits_binary_path: |
| | Path to hhblits binary |
| | hhsearch_binary_path: |
| | Path to hhsearch binary |
| | uniref90_database_path: |
| | Path to uniref90 database. If provided, jackhmmer_binary_path |
| | must also be provided |
| | mgnify_database_path: |
| | Path to mgnify database. If provided, jackhmmer_binary_path |
| | must also be provided |
| | bfd_database_path: |
| | Path to BFD database. Depending on the value of use_small_bfd, |
| | one of hhblits_binary_path or jackhmmer_binary_path must be |
| | provided. |
| | uniclust30_database_path: |
| | Path to uniclust30. Searched alongside BFD if use_small_bfd is |
| | false. |
| | pdb70_database_path: |
| | Path to pdb70 database. |
| | use_small_bfd: |
| | Whether to search the BFD database alone with jackhmmer or |
| | in conjunction with uniclust30 with hhblits. |
| | no_cpus: |
| | The number of CPUs available for alignment. By default, all |
| | CPUs are used. |
| | uniref_max_hits: |
| | Max number of uniref hits |
| | mgnify_max_hits: |
| | Max number of mgnify hits |
| | """ |
| | db_map = { |
| | "jackhmmer": { |
| | "binary": jackhmmer_binary_path, |
| | "dbs": [ |
| | uniref90_database_path, |
| | mgnify_database_path, |
| | bfd_database_path if use_small_bfd else None, |
| | ], |
| | }, |
| | "hhblits": { |
| | "binary": hhblits_binary_path, |
| | "dbs": [ |
| | bfd_database_path if not use_small_bfd else None, |
| | ], |
| | }, |
| | "hhsearch": { |
| | "binary": hhsearch_binary_path, |
| | "dbs": [ |
| | pdb70_database_path, |
| | ], |
| | }, |
| | } |
| |
|
| | for name, dic in db_map.items(): |
| | binary, dbs = dic["binary"], dic["dbs"] |
| | if(binary is None and not all([x is None for x in dbs])): |
| | raise ValueError( |
| | f"{name} DBs provided but {name} binary is None" |
| | ) |
| |
|
| | if(not all([x is None for x in db_map["hhsearch"]["dbs"]]) |
| | and uniref90_database_path is None): |
| | raise ValueError( |
| | """uniref90_database_path must be specified in order to perform |
| | template search""" |
| | ) |
| |
|
| | self.uniref_max_hits = uniref_max_hits |
| | self.mgnify_max_hits = mgnify_max_hits |
| | self.use_small_bfd = use_small_bfd |
| |
|
| | if(no_cpus is None): |
| | no_cpus = cpu_count() |
| |
|
| | self.jackhmmer_uniref90_runner = None |
| | if(jackhmmer_binary_path is not None and |
| | uniref90_database_path is not None |
| | ): |
| | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( |
| | binary_path=jackhmmer_binary_path, |
| | database_path=uniref90_database_path, |
| | n_cpu=no_cpus, |
| | ) |
| | |
| | self.jackhmmer_small_bfd_runner = None |
| | self.hhblits_bfd_uniclust_runner = None |
| | if(bfd_database_path is not None): |
| | if use_small_bfd: |
| | self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( |
| | binary_path=jackhmmer_binary_path, |
| | database_path=bfd_database_path, |
| | n_cpu=no_cpus, |
| | ) |
| | else: |
| | dbs = [bfd_database_path] |
| | if(uniclust30_database_path is not None): |
| | dbs.append(uniclust30_database_path) |
| | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( |
| | binary_path=hhblits_binary_path, |
| | databases=dbs, |
| | n_cpu=no_cpus, |
| | ) |
| |
|
| | self.jackhmmer_mgnify_runner = None |
| | if(mgnify_database_path is not None): |
| | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( |
| | binary_path=jackhmmer_binary_path, |
| | database_path=mgnify_database_path, |
| | n_cpu=no_cpus, |
| | ) |
| |
|
| | self.hhsearch_pdb70_runner = None |
| | if(pdb70_database_path is not None): |
| | self.hhsearch_pdb70_runner = hhsearch.HHSearch( |
| | binary_path=hhsearch_binary_path, |
| | databases=[pdb70_database_path], |
| | n_cpu=no_cpus, |
| | ) |
| |
|
| | def run( |
| | self, |
| | fasta_path: str, |
| | output_dir: str, |
| | ): |
| | """Runs alignment tools on a sequence""" |
| | if(self.jackhmmer_uniref90_runner is not None): |
| | jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( |
| | fasta_path |
| | )[0] |
| | uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( |
| | jackhmmer_uniref90_result["sto"], |
| | max_sequences=self.uniref_max_hits |
| | ) |
| | uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m") |
| | with open(uniref90_out_path, "w") as f: |
| | f.write(uniref90_msa_as_a3m) |
| |
|
| | if(self.hhsearch_pdb70_runner is not None): |
| | hhsearch_result = self.hhsearch_pdb70_runner.query( |
| | uniref90_msa_as_a3m |
| | ) |
| | pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr") |
| | with open(pdb70_out_path, "w") as f: |
| | f.write(hhsearch_result) |
| |
|
| | if(self.jackhmmer_mgnify_runner is not None): |
| | jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( |
| | fasta_path |
| | )[0] |
| | mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( |
| | jackhmmer_mgnify_result["sto"], |
| | max_sequences=self.mgnify_max_hits |
| | ) |
| | mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m") |
| | with open(mgnify_out_path, "w") as f: |
| | f.write(mgnify_msa_as_a3m) |
| |
|
| | if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None): |
| | jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( |
| | fasta_path |
| | )[0] |
| | bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto") |
| | with open(bfd_out_path, "w") as f: |
| | f.write(jackhmmer_small_bfd_result["sto"]) |
| | elif(self.hhblits_bfd_uniclust_runner is not None): |
| | hhblits_bfd_uniclust_result = ( |
| | self.hhblits_bfd_uniclust_runner.query(fasta_path) |
| | ) |
| | if output_dir is not None: |
| | bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m") |
| | with open(bfd_out_path, "w") as f: |
| | f.write(hhblits_bfd_uniclust_result["a3m"]) |
| |
|
| |
|
| | class DataPipeline: |
| | """Assembles input features.""" |
| | def __init__( |
| | self, |
| | template_featurizer: Optional[templates.TemplateHitFeaturizer], |
| | ): |
| | self.template_featurizer = template_featurizer |
| |
|
| | def _parse_msa_data( |
| | self, |
| | alignment_dir: str, |
| | _alignment_index: Optional[Any] = None, |
| | ) -> Mapping[str, Any]: |
| | msa_data = {} |
| | |
| | if(_alignment_index is not None): |
| | fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") |
| |
|
| | def read_msa(start, size): |
| | fp.seek(start) |
| | msa = fp.read(size).decode("utf-8") |
| | return msa |
| |
|
| | for (name, start, size) in _alignment_index["files"]: |
| | ext = os.path.splitext(name)[-1] |
| |
|
| | if(ext == ".a3m"): |
| | msa, deletion_matrix = parsers.parse_a3m( |
| | read_msa(start, size) |
| | ) |
| | data = {"msa": msa, "deletion_matrix": deletion_matrix} |
| | elif(ext == ".sto"): |
| | msa, deletion_matrix, _ = parsers.parse_stockholm( |
| | read_msa(start, size) |
| | ) |
| | data = {"msa": msa, "deletion_matrix": deletion_matrix} |
| | else: |
| | continue |
| | |
| | msa_data[name] = data |
| | |
| | fp.close() |
| | else: |
| | for f in os.listdir(alignment_dir): |
| | path = os.path.join(alignment_dir, f) |
| | ext = os.path.splitext(f)[-1] |
| |
|
| | if(ext == ".a3m"): |
| | with open(path, "r") as fp: |
| | msa, deletion_matrix = parsers.parse_a3m(fp.read()) |
| | data = {"msa": msa, "deletion_matrix": deletion_matrix} |
| | elif(ext == ".sto"): |
| | with open(path, "r") as fp: |
| | msa, deletion_matrix, _ = parsers.parse_stockholm( |
| | fp.read() |
| | ) |
| | data = {"msa": msa, "deletion_matrix": deletion_matrix} |
| | else: |
| | continue |
| | |
| | msa_data[f] = data |
| |
|
| | return msa_data |
| |
|
| | def _parse_template_hits( |
| | self, |
| | alignment_dir: str, |
| | _alignment_index: Optional[Any] = None |
| | ) -> Mapping[str, Any]: |
| | all_hits = {} |
| | if(_alignment_index is not None): |
| | fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') |
| |
|
| | def read_template(start, size): |
| | fp.seek(start) |
| | return fp.read(size).decode("utf-8") |
| |
|
| | for (name, start, size) in _alignment_index["files"]: |
| | ext = os.path.splitext(name)[-1] |
| |
|
| | if(ext == ".hhr"): |
| | hits = parsers.parse_hhr(read_template(start, size)) |
| | all_hits[name] = hits |
| |
|
| | fp.close() |
| | else: |
| | for f in os.listdir(alignment_dir): |
| | path = os.path.join(alignment_dir, f) |
| | ext = os.path.splitext(f)[-1] |
| |
|
| | if(ext == ".hhr"): |
| | with open(path, "r") as fp: |
| | hits = parsers.parse_hhr(fp.read()) |
| | all_hits[f] = hits |
| |
|
| | return all_hits |
| |
|
| | def _process_msa_feats( |
| | self, |
| | alignment_dir: str, |
| | input_sequence: Optional[str] = None, |
| | _alignment_index: Optional[str] = None |
| | ) -> Mapping[str, Any]: |
| | msa_data = self._parse_msa_data(alignment_dir, _alignment_index) |
| | |
| | if(len(msa_data) == 0): |
| | if(input_sequence is None): |
| | raise ValueError( |
| | """ |
| | If the alignment dir contains no MSAs, an input sequence |
| | must be provided. |
| | """ |
| | ) |
| | msa_data["dummy"] = { |
| | "msa": [input_sequence], |
| | "deletion_matrix": [[0 for _ in input_sequence]], |
| | } |
| |
|
| | msas, deletion_matrices = zip(*[ |
| | (v["msa"], v["deletion_matrix"]) for v in msa_data.values() |
| | ]) |
| |
|
| | msa_features = make_msa_features( |
| | msas=msas, |
| | deletion_matrices=deletion_matrices, |
| | ) |
| |
|
| | return msa_features |
| |
|
| | def process_fasta( |
| | self, |
| | fasta_path: str, |
| | alignment_dir: str, |
| | _alignment_index: Optional[str] = None, |
| | ) -> FeatureDict: |
| | """Assembles features for a single sequence in a FASTA file""" |
| | with open(fasta_path) as f: |
| | fasta_str = f.read() |
| | input_seqs, input_descs = parsers.parse_fasta(fasta_str) |
| | if len(input_seqs) != 1: |
| | raise ValueError( |
| | f"More than one input sequence found in {fasta_path}." |
| | ) |
| | input_sequence = input_seqs[0] |
| | input_description = input_descs[0] |
| | num_res = len(input_sequence) |
| |
|
| | hits = self._parse_template_hits(alignment_dir, _alignment_index) |
| | template_features = make_template_features( |
| | input_sequence, |
| | hits, |
| | self.template_featurizer, |
| | ) |
| |
|
| | sequence_features = make_sequence_features( |
| | sequence=input_sequence, |
| | description=input_description, |
| | num_res=num_res, |
| | ) |
| |
|
| | msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) |
| | |
| | return { |
| | **sequence_features, |
| | **msa_features, |
| | **template_features |
| | } |
| |
|
| | def process_mmcif( |
| | self, |
| | mmcif: mmcif_parsing.MmcifObject, |
| | alignment_dir: str, |
| | chain_id: Optional[str] = None, |
| | _alignment_index: Optional[str] = None, |
| | ) -> FeatureDict: |
| | """ |
| | Assembles features for a specific chain in an mmCIF object. |
| | |
| | If chain_id is None, it is assumed that there is only one chain |
| | in the object. Otherwise, a ValueError is thrown. |
| | """ |
| | if chain_id is None: |
| | chains = mmcif.structure.get_chains() |
| | chain = next(chains, None) |
| | if chain is None: |
| | raise ValueError("No chains in mmCIF file") |
| | chain_id = chain.id |
| |
|
| | mmcif_feats = make_mmcif_features(mmcif, chain_id) |
| |
|
| | input_sequence = mmcif.chain_to_seqres[chain_id] |
| | hits = self._parse_template_hits(alignment_dir, _alignment_index) |
| | template_features = make_template_features( |
| | input_sequence, |
| | hits, |
| | self.template_featurizer, |
| | query_release_date=to_date(mmcif.header["release_date"]) |
| | ) |
| | |
| | msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) |
| |
|
| | return {**mmcif_feats, **template_features, **msa_features} |
| |
|
| | def process_pdb( |
| | self, |
| | pdb_path: str, |
| | alignment_dir: str, |
| | is_distillation: bool = True, |
| | chain_id: Optional[str] = None, |
| | _alignment_index: Optional[str] = None, |
| | ) -> FeatureDict: |
| | """ |
| | Assembles features for a protein in a PDB file. |
| | """ |
| | with open(pdb_path, 'r') as f: |
| | pdb_str = f.read() |
| |
|
| | protein_object = protein.from_pdb_string(pdb_str, chain_id) |
| | input_sequence = _aatype_to_str_sequence(protein_object.aatype) |
| | description = os.path.splitext(os.path.basename(pdb_path))[0].upper() |
| | pdb_feats = make_pdb_features( |
| | protein_object, |
| | description, |
| | is_distillation |
| | ) |
| |
|
| | hits = self._parse_template_hits(alignment_dir, _alignment_index) |
| | template_features = make_template_features( |
| | input_sequence, |
| | hits, |
| | self.template_featurizer, |
| | ) |
| |
|
| | msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) |
| |
|
| | return {**pdb_feats, **template_features, **msa_features} |
| |
|
| | def process_core( |
| | self, |
| | core_path: str, |
| | alignment_dir: str, |
| | _alignment_index: Optional[str] = None, |
| | ) -> FeatureDict: |
| | """ |
| | Assembles features for a protein in a ProteinNet .core file. |
| | """ |
| | with open(core_path, 'r') as f: |
| | core_str = f.read() |
| |
|
| | protein_object = protein.from_proteinnet_string(core_str) |
| | input_sequence = _aatype_to_str_sequence(protein_object.aatype) |
| | description = os.path.splitext(os.path.basename(core_path))[0].upper() |
| | core_feats = make_protein_features(protein_object, description) |
| | |
| | hits = self._parse_template_hits(alignment_dir, _alignment_index) |
| | template_features = make_template_features( |
| | input_sequence, |
| | hits, |
| | self.template_featurizer, |
| | ) |
| |
|
| | msa_features = self._process_msa_feats(alignment_dir, input_sequence) |
| |
|
| | return {**core_feats, **template_features, **msa_features} |
| |
|
| |
|