| | """PDB data loader.""" |
| | import math |
| | import torch |
| | import tree |
| | import numpy as np |
| | import torch |
| | import pandas as pd |
| | import logging |
| | import os |
| | import random |
| | import esm |
| | import copy |
| |
|
| | from data import utils as du |
| | from data.repr import get_pre_repr |
| | from openfold.data import data_transforms |
| | from openfold.utils import rigid_utils |
| | from data.residue_constants import restype_atom37_mask, order2restype_with_mask |
| |
|
| | from pytorch_lightning import LightningDataModule |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.utils.data.distributed import DistributedSampler, dist |
| | from scipy.spatial.transform import Rotation as scipy_R |
| |
|
| |
|
| |
|
| |
|
| | class PdbDataModule(LightningDataModule): |
| | def __init__(self, data_cfg): |
| | super().__init__() |
| | self.data_cfg = data_cfg |
| | self.loader_cfg = data_cfg.loader |
| | self.dataset_cfg = data_cfg.dataset |
| | self.sampler_cfg = data_cfg.sampler |
| |
|
| | def setup(self, stage: str): |
| | self._train_dataset = PdbDataset( |
| | dataset_cfg=self.dataset_cfg, |
| | is_training=True, |
| | ) |
| | self._valid_dataset = PdbDataset( |
| | dataset_cfg=self.dataset_cfg, |
| | is_training=False, |
| | ) |
| |
|
| | def train_dataloader(self, rank=None, num_replicas=None): |
| | num_workers = self.loader_cfg.num_workers |
| | return DataLoader( |
| | self._train_dataset, |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | sampler=DistributedSampler(self._train_dataset, shuffle=True), |
| |
|
| | num_workers=self.loader_cfg.num_workers, |
| | prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
| | persistent_workers=True if num_workers > 0 else False, |
| | |
| | ) |
| |
|
| | def val_dataloader(self): |
| | num_workers = self.loader_cfg.num_workers |
| | return DataLoader( |
| | self._valid_dataset, |
| | sampler=DistributedSampler(self._valid_dataset, shuffle=False), |
| | num_workers=self.loader_cfg.num_workers, |
| | prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
| | persistent_workers=True, |
| | |
| | ) |
| |
|
| |
|
| | class PdbDataset(Dataset): |
| | def __init__( |
| | self, |
| | *, |
| | dataset_cfg, |
| | is_training, |
| | ): |
| | self._log = logging.getLogger(__name__) |
| | self._is_training = is_training |
| | self._dataset_cfg = dataset_cfg |
| | self.split_frac = self._dataset_cfg.split_frac |
| | self.random_seed = self._dataset_cfg.seed |
| | |
| |
|
| | self._init_metadata() |
| |
|
| | @property |
| | def is_training(self): |
| | return self._is_training |
| |
|
| | @property |
| | def dataset_cfg(self): |
| | return self._dataset_cfg |
| |
|
| | def _init_metadata(self): |
| | """Initialize metadata.""" |
| |
|
| | |
| | pdb_csv = pd.read_csv(self.dataset_cfg.csv_path) |
| | self.raw_csv = pdb_csv |
| | pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_num_res] |
| | pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= self.dataset_cfg.min_num_res] |
| |
|
| | if self.dataset_cfg.subset is not None: |
| | pdb_csv = pdb_csv.iloc[:self.dataset_cfg.subset] |
| | pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) |
| |
|
| | |
| | |
| |
|
| | |
| | if self.is_training: |
| | self.csv = pdb_csv[pdb_csv['is_trainset']] |
| | self.csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed).reset_index() |
| | self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"train.csv"), index=False) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | self._log.info( |
| | f"Training: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
| | else: |
| | self.csv = pdb_csv[~pdb_csv['is_trainset']] |
| | |
| | |
| | |
| | self.csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_eval_length] |
| | self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"valid.csv"), index=False) |
| |
|
| | self.csv = self.csv.sample(n=min(self.dataset_cfg.max_valid_num, len(self.csv)), random_state=self.random_seed).reset_index() |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | self._log.info( |
| | f"Valid: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | def __len__(self): |
| | return len(self.csv) |
| |
|
| | def __getitem__(self, idx): |
| | |
| |
|
| | processed_path = self.csv.iloc[idx]['processed_path'] |
| | chain_feats = du.read_pkl(processed_path) |
| | chain_feats['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) |
| |
|
| | energy = chain_feats['energy'] |
| |
|
| |
|
| | if self.is_training and self._dataset_cfg.use_split: |
| | |
| |
|
| | split_len = random.randint(self.dataset_cfg.min_num_res, min(self._dataset_cfg.split_len, chain_feats['aatype'].shape[0])) |
| |
|
| | idx = random.randint(0,chain_feats['aatype'].shape[0]-split_len) |
| | output_total = copy.deepcopy(chain_feats) |
| |
|
| | output_total['energy'] = torch.ones(chain_feats['aatype'].shape) |
| |
|
| | output_temp = tree.map_structure(lambda x: x[idx:idx+split_len], output_total) |
| |
|
| | bb_center = np.sum(output_temp['bb_positions'], axis=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
| | output_temp['trans_1']=(output_temp['trans_1'] - torch.from_numpy(bb_center[None, :])).float() |
| | output_temp['bb_positions']=output_temp['bb_positions']- bb_center[None, :] |
| | output_temp['all_atom_positions']=output_temp['all_atom_positions'] - torch.from_numpy(bb_center[None, None, :]) |
| | output_temp['pair_repr_pre'] = output_temp['pair_repr_pre'][:,idx:idx+split_len] |
| |
|
| | bb_center_esmfold = torch.sum(output_temp['trans_esmfold'], dim=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
| | output_temp['trans_esmfold']=(output_temp['trans_esmfold'] - bb_center_esmfold[None, :]).float() |
| |
|
| | chain_feats = output_temp |
| | chain_feats['energy'] = energy |
| |
|
| |
|
| | if self._dataset_cfg.use_rotate_enhance: |
| | rot_vet = [random.random() for _ in range(3)] |
| | rot_mat = torch.tensor(scipy_R.from_rotvec(rot_vet).as_matrix()) |
| | chain_feats['all_atom_positions']=torch.einsum('lij,kj->lik',chain_feats['all_atom_positions'], |
| | rot_mat.type(chain_feats['all_atom_positions'].dtype)) |
| | |
| | all_atom_mask = np.array([restype_atom37_mask[i] for i in chain_feats['aatype']]) |
| |
|
| | chain_feats_temp = { |
| | 'aatype': chain_feats['aatype'], |
| | 'all_atom_positions': chain_feats['all_atom_positions'], |
| | 'all_atom_mask': torch.tensor(all_atom_mask).double(), |
| | } |
| | chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp) |
| | curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0] |
| | chain_feats['trans_1'] = curr_rigid.get_trans() |
| | chain_feats['rotmats_1'] = curr_rigid.get_rots().get_rot_mats() |
| | chain_feats['bb_positions']=(chain_feats['trans_1']).numpy().astype(chain_feats['bb_positions'].dtype) |
| |
|
| | return chain_feats |
| |
|