| | """Utility functions for experiments.""" |
| | import logging |
| | import torch |
| | import os |
| | import re |
| | import random |
| | import esm |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import random |
| |
|
| | from analysis import utils as au |
| | from pytorch_lightning.utilities.rank_zero import rank_zero_only |
| | from data.residue_constants import restype_order |
| | from data.repr import get_pre_repr |
| | from data import utils as du |
| | from data.residue_constants import restype_atom37_mask |
| | from openfold.data import data_transforms |
| | from openfold.utils import rigid_utils |
| | from data.cal_trans_rotmats import cal_trans_rotmats |
| |
|
| |
|
| | class LengthDataset(torch.utils.data.Dataset): |
| | def __init__(self, samples_cfg): |
| | self._samples_cfg = samples_cfg |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | validcsv = pd.read_csv(self._samples_cfg.validset_path) |
| |
|
| | self._all_sample_seqs = [] |
| | self._all_filename = [] |
| |
|
| | prob_num = 500 |
| | exp_prob = np.exp([-prob/prob_num*2 for prob in range(prob_num)]).cumsum() |
| | exp_prob = exp_prob/np.max(exp_prob) |
| |
|
| | for idx in range(len(validcsv['seq'])): |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | self._all_filename += [validcsv['file'][idx]] * self._samples_cfg.sample_num |
| |
|
| | for batch_idx in range(self._samples_cfg.sample_num): |
| |
|
| | rand = random.random() |
| | for prob in range(prob_num): |
| | if rand < exp_prob[prob]: |
| | energy = torch.tensor(prob/prob_num) |
| | break |
| |
|
| | self._all_sample_seqs += [(validcsv['seq'][idx], energy)] |
| |
|
| |
|
| | self._all_sample_ids = self._all_sample_seqs |
| |
|
| | |
| | self.device_esm=f'cuda:{torch.cuda.current_device()}' |
| | self.model_esm2, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| | self.batch_converter = self.alphabet.get_batch_converter() |
| | self.model_esm2.eval().cuda(self.device_esm) |
| | self.model_esm2.requires_grad_(False) |
| |
|
| | self._folding_model = esm.pretrained.esmfold_v1().eval() |
| | self._folding_model = self._folding_model.to(self.device_esm) |
| |
|
| | self.esm_savepath = self._samples_cfg.esm_savepath |
| |
|
| |
|
| | self.device_esm=f'cuda:{torch.cuda.current_device()}' |
| | self._folding_model = esm.pretrained.esmfold_v1().eval() |
| | self._folding_model.requires_grad_(False) |
| | self._folding_model.to(self.device_esm) |
| |
|
| | def run_folding(self, sequence, save_path): |
| | """Run ESMFold on sequence.""" |
| | with torch.no_grad(): |
| | output = self._folding_model.infer_pdb(sequence) |
| | self._folding_model.to("cpu") |
| |
|
| | with open(save_path, "w") as f: |
| | f.write(output) |
| | return output |
| |
|
| | def __len__(self): |
| | return len(self._all_sample_ids) |
| |
|
| | def __getitem__(self, idx): |
| | seq, energy = self._all_sample_ids[idx] |
| | aatype = torch.tensor([restype_order[s] for s in seq]) |
| | num_res = len(aatype) |
| |
|
| | node_repr_pre, pair_repr_pre = get_pre_repr(aatype, self.model_esm2, |
| | self.alphabet, self.batch_converter, device = self.device_esm) |
| | node_repr_pre = node_repr_pre[0].cpu() |
| | pair_repr_pre = pair_repr_pre[0].cpu() |
| | |
| | motif_mask = torch.ones(aatype.shape) |
| |
|
| |
|
| | save_path = os.path.join(self.esm_savepath, "esm_" + self._all_filename[idx] + ".pdb") |
| | if not os.path.exists(save_path): |
| | seq_string = seq |
| | with torch.no_grad(): |
| | output = self._folding_model.infer_pdb(seq_string) |
| | with open(save_path, "w") as f: |
| | f.write(output) |
| |
|
| |
|
| | trans_esmfold, rotmats_esmfold = cal_trans_rotmats(save_path) |
| |
|
| | batch = { |
| | 'filename':self._all_filename[idx], |
| | 'trans_esmfold': trans_esmfold, |
| | 'rotmats_esmfold': rotmats_esmfold, |
| | 'motif_mask': motif_mask, |
| | 'res_mask': torch.ones(num_res).int(), |
| | 'num_res': num_res, |
| | 'energy': energy, |
| | 'aatype': aatype, |
| | 'seq': seq, |
| | 'node_repr_pre': node_repr_pre, |
| | 'pair_repr_pre': pair_repr_pre, |
| | } |
| | return batch |
| |
|
| |
|
| |
|
| | def save_traj( |
| | sample: np.ndarray, |
| | bb_prot_traj: np.ndarray, |
| | x0_traj: np.ndarray, |
| | diffuse_mask: np.ndarray, |
| | output_dir: str, |
| | aatype = None, |
| | index=0, |
| | ): |
| | """Writes final sample and reverse diffusion trajectory. |
| | |
| | Args: |
| | bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states. |
| | T is number of time steps. First time step is t=eps, |
| | i.e. bb_prot_traj[0] is the final sample after reverse diffusion. |
| | N is number of residues. |
| | x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. |
| | aatype: [T, N, 21] amino acid probability vector trajectory. |
| | res_mask: [N] residue mask. |
| | diffuse_mask: [N] which residues are diffused. |
| | output_dir: where to save samples. |
| | |
| | Returns: |
| | Dictionary with paths to saved samples. |
| | 'sample_path': PDB file of final state of reverse trajectory. |
| | 'traj_path': PDB file os all intermediate diffused states. |
| | 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. |
| | b_factors are set to 100 for diffused residues and 0 for motif |
| | residues if there are any. |
| | """ |
| |
|
| | |
| | diffuse_mask = diffuse_mask.astype(bool) |
| | sample_path = os.path.join(output_dir, 'sample_'+str(index)+'.pdb') |
| | prot_traj_path = os.path.join(output_dir, 'bb_traj_'+str(index)+'.pdb') |
| | x0_traj_path = os.path.join(output_dir, 'x0_traj_'+str(index)+'.pdb') |
| |
|
| | |
| | b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37)) |
| |
|
| | sample_path = au.write_prot_to_pdb( |
| | sample, |
| | sample_path, |
| | b_factors=b_factors, |
| | no_indexing=True, |
| | aatype=aatype, |
| | ) |
| | prot_traj_path = au.write_prot_to_pdb( |
| | bb_prot_traj, |
| | prot_traj_path, |
| | b_factors=b_factors, |
| | no_indexing=True, |
| | aatype=aatype, |
| | ) |
| | x0_traj_path = au.write_prot_to_pdb( |
| | x0_traj, |
| | x0_traj_path, |
| | b_factors=b_factors, |
| | no_indexing=True, |
| | aatype=aatype |
| | ) |
| | return { |
| | 'sample_path': sample_path, |
| | 'traj_path': prot_traj_path, |
| | 'x0_traj_path': x0_traj_path, |
| | } |
| |
|
| |
|
| | def get_pylogger(name=__name__) -> logging.Logger: |
| | """Initializes multi-GPU-friendly python command line logger.""" |
| |
|
| | logger = logging.getLogger(name) |
| |
|
| | |
| | |
| | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") |
| | for level in logging_levels: |
| | setattr(logger, level, rank_zero_only(getattr(logger, level))) |
| |
|
| | return logger |
| |
|
| |
|
| | def flatten_dict(raw_dict): |
| | """Flattens a nested dict.""" |
| | flattened = [] |
| | for k, v in raw_dict.items(): |
| | if isinstance(v, dict): |
| | flattened.extend([ |
| | (f'{k}:{i}', j) for i, j in flatten_dict(v) |
| | ]) |
| | else: |
| | flattened.append((k, v)) |
| | return flattened |
| |
|