| | """Protein dataset class.""" |
| | import os |
| | import pickle |
| | from pathlib import Path |
| | from glob import glob |
| | from typing import Optional, Sequence, List, Union |
| | from functools import lru_cache |
| | import tree |
| |
|
| | from tqdm import tqdm |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| |
|
| | from src.common import residue_constants, data_transforms, rigid_utils, protein |
| |
|
| |
|
| | CA_IDX = residue_constants.atom_order['CA'] |
| | DTYPE_MAPPING = { |
| | 'aatype': torch.long, |
| | 'atom_positions': torch.double, |
| | 'atom_mask': torch.double, |
| | } |
| |
|
| |
|
| | class ProteinFeatureTransform: |
| | def __init__(self, |
| | unit: Optional[str] = 'angstrom', |
| | truncate_length: Optional[int] = None, |
| | strip_missing_residues: bool = True, |
| | recenter_and_scale: bool = True, |
| | eps: float = 1e-8, |
| | ): |
| | if unit == 'angstrom': |
| | self.coordinate_scale = 1.0 |
| | elif unit in ('nm', 'nanometer'): |
| | self.coordiante_scale = 0.1 |
| | else: |
| | raise ValueError(f"Invalid unit: {unit}") |
| | |
| | if truncate_length is not None: |
| | assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}" |
| | self.truncate_length = truncate_length |
| | |
| | self.strip_missing_residues = strip_missing_residues |
| | self.recenter_and_scale = recenter_and_scale |
| | self.eps = eps |
| | |
| | def __call__(self, chain_feats): |
| | chain_feats = self.patch_feats(chain_feats) |
| | |
| | if self.strip_missing_residues: |
| | chain_feats = self.strip_ends(chain_feats) |
| | |
| | if self.truncate_length is not None: |
| | chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length) |
| | |
| | |
| | if self.recenter_and_scale: |
| | chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps) |
| | |
| | |
| | chain_feats = self.map_to_tensors(chain_feats) |
| | |
| | chain_feats = self.protein_data_transform(chain_feats) |
| | |
| | |
| | return chain_feats |
| | |
| | @staticmethod |
| | def patch_feats(chain_feats): |
| | seq_mask = chain_feats['atom_mask'][:, CA_IDX] |
| | |
| | residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) |
| | patch_feats = { |
| | 'seq_mask': seq_mask, |
| | 'residue_mask': seq_mask, |
| | 'residue_idx': residue_idx, |
| | 'fixed_mask': np.zeros_like(seq_mask), |
| | 'sc_ca_t': np.zeros(seq_mask.shape + (3, )), |
| | } |
| | chain_feats.update(patch_feats) |
| | return chain_feats |
| | |
| | @staticmethod |
| | def strip_ends(chain_feats): |
| | |
| | modeled_idx = np.where(chain_feats['aatype'] != 20)[0] |
| | min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx) |
| | chain_feats = tree.map_structure( |
| | lambda x: x[min_idx : (max_idx+1)], chain_feats) |
| | return chain_feats |
| | |
| | @staticmethod |
| | def random_truncate(chain_feats, max_len): |
| | L = chain_feats['aatype'].shape[0] |
| | if L > max_len: |
| | |
| | start = np.random.randint(0, L - max_len + 1) |
| | end = start + max_len |
| | chain_feats = tree.map_structure( |
| | lambda x: x[start : end], chain_feats) |
| | return chain_feats |
| | |
| | @staticmethod |
| | def map_to_tensors(chain_feats): |
| | chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()} |
| | |
| | for k, dtype in DTYPE_MAPPING.items(): |
| | if k in chain_feats: |
| | chain_feats[k] = chain_feats[k].type(dtype) |
| | return chain_feats |
| | |
| | @staticmethod |
| | def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8): |
| | |
| | bb_pos = chain_feats['atom_positions'][:, CA_IDX] |
| | bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps) |
| | centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
| | scaled_pos = centered_pos * coordinate_scale |
| | chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
| | return chain_feats |
| |
|
| | @staticmethod |
| | def protein_data_transform(chain_feats): |
| | chain_feats.update( |
| | { |
| | "all_atom_positions": chain_feats["atom_positions"], |
| | "all_atom_mask": chain_feats["atom_mask"], |
| | } |
| | ) |
| | chain_feats = data_transforms.atom37_to_frames(chain_feats) |
| | chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats) |
| | chain_feats = data_transforms.get_backbone_frames(chain_feats) |
| | chain_feats = data_transforms.get_chi_angles(chain_feats) |
| | chain_feats = data_transforms.make_pseudo_beta("")(chain_feats) |
| | chain_feats = data_transforms.make_atom14_masks(chain_feats) |
| | chain_feats = data_transforms.make_atom14_positions(chain_feats) |
| | |
| | |
| | chain_feats.pop("all_atom_positions") |
| | chain_feats.pop("all_atom_mask") |
| | return chain_feats |
| | |
| |
|
| | class MetadataFilter: |
| | def __init__(self, |
| | min_len: Optional[int] = None, |
| | max_len: Optional[int] = None, |
| | min_chains: Optional[int] = None, |
| | max_chains: Optional[int] = None, |
| | min_resolution: Optional[int] = None, |
| | max_resolution: Optional[int] = None, |
| | include_structure_method: Optional[List[str]] = None, |
| | include_oligomeric_detail: Optional[List[str]] = None, |
| | **kwargs, |
| | ): |
| | self.min_len = min_len |
| | self.max_len = max_len |
| | self.min_chains = min_chains |
| | self.max_chains = max_chains |
| | self.min_resolution = min_resolution |
| | self.max_resolution = max_resolution |
| | self.include_structure_method = include_structure_method |
| | self.include_oligomeric_detail = include_oligomeric_detail |
| | |
| | def __call__(self, df): |
| | _pre_filter_len = len(df) |
| | if self.min_len is not None: |
| | df = df[df['raw_seq_len'] >= self.min_len] |
| | if self.max_len is not None: |
| | df = df[df['raw_seq_len'] <= self.max_len] |
| | if self.min_chains is not None: |
| | df = df[df['num_chains'] >= self.min_chains] |
| | if self.max_chains is not None: |
| | df = df[df['num_chains'] <= self.max_chains] |
| | if self.min_resolution is not None: |
| | df = df[df['resolution'] >= self.min_resolution] |
| | if self.max_resolution is not None: |
| | df = df[df['resolution'] <= self.max_resolution] |
| | if self.include_structure_method is not None: |
| | df = df[df['include_structure_method'].isin(self.include_structure_method)] |
| | if self.include_oligomeric_detail is not None: |
| | df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)] |
| | |
| | print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter") |
| | return df |
| |
|
| |
|
| | class RandomAccessProteinDataset(torch.utils.data.Dataset): |
| | """Random access to pickle protein objects of dataset. |
| | |
| | dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors']) |
| | |
| | Note that each value is a ndarray in shape (L, *), for example: |
| | 'atom_positions': (L, 37, 3) |
| | """ |
| | def __init__(self, |
| | path_to_dataset: Union[Path, str], |
| | path_to_seq_embedding: Optional[Path] = None, |
| | metadata_filter: Optional[MetadataFilter] = None, |
| | training: bool = True, |
| | transform: Optional[ProteinFeatureTransform] = None, |
| | suffix: Optional[str] = '.pkl', |
| | accession_code_fillter: Optional[Sequence[str]] = None, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | path_to_dataset = os.path.expanduser(path_to_dataset) |
| | suffix = suffix if suffix.startswith('.') else '.' + suffix |
| | assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}" |
| | |
| | if os.path.isfile(path_to_dataset): |
| | assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)" |
| | self._df = pd.read_csv(path_to_dataset) |
| | self._df.sort_values('modeled_seq_len', ascending=False) |
| | if metadata_filter: |
| | self._df = metadata_filter(self._df) |
| | self._data = self._df['processed_complex_path'].tolist() |
| | elif os.path.isdir(path_to_dataset): |
| | self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix))) |
| | assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'" |
| | else: |
| | _pattern = path_to_dataset |
| | self._data = sorted(glob(_pattern)) |
| | assert len(self._data) > 0, f"No files found in '{_pattern}'" |
| | |
| | if accession_code_fillter and len(accession_code_fillter) > 0: |
| | self._data = [p for p in self._data |
| | if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter) |
| | ] |
| | |
| | self.data = np.asarray(self._data) |
| | self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \ |
| | if path_to_seq_embedding is not None else None |
| | self.suffix = suffix |
| | self.transform = transform |
| | self.training = training |
| | |
| | |
| | @property |
| | def num_samples(self): |
| | return len(self.data) |
| | |
| | def len(self): |
| | return self.__len__() |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| | def get(self, idx): |
| | return self.__getitem__(idx) |
| |
|
| | @lru_cache(maxsize=100) |
| | def __getitem__(self, idx): |
| | """return single pyg.Data() instance |
| | """ |
| | data_path = self.data[idx] |
| | accession_code = os.path.splitext(os.path.basename(data_path))[0] |
| | |
| | if self.suffix == '.pkl': |
| | |
| | with open(data_path, 'rb') as f: |
| | data_object = pickle.load(f) |
| | elif self.suffix == '.pdb': |
| | |
| | with open(data_path, 'r') as f: |
| | pdb_string = f.read() |
| | data_object = protein.from_pdb_string(pdb_string).to_dict() |
| | |
| | |
| | if self.transform is not None: |
| | data_object = self.transform(data_object) |
| | |
| | |
| | if self.path_to_seq_embedding is not None: |
| | embed_dict = torch.load( |
| | os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt") |
| | ) |
| | data_object.update( |
| | { |
| | 'seq_emb': embed_dict['representations'][33].float(), |
| | } |
| | ) |
| | |
| | data_object['accession_code'] = accession_code |
| | return data_object |
| |
|
| | |
| |
|
| | class PretrainPDBDataset(RandomAccessProteinDataset): |
| | def __init__(self, |
| | path_to_dataset: str, |
| | metadata_filter: MetadataFilter, |
| | transform: ProteinFeatureTransform, |
| | **kwargs, |
| | ): |
| | super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset, |
| | metadata_filter=metadata_filter, |
| | transform=transform, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | class SamplingPDBDataset(RandomAccessProteinDataset): |
| | def __init__(self, |
| | path_to_dataset: str, |
| | training: bool = False, |
| | suffix: str = '.pdb', |
| | transform: Optional[ProteinFeatureTransform] = None, |
| | accession_code_fillter: Optional[Sequence[str]] = None, |
| | ): |
| | assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}" |
| | super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset, |
| | training=training, |
| | suffix=suffix, |
| | transform=transform, |
| | accession_code_fillter=accession_code_fillter, |
| | metadata_filter=None, |
| | ) |
| | |