P2DFlow / openfold /data /data_pipeline.py
Holmes
test
ca7299e
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datetime
from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any
import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat,
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=int)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features["residue_index"] = np.array(range(num_res), dtype=int)
features["seq_length"] = np.array([num_res] * num_res, dtype=int)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def _aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
def make_protein_features(
protein_object: protein.Protein,
description: str,
_is_distillation: bool = False,
) -> FeatureDict:
pdb_feats = {}
aatype = protein_object.aatype
sequence = _aatype_to_str_sequence(aatype)
pdb_feats.update(
make_sequence_features(
sequence=sequence,
description=description,
num_res=len(protein_object.aatype),
)
)
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(
1. if _is_distillation else 0.
).astype(np.float32)
return pdb_feats
def make_pdb_features(
protein_object: protein.Protein,
description: str,
confidence_threshold: float = 0.5,
is_distillation: bool = True,
) -> FeatureDict:
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
)
if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence):
if(not confident):
pdb_feats["all_atom_mask"][i] = 0
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=int)
features["msa"] = np.array(int_msa, dtype=int)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=int
)
return features
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniclust30_database_path is not None):
dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = None
if(pdb70_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
)
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result["sto"],
max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None):
hhsearch_result = self.hhsearch_pdb70_runner.query(
uniref90_msa_as_a3m
)
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result["sto"],
max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
elif(self.hhblits_bfd_uniclust_runner is not None):
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
"""Assembles input features."""
def __init__(
self,
template_featurizer: Optional[templates.TemplateHitFeaturizer],
):
self.template_featurizer = template_featurizer
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
"""
If the alignment dir contains no MSAs, an input sequence
must be provided.
"""
)
msa_data["dummy"] = {
"msa": [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]],
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
**msa_features,
**template_features
}
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
def process_pdb(
self,
pdb_path: str,
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
def process_core(
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
"""
with open(core_path, 'r') as f:
core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features}