| | |
| | |
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | from tqdm import tqdm |
| | from os.path import splitext, basename |
| |
|
| | import ray |
| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| |
|
| | from data.format import Atom, Block, VOCAB |
| | from data.converter.pdb_to_list_blocks import pdb_to_list_blocks |
| | from data.converter.list_blocks_to_pdb import list_blocks_to_pdb |
| | from data.codesign import calculate_covariance_matrix |
| | from utils.const import sidechain_atoms |
| | from utils.logger import print_log |
| | from evaluation.dG.openmm_relaxer import ForceFieldMinimizer |
| |
|
| |
|
| | class DesignDataset(torch.utils.data.Dataset): |
| |
|
| | MAX_N_ATOM = 14 |
| |
|
| | def __init__(self, pdbs, epitopes, lengths_range=None, seqs=None) -> None: |
| | super().__init__() |
| | self.pdbs = pdbs |
| | self.epitopes = epitopes |
| | self.lengths_range = lengths_range |
| | self.seqs = seqs |
| | |
| | assert (self.seqs is not None and self.lengths_range is None) | \ |
| | (self.seqs is None and self.lengths_range is not None) |
| |
|
| | def get_epitope(self, idx): |
| | pdb, epitope_def = self.pdbs[idx], self.epitopes[idx] |
| |
|
| | with open(epitope_def, 'r') as fin: |
| | epitope = json.load(fin) |
| | to_str = lambda pos: f'{pos[0]}-{pos[1]}' |
| | epi_map = {} |
| | for chain_name, pos in epitope: |
| | if chain_name not in epi_map: |
| | epi_map[chain_name] = {} |
| | epi_map[chain_name][to_str(pos)] = True |
| | residues, position_ids = [], [] |
| | chain2blocks = pdb_to_list_blocks(pdb, list(epi_map.keys()), dict_form=True) |
| | if len(chain2blocks) != len(epi_map): |
| | print_log(f'Some chains in the epitope are missing. Parsed {list(chain2blocks.keys())}, given {list(epi_map.keys())}.', level='WARN') |
| | for chain_name in chain2blocks: |
| | chain = chain2blocks[chain_name] |
| | for i, block in enumerate(chain): |
| | if to_str(block.id) in epi_map[chain_name]: |
| | residues.append(block) |
| | position_ids.append(i + 1) |
| | return residues, position_ids, chain2blocks |
| |
|
| | def generate_pep_chain(self, idx): |
| | if self.lengths_range is not None: |
| | lmin, lmax = self.lengths_range[idx] |
| | length = np.random.randint(lmin, lmax) |
| | unk_block = Block(VOCAB.symbol_to_abrv(VOCAB.UNK), [Atom('CA', [0, 0, 0], 'C')]) |
| | return [unk_block] * length |
| | else: |
| | seq = self.seqs[idx] |
| | blocks = [] |
| | for s in seq: |
| | atoms = [] |
| | for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(s, []): |
| | atoms.append(Atom(atom_name, [0, 0, 0], atom_name[0])) |
| | blocks.append(Block(VOCAB.symbol_to_abrv(s), atoms)) |
| | return blocks |
| | |
| | def __len__(self): |
| | return len(self.pdbs) |
| |
|
| | def __getitem__(self, idx: int): |
| | rec_blocks, rec_position_ids, rec_chain2blocks = self.get_epitope(idx) |
| | lig_blocks = self.generate_pep_chain(idx) |
| |
|
| | mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] |
| | position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] |
| | X, S, atom_mask = [], [], [] |
| | for block in rec_blocks + lig_blocks: |
| | symbol = VOCAB.abrv_to_symbol(block.abrv) |
| | atom2coord = { unit.name: unit.get_coord() for unit in block.units } |
| | bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() |
| | coords, coord_mask = [], [] |
| | for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): |
| | if atom_name in atom2coord: |
| | coords.append(atom2coord[atom_name]) |
| | coord_mask.append(1) |
| | else: |
| | coords.append(bb_pos) |
| | coord_mask.append(0) |
| | n_pad = self.MAX_N_ATOM - len(coords) |
| | for _ in range(n_pad): |
| | coords.append(bb_pos) |
| | coord_mask.append(0) |
| |
|
| | X.append(coords) |
| | S.append(VOCAB.symbol_to_idx(symbol)) |
| | atom_mask.append(coord_mask) |
| | |
| | X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) |
| | mask = torch.tensor(mask, dtype=torch.bool) |
| | cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) |
| | eps = 1e-4 |
| | cov = cov + eps * np.identity(cov.shape[0]) |
| | L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) |
| |
|
| | return { |
| | 'X': X, |
| | 'S': torch.tensor(S, dtype=torch.long), |
| | 'position_ids': torch.tensor(position_ids, dtype=torch.long), |
| | 'mask': mask, |
| | 'atom_mask': atom_mask, |
| | 'lengths': len(S), |
| | 'rec_chain2blocks': rec_chain2blocks, |
| | 'L': L |
| | } |
| |
|
| | def collate_fn(self, batch): |
| | results = {} |
| | for key in batch[0]: |
| | values = [item[key] for item in batch] |
| | if key == 'lengths': |
| | results[key] = torch.tensor(values, dtype=torch.long) |
| | elif key == 'rec_chain2blocks': |
| | results[key] = values |
| | else: |
| | results[key] = torch.cat(values, dim=0) |
| | return results |
| |
|
| |
|
| | @ray.remote(num_cpus=1, num_gpus=1/16) |
| | def openmm_relax(pdb_path): |
| | force_field = ForceFieldMinimizer() |
| | force_field(pdb_path, pdb_path) |
| | return pdb_path |
| |
|
| |
|
| | def design(mode, ckpt, gpu, pdbs, epitope_defs, n_samples, out_dir, |
| | lengths_range=None, seqs=None, identifiers=None, batch_size=8, num_workers=4): |
| |
|
| | |
| | if not os.path.exists(out_dir): |
| | os.makedirs(out_dir) |
| | result_summary = open(os.path.join(out_dir, 'summary.jsonl'), 'w') |
| | if identifiers is None: |
| | identifiers = [splitext(basename(pdb))[0] for pdb in pdbs] |
| | |
| | device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}') |
| | model = torch.load(ckpt, map_location='cpu') |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | |
| | if lengths_range is None: lengths_range = [None for _ in pdbs] |
| | if seqs is None: seqs = [None for _ in pdbs] |
| | expand_pdbs, expand_epitopes, expand_lens, expand_ids, expand_seqs = [], [], [], [], [] |
| | for _id, pdb, epitope, l, s, n in zip(identifiers, pdbs, epitope_defs, lengths_range, seqs, n_samples): |
| | expand_ids.extend([f'{_id}_{i}' for i in range(n)]) |
| | expand_pdbs.extend([pdb for _ in range(n)]) |
| | expand_epitopes.extend([epitope for _ in range(n)]) |
| | expand_lens.extend([l for _ in range(n)]) |
| | expand_seqs.extend([s for _ in range(n)]) |
| | |
| | if expand_lens[0] is None: expand_lens = None |
| | if expand_seqs[0] is None: expand_seqs = None |
| | dataset = DesignDataset(expand_pdbs, expand_epitopes, expand_lens, expand_seqs) |
| | dataloader = DataLoader(dataset, batch_size=batch_size, |
| | num_workers=num_workers, |
| | collate_fn=dataset.collate_fn, |
| | shuffle=False |
| | ) |
| | |
| | |
| | cnt = 0 |
| | all_pdbs = [] |
| | for batch in tqdm(dataloader): |
| | with torch.no_grad(): |
| | |
| | for k in batch: |
| | if hasattr(batch[k], 'to'): |
| | batch[k] = batch[k].to(device) |
| | |
| | batch_X, batch_S, batch_pmetric = model.sample( |
| | batch['X'], batch['S'], |
| | batch['mask'], batch['position_ids'], |
| | batch['lengths'], batch['atom_mask'], |
| | L=batch['L'], sample_opt={ |
| | 'energy_func': 'default', |
| | 'energy_lambda': 0.5 if mode == 'struct_pred' else 0.8 |
| | } |
| | ) |
| | |
| | for X, S, pmetric, rec_chain2blocks in zip(batch_X, batch_S, batch_pmetric, batch['rec_chain2blocks']): |
| | if S is None: S = expand_seqs[cnt] |
| | lig_blocks = [] |
| | for x, s in zip(X, S): |
| | abrv = VOCAB.symbol_to_abrv(s) |
| | atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)] |
| | units = [ |
| | Atom(atom_name, coord, atom_name[0]) for atom_name, coord in zip(atoms, x) |
| | ] |
| | lig_blocks.append(Block(abrv, units)) |
| | list_blocks, chain_names = [], [] |
| | for chain in rec_chain2blocks: |
| | list_blocks.append(rec_chain2blocks[chain]) |
| | chain_names.append(chain) |
| | pep_chain_id = chr(max([ord(c) for c in chain_names]) + 1) |
| | list_blocks.append(lig_blocks) |
| | chain_names.append(pep_chain_id) |
| | out_pdb = os.path.join(out_dir, expand_ids[cnt] + '.pdb') |
| | list_blocks_to_pdb(list_blocks, chain_names, out_pdb) |
| | all_pdbs.append(out_pdb) |
| | result_summary.write(json.dumps({ |
| | 'id': expand_ids[cnt], |
| | 'rec_chains': list(rec_chain2blocks.keys()), |
| | 'pep_chain': pep_chain_id, |
| | 'pep_seq': ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) |
| | }) + '\n') |
| | result_summary.flush() |
| | cnt += 1 |
| | result_summary.close() |
| |
|
| | print_log(f'Running openmm relaxation...') |
| | ray.init(num_cpus=8) |
| | futures = [openmm_relax.remote(path) for path in all_pdbs] |
| | pbar = tqdm(total=len(futures)) |
| | while len(futures) > 0: |
| | done_ids, futures = ray.wait(futures, num_returns=1) |
| | for done_id in done_ids: |
| | done_path = ray.get(done_id) |
| | pbar.update(1) |
| | print_log(f'Done') |
| |
|
| |
|
| | def parse(): |
| | parser = argparse.ArgumentParser(description='run pepglad for codesign or structure prediction') |
| | parser.add_argument('--mode', type=str, required=True, choices=['codesign', 'struct_pred'], help='Running mode') |
| | parser.add_argument('--pdb', type=str, required=True, help='Path to the PDB file of the target protein') |
| | parser.add_argument('--pocket', type=str, required=True, help='Path to the pocket definition (*.json generated by detect_pocket)') |
| | parser.add_argument('--n_samples', type=int, default=10, help='Number of samples') |
| | parser.add_argument('--out_dir', type=str, required=True, help='Output directory') |
| | parser.add_argument('--peptide_seq', type=str, required='struct_pred' in sys.argv, help='Peptide sequence for structure prediction') |
| | parser.add_argument('--length_min', type=int, required='codesign' in sys.argv, help='Minimum peptide length for codesign (inclusive)') |
| | parser.add_argument('--length_max', type=int, required='codesign' in sys.argv, help='Maximum peptide length for codesign (exclusive)') |
| | parser.add_argument('--gpu', type=int, default=0, help='GPU to use') |
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = parse() |
| | proj_dir = os.path.join(os.path.dirname(__file__), '..') |
| | ckpt = os.path.join(proj_dir, 'checkpoints', 'fixseq.ckpt' if args.mode == 'struct_pred' else 'codesign.ckpt') |
| | print_log(f'Loading checkpoint: {ckpt}') |
| | design( |
| | mode=args.mode, |
| | ckpt=ckpt, |
| | gpu=args.gpu, |
| | pdbs=[args.pdb], |
| | epitope_defs=[args.pocket], |
| | n_samples=[args.n_samples], |
| | out_dir=args.out_dir, |
| | identifiers=[os.path.basename(os.path.splitext(args.pdb)[0])], |
| | lengths_range=[(args.length_min, args.length_max)] if args.mode == 'codesign' else None, |
| | seqs=[args.peptide_seq] if args.mode == 'struct_pred' else None |
| | ) |