| | import argparse |
| | import dataclasses |
| | import functools as fn |
| | import pandas as pd |
| | import os |
| | import tree |
| | import torch |
| | import multiprocessing as mp |
| | import time |
| | import esm |
| | from Bio import PDB |
| | import numpy as np |
| | from data import utils as du |
| | from data import parsers |
| | from data import errors |
| | from data.repr import get_pre_repr |
| | from openfold.data import data_transforms |
| | from openfold.utils import rigid_utils |
| | from data.cal_trans_rotmats import cal_trans_rotmats |
| | from data.ESMfold_pred import ESMFold_Pred |
| |
|
| |
|
| | def process_file(file_path: str, write_dir: str): |
| | """Processes protein file into usable, smaller pickles. |
| | |
| | Args: |
| | file_path: Path to file to read. |
| | write_dir: Directory to write pickles to. |
| | |
| | Returns: |
| | Saves processed protein to pickle and returns metadata. |
| | |
| | Raises: |
| | DataError if a known filtering rule is hit. |
| | All other errors are unexpected and are propogated. |
| | """ |
| | metadata = {} |
| | pdb_name = os.path.basename(file_path).replace('.pdb', '') |
| | metadata['pdb_name'] = pdb_name |
| |
|
| | processed_path = os.path.join(write_dir, f'{pdb_name}.pkl') |
| | metadata['processed_path'] = os.path.abspath(processed_path) |
| | metadata['raw_path'] = file_path |
| | parser = PDB.PDBParser(QUIET=True) |
| | structure = parser.get_structure(pdb_name, file_path) |
| |
|
| | |
| | struct_chains = { |
| | chain.id.upper(): chain |
| | for chain in structure.get_chains()} |
| | metadata['num_chains'] = len(struct_chains) |
| |
|
| | |
| | struct_feats = [] |
| | all_seqs = set() |
| | for chain_id, chain in struct_chains.items(): |
| | |
| | chain_id = du.chain_str_to_int(chain_id) |
| | chain_prot = parsers.process_chain(chain, chain_id) |
| | chain_dict = dataclasses.asdict(chain_prot) |
| | chain_dict = du.parse_chain_feats(chain_dict) |
| | all_seqs.add(tuple(chain_dict['aatype'])) |
| | struct_feats.append(chain_dict) |
| | if len(all_seqs) == 1: |
| | metadata['quaternary_category'] = 'homomer' |
| | else: |
| | metadata['quaternary_category'] = 'heteromer' |
| | complex_feats = du.concat_np_features(struct_feats, False) |
| |
|
| | |
| | complex_aatype = complex_feats['aatype'] |
| | metadata['seq_len'] = len(complex_aatype) |
| | modeled_idx = np.where(complex_aatype != 20)[0] |
| | if np.sum(complex_aatype != 20) == 0: |
| | raise errors.LengthError('No modeled residues') |
| | min_modeled_idx = np.min(modeled_idx) |
| | max_modeled_idx = np.max(modeled_idx) |
| | metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1 |
| | complex_feats['modeled_idx'] = modeled_idx |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | du.write_pkl(processed_path, complex_feats) |
| |
|
| | return metadata |
| |
|
| |
|
| | def process_serially(all_paths, write_dir): |
| | all_metadata = [] |
| | for i, file_path in enumerate(all_paths): |
| | try: |
| | start_time = time.time() |
| | metadata = process_file( |
| | file_path, |
| | write_dir) |
| | elapsed_time = time.time() - start_time |
| | print(f'Finished {file_path} in {elapsed_time:2.2f}s') |
| | all_metadata.append(metadata) |
| | except errors.DataError as e: |
| | print(f'Failed {file_path}: {e}') |
| | return all_metadata |
| |
|
| |
|
| | def process_fn( |
| | file_path, |
| | verbose=None, |
| | write_dir=None): |
| | try: |
| | start_time = time.time() |
| | metadata = process_file( |
| | file_path, |
| | write_dir) |
| | elapsed_time = time.time() - start_time |
| | if verbose: |
| | print(f'Finished {file_path} in {elapsed_time:2.2f}s') |
| | return metadata |
| | except errors.DataError as e: |
| | if verbose: |
| | print(f'Failed {file_path}: {e}') |
| |
|
| |
|
| | def main(args): |
| | pdb_dir = args.pdb_dir |
| | all_file_paths = [ |
| | os.path.join(pdb_dir, x) |
| | for x in os.listdir(args.pdb_dir) if '.pdb' in x] |
| | total_num_paths = len(all_file_paths) |
| | write_dir = args.write_dir |
| | if not os.path.exists(write_dir): |
| | os.makedirs(write_dir) |
| | if args.debug: |
| | metadata_file_name = 'metadata_debug.csv' |
| | else: |
| | metadata_file_name = 'metadata.csv' |
| | metadata_path = os.path.join(write_dir, metadata_file_name) |
| | print(f'Files will be written to {write_dir}') |
| |
|
| | |
| | if args.num_processes == 1 or args.debug: |
| | all_metadata = process_serially( |
| | all_file_paths, |
| | write_dir) |
| | else: |
| | _process_fn = fn.partial( |
| | process_fn, |
| | verbose=args.verbose, |
| | write_dir=write_dir) |
| | with mp.Pool(processes=args.num_processes) as pool: |
| | all_metadata = pool.map(_process_fn, all_file_paths) |
| | all_metadata = [x for x in all_metadata if x is not None] |
| | metadata_df = pd.DataFrame(all_metadata) |
| | metadata_df.to_csv(metadata_path, index=False) |
| | succeeded = len(all_metadata) |
| | print( |
| | f'Finished processing {succeeded}/{total_num_paths} files') |
| |
|
| |
|
| | def cal_repr(processed_file_path, model_esm2, alphabet, batch_converter, esm_device): |
| | print(f'cal_repr for {processed_file_path}') |
| | processed_feats_org = du.read_pkl(processed_file_path) |
| | processed_feats = du.parse_chain_feats(processed_feats_org) |
| |
|
| | |
| | modeled_idx = processed_feats['modeled_idx'] |
| | min_idx = np.min(modeled_idx) |
| | max_idx = np.max(modeled_idx) |
| | |
| | processed_feats = tree.map_structure( |
| | lambda x: x[min_idx:(max_idx+1)], processed_feats) |
| |
|
| | |
| | chain_feats = { |
| | 'aatype': torch.tensor(processed_feats['aatype']).long(), |
| | 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), |
| | 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() |
| | } |
| | chain_feats = data_transforms.atom37_to_frames(chain_feats) |
| | rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0] |
| | rotmats_1 = rigids_1.get_rots().get_rot_mats() |
| | trans_1 = rigids_1.get_trans() |
| | |
| |
|
| | node_repr_pre, pair_repr_pre = get_pre_repr(chain_feats['aatype'], model_esm2, alphabet, batch_converter, device = esm_device) |
| | node_repr_pre = node_repr_pre[0].cpu() |
| | pair_repr_pre = pair_repr_pre[0].cpu() |
| |
|
| | out = { |
| | 'aatype': chain_feats['aatype'], |
| | 'rotmats_1': rotmats_1, |
| | 'trans_1': trans_1, |
| | 'res_mask': torch.tensor(processed_feats['bb_mask']).int(), |
| | 'bb_positions': processed_feats['bb_positions'], |
| | 'all_atom_positions':chain_feats['all_atom_positions'], |
| | 'node_repr_pre':node_repr_pre, |
| | 'pair_repr_pre':pair_repr_pre, |
| | } |
| |
|
| | du.write_pkl(processed_file_path, out) |
| |
|
| | def cal_static_structure(processed_file_path, raw_pdb_file, ESMFold): |
| | output_total = du.read_pkl(processed_file_path) |
| |
|
| | save_dir = os.path.join(os.path.dirname(raw_pdb_file), 'ESMFold_Pred_results') |
| | os.makedirs(save_dir, exist_ok=True) |
| | save_path = os.path.join(save_dir, os.path.basename(processed_file_path)[:6]+'_esmfold.pdb') |
| | if not os.path.exists(save_path): |
| | print(f'cal_static_structure for {processed_file_path}') |
| | ESMFold.predict_str(raw_pdb_file, save_path) |
| | trans, rotmats = cal_trans_rotmats(save_path) |
| | output_total['trans_esmfold'] = trans |
| | output_total['rotmats_esmfold'] = rotmats |
| |
|
| | du.write_pkl(processed_file_path, output_total) |
| |
|
| |
|
| | def merge_pdb(metadata_path, traj_info_file, valid_seq_file, merged_output_file): |
| | df1 = pd.read_csv(metadata_path) |
| | df2 = pd.read_csv(traj_info_file) |
| | df3 = pd.read_csv(valid_seq_file) |
| |
|
| | |
| | df1['traj_filename'] = [os.path.basename(i) for i in df1['raw_path']] |
| |
|
| | |
| | merged = df1.merge(df2[['traj_filename', 'energy']], on='traj_filename', how='left') |
| | merged['is_trainset'] = ~merged['traj_filename'].str[:6].isin(df3['file']) |
| |
|
| | |
| | merged.to_csv(merged_output_file, index=False) |
| | print('merge complete!') |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--pdb_dir", type=str, default="./dataset/ATLAS/select") |
| | parser.add_argument("--write_dir", type=str, default="./dataset/ATLAS/select/pkl") |
| | parser.add_argument("--csv_name", type=str, default="metadata.csv") |
| | parser.add_argument("--debug", type=bool, default=False) |
| | parser.add_argument("--num_processes", type=int, default=48) |
| | parser.add_argument('--verbose', help='Whether to log everything.',action='store_true') |
| |
|
| | parser.add_argument("--esm_device", type=str, default='cuda') |
| |
|
| | parser.add_argument("--traj_info_file", type=str, default='./dataset/ATLAS/select/traj_info_select.csv') |
| | parser.add_argument("--valid_seq_file", type=str, default='./inference/valid_seq.csv') |
| | parser.add_argument("--merged_output_file", type=str, default='./dataset/ATLAS/select/pkl/metadata_merged.csv') |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | main(args) |
| |
|
| | |
| | csv_path = os.path.join(args.write_dir, args.csv_name) |
| | pdb_csv = pd.read_csv(csv_path) |
| | pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) |
| | model_esm2, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| | batch_converter = alphabet.get_batch_converter() |
| | model_esm2.eval() |
| | model_esm2.requires_grad_(False) |
| | model_esm2.to(args.esm_device) |
| | for idx in range(len(pdb_csv)): |
| | cal_repr(pdb_csv.iloc[idx]['processed_path'], model_esm2, alphabet, batch_converter, args.esm_device) |
| |
|
| | |
| | csv_path = os.path.join(args.write_dir, args.csv_name) |
| | pdb_csv = pd.read_csv(csv_path) |
| | ESMFold = ESMFold_Pred(device = args.esm_device) |
| | for idx in range(len(pdb_csv)): |
| | cal_static_structure(pdb_csv.iloc[idx]['processed_path'], pdb_csv.iloc[idx]['raw_path'], ESMFold) |
| |
|
| | |
| | csv_path = os.path.join(args.write_dir, args.csv_name) |
| | merge_pdb(csv_path, args.traj_info_file, args.valid_seq_file, args.merged_output_file) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|