diff --git a/README.md b/README.md index 6ddff2f305ace52cf0365e74625592c1a32ac0be..41d9bffaf15028ff141f8e5bf7f458c44510add8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,134 @@ ---- -license: gpl-3.0 ---- +# P2DFlow + +> ## ℹ️ The version 2 of codes for P2DFlow will come soon to align with the new revised paper, and for easier use and modification (the old version can run correctly) + +P2DFlow is a protein ensemble generative model with SE(3) flow matching based on ESMFold, the ensembles generated by P2DFlow could aid in understanding protein functions across various scenarios. + +Technical details and evaluation results are provided in our paper: +* [P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching](https://arxiv.org/abs/2411.17196) + +

+ +

+ +![P2DFlow](resources/gen_example.gif) + + +## Table of Contents +1. [Installation](#Installation) +2. [Prepare Dataset](#Prepare-Dataset) +3. [Model weights](#Model-weights) +4. [Training](#Training) +5. [Inference](#Inference) +6. [Evaluation](#Evaluation) +7. [License](#License) +8. [Citation](#Citation) + + +## Installation +In an environment with cuda 11.7, run: +``` +conda env create -f environment.yml +``` +To activate the environment, run: +``` +conda activate P2DFlow +``` + +## Prepare Dataset +#### (tips: If you want to use the data we have preprocessed, please go directly to `3. Process selected dataset`; if you prefer to process the data from scratch or work with your own data, please start from the beginning) + +#### 1. Download raw ATLAS dataset +(i) Download the `Analysis & MDs` dataset from [ATLAS](https://www.dsimb.inserm.fr/ATLAS/), or you can use `./dataset/download.py` by running: +``` +python ./dataset/download.py +``` +We will use `.pdb` and `.xtc` files for the following calculation. + +#### 2. Calculate the 'approximate energy and select representative structures +(i) Use `gaussian_kde` to calculate the 'approximate energy' (You need to put all files above in `./dataset`, include `ATLAS_filename.txt` for filenames of all proteins): +``` +python ./dataset/traj_analyse.py +``` +And you will get `traj_info.csv`. + +(ii) Select representative structures at equal intervals based on the 'approximate energy': +``` +python ./dataset/md_select.py +``` + +#### 3. Process selected dataset + +(i) Download the selected dataset (or get it from the two steps above) from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `selected_dataset.tar`, and decompress it using: +``` +tar -xvf select_dataset.tar +``` +(ii) Preprocess `.pdb` files to get `.pkl` files: +``` +python ./data/process_pdb_files.py --pdb_dir ${pdb_dir} --write_dir ${write_dir} +``` +And you will get `metadata.csv`. + +then compute node representation and pair representation using ESM-2 (`csv_path` is the path of `metadata.csv`): +``` +python ./data/cal_repr.py --csv_path ${csv_path} +``` +then compute predicted static structure using ESMFold (`csv_path` is the path of `metadata.csv`): +``` +python ./data/cal_static_structure.py --csv_path ${csv_path} +``` +(iii) Provide the necessary `.csv` files for training + +If you are using the data we have preprocessed, download the `.csv` files from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filenames are `train_dataset.csv` and `train_dataset_energy.csv`(they correspond to `csv_path` and `energy_csv_path` in `./configs/base.yaml` during training). + +Or if you are using your own data, you can get `metadata.csv` from step 3 (correspond to `csv_path` in `./configs/base.yaml` during training, and you need to split a subset from it as the train dataset), and get `traj_info.csv` from step 2 (correspond to `energy_csv_path`). + + + +## Model weights +Download the pretrained checkpoint from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `pretrained.ckpt`, and put it into `./weights` folder. You can use the pretrained weight for inference. + + +## Training +To train P2DFlow, firstly make sure you have prepared the dataset according to `Prepare Dataset`, and put it in the right folder, then modify `./configs/base.yaml` (especially for `csv_path` and `energy_csv_path`). After this, you can run: +``` +python experiments/train_se3_flows.py +``` +And you will get the checkpoints in `./ckpt`. + + +## Inference +To infer for specified protein sequence, firstly modify `./inference/valid_seq.csv` and `./configs/inference.yaml` (especially for `validset_path`), then run: +``` +python experiments/inference_se3_flows.py +``` +And you will get the results in `./inference_outputs/weights/`. + + +## Evaluation +To evaluate metrics related to fidelity and dynamics, specify paths in `./analysis/eval_test.py`, then run: +``` +python ./analysis/eval_test.py +``` +To evaluate PCA, specify paths in `./analysis/pca_analyse.py`, then run: +``` +python ./analysis/pca_analyse.py +``` +To draw the ramachandran plots, specify paths in `./analysis/Ramachandran_plot.py`, then run: +``` +python ./analysis/Ramachandran_plot.py +``` + +## License +This project is licensed under the terms of the GPL-3.0 license. + + +## Citation +``` +@article{jin2024p2dflow, + title={P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching}, + author={Yaowei Jin, Qi Huang, Ziyang Song, Mingyue Zheng, Dan Teng, Qian Shi}, + journal={arXiv preprint arXiv:2411.17196}, + year={2024} +} +``` diff --git a/analysis/Ramachandran_plot.py b/analysis/Ramachandran_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..c3660a4e016b22952e8ba86dbd26ed66812a0bfa --- /dev/null +++ b/analysis/Ramachandran_plot.py @@ -0,0 +1,99 @@ +import MDAnalysis as mda +import numpy as np +from MDAnalysis.analysis.dihedrals import Ramachandran +import os +import re +import pandas as pd +import matplotlib.pyplot as plt + + +def ramachandran_eval(all_paths, pdb_file, output_dir): + angle_results_all = [] + + for dirpath in all_paths: + pdb_path = os.path.join(dirpath,pdb_file) + + u = mda.Universe(pdb_path) + protein = u.select_atoms('protein') + # print('There are {} residues in the protein'.format(len(protein.residues))) + + ramachandran = Ramachandran(protein) + ramachandran.run() + angle_results = ramachandran.results.angles + # print(angle_results.shape) + + # ramachandran.plot(color='black', marker='.') + + angle_results_all.append(angle_results.reshape([-1,2])) + + + # df = pd.DataFrame(angle_results.reshape([-1,2])) + # df.to_csv(os.path.join(output_dir, os.path.basename(dirpath)+'_'+pdb_file.split('.')[0]+'.csv'), index=False) + + + points1 = angle_results_all[0] + grid_size = 360 # 网格的大小 + x_bins = np.linspace(-180, 180, grid_size) + y_bins = np.linspace(-180, 180, grid_size) + result_tmp={} + for idx in range(len(angle_results_all[1:])): + idx = idx + 1 + points2 = angle_results_all[idx] + + # 使用2D直方图统计每组点在网格上的分布 + hist1, _, _ = np.histogram2d(points1[:, 0], points1[:, 1], bins=[x_bins, y_bins]) + hist2, _, _ = np.histogram2d(points2[:, 0], points2[:, 1], bins=[x_bins, y_bins]) + + # 将直方图转换为布尔值,表示某个网格是否有点落入 + mask1 = hist1 > 0 + mask2 = hist2 > 0 + + intersection = np.logical_and(mask1, mask2).sum() + all_mask2 = mask2.sum() + val_ratio = intersection / all_mask2 + print(os.path.basename(all_paths[idx]), "val_ratio:", val_ratio) + + + result_tmp[os.path.basename(all_paths[idx])] = val_ratio + result_tmp['file'] = pdb_file + + return result_tmp + + + +if __name__ == "__main__": + key1 = 'P2DFlow_epoch19' + all_paths = [ + "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/ATLAS_valid", + # "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/esm_n_pred", + "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/alphaflow_pred", + "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Str2Str_pred", + + f'/cluster/home/shiqian/frame-flow-test1/valid/evaluate/{key1}', + + ] + output_dir = '/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Ramachandran' + os.makedirs(output_dir, exist_ok=True) + results={ + 'file':[], + # 'esm_n_pred':[], + 'alphaflow_pred':[], + 'Str2Str_pred':[], + + key1:[], + } + for file in os.listdir(all_paths[0]): + if re.search('\.pdb',file): + + pdb_file = file + print(file) + result_tmp = ramachandran_eval( + all_paths=all_paths, + pdb_file=pdb_file, + output_dir=output_dir + ) + for key in results.keys(): + results[key].append(result_tmp[key]) + + out_total_df = pd.DataFrame(results) + out_total_df.to_csv(os.path.join(output_dir,f'Ramachandran_plot_validity_{key1}.csv'),index=False) diff --git a/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc b/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ec969c89ac3f14e4082d215fa1ab4e9e48a217 Binary files /dev/null and b/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc differ diff --git a/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc b/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2006cceb04da306e21b3d87a1e73548994aa25cc Binary files /dev/null and b/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc differ diff --git a/analysis/__pycache__/metrics.cpython-310.pyc b/analysis/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9fe51afc9bc53a9707b083799cd1d594fc67e8 Binary files /dev/null and b/analysis/__pycache__/metrics.cpython-310.pyc differ diff --git a/analysis/__pycache__/utils.cpython-310.pyc b/analysis/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3cf847b8371ad05b40a50d85f93733692824c6 Binary files /dev/null and b/analysis/__pycache__/utils.cpython-310.pyc differ diff --git a/analysis/__pycache__/utils.cpython-38.pyc b/analysis/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dba74aaa6f8107a87afd99f9de5c18c57001386 Binary files /dev/null and b/analysis/__pycache__/utils.cpython-38.pyc differ diff --git a/analysis/eval_result.py b/analysis/eval_result.py new file mode 100644 index 0000000000000000000000000000000000000000..48ae898e37bf51709dcdba4cdfe15b7be59a8238 --- /dev/null +++ b/analysis/eval_result.py @@ -0,0 +1,66 @@ +import os +import re +import sys +sys.path.append('./analysis') +import argparse + +import pandas as pd +from src.eval import evaluate_prediction +from merge_pred_pdb import merge_pdb_full +from Ramachandran_plot import ramachandran_eval + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument("--pred_org_dir", type=str, default="./inference_outputs/weights/pretrained/2025-03-13_10-08") + parser.add_argument("--valid_csv_file", type=str, default="./inference/valid_seq.csv") + parser.add_argument("--pred_merge_dir", type=str, default="./inference/test/pred_merge_results") + parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir") + parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir") + + args = parser.parse_args() + + + # merge pdb + pred_org_dir = args.pred_org_dir + valid_csv_file = args.valid_csv_file + pred_merge_dir = args.pred_merge_dir + merge_pdb_full(pred_org_dir, valid_csv_file, pred_merge_dir) + + + # cal_eval + pred_merge_dir = args.pred_merge_dir + target_dir = args.target_dir + crystal_dir = args.crystal_dir + evaluate_prediction(pred_merge_dir, target_dir, crystal_dir) + + + # cal_RP + all_paths = [ + args.target_dir, + args.pred_merge_dir, + ] + results={} + for file in os.listdir(all_paths[0]): + if re.search('\.pdb',file): + + pdb_file = file + print(file) + result_tmp = ramachandran_eval( + all_paths=all_paths, + pdb_file=pdb_file, + output_dir=args.pred_merge_dir, + ) + + for pred_paths in all_paths[1:]: + key_name = os.path.basename(pred_paths) + if key_name is results.keys(): + results[key_name].append(result_tmp[key_name]) + else: + results[key_name] = [result_tmp[key_name]] + + out_total_df = pd.DataFrame(results) + out_total_df.to_csv(os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv'), index=False) + print(f"RP results saved to {os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv')}") + + diff --git a/analysis/merge_pred_pdb.py b/analysis/merge_pred_pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4366187ef1c19abe425ea43c99d63dcbf90812 --- /dev/null +++ b/analysis/merge_pred_pdb.py @@ -0,0 +1,45 @@ +import os +import re +import pandas as pd +from Bio.PDB import PDBParser, PDBIO + +def merge_pdb(work_dir, new_file, ref_pdb): + parser = PDBParser() + structures = [] + for pdb_dir in os.listdir(work_dir): + pattern=".*"+ref_pdb + pdb_dir_full=os.path.join(work_dir,pdb_dir) + if os.path.isdir(pdb_dir_full) and re.match(pattern,pdb_dir): + for pdb_file in os.listdir(pdb_dir_full): + if re.match("sample.*\.pdb",pdb_file): + structure = parser.get_structure(pdb_file, os.path.join(work_dir,pdb_dir,pdb_file)) + structures.append(structure) + + if len(structures) == 0: + return + print(ref_pdb,len(structures),"files") + + new_structure = structures[0] + count = 0 + for structure in structures[1:]: + for model in structure: + count += 1 + # print(dir(model)) + model.id = count + new_structure.add(model) + + io = PDBIO() + io.set_structure(new_structure) + io.save(new_file) + + +def merge_pdb_full(inference_dir_f, valid_csv, output_dir): + os.makedirs(output_dir,exist_ok=True) + valid_set = pd.read_csv(valid_csv) + for filename in valid_set['file']: + output_file = os.path.join(output_dir, filename+".pdb") + merge_pdb(inference_dir_f, output_file, filename) + + + + diff --git a/analysis/metrics.py b/analysis/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..08f52fb6ac7ca7ee0547c5c76565f0ee8f5a45cd --- /dev/null +++ b/analysis/metrics.py @@ -0,0 +1,54 @@ +""" Metrics. """ +import mdtraj as md +import numpy as np +from openfold.np import residue_constants +from tmtools import tm_align +from data import utils as du + +def calc_tm_score(pos_1, pos_2, seq_1, seq_2): + tm_results = tm_align(pos_1, pos_2, seq_1, seq_2) + return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2 + +def calc_mdtraj_metrics(pdb_path): + try: + traj = md.load(pdb_path) + pdb_ss = md.compute_dssp(traj, simplified=True) + pdb_coil_percent = np.mean(pdb_ss == 'C') + pdb_helix_percent = np.mean(pdb_ss == 'H') + pdb_strand_percent = np.mean(pdb_ss == 'E') + pdb_ss_percent = pdb_helix_percent + pdb_strand_percent + pdb_rg = md.compute_rg(traj)[0] + except IndexError as e: + print('Error in calc_mdtraj_metrics: {}'.format(e)) + pdb_ss_percent = 0.0 + pdb_coil_percent = 0.0 + pdb_helix_percent = 0.0 + pdb_strand_percent = 0.0 + pdb_rg = 0.0 + return { + 'non_coil_percent': pdb_ss_percent, + 'coil_percent': pdb_coil_percent, + 'helix_percent': pdb_helix_percent, + 'strand_percent': pdb_strand_percent, + 'radius_of_gyration': pdb_rg, + } + +def calc_aligned_rmsd(pos_1, pos_2): + aligned_pos_1 = du.rigid_transform_3D(pos_1, pos_2)[0] + return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1)) + +def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0): + ca_bond_dists = np.linalg.norm( + ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:] + ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca)) + ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol)) + + ca_ca_dists2d = np.linalg.norm( + ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1) + inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)] + clashes = inter_dists < clash_tol + return { + 'ca_ca_deviation': ca_ca_dev, + 'ca_ca_valid_percent': ca_ca_valid, + 'num_ca_ca_clashes': np.sum(clashes), + } diff --git a/analysis/pca_analyse.py b/analysis/pca_analyse.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c04e261eedfa1935952f681327e76646cd8b4 --- /dev/null +++ b/analysis/pca_analyse.py @@ -0,0 +1,116 @@ +import os +import re +import pandas as pd +import MDAnalysis as mda +from MDAnalysis.analysis import pca, align, rms +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import warnings +import argparse +warnings.filterwarnings("ignore") + + +def cal_PCA(md_pdb_path,ref_path,pred_pdb_path,n_components = 2): + print("") + print('filename=',os.path.basename(ref_path)) + + u = mda.Universe(md_pdb_path, md_pdb_path) + u_ref = mda.Universe(ref_path, ref_path) + + aligner = align.AlignTraj(u, + u_ref, + select='name CA or name C or name N', + in_memory=True).run() + + pc = pca.PCA(u, + select='name CA or name C or name N', + align=False, mean=None, + # n_components=None, + n_components=n_components, + ).run() + + backbone = u.select_atoms('name CA or name C or name N') + n_bb = len(backbone) + print('There are {} backbone atoms in the analysis'.format(n_bb)) + + for i in range(n_components): + print(f"Cumulated variance {i+1}: {pc.cumulated_variance[i]:.3f}") + + transformed = pc.transform(backbone, n_components=n_components) + + print(transformed.shape) # (3000, 2) + + df = pd.DataFrame(transformed, + columns=['PC{}'.format(i+1) for i in range(n_components)]) + + + plt.scatter(df['PC1'],df['PC2'],marker='o') + plt.show() + + output_dir = os.path.dirname(md_pdb_path) + output_filename = os.path.basename(md_pdb_path).split('.')[0] + + df.to_csv(os.path.join(output_dir, f'{output_filename}_md_pca.csv')) + plt.savefig(os.path.join(output_dir, f'{output_filename}_md_pca.png')) + + + for k,v in pred_pdb_path.items(): + u_pred = mda.Universe(v, v) + aligner = align.AlignTraj(u_pred, + u_ref, + select='name CA or name C or name N', + in_memory=True).run() + pred_backbone = u_pred.select_atoms('name CA or name C or name N') + pred_transformed = pc.transform(pred_backbone, n_components=n_components) + + df = pd.DataFrame(pred_transformed, + columns=['PC{}'.format(i+1) for i in range(n_components)]) + + plt.scatter(df['PC1'],df['PC2'],marker='o') + plt.show() + + output_dir = os.path.dirname(v) + output_filename = os.path.basename(v).split('.')[0] + df.to_csv(os.path.join(output_dir, f'{output_filename}_{k}_pca.csv')) + plt.savefig(os.path.join(output_dir, f'{output_filename}_{k}_pca.png')) + plt.clf() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument("--pred_pdb_dir", type=str, default="./inference/test/pred_merge_results") + parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir") + parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir") + + args = parser.parse_args() + + + pred_pdb_path_org={ + 'P2DFlow':args.pred_pdb_dir, + } + md_pdb_path_org = args.target_dir + ref_path_org = args.crystal_dir + + + for file in os.listdir(md_pdb_path_org): + if re.search('\.pdb',file): + pred_pdb_path={ + 'P2DFlow':'', + # 'alphaflow':'', + # 'Str2Str':'', + } + for k,v in pred_pdb_path.items(): + pred_pdb_path[k]=os.path.join(pred_pdb_path_org[k],file) + md_pdb_path = os.path.join(md_pdb_path_org, file) + ref_path = os.path.join(ref_path_org, file) + cal_PCA(md_pdb_path,ref_path,pred_pdb_path) + + + + + + + diff --git a/analysis/src/__init__.py b/analysis/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/__pycache__/__init__.cpython-310.pyc b/analysis/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b45c4b20fb276672c934f9937f860c0d778ff5 Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/analysis/src/__pycache__/__init__.cpython-37.pyc b/analysis/src/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b8975ce7ab72af6367e9cd2689079059aead9ef Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-37.pyc differ diff --git a/analysis/src/__pycache__/__init__.cpython-39.pyc b/analysis/src/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0970c21c1f8f16b2738cb805bcbefebd8380fe Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/__pycache__/eval.cpython-310.pyc b/analysis/src/__pycache__/eval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41c4561ea536a75d1dbff7dfc80355ec43e4a1b9 Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-310.pyc differ diff --git a/analysis/src/__pycache__/eval.cpython-37.pyc b/analysis/src/__pycache__/eval.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b58d4ae3099de98bab35e261c6d866a4a1f445e Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-37.pyc differ diff --git a/analysis/src/__pycache__/eval.cpython-39.pyc b/analysis/src/__pycache__/eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ffe4d66740cbc14ac30ca9a23bceeaee9c9dca8 Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-39.pyc differ diff --git a/analysis/src/common/__init__.py b/analysis/src/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/common/__pycache__/__init__.cpython-310.pyc b/analysis/src/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b27312e0ca59c8a4668f989b47079cfd0fc8bc3c Binary files /dev/null and b/analysis/src/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/analysis/src/common/__pycache__/__init__.cpython-39.pyc b/analysis/src/common/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db4b9bec0684f401570540a5b6f21d4630d0bc9c Binary files /dev/null and b/analysis/src/common/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/all_atom.cpython-39.pyc b/analysis/src/common/__pycache__/all_atom.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3c3307dc43eed2f9fc0eaf5106fb4322736a72b Binary files /dev/null and b/analysis/src/common/__pycache__/all_atom.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc b/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfcf5d1844248f47fa19a4083216ec3c27b6fe5c Binary files /dev/null and b/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc b/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81e4f5ab50727c67b65472fb379bc620e47a71c0 Binary files /dev/null and b/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc differ diff --git a/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc b/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..542b78adb96abcc3e32c308aecc1b3cf4bacec66 Binary files /dev/null and b/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc b/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..400edcbb597ec5560f73df0b11c9346edb7c8b11 Binary files /dev/null and b/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc differ diff --git a/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc b/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c90e9b484d6e3bd6d6d4a700d0d6d5c9360ccd Binary files /dev/null and b/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/protein.cpython-310.pyc b/analysis/src/common/__pycache__/protein.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3f19716658346bec4690f0f591c1eef317a1b4 Binary files /dev/null and b/analysis/src/common/__pycache__/protein.cpython-310.pyc differ diff --git a/analysis/src/common/__pycache__/protein.cpython-39.pyc b/analysis/src/common/__pycache__/protein.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24c502a02413adb73c405278c3a974b7c07055e0 Binary files /dev/null and b/analysis/src/common/__pycache__/protein.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc b/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c3496c2c6dfd4cdbc79e7a4b1ddbd1f70bd7b1b Binary files /dev/null and b/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc differ diff --git a/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc b/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0edc630ea11af65cd953ad52c22b8f11a3999efc Binary files /dev/null and b/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc b/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f436865e10980ad8ddeed2a8963702cd54c98065 Binary files /dev/null and b/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc differ diff --git a/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc b/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ddea4ceb14e6a9aae905da01b8af85006fc7064 Binary files /dev/null and b/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc differ diff --git a/analysis/src/common/all_atom.py b/analysis/src/common/all_atom.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b22993bdbb680c3c583e800eb0741b65fd0c36 --- /dev/null +++ b/analysis/src/common/all_atom.py @@ -0,0 +1,219 @@ +""" +Utilities for calculating all atom representations. +""" + +import torch + +from src.common import residue_constants as rc +from src.common.data_transforms import atom37_to_torsion_angles +from src.common.rigid_utils import Rigid, Rotation + + +# Residue Constants from OpenFold/AlphaFold2. +IDEALIZED_POS37 = torch.tensor(rc.restype_atom37_rigid_group_positions) +IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1) +IDEALIZED_POS = torch.tensor(rc.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.tensor(rc.restype_rigid_group_default_frame) +ATOM_MASK = torch.tensor(rc.restype_atom14_mask) +GROUP_IDX = torch.tensor(rc.restype_atom14_to_rigid_group) + + +def torsion_angles_to_frames( + r: Rigid, + alpha: torch.Tensor, + aatype: torch.Tensor, +): + # [*, N, 8, 4, 4] + default_4x4 = DEFAULT_FRAMES[aatype, ...].to(r.device) + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat( + [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2 + ) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def prot_to_torsion_angles(aatype, atom37, atom37_mask): + """Calculate torsion angle features from protein features.""" + prot_feats = { + 'aatype': aatype, + 'all_atom_positions': atom37, + 'all_atom_mask': atom37_mask, + } + torsion_angles_feats = atom37_to_torsion_angles()(prot_feats) + torsion_angles = torsion_angles_feats['torsion_angles_sin_cos'] + torsion_mask = torsion_angles_feats['torsion_angles_mask'] + return torsion_angles, torsion_mask + + +def frames_to_atom14_pos( + r: Rigid, + aatype: torch.Tensor, + ): + """Convert frames to their idealized all atom representation. + + Args: + r: All rigid groups. [..., N, 8, 3] + aatype: Residue types. [..., N] + + Returns: + + """ + + # [*, N, 14] + group_mask = GROUP_IDX[aatype, ...] + + # [*, N, 14, 8] + group_mask = torch.nn.functional.one_hot( + group_mask, + num_classes=DEFAULT_FRAMES.shape[-3], + ).to(r.device) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1) + ) + + # [*, N, 14, 1] + frame_atom_mask = ATOM_MASK[aatype, ...].unsqueeze(-1).to(r.device) + + # [*, N, 14, 3] + frame_null_pos = IDEALIZED_POS[aatype, ...].to(r.device) + pred_positions = t_atoms_to_global.apply(frame_null_pos) + pred_positions = pred_positions * frame_atom_mask + + return pred_positions + + +def compute_backbone(bb_rigids, psi_torsions, aatype=None, device=None): + if device is None: + device = bb_rigids.device + + torsion_angles = torch.tile( + psi_torsions[..., None, :], + tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1) + ).to(device) + + # aatype must be on cpu for initializing the tensor by indexing + if aatype is None: + aatype = torch.zeros_like(bb_rigids).cpu().long() + else: + aatype = aatype.cpu() + + + all_frames = torsion_angles_to_frames( + bb_rigids, + torsion_angles, + aatype, + ) + atom14_pos = frames_to_atom14_pos( + all_frames, + aatype, + ) + atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=device) + # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB'] + # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O'] + atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :] + atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :] + atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :] + atom37_mask = torch.any(atom37_bb_pos, axis=-1) + return atom37_bb_pos, atom37_mask, aatype.to(device), atom14_pos + + +def calculate_neighbor_angles(R_ac, R_ab): + """Calculate angles between atoms c <- a -> b. + + Parameters + ---------- + R_ac: Tensor, shape = (N,3) + Vector from atom a to c. + R_ab: Tensor, shape = (N,3) + Vector from atom a to b. + + Returns + ------- + angle_cab: Tensor, shape = (N,) + Angle between atoms c <- a -> b. + """ + # cos(alpha) = (u * v) / (|u|*|v|) + x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,) + # sin(alpha) = |u x v| / (|u|*|v|) + y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,) + # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN + y = torch.max(y, torch.tensor(1e-9)) + angle = torch.atan2(y, x) + return angle + + +def vector_projection(R_ab, P_n): + """ + Project the vector R_ab onto a plane with normal vector P_n. + + Parameters + ---------- + R_ab: Tensor, shape = (N,3) + Vector from atom a to b. + P_n: Tensor, shape = (N,3) + Normal vector of a plane onto which to project R_ab. + + Returns + ------- + R_ab_proj: Tensor, shape = (N,3) + Projected vector (orthogonal to P_n). + """ + a_x_b = torch.sum(R_ab * P_n, dim=-1) + b_x_b = torch.sum(P_n * P_n, dim=-1) + return R_ab - (a_x_b / b_x_b)[:, None] * P_n diff --git a/analysis/src/common/data_transforms.py b/analysis/src/common/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..85133bac06721da875d25369884bdc8021b63d2e --- /dev/null +++ b/analysis/src/common/data_transforms.py @@ -0,0 +1,1194 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from functools import reduce, wraps +from operator import add + +import numpy as np +import torch + +from src.common import residue_constants as rc +from src.common.rigid_utils import Rotation, Rigid +from src.utils.tensor_utils import ( + tree_map, + tensor_tree_map, + batched_gather, +) + +NUM_RES = "num residues placeholder" +NUM_MSA_SEQ = "msa placeholder" +NUM_EXTRA_SEQ = "extra msa placeholder" +NUM_TEMPLATES = "num templates placeholder" + +MSA_FEATURE_NAMES = [ + "msa", + "deletion_matrix", + "msa_mask", + "msa_row_mask", + "bert_mask", + "true_msa", +] + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if v.dtype == torch.int32: + protein[k] = v.type(torch.int64) + + return protein + + +def make_one_hot(x, num_classes): + x_one_hot = torch.zeros(*x.shape, num_classes) + x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) + return x_one_hot + + +def make_seq_mask(protein): + protein["seq_mask"] = torch.ones( + protein["aatype"].shape, dtype=torch.float32 + ) + return protein + + +def make_template_mask(protein): + protein["template_mask"] = torch.ones( + protein["template_aatype"].shape[0], dtype=torch.float32 + ) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def make_all_atom_aatype(protein): + protein["all_atom_aatype"] = protein["aatype"] + return protein + + +def fix_templates_aatype(protein): + # Map one-hot to indices + num_templates = protein["template_aatype"].shape[0] + if(num_templates > 0): + protein["template_aatype"] = torch.argmax( + protein["template_aatype"], dim=-1 + ) + # Map hhsearch-aatype to our aatype. + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = torch.tensor(new_order_list, dtype=torch.int64).expand( + num_templates, -1 + ) + protein["template_aatype"] = torch.gather( + new_order, 1, index=protein["template_aatype"] + ) + + return protein + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as rc.""" + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = torch.tensor( + [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype + ).transpose(0, 1) + protein["msa"] = torch.gather(new_order, 0, protein["msa"]) + + perm_matrix = np.zeros((22, 22), dtype=np.float32) + perm_matrix[range(len(new_order_list)), new_order_list] = 1.0 + + for k in protein: + if "profile" in k: + num_dim = protein[k].shape.as_list()[-1] + assert num_dim in [ + 20, + 21, + 22, + ], "num_dim for %s out of expected range: %s" % (k, num_dim) + protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim]) + + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + protein["aatype"] = torch.argmax(protein["aatype"], dim=-1) + for k in [ + "domain_name", + "msa", + "num_alignments", + "seq_length", + "sequence", + "superfamily", + "deletion_matrix", + "resolution", + "between_segment_residues", + "residue_index", + "template_all_atom_mask", + ]: + if k in protein: + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ["seq_length", "num_alignments"]: + if k in protein: + protein[k] = protein[k][0] + + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a portion of the MSA with 'X'.""" + msa_mask = torch.rand(protein["msa"].shape) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx) + protein["msa"] = torch.where( + msa_mask, + torch.ones_like(protein["msa"]) * x_idx, + protein["msa"] + ) + aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion + + protein["aatype"] = torch.where( + aatype_mask, + torch.ones_like(protein["aatype"]) * x_idx, + protein["aatype"], + ) + return protein + + +@curry1 +def sample_msa(protein, max_seq, keep_extra, seed=None): + """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" + num_seq = protein["msa"].shape[0] + g = torch.Generator(device=protein["msa"].device) + if seed is not None: + g.manual_seed(seed) + shuffled = torch.randperm(num_seq - 1, generator=g) + 1 + index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) + num_sel = min(max_seq, num_seq) + sel_seq, not_sel_seq = torch.split( + index_order, [num_sel, num_seq - num_sel] + ) + + for k in MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein["extra_" + k] = torch.index_select( + protein[k], 0, not_sel_seq + ) + protein[k] = torch.index_select(protein[k], 0, sel_seq) + + return protein + + +@curry1 +def add_distillation_flag(protein, distillation): + protein['is_distillation'] = distillation + return protein + +@curry1 +def sample_msa_distillation(protein, max_seq): + if(protein["is_distillation"] == 1): + protein = sample_msa(max_seq, keep_extra=False)(protein) + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + num_seq = protein["extra_msa"].shape[0] + num_sel = min(max_extra_msa, num_seq) + select_indices = torch.randperm(num_seq)[:num_sel] + for k in MSA_FEATURE_NAMES: + if "extra_" + k in protein: + protein["extra_" + k] = torch.index_select( + protein["extra_" + k], 0, select_indices + ) + + return protein + + +def delete_extra_msa(protein): + for k in MSA_FEATURE_NAMES: + if "extra_" + k in protein: + del protein["extra_" + k] + return protein + + +# Not used in inference +@curry1 +def block_delete_msa(protein, config): + num_seq = protein["msa"].shape[0] + block_num_seq = torch.floor( + torch.tensor(num_seq, dtype=torch.float32) + * config.msa_fraction_per_block + ).to(torch.int32) + + if config.randomize_num_blocks: + nb = torch.distributions.uniform.Uniform( + 0, config.num_blocks + 1 + ).sample() + else: + nb = config.num_blocks + + del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) + del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) + del_blocks = torch.clip(del_blocks, 0, num_seq - 1) + del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] + + # Make sure we keep the original sequence + combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None])) + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + intersection = uniques[counts > 1] + keep_indices = torch.squeeze(difference, 0) + + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.gather(protein[k], keep_indices) + + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): + weights = torch.cat( + [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)], + 0, + ) + + # Make agreement score as weighted Hamming distance + msa_one_hot = make_one_hot(protein["msa"], 23) + sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot + extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23) + extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot + + num_seq, num_res, _ = sample_one_hot.shape + extra_num_seq, _, _ = extra_one_hot.shape + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + agreement = torch.matmul( + torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + torch.reshape( + sample_one_hot * weights, [num_seq, num_res * 23] + ).transpose(0, 1), + ) + + # Assign each sequence in the extra sequences to the closest MSA sample + protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to( + torch.int64 + ) + + return protein + + +def unsorted_segment_sum(data, segment_ids, num_segments): + """ + Computes the sum along segments of a tensor. Similar to + tf.unsorted_segment_sum, but only supports 1-D indices. + + :param data: A tensor whose segments are to be summed. + :param segment_ids: The 1-D segment indices tensor. + :param num_segments: The number of segments. + :return: A tensor of same data type as the data argument. + """ + assert ( + len(segment_ids.shape) == 1 and + segment_ids.shape[0] == data.shape[0] + ) + segment_ids = segment_ids.view( + segment_ids.shape[0], *((1,) * len(data.shape[1:])) + ) + segment_ids = segment_ids.expand(data.shape) + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor + + +@curry1 +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = protein["msa"].shape[0] + + def csum(x): + return unsorted_segment_sum( + x, protein["extra_cluster_assignment"], num_seq + ) + + mask = protein["extra_msa_mask"] + mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23)) + msa_sum += make_one_hot(protein["msa"], 23) # Original sequence + protein["cluster_profile"] = msa_sum / mask_counts[:, :, None] + del msa_sum + + del_sum = csum(mask * protein["extra_deletion_matrix"]) + del_sum += protein["deletion_matrix"] # Original sequence + protein["cluster_deletion_mean"] = del_sum / mask_counts + del del_sum + + return protein + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32) + protein["msa_row_mask"] = torch.ones( + (protein["msa"].shape[0]), dtype=torch.float32 + ) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + is_gly = torch.eq(aatype, rc.restype_order["G"]) + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_mask is not None: + pseudo_beta_mask = torch.where( + is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=""): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ["", "template_"] + ( + protein[prefix + "pseudo_beta"], + protein[prefix + "pseudo_beta_mask"], + ) = pseudo_beta_fn( + protein["template_aatype" if prefix else "aatype"], + protein[prefix + "all_atom_positions"], + protein["template_all_atom_mask" if prefix else "all_atom_mask"], + ) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = torch.tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = probs.shape + num_classes = ds[-1] + distribution = torch.distributions.categorical.Categorical( + torch.reshape(probs + epsilon, [-1, num_classes]) + ) + counts = distribution.sample() + return torch.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if "hhblits_profile" in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + msa_one_hot = make_one_hot(protein["msa"], 22) + + protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0) + return protein + + +@curry1 +def make_masked_msa(protein, config, replace_fraction): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein["hhblits_profile"] + + config.same_prob * make_one_hot(protein["msa"], 22) + ) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = list( + reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))]) + ) + pad_shapes[1] = 1 + mask_prob = ( + 1.0 - config.profile_prob - config.same_prob - config.uniform_prob + ) + assert mask_prob >= 0.0 + categorical_probs = torch.nn.functional.pad( + categorical_probs, pad_shapes, value=mask_prob + ) + + sh = protein["msa"].shape + mask_position = torch.rand(sh) < replace_fraction + + bert_msa = shaped_categorical(categorical_probs) + bert_msa = torch.where(mask_position, bert_msa, protein["msa"]) + + # Mix real and masked MSA + protein["bert_mask"] = mask_position.to(torch.float32) + protein["true_msa"] = protein["msa"] + protein["msa"] = bert_msa + + return protein + + +@curry1 +def make_fixed_size( + protein, + shape_schema, + msa_cluster_size, + extra_msa_size, + num_res=0, + num_templates=0, +): + """Guess at the MSA and sequence dimension to make fixed size.""" + pad_size_map = { + NUM_RES: num_res, + NUM_MSA_SEQ: msa_cluster_size, + NUM_EXTRA_SEQ: extra_msa_size, + NUM_TEMPLATES: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == "extra_cluster_assignment": + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = "Rank mismatch between shape and shape schema for" + assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}" + pad_size = [ + pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) + ] + + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + padding.reverse() + padding = list(itertools.chain(*padding)) + if padding: + protein[k] = torch.nn.functional.pad(v, padding) + protein[k] = torch.reshape(protein[k], pad_size) + + return protein + + +@curry1 +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping for + # compatibility with domain datasets. + has_break = torch.clip( + protein["between_segment_residues"].to(torch.float32), 0, 1 + ) + aatype_1hot = make_one_hot(protein["aatype"], 21) + + target_feat = [ + torch.unsqueeze(has_break, dim=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + + msa_1hot = make_one_hot(protein["msa"], 23) + has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0) + deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * ( + 2.0 / np.pi + ) + + msa_feat = [ + msa_1hot, + torch.unsqueeze(has_deletion, dim=-1), + torch.unsqueeze(deletion_value, dim=-1), + ] + + if "cluster_profile" in protein: + deletion_mean_value = torch.atan( + protein["cluster_deletion_mean"] / 3.0 + ) * (2.0 / np.pi) + msa_feat.extend( + [ + protein["cluster_profile"], + torch.unsqueeze(deletion_mean_value, dim=-1), + ] + ) + + if "extra_deletion_matrix" in protein: + protein["extra_has_deletion"] = torch.clip( + protein["extra_deletion_matrix"], 0.0, 1.0 + ) + protein["extra_deletion_value"] = torch.atan( + protein["extra_deletion_matrix"] / 3.0 + ) * (2.0 / np.pi) + + protein["msa_feat"] = torch.cat(msa_feat, dim=-1) + protein["target_feat"] = torch.cat(target_feat, dim=-1) + return protein + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +@curry1 +def crop_templates(protein, max_templates): + for k, v in protein.items(): + if k.startswith("template_"): + protein[k] = v[:max_templates] + return protein + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in rc.restypes: + atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] + restype_atom14_to_atom37.append( + [(rc.atom_order[name] if name else 0) for name in atom_names] + ) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append( + [ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in rc.atom_types + ] + ) + + restype_atom14_mask.append( + [(1.0 if name else 0.0) for name in atom_names] + ) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = torch.tensor( + restype_atom14_to_atom37, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom37_to_atom14 = torch.tensor( + restype_atom37_to_atom14, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom14_mask = torch.tensor( + restype_atom14_mask, + dtype=torch.float32, + device=protein["aatype"].device, + ) + protein_aatype = protein['aatype'].to(torch.long) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] + residx_atom14_mask = restype_atom14_mask[protein_aatype] + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() + + # create the gather indices for mapping back + residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] + protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() + + # create the corresponding mask + restype_atom37_mask = torch.zeros( + [21, 37], dtype=torch.float32, device=protein["aatype"].device + ) + for restype, restype_letter in enumerate(rc.restypes): + restype_name = rc.restype_1to3[restype_letter] + atom_names = rc.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = rc.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[protein_aatype] + protein["atom37_atom_exists"] = residx_atom37_mask + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + residx_atom14_mask = protein["atom14_atom_exists"] + residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein["all_atom_mask"], + residx_atom14_to_atom37, + dim=-1, + no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein["all_atom_positions"], + residx_atom14_to_atom37, + dim=-2, + no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), + ) + ) + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["atom14_gt_exists"] = residx_atom14_gt_mask + protein["atom14_gt_positions"] = residx_atom14_gt_positions + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [rc.restype_1to3[res] for res in rc.restypes] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = { + res: torch.eye( + 14, + dtype=protein["all_atom_mask"].dtype, + device=protein["all_atom_mask"].device, + ) + for res in restype_3 + } + for resname, swap in rc.residue_atom_renaming_swaps.items(): + correspondences = torch.arange( + 14, device=protein["all_atom_mask"].device + ) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = rc.restype_name_to_atom14_names[resname].index( + source_atom_swap + ) + target_index = rc.restype_name_to_atom14_names[resname].index( + target_atom_swap + ) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + renaming_matrices = torch.stack( + [all_matrices[restype] for restype in restype_3] + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum( + "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform + ) + protein["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum( + "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform + ) + protein["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) + for resname, swap in rc.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + atom_idx1 = rc.restype_name_to_atom14_names[resname].index( + atom_name1 + ) + atom_idx2 = rc.restype_name_to_atom14_names[resname].index( + atom_name2 + ) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ + protein["aatype"] + ] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + aatype = protein["aatype"] + all_atom_positions = protein["all_atom_positions"] + all_atom_mask = protein["all_atom_mask"] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"] + restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, : + ] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), + ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor( + rc.chi_angles_mask + ) + + lookuptable = rc.atom_order.copy() + lookuptable[""] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, + ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, + ) + restype_rigidgroup_base_atom37_idx = ( + restype_rigidgroup_base_atom37_idx.view( + *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape + ) + ) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + no_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + no_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Rigid.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + no_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + rots = Rotation(rot_mats=rots) + + gt_frames = gt_frames.compose(Rigid(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1,) * batch_dims), 21, 8 + ) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device + ) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1,) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + rot_mats=residx_rigidgroup_ambiguity_rot + ) + alt_gt_frames = gt_frames.compose( + Rigid(residx_rigidgroup_ambiguity_rot, None) + ) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein["rigidgroups_gt_frames"] = gt_frames_tensor + protein["rigidgroups_gt_exists"] = gt_exists + protein["rigidgroups_group_exists"] = group_exists + protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous + protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor + + return protein + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in rc.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in rc.restypes: + residue_name = rc.restype_1to3[residue_name] + residue_chi_angles = rc.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append( + [0, 0, 0, 0] + ) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix="", +): + """ + Convert coordinates to torsion angles. + + This function is extremely sensitive to floating point imprecisions + and should be run with double precision whenever possible. + + Args: + Dict containing: + * (prefix)aatype: + [*, N_res] residue indices + * (prefix)all_atom_positions: + [*, N_res, 37, 3] atom positions (in atom37 + format) + * (prefix)all_atom_mask: + [*, N_res, 37] atom position mask + Returns: + The same dictionary updated with the following features: + + "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Torsion angles + "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Alternate torsion angles (accounting for 180-degree symmetry) + "(prefix)torsion_angles_mask" ([*, N_res, 7]) + Torsion angles mask + """ + aatype = protein[prefix + "aatype"] + all_atom_positions = protein[prefix + "all_atom_positions"] + all_atom_mask = protein[prefix + "all_atom_mask"] + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3] + ) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3 + ) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1 + ) * torch.prod(all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype + ) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4] + ) + + chi_atom_indices = torch.as_tensor( + get_chi_atom_indices(), device=aatype.device + ) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather( + all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2]) + ) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + no_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype + ) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Rigid.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :] + ) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 + ) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + + 1e-8 + ) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( + [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], + )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, + )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None] + ) + + protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos + protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos + protein[prefix + "torsion_angles_mask"] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why. + protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][ + ..., 0, :, : + ] + protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein["all_atom_mask"].dtype + protein["chi_angles_sin_cos"] = ( + protein["torsion_angles_sin_cos"][..., 3:, :] + ).to(dtype) + protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype) + + return protein + + +@curry1 +def random_crop_to_size( + protein, + crop_size, + max_templates, + shape_schema, + subsample_templates=False, + seed=None, +): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + # We want each ensemble to be cropped the same way + g = torch.Generator(device=protein["seq_length"].device) + if seed is not None: + g.manual_seed(seed) + + seq_length = protein["seq_length"] + + if "template_mask" in protein: + num_templates = protein["template_mask"].shape[-1] + else: + num_templates = 0 + + # No need to subsample templates if there aren't any + subsample_templates = subsample_templates and num_templates + + num_res_crop_size = min(int(seq_length), crop_size) + + def _randint(lower, upper): + return int(torch.randint( + lower, + upper + 1, + (1,), + device=protein["seq_length"].device, + generator=g, + )[0]) + + if subsample_templates: + templates_crop_start = _randint(0, num_templates) + templates_select_indices = torch.randperm( + num_templates, device=protein["seq_length"].device, generator=g + ) + else: + templates_crop_start = 0 + + num_templates_crop_size = min( + num_templates - templates_crop_start, max_templates + ) + + n = seq_length - num_res_crop_size + if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: + right_anchor = n + else: + x = _randint(0, n) + right_anchor = n - x + + num_res_crop_start = _randint(0, right_anchor) + + for k, v in protein.items(): + if k not in shape_schema or ( + "template" not in k and NUM_RES not in shape_schema[k] + ): + continue + + # randomly permute the templates before cropping them. + if k.startswith("template") and subsample_templates: + v = v[templates_select_indices] + + slices = [] + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): + is_num_res = dim_size == NUM_RES + if i == 0 and k.startswith("template"): + crop_size = num_templates_crop_size + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size = num_res_crop_size if is_num_res else dim + slices.append(slice(crop_start, crop_start + crop_size)) + protein[k] = v[slices] + + protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) + + return protein diff --git a/analysis/src/common/geo_utils.py b/analysis/src/common/geo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62949d45b56a9c9716a5a9cc338f85f2a6d1cdc7 --- /dev/null +++ b/analysis/src/common/geo_utils.py @@ -0,0 +1,155 @@ +""" +Utility functions for geometric operations (torch only). +""" +import torch + + +def rots_mul_vecs(m, v): + """(Batch) Apply rotations 'm' to vectors 'v'.""" + return torch.stack([ + m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2], + m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2], + m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2], + ], dim=-1) + +def distance(p, eps=1e-10): + """Calculate distance between a pair of points (dim=-2).""" + # [*, 2, 3] + return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5 + +def dihedral(p, eps=1e-10): + """Calculate dihedral angle between a quadruple of points (dim=-2).""" + # p: [*, 4, 3] + + # [*, 3] + u1 = p[..., 1, :] - p[..., 0, :] + u2 = p[..., 2, :] - p[..., 1, :] + u3 = p[..., 3, :] - p[..., 2, :] + + # [*, 3] + u1xu2 = torch.cross(u1, u2, dim=-1) + u2xu3 = torch.cross(u2, u3, dim=-1) + + # [*] + u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5 + u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5 + u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5 + + # [*] + cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm) + sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm) + + return torch.stack([cos_enc, sin_enc], dim=-1) + +def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int): + # pos: [*, L, 3] + dists_2d = torch.linalg.norm( + pos[..., :, None, :] - pos[..., None, :, :], axis=-1 + )[..., None] + lower = torch.linspace( + min_bin, + max_bin, + num_bins, + device=pos.device) + upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) + distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) + return distogram + +def rmsd(xyz1, xyz2): + """ Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """ + return squared_deviation(xyz1, xyz2, 'rmsd') + +def squared_deviation(xyz1, xyz2, reduction='none'): + """Squared point-wise deviation between two point clouds after alignment. + + Args: + xyz1: (*, L, 3), to be transformed + xyz2: (*, L, 3), the reference + + Returns: + rmsd: (*, ) or none: (*, L) + """ + map_to_np = False + if not torch.is_tensor(xyz1): + map_to_np = True + xyz1 = torch.as_tensor(xyz1) + xyz2 = torch.as_tensor(xyz2) + + R, t = _find_rigid_alignment(xyz1, xyz2) + + # print(R.shape, t.shape) # B, 3, 3 & B, 3 + + # xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1) + xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) + + sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L) + + assert sd.shape == xyz1.shape[:-1] + if reduction == 'none': + pass + elif reduction == 'rmsd': + sd = torch.sqrt(sd.mean(dim=-1)) + else: + raise NotImplementedError() + + sd = sd.numpy() if map_to_np else sd + return sd + +def _find_rigid_alignment(src, tgt): + """Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/; + https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8 + + See: https://en.wikipedia.org/wiki/Kabsch_algorithm + + 2-D or 3-D registration with known correspondences. + Registration occurs in the zero centered coordinate system, and then + must be transported back. + + Args: + src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source) + tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target) + Returns: + R: optimal rotation (*, 3, 3) + t: optimal translation (*, 3) + + Test on rotation + translation and on rotation + translation + reflection + >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float) + >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float) + >>> B = (R0.mm(A.T)).T + >>> t0 = torch.tensor([3., 3.]) + >>> B += t0 + >>> R, t = find_rigid_alignment(A, B) + >>> A_aligned = (R.mm(A.T)).T + t + >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) + >>> rmsd + tensor(3.7064e-07) + >>> B *= torch.tensor([-1., 1.]) + >>> R, t = find_rigid_alignment(A, B) + >>> A_aligned = (R.mm(A.T)).T + t + >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) + >>> rmsd + tensor(3.7064e-07) + """ + assert src.shape[-2] > 1 + src_com = src.mean(dim=-2, keepdim=True) + tgt_com = tgt.mean(dim=-2, keepdim=True) + src_centered = src - src_com + tgt_centered = tgt - tgt_com + + # Covariance matrix + + # H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3 + H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered) + + U, S, V = torch.svd(H) + # Rotation matrix + + # R = V.bmm(U.transpose(-2,-1)) + R = torch.matmul(V, U.transpose(-2, -1)) + + # Translation vector + + # t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1) + t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1) + + return R, t.squeeze(-2) # (B, 3, 3), (B, 3) diff --git a/analysis/src/common/pdb_utils.py b/analysis/src/common/pdb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f65b7ce8ac7d4c91c3517da48c356eab2164162 --- /dev/null +++ b/analysis/src/common/pdb_utils.py @@ -0,0 +1,353 @@ +"""Utility functions for operating PDB files. +""" +import os +import re +from typing import Optional +from collections import OrderedDict + +import numpy as np +from tqdm import tqdm + +import biotite.structure as struct +from biotite.structure.io.pdb import PDBFile + +from src.common import protein + + +def write_pdb_string(pdb_string: str, save_to: str): + """Write pdb string to file""" + with open(save_to, 'w') as f: + f.write(pdb_string) + +def read_pdb_to_string(pdb_file): + """Read PDB file as pdb string. Convenient API""" + with open(pdb_file, 'r') as fi: + pdb_string = '' + for line in fi: + if line.startswith('END') or line.startswith('TER') \ + or line.startswith('MODEL') or line.startswith('ATOM'): + pdb_string += line + return pdb_string + +def merge_pdbfiles(input, output_file, verbose=True): + """ordered merging process of pdbs""" + if isinstance(input, str): + pdb_files = [os.path.join(input, f) for f in os.listdir(input) if f.endswith('.pdb')] + elif isinstance(input, list): + pdb_files = input + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + model_number = 0 + pdb_lines = [] + if verbose: + _iter = tqdm(pdb_files, desc='Merging PDBs') + else: + _iter = pdb_files + + for pdb_file in _iter: + with open(pdb_file, 'r') as pdb: + lines = pdb.readlines() + single_model = True + + for line in lines: + if line.startswith('MODEL') or line.startswith('ENDMDL'): + single_model = False + break + + if single_model: # single model + model_number += 1 + pdb_lines.append(f"MODEL {model_number}") + for line in lines: + if line.startswith('TER') or line.startswith('ATOM'): + pdb_lines.append(line.strip()) + pdb_lines.append("ENDMDL") + else: # multiple models + for line in lines: + if line.startswith('MODEL'): + model_number += 1 + if model_number > 1: + pdb_lines.append("ENDMDL") + pdb_lines.append(f"MODEL {model_number}") + elif line.startswith('END'): + continue + elif line.startswith('TER') or line.startswith('ATOM'): + pdb_lines.append(line.strip()) + pdb_lines.append('ENDMDL') + pdb_lines.append('END') + pdb_lines = [_line.ljust(80) for _line in pdb_lines] + pdb_str = '\n'.join(pdb_lines) + '\n' + with open(output_file, 'w') as fo: + fo.write(pdb_str) + + if verbose: + print(f"Merged {len(pdb_files)} PDB files into {output_file} with {model_number} models.") + + +def split_pdbfile(pdb_file, output_dir=None, suffix='index', verbose=True): + """Split a PDB file into multiple PDB files in output_dir. + Preassume that each model is wrapped by 'MODEL' and 'ENDMDL'. + """ + assert os.path.exists(pdb_file), f"File {pdb_file} does not exist." + assert suffix == 'index', 'Only support [suffix=index] for now.' + + if output_dir is not None: # also dump to output_dir + os.makedirs(output_dir, exist_ok=True) + base = os.path.splitext(os.path.basename(pdb_file))[0] + + i = 0 + pdb_strs = [] + pdb_string = '' + with open(pdb_file, 'r') as fi: + # pdb_string = '' + for line in fi: + if line.startswith('MODEL'): + pdb_string = '' + elif line.startswith('ATOM') or line.startswith('TER'): + pdb_string += line + elif line.startswith('ENDMDL') or line.startswith('END'): + if pdb_string == '': continue + pdb_string += 'END\n' + if output_dir is not None: + _save_to = os.path.join(output_dir, f'{base}_{i}.pdb') if suffix == 'index' else None + with open(_save_to, 'w') as fo: + fo.write(pdb_string) + pdb_strs.append(pdb_string) + pdb_string = '' + i += 1 + else: + if verbose: + print(f"Warning: line '{line}' is not recognized. Skip.") + if verbose: + print(f">>> Split pdb {pdb_file} into {i}/{len(pdb_strs)} structures.") + return pdb_strs + + +def stratify_sample_pdbfile(input_path, output_path, n_max_sample=1000, end_at=0, verbose=True): + """ """ + assert os.path.exists(input_path), f"File {input_path} does not exist." + assert not os.path.exists(output_path), f"Output path {output_path} already exists." + + i = 0 + pdb_strs = [] + with open(input_path, 'r') as fi: + # pdb_string = '' + pdb_lines_per_model = [] + for line in fi: + if line.startswith('MODEL'): + pdb_lines_per_model = [] + elif line.startswith('ATOM') or line.startswith('TER'): + pdb_lines_per_model.append(line.strip()) + elif line.startswith('ENDMDL') or line.startswith('END'): + if pdb_lines_per_model == []: continue # skip empty model + # wrap up the model + pdb_lines_per_model.append('ENDMDL') + # Pad all lines to 80 characters. + pdb_lines_per_model = [_line.ljust(80) for _line in pdb_lines_per_model] + pdb_str_per_model = '\n'.join(pdb_lines_per_model) + '\n' # Add terminating newline. + pdb_strs.append(pdb_str_per_model) + # reset + pdb_lines_per_model = [] + i += 1 + else: + if verbose: + print(f"Warning: line '{line}' is not recognized. Skip.") + if end_at > 0 and i > end_at: + break + end = end_at if end_at > 0 else len(pdb_strs) + + # sample evenly + if end > n_max_sample: + interleave_step = int(end // n_max_sample) # floor + sampled_pdb_strs = pdb_strs[:end][::interleave_step][:n_max_sample] + else: + sampled_pdb_strs = pdb_strs[:end] + + output_str = '' + for i, pdb_str in enumerate(sampled_pdb_strs): # renumber models + output_str += f"MODEL {i+1}".ljust(80) + '\n' + output_str += pdb_str + output_str = output_str + ('END'.ljust(80) + '\n') + + write_pdb_string(output_str, save_to=output_path) + if verbose: + print(f">>> Split pdb {input_path} into {len(sampled_pdb_strs)}/{n_max_sample} structures.") + return + + +def protein_with_default_params( + atom_positions: np.ndarray, + atom_mask: np.ndarray, + aatype: Optional[np.ndarray] = None, + b_factors: Optional[np.ndarray] = None, + chain_index: Optional[np.ndarray] = None, + residue_index: Optional[np.ndarray] = None, +): + assert atom_positions.ndim == 3 + assert atom_positions.shape[-1] == 3 + assert atom_positions.shape[-2] == 37 + n = atom_positions.shape[0] + sqz = lambda x: np.squeeze(x) if x.shape[0] == 1 and len(x.shape) > 1 else x + + residue_index = np.arange(n) + 1 if residue_index is None else sqz(residue_index) + chain_index = np.zeros(n) if chain_index is None else sqz(chain_index) + b_factors = np.zeros([n, 37]) if b_factors is None else sqz(b_factors) + aatype = np.zeros(n, dtype=int) if aatype is None else sqz(aatype) + + return protein.Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=residue_index, + chain_index=chain_index, + b_factors=b_factors + ) + +def atom37_to_pdb( + save_to: str, + atom_positions: np.ndarray, + aatype: Optional[np.ndarray] = None, + b_factors: Optional[np.ndarray] = None, + chain_index: Optional[np.ndarray] = None, + residue_index: Optional[np.ndarray] = None, + overwrite: bool = False, + no_indexing: bool = True, +): + # configure save path + if overwrite: + max_existing_idx = 0 + else: + file_dir = os.path.dirname(save_to) + file_name = os.path.basename(save_to).strip('.pdb') + existing_files = [x for x in os.listdir(file_dir) if file_name in x] + max_existing_idx = max([ + int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x) + if re.findall(r'_(\d+).pdb', x)] + [0]) + if not no_indexing: + save_to = save_to.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb' + else: + save_to = save_to + + with open(save_to, 'w') as f: + if atom_positions.ndim == 4: + for mi, pos37 in enumerate(atom_positions): + atom_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 + prot = protein_with_default_params( + pos37, atom_mask, aatype=aatype, b_factors=b_factors, + chain_index=chain_index, residue_index=residue_index + ) + pdb_str = protein.to_pdb(prot, model=mi+1, add_end=False) + f.write(pdb_str) + elif atom_positions.ndim == 3: + atom_mask = np.sum(np.abs(atom_positions), axis=-1) > 1e-7 + prot = protein_with_default_params( + atom_positions, atom_mask, aatype=aatype, b_factors=b_factors, + chain_index=chain_index, residue_index=residue_index + ) + pdb_str = protein.to_pdb(prot, model=1, add_end=False) + f.write(pdb_str) + else: + raise ValueError(f'Invalid positions shape {atom_positions.shape}') + f.write('END') + + return save_to + +def extract_backbone_coords_from_pdb(pdb_path: str, target_atoms: Optional[list] = ["CA"]): + structure = PDBFile.read(pdb_path) + structure_list = structure.get_structure() + + coords_list = [] + for b_idx in range(structure.get_model_count()): + chain = structure_list[b_idx] + + backbone_atoms = chain[struct.filter_backbone(chain)] # This includes the “N”, “CA” and “C” atoms of amino acids. + ret_coords = OrderedDict() + # init dict + for k in target_atoms: + ret_coords[k] = [] + + for c in backbone_atoms: + if c.atom_name in ret_coords: + ret_coords[c.atom_name].append(c.coord) + + ret_coords = [np.vstack(v) for k,v in ret_coords.items()] + if len(target_atoms) == 1: + ret_coords = ret_coords[0] # L, 3 + else: + ret_coords = np.stack(ret_coords, axis=1) # L, na, 3 + + coords_list.append(ret_coords) + + coords_list = np.stack(coords_list, axis=0) # B, L, na, 3 or B, L, 3 (ca only) + return coords_list + + +def extract_backbone_coords_from_pdb_dir(pdb_dir: str): + return np.concatenate([ + extract_backbone_coords_from_pdb(os.path.join(pdb_dir, f)) + for f in os.listdir(pdb_dir) if f.endswith('.pdb') + ], axis=0) + +def extract_backbone_coords_from_npy(npy_path: str): + return np.load(npy_path) + + +def extract_backbone_coords(input_path: str, + max_n_model: Optional[int] = None, +): + """Extract backbone coordinates from PDB file. + + Args: + input_path (str): The path to the PDB file. + ca_only (bool): Whether to extract only CA coordinates. + max_n_model (int): The maximum number of models to extract. + """ + assert os.path.exists(input_path), f"File {input_path} does not exist." + if input_path.endswith('.pdb'): + coords = extract_backbone_coords_from_pdb(input_path) + elif input_path.endswith('.npy'): + coords = extract_backbone_coords_from_npy(input_path) + elif os.path.isdir(input_path): + coords = extract_backbone_coords_from_pdb_dir(input_path) + else: + raise ValueError(f"Unrecognized input path {input_path}.") + + if max_n_model is not None and len(coords) > max_n_model > 0: + coords = coords[:max_n_model] + return coords + + + +if __name__ == '__main__': + import argparse + def get_argparser(): + parser = argparse.ArgumentParser(description='Main script for pdb processing.') + parser.add_argument("input", type=str, help="The generic path to sampled pdb directory / pdb file.") + parser.add_argument("-m", "--mode", type=str, help="The mode of processing.", + default="split") + parser.add_argument("-o", "--output", type=str, help="The output directory for processed pdb files.", + default=None) + + args = parser.parse_args() + return args + args = get_argparser() + + # ad hoc functions + def split_pdbs(args): + os.makedirs(args.output, exist_ok=True) + _ = split_pdbfile(pdb_file=args.input, + output_dir=args.output) + + def merge_pdbs(args): + output = args.output or f"{args.input}_all.pdb" + merge_pdbfiles(input=args.input, + output_file=output) + + if args.mode == "split": + split_pdbs(args) + elif args.mode == "merge": + merge_pdbs(args) + elif args.mode == "stratify": + stratify_sample_pdbfile(input_path=args.input, output_path=args.output) + else: + raise ValueError(f"Unrecognized mode {args.mode}.") \ No newline at end of file diff --git a/analysis/src/common/protein.py b/analysis/src/common/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4ba6b66e48c956e7c966f60fc0819865d81eaf --- /dev/null +++ b/analysis/src/common/protein.py @@ -0,0 +1,289 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional + +from Bio.PDB import PDBParser +import numpy as np + +from src.common import residue_constants + + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + def to_dict(self): + return dataclasses.asdict(self) + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors)) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein, model=1, add_end=True) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(int) + chain_index = prot.chain_index.astype(int) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append(f'MODEL {model}') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append(_chain_end( + atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], + residue_index[i - 1])) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + # skip CB for GLY + if res_name_3 == 'GLY' and atom_name == 'CB': + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], residue_index[-1])) + pdb_lines.append('ENDMDL') + if add_end: + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = True) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values. + + Returns: + A protein instance. + """ + fold_output = result['structure_module'] + + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if 'asym_id' in features: + chain_index = _maybe_remove_leading_dim(features['asym_id']) + else: + chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + return Protein( + aatype=_maybe_remove_leading_dim(features['aatype']), + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, + chain_index=chain_index, + b_factors=b_factors) \ No newline at end of file diff --git a/analysis/src/common/residue_constants.py b/analysis/src/common/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9b925b5c25bb76bbffc12940fae3afbc9f2ba1a8 --- /dev/null +++ b/analysis/src/common/residue_constants.py @@ -0,0 +1,897 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]]]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt' + ) + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, + residue_virtual_bonds, + residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError('The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError(f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=int) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1]*(4-len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } \ No newline at end of file diff --git a/analysis/src/common/rigid_utils.py b/analysis/src/common/rigid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..699be6541a9c34e9ae9f23e74b97fca70f3eadc8 --- /dev/null +++ b/analysis/src/common/rigid_utils.py @@ -0,0 +1,1451 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Any, Sequence, Callable, Optional + +import numpy as np +import torch + +from src.common import rotation3d + + +def rot_matmul( + a: torch.Tensor, + b: torch.Tensor +) -> torch.Tensor: + """ + Performs matrix multiplication of two rotation matrix tensors. Written + out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + row_1 = torch.stack( + [ + a[..., 0, 0] * b[..., 0, 0] + + a[..., 0, 1] * b[..., 1, 0] + + a[..., 0, 2] * b[..., 2, 0], + a[..., 0, 0] * b[..., 0, 1] + + a[..., 0, 1] * b[..., 1, 1] + + a[..., 0, 2] * b[..., 2, 1], + a[..., 0, 0] * b[..., 0, 2] + + a[..., 0, 1] * b[..., 1, 2] + + a[..., 0, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_2 = torch.stack( + [ + a[..., 1, 0] * b[..., 0, 0] + + a[..., 1, 1] * b[..., 1, 0] + + a[..., 1, 2] * b[..., 2, 0], + a[..., 1, 0] * b[..., 0, 1] + + a[..., 1, 1] * b[..., 1, 1] + + a[..., 1, 2] * b[..., 2, 1], + a[..., 1, 0] * b[..., 0, 2] + + a[..., 1, 1] * b[..., 1, 2] + + a[..., 1, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_3 = torch.stack( + [ + a[..., 2, 0] * b[..., 0, 0] + + a[..., 2, 1] * b[..., 1, 0] + + a[..., 2, 2] * b[..., 2, 0], + a[..., 2, 0] * b[..., 0, 1] + + a[..., 2, 1] * b[..., 1, 1] + + a[..., 2, 2] * b[..., 2, 1], + a[..., 2, 0] * b[..., 0, 2] + + a[..., 2, 1] * b[..., 1, 2] + + a[..., 2, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack([row_1, row_2, row_3], dim=-2) + + +def rot_vec_mul( + r: torch.Tensor, + t: torch.Tensor +) -> torch.Tensor: + """ + Applies a rotation to a vector. Written out by hand to avoid transfer + to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + x = t[..., 0] + y = t[..., 1] + z = t[..., 2] + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + + +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad + ) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + + return rots + + +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros( + (*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + return trans + + +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros( + (*batch_dims, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """ + Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = quat.new_tensor(_QTR_MAT, requires_grad=False) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if(rot.shape[-2:] != (3, 3)): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,] + ] + + k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = quat1.new_tensor(_QUAT_MULTIPLY) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat1[..., :, None, None] * + quat2[..., None, :, None], + dim=(-3, -2) + ) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat[..., :, None, None] * + vec[..., None, :, None], + dim=(-3, -2) + ) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """ + A 3D rotation. Depending on how the object is initialized, the + rotation is represented by either a rotation matrix or a + quaternion, though both formats are made available by helper functions. + To simplify gradient computation, the underlying format of the + rotation cannot be changed in-place. Like Rigid, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotation + object were a tensor of rotations, in one format or another. + """ + def __init__(self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """ + Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with + quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not True, must be a unit quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if((rot_mats is None and quats is None) or + (rot_mats is not None and quats is not None)): + raise ValueError("Exactly one input argument must be specified") + + if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or + (quats is not None and quats.shape[-1] != 4)): + raise ValueError( + "Incorrectly shaped rotation matrix or quaternion" + ) + + # Force full-precision + if(quats is not None): + quats = quats.type(torch.float32) + if(rot_mats is not None): + rot_mats = rot_mats.type(torch.float32) + + if(quats is not None and normalize_quats): + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ): + """ + Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation + for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object + should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format + of the new object's rotation + Returns: + A new identity rotation + """ + if(fmt == "rot_mat"): + rot_mats = identity_rot_mats( + shape, dtype, device, requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif(fmt == "quat"): + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any): + """ + Allows torch-style indexing over the virtual shape of the rotation + object. See documentation for the shape property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if(self._rot_mats is not None): + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif(self._quats is not None): + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__(self, + right: torch.Tensor, + ): + """ + Pointwise left multiplication of the rotation with a tensor. Can be + used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__(self, + left: torch.Tensor, + ): + """ + Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """ + Returns the virtual shape of the rotation object. This shape is + defined as the batch dimensions of the underlying rotation matrix + or quaternion. If the Rotation was initialized with a [10, 3, 3] + rotation matrix tensor, for example, the resulting shape would be + [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if(self._quats is not None): + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.dtype + elif(self._quats is not None): + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """ + The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.device + elif(self._quats is not None): + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """ + Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if(self._rot_mats is not None): + return self._rot_mats.requires_grad + elif(self._quats is not None): + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if(rot_mats is None): + if(self._quats is None): + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a + quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if(quats is None): + if(self._rot_mats is None): + raise ValueError("Both rotations are None") + else: + # quats = rot_to_quat(self._rot_mats) + quats = rotation3d.matrix_to_quaternion(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """ + Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if(self._rot_mats is not None): + return self._rot_mats + elif(self._quats is not None): + return self._quats + else: + raise ValueError("Both rotations are None") + + def get_rotvec(self, eps=1e-6) -> torch.Tensor: + """ + Return the underlying axis-angle rotation vector. + + Follow's scipy's implementation: + https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402 + + Returns: + The stored rotation as a axis-angle vector. + """ + quat = self.get_quats() + # w > 0 to ensure 0 <= angle <= pi + flip = (quat[..., :1] < 0).float() + quat = (-1 * quat) * flip + (1 - flip) * quat + + angle = 2 * torch.atan2( + torch.linalg.norm(quat[..., 1:], dim=-1), + quat[..., 0] + ) + + angle2 = angle * angle + small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880 + large_angle_scales = angle / torch.sin(angle / 2 + eps) + + small_angles = (angle <= 1e-3).float() + rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales + rot_vec = rot_vec_scale[..., None] * quat[..., 1:] + return rot_vec + + # Rotation functions + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True, + update_mask: torch.Tensor = None, + ): + """ + Returns a new quaternion Rotation after updating the current + object's underlying rotation with a quaternion update, formatted + as a [*, 3] tensor whose final three columns represent x, y, z such + that (1, x, y, z) is the desired (not necessarily unit) quaternion + update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + quat_update = quat_multiply_by_vec(quats, q_update_vec) + if update_mask is not None: + quat_update = quat_update * update_mask + new_quats = quats + quat_update + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r): + """ + Compose the rotation matrices of the current Rotation object with + those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r, normalize_quats: bool = True): + """ + Compose the quaternions of the current Rotation object with those + of another. + + Depending on whether either Rotation was initialized with + quaternions, this function may call torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation( + rot_mats=None, quats=new_quats, normalize_quats=normalize_quats + ) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Apply the current Rotation as a rotation matrix to a set of 3D + coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) : + """ + Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=invert_rot_mat(self._rot_mats), + quats=None + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze(self, + dim: int, + ): + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs, + dim: int, + ): + """ + Concatenates rotations along one of the batch dimensions. Analogous + to torch.cat(). + + Note that the output of this operation is always a rotation matrix, + regardless of the format of input rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be + concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, + fn + ): + """ + Apply a Tensor -> Tensor function to underlying rotation tensors, + mapping over the rotation dimension(s). Can be used e.g. to sum out + a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if(self._rot_mats is not None): + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack( + list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1 + ) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = torch.stack( + list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1 + ) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self): + """ + Analogous to the cuda() method of torch Tensors + + Returns: + A copy of the Rotation in CUDA memory + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.cuda(), + normalize_quats=False + ) + else: + raise ValueError("Both rotations are None") + + def to(self, + device: Optional[torch.device], + dtype: Optional[torch.dtype] + ): + """ + Analogous to the to() method of torch Tensors + + Args: + device: + A torch device + dtype: + A torch dtype + Returns: + A copy of the Rotation using the new device and dtype + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=self._rot_mats.to(device=device, dtype=dtype), + quats=None, + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.to(device=device, dtype=dtype), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + def detach(self): + """ + Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """ + A class representing a rigid transformation. Little more than a wrapper + around two objects: a Rotation object and a [*, 3] translation + Designed to behave approximately like a single torch tensor with the + shape of the shared batch dimensions of its component parts. + """ + def __init__(self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """ + Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if(trans is not None): + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif(rots is not None): + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if(rots is None): + rots = Rotation.identity( + batch_dims, dtype, device, requires_grad, + ) + elif(trans is None): + trans = identity_trans( + batch_dims, dtype, device, requires_grad, + ) + + if((rots.shape != trans.shape[:-1]) or + (rots.device != trans.device)): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.type(torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ): + """ + Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__(self, + index: Any, + ): + """ + Indexes the affine transformation with PyTorch-style indices. + The index is applied to the shared dimensions of both the rotation + and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + t = Rigid(r, torch.rand(10, 10, 3)) + indexed = t[3, 4:6] + assert(indexed.shape == (2,)) + assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__(self, + right: torch.Tensor, + ): + """ + Pointwise left multiplication of the transformation with a tensor. + Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__(self, + left: torch.Tensor, + ): + """ + Reverse pointwise multiplication of the transformation with a + tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """ + Returns the shape of the shared dimensions of the rotation and + the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """ + Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + def get_rots(self) -> Rotation: + """ + Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """ + Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + update_mask: torch.Tensor=None, + ): + """ + Composes the transformation with a quaternion update vector of + shape [*, 6], where the final 6 columns represent the x, y, and + z values of a quaternion of form (1, x, y, z) followed by a 3D + translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec( + q_vec, update_mask=update_mask) + + trans_update = self._rots.apply(t_vec) + if update_mask is not None: + trans_update = trans_update * update_mask + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose(self, + r, + ): + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def compose_r(self, + rot, + order='right' + ): + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + order: + Order in which to perform rotation multiplication. + Returns: + The composition of the two transformations + """ + if order == 'right': + new_rot = self._rots.compose_r(rot) + elif order == 'left': + new_rot = rot.compose_r(self._rots) + else: + raise ValueError(f'Unrecognized multiplication order: {order}') + return Rigid(new_rot, self._trans) + + def apply(self, + pts: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, + pts: torch.Tensor + ) -> torch.Tensor: + """ + Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self): + """ + Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, + fn + ): + """ + Apply a Tensor -> Tensor function to underlying translation and + rotation tensors, mapping over the translation/rotation dimensions + respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._trans, dim=-1))), + dim=-1 + ) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """ + Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4( + t: torch.Tensor + ): + """ + Constructs a transformation from a homogenous transformation + tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if(t.shape[-2:] != (4, 4)): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """ + Converts a transformation to a tensor with 7 final columns, four + for the quaternion followed by three for the translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ): + if(t.shape[-1] != 7): + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation( + rot_mats=None, + quats=quats, + normalize_quats=normalize_quats + ) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8 + ): + """ + Implements algorithm 21. Constructs transformations from sets of 3 + points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze(self, + dim: int, + ): + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts, + dim: int, + ): + """ + Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be + concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat( + [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1 + ) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn): + """ + Applies a Rotation -> Rotation function to the stored rotation + object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn): + """ + Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the + translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float): + """ + Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + fn = lambda t: t * trans_scale_factor + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self): + """ + Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + fn = lambda r: r.detach() + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """ + Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you + provide the atom positions in the non-standard way, the N atom will + end up not at [-0.527250, 1.359329, 0.0] but instead at + [-0.527250, -1.359329, 0.0]. You need to take care of such cases in + your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + Returns: + A transformation object. After applying the translation and + rotation to the reference backbone, the coordinates will + approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = sin_c1.new_zeros(sin_c1.shape) + ones = sin_c1.new_ones(sin_c1.shape) + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c1_rots[..., 2, 0] = -1 * sin_c2 + c1_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self): + """ + Moves the transformation object to GPU memory + + Returns: + A version of the transformation on GPU + """ + return Rigid(self._rots.cuda(), self._trans.cuda()) diff --git a/analysis/src/common/rotation3d.py b/analysis/src/common/rotation3d.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9febabf302b0b81d8b97cece719570aeb6402c --- /dev/null +++ b/analysis/src/common/rotation3d.py @@ -0,0 +1,596 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import torch +from torch.nn import functional as F + +Device = Union[str, torch.device] + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device)[0] + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) \ No newline at end of file diff --git a/analysis/src/data/__init__.py b/analysis/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/data/__pycache__/__init__.cpython-39.pyc b/analysis/src/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf56b942c18dc19acdf77b510065416f88e26b01 Binary files /dev/null and b/analysis/src/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc b/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87bba346055e0ee2e6c40316502089321e7bc0b Binary files /dev/null and b/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc differ diff --git a/analysis/src/data/components/__init__.py b/analysis/src/data/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc b/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e615af160dca3703e79f7b957b6e93d97e25973 Binary files /dev/null and b/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc b/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..411ee67d87b0b1f4e2d5bb461244b9a0be259511 Binary files /dev/null and b/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc differ diff --git a/analysis/src/data/components/dataset.py b/analysis/src/data/components/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e619d9d9d2a230d315844580336152191b4b52d2 --- /dev/null +++ b/analysis/src/data/components/dataset.py @@ -0,0 +1,321 @@ +"""Protein dataset class.""" +import os +import pickle +from pathlib import Path +from glob import glob +from typing import Optional, Sequence, List, Union +from functools import lru_cache +import tree + +from tqdm import tqdm +import numpy as np +import pandas as pd +import torch + +from src.common import residue_constants, data_transforms, rigid_utils, protein + + +CA_IDX = residue_constants.atom_order['CA'] +DTYPE_MAPPING = { + 'aatype': torch.long, + 'atom_positions': torch.double, + 'atom_mask': torch.double, +} + + +class ProteinFeatureTransform: + def __init__(self, + unit: Optional[str] = 'angstrom', + truncate_length: Optional[int] = None, + strip_missing_residues: bool = True, + recenter_and_scale: bool = True, + eps: float = 1e-8, + ): + if unit == 'angstrom': + self.coordinate_scale = 1.0 + elif unit in ('nm', 'nanometer'): + self.coordiante_scale = 0.1 + else: + raise ValueError(f"Invalid unit: {unit}") + + if truncate_length is not None: + assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}" + self.truncate_length = truncate_length + + self.strip_missing_residues = strip_missing_residues + self.recenter_and_scale = recenter_and_scale + self.eps = eps + + def __call__(self, chain_feats): + chain_feats = self.patch_feats(chain_feats) + + if self.strip_missing_residues: + chain_feats = self.strip_ends(chain_feats) + + if self.truncate_length is not None: + chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length) + + # Recenter and scale atom positions + if self.recenter_and_scale: + chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps) + + # Map to torch Tensor + chain_feats = self.map_to_tensors(chain_feats) + # Add extra features from AF2 + chain_feats = self.protein_data_transform(chain_feats) + + # ** refer to line 170 in pdb_data_loader.py ** + return chain_feats + + @staticmethod + def patch_feats(chain_feats): + seq_mask = chain_feats['atom_mask'][:, CA_IDX] # a little hack here + # residue_idx = np.arange(seq_mask.shape[0], dtype=np.int64) + residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) # start from 0, possibly has chain break + patch_feats = { + 'seq_mask': seq_mask, + 'residue_mask': seq_mask, + 'residue_idx': residue_idx, + 'fixed_mask': np.zeros_like(seq_mask), + 'sc_ca_t': np.zeros(seq_mask.shape + (3, )), + } + chain_feats.update(patch_feats) + return chain_feats + + @staticmethod + def strip_ends(chain_feats): + # Strip missing residues on both ends + modeled_idx = np.where(chain_feats['aatype'] != 20)[0] + min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx) + chain_feats = tree.map_structure( + lambda x: x[min_idx : (max_idx+1)], chain_feats) + return chain_feats + + @staticmethod + def random_truncate(chain_feats, max_len): + L = chain_feats['aatype'].shape[0] + if L > max_len: + # Randomly truncate + start = np.random.randint(0, L - max_len + 1) + end = start + max_len + chain_feats = tree.map_structure( + lambda x: x[start : end], chain_feats) + return chain_feats + + @staticmethod + def map_to_tensors(chain_feats): + chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()} + # Alter dtype + for k, dtype in DTYPE_MAPPING.items(): + if k in chain_feats: + chain_feats[k] = chain_feats[k].type(dtype) + return chain_feats + + @staticmethod + def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8): + # recenter and scale atom positions + bb_pos = chain_feats['atom_positions'][:, CA_IDX] + bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps) + centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] + scaled_pos = centered_pos * coordinate_scale + chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] + return chain_feats + + @staticmethod + def protein_data_transform(chain_feats): + chain_feats.update( + { + "all_atom_positions": chain_feats["atom_positions"], + "all_atom_mask": chain_feats["atom_mask"], + } + ) + chain_feats = data_transforms.atom37_to_frames(chain_feats) + chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats) + chain_feats = data_transforms.get_backbone_frames(chain_feats) + chain_feats = data_transforms.get_chi_angles(chain_feats) + chain_feats = data_transforms.make_pseudo_beta("")(chain_feats) + chain_feats = data_transforms.make_atom14_masks(chain_feats) + chain_feats = data_transforms.make_atom14_positions(chain_feats) + + # Add convenient key + chain_feats.pop("all_atom_positions") + chain_feats.pop("all_atom_mask") + return chain_feats + + +class MetadataFilter: + def __init__(self, + min_len: Optional[int] = None, + max_len: Optional[int] = None, + min_chains: Optional[int] = None, + max_chains: Optional[int] = None, + min_resolution: Optional[int] = None, + max_resolution: Optional[int] = None, + include_structure_method: Optional[List[str]] = None, + include_oligomeric_detail: Optional[List[str]] = None, + **kwargs, + ): + self.min_len = min_len + self.max_len = max_len + self.min_chains = min_chains + self.max_chains = max_chains + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.include_structure_method = include_structure_method + self.include_oligomeric_detail = include_oligomeric_detail + + def __call__(self, df): + _pre_filter_len = len(df) + if self.min_len is not None: + df = df[df['raw_seq_len'] >= self.min_len] + if self.max_len is not None: + df = df[df['raw_seq_len'] <= self.max_len] + if self.min_chains is not None: + df = df[df['num_chains'] >= self.min_chains] + if self.max_chains is not None: + df = df[df['num_chains'] <= self.max_chains] + if self.min_resolution is not None: + df = df[df['resolution'] >= self.min_resolution] + if self.max_resolution is not None: + df = df[df['resolution'] <= self.max_resolution] + if self.include_structure_method is not None: + df = df[df['include_structure_method'].isin(self.include_structure_method)] + if self.include_oligomeric_detail is not None: + df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)] + + print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter") + return df + + +class RandomAccessProteinDataset(torch.utils.data.Dataset): + """Random access to pickle protein objects of dataset. + + dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors']) + + Note that each value is a ndarray in shape (L, *), for example: + 'atom_positions': (L, 37, 3) + """ + def __init__(self, + path_to_dataset: Union[Path, str], + path_to_seq_embedding: Optional[Path] = None, + metadata_filter: Optional[MetadataFilter] = None, + training: bool = True, + transform: Optional[ProteinFeatureTransform] = None, + suffix: Optional[str] = '.pkl', + accession_code_fillter: Optional[Sequence[str]] = None, + **kwargs, + ): + super().__init__() + path_to_dataset = os.path.expanduser(path_to_dataset) + suffix = suffix if suffix.startswith('.') else '.' + suffix + assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}" + + if os.path.isfile(path_to_dataset): # path to csv file + assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)" + self._df = pd.read_csv(path_to_dataset) + self._df.sort_values('modeled_seq_len', ascending=False) + if metadata_filter: + self._df = metadata_filter(self._df) + self._data = self._df['processed_complex_path'].tolist() + elif os.path.isdir(path_to_dataset): # path to directory + self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix))) + assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'" + else: # path as glob pattern + _pattern = path_to_dataset + self._data = sorted(glob(_pattern)) + assert len(self._data) > 0, f"No files found in '{_pattern}'" + + if accession_code_fillter and len(accession_code_fillter) > 0: + self._data = [p for p in self._data + if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter) + ] + + self.data = np.asarray(self._data) + self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \ + if path_to_seq_embedding is not None else None + self.suffix = suffix + self.transform = transform + self.training = training # not implemented yet + + + @property + def num_samples(self): + return len(self.data) + + def len(self): + return self.__len__() + + def __len__(self): + return self.num_samples + + def get(self, idx): + return self.__getitem__(idx) + + @lru_cache(maxsize=100) + def __getitem__(self, idx): + """return single pyg.Data() instance + """ + data_path = self.data[idx] + accession_code = os.path.splitext(os.path.basename(data_path))[0] + + if self.suffix == '.pkl': + # Load pickled protein + with open(data_path, 'rb') as f: + data_object = pickle.load(f) + elif self.suffix == '.pdb': + # Load pdb file + with open(data_path, 'r') as f: + pdb_string = f.read() + data_object = protein.from_pdb_string(pdb_string).to_dict() + + # Apply data transform + if self.transform is not None: + data_object = self.transform(data_object) + + # Get sequence embedding if have + if self.path_to_seq_embedding is not None: + embed_dict = torch.load( + os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt") + ) + data_object.update( + { + 'seq_emb': embed_dict['representations'][33].float(), + } # 33 is for ESM650M + ) + + data_object['accession_code'] = accession_code + return data_object # dict of arrays + + + +class PretrainPDBDataset(RandomAccessProteinDataset): + def __init__(self, + path_to_dataset: str, + metadata_filter: MetadataFilter, + transform: ProteinFeatureTransform, + **kwargs, + ): + super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset, + metadata_filter=metadata_filter, + transform=transform, + **kwargs, + ) + + +class SamplingPDBDataset(RandomAccessProteinDataset): + def __init__(self, + path_to_dataset: str, + training: bool = False, + suffix: str = '.pdb', + transform: Optional[ProteinFeatureTransform] = None, + accession_code_fillter: Optional[Sequence[str]] = None, + ): + assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}" + super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset, + training=training, + suffix=suffix, + transform=transform, + accession_code_fillter=accession_code_fillter, + metadata_filter=None, + ) + \ No newline at end of file diff --git a/analysis/src/data/protein_datamodule.py b/analysis/src/data/protein_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..e4162f3ca670361054de47e2ad18b93da55c97bb --- /dev/null +++ b/analysis/src/data/protein_datamodule.py @@ -0,0 +1,242 @@ +from typing import Any, Dict, Optional, Tuple, List, Sequence + +import torch +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split +from lightning import LightningDataModule +from hydra.utils import instantiate + + +class BatchTensorConverter: + """Callable to convert an unprocessed (labels + strings) batch to a + processed (labels + tensor) batch. + """ + def __init__(self, target_keys: Optional[List] = None): + self.target_keys = target_keys + + def __call__(self, raw_batch: Sequence[Dict[str, object]]): + B = len(raw_batch) + # Only do for Tensor + target_keys = self.target_keys \ + if self.target_keys is not None else [k for k,v in raw_batch[0].items() if torch.is_tensor(v)] + # Non-array, for example string, int + non_array_keys = [k for k in raw_batch[0] if k not in target_keys] + collated_batch = dict() + for k in target_keys: + collated_batch[k] = self.collate_dense_tensors([d[k] for d in raw_batch], pad_v=0.0) + for k in non_array_keys: # return non-array keys as is + collated_batch[k] = [d[k] for d in raw_batch] + return collated_batch + + @staticmethod + def collate_dense_tensors(samples: Sequence, pad_v: float = 0.0): + """ + Takes a list of tensors with the following dimensions: + [(d_11, ..., d_1K), + (d_21, ..., d_2K), + ..., + (d_N1, ..., d_NK)] + and stack + pads them into a single tensor of: + (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) + """ + if len(samples) == 0: + return torch.Tensor() + if len(set(x.dim() for x in samples)) != 1: + raise RuntimeError( + f"Samples has varying dimensions: {[x.dim() for x in samples]}" + ) + (device,) = tuple(set(x.device for x in samples)) # assumes all on same device + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + result = torch.empty( + len(samples), *max_shape, dtype=samples[0].dtype, device=device + ) + result.fill_(pad_v) + for i in range(len(samples)): + result_i = result[i] + t = samples[i] + result_i[tuple(slice(0, k) for k in t.shape)] = t + return result + + +class ProteinDataModule(LightningDataModule): + """`LightningDataModule` for a single protein dataset, + for pretrain or finetune purpose. + + ### To be revised.### + + The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. + It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a + fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box + while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing + technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of + mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field. + + A `LightningDataModule` implements 7 key methods: + + ```python + def prepare_data(self): + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + + def setup(self, stage): + # Things to do on every process in DDP. + # Load data, set variables, etc... + + def train_dataloader(self): + # return train dataloader + + def val_dataloader(self): + # return validation dataloader + + def test_dataloader(self): + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + batch_size: int = 64, + generator_seed: int = 42, + train_val_split: Tuple[float, float] = (0.95, 0.05), + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + ) -> None: + """Initialize a `MNISTDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.dataset = dataset + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = batch_size + + def prepare_data(self) -> None: + """Download data if needed. Lightning ensures that `self.prepare_data()` is called only + within a single process on CPU, so you can safely add your downloading logic within. In + case of multi-node training, the execution of this hook depends upon + `self.prepare_data_per_node()`. + + Do not use it to assign state (self.x = y). + """ + pass + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + + # load and split datasets only if not loaded already + if stage == 'fit' and not self.data_train and not self.data_val: + # dataset = ConcatDataset(datasets=[trainset, testset]) + self.data_train, self.data_val = random_split( + dataset=self.dataset, + lengths=self.hparams.train_val_split, + generator=torch.Generator().manual_seed(self.hparams.generator_seed), + ) + elif stage in ('predict', 'test'): + self.data_test = self.dataset + else: + raise NotImplementedError(f"Stage {stage} not implemented.") + + def _dataloader_template(self, dataset: Dataset[Any]) -> DataLoader[Any]: + """Create a dataloader from a dataset. + + :param dataset: The dataset. + :return: The dataloader. + """ + batch_collator = BatchTensorConverter() # list of dicts -> dict of tensors + return DataLoader( + dataset=dataset, + collate_fn=batch_collator, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=self.hparams.shuffle, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return self._dataloader_template(self.data_train) + + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return self._dataloader_template(self.data_val) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return self._dataloader_template(self.data_test) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> Dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass + diff --git a/analysis/src/eval.py b/analysis/src/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6e8462fb3c86c6ebee79d76614d16fee2d5b09 --- /dev/null +++ b/analysis/src/eval.py @@ -0,0 +1,217 @@ +from typing import Any, Dict, List, Tuple +import os +from time import strftime + +import numpy as np +import pandas as pd +import torch +# import hydra +# import rootutils +# from lightning import LightningDataModule, LightningModule, Trainer +# from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +# rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from src.utils import ( + RankedLogger, + extras, + instantiate_loggers, + log_hyperparameters, + task_wrapper, + checkpoint_utils, + plot_utils, +) +from src.common.pdb_utils import extract_backbone_coords +from src.metrics import metrics +from src.common.geo_utils import _find_rigid_alignment + +log = RankedLogger(__name__, rank_zero_only=True) + + +def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None): + """Evaluate prediction results based on pdb files. + """ + if target_dir is None or not os.path.isdir(target_dir): + log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.") + return {} + + assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory." + + targets = [ + d.replace(".pdb", "") for d in os.listdir(target_dir) + ] + # pred_bases = os.listdir(pred_dir) + output_dir = pred_dir + tag = tag if tag is not None else "dev" + timestamp = strftime("%m%d-%H-%M") + + fns = { + 'val_clash': metrics.validity, + 'val_bond': metrics.bonding_validity, + 'js_pwd': metrics.js_pwd, + 'js_rg': metrics.js_rg, + # 'js_tica_pos': metrics.js_tica_pos, + 'w2_rmwd': metrics.w2_rmwd, + # 'div_rmsd': metrics.div_rmsd, + 'div_rmsf': metrics.div_rmsf, + 'pro_w_contacks': metrics.pro_w_contacts, + 'pro_t_contacks': metrics.pro_t_contacts, + # 'pro_c_contacks': metrics.pro_c_contacts, + } + eval_res = {k: {} for k in fns} + + + print(f"total_md_num = {len(targets)}") + count = 0 + for target in targets: + count += 1 + print("") + print(count, target) + pred_file = os.path.join(pred_dir, f"{target}.pdb") + # assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist." + if not os.path.isfile(pred_file): + continue + + target_file = os.path.join(target_dir, f"{target}.pdb") + ca_coords = { + 'target': extract_backbone_coords(target_file), + 'pred': extract_backbone_coords(pred_file), + } + cry_target_file = os.path.join(crystal_dir, f"{target}.pdb") + cry_ca_coords = extract_backbone_coords(cry_target_file)[0] + + + for f_name, func in fns.items(): + print(f_name) + + + if f_name == 'w2_rmwd': + v_ref = torch.as_tensor(ca_coords['target'][0]) + for k, v in ca_coords.items(): + v = torch.as_tensor(v) # (250,356,3) + for idx in range(v.shape[0]): + R, t = _find_rigid_alignment(v[idx], v_ref) + v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) + ca_coords[k] = v.numpy() + + + if f_name.startswith('js_'): + res = func(ca_coords, ref_key='target') + elif f_name == 'pro_c_contacks': + res = func(target_file, pred_file, cry_target_file) + elif f_name.startswith('pro_'): + res = func(ca_coords, cry_ca_coords) + else: + res = func(ca_coords) + + if f_name == 'js_tica' or f_name == 'js_tica_pos': + pass + # eval_res[f_name][target] = res[0]['pred'] + # save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png") + # plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target') + else: + eval_res[f_name][target] = res['pred'] + + csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv") + df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name + df.to_csv(csv_save_to) + print(f"metrics saved to {csv_save_to}") + mean_metrics = np.around(df.mean(), decimals=4) + + return mean_metrics + + +# @task_wrapper +# def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: +# """Sample on a test set and report evaluation metrics. + +# This method is wrapped in optional @task_wrapper decorator, that controls the behavior during +# failure. Useful for multiruns, saving info about the crash, etc. + +# :param cfg: DictConfig configuration composed by Hydra. +# :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. +# """ +# # assert cfg.ckpt_path +# pred_dir = cfg.get("pred_dir") +# if pred_dir and os.path.isdir(pred_dir): +# log.info(f"Found pre-computed prediction directory {pred_dir}.") +# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir) +# return metric_dict, None + +# log.info(f"Instantiating datamodule <{cfg.data._target_}>") +# datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + +# log.info(f"Instantiating model <{cfg.model._target_}>") +# model: LightningModule = hydra.utils.instantiate(cfg.model) + +# log.info("Instantiating loggers...") +# logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + +# log.info(f"Instantiating trainer <{cfg.trainer._target_}>") +# trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) + +# object_dict = { +# "cfg": cfg, +# "datamodule": datamodule, +# "model": model, +# "logger": logger, +# "trainer": trainer, +# } + +# if logger: +# log.info("Logging hyperparameters!") +# log_hyperparameters(object_dict) + +# # Load checkpoint manually. +# model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path) + +# # log.info("Starting testing!") +# # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) + +# # Get dataloader for prediction. +# datamodule.setup(stage="predict") +# dataloaders = datamodule.test_dataloader() + +# log.info("Starting predictions.") +# pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1] + +# # metric_dict = trainer.callback_metrics +# log.info("Starting evaluations.") +# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir) + +# return metric_dict, object_dict + + +# @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") +# def main(cfg: DictConfig) -> None: +# """Main entry point for evaluation. + +# :param cfg: DictConfig configuration composed by Hydra. +# """ +# # apply extra utilities +# # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) +# extras(cfg) + +# evaluate(cfg) + + +# if __name__ == "__main__": +# main() diff --git a/analysis/src/metrics/__init__.py b/analysis/src/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc b/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..943786a20c0c58e337532cf097a1e607873eeb48 Binary files /dev/null and b/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc b/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f483500f3111605d7d3dfc16fc71594d0a2aa1 Binary files /dev/null and b/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc b/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a62a83bd62b9161c18369ceb843570d75934b5 Binary files /dev/null and b/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc differ diff --git a/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc b/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bacc6af013cf47b8708b0305d26ad5af468f777b Binary files /dev/null and b/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc differ diff --git a/analysis/src/metrics/metrics.py b/analysis/src/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..01a1e521f38f018ea1e814bdc10a3a589191c588 --- /dev/null +++ b/analysis/src/metrics/metrics.py @@ -0,0 +1,550 @@ +import os +from typing import * + +import numpy as np +import torch +from scipy.spatial import distance +# from deeptime.decomposition import TICA +from src.common.geo_utils import rmsd, _find_rigid_alignment, squared_deviation +from scipy.linalg import fractional_matrix_power +from sklearn.mixture import GaussianMixture +from Bio.PDB import PDBParser +# import freesasa +from Bio.PDB.Polypeptide import PPBuilder +import multiprocessing as mp + +EPS = 1e-12 +PSEUDO_C = 1e-6 + + +def adjacent_ca_distance(coords): + """Calculate distance array for a single chain of CA atoms. Only k=1 neighbors. + Args: + coords: (..., L, 3) + return + dist: (..., L-1) + """ + assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3) + dX = coords[..., :-1, :] - coords[..., 1:, :] # (..., L-1, 3) + dist = np.sqrt(np.sum(dX**2, axis=-1)) + return dist # (..., L-1) + + +def distance_matrix_ca(coords): + """Calculate distance matrix for a single chain of CA atoms. W/o exclude neighbors. + Args: + coords: (..., L, 3) + Return: + dist: (..., L, L) + """ + assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3) + dX = coords[..., None, :, :] - coords[..., None, :] # (..., L, L, 3) + dist = np.sqrt(np.sum(dX**2, axis=-1)) + return dist # (..., L, L) + + +def pairwise_distance_ca(coords, k=1): + """Calculate pairwise distance vector for a single chain of CA atoms. W/o exclude neighbors. + Args: + coords: (..., L, 3) + Return: + dist: (..., D) (D=L * (L - 1) // 2) when k=1) + """ + assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3) + dist = distance_matrix_ca(coords) + L = dist.shape[-1] + row, col = np.triu_indices(L, k=k) + triu = dist[..., row, col] # unified (but unclear) order + return triu # (..., D) + + +def radius_of_gyration(coords, masses=None): + """Compute the radius of gyration for every frame. + + Args: + coords: (..., num_atoms, 3) + masses: (num_atoms,) + + Returns: + Rg: (..., ) + + If masses are none, assumes equal masses. + """ + assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3) + + if masses is None: + masses = np.ones(coords.shape[-2]) + else: + assert len(masses.shape) == 1, f"masses should be 1D, got {masses.shape}" + assert masses.shape[0] == coords.shape[-2], f"masses {masses.shape} != number of particles {coords.shape[-2]}" + + weights = masses / masses.sum() + centered = coords - coords.mean(-2, keepdims=True) + squared_dists = (centered ** 2).sum(-1) + Rg = (squared_dists * weights).sum(-1) ** 0.5 + return Rg + + +def _steric_clash(coords, ca_vdw_radius=1.7, allowable_overlap=0.4, k_exclusion=0): + """ https://www.schrodinger.com/sites/default/files/s3/public/python_api/2022-3/_modules/schrodinger/structutils/interactions/steric_clash.html#clash_iterator + Calculate the number of clashes in a single chain of CA atoms. + + Usage: + n_clash = calc_clash(coords) + + Args: + coords: (n_atoms, 3), CA coordinates, coords should from one protein chain. + ca_vdw_radius: float, default 1.7. + allowable_overlap: float, default 0.4. + k_exclusion: int, default 0. Exclude neighbors within [i-k-1, i+k+1]. + + """ + assert np.isnan(coords).sum() == 0, "coords should not contain nan" + assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3) + assert k_exclusion >= 0, "k_exclusion should be non-negative" + bar = 2 * ca_vdw_radius - allowable_overlap + # L = len(coords) + # dist = np.sqrt(np.sum((coords[:L-k_exclusion, None, :] - coords[None, k_exclusion:, :])**2, axis=-1)) + pwd = pairwise_distance_ca(coords, k=k_exclusion+1) # by default, only excluding self (k=1) + + # print('val_clash') + # print(pwd.shape) + # print(pwd.max(),pwd.min()) + # idx_min=-1 + # smin=10 + # for idx,pwd_single in enumerate(pwd): + # if pwd_single.min()0).mean() for k, v in n_clash.items() + # } + results = { + k: 1.0 - (v/num_residue).mean() for k, v in n_clash.items() + } + + results = {k: np.around(v, decimals=4) for k, v in results.items()} + return results + + +def bonding_validity(ca_coords_dict, ref_key='target', eps=1e-6): + """Calculate bonding dissociation validity of ensembles.""" + adj_dist = {k: adjacent_ca_distance(v) + for k, v in ca_coords_dict.items() + } + thres = adj_dist[ref_key].max()+ 1e-6 + + # print('val_bond') + # print('target') + # print(adj_dist['target'].shape) + # print(adj_dist['target'].max(),adj_dist['target'].min()) + # print('pred') + # print(adj_dist['pred'].shape) + # print(adj_dist['pred'].max(),adj_dist['pred'].min()) + + # idx_max=-1 + # smax=0 + # for idx,adj in enumerate(adj_dist['pred']): + # if adj.max()>smax: + # smax=adj.max() + # idx_max=idx+1 + # print('smax=',smax) + # print('idx_max=',idx_max) + # max_index2 = np.argmax(adj_dist['pred'], axis=-1) + # print('res_max_index_all',max_index2+1) + # print('res_max=',max_index2[idx_max-1]+1) + + # results = { + # k: (v < thres).all(-1).sum().item() / len(v) + # for k, v in adj_dist.items() + # } + results = { + k: (v < thres).mean() + for k, v in adj_dist.items() + } + + results = {k: np.around(v, decimals=4) for k, v in results.items()} + return results + + +def js_pwd(ca_coords_dict, ref_key='target', n_bins=50, pwd_offset=3, weights=None): + # n_bins = 50 follows idpGAN + # k=3 follows 2for1 + + ca_pwd = { + k: pairwise_distance_ca(v, k=pwd_offset) for k, v in ca_coords_dict.items() + } # (B, D) + + if weights is None: + weights = {} + weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) + + d_min = ca_pwd[ref_key].min(axis=0) # (D, ) + d_max = ca_pwd[ref_key].max(axis=0) + ca_pwd_binned = { + k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0, + np.concatenate([v, d_min[None], d_max[None]], axis=0)) + for k, v in ca_pwd.items() + } # (N_bins, D)-> (N_bins * D, ) + # js divergence per channel and average + results = {k: distance.jensenshannon(v, ca_pwd_binned[ref_key], axis=0).mean() + for k, v in ca_pwd_binned.items() if k != ref_key} + results[ref_key] = 0.0 + results = {k: np.around(v, decimals=4) for k, v in results.items()} + return results + +# def js_tica(ca_coords_dict, ref_key='target', n_bins=50, lagtime=20, return_tic=True, weights=None): +# # n_bins = 50 follows idpGAN +# ca_pwd = { +# k: pairwise_distance_ca(v) for k, v in ca_coords_dict.items() +# } # (B, D) + +# print('tica1', ca_pwd[ref_key].shape) +# estimator = TICA(dim=2, lagtime=lagtime).fit(ca_pwd[ref_key]) +# print('tica2') +# tica = estimator.fetch_model() +# # dimension reduction into 2D +# ca_dr2d = { +# k: tica.transform(v) for k, v in ca_pwd.items() +# } +# if weights is None: weights = {} +# weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) + +# d_min = ca_dr2d[ref_key].min(axis=0) # (D, ) +# d_max = ca_dr2d[ref_key].max(axis=0) +# ca_dr2d_binned = { +# k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0, +# np.concatenate([v, d_min[None], d_max[None]], axis=0)) +# for k, v in ca_dr2d.items() +# } # (N_bins, 2) +# results = {k: distance.jensenshannon(v, ca_dr2d_binned[ref_key], axis=0).mean() +# for k, v in ca_dr2d_binned.items() if k != ref_key} +# results[ref_key] = 0.0 +# results = {k: np.around(v, decimals=4) for k, v in results.items()} +# if return_tic: +# return results, ca_dr2d +# return results + +# def js_tica_pos(ca_coords_dict, ref_key='target', n_bins=50, lagtime=20, return_tic=True, weights=None): +# # n_bins = 50 follows idpGAN +# v_ref = torch.as_tensor(ca_coords_dict['target'][0]) +# for k, v in ca_coords_dict.items(): +# v = torch.as_tensor(v) +# for idx in range(v.shape[0]): +# R, t = _find_rigid_alignment(v[idx], v_ref) +# v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) +# ca_coords_dict[k] = v.numpy() + +# ca_pos = { k: v.reshape(v.shape[0],-1) for k, v in ca_coords_dict.items()} # (B, 3*N) + +# estimator = TICA(dim=2, lagtime=lagtime).fit(ca_pos[ref_key]) +# tica = estimator.fetch_model() +# # dimension reduction into 2D +# ca_dr2d = { +# k: tica.transform(v) for k, v in ca_pos.items() +# } +# if weights is None: weights = {} +# weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) + +# d_min = ca_dr2d[ref_key].min(axis=0) # (D, ) +# d_max = ca_dr2d[ref_key].max(axis=0) +# ca_dr2d_binned = { +# k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0, +# np.concatenate([v, d_min[None], d_max[None]], axis=0)) +# for k, v in ca_dr2d.items() +# } # (N_bins, 2) +# results = {k: distance.jensenshannon(v, ca_dr2d_binned[ref_key], axis=0).mean() +# for k, v in ca_dr2d_binned.items() if k != ref_key} +# results[ref_key] = 0.0 +# results = {k: np.around(v, decimals=4) for k, v in results.items()} +# if return_tic: +# return results, ca_dr2d +# return results + +def js_rg(ca_coords_dict, ref_key='target', n_bins=50, weights=None): + ca_rg = { + k: radius_of_gyration(v) for k, v in ca_coords_dict.items() + } # (B, ) + if weights is None: + weights = {} + weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) + + d_min = ca_rg[ref_key].min() # (1, ) + d_max = ca_rg[ref_key].max() + ca_rg_binned = { + k: np.histogram(v, bins=n_bins, weights=weights[k], range=(d_min, d_max))[0]+PSEUDO_C + for k, v in ca_rg.items() + } # (N_bins, ) + # print("ca_rg_binned shape", {k: v.shape for k, v in ca_rg_binned.items()}) + results = {k: distance.jensenshannon(v, ca_rg_binned[ref_key], axis=0).mean() + for k, v in ca_rg_binned.items() if k != ref_key} + + results[ref_key] = 0.0 + results = {k: np.around(v, decimals=4) for k, v in results.items()} + return results + +def div_rmsd(ca_coords_dict): + results = {} + for k, v in ca_coords_dict.items(): + + # print(k) # [target, pred] + # print(v.shape) # (25,356,3) + # only calculate Ca + + v = torch.as_tensor(v) + # for idx in range(v.shape[0]): + # R, t = _find_rigid_alignment(v[idx], v[0]) + # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) + + # v1 = v.numpy() + # count = v1.shape[0] + # rmsd_sum = np.sum(np.sqrt(np.sum((v1[:, None, :] - v1[None, :, :])**2, axis=-1))) + + count = 0 + rmsd_2_sum = 0 + for coord1 in v: + for coord2 in v: + count += 1 + rmsd_2_sum += squared_deviation(coord1,coord2,reduction='none') # (356,) + + # with mp.Pool() as pool: + # res= pool.starmap(squared_deviation,[(coord1, coord2, 'none') for coord1 in v for coord2 in v]) + # pool.close() + # pool.join() + # count = len(res)-v.shape[0] + # rmsd_2_sum = sum(res) + + results[k]=torch.sqrt(rmsd_2_sum/count) + results[k]=np.around(float(torch.mean(results[k])), decimals=4) + results['pred'] = (results['pred']-results['target'])/results['target'] + results = {k: np.around(v, decimals=4) for k, v in results.items()} + # print(result) + return results + +def div_rmsf(ca_coords_dict): + ''' + 1D and 0D data + ''' + results = {} + for k, v in ca_coords_dict.items(): + + v = torch.as_tensor(v) # (250,356,3) + # for idx in range(v.shape[0]): + # R, t = _find_rigid_alignment(v[idx], v[0]) + # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) + + count = 0 + rmsd_2_sum = 0 + mean_str = torch.mean(v,dim = 0) # (356,3) + for coord1 in v: + count += 1 + rmsd_2_sum += squared_deviation(coord1,mean_str,reduction='none') # (356,) + + # count = v.shape[0] + # rmsd_2_sum = torch.sum(torch.norm(v - mean_str[None,...], dim=-1) ** 2) + + # mean_str = torch.mean(v,dim = 0) # (356,3) + # with mp.Pool() as pool: + # res= pool.starmap(squared_deviation,[[(coord1, mean_str, 'none') for coord1 in v]]) + # pool.close() + # pool.join() + # count = len(res) + # rmsd_2_sum = sum(res) + + results[k]=torch.sqrt(rmsd_2_sum/count) + results[k]=np.around(float(torch.mean(results[k])), decimals=4) + # print(result[k]) + results['pred'] = (results['pred']-results['target'])/results['target'] + results = {k: np.around(v, decimals=4) for k, v in results.items()} + return results + +def w2_rmwd(ca_coords_dict): + result = {} + means_total = {} + covariances_total = {} + count = 0 + v_ref = torch.as_tensor(ca_coords_dict['target'][0]) + for k, v in ca_coords_dict.items(): + + v = torch.as_tensor(v) + # for idx in range(v.shape[0]): + # R, t = _find_rigid_alignment(v[idx], v_ref) + # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) + + means_total[k] = [] + covariances_total[k] = [] + + for idx_residue in range(v.shape[1]): + gmm = GaussianMixture(n_components=1) + gmm.fit(v[:, idx_residue, :]) + means = torch.as_tensor(gmm.means_[0]) # 形状为 (3,) + covariances = torch.as_tensor(gmm.covariances_[0]) # 形状为 (3, 3) + + means_total[k].append(means) + covariances_total[k].append(covariances) + means_total[k] = torch.stack(means_total[k], dim=0) # (356, 3) + covariances_total[k] = torch.stack(covariances_total[k], dim=0) # (356, 3, 3) + # print(means_total[k].shape, covariances_total[k].shape) + # print(means_total[k][0], covariances_total[k][0]) + + sigma_1_2_sqrt = [torch.as_tensor(fractional_matrix_power(i, 0.5)) for i in torch.matmul(covariances_total['target'], covariances_total['pred'])] + sigma_1_2_sqrt = torch.stack(sigma_1_2_sqrt, dim=0) + sigma_trace = covariances_total['target'] + covariances_total['pred'] - 2 * sigma_1_2_sqrt + sigma_trace = [torch.trace(i) for i in sigma_trace] + sigma_trace = torch.stack(sigma_trace, dim=0) + + result_1D = torch.sum((means_total['target'] - means_total['pred'])**2, dim=-1) + sigma_trace + result['pred'] = np.around(float(torch.mean(result_1D)), decimals=4) + # print(result['pred']) + + return result + +def pro_w_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): + result = {} + w_contacts_total = {} + + dist = distance_matrix_ca(cry_ca_coords) + L = dist.shape[-1] + row, col = np.triu_indices(L, k=1) + triu = dist[..., row, col] # (n*(n-1)/2) + w_contacts_crystall = (triu < dist_threshold) + + for k, v in ca_coords_dict.items(): + + dist = distance_matrix_ca(v) + + L = dist.shape[-1] + row, col = np.triu_indices(L, k=1) + triu = dist[..., row, col] # (b, n*(n-1)/2) + + w_contacts = (torch.tensor(triu) > dist_threshold).type(torch.float32) + w_contacts = torch.mean(w_contacts, dim=0) # (n*(n-1)/2,) + w_contacts = w_contacts > percent_threshold + + w_contacts_total[k] = w_contacts & w_contacts_crystall + + jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) + result['pred'] = np.around(float(jac_w_contacts), decimals=4) + # print(result['pred']) + + return result + +def pro_t_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): + result = {} + w_contacts_total = {} + + dist = distance_matrix_ca(cry_ca_coords) + L = dist.shape[-1] + row, col = np.triu_indices(L, k=1) + triu = dist[..., row, col] # (n*(n-1)/2) + w_contacts_crystall = (triu >= dist_threshold) + + for k, v in ca_coords_dict.items(): + + dist = distance_matrix_ca(v) + + L = dist.shape[-1] + row, col = np.triu_indices(L, k=1) + triu = dist[..., row, col] # (b, n*(n-1)/2) + + w_contacts = (torch.tensor(triu) <= dist_threshold).type(torch.float32) + w_contacts = torch.mean(w_contacts, dim=0) # (n*(n-1)/2,) + w_contacts = w_contacts > percent_threshold + + w_contacts_total[k] = w_contacts & w_contacts_crystall + + jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) + result['pred'] = np.around(float(jac_w_contacts), decimals=4) + # print(result['pred']) + + return result + +# def pro_c_contacts(target_file, pred_file, cry_target_file, area_threshold = 2.0, percent_threshold = 0.1): +# result = {} +# c_contacts_total = {} + +# parser = PDBParser() +# params = freesasa.Parameters({'algorithm': 'ShrakeRupley', 'probe-radius': 2.8}) + +# structure_cry_target = parser.get_structure('cry_target', cry_target_file) +# str_params = {'separate-chains': False, 'separate-models': True} +# structure_target = freesasa.structureArray(target_file,str_params) +# structure_pred = freesasa.structureArray(pred_file,str_params) + + +# structure_cry_target = freesasa.structureFromBioPDB(structure_cry_target) +# sasa = freesasa.calc(structure_cry_target,params) +# residue_sasa = sasa.residueAreas() + +# c_contacts_crystall = [] +# # 打印每个残基的 SASA +# for chain_id in residue_sasa: +# for residue_id in residue_sasa[chain_id]: +# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}") +# c_contacts_crystall.append(residue_sasa[chain_id][residue_id].total < area_threshold) +# c_contacts_crystall = torch.tensor(c_contacts_crystall) + +# c_contacts_target = 0 +# count = 0 +# for structure_temp in structure_target: +# count += 1 +# sasa = freesasa.calc(structure_temp,params) +# residue_sasa = sasa.residueAreas() + +# c_contacts_temp = [] +# # 打印每个残基的 SASA +# for chain_id in residue_sasa: +# for residue_id in residue_sasa[chain_id]: +# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}") +# c_contacts_temp.append(residue_sasa[chain_id][residue_id].total > area_threshold) +# c_contacts_temp = torch.tensor(c_contacts_temp).type(torch.float32) +# c_contacts_target += c_contacts_temp +# c_contacts_target = c_contacts_target / count +# c_contacts_total['target'] = (c_contacts_target > percent_threshold) & c_contacts_crystall + + +# c_contacts_pred = 0 +# count = 0 +# for structure_temp in structure_pred: +# count += 1 +# sasa = freesasa.calc(structure_temp,params) +# residue_sasa = sasa.residueAreas() + +# c_contacts_temp = [] +# # 打印每个残基的 SASA +# for chain_id in residue_sasa: +# for residue_id in residue_sasa[chain_id]: +# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}") +# c_contacts_temp.append(residue_sasa[chain_id][residue_id].total > area_threshold) +# c_contacts_temp = torch.tensor(c_contacts_temp).type(torch.float32) +# c_contacts_pred += c_contacts_temp +# c_contacts_pred = c_contacts_pred / count +# c_contacts_total['pred'] = (c_contacts_pred > percent_threshold) & c_contacts_crystall + +# jac_w_contacts = torch.sum(c_contacts_total['target'] & c_contacts_total['pred'])/torch.sum(c_contacts_total['target'] | c_contacts_total['pred']) +# result['pred'] = np.around(float(jac_w_contacts), decimals=4) +# # print(jac_w_contacts) +# return result \ No newline at end of file diff --git a/analysis/src/models/__init__.py b/analysis/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/models/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36259e9bc2d83c41aad93af61758d758f3108391 Binary files /dev/null and b/analysis/src/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc b/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..851fa55c838e0e105cdf51b7f0c19e6a37538f78 Binary files /dev/null and b/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc differ diff --git a/analysis/src/models/__pycache__/loss.cpython-39.pyc b/analysis/src/models/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c562b4260ada6ae3604b36d9cc562a2c66dc0b34 Binary files /dev/null and b/analysis/src/models/__pycache__/loss.cpython-39.pyc differ diff --git a/analysis/src/models/diffusion_module.py b/analysis/src/models/diffusion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..671de760108128189ab3ef7d4b54187be039d16d --- /dev/null +++ b/analysis/src/models/diffusion_module.py @@ -0,0 +1,408 @@ +import os +from typing import Any, Dict, Tuple, Optional +from random import random +from copy import deepcopy + +import numpy as np +import torch +from lightning import LightningModule +from torchmetrics import MinMetric, MeanMetric + +from src.models.score.frame import FrameDiffuser +from src.models.loss import ScoreMatchingLoss +from src.common.rigid_utils import Rigid +from src.common.all_atom import compute_backbone +from src.common.pdb_utils import atom37_to_pdb, merge_pdbfiles + + +class DiffusionLitModule(LightningModule): + """Example of a `LightningModule` for denoising diffusion training. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + diffuser: FrameDiffuser, + loss: Dict[str, Any], + compile: bool, + inference: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + # network and diffusion module + self.net = net + self.diffuser = diffuser + + # loss function + self.loss = ScoreMatchingLoss(config=self.hparams.loss) + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + # self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_loss_best = MinMetric() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + (Not actually used) + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_loss_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], training: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + """ + # preprocess by augmenting additional feats + rigids_0 = Rigid.from_tensor_4x4(batch['rigidgroups_gt_frames'][..., 0, :, :]) + batch_size = rigids_0.shape[0] + t = (1.0 - self.diffuser.min_t) * torch.rand(batch_size, device=rigids_0.device) + self.diffuser.min_t + perturb_feats = self.diffuser.forward_marginal( + rigids_0=rigids_0, + t=t, + diffuse_mask=None, + as_tensor_7=True, + ) + patch_feats = { + 't': t, + 'rigids_0': rigids_0, + } + batch.update({**perturb_feats, **patch_feats}) + + # probably add self-conditioning (recycle once) + if self.net.embedder.self_conditioning and random() > 0.5: + with torch.no_grad(): + batch['sc_ca_t'] = self.net(batch, as_tensor_7=True)['rigids'][..., 4:] + + # feedforward + out = self.net(batch) + + # postprocess by add score computation + pred_scores = self.diffuser.score( + rigids_0=out['rigids'], + rigids_t=Rigid.from_tensor_7(batch['rigids_t']), + t=t, + mask=batch['residue_mask'], + ) + out.update(pred_scores) + + # calculate losses + loss, loss_bd = self.loss(out, batch, _return_breakdown=True) + return loss, loss_bd + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, loss_bd = self.model_step(batch) + + # update and log metrics + self.train_loss(loss) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + for k,v in loss_bd.items(): + if k == 'loss': continue + self.log(f"train/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + # return loss or backpropagation will fail + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + pass + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, loss_bd = self.model_step(batch, training=False) + + # update and log metrics + self.val_loss(loss) # update + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + _vall = self.val_loss.compute() # get current val acc + self.val_loss_best(_vall) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + raise NotImplementedError("Test step not implemented.") + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def predict_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, + ) -> str: + """Perform a prediction step on a batch of data from the dataloader. + + This prediction step will sample `n_replica` copies from the forward-backward process, + repeated for each delta-T in the range of [delta_min, delta_max] with step size + `delta_step`. If `backward_only` is set to True, then only backward process will be + performed, and `n_replica` will be multiplied by the number of delta-Ts. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + # extract hyperparams for inference + n_replica = self.hparams.inference.n_replica + replica_per_batch = self.hparams.inference.replica_per_batch + delta_range = np.arange( + self.hparams.inference.delta_min, + self.hparams.inference.delta_max + 1e-5, + self.hparams.inference.delta_step + ) + delta_range = np.around(delta_range, decimals=2) # up to 2 decimal places + num_timesteps = self.hparams.inference.num_timesteps + noise_scale = self.hparams.inference.noise_scale + probability_flow = self.hparams.inference.probability_flow + self_conditioning = self.hparams.inference.self_conditioning and self.net.embedder.self_conditioning + min_t = self.hparams.inference.min_t + output_dir = self.hparams.inference.output_dir + backward_only = self.hparams.inference.backward_only + # if backward_only, then only perform backward process (vanilla sampling of diffusion) + if backward_only: + n_replica *= len(delta_range) + delta_range = [-1.0] + + assert batch['aatype'].shape[0] == 1, "Batch size must be 1 for correct inference." + + # get extra features of the current protein + accession_code = batch['accession_code'][0] + extra = { + 'aatype': batch['aatype'][0].detach().cpu().numpy(), + 'chain_index': batch['chain_index'][0].detach().cpu().numpy(), + 'residue_index': batch['residue_index'][0].detach().cpu().numpy(), + } + + # define sampling subroutine + def forward_backward(rigids_0: Rigid, t_delta: float): + # if t_delta <= 0 (invalid), then perform backward only (from t=T) + T = t_delta if t_delta > 0 else 1.0 + + batch_size, device = rigids_0.shape[0], rigids_0.device + _num_timesteps = int(float(num_timesteps) * T) + dt = 1.0 / _num_timesteps + ts = np.linspace(min_t, T, _num_timesteps)[::-1] # reverse in time + + _feats = deepcopy({ + k: v.repeat(batch_size, *(1,)*(v.ndim-1)) + for k,v in batch.items() if k in ('aatype', 'residue_mask', 'fixed_mask', 'residue_idx', 'torsion_angles_sin_cos') + }) + if t_delta > 0: + rigids_t = self.diffuser.forward_marginal( + rigids_0=rigids_0, + t=t_delta * torch.ones(batch_size, device=device), + diffuse_mask=_feats['residue_mask'], + as_tensor_7=True, + )['rigids_t'] + else: + rigids_t = self.diffuser.sample_prior( + shape=rigids_0.shape, + device=device, + as_tensor_7=True, + )['rigids_t'] + + _feats['rigids_t'] = rigids_t + + traj_atom37 = [] + with torch.no_grad(): + fixed_mask = _feats['fixed_mask'] * _feats['residue_mask'] + diffuse_mask = (1 - _feats['fixed_mask']) * _feats['residue_mask'] + + if self_conditioning: + _feats['sc_ca_t'] = torch.zeros_like(rigids_t[..., 4:]) + _feats['t'] = ts[0] * torch.ones(batch_size, device=device) + _feats['sc_ca_t'] = self.net(_feats, as_tensor_7=True)['rigids'][..., 4:] # update self-conditioning feats + + for t in ts: + _feats['t'] = t * torch.ones(batch_size, device=device) + out = self.net(_feats, as_tensor_7=False) + + # compute predicted rigids + if t == min_t: + rigids_pred = out['rigids'] + else: + # update self-conditioning feats + if self_conditioning: + _feats['sc_ca_t'] = out['rigids'].to_tensor_7()[..., 4:] + # get score based on predicted rigids + pred_scores = self.diffuser.score( + rigids_0=out['rigids'], + rigids_t=Rigid.from_tensor_7(_feats['rigids_t']), + t=_feats['t'], + mask=_feats['residue_mask'], + ) + rigids_pred = self.diffuser.reverse( + rigids_t=Rigid.from_tensor_7(_feats['rigids_t']), + rot_score=pred_scores['rot_score'], + trans_score=pred_scores['trans_score'], + t=_feats['t'], + dt=dt, + diffuse_mask=diffuse_mask, + center_trans=True, + noise_scale=noise_scale, + probability_flow=probability_flow, + ) # Rigid object + # update rigids_t as tensor_7 + _feats['rigids_t'] = rigids_pred.to_tensor_7() + + # compute atom37 positions + atom37 = compute_backbone(rigids_pred, out['psi'], aatype=_feats['aatype'])[0] + atom37 = atom37.detach().cpu().numpy() # (Bi, L, 37 ,3) + return atom37 + + saved_paths = [] + + # iterate over delta-Ts + for t_delta in delta_range: + gt_rigids_4x4 = batch['rigidgroups_gt_frames'][..., 0, :, :].clone() + n_bs = n_replica // replica_per_batch # number of batches + last_bs = n_replica % replica_per_batch # last batch size + atom_positions = [] + for _ in range(n_bs): + rigids_0 = Rigid.from_tensor_4x4(gt_rigids_4x4.repeat(replica_per_batch, *(1,)*(gt_rigids_4x4.ndim-1))) + traj_atom37 = forward_backward(rigids_0, t_delta) + atom_positions.append(traj_atom37) + if last_bs > 0: + rigids_0 = Rigid.from_tensor_4x4(gt_rigids_4x4.repeat(last_bs, *(1,)*(gt_rigids_4x4.ndim-1))) + traj_atom37 = forward_backward(rigids_0, t_delta) + atom_positions.append(traj_atom37) + atom_positions = np.concatenate(atom_positions, axis=0) # (B, L, 37, 3) + + # Save atom positions to pdb. + t_delta_dir = os.path.join(output_dir, f"{t_delta}") + os.makedirs(t_delta_dir, exist_ok=True) + save_to = os.path.join(t_delta_dir, f"{accession_code}.pdb") + saved_to = atom37_to_pdb( + atom_positions=atom_positions, + save_to=save_to, + **extra, + ) + saved_paths.append(saved_to) + + all_delta_dir = os.path.join(output_dir, "all_delta") + os.makedirs(all_delta_dir, exist_ok=True) + merge_pdbfiles(saved_paths, os.path.join(all_delta_dir, f"{accession_code}.pdb")) + + return all_delta_dir + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = DiffusionLitModule(None, None, None, None, None) diff --git a/analysis/src/models/loss.py b/analysis/src/models/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3538b6d59abb48c841791fa8c28d704bbd894d67 --- /dev/null +++ b/analysis/src/models/loss.py @@ -0,0 +1,1741 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import logging +from typing import Dict, Optional, Tuple + +import ml_collections +import numpy as np +import torch +import torch.nn as nn + +from src.common import residue_constants +from src.common.all_atom import compute_backbone +from src.common.rigid_utils import Rotation, Rigid +from src.utils.tensor_utils import ( + tree_map, + tensor_tree_map, + masked_mean, + permute_final_dims, + batched_gather, + sum_except_batch, + inflate_array_like +) + + +def softmax_cross_entropy(logits, labels): + loss = -1 * torch.sum( + labels * torch.nn.functional.log_softmax(logits, dim=-1), + dim=-1, + ) + return loss + + +def sigmoid_cross_entropy(logits, labels): + log_p = torch.log(torch.sigmoid(logits)) + log_not_p = torch.log(torch.sigmoid(-logits)) + loss = -labels * log_p - (1 - labels) * log_not_p + return loss + + +def torsion_angle_loss( + a, # [*, N, 7, 2] + a_gt, # [*, N, 7, 2] + a_alt_gt, # [*, N, 7, 2] +): + # [*, N, 7] + norm = torch.norm(a, dim=-1) + + # [*, N, 7, 2] + a = a / norm.unsqueeze(-1) + + # [*, N, 7] + diff_norm_gt = torch.norm(a - a_gt, dim=-1) + diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) + min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2) + + # [*] + l_torsion = torch.mean(min_diff, dim=(-1, -2)) + l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) + + an_weight = 0.02 + return l_torsion + an_weight * l_angle_norm + + +def compute_fape( + pred_frames: Rigid, + target_frames: Rigid, + frames_mask: torch.Tensor, + pred_positions: torch.Tensor, + target_positions: torch.Tensor, + positions_mask: torch.Tensor, + length_scale: float, + l1_clamp_distance: Optional[float] = None, + eps=1e-8, + ignore_nan=True, +) -> torch.Tensor: + """ + Computes FAPE loss. + + Args: + pred_frames: + [*, N_frames] Rigid object of predicted frames + target_frames: + [*, N_frames] Rigid object of ground truth frames + frames_mask: + [*, N_frames] binary mask for the frames + pred_positions: + [*, N_pts, 3] predicted atom positions + target_positions: + [*, N_pts, 3] ground truth positions + positions_mask: + [*, N_pts] positions mask + length_scale: + Length scale by which the loss is divided + l1_clamp_distance: + Cutoff above which distance errors are disregarded + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + # [*, N_frames, N_pts, 3] + local_pred_pos = pred_frames.invert()[..., None].apply( + pred_positions[..., None, :, :], + ) + local_target_pos = target_frames.invert()[..., None].apply( + target_positions[..., None, :, :], + ) + + error_dist = torch.sqrt( + torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps + ) + + if l1_clamp_distance is not None: + error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error = normed_error * frames_mask[..., None] + normed_error = normed_error * positions_mask[..., None, :] + if ignore_nan: + normed_error = torch.nan_to_num(normed_error) + + # FP16-friendly averaging. Roughly equivalent to: + # + # norm_factor = ( + # torch.sum(frames_mask, dim=-1) * + # torch.sum(positions_mask, dim=-1) + # ) + # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) + # + # ("roughly" because eps is necessarily duplicated in the latter) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = ( + normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] + ) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) + return normed_error + + +def backbone_loss( + backbone_rigid_tensor: torch.Tensor, + backbone_rigid_mask: torch.Tensor, + traj: torch.Tensor, + use_clamped_fape: Optional[torch.Tensor] = None, + clamp_distance: float = 10.0, + loss_unit_distance: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + pred_aff = Rigid.from_tensor_7(traj) + pred_aff = Rigid( + Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), + pred_aff.get_trans(), + ) + + # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of + # backbone tensor, normalizes it, and then turns it back to a rotation + # matrix. To avoid a potentially numerically unstable rotation matrix + # to quaternion conversion, we just use the original rotation matrix + # outright. This one hasn't been composed a bunch of times, though, so + # it might be fine. + gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=clamp_distance, + length_scale=loss_unit_distance, + eps=eps, + ) + if use_clamped_fape is not None: + unclamped_fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=None, + length_scale=loss_unit_distance, + eps=eps, + ) + + fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * ( + 1 - use_clamped_fape + ) + + # Average over the batch dimension + fape_loss = torch.mean(fape_loss) + + return fape_loss + + +def sidechain_loss( + sidechain_frames: torch.Tensor, + sidechain_atom_pos: torch.Tensor, + rigidgroups_gt_frames: torch.Tensor, + rigidgroups_alt_gt_frames: torch.Tensor, + rigidgroups_gt_exists: torch.Tensor, + renamed_atom14_gt_positions: torch.Tensor, + renamed_atom14_gt_exists: torch.Tensor, + alt_naming_is_better: torch.Tensor, + clamp_distance: float = 10.0, + length_scale: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + renamed_gt_frames = ( + 1.0 - alt_naming_is_better[..., None, None, None] + ) * rigidgroups_gt_frames + alt_naming_is_better[ + ..., None, None, None + ] * rigidgroups_alt_gt_frames + + # Steamroll the inputs + sidechain_frames = sidechain_frames[-1] + batch_dims = sidechain_frames.shape[:-4] + sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) + sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) + renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) + renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) + rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) + sidechain_atom_pos = sidechain_atom_pos[-1] + sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) + renamed_atom14_gt_positions = renamed_atom14_gt_positions.view( + *batch_dims, -1, 3 + ) + renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) + + fape = compute_fape( + sidechain_frames, + renamed_gt_frames, + rigidgroups_gt_exists, + sidechain_atom_pos, + renamed_atom14_gt_positions, + renamed_atom14_gt_exists, + l1_clamp_distance=clamp_distance, + length_scale=length_scale, + eps=eps, + ) + + return fape + + +def fape_loss( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + bb_loss = backbone_loss( + traj=out["sm"]["frames"], + **{**batch, **config.backbone}, + ) + + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + **{**batch, **config.sidechain}, + ) + + loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def supervised_chi_loss( + angles_sin_cos: torch.Tensor, + unnormalized_angles_sin_cos: torch.Tensor, + aatype: torch.Tensor, + seq_mask: torch.Tensor, + chi_mask: torch.Tensor, + chi_angles_sin_cos: torch.Tensor, + chi_weight: float, + angle_norm_weight: float, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + """ + Implements Algorithm 27 (torsionAngleLoss) + + Args: + angles_sin_cos: + [*, N, 7, 2] predicted angles + unnormalized_angles_sin_cos: + The same angles, but unnormalized + aatype: + [*, N] residue indices + seq_mask: + [*, N] sequence mask + chi_mask: + [*, N, 7] angle mask + chi_angles_sin_cos: + [*, N, 7, 2] ground truth angles + chi_weight: + Weight for the angle component of the loss + angle_norm_weight: + Weight for the normalization component of the loss + Returns: + [*] loss tensor + """ + pred_angles = angles_sin_cos[..., 3:, :] + residue_type_one_hot = torch.nn.functional.one_hot( + aatype, + residue_constants.restype_num + 1, + ) + chi_pi_periodic = torch.einsum( + "...ij,jk->ik", + residue_type_one_hot.type(angles_sin_cos.dtype), + angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), + ) + + true_chi = chi_angles_sin_cos[None] + + shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) + true_chi_shifted = shifted_mask * true_chi + sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1) + sq_chi_error_shifted = torch.sum( + (true_chi_shifted - pred_angles) ** 2, dim=-1 + ) + sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) + # The ol' switcheroo + sq_chi_error = sq_chi_error.permute( + *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1 + ) + sq_chi_loss = masked_mean( + chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3) + ) + + loss = chi_weight * sq_chi_loss + + angle_norm = torch.sqrt( + torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps + ) + norm_error = torch.abs(angle_norm - 1.0) + norm_error = norm_error.permute( + *range(len(norm_error.shape))[1:-2], 0, -2, -1 + ) + angle_norm_loss = masked_mean( + seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3) + ) + + loss = loss + angle_norm_weight * angle_norm_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def compute_plddt(logits: torch.Tensor) -> torch.Tensor: + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device + ) + probs = torch.nn.functional.softmax(logits, dim=-1) + pred_lddt_ca = torch.sum( + probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return pred_lddt_ca * 100 + + +def lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + ( + all_atom_positions[..., None, :] + - all_atom_positions[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + ( + all_atom_pred_pos[..., None, :] + - all_atom_pred_pos[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * permute_final_dims(all_atom_mask, (1, 0)) + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + return lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) + + +def lddt_loss( + logits: torch.Tensor, + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + cutoff: float = 15.0, + no_bins: int = 50, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + score = lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps + ) + + score = score.detach() + + bin_index = torch.floor(score * no_bins).long() + bin_index = torch.clamp(bin_index, max=(no_bins - 1)) + lddt_ca_one_hot = torch.nn.functional.one_hot( + bin_index, num_classes=no_bins + ) + + errors = softmax_cross_entropy(logits, lddt_ca_one_hot) + all_atom_mask = all_atom_mask.squeeze(-1) + loss = torch.sum(errors * all_atom_mask, dim=-1) / ( + eps + torch.sum(all_atom_mask, dim=-1) + ) + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + min_bin=2.3125, + max_bin=21.6875, + no_bins=64, + eps=1e-6, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + boundaries = boundaries ** 2 + + dists = torch.sum( + (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + true_bins = torch.sum(dists > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] + + # FP16-friendly sum. Equivalent to: + # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / + # (eps + torch.sum(square_mask, dim=(-1, -2)))) + denom = eps + torch.sum(square_mask, dim=(-1, -2)) + mean = errors * square_mask + mean = torch.sum(mean, dim=-1) + mean = mean / denom[..., None] + mean = torch.sum(mean, dim=-1) + + # Average over the batch dimensions + mean = torch.mean(mean) + + return mean + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat( + [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 + ) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + ( + predicted_aligned_error, + max_predicted_aligned_error, + ) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + bin_centers = _calculate_bin_centers(boundaries) + torch.sum(residue_weights) + n = logits.shape[-2] + clipped_n = max(n, 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] + + +def tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + max_bin=31, + no_bins=64, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps=1e-8, + **kwargs, +): + pred_affine = Rigid.from_tensor_7(final_affine_tensor) + backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + def _points(affine): + pts = affine.get_trans()[..., None, :, :] + return affine.invert()[..., None].apply(pts) + + sq_diff = torch.sum( + (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1 + ) + + sq_diff = sq_diff.detach() + + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + boundaries = boundaries ** 2 + true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, torch.nn.functional.one_hot(true_bins, no_bins) + ) + + square_mask = ( + backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :] + ) + + loss = torch.sum(errors * square_mask, dim=-1) + scale = 0.5 # hack to help FP16 training along + denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + # Average over the loss dimension + loss = torch.mean(loss) + + return loss + + +def between_residue_bond_loss( + pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3) + pred_atom_mask: torch.Tensor, # (*, N, 37/14) + residue_index: torch.Tensor, # (*, N) + aatype: torch.Tensor, # (*, N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0, + eps=1e-6, +) -> Dict[str, torch.Tensor]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + this_c_pos = pred_atom_positions[..., :-1, 2, :] + this_c_mask = pred_atom_mask[..., :-1, 2] + next_n_pos = pred_atom_positions[..., 1:, 0, :] + next_n_mask = pred_atom_mask[..., 1:, 0] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + + # Compute loss for the C--N bond. + c_n_bond_length = torch.sqrt( + eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1) + ) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] + gt_length = ( + ~next_is_proline + ) * residue_constants.between_res_bond_length_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_c_n[ + 1 + ] + gt_stddev = ( + ~next_is_proline + ) * residue_constants.between_res_bond_length_stddev_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[ + 1 + ] + c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) + c_n_loss_per_residue = torch.nn.functional.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev + ) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev) + ) + + # Compute loss for the angles. + ca_c_bond_length = torch.sqrt( + eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1) + ) + n_ca_bond_length = torch.sqrt( + eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1) + ) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] + + ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = torch.sqrt( + eps + (ca_c_n_cos_angle - gt_angle) ** 2 + ) + ca_c_n_loss_per_residue = torch.nn.functional.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev + ) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + ca_c_n_violation_mask = mask * ( + ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev) + ) + + c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = torch.sqrt( + eps + torch.square(c_n_ca_cos_angle - gt_angle) + ) + c_n_ca_loss_per_residue = torch.nn.functional.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev + ) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev) + ) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = ( + c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue + ) + per_residue_loss_sum = 0.5 * ( + torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + + torch.nn.functional.pad(per_residue_loss_sum, (1, 0)) + ) + + # Compute hard violations. + violation_mask = torch.max( + torch.stack( + [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], + dim=-2, + ), + dim=-2, + )[0] + violation_mask = torch.maximum( + torch.nn.functional.pad(violation_mask, (0, 1)), + torch.nn.functional.pad(violation_mask, (1, 0)), + ) + + return { + "c_n_loss_mean": c_n_loss, + "ca_c_n_loss_mean": ca_c_n_loss, + "c_n_ca_loss_mean": c_n_ca_loss, + "per_residue_loss_sum": per_residue_loss_sum, + "per_residue_violation_mask": violation_mask, + } + + +def between_residue_clash_loss( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_atom_radius: torch.Tensor, + residue_index: torch.Tensor, + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + fp_type = atom14_pred_positions.dtype + + # Create the distance matrix. + # (N, N, 14, 14) + dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., :, None, :, None, :] + - atom14_pred_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = ( + atom14_atom_exists[..., :, None, :, None] + * atom14_atom_exists[..., None, :, None, :] + ).type(fp_type) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask = dists_mask * ( + residue_index[..., :, None, None, None] + < residue_index[..., None, :, None, None] + ) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = torch.nn.functional.one_hot( + residue_index.new_tensor(2), num_classes=14 + ) + c_one_hot = c_one_hot.reshape( + *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape + ) + c_one_hot = c_one_hot.type(fp_type) + n_one_hot = torch.nn.functional.one_hot( + residue_index.new_tensor(0), num_classes=14 + ) + n_one_hot = n_one_hot.reshape( + *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape + ) + n_one_hot = n_one_hot.type(fp_type) + + neighbour_mask = ( + residue_index[..., :, None, None, None] + 1 + ) == residue_index[..., None, :, None, None] + c_n_bonds = ( + neighbour_mask + * c_one_hot[..., None, None, :, None] + * n_one_hot[..., None, None, None, :] + ) + dists_mask = dists_mask * (1.0 - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys = residue_constants.restype_name_to_atom14_names["CYS"] + cys_sg_idx = cys.index("SG") + cys_sg_idx = residue_index.new_tensor(cys_sg_idx) + cys_sg_idx = cys_sg_idx.reshape( + *((1,) * len(residue_index.shape[:-1])), 1 + ).squeeze(-1) + cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = ( + cys_sg_one_hot[..., None, None, :, None] + * cys_sg_one_hot[..., None, None, None, :] + ) + dists_mask = dists_mask * (1.0 - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom14_atom_radius[..., :, None, :, None] + + atom14_atom_radius[..., None, :, None, :] + ) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * torch.nn.functional.relu( + dists_lower_bound - overlap_tolerance_soft - dists + ) + + # Compute the mean loss. + # shape () + mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum( + dists_to_low_error, axis=(-3, -1) + ) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard) + ) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = torch.maximum( + torch.amax(clash_mask, axis=(-4, -2)), + torch.amax(clash_mask, axis=(-3, -1)), + ) + + return { + "mean_loss": mean_loss, # shape () + "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) + "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_dists_lower_bound: torch.Tensor, + atom14_dists_upper_bound: torch.Tensor, + tighten_bounds_for_loss=0.0, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions ([*, N, 14, 3]): + Predicted positions of atoms in global prediction frame. + atom14_atom_exists ([*, N, 14]): + Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound ([*, N, 14]): + Lower bound on allowed distances. + atom14_dists_upper_bound ([*, N, 14]): + Upper bound on allowed distances + tighten_bounds_for_loss ([*, N]): + Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum' ([*, N, 14]): + sum of all clash losses per atom, shape + * 'per_atom_clash_mask' ([*, N, 14]): + mask whether atom clashes with any other atom shape + """ + # Compute the mask for each residue. + dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] + dists_masks = dists_masks.reshape( + *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape + ) + dists_masks = ( + atom14_atom_exists[..., :, :, None] + * atom14_atom_exists[..., :, None, :] + * dists_masks + ) + + # Distance matrix + dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., :, :, None, :] + - atom14_pred_positions[..., :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + # Compute the loss. + dists_to_low_error = torch.nn.functional.relu( + atom14_dists_lower_bound + tighten_bounds_for_loss - dists + ) + dists_to_high_error = torch.nn.functional.relu( + dists - (atom14_dists_upper_bound - tighten_bounds_for_loss) + ) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) + + # Compute the violations mask. + violations = dists_masks * ( + (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound) + ) + + # Compute the per atom violations. + per_atom_violations = torch.maximum( + torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0] + ) + + return { + "per_atom_loss_sum": per_atom_loss_sum, + "per_atom_violations": per_atom_violations, + } + + +def find_structural_violations( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + violation_tolerance_factor: float, + clash_overlap_tolerance: float, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + aatype=batch["aatype"], + tolerance_factor_soft=violation_tolerance_factor, + tolerance_factor_hard=violation_tolerance_factor, + ) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ] + atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) + atom14_atom_radius = ( + batch["atom14_atom_exists"] + * atomtype_radius[batch["residx_atom14_to_atom37"]] + ) + + # Compute the between residue clash loss. + between_residue_clashes = between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch["residue_index"], + overlap_tolerance_soft=clash_overlap_tolerance, + overlap_tolerance_hard=clash_overlap_tolerance, + ) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=clash_overlap_tolerance, + bond_length_tolerance_factor=violation_tolerance_factor, + ) + atom14_atom_exists = batch["atom14_atom_exists"] + atom14_dists_lower_bound = atom14_pred_positions.new_tensor( + restype_atom14_bounds["lower_bound"] + )[batch["aatype"]] + atom14_dists_upper_bound = atom14_pred_positions.new_tensor( + restype_atom14_bounds["upper_bound"] + )[batch["aatype"]] + residue_violations = within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0, + ) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = torch.max( + torch.stack( + [ + connection_violations["per_residue_violation_mask"], + torch.max( + between_residue_clashes["per_atom_clash_mask"], dim=-1 + )[0], + torch.max(residue_violations["per_atom_violations"], dim=-1)[0], + ], + dim=-1, + ), + dim=-1, + )[0] + + return { + "between_residues": { + "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # () + "angles_ca_c_n_loss_mean": connection_violations[ + "ca_c_n_loss_mean" + ], # () + "angles_c_n_ca_loss_mean": connection_violations[ + "c_n_ca_loss_mean" + ], # () + "connections_per_residue_loss_sum": connection_violations[ + "per_residue_loss_sum" + ], # (N) + "connections_per_residue_violation_mask": connection_violations[ + "per_residue_violation_mask" + ], # (N) + "clashes_mean_loss": between_residue_clashes["mean_loss"], # () + "clashes_per_atom_loss_sum": between_residue_clashes[ + "per_atom_loss_sum" + ], # (N, 14) + "clashes_per_atom_clash_mask": between_residue_clashes[ + "per_atom_clash_mask" + ], # (N, 14) + }, + "within_residues": { + "per_atom_loss_sum": residue_violations[ + "per_atom_loss_sum" + ], # (N, 14) + "per_atom_violations": residue_violations[ + "per_atom_violations" + ], # (N, 14), + }, + "total_per_residue_violations_mask": per_residue_violations_mask, # (N) + } + + +def find_structural_violations_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + config: ml_collections.ConfigDict, +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + + out = find_structural_violations(batch, atom14_pred_positions, **config) + + to_np = lambda x: np.array(x) + np_out = tensor_tree_map(to_np, out) + + return np_out + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: torch.Tensor, # (N, 37(14), 3) + pred_atom_mask: torch.Tensor, # (N, 37(14)) + residue_index: torch.Tensor, # (N) + max_angstrom_tolerance=1.5, + eps=1e-6, +) -> torch.Tensor: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + ca_ca_distance = torch.sqrt( + eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1) + ) + violations = ( + ca_ca_distance - residue_constants.ca_ca + ) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + mean = masked_mean(mask, violations, -1) + return mean + + +def compute_violation_metrics( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, # (N, 14, 3) + violations: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Compute several metrics to assess the structural violations.""" + ret = {} + extreme_ca_ca_violations = extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + ) + ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations + ret["violations_between_residue_bond"] = masked_mean( + batch["seq_mask"], + violations["between_residues"][ + "connections_per_residue_violation_mask" + ], + dim=-1, + ) + ret["violations_between_residue_clash"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["between_residues"]["clashes_per_atom_clash_mask"], + dim=-1, + )[0], + dim=-1, + ) + ret["violations_within_residue"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["within_residues"]["per_atom_violations"], dim=-1 + )[0], + dim=-1, + ) + ret["violations_per_residue"] = masked_mean( + mask=batch["seq_mask"], + value=violations["total_per_residue_violations_mask"], + dim=-1, + ) + return ret + + +def compute_violation_metrics_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + violations: Dict[str, np.ndarray], +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + violations = tree_map(to_tensor, violations, np.ndarray) + + out = compute_violation_metrics(batch, atom14_pred_positions, violations) + + to_np = lambda x: np.array(x) + return tree_map(to_np, out, torch.Tensor) + + +def violation_loss( + violations: Dict[str, torch.Tensor], + atom14_atom_exists: torch.Tensor, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + num_atoms = torch.sum(atom14_atom_exists) + l_clash = torch.sum( + violations["between_residues"]["clashes_per_atom_loss_sum"] + + violations["within_residues"]["per_atom_loss_sum"] + ) + l_clash = l_clash / (eps + num_atoms) + loss = ( + violations["between_residues"]["bonds_c_n_loss_mean"] + + violations["between_residues"]["angles_ca_c_n_loss_mean"] + + violations["between_residues"]["angles_c_n_ca_loss_mean"] + + l_clash + ) + + return loss + + +def compute_renamed_ground_truth( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """ + Find optimal renaming of ground truth based on the predicted positions. + + Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + + pred_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., None, :, None, :] + - atom14_pred_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_gt_positions = batch["atom14_gt_positions"] + gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_gt_positions[..., None, :, None, :] + - atom14_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] + alt_gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_alt_gt_positions[..., None, :, None, :] + - atom14_alt_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) + alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) + + atom14_gt_exists = batch["atom14_gt_exists"] + atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] + mask = ( + atom14_gt_exists[..., None, :, None] + * atom14_atom_is_ambiguous[..., None, :, None] + * atom14_gt_exists[..., None, :, None, :] + * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) + ) + + per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) + alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) + + fp_type = atom14_pred_positions.dtype + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) + + renamed_atom14_gt_positions = ( + 1.0 - alt_naming_is_better[..., None, None] + ) * atom14_gt_positions + alt_naming_is_better[ + ..., None, None + ] * atom14_alt_gt_positions + + renamed_atom14_gt_mask = ( + 1.0 - alt_naming_is_better[..., None] + ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[ + "atom14_alt_gt_exists" + ] + + return { + "alt_naming_is_better": alt_naming_is_better, + "renamed_atom14_gt_positions": renamed_atom14_gt_positions, + "renamed_atom14_gt_exists": renamed_atom14_gt_mask, + } + + +def experimentally_resolved_loss( + logits: torch.Tensor, + atom37_atom_exists: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + min_resolution: float, + max_resolution: float, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + errors = sigmoid_cross_entropy(logits, all_atom_mask) + loss = torch.sum(errors * atom37_atom_exists, dim=-1) + loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) + loss = torch.sum(loss, dim=-1) + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + loss = torch.mean(loss) + + return loss + + +def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): + """ + Computes BERT-style masked MSA loss. Implements subsection 1.9.9. + + Args: + logits: [*, N_seq, N_res, 23] predicted residue distribution + true_msa: [*, N_seq, N_res] true MSA + bert_mask: [*, N_seq, N_res] MSA mask + Returns: + Masked MSA loss + """ + errors = softmax_cross_entropy( + logits, torch.nn.functional.one_hot(true_msa, num_classes=23) + ) + + # FP16-friendly averaging. Equivalent to: + # loss = ( + # torch.sum(errors * bert_mask, dim=(-1, -2)) / + # (eps + torch.sum(bert_mask, dim=(-1, -2))) + # ) + loss = errors * bert_mask + loss = torch.sum(loss, dim=-1) + scale = 0.5 + denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = torch.mean(loss) + + return loss + + +def compute_drmsd(structure_1, structure_2, mask=None): + if(mask is not None): + structure_1 = structure_1 * mask[..., None] + structure_2 = structure_2 * mask[..., None] + + d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] + d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] + + d1 = d1 ** 2 + d2 = d2 ** 2 + + d1 = torch.sqrt(torch.sum(d1, dim=-1)) + d2 = torch.sqrt(torch.sum(d2, dim=-1)) + + drmsd = d1 - d2 + drmsd = drmsd ** 2 + drmsd = torch.sum(drmsd, dim=(-1, -2)) + n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) + drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) + drmsd = torch.sqrt(drmsd) + + return drmsd + + +def compute_drmsd_np(structure_1, structure_2, mask=None): + structure_1 = torch.tensor(structure_1) + structure_2 = torch.tensor(structure_2) + if(mask is not None): + mask = torch.tensor(mask) + + return compute_drmsd(structure_1, structure_2, mask) + + +def backbone_atom_loss( + pred_atom37: torch.Tensor, + batch: Dict[str, torch.Tensor], + mask: torch.Tensor = None, + eps: float = 1e-4, + t_threshold: Optional[float] = None, + **kwargs, +): + pred_backb_atoms = pred_atom37[:, :, :5] # (B, L, 5, 3) + gt_rigids = batch['rigids_0'] + gt_psi = batch["torsion_angles_sin_cos"][..., 2, :] + + gt_atom37, atom37_mask, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"]) + gt_backb_atoms, backb_mask = gt_atom37[:, :, :5], atom37_mask[:, :, :5] + + if mask is not None: + backb_mask = backb_mask * mask[..., None] # (B, L, 5) + + backb_atom_loss = torch.sum( + (pred_backb_atoms - gt_backb_atoms)**2 * backb_mask[..., None], + dim=(-1, -2, -3) + ) / (backb_mask.sum(dim=(-1, -2)) + eps) + + if t_threshold is not None: + backb_atom_loss = backb_atom_loss * (batch['t'] < t_threshold) + return torch.mean(backb_atom_loss) + + +def pairwise_distance_loss( + pred_atom37: torch.Tensor, + batch: Dict[str, torch.Tensor], + mask: torch.Tensor = None, + eps: float = 1e-4, + t_threshold: Optional[float] = None, + dist_threshold: float = 6.0, + **kwargs, +): + batch_size, n_res = pred_atom37.shape[:2] + pred_backb_atoms = pred_atom37[:, :, :5].reshape(batch_size, -1, 3) + + gt_rigids = batch['rigids_0'] + gt_psi = batch["torsion_angles_sin_cos"][..., 2, :] + gt_atom37, _, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"]) + gt_backb_atoms = gt_atom37[:, :, :5].reshape(batch_size, -1, 3) + + # Configure masks. + residue_mask = batch['seq_mask'] + if mask is not None: + residue_mask = residue_mask * mask + residue_mask = torch.tile(residue_mask[:, :, None], (1, 1, 5)).view(batch_size, -1) + + gt_pwd = torch.linalg.norm( + gt_backb_atoms[:, :, None, :] - gt_backb_atoms[:, None, :, :], + dim=-1 + ) * residue_mask[..., None] + pred_pwd = torch.linalg.norm( + pred_backb_atoms[:, :, None, :] - pred_backb_atoms[:, None, :, :], + dim=-1 + ) * residue_mask[..., None] + + + pair_mask = residue_mask[:, :, None] * residue_mask[:, None, :] # atom-wise + pair_mask = pair_mask * (pred_pwd < dist_threshold) + pwd_loss = torch.sum( + (gt_pwd - pred_pwd)**2 * pair_mask, dim=(-1, -2) + ) / (torch.sum(pair_mask, dim=(-1, -2)) - n_res + eps) + + if t_threshold is not None: + pwd_loss = pwd_loss * (batch['t'] < t_threshold) + return torch.mean(pwd_loss) + + + +#################### Training Losses #################### + + +class ScoreMatchingLoss(nn.Module): + """Aggregation of the various losses described in the supplement""" + def __init__(self, config): + super(ScoreMatchingLoss, self).__init__() + self.config = config # config.loss + + def forward(self, out, batch, _return_breakdown=False): + # Configure masks. + seq_mask = batch['seq_mask'] + diffuse_mask = 1. - batch['fixed_mask'] + loss_mask = seq_mask * diffuse_mask # (B, L) + _denom = sum_except_batch(loss_mask) + self.config.eps # normalizing sum + + ### + # Score Matching loss. + ### + pred_rot_score = out['rot_score'] * diffuse_mask[..., None] + pred_trans_score = out['trans_score'] * diffuse_mask[..., None] + gt_rot_score = batch['rot_score'] * diffuse_mask[..., None] + gt_trans_score = batch['trans_score'] * diffuse_mask[..., None] + # Translation component. + trans_score_loss = (gt_trans_score - pred_trans_score) * loss_mask[..., None] + trans_score_loss /= inflate_array_like(batch['trans_score_scaling'], trans_score_loss) + trans_score_loss = torch.sum(trans_score_loss**2, dim=(-1, -2)) / _denom + # Alternative x0 loss. + trans_x0_loss = (self.config.translation.coordinate_scaling * + (batch['rigids_0'].get_trans() - out['rigids'].get_trans()) * + loss_mask[..., None] + ) + trans_x0_loss = torch.sum(trans_x0_loss**2, dim=(-1, -2)) / _denom + trans_loss = torch.mean( + trans_score_loss * (batch['t'] > self.config.translation.x0_threshold) + + trans_x0_loss * (batch['t'] <= self.config.translation.x0_threshold) + ) + # Rotation component. + rot_loss = (gt_rot_score - pred_rot_score) * loss_mask[..., None] + rot_loss /= inflate_array_like(batch['rot_score_scaling'], rot_loss) + rot_loss = torch.mean(torch.sum(rot_loss**2, dim=(-1, -2)) / _denom) + + loss_fns = { + "translation": lambda: trans_loss, + "rotation": lambda: rot_loss, + } + + # Auxiliary folding loss. + if self.config.distogram.enabled: + loss_fns["distogram"] = lambda: distogram_loss( + logits=out["distogram_logits"], + **{**batch, **self.config.distogram}, + ) + if self.config.supervised_chi.enabled: + loss_fns["supervised_chi"] = lambda: supervised_chi_loss( + out["sm"]["angles"], + out["sm"]["unnormalized_angles"], + **{**batch, **self.config.supervised_chi}, + ) + if self.config.lddt.enabled: + loss_fns["lddt"] = lambda: lddt_loss( + logits=out["lddt_logits"], + all_atom_pred_pos=out["final_atom_positions"], + **{**batch, **self.config.lddt}, + ) + if self.config.fape.enabled: + loss_fns["fape"] = lambda: fape_loss( + out, + batch, + self.config.fape, + ) + if self.config.tm.enabled: + loss_fns["tm"] = lambda: tm_loss( + logits=out["tm_logits"], + **{**batch, **out, **self.config.tm}, + ) + if self.config.backbone.enabled: + loss_fns["backbone"] = lambda: backbone_atom_loss( + pred_atom37=out["atom37"], + batch=batch, + mask=loss_mask, + **self.config.backbone, + ) + if self.config.pwd.enabled: + loss_fns["pwd"] = lambda: pairwise_distance_loss( + pred_atom37=out["atom37"], + batch=batch, + mask=loss_mask, + **self.config.pwd, + ) + + cum_loss = 0. + losses = {} + for loss_name, loss_fn in loss_fns.items(): + weight = self.config[loss_name].weight + loss = loss_fn() + if torch.isnan(loss) or torch.isinf(loss): + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0., requires_grad=True) + cum_loss = cum_loss + weight * loss + losses[loss_name] = loss.detach().clone() + + # losses["unscaled_loss"] = cum_loss.detach().clone() + + # Scale the loss by the square root of the minimum of the crop size and + # the (average) sequence length. See subsection 1.9. + # seq_len = torch.mean(batch["seq_length"].float()) + # crop_len = batch["aatype"].shape[-1] + # cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + + losses["loss"] = cum_loss.detach().clone() + + if not _return_breakdown: + return cum_loss + + return cum_loss, losses diff --git a/analysis/src/models/net/__init__.py b/analysis/src/models/net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbe50b991d299f54e717393bcd3228481c5554a0 Binary files /dev/null and b/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc b/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb0704f6950e4402f56d8c64add284c543a3cb76 Binary files /dev/null and b/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc differ diff --git a/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc b/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dc9a2164c3b81bc48399c82b57221f57e57fc4d Binary files /dev/null and b/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc differ diff --git a/analysis/src/models/net/__pycache__/layers.cpython-39.pyc b/analysis/src/models/net/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7173fb6bc524cad8068b4f18e5ae180751d69da Binary files /dev/null and b/analysis/src/models/net/__pycache__/layers.cpython-39.pyc differ diff --git a/analysis/src/models/net/denoising_ipa.py b/analysis/src/models/net/denoising_ipa.py new file mode 100644 index 0000000000000000000000000000000000000000..a954a7b62ca3db1498e4c8bb492c80c3920d2153 --- /dev/null +++ b/analysis/src/models/net/denoising_ipa.py @@ -0,0 +1,211 @@ +import math +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F + +from src.common.all_atom import compute_backbone +from src.common.geo_utils import calc_distogram +from src.models.net.ipa import TranslationIPA + + +def get_positional_embedding(indices, embedding_dim, max_len=2056): + """Creates sine / cosine positional embeddings from a prespecified indices. + + Args: + indices: offsets of size [..., N_edges] of type integer + max_len: maximum length. + embedding_dim: dimension of the embeddings to create + + Returns: + positional embedding of shape [N, embedding_dim] + """ + K = torch.arange(embedding_dim//2, device=indices.device) + pos_embedding_sin = torch.sin( + indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device) + pos_embedding_cos = torch.cos( + indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device) + pos_embedding = torch.cat([ + pos_embedding_sin, pos_embedding_cos], axis=-1) + return pos_embedding + + +def get_timestep_embedding(timesteps, embedding_dim, max_len=10000): + # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py + assert len(timesteps.shape) == 1 + timesteps = timesteps * max_len + half_dim = embedding_dim // 2 + emb = math.log(max_len) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +class EmbeddingModule(nn.Module): + def __init__(self, + init_embed_size: int, + node_embed_size: int, + edge_embed_size: int, + num_bins: int = 22, + min_bin: float = 1e-5, + max_bin: float = 20.0, + self_conditioning: bool = True, + ): + super(EmbeddingModule, self).__init__() + pos_embed_size = init_embed_size + t_embed_size = init_embed_size + + # time embedding + node_in_dim = t_embed_size + 1 + edge_in_dim = (t_embed_size + 1) * 2 + + # positional embedding + node_in_dim += pos_embed_size + edge_in_dim += pos_embed_size + + self.node_embed = nn.Sequential( + nn.Linear(node_in_dim, node_embed_size), + nn.ReLU(), + nn.Linear(node_embed_size, node_embed_size), + nn.ReLU(), + nn.Linear(node_embed_size, node_embed_size), + nn.LayerNorm(node_embed_size), + ) + + # self-conditioning trick used in RFDiffusion + self.self_conditioning = self_conditioning + if self_conditioning: + edge_in_dim += num_bins + + self.edge_embed = nn.Sequential( + nn.Linear(edge_in_dim, edge_embed_size), + nn.ReLU(), + nn.Linear(edge_embed_size, edge_embed_size), + nn.ReLU(), + nn.Linear(edge_embed_size, edge_embed_size), + nn.LayerNorm(edge_embed_size), + ) + + self.time_embed = partial( + get_timestep_embedding, embedding_dim=t_embed_size + ) + self.position_embed = partial( + get_positional_embedding, embedding_dim=pos_embed_size + ) + self.distogram_embed = partial( + calc_distogram, + min_bin=min_bin, + max_bin=max_bin, + num_bins=num_bins, + ) + + def forward( + self, + residue_idx, + t, + fixed_mask, + self_conditioning_ca, + ): + """ + Args: + residue_idx: [..., N] Positional sequence index for each residue. + t: Sampled t in [0, 1]. + fixed_mask: mask of fixed (motif) residues. + self_conditioning_ca: [..., N, 3] Ca positions of self-conditioning + input. + + Returns: + node_embed: [B, N, D_node] + edge_embed: [B, N, N, D_edge] + """ + B, L = residue_idx.shape + fixed_mask = fixed_mask[..., None].float() + node_feats = [] + pair_feats = [] + + # configure time embedding + t_embed = torch.tile(self.time_embed(t)[:, None, :], (1, L, 1)) + t_embed = torch.cat([t_embed, fixed_mask], dim=-1) + node_feats.append(t_embed) + + # make pair embedding from 1d time feats + concat_1d = torch.cat( + [torch.tile(t_embed[:, :, None, :], (1, 1, L, 1)), + torch.tile(t_embed[:, None, :, :], (1, L, 1, 1))], + dim=-1).float().reshape([B, L**2, -1]) + pair_feats.append(concat_1d) + + # positional embedding + node_feats.append(self.position_embed(residue_idx)) + + # relative 2d positional embedding + rel_seq_offset = residue_idx[:, :, None] - residue_idx[:, None, :] + rel_seq_offset = rel_seq_offset.reshape([B, L**2]) + pair_feats.append(self.position_embed(rel_seq_offset)) + + # self-conditioning distogram of C-alpha atoms + if self.self_conditioning: + ca_dist = self.distogram_embed(self_conditioning_ca) + pair_feats.append(ca_dist.reshape([B, L**2, -1])) + + node_embed = self.node_embed(torch.cat(node_feats, dim=-1).float()) + edge_embed = self.edge_embed(torch.cat(pair_feats, dim=-1).float()) + edge_embed = edge_embed.reshape([B, L, L, -1]) + return node_embed, edge_embed + + +class DenoisingNet(nn.Module): + def __init__(self, + embedder: nn.Module, + translator: nn.Module, + ): + super(DenoisingNet, self).__init__() + self.embedder = embedder # embedding module + self.translator = translator # translationIPA + + def forward(self, batch, as_tensor_7=False): + """Forward computes the denoised frames p(X^t|X^{t+1}) + """ + # Frames as [batch, res, 7] tensors. + node_mask = batch['residue_mask'].type(torch.float) # [B, N] + fixed_mask = batch['fixed_mask'].type(torch.float) + edge_mask = node_mask[..., None] * node_mask[..., None, :] + + # Get embeddings. + node_embed, edge_embed = self.embedder( + residue_idx=batch['residue_idx'], + t=batch['t'], + fixed_mask=fixed_mask, + self_conditioning_ca=batch['sc_ca_t'], + ) + node_embed = node_embed * node_mask[..., None] # (L, D) + edge_embed = edge_embed * edge_mask[..., None] # (L, L, D) + + # Translation for frames. + model_out = self.translator(node_embed, edge_embed, batch) + + # Psi angle prediction + gt_psi = batch['torsion_angles_sin_cos'][..., 2, :] + psi_pred = gt_psi * fixed_mask[..., None] + model_out['psi'] * (1 - fixed_mask[..., None]) + rigids_pred = model_out['out_rigids'] + + bb_representations = compute_backbone( + rigids_pred, psi_pred, aatype=batch['aatype'] if 'aatype' in batch else None + ) + atom37_pos = bb_representations[0].to(rigids_pred.device) + atom14_pos = bb_representations[-1].to(rigids_pred.device) + + if as_tensor_7: + rigids_pred = rigids_pred.to_tensor_7() + + return { + 'rigids': rigids_pred, + 'psi': psi_pred, + 'atom37': atom37_pos, + 'atom14': atom14_pos, + } \ No newline at end of file diff --git a/analysis/src/models/net/ipa.py b/analysis/src/models/net/ipa.py new file mode 100644 index 0000000000000000000000000000000000000000..bc42264f44ea24914b3b00929feab779a2ededb3 --- /dev/null +++ b/analysis/src/models/net/ipa.py @@ -0,0 +1,387 @@ +""" +Adapted from [Openfold](https://github.com/aqlaboratory/openfold) IPA implementation. +""" + +import math +from typing import Optional, List, Sequence + +import torch +import torch.nn as nn + +from src.common.rigid_utils import Rigid +from src.models.net.layers import Linear, NodeTransition, EdgeTransition, TorsionAngleHead, BackboneUpdate + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc) + self.linear_kv = Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.linear_q_points = Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.linear_kv_points = Linear(self.c_s, hpkv) + + self.linear_b = Linear(self.c_z, self.no_heads) + self.down_z = Linear(self.c_z, self.c_z // 4) + + self.head_weights = nn.Parameter(torch.zeros((self.no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = ( + self.c_z // 4 + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + if _offload_inference: + z = _z_reference_list + else: + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view( + q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) + ) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if(_offload_inference): + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) + + # [*, N_res, N_res, H, P_q, 3] + pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_displacement ** 2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view( + *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + ) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + ) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, v.transpose(-2, -3).to(dtype=a.dtype) + ).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + ( + a[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + ), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps) + o_pt_norm_feats = flatten_final_dims( + o_pt_dists, 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if(_offload_inference): + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z // 4] + pair_z = self.down_z(z[0]).to(dtype=a.dtype) + o_pair = torch.matmul(a.transpose(-2, -3), pair_z) + + # [*, N_res, H * C_z // 4] + o_pair = flatten_final_dims(o_pair, 2) + + o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + o_feats, dim=-1 + ).to(dtype=z[0].dtype) + ) + + return s + + +class TranslationIPA(nn.Module): + def __init__(self, + c_s: int, + c_z: int, + coordinate_scaling: float, + no_ipa_blocks: int, + skip_embed_size: int, + transformer_num_heads: int = 4, + transformer_num_layers: int = 2, + c_hidden: int = 256, + no_heads: int = 8, + no_qk_points: int = 8, + no_v_points: int = 12, + dropout: float = 0.0, + ): + super(TranslationIPA, self).__init__() + + self.scale_pos = lambda x: x * coordinate_scaling + self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos) + + self.unscale_pos = lambda x: x / coordinate_scaling + self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos) + self.trunk = nn.ModuleDict() + self.num_blocks = no_ipa_blocks + + for b in range(no_ipa_blocks): + self.trunk[f'ipa_{b}'] = InvariantPointAttention( + c_s=c_s, + c_z=c_z, + c_hidden=c_hidden, + no_heads=no_heads, + no_qk_points=no_qk_points, + no_v_points=no_v_points, + ) + self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(c_s) + self.trunk[f'skip_embed_{b}'] = Linear( + c_s, + skip_embed_size, + init="final" + ) + _in_dim = c_s + skip_embed_size + transformer_layer = nn.TransformerEncoderLayer( + d_model=_in_dim, + nhead=transformer_num_heads, + dim_feedforward=_in_dim, + ) + self.trunk[f'transformer_{b}'] = nn.TransformerEncoder(transformer_layer, transformer_num_layers) + self.trunk[f'linear_{b}'] = Linear(_in_dim, c_s, init="final") + self.trunk[f'node_transition_{b}'] = NodeTransition(c_s) + self.trunk[f'bb_update_{b}'] = BackboneUpdate(c_s) + + if b < self.num_blocks - 1: # No need edge update on the last block + self.trunk[f'edge_transition_{b}'] = EdgeTransition( + node_embed_size=c_s, + edge_embed_in=c_z, + edge_embed_out=c_z, + ) + self.torsion_pred = TorsionAngleHead(c_s, 1) + # self.chi_pred = TorsionAngleHead(c_s, 4) + + def forward(self, node_embed, edge_embed, batch): + node_mask = batch['residue_mask'].type(torch.float) + diffuse_mask = (1 - batch['fixed_mask'].type(torch.float)) * node_mask + edge_mask = node_mask[..., None] * node_mask[..., None, :] + + init_frames = batch['rigids_t'].type(torch.float) + curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames)) + init_rigids = Rigid.from_tensor_7(init_frames) + curr_rigids = self.scale_rigids(curr_rigids) + + # Main trunk + init_node_embed = node_embed + for b in range(self.num_blocks): + ipa_embed = self.trunk[f'ipa_{b}']( + node_embed, + edge_embed, + curr_rigids, + node_mask + ) + ipa_embed *= node_mask[..., None] + node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed) + # MHA + FF + concat_node_embed = torch.cat([ + node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed) + ], dim=-1) + concat_node_embed = torch.transpose(concat_node_embed, 0, 1) + transformed_embed = self.trunk[f'transformer_{b}'](concat_node_embed, src_key_padding_mask=1.0 - node_mask) + transformed_embed = torch.transpose(transformed_embed, 0, 1) + # residual + node_embed = node_embed + self.trunk[f'linear_{b}'](transformed_embed) + + # MLP + node_embed = self.trunk[f'node_transition_{b}'](node_embed) + # apply mask + node_embed = node_embed * node_mask[..., None] + # backbone update + rigid_update = self.trunk[f'bb_update_{b}'](node_embed * diffuse_mask[..., None]) + # apply update + curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, diffuse_mask[..., None]) + + if b < self.num_blocks - 1: + edge_embed = self.trunk[f'edge_transition_{b}'](node_embed, edge_embed) * edge_mask[..., None] + + # torsion prediction + psi_pred = self.torsion_pred(node_embed) + # chi_pred = self.chi_pred(node_embed) + + # scale back + curr_rigids = self.unscale_rigids(curr_rigids) + + model_out = { + 'in_rigids': init_rigids, + 'out_rigids': curr_rigids, + 'psi': psi_pred, + # 'chi': chi_pred, + } + return model_out diff --git a/analysis/src/models/net/layers.py b/analysis/src/models/net/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f777934166f409ba0ee83ef5b3c987746e18da31 --- /dev/null +++ b/analysis/src/models/net/layers.py @@ -0,0 +1,241 @@ +import math +from typing import Optional, Callable, List + +import numpy as np +import torch +import torch.nn as nn +from scipy.stats import truncnorm + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + +def normal_init_(weights): + nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +class Linear(nn.Linear): + """ + A Linear layer with in-house nonstandard initializations. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +# Simply 3-layer MLP +class NodeTransition(nn.Module): + def __init__(self, dim): + super(NodeTransition, self).__init__() + self.dim = dim + self.linear_1 = Linear(dim, dim, init="relu") + self.linear_2 = Linear(dim, dim, init="relu") + self.linear_3 = Linear(dim, dim, init="final") + self.relu = nn.ReLU() + self.ln = nn.LayerNorm(dim) + + def forward(self, s): + s_initial = s + s = self.relu(self.linear_1(s)) + s = self.relu(self.linear_2(s)) + s = self.linear_3(s) + s = s + s_initial + s = self.ln(s) + return s + + +class EdgeTransition(nn.Module): + def __init__(self, + node_embed_size, + edge_embed_in, + edge_embed_out, + num_layers=2, + node_dilation=2 + ): + super(EdgeTransition, self).__init__() + + bias_embed_size = node_embed_size // node_dilation + self.initial_embed = Linear( + node_embed_size, bias_embed_size, init="relu") + hidden_size = bias_embed_size * 2 + edge_embed_in + trunk_layers = [] + for _ in range(num_layers): + trunk_layers.append(Linear(hidden_size, hidden_size, init="relu")) + trunk_layers.append(nn.ReLU()) + self.trunk = nn.Sequential(*trunk_layers) + self.final_layer = Linear(hidden_size, edge_embed_out, init="final") + self.layer_norm = nn.LayerNorm(edge_embed_out) + + def forward(self, node_embed, edge_embed): + node_embed = self.initial_embed(node_embed) + batch_size, num_res, _ = node_embed.shape + edge_bias = torch.cat([ + torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)), + torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)), + ], axis=-1) + edge_embed = torch.cat( + [edge_embed, edge_bias], axis=-1).reshape( + batch_size * num_res**2, -1) + edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed) + edge_embed = self.layer_norm(edge_embed) + edge_embed = edge_embed.reshape( + batch_size, num_res, num_res, -1 + ) + return edge_embed + + +class TorsionAngleHead(nn.Module): + def __init__(self, in_dim, n_torsion_angles, eps=1e-8): + super(TorsionAngleHead, self).__init__() + + self.linear_1 = Linear(in_dim, in_dim, init="relu") + self.linear_2 = Linear(in_dim, in_dim, init="relu") + self.linear_3 = Linear(in_dim, in_dim, init="final") + self.linear_final = Linear(in_dim, n_torsion_angles * 2, init="final") + self.relu = nn.ReLU() + self.eps = eps + + def forward(self, s): + s_initial = s + s = self.relu(self.linear_1(s)) + s = self.linear_2(s) + s = s + s_initial + + unnormalized_s = self.linear_final(s) + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + normalized_s = unnormalized_s / norm_denom + return normalized_s + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + self.linear = Linear(self.c_s, 6, init="final") + + def forward(self, s: torch.Tensor): + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + return update \ No newline at end of file diff --git a/analysis/src/models/score/__init__.py b/analysis/src/models/score/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04258e1ce8dd14ffd1def9b52a53e41cbdce88f4 Binary files /dev/null and b/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/models/score/__pycache__/frame.cpython-39.pyc b/analysis/src/models/score/__pycache__/frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..643ca419ca118df819762549957054e13764d3d0 Binary files /dev/null and b/analysis/src/models/score/__pycache__/frame.cpython-39.pyc differ diff --git a/analysis/src/models/score/__pycache__/r3.cpython-39.pyc b/analysis/src/models/score/__pycache__/r3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..472954bc6f4c67f7282cb6c96da062cdef60c9af Binary files /dev/null and b/analysis/src/models/score/__pycache__/r3.cpython-39.pyc differ diff --git a/analysis/src/models/score/__pycache__/so3.cpython-39.pyc b/analysis/src/models/score/__pycache__/so3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78eecc5862cbaaaa6aec9445e2a967d6e6d364f5 Binary files /dev/null and b/analysis/src/models/score/__pycache__/so3.cpython-39.pyc differ diff --git a/analysis/src/models/score/frame.py b/analysis/src/models/score/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..26589afc1ccbbd19cb14cab7b47d378de9c62db8 --- /dev/null +++ b/analysis/src/models/score/frame.py @@ -0,0 +1,255 @@ +from typing import Optional +import torch + +from src.models.score import so3, r3 +from src.common.rigid_utils import Rigid, Rotation, quat_multiply +from src.common import rotation3d + + +def assemble_rigid(rotvec: torch.Tensor, trans: torch.Tensor): + rotvec_shape = rotvec.shape + rotmat = rotation3d.axis_angle_to_matrix(rotvec).view(rotvec_shape[:-1] + (3, 3)) + return Rigid( + rots=Rotation(rot_mats=rotmat), + trans=trans, + ) + +def apply_mask(x_tgt, x_src, tgt_mask): + return tgt_mask * x_tgt + (1 - tgt_mask) * x_src + + +class FrameDiffuser: + """ + Wrapper class for diffusion of rigid body transformations, + including rotations and translations. + """ + def __init__(self, + trans_diffuser: Optional[r3.R3Diffuser] = None, + rot_diffuser: Optional[so3.SO3Diffuser] = None, + min_t: float = 0.001, + ): + # if None, then no diffusion for this component + self.trans_diffuser = trans_diffuser + self.rot_diffuser = rot_diffuser + self.min_t = min_t + + def forward_marginal( + self, + rigids_0: Rigid, + t: torch.Tensor, + diffuse_mask: torch.Tensor = None, + as_tensor_7: bool = True, + ): + """ + Args: + rigids_0: [..., N] openfold Rigid objects + t: continuous time in [0, 1]. + + Returns: + Dict contains: + rigids_t: [..., N] noised rigid. [..., N, 7] if as_tensor_7 is true. + trans_score: [..., N, 3] translation score + rot_score: [..., N, 3] rotation score + trans_score_norm: [...] translation score norm + rot_score_norm: [...] rotation score norm + """ + output = {} + rot_0 = rotation3d.matrix_to_axis_angle(rigids_0.get_rots().get_rot_mats()) + trans_0 = rigids_0.get_trans() + + if self.rot_diffuser is None: + rot_t = rot_0 + rot_score, rot_score_scaling = torch.zeros_like(rot_0), t + else: + rot_t, rot_score = self.rot_diffuser.forward_marginal(rot_0, t) + rot_score_scaling = self.rot_diffuser.score_scaling(t) + + if self.trans_diffuser is None: + trans_t, trans_score, trans_score_scaling = ( + trans_0, + torch.zeros_like(trans_0), + torch.ones_like(t) + ) + else: + trans_t, trans_score = self.trans_diffuser.forward_marginal(trans_0, t) + trans_score_scaling = self.trans_diffuser.score_scaling(t) + + # Perturb only a subset of residues + if diffuse_mask is not None: + diffuse_mask = torch.as_tensor(diffuse_mask, device=trans_t.device, dtype=trans_t.dtype)[..., None] + + rot_t = apply_mask(rot_t, rot_0, diffuse_mask) + trans_t = apply_mask(trans_t, trans_0, diffuse_mask) + + trans_score = apply_mask( + trans_score, + torch.zeros_like(trans_score), + diffuse_mask + ) + rot_score = apply_mask( + rot_score, + torch.zeros_like(rot_score), + diffuse_mask + ) + + rigids_t = assemble_rigid(rot_t, trans_t) + + if as_tensor_7: + rigids_t = rigids_t.to_tensor_7() + + output = { + 'rigids_t': rigids_t, + 'trans_score': trans_score, + 'rot_score': rot_score, + 'trans_score_scaling': trans_score_scaling, + 'rot_score_scaling': rot_score_scaling, + } + return output + + def score( + self, + rigids_0: Rigid, + rigids_t: Rigid, + t: torch.Tensor, + mask: torch.Tensor = None, + ): + rot_0, trans_0 = rigids_0.get_rots(), rigids_0.get_trans() + rot_t, trans_t = rigids_t.get_rots(), rigids_t.get_trans() + + if self.rot_diffuser is None: + rot_score = torch.zeros_like(rot_0) + else: + rot_0_inv = rot_0.invert() + quat_0_inv = rotation3d.matrix_to_quaternion(rot_0_inv.get_rot_mats()) + quat_t = rotation3d.matrix_to_quaternion(rot_t.get_rot_mats()) + # get relative rotation + quat_0t = quat_multiply(quat_0_inv, quat_t) + rotvec_0t = rotation3d.quaternion_to_axis_angle(quat_0t) + # calculate score + rot_score = self.rot_diffuser.score(rotvec_0t, t) + + if self.trans_diffuser is None: + trans_score = torch.zeros_like(trans_0) + else: + trans_score = self.trans_diffuser.score(trans_t, trans_0, t, scale=True) + + if mask is not None: + trans_score = trans_score * mask[..., None] + rot_score = rot_score * mask[..., None] + + return { + 'trans_score': trans_score, + 'rot_score': rot_score + } + + def score_scaling(self, t): + rot_score_scaling = self.rot_diffuser.score_scaling(t) + trans_score_scaling = self.trans_diffuser.score_scaling(t) + return { + 'trans_score_scaling': trans_score_scaling, + 'rot_score_scaling': rot_score_scaling, + } + + def reverse( + self, + rigids_t: Rigid, + rot_score: torch.Tensor, + trans_score: torch.Tensor, + t: torch.Tensor, + dt: float, + diffuse_mask: torch.Tensor = None, + center_trans: bool = True, + noise_scale: float = 1.0, + probability_flow: bool = True, + ): + """Reverse sampling function from (t) to (t-1). + + Args: + rigids_t: [..., N] protein rigid objects at time t. + rot_score: [..., N, 3] rotation score. + trans_score: [..., N, 3] translation score. + t: continuous time in [0, 1]. + dt: continuous step size in [0, 1]. + mask: [..., N] which residues to update. + center_trans: true to set center of mass to zero after step + probability_flow: whether to use probability flow ODE. + + Returns: + rigids_t_1: [..., N] protein rigid objects at time t-1. + """ + # extract rot and trans as tensors + rot_t = rotation3d.matrix_to_axis_angle(rigids_t.get_rots().get_rot_mats()) + trans_t = rigids_t.get_trans() + + # reverse rot + rot_t_1 = self.rot_diffuser.reverse( + rot_t=rot_t, + score_t=rot_score, + t=t, + dt=dt, + noise_scale=noise_scale, + probability_flow=probability_flow, + ) if self.rot_diffuser is not None else rot_t # if no diffusion module, return as-is + + # reverse trans + trans_t_1 = self.trans_diffuser.reverse( + x_t=trans_t, + score_t=trans_score, + t=t, + dt=dt, + center=center_trans, + noise_scale=noise_scale, + probability_flow=probability_flow, + ) if self.trans_diffuser is not None else trans_t + + # apply mask + if diffuse_mask is not None: + trans_t_1 = apply_mask(trans_t_1, trans_t, diffuse_mask[..., None]) + rot_t_1 = apply_mask(rot_t_1, rot_t, diffuse_mask[..., None]) + + return assemble_rigid(rot_t_1, trans_t_1) + + def sample_prior( + self, + shape: torch.Size, + device: torch.device, + reference_rigids: Rigid = None, + diffuse_mask: torch.Tensor = None, + as_tensor_7: bool = False + ): + """Samples rigids from reference distribution. + + """ + if reference_rigids is not None: + assert reference_rigids.shape[:-1] == shape, f"reference_rigids.shape[:-1] = {reference_rigids.shape[:-1]}, shape = {shape}" + assert diffuse_mask is not None, "diffuse_mask must be provided if reference_rigids is given" + rot_ref = rotation3d.matrix_to_axis_angle(reference_rigids.get_rots().get_rot_mats()) + trans_ref = reference_rigids.get_trans() + + trans_ref = self.trans_diffuser.scale(trans_ref) + else: + # sanity check + assert diffuse_mask is None, "diffuse_mask must be None if reference_rigids is None" + assert self.rot_diffuser is not None and self.trans_diffuser is not None + + # sample from prior + trans_shape, rot_shape = shape + (3, ), shape + (3, ) + rot_sample = self.rot_diffuser.sample_prior(shape=rot_shape, device=device) \ + if self.rot_diffuser is not None else rot_ref + trans_sample = self.trans_diffuser.sample_prior(shape=trans_shape, device=device) \ + if self.trans_diffuser is not None else trans_ref + + # apply mask + if diffuse_mask is not None: + rot_sample = apply_mask(rot_sample, rot_ref, diffuse_mask[..., None]) + trans_sample = apply_mask(trans_sample, trans_ref, diffuse_mask[..., None]) + + trans_sample = self.trans_diffuser.unscale(trans_sample) + + # assemble sampled rot and trans -> rigid + rigids_t = assemble_rigid(rot_sample, trans_sample) + + if as_tensor_7: + rigids_t = rigids_t.to_tensor_7() + + return {'rigids_t': rigids_t} \ No newline at end of file diff --git a/analysis/src/models/score/r3.py b/analysis/src/models/score/r3.py new file mode 100644 index 0000000000000000000000000000000000000000..c11b6b37147e6c25ac9795ee22c4754272e4f7d8 --- /dev/null +++ b/analysis/src/models/score/r3.py @@ -0,0 +1,147 @@ +"""Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/r3_diffuser.py""" + +from math import sqrt +import torch + +from src.utils.tensor_utils import inflate_array_like + +class R3Diffuser: + """VPSDE diffusion module.""" + def __init__( + self, + min_b: float = 0.1, + max_b: float = 20.0, + coordinate_scaling: float = 1.0, + ): + self.min_b = min_b + self.max_b = max_b + self.coordinate_scaling = coordinate_scaling + + def scale(self, x): + return x * self.coordinate_scaling + + def unscale(self, x): + return x / self.coordinate_scaling + + def b_t(self, t: torch.Tensor): + if torch.any(t < 0) or torch.any(t > 1): + raise ValueError(f'Invalid t={t}') + return self.min_b + t * (self.max_b - self.min_b) + + def diffusion_coef(self, t): + return torch.sqrt(self.b_t(t)) + + def drift_coef(self, x, t): + return -0.5 * self.b_t(t) * x + + def sample_prior(self, shape, device=None): + return torch.randn(size=shape, device=device) + + def marginal_b_t(self, t): + return t*self.min_b + 0.5*(t**2)*(self.max_b-self.min_b) + + def calc_trans_0(self, score_t, x_t, t): + beta_t = self.marginal_b_t(t) + beta_t = beta_t[..., None, None] + cond_var = 1 - torch.exp(-beta_t) + return (score_t * cond_var + x_t) / torch.exp(-0.5*beta_t) + + def forward_marginal( + self, + x_0: torch.Tensor, + t: torch.Tensor + ): + """Samples marginal p(x(t) | x(0)). + + Args: + x_0: [..., n, 3] initial positions in Angstroms. + t: continuous time in [0, 1]. + + Returns: + x_t: [..., n, 3] positions at time t in Angstroms. + score_t: [..., n, 3] score at time t in scaled Angstroms. + """ + t = inflate_array_like(t, x_0) + x_0 = self.scale(x_0) + + loc = torch.exp(-0.5 * self.marginal_b_t(t)) * x_0 + scale = torch.sqrt(1 - torch.exp(-self.marginal_b_t(t))) + z = torch.randn_like(x_0) + x_t = z * scale + loc + score_t = self.score(x_t, x_0, t) + + x_t = self.unscale(x_t) + return x_t, score_t + + def score_scaling(self, t: torch.Tensor): + return 1.0 / torch.sqrt(self.conditional_var(t)) + + def reverse( + self, + x_t: torch.Tensor, + score_t: torch.Tensor, + t: torch.Tensor, + dt: float, + mask: torch.Tensor = None, + center: bool = True, + noise_scale: float = 1.0, + probability_flow: bool = True, + ): + """Simulates the reverse SDE for 1 step + + Args: + x_t: [..., 3] current positions at time t in angstroms. + score_t: [..., 3] rotation score at time t. + t: continuous time in [0, 1]. + dt: continuous step size in [0, 1]. + mask: True indicates which residues to diffuse. + probability_flow: whether to use probability flow ODE. + + Returns: + [..., 3] positions at next step t-1. + """ + t = inflate_array_like(t, x_t) + x_t = self.scale(x_t) + + f_t = self.drift_coef(x_t, t) + g_t = self.diffusion_coef(t) + + z = noise_scale * torch.randn_like(score_t) + + rev_drift = (f_t - g_t ** 2 * score_t) * dt * (0.5 if probability_flow else 1.) + rev_diffusion = 0. if probability_flow else (g_t * sqrt(dt) * z) + perturb = rev_drift + rev_diffusion + + if mask is not None: + perturb *= mask[..., None] + else: + mask = torch.ones_like(x_t[..., 0]) + x_t_1 = x_t - perturb # reverse in time + if center: + com = torch.sum(x_t_1, dim=-2) / torch.sum(mask, dim=-1)[..., None] # reduce length dim + x_t_1 -= com[..., None, :] + + x_t_1 = self.unscale(x_t_1) + return x_t_1 + + def conditional_var(self, t, use_torch=False): + """Conditional variance of p(xt|x0). + Var[x_t|x_0] = conditional_var(t) * I + """ + return 1.0 - torch.exp(-self.marginal_b_t(t)) + + def score(self, x_t, x_0, t, scale=False): + t = inflate_array_like(t, x_t) + if scale: + x_t, x_0 = self.scale(x_t), self.scale(x_0) + return -(x_t - torch.exp(-0.5 * self.marginal_b_t(t)) * x_0) / self.conditional_var(t) + + def distribution(self, x_t, score_t, t, mask, dt): + x_t = self.scale(x_t) + f_t = self.drift_coef(x_t, t) + g_t = self.diffusion_coef(t) + std = g_t * sqrt(dt) + mu = x_t - (f_t - g_t**2 * score_t) * dt + if mask is not None: + mu *= mask[..., None] + return mu, std \ No newline at end of file diff --git a/analysis/src/models/score/so3.py b/analysis/src/models/score/so3.py new file mode 100644 index 0000000000000000000000000000000000000000..8a43caa5ba1e062020b1a7e95406427b29bb02c3 --- /dev/null +++ b/analysis/src/models/score/so3.py @@ -0,0 +1,371 @@ +"""Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/so3_diffuser.py""" + +import os +import math + +import numpy as np +import torch + +from src.utils.tensor_utils import inflate_array_like +from src.common import rotation3d + + +def compose_rotvec(rotvec1, rotvec2): + """Compose two rotation euler vectors.""" + dtype = rotvec1.dtype + R1 = rotation3d.axis_angle_to_matrix(rotvec1) + R2 = rotation3d.axis_angle_to_matrix(rotvec2) + cR = torch.einsum('...ij,...jk->...ik', R1.double(), R2.double()) + return rotation3d.matrix_to_axis_angle(cR).type(dtype) + +def igso3_expansion(omega, eps, L=1000, use_torch=False): + """Truncated sum of IGSO(3) distribution. + + This function approximates the power series in equation 5 of + "DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL + ALIGNMENT" + Leach et al. 2022 + + This expression diverges from the expression in Leach in that here, eps = + sqrt(2) * eps_leach, if eps_leach were the scale parameter of the IGSO(3). + + With this reparameterization, IGSO(3) agrees with the Brownian motion on + SO(3) with t=eps^2. + + Args: + omega: rotation of Euler vector (i.e. the angle of rotation) + eps: std of IGSO(3). + L: Truncation level + use_torch: set true to use torch tensors, otherwise use numpy arrays. + """ + + lib = torch if use_torch else np + ls = lib.arange(L) + if use_torch: + ls = ls.to(omega.device) + + if len(omega.shape) == 2: + # Used during predicted score calculation. + ls = ls[None, None] # [1, 1, L] + omega = omega[..., None] # [num_batch, num_res, 1] + eps = eps[..., None] + elif len(omega.shape) == 1: + # Used during cache computation. + ls = ls[None] # [1, L] + omega = omega[..., None] # [num_batch, 1] + else: + raise ValueError("Omega must be 1D or 2D.") + p = (2*ls + 1) * lib.exp(-ls*(ls+1)*eps**2/2) * lib.sin(omega*(ls+1/2)) / lib.sin(omega/2) + if use_torch: + return p.sum(dim=-1) + else: + return p.sum(axis=-1) + + +def density(expansion, omega, marginal=True, use_torch=False): + """IGSO(3) density. + + Args: + expansion: truncated approximation of the power series in the IGSO(3) + density. + omega: length of an Euler vector (i.e. angle of rotation) + marginal: set true to give marginal density over the angle of rotation, + otherwise include normalization to give density on SO(3) or a + rotation with angle omega. + """ + lib = torch if use_torch else np + if marginal: + # if marginal, density over [0, pi], else over SO(3) + return expansion * (1.0 - lib.cos(omega)) / np.pi + else: + # the constant factor doesn't affect any actual calculations though + return expansion / 8 / np.pi**2 + + +def score(exp, omega, eps, L=1000, use_torch=False): # score of density over SO(3) + """score uses the quotient rule to compute the scaling factor for the score + of the IGSO(3) density. + + This function is used within the Diffuser class to when computing the score + as an element of the tangent space of SO(3). + + This uses the quotient rule of calculus, and take the derivative of the + log: + d hi(x)/lo(x) = (lo(x) d hi(x)/dx - hi(x) d lo(x)/dx) / lo(x)^2 + and + d log expansion(x) / dx = (d expansion(x)/ dx) / expansion(x) + + Args: + exp: truncated expansion of the power series in the IGSO(3) density + omega: length of an Euler vector (i.e. angle of rotation) + eps: scale parameter for IGSO(3) -- as in expansion() this scaling + differ from that in Leach by a factor of sqrt(2). + L: truncation level + use_torch: set true to use torch tensors, otherwise use numpy arrays. + + Returns: + The d/d omega log IGSO3(omega; eps)/(1-cos(omega)) + + """ + lib = torch if use_torch else np + ls = lib.arange(L) + if use_torch: + ls = ls.to(omega.device) + ls = ls[None] + if len(omega.shape) == 2: + ls = ls[None] + elif len(omega.shape) > 2: + raise ValueError("Omega must be 1D or 2D.") + omega = omega[..., None] + eps = eps[..., None] + hi = lib.sin(omega * (ls + 1 / 2)) + dhi = (ls + 1 / 2) * lib.cos(omega * (ls + 1 / 2)) + lo = lib.sin(omega / 2) + dlo = 1 / 2 * lib.cos(omega / 2) + dSigma = (2 * ls + 1) * lib.exp(-ls * (ls + 1) * eps**2/2) * (lo * dhi - hi * dlo) / lo ** 2 + if use_torch: + dSigma = dSigma.sum(dim=-1) + else: + dSigma = dSigma.sum(axis=-1) + return dSigma / (exp + 1e-4) + + +class SO3Diffuser: + def __init__( + self, + cache_dir: str = './cache', + schedule: str = 'logarithmic', + min_sigma: float = 0.1, + max_sigma: float = 1.5, + num_sigma: int = 1000, + num_omega: int = 1000, + use_cached_score: bool = False, + eps: float = 1e-6, + ): + self.schedule = schedule + self.min_sigma = min_sigma + self.max_sigma = max_sigma + self.num_sigma = num_sigma + self.use_cached_score = use_cached_score + self.eps = eps + + # Discretize omegas for calculating CDFs. Skip omega=0. + self.discrete_omega = torch.linspace(0, np.pi, steps=num_omega+1)[1:] + + # Configure cache directory. + replace_period = lambda x: str(x).replace('.', '_') + cache_dir = os.path.join( + cache_dir, + f'eps_{num_sigma}_omega_{num_omega}_min_sigma_{replace_period(min_sigma)}_max_sigma_{replace_period(max_sigma)}_schedule_{schedule}' + ) + if not os.path.isdir(cache_dir): + os.makedirs(cache_dir) + + pdf_cache = os.path.join(cache_dir, 'pdf_vals.pt') + cdf_cache = os.path.join(cache_dir, 'cdf_vals.pt') + score_norms_cache = os.path.join(cache_dir, 'score_norms.pt') + + if os.path.exists(pdf_cache) \ + and os.path.exists(cdf_cache) \ + and os.path.exists(score_norms_cache): + self._pdf = torch.load(pdf_cache, map_location='cpu') + self._cdf = torch.load(cdf_cache, map_location='cpu') + self._score_norms = torch.load(score_norms_cache, map_location='cpu') + else: + disc_omega = self.discrete_omega.numpy() + disc_sigma = self.discrete_sigma.numpy() + # Compute the expansion of the power series. + exp_vals = np.asarray( + [igso3_expansion(disc_omega, sigma) for sigma in disc_sigma] + ) + # Compute the pdf and cdf values for the marginal distribution of the axis-angles (readily needed for sampling) + self._pdf = np.asarray( + [density(x, disc_omega, marginal=True) for x in exp_vals] + ) + self._cdf = np.asarray( + [pdf.cumsum() / num_omega * np.pi for pdf in self._pdf] + ) + # Compute the norms of the scores. This are used to scale the rotation axis when + # computing the score as a vector. + self._score_norms = np.asarray( + [score(exp_vals[i], disc_omega, x) for i, x in enumerate(disc_sigma)] + ) + self._pdf, self._cdf, self._score_norms = map( + torch.as_tensor, (self._pdf, self._cdf, self._score_norms) + ) + # Cache the precomputed values + torch.save(obj=self._pdf, f=pdf_cache) + torch.save(obj=self._cdf, f=cdf_cache) + torch.save(obj=self._score_norms, f=score_norms_cache) + + self._score_scaling = torch.sqrt(torch.abs( + torch.sum(self._score_norms**2 * self._pdf, dim=-1) / torch.sum(self._pdf, dim=-1) + )) / np.sqrt(3) + + @property + def discrete_sigma(self): + return self.sigma( + torch.linspace(0.0, 1.0, self.num_sigma) + ) + + def sigma_idx(self, sigma: torch.Tensor): + """Calculates the index for discretized sigma during IGSO(3) initialization.""" + return torch.as_tensor(np.digitize(sigma.cpu().numpy(), self.discrete_sigma) - 1, + dtype=torch.long) + + def sigma(self, t: torch.Tensor): + """Extract \sigma(t) corresponding to chosen sigma schedule.""" + if torch.any(t < 0) or torch.any(t > 1): + raise ValueError(f'Invalid t={t}') + if self.schedule == 'logarithmic': + return torch.log(t * math.exp(self.max_sigma) + (1 - t) * math.exp(self.min_sigma)) + else: + raise ValueError(f'Unrecognize schedule {self.schedule}') + + def diffusion_coef(self, t: torch.Tensor): + """Compute diffusion coefficient (g_t).""" + if self.schedule == 'logarithmic': + g_t = torch.sqrt( + 2 * (math.exp(self.max_sigma) - math.exp(self.min_sigma)) \ + * self.sigma(t) / torch.exp(self.sigma(t)) + ) + else: + raise ValueError(f'Unrecognize schedule {self.schedule}') + return g_t + + def t_to_idx(self, t: torch.Tensor): + """Helper function to go from time t to corresponding sigma_idx.""" + return self.sigma_idx(self.sigma(t)) + + def sample_prior(self, shape: torch.Size, device=None): + t = torch.ones(shape[0], dtype=torch.float, device=device) + return self.sample(t, shape) + + def sample(self, t: torch.Tensor, shape: torch.Size): + """Generates rotation vector(s) from IGSO(3). + + Args: + t: continuous time in [0, 1]. + shape: shape of the output tensor. + + Returns: + (shape, ) axis-angle rotation vectors sampled from IGSO(3). + """ + assert t.ndim == 1 and t.shape[0] == shape[0], \ + f"t.shape={t.shape}, shape={shape}" + assert shape[-1] == 3, f"The last dim should be 3, but got shape={shape}" + + z = torch.randn(shape, device=t.device) # axis-angle + x = z / torch.linalg.norm(z, dim=-1, keepdims=True) + + # sample igso3 + z_igso3 = torch.rand(shape[:-1]) # (num_batch, num_res) + igso3_scaling = [] + for i, _t in enumerate(t): + t_idx = self.t_to_idx(_t).item() # [1] + # np.interp can accept Tensor as input + igso3_scaling.append( + np.interp(z_igso3[i], self._cdf[t_idx], self.discrete_omega), + ) # (num_res, ) + igso3_scaling = torch.as_tensor(np.asarray(igso3_scaling), dtype=x.dtype, device=t.device) + + return x * igso3_scaling[..., None] + + def score( + self, + vec: torch.tensor, + t: torch.tensor, + ): + """Computes the score of IGSO(3) density as a rotation vector. + + Same as score function but uses pytorch and performs a look-up. + + Args: + vec: [..., 3] array of axis-angle rotation vectors. + t: continuous time in [0, 1]. + + Returns: + [..., 3] score vector in the direction of the sampled vector with + magnitude given by _score_norms. + """ + assert t.ndim == 1 and t.shape[0] == vec.shape[0], \ + f"t.shape={t.shape}, vec.shape={vec.shape}" + device = vec.device + + omega = torch.linalg.norm(vec, dim=-1) + self.eps # [num_batch, num_res] + + if self.use_cached_score: + score_norms_t = self._score_norms[self.t_to_idx(t)] + score_norms_t = torch.as_tensor(score_norms_t).to(device) + omega_idx = torch.bucketize( + omega, torch.as_tensor(self.discrete_omega[:-1]).to(device)) + omega_scores_t = torch.gather(score_norms_t, 1, omega_idx) + else: # compute on the fly + sigma = self.discrete_sigma[self.t_to_idx(t)] + sigma = torch.as_tensor(sigma).to(device) + omega_vals = igso3_expansion(omega, sigma[:, None], use_torch=True) + omega_scores_t = score(omega_vals, omega, sigma[:, None], use_torch=True) + + return omega_scores_t[..., None] * vec / (omega[..., None] + self.eps) + + def score_scaling(self, t: torch.Tensor): + """Calculates scaling used for scores during trianing.""" + return torch.as_tensor(self._score_scaling[self.t_to_idx(t)], device=t.device) + + def forward_marginal(self, rot_0: torch.Tensor, t: torch.Tensor): + """Samples from the forward diffusion process at time index t. + + Args: + rot_0: [..., 3] initial rotations. + t: continuous time in [0, 1]. + + Returns: + rot_t: [..., 3] noised rotation vectors. + rot_score: [..., 3] score of rot_t as a rotation vector. + """ + rotvec_0t = self.sample(t, shape=rot_0.shape) # "delta_rot" + rot_score = self.score(rotvec_0t, t) + + # Right multiply => vector_add in Euclidean space + rot_t = compose_rotvec(rot_0, rotvec_0t) + return rot_t, rot_score + + def reverse( + self, + rot_t: torch.Tensor, + score_t: torch.Tensor, + t: torch.Tensor, + dt: float, + mask: torch.Tensor = None, + noise_scale: float = 1.0, + probability_flow: bool = True, + ): + """Simulates the reverse SDE for 1 step using the Geodesic random walk. + + Args: + rot_t: [..., 3] current rotations at time t. + score_t: [..., 3] rotation score at time t. + t: continuous time in [0, 1]. + dt: continuous step size in [0, 1]. + add_noise: set False to set diffusion coefficent to 0. + mask: True indicates which residues to diffuse. + probability_flow: set True to use probability flow ODE. + + Returns: + [..., 3] rotation vector at next step. + """ + t = inflate_array_like(t, rot_t) + + g_t = self.diffusion_coef(t) + z = noise_scale * torch.randn_like(score_t) + + rev_drift = -1.0 * (g_t ** 2) * score_t * dt * (0.5 if probability_flow else 1.) + rev_diffusion = 0. if probability_flow else (g_t * np.sqrt(dt) * z) + perturb = rev_drift + rev_diffusion + + if mask is not None: + perturb *= mask[..., None] + + # Right multiply. + rot_t_1 = compose_rotvec(rot_t, -1.0 * perturb) # reverse in time + return rot_t_1 \ No newline at end of file diff --git a/analysis/src/train.py b/analysis/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c9973055301bf5844aef703317f637742f924c6c --- /dev/null +++ b/analysis/src/train.py @@ -0,0 +1,135 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from src.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, + checkpoint_utils, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.get("ckpt_path")) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/analysis/src/utils/__init__.py b/analysis/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0707ca57ec89fc5f5cb1a023135eeb756a8e1e --- /dev/null +++ b/analysis/src/utils/__init__.py @@ -0,0 +1,5 @@ +from src.utils.instantiators import instantiate_callbacks, instantiate_loggers +from src.utils.logging_utils import log_hyperparameters +from src.utils.pylogger import RankedLogger +from src.utils.rich_utils import enforce_tags, print_config_tree +from src.utils.utils import extras, get_metric_value, task_wrapper diff --git a/analysis/src/utils/__pycache__/__init__.cpython-310.pyc b/analysis/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc38620891cdc2e6b3c469b37dfb5ccf5971370c Binary files /dev/null and b/analysis/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/__init__.cpython-39.pyc b/analysis/src/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e16fdb32c1ff6c131fe52d84b7b2eaefb7456ed Binary files /dev/null and b/analysis/src/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e5c64df79d10dee3601326a0ed066cc4659070 Binary files /dev/null and b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d97c36604a77e84444152aa33ba67fd17f0f68 Binary files /dev/null and b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc b/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebab309ccf49d8e01a2dadcccf996dc4e425d97 Binary files /dev/null and b/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc b/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e23b93bac589979f49f9b8fc6fac281b3026070 Binary files /dev/null and b/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e403b520b31a1bc9184a0bf74301017f0ce1e58 Binary files /dev/null and b/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26c790bd93755550c51ec8cba7e6d2f8537a1e7c Binary files /dev/null and b/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1784bf4f81c7600125742d41edea5f19838c106 Binary files /dev/null and b/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7dad302ccc0ef4a29b00490c6ac9eea37a24fc7 Binary files /dev/null and b/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc b/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ef8808d9784553d45122ebebf168931145d9a9e Binary files /dev/null and b/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc b/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..315a10f77f0a243ef68628c7df04759004806549 Binary files /dev/null and b/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc7ea30ef6b89b3944f92dc3e6d3f3f836b7244 Binary files /dev/null and b/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e48ab6a604dcd05449c7a29afe102bfc1e0e488 Binary files /dev/null and b/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9647e909d1c230486fcdd49ea9533e6602e9f55 Binary files /dev/null and b/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc differ diff --git a/analysis/src/utils/__pycache__/utils.cpython-310.pyc b/analysis/src/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1deb35fe5b79b73bfc4da02dece2e9a529c7e461 Binary files /dev/null and b/analysis/src/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/analysis/src/utils/__pycache__/utils.cpython-39.pyc b/analysis/src/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e50a4c7a3c2f693aa2bda2891fae98ead5fe034 Binary files /dev/null and b/analysis/src/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/analysis/src/utils/checkpoint_utils.py b/analysis/src/utils/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec721430698a9f9dd1ae4da8244baef1d4b5ea8 --- /dev/null +++ b/analysis/src/utils/checkpoint_utils.py @@ -0,0 +1,28 @@ +import torch + +def load_model_checkpoint(model, ckpt_path): + """Load state dict from checkpoint file. + + :param model: The model to load the state dict into. + :param ckpt_path: The path to the checkpoint file. + """ + if ckpt_path is None: + return model, None + + # The ckpt_path ending with .ckpt is a checkpoint file saved by pytorch-lightning. + # If the ckpt_path is a .pth file, it is viewed as a checkpoint file saved by pytorch + # such that only net parameters are loaded. + # (This may avoid the ambiguity of loading #epochs/lr for finetuning) + if ckpt_path.endswith(".pth"): + net_params = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'] + net_params = {k.replace('net.', ''): v for k, v in net_params.items()} + model.net.load_state_dict(net_params) + ckpt_path = None + elif ckpt_path.endswith(".ckpt"): + # will be handled later by the trainer + pass + else: + # suffix check + raise ValueError(f"ckpt_path {ckpt_path} is not a valid checkpoint file.") + + return model, ckpt_path \ No newline at end of file diff --git a/analysis/src/utils/instantiators.py b/analysis/src/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..82b9278a465d39565942f862442ebe79549825d7 --- /dev/null +++ b/analysis/src/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/analysis/src/utils/logging_utils.py b/analysis/src/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..360abcdceec82e551995f756ce6ec3b2d06ae641 --- /dev/null +++ b/analysis/src/utils/logging_utils.py @@ -0,0 +1,57 @@ +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/analysis/src/utils/multiprocs.py b/analysis/src/utils/multiprocs.py new file mode 100644 index 0000000000000000000000000000000000000000..01cdbe3476485adcc234753ebc2a737425e599ef --- /dev/null +++ b/analysis/src/utils/multiprocs.py @@ -0,0 +1,34 @@ +""" +Utilities for multiprocessing. +""" + +import multiprocessing as mp + +import numpy as np + + +def mp_map(func, args, n_cpu=None, verbose=True): + """ + Multiprocessing wrapper for map function. + """ + if n_cpu is None: + n_cpu = mp.cpu_count() + pool = mp.Pool(n_cpu) + if verbose: + print('Running on {} CPUs'.format(n_cpu)) + result = pool.map(func, args) + pool.close() + pool.join() + return result + +def parse_mp_result(result_list, sort_fn=None): + """Parse output from mp_map""" + if sort_fn is None: + return result_list + result = sorted(result_list, key=sort_fn, ) # ascending + return result + + +# unit test +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/analysis/src/utils/plot_utils.py b/analysis/src/utils/plot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3316ac75194ae0e3befb6bd76fe427c6459ac3c0 --- /dev/null +++ b/analysis/src/utils/plot_utils.py @@ -0,0 +1,100 @@ +from typing import Dict, Optional +import os + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.stats import gaussian_kde +from scipy import interpolate + + +################################################## +FONTSIZE = 18 +################################################## + + +def scatterplot_2d( + data_dict: Dict, + save_to: str, + ref_key: str = 'target', + xlabel: str = 'tIC1', + ylabel: str = 'tIC2', + n_max_point: int = 1000, + pop_ref: bool = False, + xylim_key: bool = 'PDB_clusters', + plot_kde: bool = False, + density_mapping: Optional[Dict] = None +): + # configure min max + if xylim_key and xylim_key in data_dict: + xylim = data_dict.pop(xylim_key) + # plot + x_max = max(xylim[:,0]) + x_min = min(xylim[:,0]) + y_max = max(xylim[:,1]) + y_min = min(xylim[:,1]) + else: + xylim = None + x_max = max(data_dict[ref_key][:,0]) + x_min = min(data_dict[ref_key][:,0]) + y_max = max(data_dict[ref_key][:,1]) + y_min = min(data_dict[ref_key][:,1]) + + # Add margin. + x_min -= (x_max - x_min)/5.0 + x_max += (x_max - x_min)/5.0 + y_min -= (y_max - y_min)/5.0 + y_max += (y_max - y_min)/5.0 + + # Remove reference data to save time. + if pop_ref: + data_dict.pop(ref_key) + + # plot tica + print(f">>> Plotting scatter in 2D space. Image save to {save_to}") + + # Configure subplots. + plot_n_row = len(data_dict) // 5 if len(data_dict) > 5 else 1 # at most 6 columns + plot_n_columns = len(data_dict) // plot_n_row if len(data_dict) > 5 else len(data_dict) + plt.figure(figsize=(6 * plot_n_columns , plot_n_row * 6)) + + i = 0 + for k, v in data_dict.items(): + i += 1 + plt.subplot(plot_n_row, plot_n_columns, i) + + if k != ref_key and v.shape[0] > n_max_point: # subsample for visualize + idx = np.random.choice(v.shape[0], n_max_point, replace=False) + v = v[idx] + + if v.shape[0] < v.shape[1]: + print(f"Warning: {k} has more dimensions than samples, using uniform density.") + density = np.ones_like(v[:,0]) + density /= density.sum() + else: + cov = np.transpose(v) + density = gaussian_kde(cov)(cov) + + # Optional precomputed density mapping. + if density_mapping and k in density_mapping: + density = density_mapping[k] + + plt.scatter(v[:, 0], v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40) + # sns.scatterplot(x=v[:, 0], y=v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40) + + if plot_kde: + sns.kdeplot(x=data_dict[ref_key][:, 0], y=data_dict[ref_key][:,1]) # landscape + + if xylim is not None: + plt.scatter(xylim[:,0], xylim[:,1], s=40, marker="o", c="none", edgecolors="tab:red") # cluster centers + + plt.xlabel(xlabel, fontsize=FONTSIZE, fontfamily="sans-serif") + if (i-1) % plot_n_columns == 0: + plt.ylabel(ylabel, fontsize=FONTSIZE, fontfamily="sans-serif") + + plt.xlim(x_min, x_max) + plt.ylim(y_min, y_max) + plt.title(k, fontsize=FONTSIZE, fontfamily="sans-serif") + + plt.tight_layout() + plt.savefig(save_to, dpi=500) \ No newline at end of file diff --git a/analysis/src/utils/pylogger.py b/analysis/src/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ee8675ebde11b2a43b0679a03cd88d9268bc71 --- /dev/null +++ b/analysis/src/utils/pylogger.py @@ -0,0 +1,51 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/analysis/src/utils/rich_utils.py b/analysis/src/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aeec6806bb1e4a15a04b91b710a546231590ab14 --- /dev/null +++ b/analysis/src/utils/rich_utils.py @@ -0,0 +1,99 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/analysis/src/utils/tensor_utils.py b/analysis/src/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b61735bd28b939ffaa196f14346b56d1429b0f60 --- /dev/null +++ b/analysis/src/utils/tensor_utils.py @@ -0,0 +1,149 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List + + +import torch +import torch.nn as nn + + +def inflate_array_like(array, target): + """ (tested) + Inflates the array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty + axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. + Args: + array: (B, ) + target: (B, ...) + + Returns: + array: (B, ...) + """ + if isinstance(array, float): + return array + + diff_dims = target.ndim - array.ndim + assert diff_dims >= 0, f'Error: target.ndim {target.ndim} < array.ndim {array.ndim}' + if diff_dims == 0: + return array + assert target.shape[:array.ndim] == array.shape[:array.ndim], f'Error: target.shape[:array.ndim] {target.shape[:array.ndim]} != array.shape[:array.ndim] {array.shape[:array.ndim]}' + return array[(...,) + (None,) * diff_dims] + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + +def sum_except_batch(t: torch.Tensor, batch_dims: int=1): + return t.reshape(t.shape[:batch_dims] + (-1,)).sum(dim=-1) + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes diff --git a/analysis/src/utils/utils.py b/analysis/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02b55765ad3de9441bed931a577ebbb3b669fda4 --- /dev/null +++ b/analysis/src/utils/utils.py @@ -0,0 +1,119 @@ +import warnings +from importlib.util import find_spec +from typing import Any, Callable, Dict, Optional, Tuple + +from omegaconf import DictConfig + +from src.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/analysis/utils.py b/analysis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce32888f21cc0f107d39c5a3264fedcafa431e6a --- /dev/null +++ b/analysis/utils.py @@ -0,0 +1,74 @@ +import numpy as np +import os +import re +from data import protein +from openfold.utils import rigid_utils + + +Rigid = rigid_utils.Rigid + + +def create_full_prot( + atom37: np.ndarray, + atom37_mask: np.ndarray, + aatype=None, + b_factors=None, + ): + assert atom37.ndim == 3 + assert atom37.shape[-1] == 3 + assert atom37.shape[-2] == 37 + n = atom37.shape[0] + residue_index = np.arange(n) + chain_index = np.zeros(n) + if b_factors is None: + b_factors = np.zeros([n, 37]) + if aatype is None: + aatype = np.zeros(n, dtype=int) + return protein.Protein( + atom_positions=atom37, + atom_mask=atom37_mask, + aatype=aatype, + residue_index=residue_index, + chain_index=chain_index, + b_factors=b_factors) + + +def write_prot_to_pdb( + prot_pos: np.ndarray, + file_path: str, + aatype: np.ndarray=None, + overwrite=False, + no_indexing=False, + b_factors=None, + ): + if overwrite: + max_existing_idx = 0 + else: + file_dir = os.path.dirname(file_path) + file_name = os.path.basename(file_path).strip('.pdb') + existing_files = [x for x in os.listdir(file_dir) if file_name in x] + max_existing_idx = max([ + int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x) + if re.findall(r'_(\d+).pdb', x)] + [0]) + if not no_indexing: + save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb' + else: + save_path = file_path + with open(save_path, 'w') as f: + if prot_pos.ndim == 4: + for t, pos37 in enumerate(prot_pos): + atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 + prot = create_full_prot( + pos37, atom37_mask, aatype=aatype, b_factors=b_factors) + pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False) + f.write(pdb_prot) + elif prot_pos.ndim == 3: + atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7 + prot = create_full_prot( + prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors) + pdb_prot = protein.to_pdb(prot, model=1, add_end=False) + f.write(pdb_prot) + else: + raise ValueError(f'Invalid positions shape {prot_pos.shape}') + f.write('END') + return save_path diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cc27fcf4b9c6a2d28d07a361b20f9645314aff6 --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,144 @@ +data: + # CSV for path and metadata to training examples. + dataset: + seed: 123 + use_rotate_enhance: False + split_frac: 1.0 # 1.0 means all data except validation set is used for training + use_split: True # whether cut the sequence to slices + split_len: 600 # 328 + min_num_res: 32 + max_num_res: 2000 + cache_num_res: 0 + subset: null + samples_per_eval_length: 5 + num_eval_lengths: 32 + max_eval_length: 2000 + max_valid_num: 50 + csv_path: ./dataset/ATLAS/select/pkl/metadata_merged.csv + loader: + num_workers: 8 + prefetch_factor: 10 + sampler: + max_batch_size: 64 + max_num_res_squared: 500000 # dynamic batch size + use_batch_repeats: false + num_batches: null + +interpolant: + min_t: 0.01 + separate_t: true + provide_kappa: true + hierarchical_t: false + add_noise: True + rots: + train_schedule: linear + sample_schedule: exp + exp_rate: 10 + noise_scale: 1.5 + trans: + batch_ot: true + train_schedule: linear + sample_schedule: linear + noise_scale: 10.0 # noise for esmfold_pred to get prior + sample_temp: 1.0 + vpsde_bmin: 0.1 + vpsde_bmax: 20.0 + sampling: + num_timesteps: 100 + do_sde: false + self_condition: ${model.edge_features.self_condition} + +model: + dropout: 0.2 # 0.2 + node_embed_size: 256 + edge_embed_size: 128 + symmetric: False + use_torsions: True + use_adapter_node: True + use_mid_bb_update: False + use_mid_bb_update_e3: False + use_e3_transformer: False + node_features: + c_s: ${model.node_embed_size} + c_pos_emb: 1280 + # c_pos_emb: ${model.edge_embed_size} + c_timestep_emb: ${model.node_embed_size} + embed_diffuse_mask: False + separate_t: ${interpolant.separate_t} + max_num_res: 2000 + timestep_int: 1000 + dropout: ${model.dropout} + edge_features: + single_bias_transition_n: 2 + c_s: ${model.node_embed_size} + c_p: ${model.edge_embed_size} + relpos_k: 64 + use_rbf: True + num_rbf: 64 + feat_dim: ${model.edge_embed_size} + num_bins: 64 + # max_dist: angstrom + max_dist: 30.0 # 50.0 + dropout: ${model.dropout} + self_condition: True + ipa: + c_s: ${model.node_embed_size} + c_z: ${model.edge_embed_size} + c_hidden: ${model.edge_embed_size} + no_heads: 16 + no_qk_points: 32 + no_v_points: 8 + seq_tfmr_num_heads: 16 + seq_tfmr_num_layers: 2 + num_blocks: 6 + dropout: ${model.dropout} + +experiment: + debug: False + seed: 123 + num_devices: 1 + # warm_start: ./weights/esm_egf4.ckpt + warm_start: null + warm_start_cfg_override: True + use_swa: False + batch_ot: + enabled: True + cost: kabsch + noise_per_sample: 1 + permute: False + training: + min_plddt_mask: null + t_normalize_clip: 0.9 + loss: se3_vf_loss + # trans_scale: 0.1 + translation_loss_weight: 0.2 + rotation_loss_weights: 1.0 # 1.0 + bb_atom_scale: 0.1 # 0.1 + aux_loss_weight: 1.0 # 1.0 + aux_loss_t_pass: 0.25 + aatype_loss_weight: 0.0 + wandb: + name: loss_full + project: P2DFlow + save_code: True + tags: [] + optimizer: + lr: 1e-4 # 1e-4 + trainer: + overfit_batches: 0 + min_epochs: 1 # prevents early stopping + max_epochs: 2000 + accelerator: gpu + log_every_n_steps: 1 + deterministic: False + # strategy: ddp + strategy: ddp_find_unused_parameters_true + check_val_every_n_epoch: 1 + accumulate_grad_batches: 2 + gradient_clip_val: 10 + checkpointer: + dirpath: ckpt/${experiment.wandb.project}/${experiment.wandb.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + save_last: True + save_top_k: 2000 + monitor: valid/rmsd_loss + mode: min diff --git a/configs/inference.yaml b/configs/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86793d802acfc85a33e1f94049f067786c8e0de0 --- /dev/null +++ b/configs/inference.yaml @@ -0,0 +1,15 @@ +inference: + # Use this to write with date-time stamp. + name: ${now:%Y-%m-%d}_${now:%H-%M} + seed: 123 + ckpt_path: ./weights/pretrained.ckpt + output_dir: inference_outputs/ + use_gpu: True + num_gpus: 1 + + samples: + validset_path: ./inference/valid_seq.csv + esm_savepath: ${inference.output_dir} + sample_num: 250 + sample_batch: 5 + diff --git a/data/ESMfold_pred.py b/data/ESMfold_pred.py new file mode 100644 index 0000000000000000000000000000000000000000..3782f88727d98c9f782b9f364e0e74811b7b5457 --- /dev/null +++ b/data/ESMfold_pred.py @@ -0,0 +1,35 @@ +import esm +import torch + +from Bio import SeqIO + +class ESMFold_Pred(): + def __init__(self, device): + self._folding_model = esm.pretrained.esmfold_v1().eval() + self._folding_model.requires_grad_(False) + self._folding_model.to(device) + + def predict_str(self, pdbfile, save_path, max_seq_len = 1500): + seq_record = SeqIO.parse(pdbfile, "pdb-atom") + count = 0 + seq_list = [] + for record in seq_record: + seq = str(record.seq) + # seq = seq.replace("X","") + + if len(seq) > max_seq_len: + continue + + print(f'seq {count}:',seq) + seq_list.append(seq) + count += 1 + + for idx, seq in enumerate(seq_list): + with torch.no_grad(): + output = self._folding_model.infer_pdb(seq) + with open(save_path, "w+") as f: + f.write(output) + break # only infer for the first seq + + + diff --git a/data/__pycache__/ESMfold_pred.cpython-310.pyc b/data/__pycache__/ESMfold_pred.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c24df2a1c5df266ab49f8f00d0371a9e26ea56 Binary files /dev/null and b/data/__pycache__/ESMfold_pred.cpython-310.pyc differ diff --git a/data/__pycache__/all_atom.cpython-310.pyc b/data/__pycache__/all_atom.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..465801d3621feaeb3df0fd0a490d84641954e3fa Binary files /dev/null and b/data/__pycache__/all_atom.cpython-310.pyc differ diff --git a/data/__pycache__/cal_trans_rotmats.cpython-310.pyc b/data/__pycache__/cal_trans_rotmats.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e54e6cf8ca9e5f14c5025946f335a6a02434734 Binary files /dev/null and b/data/__pycache__/cal_trans_rotmats.cpython-310.pyc differ diff --git a/data/__pycache__/errors.cpython-310.pyc b/data/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ceeb5873186dc6804d8e68e9c5dc67932bea353 Binary files /dev/null and b/data/__pycache__/errors.cpython-310.pyc differ diff --git a/data/__pycache__/interpolant.cpython-310.pyc b/data/__pycache__/interpolant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..414afb84f0859964fd0a27243a241bcea766f814 Binary files /dev/null and b/data/__pycache__/interpolant.cpython-310.pyc differ diff --git a/data/__pycache__/parsers.cpython-310.pyc b/data/__pycache__/parsers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa7575b05a659c8d9fea738835b6b53b10ed091 Binary files /dev/null and b/data/__pycache__/parsers.cpython-310.pyc differ diff --git a/data/__pycache__/pdb_dataloader.cpython-310.pyc b/data/__pycache__/pdb_dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d864214417cd92ee5c85fa2a42ade625eccd025 Binary files /dev/null and b/data/__pycache__/pdb_dataloader.cpython-310.pyc differ diff --git a/data/__pycache__/protein.cpython-310.pyc b/data/__pycache__/protein.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38d42ce9ba8247026fb7220e63b82bb51f7c6f1d Binary files /dev/null and b/data/__pycache__/protein.cpython-310.pyc differ diff --git a/data/__pycache__/protein.cpython-38.pyc b/data/__pycache__/protein.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e20a9345bf101231b0840ce932c373f81a5e21b8 Binary files /dev/null and b/data/__pycache__/protein.cpython-38.pyc differ diff --git a/data/__pycache__/repr.cpython-310.pyc b/data/__pycache__/repr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c46ecaf7b5c0ebe47982c57c86cc1707dc54943c Binary files /dev/null and b/data/__pycache__/repr.cpython-310.pyc differ diff --git a/data/__pycache__/residue_constants.cpython-310.pyc b/data/__pycache__/residue_constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bec5e8f79ec87bd090fc6b342da8e54efa3dc019 Binary files /dev/null and b/data/__pycache__/residue_constants.cpython-310.pyc differ diff --git a/data/__pycache__/residue_constants.cpython-38.pyc b/data/__pycache__/residue_constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7aa3129887b89f164b60dffb65189952c7f9c40 Binary files /dev/null and b/data/__pycache__/residue_constants.cpython-38.pyc differ diff --git a/data/__pycache__/so3_utils.cpython-310.pyc b/data/__pycache__/so3_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f865ab345204a0c5824d8a65813991dffdfd02ce Binary files /dev/null and b/data/__pycache__/so3_utils.cpython-310.pyc differ diff --git a/data/__pycache__/utils.cpython-310.pyc b/data/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d50c90627c6c52db534b8eeb036277693a3b88a7 Binary files /dev/null and b/data/__pycache__/utils.cpython-310.pyc differ diff --git a/data/all_atom.py b/data/all_atom.py new file mode 100644 index 0000000000000000000000000000000000000000..05e7bd41fa54fea9663b9e5eec780aa25c16517b --- /dev/null +++ b/data/all_atom.py @@ -0,0 +1,371 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for calculating all atom representations. +Code adapted from OpenFold. +""" + +import torch +from openfold.data import data_transforms +from openfold.np import residue_constants +from openfold.utils import rigid_utils as ru +from data import utils as du + +Rigid = ru.Rigid +Rotation = ru.Rotation + +# Residue Constants from OpenFold/AlphaFold2. + + +IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) +ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) +GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) + +IDEALIZED_POS_37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions) +ATOM_MASK_37 = torch.tensor(residue_constants.restype_atom37_mask) +GROUP_IDX_37 = torch.tensor(residue_constants.restype_atom37_to_rigid_group) + +def to_atom37(trans, rots, aatype=None, torsions_with_CB=None, get_mask=False): + num_batch, num_res, _ = trans.shape + + if torsions_with_CB is None: # (B,L,) + torsions_with_CB = torch.concat( + [torch.zeros((num_batch,num_res,8,1),device=trans.device), + torch.ones((num_batch,num_res,8,1),device=trans.device)], + dim=-1 + ) # (B,L,8,2) + + # final_atom37 = compute_backbone( + # du.create_rigid(rots, trans), + # torch.zeros(num_batch, num_res, 2, device=trans.device) + # )[0] + + final_atom37, atom37_mask = compute_atom37_pos( + du.create_rigid(rots, trans), + torsions_with_CB, + aatype=aatype, + )[:2] # (B,L,37,3) + + if get_mask: + return final_atom37, atom37_mask + else: + return final_atom37 + +def torsion_angles_to_frames( + r: Rigid, # type: ignore [valid-type] + alpha: torch.Tensor, + aatype: torch.Tensor, + bb_rot = None +): + """Conversion method of torsion angles to frames provided the backbone. + + Args: + r: Backbone rigid groups. + alpha: Torsion angles. (B,L,7,2) + aatype: residue types. + + Returns: + All 8 frames corresponding to each torsion frame. + + + + + + !!! May need to set omega and fai angle to be zero !!! + + + + + + """ + # [*, N, 8, 4, 4] + with torch.no_grad(): + default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] # type: ignore [attr-defined] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) # (B,L,8) + + if bb_rot is None: + bb_rot = alpha.new_zeros(((1,) * len(alpha.shape[:-1]))+(2,)) # (1,1,1,2) + bb_rot[..., 1] = 1 + bb_rot = bb_rot.expand(*alpha.shape[:-2], -1, -1) # (B,L,1,2) + + alpha = torch.cat([bb_rot, alpha], dim=-2) # (B,L,8,2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) # (B,L,8,3,3) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) # type: ignore [index] + + return all_frames_to_global + + +def prot_to_torsion_angles(aatype, atom37, atom37_mask): + """Calculate torsion angle features from protein features.""" + prot_feats = { + "aatype": aatype, + "all_atom_positions": atom37, + "all_atom_mask": atom37_mask, + } + torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats) + torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"] + torsion_mask = torsion_angles_feats["torsion_angles_mask"] + return torsion_angles, torsion_mask + + +def frames_to_atom14_pos( + r: Rigid, # type: ignore [valid-type] + aatype: torch.Tensor, +): + """Convert frames to their idealized all atom representation. + + Args: + r: All rigid groups. [B,L,8] + aatype: Residue types. [B,L] + + Returns: + + """ + with torch.no_grad(): + group_mask = GROUP_IDX.to(aatype.device)[aatype, ...] # (B,L,14) + group_mask = torch.nn.functional.one_hot( + group_mask, + num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4) + ) # (B,L,14,8) + frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,14,1) + frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] # (B,L,14,3) + + t_atoms_to_global = r[..., None, :] * group_mask # (B,L,14,8) + + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,14) + + pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,14,3) + pred_positions = pred_positions * frame_atom_mask + + return pred_positions + +def frames_to_atom37_pos( + r: Rigid, # type: ignore [valid-type] + aatype: torch.Tensor, +): + """Convert frames to their idealized all atom representation. + + Args: + r: All rigid groups. [B,L] + aatype: Residue types. [B,L] + + Returns: + + """ + with torch.no_grad(): + group_mask = GROUP_IDX_37.to(aatype.device)[aatype, ...] # (B,L,37) + group_mask = torch.nn.functional.one_hot( + group_mask, + num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4) + ) # (B,L,37,8) + frame_atom_mask = ATOM_MASK_37.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,37,1) + frame_null_pos = IDEALIZED_POS_37.to(aatype.device)[aatype, ...] # (B,L,37,3) + + t_atoms_to_global = r[..., None, :] * group_mask # (B,L,37,8) + + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,37) + + pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,37,3) + pred_positions = pred_positions * frame_atom_mask + + return pred_positions, frame_atom_mask[...,0] + +#! may need to remove it +def compute_backbone(bb_rigids, psi_torsions, aatype=None): + torsion_angles = torch.tile( + psi_torsions[..., None, :], tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1) + ) + ''' + psi_torsions[..., None, :].shape: (B,L,1,2) + torsion_angles.shape: (B,L,7,2) + bb_rigids.shape: (B,L) + ''' + + if aatype is None: + aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long() + + all_frames = torsion_angles_to_frames( + bb_rigids, + torsion_angles, + aatype, + ) + atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3) + atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device) # (B,L,37,3) + + # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB'] + # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O'] + atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :] + atom37_mask = torch.any(atom37_bb_pos, axis=-1) # mask atom with all 0 xyz + + return atom37_bb_pos, atom37_mask, aatype, atom14_pos + +def compute_atom37_pos(bb_rigids, torsions_with_CB, aatype=None): + ''' + torsions_with_CB.shape: (B,L,8,2) + bb_rigids.shape: (B,L) + ''' + + if aatype is None: + aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long() + + all_frames = torsion_angles_to_frames( + bb_rigids, + torsions_with_CB[:,:,1:,:], + aatype, + bb_rot = torsions_with_CB[:,:,0:1,:], + ) + + atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3) + atom37_pos,atom37_mask = frames_to_atom37_pos(all_frames, aatype) # (B,L,37,3) + + return atom37_pos, atom37_mask, aatype, atom14_pos + +def calculate_neighbor_angles(R_ac, R_ab): + """Calculate angles between atoms c <- a -> b. + + Parameters + ---------- + R_ac: Tensor, shape = (N,3) + Vector from atom a to c. + R_ab: Tensor, shape = (N,3) + Vector from atom a to b. + + Returns + ------- + angle_cab: Tensor, shape = (N,) + Angle between atoms c <- a -> b. + """ + # cos(alpha) = (u * v) / (|u|*|v|) + x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,) + # sin(alpha) = |u x v| / (|u|*|v|) + y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,) + # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN + y = torch.max(y, torch.tensor(1e-9)) + angle = torch.atan2(y, x) + return angle + + +def vector_projection(R_ab, P_n): + """ + Project the vector R_ab onto a plane with normal vector P_n. + + Parameters + ---------- + R_ab: Tensor, shape = (N,3) + Vector from atom a to b. + P_n: Tensor, shape = (N,3) + Normal vector of a plane onto which to project R_ab. + + Returns + ------- + R_ab_proj: Tensor, shape = (N,3) + Projected vector (orthogonal to P_n). + """ + a_x_b = torch.sum(R_ab * P_n, dim=-1) + b_x_b = torch.sum(P_n * P_n, dim=-1) + return R_ab - (a_x_b / b_x_b)[:, None] * P_n + + +def transrot_to_atom37(transrot_traj, res_mask, aatype=None, torsions_with_CB=None): + atom37_traj = [] + res_mask = res_mask.detach().cpu() + num_batch = res_mask.shape[0] + + for trans, rots in transrot_traj: + atom37 = to_atom37(trans, rots, aatype=aatype, torsions_with_CB=torsions_with_CB,get_mask=False) + atom37 = atom37.detach().cpu() + # batch_atom37 = [] + # for i in range(num_batch): + # batch_atom37.append( + # du.adjust_oxygen_pos(atom37[i], res_mask[i]) + # ) + # atom37_traj.append(torch.stack(batch_atom37)) + atom37_traj.append(atom37) + return atom37_traj + +# def atom37_from_trans_rot(trans, rots, res_mask): +# rigids = du.create_rigid(rots, trans) +# atom37 = compute_backbone( +# rigids, +# torch.zeros( +# trans.shape[0], +# trans.shape[1], +# 2, +# device=trans.device +# ) +# )[0] +# atom37 = atom37.detach().cpu() +# batch_atom37 = [] +# num_batch = res_mask.shape[0] +# for i in range(num_batch): +# batch_atom37.append( +# du.adjust_oxygen_pos(atom37[i], res_mask[i]) +# ) +# return torch.stack(batch_atom37) + + +# def process_trans_rot_traj(trans_traj, rots_traj, res_mask): +# res_mask = res_mask.detach().cpu() +# atom37_traj = [ +# atom37_from_trans_rot(trans, rots, res_mask) +# for trans, rots in zip(trans_traj, rots_traj) +# ] +# atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1) +# return atom37_traj diff --git a/data/cal_trans_rotmats.py b/data/cal_trans_rotmats.py new file mode 100644 index 0000000000000000000000000000000000000000..1c75c43cd2793763228c56ede69ebb283e8beaee --- /dev/null +++ b/data/cal_trans_rotmats.py @@ -0,0 +1,60 @@ +import torch +import dataclasses + +import numpy as np + +from Bio import PDB +from data import parsers, errors +from data import utils as du +from openfold.data import data_transforms +from openfold.utils import rigid_utils + + +def cal_trans_rotmats(save_path): + metadata = {} + parser = PDB.PDBParser(QUIET=True) + structure = parser.get_structure('test', save_path) + + # Extract all chains + struct_chains = { + chain.id.upper(): chain + for chain in structure.get_chains()} + metadata['num_chains'] = len(struct_chains) + # Extract features + struct_feats = [] + all_seqs = set() + for chain_id, chain in struct_chains.items(): + # Convert chain id into int + chain_id = du.chain_str_to_int(chain_id) + chain_prot = parsers.process_chain(chain, chain_id) + chain_dict = dataclasses.asdict(chain_prot) + chain_dict = du.parse_chain_feats(chain_dict) + all_seqs.add(tuple(chain_dict['aatype'])) + struct_feats.append(chain_dict) + if len(all_seqs) == 1: + metadata['quaternary_category'] = 'homomer' + else: + metadata['quaternary_category'] = 'heteromer' + complex_feats = du.concat_np_features(struct_feats, False) + # Process geometry features + complex_aatype = complex_feats['aatype'] + metadata['seq_len'] = len(complex_aatype) + modeled_idx = np.where(complex_aatype != 20)[0] + if np.sum(complex_aatype != 20) == 0: + raise errors.LengthError('No modeled residues') + min_modeled_idx = np.min(modeled_idx) + max_modeled_idx = np.max(modeled_idx) + metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1 + complex_feats['modeled_idx'] = modeled_idx + + processed_feats = du.parse_chain_feats(complex_feats) + chain_feats_temp = { + 'aatype': torch.tensor(processed_feats['aatype']).long(), + 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), + 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() + } + chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp) + curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0] + trans = curr_rigid.get_trans().cpu() + rotmats = curr_rigid.get_rots().get_rot_mats().cpu() + return trans, rotmats \ No newline at end of file diff --git a/data/errors.py b/data/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf1026227fa6f2881c0a17b9a4cc53f3ac77317 --- /dev/null +++ b/data/errors.py @@ -0,0 +1,26 @@ +"""Error class for handled errors.""" + + +class DataError(Exception): + """Data exception.""" + pass + + +class FileExistsError(DataError): + """Raised when file already exists.""" + pass + + +class MmcifParsingError(DataError): + """Raised when mmcif parsing fails.""" + pass + + +class ResolutionError(DataError): + """Raised when resolution isn't acceptable.""" + pass + + +class LengthError(DataError): + """Raised when length isn't acceptable.""" + pass \ No newline at end of file diff --git a/data/interpolant.py b/data/interpolant.py new file mode 100644 index 0000000000000000000000000000000000000000..b66ec5f4442759bfccdc29b1be1dea6d273f01b1 --- /dev/null +++ b/data/interpolant.py @@ -0,0 +1,310 @@ +import torch +import numpy as np +from data import so3_utils +from data import utils as du +from scipy.spatial.transform import Rotation +from data import all_atom +import copy +from scipy.optimize import linear_sum_assignment + + +def _centered_gaussian(num_batch, num_res, device): + noise = torch.randn(num_batch, num_res, 3, device=device) + return noise - torch.mean(noise, dim=-2, keepdims=True) + +def _uniform_so3(num_batch, num_res, device): + return torch.tensor( + Rotation.random(num_batch*num_res).as_matrix(), + device=device, + dtype=torch.float32, + ).reshape(num_batch, num_res, 3, 3) + +def _trans_diffuse_mask(trans_t, trans_1, diffuse_mask): + return trans_t * diffuse_mask[..., None] + trans_1 * (1 - diffuse_mask[..., None]) + +def _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask): + return ( + rotmats_t * diffuse_mask[..., None, None] + + rotmats_1 * (1 - diffuse_mask[..., None, None]) + ) + + +class Interpolant: + + def __init__(self, cfg): + self._cfg = cfg + self._rots_cfg = cfg.rots + self._trans_cfg = cfg.trans + self._sample_cfg = cfg.sampling + self.add_noise = cfg.add_noise + self._igso3 = None + + @property + def igso3(self): + if self._igso3 is None: + sigma_grid = torch.linspace(0.1, 1.5, 1000) + self._igso3 = so3_utils.SampleIGSO3( + 1000, sigma_grid, cache_dir='.cache') + return self._igso3 + + def set_device(self, device): + self._device = device + + def sample_t(self, num_batch): + # t: [min_t, 1-min_t] + t = torch.rand(num_batch, device=self._device) + return t * (1 - 2*self._cfg.min_t) + self._cfg.min_t + + def _esmfold_gaussian(self, num_batch, num_res, device, trans_esmfold): + noise = torch.randn(num_batch, num_res, 3, device=device) # (B,L,3) + noise = self._trans_cfg.noise_scale * noise + trans_esmfold + return noise - torch.mean(noise, dim=-2, keepdims=True) + + def _corrupt_trans(self, trans_1, t, res_mask, trans_esmfold): + # trans_nm_0 = _centered_gaussian(*res_mask.shape, self._device) + # trans_0 = trans_nm_0 * du.NM_TO_ANG_SCALE + + if self.add_noise: + trans_0 = self._esmfold_gaussian(*res_mask.shape, self._device, trans_esmfold) + else: + trans_0 = trans_esmfold + + + trans_0 = self._batch_ot(trans_0, trans_1, res_mask) + trans_t = (1 - t[..., None]) * trans_0 + t[..., None] * trans_1 + trans_t = _trans_diffuse_mask(trans_t, trans_1, res_mask) + return trans_t * res_mask[..., None] + + def _batch_ot(self, trans_0, trans_1, res_mask): + num_batch, num_res = trans_0.shape[:2] + noise_idx, gt_idx = torch.where( + torch.ones(num_batch, num_batch)) + batch_nm_0 = trans_0[noise_idx] + batch_nm_1 = trans_1[gt_idx] + batch_mask = res_mask[gt_idx] + aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures( + batch_nm_0, batch_nm_1, mask=batch_mask + ) + + aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3) + aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3) + + # Compute cost matrix of aligned noise to ground truth + batch_mask = batch_mask.reshape(num_batch, num_batch, num_res) + cost_matrix = torch.sum( + torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1 + ) / torch.sum(batch_mask, dim=-1) + noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix)) + return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))] + # return aligned_nm_0 + + def _esmfold_igso3(self, res_mask, rotmats_esmfold): + num_batch, num_res = res_mask.shape + noisy_rotmats = self.igso3.sample( + torch.tensor([self._rots_cfg.noise_scale]), + num_batch*num_res + ).to(self._device) + noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3) + rotmats_0 = torch.einsum( + "...ij,...jk->...ik", rotmats_esmfold, noisy_rotmats) + return rotmats_0 + + def _corrupt_rotmats(self, rotmats_1, t, res_mask, rotmats_esmfold): + # num_batch, num_res = res_mask.shape + # noisy_rotmats = self.igso3.sample( + # torch.tensor([1.5]), + # num_batch*num_res + # ).to(self._device) + # noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3) + # rotmats_0 = torch.einsum( + # "...ij,...jk->...ik", rotmats_1, noisy_rotmats) + + if self.add_noise: + rotmats_0 = self._esmfold_igso3(res_mask, rotmats_esmfold) + else: + rotmats_0 = rotmats_esmfold + + + rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0) + identity = torch.eye(3, device=self._device) + rotmats_t = ( + rotmats_t * res_mask[..., None, None] + + identity[None, None] * (1 - res_mask[..., None, None]) + ) + return _rots_diffuse_mask(rotmats_t, rotmats_1, res_mask) + + def corrupt_batch(self, batch): + noisy_batch = copy.deepcopy(batch) + + # [B, N, 3] + trans_1 = batch['trans_1'] # Angstrom + + # [B, N, 3, 3] + rotmats_1 = batch['rotmats_1'] + + # [B, N] + res_mask = batch['res_mask'] + num_batch, _ = res_mask.shape + + # [B, 1] + t = self.sample_t(num_batch)[:, None] + noisy_batch['t'] = t + + # Apply corruptions + trans_t = self._corrupt_trans(trans_1, t, res_mask, batch['trans_esmfold']) + + noisy_batch['trans_t'] = trans_t + + rotmats_t = self._corrupt_rotmats(rotmats_1, t, res_mask, batch['rotmats_esmfold']) + + noisy_batch['rotmats_t'] = rotmats_t + + + # noisy_batch['t'] = 0.5 * torch.ones_like(t) + # noisy_batch['trans_t'] = batch['trans_1'] + # noisy_batch['rotmats_t'] = batch['rotmats_1'] + + + return noisy_batch + + def rot_sample_kappa(self, t): + if self._rots_cfg.sample_schedule == 'exp': + return 1 - torch.exp(-t*self._rots_cfg.exp_rate) + elif self._rots_cfg.sample_schedule == 'linear': + return t + else: + raise ValueError( + f'Invalid schedule: {self._rots_cfg.sample_schedule}') + + def _trans_euler_step(self, d_t, t, trans_1, trans_t): + trans_vf = (trans_1 - trans_t) / (1 - t) + return trans_t + trans_vf * d_t + + def _rots_euler_step(self, d_t, t, rotmats_1, rotmats_t): + if self._rots_cfg.sample_schedule == 'linear': + scaling = 1 / (1 - t) + elif self._rots_cfg.sample_schedule == 'exp': + scaling = self._rots_cfg.exp_rate + else: + raise ValueError( + f'Unknown sample schedule {self._rots_cfg.sample_schedule}') + return so3_utils.geodesic_t( + scaling * d_t, rotmats_1, rotmats_t) + + def sample( + self, + batch, + model, + ): + res_mask = batch['res_mask'] + num_batch = batch['aatype'].shape[0] + num_res = batch['aatype'].shape[1] + aatype = batch['aatype'] + motif_mask = batch.get('motif_mask',torch.ones(aatype.shape)) + + + # Set-up initial prior samples + + # trans_0 = _centered_gaussian( + # num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE + # rotmats_0 = _uniform_so3(num_batch, num_res, self._device) + + + if self.add_noise: + trans_0 = self._esmfold_gaussian(*res_mask.shape, self._device, batch['trans_esmfold']) + rotmats_0 = self._esmfold_igso3(res_mask, batch['rotmats_esmfold']) + else: + trans_0 = batch['trans_esmfold'] + rotmats_0 = batch['rotmats_esmfold'] + + + if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)): + trans_0 = motif_mask[...,None]*trans_0+(1-motif_mask[...,None])*batch['trans_fix'] + rotmats_0 = motif_mask[...,None,None]*rotmats_0+(1-motif_mask[...,None,None])*batch['rotmats_fix'] + + + # Set-up time + + ts = torch.linspace( + self._cfg.min_t, 1.0, self._sample_cfg.num_timesteps) + + # ts = torch.linspace(np.exp(self._cfg.min_t), np.exp(1.0), self._sample_cfg.num_timesteps) + # ts = torch.log(ts) + + + + t_1 = ts[0] + + prot_traj = [(trans_0, rotmats_0)] + clean_traj = [] + for t_2 in ts[1:]: + + # Run model. + trans_t_1, rotmats_t_1 = prot_traj[-1] + batch['trans_t'] = trans_t_1 + batch['rotmats_t'] = rotmats_t_1 + t = torch.ones((num_batch, 1), device=self._device) * t_1 + batch['t'] = t + with torch.no_grad(): + model_out = model(batch) + + # Process model output. + + + pred_trans_1 = model_out['pred_trans'] + pred_rotmats_1 = model_out['pred_rotmats'] + if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)): + pred_trans_1 = motif_mask[...,None]* pred_trans_1+(1-motif_mask[...,None])*batch['trans_fix'] + pred_rotmats_1 = motif_mask[...,None,None]*pred_rotmats_1+(1-motif_mask[...,None,None])*batch['rotmats_fix'] + + + clean_traj.append( + (pred_trans_1.detach(), pred_rotmats_1.detach()) + ) + if self._cfg.self_condition: + batch['trans_sc'] = pred_trans_1 + + # Take reverse step + d_t = t_2 - t_1 + + + trans_t_2 = self._trans_euler_step( + d_t, t_1, pred_trans_1, trans_t_1) + rotmats_t_2 = self._rots_euler_step( + d_t, t_1, pred_rotmats_1, rotmats_t_1) + if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)): + trans_t_2 = motif_mask[...,None]* trans_t_2+(1-motif_mask[...,None])*batch['trans_fix'] + rotmats_t_2 = motif_mask[...,None,None]*rotmats_t_2+(1-motif_mask[...,None,None])*batch['rotmats_fix'] + + + + prot_traj.append((trans_t_2, rotmats_t_2)) + t_1 = t_2 + + # We only integrated to min_t, so need to make a final step + t_1 = ts[-1] + trans_t_1, rotmats_t_1 = prot_traj[-1] + batch['trans_t'] = trans_t_1 + batch['rotmats_t'] = rotmats_t_1 + batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1 + with torch.no_grad(): + model_out = model(batch) + + + pred_trans_1 = model_out['pred_trans'] + pred_rotmats_1 = model_out['pred_rotmats'] + if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)): + pred_trans_1 = motif_mask[...,None]* pred_trans_1+(1-motif_mask[...,None])*batch['trans_fix'] + pred_rotmats_1 = motif_mask[...,None,None]*pred_rotmats_1+(1-motif_mask[...,None,None])*batch['rotmats_fix'] + + + clean_traj.append( + (pred_trans_1.detach(), pred_rotmats_1.detach()) + ) + prot_traj.append((pred_trans_1, pred_rotmats_1)) + + # Convert trajectories to atom37. + atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask, aatype=aatype, torsions_with_CB=model_out['pred_torsions_with_CB']) + clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask, aatype=aatype, torsions_with_CB=model_out['pred_torsions_with_CB']) + + return atom37_traj, clean_atom37_traj, clean_traj diff --git a/data/parsers.py b/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..931da393fe82668dc19066a504a180f127c5c107 --- /dev/null +++ b/data/parsers.py @@ -0,0 +1,82 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Library for parsing different data structures. +Code adapted from Openfold protein.py. +""" +from Bio.PDB.Chain import Chain +import numpy as np + +from data import residue_constants +from data import protein + +Protein = protein.Protein + + +def process_chain(chain: Chain, chain_id: str) -> Protein: + """Convert a PDB chain object into a AlphaFold Protein instance. + + Forked from alphafold.common.protein.from_pdb_string + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Took out lines 94-97 which don't allow insertions in the PDB. + Sabdab uses insertions for the chothia numbering so we need to allow them. + + Took out lines 110-112 since that would mess up CDR numbering. + + Args: + chain: Instance of Biopython's chain class. + + Returns: + Protein object with protein features. + """ + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + chain_ids = [] + for res in chain: + # for residue type "X", the chain may need to be removed + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name] + ] = atom.bfactor + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + chain_ids.append(chain_id) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=np.array(chain_ids), + b_factors=np.array(b_factors)) \ No newline at end of file diff --git a/data/pdb_dataloader.py b/data/pdb_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..68b1bd22198be930dcf13f290c52ca34dede21bd --- /dev/null +++ b/data/pdb_dataloader.py @@ -0,0 +1,242 @@ +"""PDB data loader.""" +import math +import torch +import tree +import numpy as np +import torch +import pandas as pd +import logging +import os +import random +import esm +import copy + +from data import utils as du +from data.repr import get_pre_repr +from openfold.data import data_transforms +from openfold.utils import rigid_utils +from data.residue_constants import restype_atom37_mask, order2restype_with_mask + +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler, dist +from scipy.spatial.transform import Rotation as scipy_R + + + + +class PdbDataModule(LightningDataModule): + def __init__(self, data_cfg): + super().__init__() + self.data_cfg = data_cfg + self.loader_cfg = data_cfg.loader + self.dataset_cfg = data_cfg.dataset + self.sampler_cfg = data_cfg.sampler + + def setup(self, stage: str): + self._train_dataset = PdbDataset( + dataset_cfg=self.dataset_cfg, + is_training=True, + ) + self._valid_dataset = PdbDataset( + dataset_cfg=self.dataset_cfg, + is_training=False, + ) + + def train_dataloader(self, rank=None, num_replicas=None): + num_workers = self.loader_cfg.num_workers + return DataLoader( # default batch_size is 1, and it is expand in training_step of FlowModule + self._train_dataset, + + # batch_sampler=LengthBatcher( + # sampler_cfg=self.sampler_cfg, + # metadata_csv=self._train_dataset.csv, + # rank=rank, + # num_replicas=num_replicas, + # ), + sampler=DistributedSampler(self._train_dataset, shuffle=True), + + num_workers=self.loader_cfg.num_workers, + prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, + persistent_workers=True if num_workers > 0 else False, + # persistent_workers=False, + ) + + def val_dataloader(self): + num_workers = self.loader_cfg.num_workers + return DataLoader( + self._valid_dataset, + sampler=DistributedSampler(self._valid_dataset, shuffle=False), + num_workers=self.loader_cfg.num_workers, + prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, + persistent_workers=True, + # persistent_workers=False, + ) + + +class PdbDataset(Dataset): + def __init__( + self, + *, + dataset_cfg, + is_training, + ): + self._log = logging.getLogger(__name__) + self._is_training = is_training + self._dataset_cfg = dataset_cfg + self.split_frac = self._dataset_cfg.split_frac + self.random_seed = self._dataset_cfg.seed + # self.count = 0 + + self._init_metadata() + + @property + def is_training(self): + return self._is_training + + @property + def dataset_cfg(self): + return self._dataset_cfg + + def _init_metadata(self): + """Initialize metadata.""" + + # Process CSV with different filtering criterions. + pdb_csv = pd.read_csv(self.dataset_cfg.csv_path) + self.raw_csv = pdb_csv + pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_num_res] + pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= self.dataset_cfg.min_num_res] + + if self.dataset_cfg.subset is not None: + pdb_csv = pdb_csv.iloc[:self.dataset_cfg.subset] + pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) + + # energy_csv_path = self.dataset_cfg.energy_csv_path + # self.energy_csv = pd.read_csv(energy_csv_path) + + ## Training or validation specific logic. + if self.is_training: + self.csv = pdb_csv[pdb_csv['is_trainset']] + self.csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed).reset_index() + self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"train.csv"), index=False) + + # self.chain_feats_total = [] + + # for idx in range(len(self.csv)): + # if idx % 200 == 0: + # self._log.info(f"pre_count= {idx}") + + # # processed_path = self.csv.iloc[idx]['processed_path'] + + # # chain_feats_temp = self._process_csv_row(processed_path) + # # chain_feats_temp = du.read_pkl(processed_path) + # # chain_feats_temp['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) + # self.chain_feats_total[idx]['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) + + # # self.chain_feats_total += [chain_feats_temp] + + + # self._log.info( + # f"Training: {len(self.chain_feats_total)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") + self._log.info( + f"Training: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") + else: + self.csv = pdb_csv[~pdb_csv['is_trainset']] + # if self.split_frac < 1.0: + # train_csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed) + # pdb_csv = pdb_csv.drop(train_csv.index) + self.csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_eval_length] + self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"valid.csv"), index=False) + + self.csv = self.csv.sample(n=min(self.dataset_cfg.max_valid_num, len(self.csv)), random_state=self.random_seed).reset_index() + + # self.chain_feats_total = [] + # for idx in range(len(self.csv)): + # processed_path = self.csv.iloc[idx]['processed_path'] + + # # chain_feats_temp = self._process_csv_row(processed_path) + # chain_feats_temp = du.read_pkl(processed_path) + # chain_feats_temp['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) + + # self.chain_feats_total += [chain_feats_temp] + + # self._log.info( + # f"Valid: {len(self.chain_feats_total)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") + self._log.info( + f"Valid: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") + + # def _process_csv_row(self, processed_file_path): + # self.count += 1 + # if self.count%200==0: + # self._log.info( + # f"pre_count= {self.count}") + + # output_total = du.read_pkl(processed_file_path) + # energy_csv = self.energy_csv + + # file = os.path.basename(processed_file_path).replace(".pkl", ".pdb") + + # matching_rows = energy_csv[energy_csv['traj_filename'] == file] + # if not matching_rows.empty: + # output_total['energy'] = torch.tensor(matching_rows['energy'].values[0], dtype=torch.float32) + + # return output_total + + def __len__(self): + return len(self.csv) + + def __getitem__(self, idx): + # chain_feats = self.chain_feats_total[idx] + + processed_path = self.csv.iloc[idx]['processed_path'] + chain_feats = du.read_pkl(processed_path) + chain_feats['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) + + energy = chain_feats['energy'] + + + if self.is_training and self._dataset_cfg.use_split: + # split_len = self._dataset_cfg.split_len + + split_len = random.randint(self.dataset_cfg.min_num_res, min(self._dataset_cfg.split_len, chain_feats['aatype'].shape[0])) + + idx = random.randint(0,chain_feats['aatype'].shape[0]-split_len) + output_total = copy.deepcopy(chain_feats) + + output_total['energy'] = torch.ones(chain_feats['aatype'].shape) + + output_temp = tree.map_structure(lambda x: x[idx:idx+split_len], output_total) + + bb_center = np.sum(output_temp['bb_positions'], axis=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) # (3,) + output_temp['trans_1']=(output_temp['trans_1'] - torch.from_numpy(bb_center[None, :])).float() + output_temp['bb_positions']=output_temp['bb_positions']- bb_center[None, :] + output_temp['all_atom_positions']=output_temp['all_atom_positions'] - torch.from_numpy(bb_center[None, None, :]) + output_temp['pair_repr_pre'] = output_temp['pair_repr_pre'][:,idx:idx+split_len] + + bb_center_esmfold = torch.sum(output_temp['trans_esmfold'], dim=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) # (3,) + output_temp['trans_esmfold']=(output_temp['trans_esmfold'] - bb_center_esmfold[None, :]).float() + + chain_feats = output_temp + chain_feats['energy'] = energy + + + if self._dataset_cfg.use_rotate_enhance: + rot_vet = [random.random() for _ in range(3)] + rot_mat = torch.tensor(scipy_R.from_rotvec(rot_vet).as_matrix()) # (3,3) + chain_feats['all_atom_positions']=torch.einsum('lij,kj->lik',chain_feats['all_atom_positions'], + rot_mat.type(chain_feats['all_atom_positions'].dtype)) + + all_atom_mask = np.array([restype_atom37_mask[i] for i in chain_feats['aatype']]) + + chain_feats_temp = { + 'aatype': chain_feats['aatype'], + 'all_atom_positions': chain_feats['all_atom_positions'], + 'all_atom_mask': torch.tensor(all_atom_mask).double(), + } + chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp) + curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0] + chain_feats['trans_1'] = curr_rigid.get_trans() + chain_feats['rotmats_1'] = curr_rigid.get_rots().get_rot_mats() + chain_feats['bb_positions']=(chain_feats['trans_1']).numpy().astype(chain_feats['bb_positions'].dtype) + + return chain_feats diff --git a/data/process_pdb_files.py b/data/process_pdb_files.py new file mode 100644 index 0000000000000000000000000000000000000000..36a3c37d2a8bd9a219d25a7e6fdc9693406e33f1 --- /dev/null +++ b/data/process_pdb_files.py @@ -0,0 +1,299 @@ +import argparse +import dataclasses +import functools as fn +import pandas as pd +import os +import tree +import torch +import multiprocessing as mp +import time +import esm +from Bio import PDB +import numpy as np +from data import utils as du +from data import parsers +from data import errors +from data.repr import get_pre_repr +from openfold.data import data_transforms +from openfold.utils import rigid_utils +from data.cal_trans_rotmats import cal_trans_rotmats +from data.ESMfold_pred import ESMFold_Pred + + +def process_file(file_path: str, write_dir: str): + """Processes protein file into usable, smaller pickles. + + Args: + file_path: Path to file to read. + write_dir: Directory to write pickles to. + + Returns: + Saves processed protein to pickle and returns metadata. + + Raises: + DataError if a known filtering rule is hit. + All other errors are unexpected and are propogated. + """ + metadata = {} + pdb_name = os.path.basename(file_path).replace('.pdb', '') + metadata['pdb_name'] = pdb_name + + processed_path = os.path.join(write_dir, f'{pdb_name}.pkl') + metadata['processed_path'] = os.path.abspath(processed_path) + metadata['raw_path'] = file_path + parser = PDB.PDBParser(QUIET=True) + structure = parser.get_structure(pdb_name, file_path) + + # Extract all chains + struct_chains = { + chain.id.upper(): chain + for chain in structure.get_chains()} + metadata['num_chains'] = len(struct_chains) + + # Extract features + struct_feats = [] + all_seqs = set() + for chain_id, chain in struct_chains.items(): + # Convert chain id into int + chain_id = du.chain_str_to_int(chain_id) + chain_prot = parsers.process_chain(chain, chain_id) + chain_dict = dataclasses.asdict(chain_prot) + chain_dict = du.parse_chain_feats(chain_dict) + all_seqs.add(tuple(chain_dict['aatype'])) + struct_feats.append(chain_dict) + if len(all_seqs) == 1: + metadata['quaternary_category'] = 'homomer' + else: + metadata['quaternary_category'] = 'heteromer' + complex_feats = du.concat_np_features(struct_feats, False) + + # Process geometry features + complex_aatype = complex_feats['aatype'] + metadata['seq_len'] = len(complex_aatype) + modeled_idx = np.where(complex_aatype != 20)[0] + if np.sum(complex_aatype != 20) == 0: + raise errors.LengthError('No modeled residues') + min_modeled_idx = np.min(modeled_idx) + max_modeled_idx = np.max(modeled_idx) + metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1 + complex_feats['modeled_idx'] = modeled_idx + + # try: + # # MDtraj + # traj = md.load(file_path) + # # SS calculation + # pdb_ss = md.compute_dssp(traj, simplified=True) + # # DG calculation + # pdb_dg = md.compute_rg(traj) + # # os.remove(file_path) + # except Exception as e: + # # os.remove(file_path) + # raise errors.DataError(f'Mdtraj failed with error {e}') + + # chain_dict['ss'] = pdb_ss[0] + # metadata['coil_percent'] = np.sum(pdb_ss == 'C') / metadata['modeled_seq_len'] + # metadata['helix_percent'] = np.sum(pdb_ss == 'H') / metadata['modeled_seq_len'] + # metadata['strand_percent'] = np.sum(pdb_ss == 'E') / metadata['modeled_seq_len'] + + # # Radius of gyration + # metadata['radius_gyration'] = pdb_dg[0] + + # Write features to pickles. + du.write_pkl(processed_path, complex_feats) + + return metadata + + +def process_serially(all_paths, write_dir): + all_metadata = [] + for i, file_path in enumerate(all_paths): + try: + start_time = time.time() + metadata = process_file( + file_path, + write_dir) + elapsed_time = time.time() - start_time + print(f'Finished {file_path} in {elapsed_time:2.2f}s') + all_metadata.append(metadata) + except errors.DataError as e: + print(f'Failed {file_path}: {e}') + return all_metadata + + +def process_fn( + file_path, + verbose=None, + write_dir=None): + try: + start_time = time.time() + metadata = process_file( + file_path, + write_dir) + elapsed_time = time.time() - start_time + if verbose: + print(f'Finished {file_path} in {elapsed_time:2.2f}s') + return metadata + except errors.DataError as e: + if verbose: + print(f'Failed {file_path}: {e}') + + +def main(args): + pdb_dir = args.pdb_dir + all_file_paths = [ + os.path.join(pdb_dir, x) + for x in os.listdir(args.pdb_dir) if '.pdb' in x] + total_num_paths = len(all_file_paths) + write_dir = args.write_dir + if not os.path.exists(write_dir): + os.makedirs(write_dir) + if args.debug: + metadata_file_name = 'metadata_debug.csv' + else: + metadata_file_name = 'metadata.csv' + metadata_path = os.path.join(write_dir, metadata_file_name) + print(f'Files will be written to {write_dir}') + + # Process each mmcif file + if args.num_processes == 1 or args.debug: + all_metadata = process_serially( + all_file_paths, + write_dir) + else: + _process_fn = fn.partial( # fix some args of fn + process_fn, + verbose=args.verbose, + write_dir=write_dir) + with mp.Pool(processes=args.num_processes) as pool: + all_metadata = pool.map(_process_fn, all_file_paths) # all jobs to be done are saved in a iterable object (such as list) + all_metadata = [x for x in all_metadata if x is not None] + metadata_df = pd.DataFrame(all_metadata) + metadata_df.to_csv(metadata_path, index=False) + succeeded = len(all_metadata) + print( + f'Finished processing {succeeded}/{total_num_paths} files') + + +def cal_repr(processed_file_path, model_esm2, alphabet, batch_converter, esm_device): + print(f'cal_repr for {processed_file_path}') + processed_feats_org = du.read_pkl(processed_file_path) + processed_feats = du.parse_chain_feats(processed_feats_org) + + # Only take modeled residues. + modeled_idx = processed_feats['modeled_idx'] + min_idx = np.min(modeled_idx) + max_idx = np.max(modeled_idx) + # del processed_feats['modeled_idx'] + processed_feats = tree.map_structure( + lambda x: x[min_idx:(max_idx+1)], processed_feats) + + # Run through OpenFold data transforms. + chain_feats = { + 'aatype': torch.tensor(processed_feats['aatype']).long(), + 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), + 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() + } + chain_feats = data_transforms.atom37_to_frames(chain_feats) + rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0] + rotmats_1 = rigids_1.get_rots().get_rot_mats() + trans_1 = rigids_1.get_trans() + # res_idx = processed_feats['residue_index'] + + node_repr_pre, pair_repr_pre = get_pre_repr(chain_feats['aatype'], model_esm2, alphabet, batch_converter, device = esm_device) # (B,L,d_node_pre=1280), (B,L,L,d_edge_pre=20) + node_repr_pre = node_repr_pre[0].cpu() + pair_repr_pre = pair_repr_pre[0].cpu() + + out = { + 'aatype': chain_feats['aatype'], + 'rotmats_1': rotmats_1, + 'trans_1': trans_1, # (L,3) + 'res_mask': torch.tensor(processed_feats['bb_mask']).int(), + 'bb_positions': processed_feats['bb_positions'], + 'all_atom_positions':chain_feats['all_atom_positions'], + 'node_repr_pre':node_repr_pre, + 'pair_repr_pre':pair_repr_pre, + } + + du.write_pkl(processed_file_path, out) + +def cal_static_structure(processed_file_path, raw_pdb_file, ESMFold): + output_total = du.read_pkl(processed_file_path) + + save_dir = os.path.join(os.path.dirname(raw_pdb_file), 'ESMFold_Pred_results') + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, os.path.basename(processed_file_path)[:6]+'_esmfold.pdb') + if not os.path.exists(save_path): + print(f'cal_static_structure for {processed_file_path}') + ESMFold.predict_str(raw_pdb_file, save_path) + trans, rotmats = cal_trans_rotmats(save_path) + output_total['trans_esmfold'] = trans + output_total['rotmats_esmfold'] = rotmats + + du.write_pkl(processed_file_path, output_total) + + +def merge_pdb(metadata_path, traj_info_file, valid_seq_file, merged_output_file): + df1 = pd.read_csv(metadata_path) + df2 = pd.read_csv(traj_info_file) + df3 = pd.read_csv(valid_seq_file) + + # 获取文件名 + df1['traj_filename'] = [os.path.basename(i) for i in df1['raw_path']] + + # 合并数据 + merged = df1.merge(df2[['traj_filename', 'energy']], on='traj_filename', how='left') + merged['is_trainset'] = ~merged['traj_filename'].str[:6].isin(df3['file']) + + # 保存结果 + merged.to_csv(merged_output_file, index=False) + print('merge complete!') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--pdb_dir", type=str, default="./dataset/ATLAS/select") + parser.add_argument("--write_dir", type=str, default="./dataset/ATLAS/select/pkl") + parser.add_argument("--csv_name", type=str, default="metadata.csv") + parser.add_argument("--debug", type=bool, default=False) + parser.add_argument("--num_processes", type=int, default=48) + parser.add_argument('--verbose', help='Whether to log everything.',action='store_true') + + parser.add_argument("--esm_device", type=str, default='cuda') + + parser.add_argument("--traj_info_file", type=str, default='./dataset/ATLAS/select/traj_info_select.csv') + parser.add_argument("--valid_seq_file", type=str, default='./inference/valid_seq.csv') + parser.add_argument("--merged_output_file", type=str, default='./dataset/ATLAS/select/pkl/metadata_merged.csv') + + args = parser.parse_args() + + # process .pdb to .pkl + main(args) + + # cal_repr + csv_path = os.path.join(args.write_dir, args.csv_name) + pdb_csv = pd.read_csv(csv_path) + pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) + model_esm2, alphabet = esm.pretrained.esm2_t33_650M_UR50D() + batch_converter = alphabet.get_batch_converter() + model_esm2.eval() + model_esm2.requires_grad_(False) + model_esm2.to(args.esm_device) + for idx in range(len(pdb_csv)): + cal_repr(pdb_csv.iloc[idx]['processed_path'], model_esm2, alphabet, batch_converter, args.esm_device) + + # cal_static_structure + csv_path = os.path.join(args.write_dir, args.csv_name) + pdb_csv = pd.read_csv(csv_path) + ESMFold = ESMFold_Pred(device = args.esm_device) + for idx in range(len(pdb_csv)): + cal_static_structure(pdb_csv.iloc[idx]['processed_path'], pdb_csv.iloc[idx]['raw_path'], ESMFold) + + # merge csv + csv_path = os.path.join(args.write_dir, args.csv_name) + merge_pdb(csv_path, args.traj_info_file, args.valid_seq_file, args.merged_output_file) + + + + + diff --git a/data/protein.py b/data/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcb3b1411ef907e8f24a9e4b62e718fd25ca8aa --- /dev/null +++ b/data/protein.py @@ -0,0 +1,280 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional +from data import residue_constants +from Bio.PDB import PDBParser +import numpy as np + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors)) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein, model=1, add_end=True) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(int) + chain_index = prot.chain_index.astype(int) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append(f'MODEL {model}') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append(_chain_end( + atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], + residue_index[i - 1])) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], residue_index[-1])) + pdb_lines.append('ENDMDL') + if add_end: + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = True) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values. + + Returns: + A protein instance. + """ + fold_output = result['structure_module'] + + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if 'asym_id' in features: + chain_index = _maybe_remove_leading_dim(features['asym_id']) + else: + chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + return Protein( + aatype=_maybe_remove_leading_dim(features['aatype']), + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, + chain_index=chain_index, + b_factors=b_factors) \ No newline at end of file diff --git a/data/repr.py b/data/repr.py new file mode 100644 index 0000000000000000000000000000000000000000..d97759ca982fb7cde94fefffad952753c85be534 --- /dev/null +++ b/data/repr.py @@ -0,0 +1,38 @@ +import torch +from opt_einsum import contract as einsum +import esm +from data.residue_constants import order2restype_with_mask + +def get_pre_repr(seqs, model, alphabet, batch_converter, device="cuda:0"): + + # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) + # data = [ + # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), + # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), + # ("protein2 with mask","KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), + # ("protein3", "K A I S Q"), + # ] + data = [] + for idx, seq in enumerate([seqs]): + seq_string = ''.join([order2restype_with_mask[int(i)] for i in seq]) + data.append(("protein_"+str(idx), seq_string)) + + batch_labels, batch_strs, batch_tokens = batch_converter(data) + batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) + + # Extract per-residue representations (on CPU) + with torch.no_grad(): + results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True) + + node_repr = results["representations"][33][:,1:-1,:] + pair_repr = results['attentions'][:,33-1,:,1:-1,1:-1].permute(0,2,3,1) + # Generate per-sequence representations via averaging + # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. + # sequence_representations = [] + # for i, tokens_len in enumerate(batch_lens): + # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0)) + + # Look at the unsupervised self-attention map contact predictions + # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]): + # plt.matshow(attention_contacts[: tokens_len, : tokens_len]) + return node_repr, pair_repr \ No newline at end of file diff --git a/data/residue_constants.py b/data/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9491831538549b2ce77224b7dd6513b7c0b114c7 --- /dev/null +++ b/data/residue_constants.py @@ -0,0 +1,902 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]]]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt' + ) + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, + residue_virtual_bonds, + residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + +restypes_with_mask = restypes + [''] + [''] # idx=20 for unknown, idx=21 for mask +# restype_order_with_mask = {restype: i for i, restype in enumerate(restypes_with_mask)} +order2restype_with_mask = {i: restype for i, restype in enumerate(restypes_with_mask)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError('The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError(f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=int) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1]*(4-len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } \ No newline at end of file diff --git a/data/so3_utils.py b/data/so3_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96353d8765d8b8b6258e8f3508d1c0e1909435bd --- /dev/null +++ b/data/so3_utils.py @@ -0,0 +1,1375 @@ +import logging +import os +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +def scale_rotmat( + rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-7 +) -> torch.Tensor: + """ + Scale rotation matrix. This is done by converting it to vector representation, + scaling the length of the vector and converting back to matrix representation. + + Args: + rotation_matrix: Rotation matrices. + scalar: Scalar values used for scaling. Should have one fewer dimension than the + rotation matrices for correct broadcasting. + tol: Numerical offset for stability. + + Returns: + Scaled rotation matrix. + """ + # Check whether dimensions match. + assert rotation_matrix.ndim - 1 == scalar.ndim + scaled_rmat = rotvec_to_rotmat(rotmat_to_rotvec(rotation_matrix) * scalar, tol=tol) + return scaled_rmat + + +def _broadcast_identity(target: torch.Tensor) -> torch.Tensor: + """ + Generate a 3 by 3 identity matrix and broadcast it to a batch of target matrices. + + Args: + target (torch.Tensor): Batch of target 3 by 3 matrices. + + Returns: + torch.Tensor: 3 by 3 identity matrices in the shapes of the target. + """ + id3 = torch.eye(3, device=target.device, dtype=target.dtype) + id3 = torch.broadcast_to(id3, target.shape) + return id3 + + +def skew_matrix_exponential_map_axis_angle( + angles: torch.Tensor, skew_matrices: torch.Tensor +) -> torch.Tensor: + """ + Compute the matrix exponential of a rotation in axis-angle representation with the axis in skew + matrix representation form. Maps the rotation from the lie group to the rotation matrix + representation. Uses Rodrigues' formula instead of `torch.linalg.matrix_exp` for better + computational performance: + + .. math:: + + \exp(\theta \mathbf{K}) = \mathbf{I} + \sin(\theta) \mathbf{K} + [1 - \cos(\theta)] \mathbf{K}^2 + + Args: + angles (torch.Tensor): Batch of rotation angles. + skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis. + + Returns: + torch.Tensor: Batch of corresponding rotation matrices. + """ + # Set up identity matrix and broadcast. + id3 = _broadcast_identity(skew_matrices) + + # Broadcast angle vector to right dimensions + angles = angles[..., None, None] + + exp_skew = ( + id3 + + torch.sin(angles) * skew_matrices + + (1.0 - torch.cos(angles)) + * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices) + ) + return exp_skew + + +def skew_matrix_exponential_map( + angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-7 +) -> torch.Tensor: + """ + Compute the matrix exponential of a rotation vector in skew matrix representation. Maps the + rotation from the lie group to the rotation matrix representation. Uses the following form of + Rodrigues' formula instead of `torch.linalg.matrix_exp` for better computational performance + (in this case the skew matrix already contains the angle factor): + + .. math :: + + \exp(\mathbf{K}) = \mathbf{I} + \frac{\sin(\theta)}{\theta} \mathbf{K} + \frac{1-\cos(\theta)}{\theta^2} \mathbf{K}^2 + + This form has the advantage, that Taylor expansions can be used for small angles (instead of + having to compute the unit length axis by dividing the rotation vector by small angles): + + .. math :: + + \frac{\sin(\theta)}{\theta} \approx 1 - \frac{\theta^2}{6} + \frac{1-\cos(\theta)}{\theta^2} \approx \frac{1}{2} - \frac{\theta^2}{24} + + Args: + angles (torch.Tensor): Batch of rotation angles. + skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis. + + Returns: + torch.Tensor: Batch of corresponding rotation matrices. + """ + # Set up identity matrix and broadcast. + id3 = _broadcast_identity(skew_matrices) + + # Broadcast angles and pre-compute square. + angles = angles[..., None, None] + angles_sq = angles.square() + + # Get standard terms. + sin_coeff = torch.sin(angles) / angles + cos_coeff = (1.0 - torch.cos(angles)) / angles_sq + # Use second order Taylor expansion for values close to zero. + sin_coeff_small = 1.0 - angles_sq / 6.0 + cos_coeff_small = 0.5 - angles_sq / 24.0 + + mask_zero = torch.abs(angles) < tol + sin_coeff = torch.where(mask_zero, sin_coeff_small, sin_coeff) + cos_coeff = torch.where(mask_zero, cos_coeff_small, cos_coeff) + + # Compute matrix exponential using Rodrigues' formula. + exp_skew = ( + id3 + + sin_coeff * skew_matrices + + cos_coeff * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices) + ) + return exp_skew + + +def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-7) -> torch.Tensor: + """ + Convert rotation vectors to rotation matrix representation. The length of the rotation vector + is the angle of rotation, the unit vector the rotation axis. + + Args: + rotation_vectors (torch.Tensor): Batch of rotation vectors. + tol: small offset for numerical stability. + + Returns: + torch.Tensor: Rotation in rotation matrix representation. + """ + # Compute rotation angle as vector norm. + rotation_angles = torch.norm(rotation_vectors, dim=-1) + + # Map axis to skew matrix basis. + skew_matrices = vector_to_skew_matrix(rotation_vectors) + + # Compute rotation matrices via matrix exponential. + rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices, tol=tol) + + return rotation_matrices + + +def rotmat_to_rotvec(rotation_matrices: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of rotation matrices to rotation vectors (logarithmic map from SO(3) to so(3)). + The standard logarithmic map can be derived from Rodrigues' formula via Taylor approximation + (in this case operating on the vector coefficients of the skew so(3) basis). + + ..math :: + + \left[\log(\mathbf{R})\right]^\lor = \frac{\theta}{2\sin(\theta)} \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor + + This formula has problems at 1) angles theta close or equal to zero and 2) at angles close and + equal to pi. + + To improve numerical stability for case 1), the angle term at small or zero angles is + approximated by its truncated Taylor expansion: + + .. math :: + + \left[\log(\mathbf{R})\right]^\lor \approx \frac{1}{2} (1 + \frac{\theta^2}{6}) \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor + + For angles close or equal to pi (case 2), the outer product relation can be used to obtain the + squared rotation vector: + + .. math :: \omega \otimes \omega = \frac{1}{2}(\mathbf{I} + R) + + Taking the root of the diagonal elements recovers the normalized rotation vector up to the signs + of the component. The latter can be obtained from the off-diagonal elements. + + Adapted from https://github.com/jasonkyuyim/se3_diffusion/blob/2cba9e09fdc58112126a0441493b42022c62bbea/data/so3_utils.py + which was adapted from https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/special_orthogonal.py + with heavy help from https://cvg.cit.tum.de/_media/members/demmeln/nurlanov2021so3log.pdf + + Args: + rotation_matrices (torch.Tensor): Input batch of rotation matrices. + + Returns: + torch.Tensor: Batch of rotation vectors. + """ + # Get angles and sin/cos from rotation matrix. + angles, angles_sin, _ = angle_from_rotmat(rotation_matrices) + # Compute skew matrix representation and extract so(3) vector components. + vector = skew_matrix_to_vector(rotation_matrices - rotation_matrices.transpose(-2, -1)) + + # Three main cases for angle theta, which are captured + # 1) Angle is 0 or close to zero -> use Taylor series for small values / return 0 vector. + mask_zero = torch.isclose(angles, torch.zeros_like(angles)).to(angles.dtype) + # 2) Angle is close to pi -> use outer product relation. + mask_pi = torch.isclose(angles, torch.full_like(angles, np.pi), atol=1e-2).to(angles.dtype) + # 3) Angle is unproblematic -> use the standard formula. + mask_else = (1 - mask_zero) * (1 - mask_pi) + + # Compute case dependent pre-factor (1/2 for angle close to 0, angle otherwise). + numerator = mask_zero / 2.0 + angles * mask_else + # The Taylor expansion used here is actually the inverse of the Taylor expansion of the inverted + # fraction sin(x) / x which gives better accuracy over a wider range (hence the minus and + # position in denominator). + denominator = ( + (1.0 - angles**2 / 6.0) * mask_zero # Taylor expansion for small angles. + + 2.0 * angles_sin * mask_else # Standard formula. + + mask_pi # Avoid zero division at angle == pi. + ) + prefactor = numerator / denominator + vector = vector * prefactor[..., None] + + # For angles close to pi, derive vectors from their outer product (ww' = 1 + R). + id3 = _broadcast_identity(rotation_matrices) + skew_outer = (id3 + rotation_matrices) / 2.0 + # Ensure diagonal is >= 0 for square root (uses identity for masking). + skew_outer = skew_outer + (torch.relu(skew_outer) - skew_outer) * id3 + + # Get basic rotation vector as sqrt of diagonal (is unit vector). + vector_pi = torch.sqrt(torch.diagonal(skew_outer, dim1=-2, dim2=-1)) + + # Compute the signs of vector elements (up to a global phase). + # Fist select indices for outer product slices with the largest norm. + signs_line_idx = torch.argmax(torch.norm(skew_outer, dim=-1), dim=-1).long() + # Select rows of outer product and determine signs. + signs_line = torch.take_along_dim(skew_outer, dim=-2, indices=signs_line_idx[..., None, None]) + signs_line = signs_line.squeeze(-2) + signs = torch.sign(signs_line) + + # Apply signs and rotation vector. + vector_pi = vector_pi * angles[..., None] * signs + + # Fill entries for angle == pi in rotation vector (basic vector has zero entries at this point). + vector = vector + vector_pi * mask_pi[..., None] + + return vector + + +def angle_from_rotmat( + rotation_matrices: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute rotation angles (as well as their sines and cosines) encoded by rotation matrices. + Uses atan2 for better numerical stability for small angles. + + Args: + rotation_matrices (torch.Tensor): Batch of rotation matrices. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Batch of computed angles, sines of the + angles and cosines of angles. + """ + # Compute sine of angles (uses the relation that the unnormalized skew vector generated by a + # rotation matrix has the length 2*sin(theta)) + skew_matrices = rotation_matrices - rotation_matrices.transpose(-2, -1) + skew_vectors = skew_matrix_to_vector(skew_matrices) + angles_sin = torch.norm(skew_vectors, dim=-1) / 2.0 + # Compute the cosine of the angle using the relation cos theta = 1/2 * (Tr[R] - 1) + angles_cos = (torch.einsum("...ii", rotation_matrices) - 1.0) / 2.0 + + # Compute angles using the more stable atan2 + angles = torch.atan2(angles_sin, angles_cos) + + return angles, angles_sin, angles_cos + + +def vector_to_skew_matrix(vectors: torch.Tensor) -> torch.Tensor: + """ + Map a vector into the corresponding skew matrix so(3) basis. + ``` + [ 0 -z y] + [x,y,z] -> [ z 0 -x] + [ -y x 0] + ``` + + Args: + vectors (torch.Tensor): Batch of vectors to be mapped to skew matrices. + + Returns: + torch.Tensor: Vectors in skew matrix representation. + """ + # Generate empty skew matrices. + skew_matrices = torch.zeros((*vectors.shape, 3), device=vectors.device, dtype=vectors.dtype) + + # Populate positive values. + skew_matrices[..., 2, 1] = vectors[..., 0] + skew_matrices[..., 0, 2] = vectors[..., 1] + skew_matrices[..., 1, 0] = vectors[..., 2] + + # Generate skew symmetry. + skew_matrices = skew_matrices - skew_matrices.transpose(-2, -1) + + return skew_matrices + + +def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor: + """ + Extract a rotation vector from the so(3) skew matrix basis. + + Args: + skew_matrices (torch.Tensor): Skew matrices. + + Returns: + torch.Tensor: Rotation vectors corresponding to skew matrices. + """ + vectors = torch.zeros_like(skew_matrices[..., 0]) + vectors[..., 0] = skew_matrices[..., 2, 1] + vectors[..., 1] = skew_matrices[..., 0, 2] + vectors[..., 2] = skew_matrices[..., 1, 0] + return vectors + + +def _rotquat_to_axis_angle( + rotation_quaternions: torch.Tensor, tol: float = 1e-7 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Auxiliary routine for computing rotation angle and rotation axis from unit quaternions. To avoid + complications, rotations vectors with angles below `tol` are set to zero. + + Args: + rotation_quaternions (torch.Tensor): Rotation quaternions in [r, i, j, k] format. + tol (float, optional): Threshold for small rotations. Defaults to 1e-7. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotation angles and axes. + """ + # Compute rotation axis and normalize (accounting for small length axes). + rotation_axes = rotation_quaternions[..., 1:] + rotation_axes_norms = torch.norm(rotation_axes, dim=-1) + + # Compute rotation angle via atan2 + rotation_angles = 2.0 * torch.atan2(rotation_axes_norms, rotation_quaternions[..., 0]) + + # Save division. + rotation_axes = rotation_axes / (rotation_axes_norms[:, None] + tol) + return rotation_angles, rotation_axes + + +def rotquat_to_rotvec(rotation_quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert unit quaternions to rotation vectors. + + Args: + rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format. + + Returns: + torch.Tensor: Rotation vectors. + """ + rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions) + rotation_vectors = rotation_axes * rotation_angles[..., None] + return rotation_vectors + + +def rotquat_to_rotmat(rotation_quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert unit quaternion to rotation matrix. + + Args: + rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format. + + Returns: + torch.Tensor: Rotation matrices. + """ + rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions) + skew_matrices = vector_to_skew_matrix(rotation_axes * rotation_angles[..., None]) + rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices) + return rotation_matrices + + +def apply_rotvec_to_rotmat( + rotation_matrices: torch.Tensor, + rotation_vectors: torch.Tensor, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Update a rotation encoded in a rotation matrix with a rotation vector. + + Args: + rotation_matrices: Input batch of rotation matrices. + rotation_vectors: Input batch of rotation vectors. + tol: Small offset for numerical stability. + + Returns: + Updated rotation matrices. + """ + # Convert vector to matrices. + rmat_right = rotvec_to_rotmat(rotation_vectors, tol=tol) + # Accumulate rotation. + rmat_rotated = torch.einsum("...ij,...jk->...ik", rotation_matrices, rmat_right) + return rmat_rotated + + +def rotmat_to_skew_matrix(mat: torch.Tensor) -> torch.Tensor: + """ + Generates skew matrix for corresponding rotation matrix. + + Args: + mat (torch.Tensor): Batch of rotation matrices. + + Returns: + torch.Tensor: Skew matrices in the shapes of mat. + """ + vec = rotmat_to_rotvec(mat) + return vector_to_skew_matrix(vec) + + +def skew_matrix_to_rotmat(skew: torch.Tensor) -> torch.Tensor: + """ + Generates rotation matrix for corresponding skew matrix. + + Args: + skew (torch.Tensor): Batch of target 3 by 3 skew symmetric matrices. + + Returns: + torch.Tensor: Rotation matrices in the shapes of skew. + """ + vec = skew_matrix_to_vector(skew) + return rotvec_to_rotmat(vec) + + +def local_log(point: torch.Tensor, base_point: torch.Tensor) -> torch.Tensor: + """ + Matrix logarithm. Computes left-invariant vector field of beinging base_point to point + on the manifold. Follows the signature of geomstats' equivalent function. + https://geomstats.github.io/api/geometry.html#geomstats.geometry.lie_group.MatrixLieGroup.log + + Args: + point (torch.Tensor): Batch of rotation matrices to compute vector field at. + base_point (torch.Tensor): Transport coordinates to take matrix logarithm. + + Returns: + torch.Tensor: Skew matrix that holds the vector field (in the tangent space). + """ + return rotmat_to_skew_matrix(rot_mult(rot_transpose(base_point), point)) + + +def multidim_trace(mat: torch.Tensor) -> torch.Tensor: + """Take the trace of a matrix with leading dimensions.""" + return torch.einsum("...ii->...", mat) + + +def geodesic_dist(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor: + """ + Calculate the geodesic distance of two rotation matrices. + + Args: + mat_1 (torch.Tensor): First rotation matrix. + mat_2 (torch.Tensor): Second rotation matrix. + + Returns: + Scalar for the geodesic distance between mat_1 and mat_2 with the same + leading (i.e. batch) dimensions. + """ + A = rotmat_to_skew_matrix(rot_mult(rot_transpose(mat_1), mat_2)) + return torch.sqrt(multidim_trace(rot_mult(A, rot_transpose(A)))) + + +def rot_transpose(mat: torch.Tensor) -> torch.Tensor: + """Take the transpose of the last two dimensions.""" + return torch.transpose(mat, -1, -2) + + +def rot_mult(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor: + """Matrix multiply two rotation matrices with leading dimensions.""" + return torch.einsum("...ij,...jk->...ik", mat_1, mat_2) + + +def calc_rot_vf(mat_t: torch.Tensor, mat_1: torch.Tensor) -> torch.Tensor: + """ + Computes the vector field Log_{mat_t}(mat_1). + + Args: + mat_t (torch.Tensor): base point to compute vector field at. + mat_1 (torch.Tensor): target rotation. + + Returns: + Rotation vector representing the vector field. + """ + return rotmat_to_rotvec(rot_mult(rot_transpose(mat_t), mat_1)) + + +def geodesic_t(t: float, mat: torch.Tensor, base_mat: torch.Tensor, rot_vf=None) -> torch.Tensor: + """ + Computes the geodesic at time t. Specifically, R_t = Exp_{base_mat}(t * Log_{base_mat}(mat)). + + Args: + t: time along geodesic. + mat: target points on manifold. + base_mat: source point on manifold. + + Returns: + Point along geodesic starting at base_mat and ending at mat. + """ + if rot_vf is None: + rot_vf = calc_rot_vf(base_mat, mat) + mat_t = rotvec_to_rotmat(t * rot_vf) + if base_mat.shape != mat_t.shape: + raise ValueError( + f'Incompatible shapes: base_mat={base_mat.shape}, mat_t={mat_t.shape}') + return torch.einsum("...ij,...jk->...ik", base_mat, mat_t) + + +class SO3LookupCache: + def __init__( + self, + cache_dir: str, + cache_file: str, + overwrite: bool = False, + ) -> None: + """ + Auxiliary class for handling storage / loading of SO(3) lookup tables in npz format. + + Args: + cache_dir: Path to the cache directory. + cache_file: Basic file name of the cache file. + overwrite: Whether existing cache files should be overwritten if requested. + """ + if not cache_file.endswith(".npz"): + raise ValueError("Filename should have '.npz' extension.") + self.cache_file = cache_file + self.cache_dir = cache_dir + self.cache_path = os.path.join(cache_dir, cache_file) + self.overwrite = overwrite + + @property + def path_exists(self) -> bool: + return os.path.exists(self.cache_path) + + @property + def dir_exists(self) -> bool: + return os.path.exists(self.cache_dir) + + def delete_cache(self) -> None: + """ + Delete the cache file. + """ + if self.path_exists: + os.remove(self.cache_path) + + def load_cache(self) -> Dict[str, torch.Tensor]: + """ + Load data from the cache file. + + Returns: + Dictionary of loaded data tensors. + """ + if self.path_exists: + # Load data and convert to torch tensors. + npz_data = np.load(self.cache_path) + torch_dict = {f: torch.from_numpy(npz_data[f]) for f in npz_data.files} + logger.info(f"Data loaded from {self.cache_path}") + return torch_dict + else: + raise ValueError(f"No cache data found at {self.cache_path}.") + + def save_cache(self, data: Dict[str, torch.Tensor]) -> None: + """ + Save a dictionary of tensors to the cache file. If overwrite is set to True, an existing + file is overwritten, otherwise a warning is raised and the file is not modified. + + Args: + data: Dictionary of tensors that should be saved to the cache. + """ + if not self.dir_exists: + os.makedirs(self.cache_dir) + + if self.path_exists: + if self.overwrite: + logger.info("Overwriting cache ...") + self.delete_cache() + else: + logger.warn( + f"Cache at {self.cache_path} exits and overwriting disabled. Doing nothing." + ) + else: + # Move everything to CPU and numpy and store. + logger.info(f"Data saved to {self.cache_path}") + numpy_dict = {k: v.detach().cpu().numpy() for k, v in data.items()} + np.savez(self.cache_path, **numpy_dict) + + +class BaseSampleSO3(nn.Module): + so3_type: str = "base" # cache basename + + def __init__( + self, + num_omega: int, + sigma_grid: torch.Tensor, + omega_exponent: int = 3, + tol: float = 1e-7, + interpolate: bool = True, + cache_dir: Optional[str] = None, + overwrite_cache: bool = False, + device: str = 'cpu', + ) -> None: + """ + Base torch.nn module for sampling rotations from the IGSO(3) distribution. Samples are + created by uniformly sampling a rotation axis and using inverse transform sampling for + the angles. The latter uses the associated SO(3) cumulative probability distribution + function (CDF) and a uniform distribution [0,1] as described in [#leach2022_1]_. CDF values + are obtained by numerically integrating the probability distribution evaluated on a grid of + angles and noise levels and stored in a lookup table. Linear interpolation is used to + approximate continuos sampling of the function. Angles are discretized in an interval [0,pi] + and the grid can be squashed to have higher resolutions at low angles by taking different + powers. Since sampling relies on tabulated values of the CDF and indexing in the form of + `torch.bucketize`, gradients are not supported. + + Args: + num_omega (int): Number of discrete angles used for generating the lookup table. + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking + its power with the provided number. Defaults to 3. + tol (float, optional): Small value for numerical stability. Defaults to 1e-7. + interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF + to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. + cache_dir: Path to an optional cache directory. If set to None, lookup tables are + computed on the fly. + overwrite_cache: If set to true, existing cache files are overwritten. Can be used for + updating stale caches. + + References + ---------- + .. [#leach2022_1] Leach, Schmon, Degiacomi, Willcocks: + Denoising diffusion probabilistic models on so (3) for rotational alignment. + ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. + """ + super().__init__() + self.num_omega = num_omega + self.omega_exponent = omega_exponent + self.tol = tol + self.interpolate = interpolate + self.device = device + self.register_buffer("sigma_grid", sigma_grid, persistent=False) + + # Generate / load lookups and store in non-persistent buffers. + omega_grid, cdf_igso3 = self._setup_lookup(sigma_grid, cache_dir, overwrite_cache) + self.register_buffer("omega_grid", omega_grid, persistent=False) + self.register_buffer("cdf_igso3", cdf_igso3, persistent=False) + + def _setup_lookup( + self, + sigma_grid: torch.Tensor, + cache_dir: Optional[str] = None, + overwrite_cache: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Master function for setting up the lookup tables. These can either be loaded from a npz + cache file or computed on the fly. Lookup tables will always be created and stored in double + precision. Casting to the target dtype is done at the end of the function. + + Args: + sigma_grid: Grid of sigma values used for computing the lookup tables. + cache_dir: Path to the cache directory. + overwrite_cache: If set to true, an existing cache is overwritten. Can be used for + updating stale caches. + + Returns: + Grid of angle values and SO(3) cumulative distribution function. + """ + if cache_dir is not None: + cache_name = self._get_cache_name() + cache = SO3LookupCache(cache_dir, cache_name, overwrite=True) + + # If cache dir is provided, check whether the necessary cache exists and whether it + # should be overwritten. + if cache.path_exists and not overwrite_cache: + # Load data from cache. + cache_data = cache.load_cache() + omega_grid = cache_data["omega_grid"] + cdf_igso3 = cache_data["cdf_igso3"] + else: + # Store data in cache (overwrite if requested). + omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid) + cache.save_cache({"omega_grid": omega_grid, "cdf_igso3": cdf_igso3}) + else: + # Other wise just generate the tables. + omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid) + + return omega_grid.to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype) + + def _get_cache_name(self) -> str: + """ + Auxiliary function for determining the cache file name based on the parameters (sigma, + omega, l, etc.) used for generating the lookup tables. + + Returns: + Base name of the cache file. + """ + cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_o{:d}-{:d}.npz".format( + self.so3_type, + torch.min(self.sigma_grid).cpu().item(), + torch.max(self.sigma_grid).cpu().item(), + self.sigma_grid.shape[0], + self.num_omega, + self.omega_exponent, + ) + return cache_name + + def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor: + """ + Convert continuous sigmas to the indices of the closest tabulated values. + + Args: + sigma (torch.Tensor): IGSO3 std devs. + + Returns: + torch.Tensor: Index tensor mapping the provided sigma values to the internal lookup + table. + """ + return torch.bucketize(sigma, self.sigma_grid) + + def expansion_function( + self, omega_grid: torch.Tensor, sigma_grid: torch.Tensor + ) -> torch.Tensor: + """ + Function for generating the angle probability distribution. Should return a 2D tensor with + values for the std dev at the first dimension (rows) and angles at the second + (columns). + + Args: + omega_grid (torch.Tensor): Grid of angle values. + sigma_grid (torch.Tensor): IGSO3 std devs. + + Returns: + torch.Tensor: Distribution for angles discretized on a 2D grid. + """ + raise NotImplementedError + + @torch.no_grad() + def _generate_lookup(self, sigma_grid: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate the lookup table for sampling from the target SO(3) CDF. The table is 2D, with the + rows corresponding to different sigma values and the columns with angles computed on a grid. + Variance is scaled by a factor of 1/2 to account for the deacceleration of time in the + diffusion process due to the choice of SO(3) basis and guarantee time-reversibility (see + appendix E.3 in [#yim2023_2]_). The returned tables are double precision and will be cast + to the target dtype in `_setup_lookup`. + + Args: + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing the grid used to compute the angles + and the associated lookup table. + + References + ---------- + .. [#yim2023_2] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola: + SE(3) diffusion model with application to protein backbone generation. + arXiv preprint arXiv:2302.02277. 2023. + """ + + current_device = sigma_grid.device + sigma_grid_tmp = sigma_grid.to(torch.float64) + + # If cuda is available, initialize everything on GPU. + # Even if Pytorch Lightning usually handles GPU allocation after initialization, this is + # required to initialize the module in GPU reducing the initializaiton time by orders of magnitude. + if torch.cuda.is_available(): + sigma_grid_tmp = sigma_grid_tmp.to(device=self.device) + + # Set up grid for angle resolution. Convert to double precision for better handling of numerics. + omega_grid = torch.linspace(0.0, 1, self.num_omega + 1).to(sigma_grid_tmp) + + # If requested, increase sample density for lower values + omega_grid = omega_grid**self.omega_exponent + + omega_grid = omega_grid * np.pi + + # Compute the expansion for all omegas and sigmas. + pdf_igso3 = self.expansion_function(omega_grid, sigma_grid_tmp) + + # Apply the pre-factor from USO(3). + pdf_igso3 = pdf_igso3 * (1.0 - torch.cos(omega_grid)) / np.pi + + # Compute the cumulative probability distribution. + cdf_igso3 = integrate_trapezoid_cumulative(pdf_igso3, omega_grid) + # Normalize integral area to 1. + cdf_igso3 = cdf_igso3 / cdf_igso3[:, -1][:, None] + + # Move back to original device. + cdf_igso3 = cdf_igso3.to(device=current_device) + omega_grid = omega_grid.to(device=current_device) + + return omega_grid[1:].to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype) + + def sample(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor: + """ + Generate samples from the target SO(3) distribution by sampling a rotation axis angle, + which are then combined into a rotation vector and transformed into the corresponding + rotation matrix via an exponential map. + + Args: + sigma_indices (torch.Tensor): Indices of the IGSO3 std devs for which to take samples. + num_samples (int): Number of angle samples to take for each std dev + + Returns: + torch.Tensor: Sampled rotations in matrix representation with dimensions + [num_sigma x num_samples x 3 x 3]. + """ + + vectors = self.sample_vector(sigma.shape[0], num_samples) + angles = self.sample_angle(sigma, num_samples) + + # Do postprocessing on angles. + angles = self._process_angles(sigma, angles) + + rotation_vectors = vectors * angles[..., None] + + rotation_matrices = rotvec_to_rotmat(rotation_vectors, tol=self.tol) + return rotation_matrices + + def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: + """ + Auxiliary function for performing additional processing steps on the sampled angles. One + example would be to ensure sampled angles are 0 for a std dev of 0 for IGSO(3). + + Args: + sigma (torch.Tensor): Current values of sigma. + angles (torch.Tensor): Sampled angles. + + Returns: + torch.Tensor: Processed sampled angles. + """ + return angles + + def sample_vector(self, num_sigma: int, num_samples: int) -> torch.Tensor: + """ + Uniformly sample rotation axis for constructing the overall rotation. + + Args: + num_sigma (int): Number of samples to draw for each std dev. + num_samples (int): Number of angle samples to take for each std dev. + + Returns: + torch.Tensor: Batch of rotation axes with dimensions [num_sigma x num_samples x 3]. + """ + vectors = torch.randn(num_sigma, num_samples, 3, device=self.sigma_grid.device) + vectors = vectors / torch.norm(vectors, dim=2, keepdim=True) + return vectors + + def sample_angle(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor: + """ + Create a series of samples from the IGSO(3) angle distribution. + + Args: + sigma_indices (torch.Tensor): Indices of the IGSO3 std deves for which to + take samples. + num_samples (int): Number of angle samples to take for each std dev. + + Returns: + torch.Tensor: Collected samples, will have the dimension [num_sigma x num_samples]. + """ + # Convert sigmas to respective indices for lookup table. + sigma_indices = self.get_sigma_idx(sigma) + # Get relevant sigma slices from stored CDFs. + cdf_tmp = self.cdf_igso3[sigma_indices, :] + + # Draw from uniform distribution. + p_uniform = torch.rand((*sigma_indices.shape, *[num_samples]), device=sigma_indices.device) + + # Determine indices for CDF. + idx_stop = torch.sum(cdf_tmp[..., None] < p_uniform[:, None, :], dim=1).long() + idx_start = torch.clamp(idx_stop - 1, min=0) + + if not self.interpolate: + omega = torch.gather(cdf_tmp, dim=1, index=idx_stop) + else: + # Get CDF values. + cdf_start = torch.gather(cdf_tmp, dim=1, index=idx_start) + cdf_stop = torch.gather(cdf_tmp, dim=1, index=idx_stop) + + # Compute weights for linear interpolation. + cdf_delta = torch.clamp(cdf_stop - cdf_start, min=self.tol) + cdf_weight = torch.clamp((p_uniform - cdf_start) / cdf_delta, min=0.0, max=1.0) + + # Get angle range for interpolation. + omega_start = self.omega_grid[idx_start] + omega_stop = self.omega_grid[idx_stop] + + # Interpolate. + omega = torch.lerp(omega_start, omega_stop, cdf_weight) + + return omega + + +class SampleIGSO3(BaseSampleSO3): + so3_type = "igso3" # cache basename + + def __init__( + self, + num_omega: int, + sigma_grid: torch.Tensor, + omega_exponent: int = 3, + tol: float = 1e-7, + interpolate: bool = True, + l_max: int = 1000, + cache_dir: Optional[str] = None, + overwrite_cache: bool = False, + device: str = 'cpu', + ) -> None: + """ + Module for sampling rotations from the IGSO(3) distribution using the explicit series + expansion. Samples are created using inverse transform sampling based on the associated + cumulative probability distribution function (CDF) and a uniform distribution [0,1] as + described in [#leach2022_2]_. CDF values are obtained by numerically integrating the + probability distribution evaluated on a grid of angles and noise levels and stored in a + lookup table. Linear interpolation is used to approximate continuos sampling of the + function. Angles are discretized in an interval [0,pi] and the grid can be squashed to have + higher resolutions at low angles by taking different powers. + Since sampling relies on tabulated values of the CDF and indexing in the form of + `torch.bucketize`, gradients are not supported. + + Args: + num_omega (int): Number of discrete angles used for generating the lookup table. + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking + its power with the provided number. Defaults to 3. + tol (float, optional): Small value for numerical stability. Defaults to 1e-7. + interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF + to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. + l_max (int, optional): Maximum number of terms used in the series expansion. + cache_dir: Path to an optional cache directory. If set to None, lookup tables are + computed on the fly. + overwrite_cache: If set to true, existing cache files are overwritten. Can be used for + updating stale caches. + + References + ---------- + .. [#leach2022_2] Leach, Schmon, Degiacomi, Willcocks: + Denoising diffusion probabilistic models on so (3) for rotational alignment. + ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. + """ + self.l_max = l_max + super().__init__( + num_omega=num_omega, + sigma_grid=sigma_grid, + omega_exponent=omega_exponent, + tol=tol, + interpolate=interpolate, + cache_dir=cache_dir, + overwrite_cache=overwrite_cache, + device=device, + ) + + def _get_cache_name(self) -> str: + """ + Auxiliary function for determining the cache file name based on the parameters (sigma, + omega, l, etc.) used for generating the lookup tables. + + Returns: + Base name of the cache file. + """ + cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_l{:d}_o{:d}-{:d}.npz".format( + self.so3_type, + torch.min(self.sigma_grid).cpu().item(), + torch.max(self.sigma_grid).cpu().item(), + self.sigma_grid.shape[0], + self.l_max, + self.num_omega, + self.omega_exponent, + ) + return cache_name + + def expansion_function( + self, + omega_grid: torch.Tensor, + sigma_grid: torch.Tensor, + ) -> torch.Tensor: + """ + Use the truncated expansion of the IGSO(3) probability function to generate the lookup table. + + Args: + omega_grid (torch.Tensor): Grid of angle values. + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + + Returns: + torch.Tensor: IGSO(3) distribution for angles discretized on a 2D grid. + """ + return generate_igso3_lookup_table(omega_grid, sigma_grid, l_max=self.l_max, tol=self.tol) + + def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: + """ + Ensure sampled angles are 0 for small noise levels in IGSO(3). (Series expansion gives + uniform probability distribution.) + + Args: + sigma (torch.Tensor): Current values of sigma. + angles (torch.Tensor): Sampled angles. + + Returns: + torch.Tensor: Processed sampled angles. + """ + angles = torch.where( + sigma[..., None] < self.tol, + torch.zeros_like(angles), + angles, + ) + return angles + + +class SampleUSO3(BaseSampleSO3): + so3_type = "uso3" # cache basename + + def __init__( + self, + num_omega: int, + sigma_grid: torch.Tensor, + omega_exponent: int = 3, + tol: float = 1e-7, + interpolate: bool = True, + cache_dir: Optional[str] = None, + overwrite_cache: bool = False, + ) -> None: + """ + Module for sampling rotations from the USO(3) distribution. Can be used to generate initial + unbiased samples in the reverse process. Samples are created using inverse transform + sampling based on the associated cumulative probability distribution function (CDF) and a + uniform distribution [0,1] as described in [#leach2022_4]_. CDF values are obtained by + numerically integrating the probability distribution evaluated on a grid of angles and noise + levels and stored in a lookup table. Linear interpolation is used to approximate continuos + sampling of the function. Angles are discretized in an interval [0,pi] and the grid can be + squashed to have higher resolutions at low angles by taking different powers. + Since sampling relies on tabulated values of the CDF and indexing in the form of + `torch.bucketize`, gradients are not supported. + + Args: + num_omega (int): Number of discrete angles used for generating the lookup table. + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking + its power with the provided number. Defaults to 3. + tol (float, optional): Small value for numerical stability. Defaults to 1e-7. + interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF + to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. + cache_dir: Path to an optional cache directory. If set to None, lookup tables are + computed on the fly. + overwrite_cache: If set to true, existing cache files are overwritten. Can be used for + updating stale caches. + + References + ---------- + .. [#leach2022_4] Leach, Schmon, Degiacomi, Willcocks: + Denoising diffusion probabilistic models on so (3) for rotational alignment. + ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. + """ + super().__init__( + num_omega=num_omega, + sigma_grid=sigma_grid, + omega_exponent=omega_exponent, + tol=tol, + interpolate=interpolate, + cache_dir=cache_dir, + overwrite_cache=overwrite_cache, + ) + + def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(sigma).long() + + def sample_shape(self, num_sigma: int, num_samples: int) -> torch.Tensor: + dummy_sigma = torch.zeros(num_sigma, device=self.sigma_grid.device) + return self.sample(dummy_sigma, num_samples) + + def expansion_function( + self, + omega_grid: torch.Tensor, + sigma_grid: torch.Tensor, + ) -> torch.Tensor: + """ + The probability density function of the uniform SO(3) distribution is the cosine scaling + term (1-cos(omega))/pi which is applied automatically during sampling. This means, it is + sufficient to return a tensor of ones to create the correct USO(3) lookup table. + + Args: + omega_grid (torch.Tensor): Grid of angle values. + sigma_grid (torch.Tensor): Grid of IGSO3 std devs. + + Returns: + torch.Tensor: USO(3) distribution for angles discretized on a 2D grid. + """ + return torch.ones(1, omega_grid.shape[0], device=omega_grid.device) + + +@torch.no_grad() +def integrate_trapezoid_cumulative(f_grid: torch.Tensor, x_grid: torch.Tensor) -> torch.Tensor: + """ + Auxiliary function for numerically integrating a discretized 1D function using the trapezoid + rule. This is mainly used for computing the cumulative probability distributions for sampling + from the IGSO(3) distribution. Works on a single 1D grid or a batch of grids. + + Args: + f_grid (torch.Tensor): Discretized function values. + x_grid (torch.Tensor): Discretized input values. + + Returns: + torch.Tensor: Integrated function (not normalized). + """ + f_sum = f_grid[..., :-1] + f_grid[..., 1:] + delta_x = torch.diff(x_grid, dim=-1) + integral = torch.cumsum((f_sum * delta_x[None, :]) / 2.0, dim=-1) + return integral + + +def uniform_so3_density(omega: torch.Tensor) -> torch.Tensor: + """ + Compute the density over the uniform angle distribution in SO(3). + + Args: + omega: Angles in radians. + + Returns: + Uniform distribution density. + """ + return (1.0 - torch.cos(omega)) / np.pi + + +def igso3_expansion( + omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 +) -> torch.Tensor: + """ + Compute the IGSO(3) angle probability distribution function for pairs of angles and std dev + levels. The expansion is computed using a grid of expansion orders ranging from 0 to l_max. + + This function approximates the power series in equation 5 of [#yim2023_3]_. With this + parameterization, IGSO(3) agrees with the Brownian motion on SO(3) with t=sigma^2. + + Args: + omega: Values of angles (1D tensor). + sigma: Values of std dev of IGSO3 distribution (1D tensor of same shape as `omega`). + l_grid: Tensor containing expansion orders (0 to l_max). + tol: Small offset for numerical stability. + + Returns: + IGSO(3) angle distribution function (without pre-factor for uniform SO(3) distribution). + + References + ---------- + .. [#yim2023_3] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola: + SE(3) diffusion model with application to protein backbone generation. + arXiv preprint arXiv:2302.02277. 2023. + """ + # Pre-compute sine in denominator and clamp for stability. + denom_sin = torch.sin(0.5 * omega) + + # Pre-compute terms that rely only on expansion orders. + l_fac_1 = 2.0 * l_grid + 1.0 + l_fac_2 = -l_grid * (l_grid + 1.0) + + # Pre-compute numerator of expansion which only depends on angles. + numerator_sin = torch.sin((l_grid[None, :] + 1 / 2) * omega[:, None]) + + # Pre-compute exponential term with (2l+1) prefactor. + exponential_term = l_fac_1[None, :] * torch.exp(l_fac_2[None, :] * sigma[:, None] ** 2 / 2) + + # Compute series expansion + f_igso = torch.sum(exponential_term * numerator_sin, dim=1) + # For small omega, accumulate limit of sine fraction instead: + # lim[x->0] sin((l+1/2)x) / sin(x/2) = 2l + 1 + f_limw = torch.sum(exponential_term * l_fac_1[None, :], dim=1) + + # Finalize expansion. Offset for stability can be added since omega is [0,pi] and sin(omega/2) + # is positive in this interval. + f_igso = f_igso / (denom_sin + tol) + + # Replace values at small omega with limit. + f_igso = torch.where(omega <= tol, f_limw, f_igso) + + # Remove remaining numerical problems + f_igso = torch.where( + torch.logical_or(torch.isinf(f_igso), torch.isnan(f_igso)), torch.zeros_like(f_igso), f_igso + ) + + return f_igso + + +def digso3_expansion( + omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 +) -> torch.Tensor: + """ + Compute the derivative of the IGSO(3) angle probability distribution function with respect to + the angles for pairs of angles and std dev levels. As in `igso3_expansion` a grid is used for the + expansion levels. Evaluates the derivative directly in order to avoid second derivatives during + backpropagation. + + The derivative of the angle-dependent part is computed as: + + .. math :: + \frac{\partial}{\partial \omega} \frac{\sin((l+\tfrac{1}{2})\omega)}{\sin(\tfrac{1}{2}\omega)} = \frac{l\sin((l+1)\omega) - (l+1)\sin(l\omega)}{1 - \cos(\omega)} + + (obtained via quotient rule + different trigonometric identities). + + Args: + omega: Values of angles (1D tensor). + sigma: Values of IGSO3 distribution std devs (1D tensor of same shape as `omega`). + l_grid: Tensor containing expansion orders (0 to l_max). + tol: Small offset for numerical stability. + + Returns: + IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution). + """ + denom_cos = 1.0 - torch.cos(omega) + + l_fac_1 = 2.0 * l_grid + 1.0 + l_fac_2 = l_grid + 1.0 + l_fac_3 = -l_grid * l_fac_2 + + # Pre-compute numerator of expansion which only depends on angles. + numerator_sin = l_grid[None, :] * torch.sin(l_fac_2[None, :] * omega[:, None]) - l_fac_2[ + None, : + ] * torch.sin(l_grid[None, :] * omega[:, None]) + + # Compute series expansion + df_igso = torch.sum( + l_fac_1[None, :] * torch.exp(l_fac_3[None, :] * sigma[:, None] ** 2 / 2) * numerator_sin, + dim=1, + ) + + # Finalize expansion. Offset for stability can be added since omega is [0,pi] and cosine term + # is positive in this interval. + df_igso = df_igso / (denom_cos + tol) + + # Replace values at small omega with limit (=0). + df_igso = torch.where(omega <= tol, torch.zeros_like(df_igso), df_igso) + + # Remove remaining numerical problems + df_igso = torch.where( + torch.logical_or(torch.isinf(df_igso), torch.isnan(df_igso)), + torch.zeros_like(df_igso), + df_igso, + ) + + return df_igso + + +def dlog_igso3_expansion( + omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 +) -> torch.Tensor: + """ + Compute the derivative of the logarithm of the IGSO(3) angle distribution function for pairs of + angles and std dev levels: + + .. math :: + \frac{\partial}{\partial \omega} \log f(\omega) = \frac{\tfrac{\partial}{\partial \omega} f(\omega)}{f(\omega)} + + Required for SO(3) score computation. + + Args: + omega: Values of angles (1D tensor). + sigma: Values of IGSO3 std devs (1D tensor of same shape as `omega`). + l_grid: Tensor containing expansion orders (0 to l_max). + tol: Small offset for numerical stability. + + Returns: + IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution). + """ + f_igso3 = igso3_expansion(omega, sigma, l_grid, tol=tol) + df_igso3 = digso3_expansion(omega, sigma, l_grid, tol=tol) + + return df_igso3 / (f_igso3 + tol) + + +@torch.no_grad() +def generate_lookup_table( + base_function: Callable, + omega_grid: torch.Tensor, + sigma_grid: torch.Tensor, + l_max: int = 1000, + tol: float = 1e-7, +): + """ + Auxiliary function for generating a lookup table from IGSO(3) expansions and their derivatives. + Takes a basic function and loops over different std dev levels. + + Args: + base_function: Function used for setting up the lookup table. + omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). + sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). + l_max: Number of terms used in the series expansion. + tol: Small value for numerical stability. + + Returns: + Table of function values evaluated at different angles and std dev levels. The final shape is + [num_sigma x num_omega]. + """ + # Generate grid of expansion orders. + l_grid = torch.arange(l_max + 1, device=omega_grid.device).to(omega_grid.dtype) + + n_omega = len(omega_grid) + n_sigma = len(sigma_grid) + + # Populate lookup table for different time frames. + f_table = torch.zeros(n_sigma, n_omega, device=omega_grid.device, dtype=omega_grid.dtype) + + for eps_idx in tqdm(range(n_sigma), desc=f"Computing {base_function.__name__}"): + f_table[eps_idx, :] = base_function( + omega_grid, + torch.ones_like(omega_grid) * sigma_grid[eps_idx], + l_grid, + tol=tol, + ) + + return f_table + + +def generate_igso3_lookup_table( + omega_grid: torch.Tensor, + sigma_grid: torch.Tensor, + l_max: int = 1000, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Generate a lookup table for the IGSO(3) probability distribution function of angles. + + Args: + omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). + sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). + l_max: Number of terms used in the series expansion. + tol: Small value for numerical stability. + + Returns: + Table of function values evaluated at different angles and std dev levels. The final shape is + [num_sigma x num_omega]. + """ + f_igso = generate_lookup_table( + base_function=igso3_expansion, + omega_grid=omega_grid, + sigma_grid=sigma_grid, + l_max=l_max, + tol=tol, + ) + return f_igso + + +def generate_dlog_igso3_lookup_table( + omega_grid: torch.Tensor, + sigma_grid: torch.Tensor, + l_max: int = 1000, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Generate a lookup table for the derivative of the logarithm of the angular IGSO(3) probability + distribution function. Used e.g. for computing scaling of SO(3) norms. + + Args: + omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). + sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). + l_max: Number of terms used in the series expansion. + tol: Small value for numerical stability. + + Returns: + Table of function values evaluated at different angles and std dev levels. The final shape is + [num_sigma x num_omega]. + """ + dlog_igso = generate_lookup_table( + base_function=dlog_igso3_expansion, + omega_grid=omega_grid, + sigma_grid=sigma_grid, + l_max=l_max, + tol=tol, + ) + return dlog_igso diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aad2f7f201561f6d4a0694d1f531b23aff170519 --- /dev/null +++ b/data/utils.py @@ -0,0 +1,507 @@ +from typing import List, Dict, Any +from openfold.utils import rigid_utils as ru +from data import residue_constants +import numpy as np +import collections +import string +import pickle +import os +import torch +from torch_scatter import scatter_add, scatter +from Bio.PDB.Chain import Chain +from data import protein +import dataclasses +from Bio import PDB + +Rigid = ru.Rigid +Protein = protein.Protein + +# Global map from chain characters to integers. +ALPHANUMERIC = string.ascii_letters + string.digits + ' ' +CHAIN_TO_INT = { + chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) +} +INT_TO_CHAIN = { + i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) +} + +NM_TO_ANG_SCALE = 10.0 +ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE + +CHAIN_FEATS = [ + 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' +] + +to_numpy = lambda x: x.detach().cpu().numpy() +aatype_to_seq = lambda aatype: ''.join([ + residue_constants.restypes_with_x[x] for x in aatype]) + + +class CPU_Unpickler(pickle.Unpickler): + """Pytorch pickle loading workaround. + + https://github.com/pytorch/pytorch/issues/16797 + """ + def find_class(self, module, name): + if module == 'torch.storage' and name == '_load_from_bytes': + return lambda b: torch.load(io.BytesIO(b), map_location='cpu') + else: return super().find_class(module, name) + + +def create_rigid(rots, trans): + rots = ru.Rotation(rot_mats=rots) + return Rigid(rots=rots, trans=trans) + + +def batch_align_structures(pos_1, pos_2, mask=None): + if pos_1.shape != pos_2.shape: + raise ValueError('pos_1 and pos_2 must have the same shape.') + if pos_1.ndim != 3: + raise ValueError(f'Expected inputs to have shape [B, N, 3]') + num_batch = pos_1.shape[0] + device = pos_1.device + batch_indices = ( + torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) + * torch.arange(num_batch, device=device)[:, None] + ) + flat_pos_1 = pos_1.reshape(-1, 3) + flat_pos_2 = pos_2.reshape(-1, 3) + flat_batch_indices = batch_indices.reshape(-1) + if mask is None: + # aligned_pos_1, aligned_pos_2, align_rots = align_structures( + # flat_pos_1, flat_batch_indices, flat_pos_2) + # aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) + # aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) + # return aligned_pos_1, aligned_pos_2, align_rots + mask = torch.ones(*pos_1.shape[:2], device=device).reshape(-1).bool() + + flat_mask = mask.reshape(-1).bool() + # _, _, align_rots = align_structures( + # flat_pos_1[flat_mask], + # flat_batch_indices[flat_mask], + # flat_pos_2[flat_mask] + # ) + # aligned_pos_1 = torch.bmm( + # pos_1, + # align_rots + # ) + # return aligned_pos_1, pos_2, align_rots + aligned_pos_1, aligned_pos_2, align_rots = align_structures( + flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask]) + aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) + aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) + return aligned_pos_1, aligned_pos_2, align_rots + + + +def adjust_oxygen_pos( + atom_37: torch.Tensor, pos_is_known = None +) -> torch.Tensor: + """ + Imputes the position of the oxygen atom on the backbone by using adjacent frame information. + Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the + current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom + away from the C in the current frame in the direction away from the Ca-C-N triangle. + + For cases where the next frame is not available, for example we are at the C-terminus or the + next frame is not available in the data then we place the oxygen in the same plane as the + N-Ca-C of the current frame and pointing in the same direction as the average of the + Ca->C and Ca->N vectors. + + Args: + atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering + which is ['N', 'CA', 'C', 'CB', 'O', ...] + pos_is_known (torch.Tensor): (N,) mask for known residues. + """ + + N = atom_37.shape[0] + assert atom_37.shape == (N, 37, 3) + + # Get vectors to Carbonly from Carbon alpha and N of next residue. (N-1, 3) + # Note that the (N,) ordering is from N-terminal to C-terminal. + + # Calpha to carbonyl both in the current frame. + calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( + torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 + ) + # For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0. + # The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change. + + # Nitrogen of the next frame to carbonyl of the current frame. + nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( + torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 + ) + + carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl # (N-1, 3) + carbonyl_to_oxygen = carbonyl_to_oxygen / ( + torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 + ) + + atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 + + # Now we deal with frames for which there is no next frame available. + + # Calpha to carbonyl both in the current frame. (N, 3) + calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( + torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 + ) + # Calpha to nitrogen both in the current frame. (N, 3) + calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( + torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 + ) + carbonyl_to_oxygen_term: torch.Tensor = ( + calpha_to_carbonyl_term + calpha_to_nitrogen_term + ) # (N, 3) + carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( + torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 + ) + + # Create a mask that is 1 when the next residue is not available either + # due to this frame being the C-terminus or the next residue is not + # known due to pos_is_known being false. + + if pos_is_known is None: + pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) + + next_res_gone: torch.Tensor = ~pos_is_known.bool() # (N,) + next_res_gone = torch.cat( + [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 + ) # (N+1, ) + next_res_gone = next_res_gone[1:] # (N,) + + atom_37[next_res_gone, 4, :] = ( + atom_37[next_res_gone, 2, :] + + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 + ) + + return atom_37 + + +def write_pkl( + save_path: str, pkl_data: Any, create_dir: bool = False, use_torch=False): + """Serialize data into a pickle file.""" + if create_dir: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + if use_torch: + torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) + else: + with open(save_path, 'wb') as handle: + pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): + """Read data from a pickle file.""" + try: + if use_torch: + return torch.load(read_path, map_location=map_location) + else: + with open(read_path, 'rb') as handle: + return pickle.load(handle) + except Exception as e: + try: + with open(read_path, 'rb') as handle: + return CPU_Unpickler(handle).load() + except Exception as e2: + if verbose: + print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}') + raise(e) + + +def chain_str_to_int(chain_str: str): + chain_int = 0 + if len(chain_str) == 1: + return CHAIN_TO_INT[chain_str] + for i, chain_char in enumerate(chain_str): + chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) + return chain_int + + +def parse_chain_feats(chain_feats, scale_factor=1.): + ca_idx = residue_constants.atom_order['CA'] + chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] + bb_pos = chain_feats['atom_positions'][:, ca_idx] + bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) + centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] + scaled_pos = centered_pos / scale_factor + chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] + chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] + return chain_feats + + +def concat_np_features( + np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): + """Performs a nested concatenation of feature dicts. + + Args: + np_dicts: list of dicts with the same structure. + Each dict must have the same keys and numpy arrays as the values. + add_batch_dim: whether to add a batch dimension to each feature. + + Returns: + A single dict with all the features concatenated. + """ + combined_dict = collections.defaultdict(list) + for chain_dict in np_dicts: + for feat_name, feat_val in chain_dict.items(): + if add_batch_dim: + feat_val = feat_val[None] + combined_dict[feat_name].append(feat_val) + # Concatenate each feature + for feat_name, feat_vals in combined_dict.items(): + combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) + return combined_dict + + +def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: + """ + Move the molecule center to zero for sparse position tensors. + + Args: + pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. + batch_indexes: [N] batch index for each atom in sparse batch format. + + Returns: + pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. + """ + assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" + + means = scatter(pos, batch_indexes, dim=0, reduce="mean") + return pos - means[batch_indexes] + + +@torch.no_grad() +def align_structures( + batch_positions: torch.Tensor, + batch_indices: torch.Tensor, + reference_positions: torch.Tensor, + broadcast_reference: bool = False, +): + """ + Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the + sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then + the reference can be given as a single structure and broadcasted. Returns the structure + coordinates shifted to the geometric center and the batch structures rotated to match the + reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of + atoms is carried out. + + Args: + batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned + to a reference. + batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective + system (e.g. batch attribute of ChemGraph batch). + reference_positions (Tensor): Reference structure. Can either be a batch of structures or a + single structure. In the second case, broadcasting is possible if the input batch is + composed exclusively of this structure. + broadcast_reference (bool, optional): If reference batch contains only a single structure, + broadcast this structure to match the ChemGraph batch. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch + structures rotated into the reference and the centered reference batch. + + References + ---------- + .. [kabsch_align1] Lawrence, Bernal, Witzgall: + A purely algebraic justification of the Kabsch-Umeyama algorithm. + Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. + """ + # Minimize || Q @ R.T - P ||, which is the same as || Q - P @ R || + # batch_positions -> P [BN x 3] + # reference_positions -> Q [B / BN x 3] + + if batch_positions.shape[0] != reference_positions.shape[0]: + if broadcast_reference: + # Get number of systems in batch and broadcast reference structure. + # This assumes, all systems in the current batch correspond to the reference system. + # Typically always the case during evaluation. + num_molecules = int(torch.max(batch_indices) + 1) + reference_positions = reference_positions.repeat(num_molecules, 1) + else: + raise ValueError("Mismatch in batch dimensions.") + + # Center structures at origin (takes care of translation alignment) + batch_positions = center_zero(batch_positions, batch_indices) + reference_positions = center_zero(reference_positions, batch_indices) + + # Compute covariance matrix for optimal rotation (Q.T @ P) -> [B x 3 x 3]. + cov = scatter_add( + batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 + ) + + # Perform singular value decomposition. (all [B x 3 x 3]) + u, _, v_t = torch.linalg.svd(cov) + # Convenience transposes. + u_t = u.transpose(1, 2) + v = v_t.transpose(1, 2) + + # Compute rotation matrix correction for ensuring right-handed coordinate system + # For comparison with other sources: det(AB) = det(A)*det(B) and det(A) = det(A.T) + sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) + # Correct transpose of U: diag(1, 1, sign_correction) @ U.T + u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] + + # Compute optimal rotation matrix (R = V @ diag(1, 1, sign_correction) @ U.T). + rotation_matrices = torch.bmm(v, u_t) + + # print('batch_positions=',batch_positions.dtype) + # print('rotation_matrices=',rotation_matrices.dtype) + rotation_matrices = rotation_matrices.type(batch_positions.dtype) + + # Rotate batch positions P to optimal alignment with Q (P @ R) + batch_positions_rotated = torch.bmm( + batch_positions[:, None, :], + rotation_matrices[batch_indices], + ).squeeze(1) + + return batch_positions_rotated, reference_positions, rotation_matrices + + +def parse_pdb_feats( + pdb_name: str, + pdb_path: str, + scale_factor=1., + # TODO: Make the default behaviour read all chains. + chain_id='A', + ): + """ + Args: + pdb_name: name of PDB to parse. + pdb_path: path to PDB file to read. + scale_factor: factor to scale atom positions. + mean_center: whether to mean center atom positions. + Returns: + Dict with CHAIN_FEATS features extracted from PDB with specified + preprocessing. + """ + parser = PDB.PDBParser(QUIET=True) + structure = parser.get_structure(pdb_name, pdb_path) + struct_chains = { + chain.id: chain + for chain in structure.get_chains()} + + def _process_chain_id(x): + chain_prot = process_chain(struct_chains[x], x) + chain_dict = dataclasses.asdict(chain_prot) + + # Process features + feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} + return parse_chain_feats( + feat_dict, scale_factor=scale_factor) + + if isinstance(chain_id, str): + return _process_chain_id(chain_id) + elif isinstance(chain_id, list): + return { + x: _process_chain_id(x) for x in chain_id + } + elif chain_id is None: + return { + x: _process_chain_id(x) for x in struct_chains + } + else: + raise ValueError(f'Unrecognized chain list {chain_id}') + +def rigid_transform_3D(A, B, verbose=False): + # Transforms A to look like B + # https://github.com/nghiaho12/rigid_transform_3D + assert A.shape == B.shape + A = A.T + B = B.T + + num_rows, num_cols = A.shape + if num_rows != 3: + raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") + + num_rows, num_cols = B.shape + if num_rows != 3: + raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") + + # find mean column wise + centroid_A = np.mean(A, axis=1) + centroid_B = np.mean(B, axis=1) + + # ensure centroids are 3x1 + centroid_A = centroid_A.reshape(-1, 1) + centroid_B = centroid_B.reshape(-1, 1) + + # subtract mean + Am = A - centroid_A + Bm = B - centroid_B + + H = Am @ np.transpose(Bm) + + # sanity check + #if linalg.matrix_rank(H) < 3: + # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H))) + + # find rotation + U, S, Vt = np.linalg.svd(H) + R = Vt.T @ U.T + + # special reflection case + reflection_detected = False + if np.linalg.det(R) < 0: + if verbose: + print("det(R) < R, reflection detected!, correcting for it ...") + Vt[2,:] *= -1 + R = Vt.T @ U.T + reflection_detected = True + + t = -R @ centroid_A + centroid_B + optimal_A = R @ A + t + + return optimal_A.T, R, t, reflection_detected + +def process_chain(chain: Chain, chain_id: str) -> Protein: + """Convert a PDB chain object into a AlphaFold Protein instance. + + Forked from alphafold.common.protein.from_pdb_string + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Took out lines 94-97 which don't allow insertions in the PDB. + Sabdab uses insertions for the chothia numbering so we need to allow them. + + Took out lines 110-112 since that would mess up CDR numbering. + + Args: + chain: Instance of Biopython's chain class. + + Returns: + Protein object with protein features. + """ + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + chain_ids = [] + for res in chain: + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name] + ] = atom.bfactor + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + chain_ids.append(chain_id) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=np.array(chain_ids), + b_factors=np.array(b_factors)) diff --git a/dataset/ATLAS_filename.txt b/dataset/ATLAS_filename.txt new file mode 100644 index 0000000000000000000000000000000000000000..a7c90622cc0b78be36866c41896ed2cc048a457e --- /dev/null +++ b/dataset/ATLAS_filename.txt @@ -0,0 +1,1390 @@ +1k5n_A +5e3e_A +3ilw_A +3bl9_A +6o2v_A +1zd7_B +2lis_A +3qra_A +7ead_A +3bn0_A +1fx4_A +4byz_A +4a56_A +1d3y_B +4oe9_B +5kko_D +4bt9_B +3pkz_A +2nwv_A +4tsh_A +5kdw_B +3rv1_B +4goq_B +1f20_A +3nrw_A +4g75_A +1gwe_A +6irx_A +3g5t_A +6uof_A +2pmb_C +3kby_A +1w2w_A +4geh_A +6lus_A +5kdv_A +2py5_A +3wol_A +3cw3_A +2hqk_A +3ll7_A +5ezu_A +1f3u_G +3a7l_A +2i5v_O +1tez_B +1jyo_F +1jyo_B +1o1x_A +6eof_A +2f5t_X +1ys1_X +3e7h_A +1nns_A +3ipf_A +6qj0_A +4ezi_A +4u9b_A +2v6a_O +1q74_C +2jdd_A +1pv5_A +4zya_A +2ii1_D +3h1d_A +3cjs_B +1r8m_E +1lln_A +5opz_A +3oyv_A +3jsz_A +2gg2_A +4q7o_B +3vkw_A +3hgl_A +1n2m_A +1vk1_A +1vdk_B +2cb5_A +3hms_A +2wnp_F +5x1e_D +3qc7_A +1lc3_A +6j56_A +6fsw_A +1z1l_A +3m6z_A +2fhz_A +5ws1_A +2x5g_A +1rm6_B +3nr1_A +3nuf_A +2xdg_A +3ov5_A +4wyw_A +3do6_A +3g28_A +2wzo_A +3lx5_A +3ooi_A +1btk_A +2gkp_A +2pmk_A +1dd3_B +6cka_B +7ec1_A +1qlm_A +3fpu_A +1msk_A +2qsk_A +3emi_A +2qcu_B +1ux6_A +2hy5_A +1gqv_A +5ol4_A +2rkq_A +2vi1_A +3po8_A +1egw_B +2d5b_A +2x96_A +1f9p_A +1qw2_A +2ot9_A +5gv0_A +6hj6_A +5kts_A +2wkx_A +6dgk_B +2x49_A +2db7_A +1sf9_A +3dan_A +6qfk_A +6fc0_B +2qhq_B +4f87_A +3se9_G +1bq8_A +1tg7_A +1dfu_P +3onj_A +1t9i_B +2ygs_A +4aqr_D +4adn_B +3g3s_B +6xds_A +6dlm_A +1tgr_A +4dol_A +3e0z_A +2e5y_B +4p7x_A +1pch_A +2vrs_C +3it4_B +1j0p_A +4ywk_A +6xrx_A +6q9c_B +1tdz_A +2wr8_A +1feh_A +2phn_A +2ox7_B +3gqv_A +1n8u_A +4ayg_B +3jub_A +6rrv_A +5o9e_A +3v0r_A +2qnl_A +1xg0_A +1m7v_A +1k7c_A +2wn3_C +1ouw_C +4qtq_A +2prv_A +3eoj_A +1ear_A +2gkr_I +1b2s_F +1jke_C +2fos_A +7lao_A +2w0i_A +1nl1_A +3lur_B +1ijx_C +4mzv_A +4yxp_A +6l4l_A +1o98_A +1sx3_A +4m0q_A +2fj8_A +1xhk_B +1a62_A +2d28_C +2v9m_A +3l1e_A +1yoz_B +4mmj_A +7asg_A +2p54_A +1g38_A +4epi_A +1ez3_B +3im1_A +1u55_A +2dtr_A +3c2q_B +4p3x_A +2p6v_A +6mdw_A +1ewf_A +2i9x_A +2fb5_A +5lpa_A +6bwq_A +1io1_A +1rlk_A +4obi_A +1uwk_B +4n67_A +1r5l_A +6kty_A +1u84_A +3vgp_A +2o6x_A +6vjg_A +4h9n_C +4tkx_L +1mv8_C +1xvg_E +1zdy_A +5t76_A +2p12_A +3d7a_B +6sms_A +4xqm_A +1m5q_I +1y7t_B +4lq4_A +3fm2_A +1jq5_A +1lr0_A +1uoy_A +6l3r_E +1xo7_B +3o0g_D +3lvu_D +1qcs_A +2zca_B +3gnu_P +1juv_A +1qdd_A +2i7c_C +5el3_B +3e4w_B +4ew5_A +4cv4_A +5op0_B +1jo0_A +5ygh_A +3bpt_A +3vnb_A +5b1n_A +2qev_A +2y44_A +7qsu_A +3fhf_A +7p46_A +5ydn_A +3ncl_A +7e2s_A +2nuk_A +3h5q_A +1o75_A +3d2q_A +2y1q_A +1wui_S +3lwx_A +5cfl_A +1gqe_A +3nrl_B +3nke_B +3q1c_A +4ddp_A +2y39_A +4jja_A +1ctf_A +3dgp_B +2w39_A +1x6j_A +1wx1_B +1h8e_G +2i2f_A +1j8e_A +2jhf_B +6avj_B +3omd_B +2in8_A +2hgx_B +2xvs_A +5w82_E +4rgi_A +1dow_A +4za9_A +3l4h_A +5b3p_A +2q9r_A +3mx7_A +6pxz_B +3q1n_A +1o0s_A +1ojh_C +2ebb_A +3g5k_D +2rff_A +2gec_A +3mmy_D +4xb5_A +3h93_A +2vfx_C +2vh3_A +3zr8_X +1j8r_A +1rqj_A +3v68_A +5hqh_A +3ii2_A +5h7w_B +3ip4_A +3zzo_A +3ghd_A +6cb7_A +3aa6_B +2b0a_A +5yrv_I +2j6a_A +2e7z_A +3f0p_A +5osw_A +3o3x_A +4jo7_A +4omf_B +5a5y_A +1jwq_A +2obl_A +3a3d_B +1kpt_A +1md6_A +4e5r_A +3owt_B +2znr_A +4cjn_A +4ndt_A +4ghn_A +2w6k_A +1rk8_B +6ovk_R +5h6x_A +5ul3_A +3jv1_A +5bs1_B +3ajf_A +2p84_A +2pcn_A +3zfn_A +2y20_A +6ndw_B +5z51_A +2w9y_A +3fh2_A +1i2t_A +1vzy_B +5kwb_A +3eto_A +6a9a_A +2wya_A +5erq_A +3l4p_A +6pce_B +4v23_A +3kbb_A +1w53_A +7p41_D +1hbn_E +1dzf_A +1h16_A +6h86_A +1hq0_A +3czz_B +2v05_A +1im5_A +1jif_A +2gpi_A +1svb_A +1pcf_C +2glz_B +4xdu_A +4mp8_A +5a67_A +5b5i_B +7jfl_C +6anz_A +6iah_A +1fs1_A +2bti_A +5iz3_B +1gte_D +3a1g_B +3ig9_C +3vej_B +5dbl_A +1zsw_A +1k7j_A +2j7z_A +4zi3_D +3rt4_A +3zoq_C +6atg_C +3fn2_B +1i4m_A +1ig0_B +2bw3_B +4f98_A +2x9o_A +2qdx_A +5imu_A +3sfv_B +4jqu_B +1krn_A +4ehx_A +6y2x_A +1z7k_B +7nmq_A +5ef9_A +4rp5_A +4qx8_B +1msc_A +1l2w_I +5w4a_B +2gwn_A +6xb3_H +1zjc_A +4zav_A +2nml_A +2d1l_A +4a3p_A +5n0s_A +1u09_A +2bo1_A +1d2t_A +3ohg_A +2p58_A +2yp8_A +4oj6_C +1hp1_A +2gc4_L +6jwh_A +2wsi_A +1z21_A +2waw_A +2bk9_A +4xb6_F +3rmq_A +1xkn_A +1b2s_E +3wa1_A +3cyi_A +1nij_A +2ahn_A +3ot9_B +5x1u_B +2ed6_A +2gg6_A +1dlj_A +2ad6_A +3mil_A +3lbl_A +2fip_A +2jek_A +6l4p_B +2vfk_A +1g9g_A +3x0g_A +1ng6_A +1h2e_A +5n6y_C +1tov_A +2cxi_B +2p2v_A +1q2h_A +4erv_A +1ifg_A +6jpt_A +2xjq_A +2b9d_A +1t61_E +1uz3_A +3bk5_A +5j3t_B +4ykd_A +5zmo_A +5hee_B +5ngd_B +1qwz_A +2pst_X +1njh_A +4zds_B +6bk4_A +1qb7_A +3ib5_A +1ypy_A +5g68_A +2pqx_A +3luu_A +3p5p_A +2au3_A +3wkx_A +2v0u_A +1ka1_A +1hc9_B +5vz0_A +1pby_C +3tgt_A +6b1z_A +3e8t_A +3rxy_A +2ims_A +1g6g_A +7a66_B +3obl_B +3ebh_A +3twz_A +2zzd_E +2ov9_C +5cof_A +5eg7_A +1wpb_A +1rss_A +3bjd_B +6okd_C +3boe_A +2inc_A +1v7r_A +1fn9_A +3on9_A +1koe_A +3fke_A +3ino_A +2bln_B +2x5r_A +6eu8_A +1nf8_A +1iom_A +1zn6_A +2cwy_A +6mbg_A +4z9p_A +16pk_A +3qc5_X +6bm5_A +3l51_B +4m37_A +5i8j_A +6f45_D +5um3_A +6in7_A +2xrh_A +3pvh_A +3fkc_A +3b49_A +1u8v_C +1dtd_B +2a6s_B +3ueb_E +2fma_A +3lpc_A +7onn_A +2ehq_A +6fub_B +3b4q_B +3bzn_A +1qxo_A +3m66_A +3mhs_B +4e1p_A +4g6d_B +1y0n_A +6j33_B +1v7z_F +4qb0_A +1xsz_A +5mux_B +5l87_A +3tvj_I +5xbc_B +1ois_A +6ono_C +1d4t_A +2v4n_A +2pfx_A +1bkp_A +3n4j_A +1pp0_C +2ad6_D +4xo1_A +3o04_A +5a0n_A +4gn0_A +2qmj_A +3jtz_A +3cy4_A +2nnu_A +3lq9_A +5cyz_C +1ocy_A +6d7y_A +4gvq_C +1wq8_A +1k8k_E +2fef_B +1t7v_A +2y3w_A +2fp1_A +1q8f_A +4exo_A +3mcb_B +4myy_A +1fd3_A +1lsl_A +3da8_A +2hba_A +1m9z_A +3im3_A +3t7h_B +3shg_B +2pq8_A +3enu_A +2v27_A +1bgf_A +1ryl_A +1nty_A +3cla_A +4fc2_B +1j2l_A +1u6t_A +1oh9_A +3m4r_A +5ij2_B +4ngd_A +3l32_A +5ce9_A +2wcu_A +1fm4_A +5vap_A +2g7o_A +5jx1_A +2pu9_A +1c96_A +2eo4_A +6odd_B +1u6r_B +2z0x_A +2v4b_B +2jft_A +3hol_A +3l9a_X +1vbw_A +4ljo_A +1tag_A +4xb6_H +3imk_A +2o7o_A +1tzw_A +2yxn_A +2xol_B +2wpq_C +2hq4_A +4y7s_A +1coj_A +1kgq_A +6gfx_C +2gwm_A +5dpo_A +1dvo_A +2fe5_A +1vr8_A +1rrm_A +3a39_A +2p9x_D +2dka_A +5ert_A +3it4_C +3bhg_A +1m8s_A +3rlk_A +2dt8_A +6p5x_B +1jnr_C +3u5v_A +3ipj_A +3zrg_B +1hye_A +3lyf_D +2rkv_A +2yjg_A +3g98_B +2pyq_B +4hwx_A +3a1u_A +3ush_B +3st1_A +3o79_A +5bnh_A +1llf_B +6tgk_C +2bjq_A +3kz7_A +3fbl_A +2nsa_A +1dwk_A +2at8_X +2qia_A +4iaj_B +2hu9_A +2e6x_A +6a02_A +2yvq_A +3ju3_A +4alz_A +4dja_A +3ic3_B +4qxo_A +1r9f_A +2yzt_A +3dff_A +1wd5_A +3e9v_A +1s9u_A +1wui_L +2ozl_A +5wvo_C +4qbu_A +5n0n_A +2vmc_A +6c0h_A +7dmn_A +3gs9_A +4cbe_A +1z0b_A +3i94_A +1dpt_A +6dnm_A +5pal_A +2gs5_A +2oeb_A +1nnh_A +4v4e_M +4zac_D +1ef1_D +1huf_A +2ixm_A +7lp1_A +3mol_B +2cx7_A +3nyy_A +3ti2_A +2xmx_A +2ppv_A +2q7x_A +1v0w_A +6l34_A +5kvk_A +5hz7_A +5h1n_B +1utg_A +1m15_A +3igz_B +3f4m_A +7ned_A +2r2d_A +7s86_A +4na5_A +4jnu_B +2qtd_A +3vc8_B +2amh_A +5l0s_A +2rij_A +2onf_B +2ie6_A +3oi8_A +1bx7_A +2rjw_A +4z8q_B +6l8s_A +2gc4_B +2gno_A +3zdm_A +3zl8_A +4p5a_A +5dmu_A +2j6b_A +7bwf_B +1c1k_A +4xgl_A +4k2m_A +3apr_E +2z0j_E +1esw_A +1zxx_A +1b8d_K +3msw_A +5z42_A +1qau_A +1wbe_A +7aex_A +1ntv_A +2pv4_A +2gh0_D +4wzg_A +1yqe_A +3mm6_A +6fy5_A +1p5g_X +3kde_C +1whi_A +4wrp_A +2hnu_A +6as3_A +5w5c_A +1t8k_A +3c1q_B +2w0g_A +5njc_D +2qjz_A +4mb5_A +3dnh_B +1c52_A +4fhr_B +6d7y_B +6e7e_A +2o4t_A +2hhv_A +1xeo_A +1z2u_A +2b97_A +5exe_C +3a2q_A +1yoe_A +3qfl_A +3ijw_A +2wbf_X +6bn0_A +2ycl_A +4ys0_A +3bec_A +5wfy_A +2bmo_B +1s9r_A +7k7p_B +1ikt_A +2jh3_B +4v1g_B +5oh5_A +3eye_A +2f60_K +1qnx_A +4h4n_A +3h0n_A +4o66_D +1ig3_A +1jhs_A +3agn_A +7buy_A +5gw6_A +2b1y_A +5v0m_A +1o4k_A +6yhu_B +1r0k_A +3eqx_A +2end_A +1dj0_B +3mfd_A +2ia1_A +4hfq_B +6hem_A +2igp_A +2bbr_A +1qhv_A +3gmv_X +4qn8_B +2dlb_A +5z1n_A +4ke2_C +2op6_A +2x5h_B +3o6p_A +2wqi_D +1v9m_A +5vu3_B +1t95_A +3m7d_A +2xz3_A +2bfw_A +6c62_C +2wb6_A +4wlr_B +2ovj_A +5njl_A +3q0i_A +3nl9_A +1n7z_A +3fcn_A +2dwk_A +3m5q_A +1lj5_A +1gnt_A +5jbg_A +5m45_K +5ori_A +2xf2_A +3u9g_A +2yvo_A +3ldu_A +2cyj_A +1dk8_A +1yu5_X +2zwd_A +2pa6_B +2p6h_B +4cfi_A +4j4r_A +3lk7_A +5iff_A +4oie_A +2cfe_A +7fd1_A +1hw1_B +1v4p_C +2fzp_A +4f55_A +5zlq_A +5wx9_A +2vy8_A +2c9i_D +2hxm_A +2au7_A +4phr_A +5ggl_A +6h49_A +2hng_A +3nqi_A +1ail_A +1vbi_A +3mqz_A +4xb6_E +5gof_A +2vlf_B +1j77_A +4fch_B +1vyi_A +7aqx_A +2rji_A +1mlw_A +1g2r_A +2qq4_B +2uxy_A +4z3x_A +4qmd_A +3fp5_A +2rci_A +2b0t_A +2y1y_A +3h8t_A +7c45_A +1upt_D +5ltj_A +2xeu_A +1dci_C +3lyy_A +4lim_A +3d3y_A +5eqz_A +4koq_A +4q0y_A +2gke_A +3jz9_A +1mk0_A +4b6i_D +1fxo_G +2qsa_A +3klr_A +1ab1_A +3msu_B +3fo8_D +4iu3_B +1u07_A +3hvv_A +4ubp_C +4s39_A +6gus_A +4d71_A +5ftw_A +3lgb_B +2pu3_A +3fyb_A +2oxo_A +3pf6_A +5klh_A +6ao8_A +5g38_A +3eh1_A +3tbn_A +4dlh_B +6q9c_A +4dok_A +5o5s_A +4lty_A +3a1g_A +5c8q_B +3d1g_A +5hkx_A +5uki_A +7n0j_E +6o6y_A +2bf6_A +2b3y_A +3v3l_A +3ga3_A +6zsl_B +1xta_A +3ni2_A +2w56_A +4ftf_A +2ra8_A +3bwu_D +7rm7_A +3g06_A +4rsw_B +3n4i_B +3ikw_A +1ssq_D +1oao_C +3gqh_A +1zpe_A +4rr9_A +5by8_B +5dje_B +3dpg_B +1l5o_A +5m1m_A +1yoc_B +3hfw_A +4af1_A +2fb6_A +3mmh_B +2xpp_A +4gv2_A +1ptq_A +1whz_A +1j5u_A +5j47_A +2hyv_A +3e2d_A +1ojq_A +4cdp_A +2imf_A +3rjs_A +4pz1_A +2oxl_B +2bko_A +3c8w_C +3p3g_A +1lri_A +2ex2_A +2rdi_A +2qwo_B +2vnl_A +4a1k_A +3lpe_B +2rbk_A +3ho6_A +6ypi_A +1fcz_A +2i5u_A +3ofg_B +5nir_A +3b9q_A +1ugp_A +2po4_A +3l84_A +5txr_B +3c0f_B +2psp_B +5w2f_A +2es9_A +3lae_A +2i9w_A +3vth_A +2oy9_B +2ysk_A +1aol_A +1by2_A +1sei_A +1wlg_B +1ijt_A +2ofz_A +2dt4_A +2it2_B +4bgp_A +1vcl_A +2aj7_A +4yal_A +3frr_A +6ro6_A +2y7b_A +3ezu_A +1l9l_A +2oya_B +7mf4_A +2e7u_A +2qp2_A +4gip_A +1wf3_A +3bzl_D +2ejn_B +3c05_A +3k5j_A +1fk5_A +4lpq_A +7jrq_A +5v7m_A +2xz2_A +5e7x_A +3dso_A +3fpn_A +2iay_A +1u7k_D +3ka5_A +3g0m_A +1ok0_A +2pw8_I +2h1t_B +4o6q_A +5hub_A +3pes_A +5en2_C +2p3p_B +5xct_B +1lfp_A +4jmu_A +5v03_B +2vg3_C +1ah7_A +7wab_A +3hlz_B +1wna_A +2vkp_B +1vkm_C +1doi_A +2b7r_A +3djl_A +3c9p_A +5znj_A +5j2l_B +3dza_A +2vac_A +3hdj_B +3apa_A +3bqk_A +4ksn_C +2z6r_A +4ebg_A +2pbk_B +1h02_B +1wq6_A +3t4r_A +5btz_A +1ql0_B +2h2z_A +1x6i_A +1ve0_A +4xb4_A +1yu0_A +1vcc_A +2pag_A +3kd4_A +1zt5_A +2erl_A +2r31_A +5mdu_A +4y93_A +1lwd_B +4hfv_A +1f5n_A +3las_B +2z4u_A +6pnv_A +2o74_F +3vor_A +2fi9_A +3pzw_A +2xse_A +3bpj_B +2e8e_A +1r6w_A +5jca_S +1vaj_A +1uha_A +6rwt_A +2ftx_A +2ccq_A +2w8x_A +2nqw_A +6oz1_A +3h1w_A +4csb_A +3f7w_A +5ja5_A +3dai_A +3mwz_A +1vlb_A +5uc0_A +2gib_B +3let_B +1kop_A +2oez_B +2nug_B +1h99_A +4ue8_B +3kz5_E +6nl2_A +2igi_A +2vzc_A +3wx4_A +5naz_A +2vm5_A +1y6i_A +1rgz_A +3nis_B +2qud_A +3gkj_A +6e33_A +4whe_A +1g7s_A +1iuq_A +4gk9_A +5iqj_C +4ytw_C +5exh_C +5j72_A +2id6_A +1xkg_A +5muy_B +5noh_A +3wvz_A +2jlq_A +3a0y_B +1tt8_A +4h9m_A +3ju4_A +5ok6_A +6p5h_A +3hxw_A +2wfr_A +2w86_A +2pnw_A +1cpq_A +1kew_A +1gpj_A +3t92_A +3ehg_A +1uxy_A +6q10_A +1xty_B +3mxz_A +1cuo_A +5jak_A +3wur_B +3wwh_A +3d4i_A +4m62_S +3rkg_A +3ol3_A +1puc_A +3p4h_A +4lzh_A +5t2p_B +1wpu_A +6sup_A +1knt_A +5h2d_A +6jv8_A +3crv_A +1ix9_A +4wt3_A +6lrd_A +6tly_A +2rku_A +6idx_A +3m7a_A +4oxw_A +2wlt_A +2zfz_D +1sdi_A +3pt1_A +3jsy_B +2g64_A +2vri_A +4gsj_A +4rnz_A +2cg7_A +2xf7_A +1chd_A +3kvd_D +4j8c_A +4kct_B +4u5h_G +1qgi_A +2y4x_B +4o87_A +2aib_A +3f0o_B +3b4w_A +2q2g_A +5ii7_A +1l2p_A +1zos_C +6e5y_A +3h2d_B +3nci_A +3kef_B +7la6_A +3txs_C +1wqj_B +3lti_A +3gwi_A +1xwl_A +3bqp_B +1xmk_A +1nh2_B +4zmk_A +4cv7_A +1ru4_A +5k6d_B +2jc5_A +3dha_A +2fyg_A +5x1e_B +4ct7_A +2vcc_A +3lkt_Q +2f6e_A +1ikp_A +4kk7_A +1b25_D +1bxy_A +2z7f_I +3oxp_A +5e5y_C +4o6g_A +4eo1_A +3l2a_A +4npd_A +3iiu_M +2xu8_B +1rl6_A +3t7z_A +4b2f_A +4ea9_A +4xpz_A +3dd6_A +1xkr_A +2abk_A +4jgi_A +5b1q_B +5w53_B +6crk_G \ No newline at end of file diff --git a/dataset/download.py b/dataset/download.py new file mode 100644 index 0000000000000000000000000000000000000000..1fcb989680da18b3cdf9205f149b01834851a73a --- /dev/null +++ b/dataset/download.py @@ -0,0 +1,38 @@ +import os +import argparse +import multiprocessing as mp + + +def fn(para): + file, url_base, data_dir = para + url = url_base + file + output_filename = file+".zip" + output_path = os.path.join(data_dir, output_filename) + unzip_path = os.path.join(data_dir, file) + os.system(f"curl -X GET {url} -H accept: */* --output {output_path}") + os.system(f"unzip {output_path} -d {unzip_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--filename_file", type=str, default="./dataset/ATLAS_filename.txt") + parser.add_argument("--url_base", type=str, default="https://www.dsimb.inserm.fr/ATLAS/api/ATLAS/analysis/") + parser.add_argument("--output_dir", type=str, default="./dataset/ATLAS_test") + + args = parser.parse_args() + + num_processes = 48 + + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.filename_file,'r+') as f: + file_cont = f.read() + file_list = file_cont.split("\n") + para_list = [(file, args.url_base, args.output_dir) for file in file_list] + + with mp.Pool(num_processes) as pool: + _ = pool.map(fn, para_list) + print("finished") + + os.system(f"cp -rf {args.filename_file} {args.output_dir}") diff --git a/dataset/traj_analyse_select.py b/dataset/traj_analyse_select.py new file mode 100644 index 0000000000000000000000000000000000000000..a91fdd302784ecb307ceb44239d58d19dffa072c --- /dev/null +++ b/dataset/traj_analyse_select.py @@ -0,0 +1,161 @@ +import os +import argparse +import shutil + +import pandas as pd +import MDAnalysis as mda +import numpy as np +import matplotlib.pyplot as plt +import multiprocessing as mp + +from scipy.stats import gaussian_kde +from MDAnalysis.analysis import align + + + +def cal_energy(para1): + file_md, dirpath = para1 + mdpath = os.path.join(dirpath, file_md) + filename = file_md + + k=2.32*1e-4 # unit(eV/K) + T=298.15 # unit(K) + + pdb_filepath = os.path.join(mdpath, filename+".pdb") + topology_filepath = os.path.join(mdpath, filename+".pdb") + + u_ref = mda.Universe(pdb_filepath) + protein_ref = u_ref.select_atoms('protein') + bb_atom_ref = protein_ref.select_atoms('name CA or name C or name N') + + info = { + 'rad_gyr': [], + 'rmsd_ref':[], + 'traj_filename':[], + 'energy':[], + } + + for xtc_idx in range(1,4): + trajectory_filepath = os.path.join(mdpath,filename+"_R"+str(xtc_idx)+".xtc") + + u = mda.Universe(topology_filepath, trajectory_filepath) + + protein = u.select_atoms('protein') + bb_atom = protein.select_atoms('name CA or name C or name N') + + # CA_atoms = u.select_atoms('name CA') + # bb_atoms = u.select_atoms('backbone') + + count = 0 + # for ts in u.trajectory: + for _ in u.trajectory: + count += 1 + + rad_gyr = bb_atom.radius_of_gyration() + rmsd_ref = align.alignto(bb_atom, bb_atom_ref, select='all', match_atoms=False)[-1] + info['rad_gyr'].append(rad_gyr) + info['rmsd_ref'].append(rmsd_ref) + + traj_filename = filename + '_R' + str(xtc_idx) + '_'+str(count)+".pdb" + info['traj_filename'].append(traj_filename) + print(traj_filename) + protein.write(os.path.join(mdpath, traj_filename)) + + + info_array = np.stack([info['rad_gyr'],info['rmsd_ref']],axis=0) # (2,2500) + kde = gaussian_kde(info_array) + density = kde(info_array) # (2500,) + G = k*T*np.log(np.max(density)/density) # (2500,) + G = (G-np.min(G))/(np.max(G)-np.min(G)) + + info['energy'] += G.tolist() + + out_total = pd.DataFrame(info) + x, y = np.meshgrid(np.linspace(min(out_total['rad_gyr'])-0.25, max(out_total['rad_gyr'])+0.25, 200), + np.linspace(min(out_total['rmsd_ref'])-0.25, max(out_total['rmsd_ref'])+0.25, 200)) + grid_coordinates = np.vstack([x.ravel(), y.ravel()]) + density_values = kde(grid_coordinates) + # 将密度值变形为与网格坐标相同的形状 + density_map = density_values.reshape(x.shape) + # 绘制高斯核密度估计图 + plt.contourf(x, y, density_map, levels= np.arange(np.max(density_map)/20, np.max(density_map)*1.1, np.max(density_map)/10)) + plt.colorbar() + + plt.savefig(os.path.join(mdpath,"md.png")) + plt.close() + + out_total.to_csv(os.path.join(mdpath,"traj_info.csv"),index=False) + + +def select_str(file, data_dir, output_dir, select_num=100): + info_total = { + 'rad_gyr': [], + 'rmsd_ref': [], + 'traj_filename': [], + 'energy': [], + } + + print(f"Processing {file}") + md_dir = os.path.join(data_dir, file) + md_csv = pd.read_csv(os.path.join(md_dir, 'traj_info.csv')) + md_csv = md_csv.sort_values('energy', ascending=True) + + idx_total = np.linspace(0, len(md_csv) - 1, select_num) + idx_total = (idx_total / idx_total[-1]) ** (1 / 3) * (len(md_csv) - 1) # mapping energy with f(x)=x**(1/3) to get more relatively high-energy structure + idx_total = np.unique(np.round(idx_total).astype(int)) + + for idx in idx_total: + info = md_csv.iloc[idx] + traj_filename = info['traj_filename'] + shutil.copy(os.path.join(md_dir, traj_filename), output_dir) + + info_total['traj_filename'].append(traj_filename) + info_total['energy'].append(info['energy']) + info_total['rad_gyr'].append(info['rad_gyr']) + info_total['rmsd_ref'].append(info['rmsd_ref']) + + return info_total + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dir_path", type=str, default="./dataset/ATLAS") + parser.add_argument("--filename", type=str, default="ATLAS_filename.txt") + + parser.add_argument("--select_num", type=int, default=100) + parser.add_argument("--select_dir", type=str, default="./dataset/ATLAS/select") + + args = parser.parse_args() + + num_processes = 48 + + file_txt = os.path.join(args.dir_path, args.filename) + os.makedirs(args.select_dir, exist_ok=True) + + with open(file_txt,'r+') as f: + file_cont = f.read() + file_list = file_cont.split("\n") + + para1_list = [(file, args.dir_path) for file in file_list] + para2_list = [(file, args.dir_path, args.select_dir, args.select_num) for file in file_list] + + info_total_all = { + 'rad_gyr': [], + 'rmsd_ref': [], + 'traj_filename': [], + 'energy': [], + } + + with mp.Pool(num_processes) as pool: + _ = pool.map(cal_energy, para1_list) + results = pool.starmap(select_str, para2_list) + + for result in results: + info_total_all['traj_filename'].extend(result['traj_filename']) + info_total_all['energy'].extend(result['energy']) + info_total_all['rad_gyr'].extend(result['rad_gyr']) + info_total_all['rmsd_ref'].extend(result['rmsd_ref']) + + df = pd.DataFrame(info_total_all) + df.to_csv(os.path.join(args.select_dir, 'traj_info_select.csv'), index=False) \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..596341a039e2dacbe8f5b3db42942328610cf71b --- /dev/null +++ b/environment.yml @@ -0,0 +1,481 @@ +name: P2DFlow +channels: + - pytorch + - nvidia + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge + - conda-forge +dependencies: + - pip + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - anyio=3.7.1=pyhd8ed1ab_0 + - argon2-cffi=21.3.0=pyhd8ed1ab_0 + - argon2-cffi-bindings=21.2.0=py310h5764c6d_3 + - arrow=1.2.3=pyhd8ed1ab_0 + - asttokens=2.2.1=pyhd8ed1ab_0 + - astunparse=1.6.3=pyhd8ed1ab_0 + - async-lru=2.0.4=pyhd8ed1ab_0 + - attrs=23.1.0=pyh71513ae_1 + - babel=2.12.1=pyhd8ed1ab_1 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=pyhd8ed1ab_3 + - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 + - beautifulsoup4=4.12.2=pyha770c72_0 + - biopython=1.81=py310h1fa729e_0 + - bleach=6.0.0=pyhd8ed1ab_0 + - blosc=1.21.4=h0f2a231_0 + - boltons=23.1.1=pyhd8ed1ab_0 + - boost=1.78.0=py310hcb52e73_5 + - boost-cpp=1.78.0=h2c5509c_4 + - brotli=1.0.9=h166bdaf_9 + - brotli-bin=1.0.9=h166bdaf_9 + - brotli-python=1.0.9=py310hd8f1fbe_9 + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.19.1=hd590300_0 + - c-blosc2=2.10.2=hb4ffafa_0 + - ca-certificates=2024.2.2=hbcca054_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cairo=1.18.0=h3faef2a_0 + - certifi=2024.2.2=pyhd8ed1ab_0 + - cffi=1.15.1=py310h255011f_3 + - cftime=1.6.3=py310h1f7b6fc_0 + - chardet=5.2.0=py310hff52083_1 + - charset-normalizer=3.2.0=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - comm=0.1.4=pyhd8ed1ab_0 + - conda=23.7.4=py310hff52083_0 + - conda-package-handling=2.2.0=pyh38be061_0 + - conda-package-streaming=0.9.0=pyhd8ed1ab_0 + - contourpy=1.1.0=py310hd41b1e2_0 + - cryptography=41.0.7=py310hb8475ec_1 + - cuda=11.6.0=0 + - cuda-cccl=11.6.55=hf6102b2_0 + - cuda-command-line-tools=11.6.2=0 + - cuda-compiler=11.6.2=0 + - cuda-cudart=11.6.55=he381448_0 + - cuda-cudart-dev=11.6.55=h42ad0f4_0 + - cuda-cuobjdump=11.6.124=h2eeebcb_0 + - cuda-cupti=11.6.124=h86345e5_0 + - cuda-cuxxfilt=11.6.124=hecbf4f6_0 + - cuda-driver-dev=11.6.55=0 + - cuda-gdb=12.0.90=hd47b8d6_0 + - cuda-libraries=11.6.2=0 + - cuda-libraries-dev=11.6.0=0 + - cuda-memcheck=11.8.86=0 + - cuda-nsight=12.0.78=ha770c72_0 + - cuda-nsight-compute=12.2.1=0 + - cuda-nvcc=11.6.124=hbba6d2d_0 + - cuda-nvdisasm=12.0.76=h59595ed_0 + - cuda-nvml-dev=11.6.55=haa9ef22_0 + - cuda-nvprof=12.0.90=h59595ed_0 + - cuda-nvprune=11.6.124=he22ec0a_0 + - cuda-nvrtc=11.6.124=h020bade_0 + - cuda-nvrtc-dev=11.6.124=h249d397_0 + - cuda-nvtx=11.6.124=h0630a44_0 + - cuda-nvvp=12.0.90=h59595ed_0 + - cuda-runtime=11.6.2=0 + - cuda-samples=11.6.101=h8efea70_0 + - cuda-sanitizer-api=12.0.90=h59595ed_0 + - cuda-toolkit=11.6.0=0 + - cuda-tools=11.6.0=0 + - cuda-version=12.0=hffde075_2 + - cuda-visual-tools=11.6.0=0 + - cycler=0.11.0=pyhd8ed1ab_0 + - debugpy=1.6.8=py310hc6cd4ac_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - exceptiongroup=1.1.3=pyhd8ed1ab_0 + - executing=1.2.0=pyhd8ed1ab_0 + - expat=2.6.2=h59595ed_0 + - fasteners=0.17.3=pyhd8ed1ab_0 + - flit-core=3.9.0=pyhd8ed1ab_0 + - fmt=10.1.1=h00ab1b0_1 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_1 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.42.0=py310h2372a71_0 + - fqdn=1.5.1=pyhd8ed1ab_0 + - freetype=2.12.1=hca18f0e_1 + - freetype-py=2.3.0=pyhd8ed1ab_0 + - gds-tools=1.5.0.59=hcb278e6_0 + - gmp=6.2.1=h58526e2_0 + - greenlet=3.0.3=py310hc6cd4ac_0 + - griddataformats=1.0.2=pyhd8ed1ab_0 + - gsd=3.2.0=py310h1f7b6fc_0 + - h5py=3.9.0=nompi_py310hcca72df_101 + - hdf4=4.2.15=h501b40f_6 + - hdf5=1.14.1=nompi_h4f84152_100 + - icu=73.2=h59595ed_0 + - idna=3.4=pyhd8ed1ab_0 + - importlib-metadata=6.8.0=pyha770c72_0 + - importlib_metadata=6.8.0=hd8ed1ab_0 + - importlib_resources=6.0.1=pyhd8ed1ab_0 + - ipykernel=6.25.1=pyh71e2992_0 + - ipython=8.14.0=pyh41d4057_0 + - ipywidgets=8.1.1=pyhd8ed1ab_0 + - isoduration=20.11.0=pyhd8ed1ab_0 + - jedi=0.19.0=pyhd8ed1ab_0 + - jinja2=3.1.2=pyhd8ed1ab_1 + - joblib=1.3.2=pyhd8ed1ab_0 + - json5=0.9.14=pyhd8ed1ab_0 + - jsonpatch=1.33=pyhd8ed1ab_0 + - jsonpointer=2.0=py_0 + - jsonschema=4.19.0=pyhd8ed1ab_1 + - jsonschema-specifications=2023.7.1=pyhd8ed1ab_0 + - jsonschema-with-format-nongpl=4.19.0=pyhd8ed1ab_1 + - jupyter-lsp=2.2.0=pyhd8ed1ab_0 + - jupyter_client=8.3.0=pyhd8ed1ab_0 + - jupyter_core=5.3.1=py310hff52083_0 + - jupyter_events=0.7.0=pyhd8ed1ab_2 + - jupyter_server=2.7.1=pyhd8ed1ab_0 + - jupyter_server_terminals=0.4.4=pyhd8ed1ab_1 + - jupyterlab=4.0.5=pyhd8ed1ab_0 + - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 + - jupyterlab_server=2.24.0=pyhd8ed1ab_0 + - jupyterlab_widgets=3.0.9=pyhd8ed1ab_0 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.4=py310hbf28c38_1 + - krb5=1.21.2=h659d440_0 + - lcms2=2.15=haa2dc70_1 + - ld_impl_linux-64=2.40=h41732ed_0 + - lerc=4.0.0=h27087fc_0 + - libaec=1.0.6=hcb278e6_1 + - libarchive=3.6.2=h039dbb9_1 + - libblas=3.9.0=17_linux64_openblas + - libbrotlicommon=1.0.9=h166bdaf_9 + - libbrotlidec=1.0.9=h166bdaf_9 + - libbrotlienc=1.0.9=h166bdaf_9 + - libcblas=3.9.0=17_linux64_openblas + - libcublas=12.0.1.189=hcb278e6_2 + - libcublas-dev=12.0.1.189=hcb278e6_2 + - libcufft=11.0.0.21=hcb278e6_1 + - libcufft-dev=11.0.0.21=hcb278e6_1 + - libcufile=1.5.0.59=hcb278e6_0 + - libcufile-dev=1.5.0.59=hcb278e6_0 + - libcurand=10.3.1.50=hcb278e6_0 + - libcurand-dev=10.3.1.50=hcb278e6_0 + - libcurl=8.2.1=hca28451_0 + - libcusolver=11.4.2.57=hcb278e6_1 + - libcusparse=12.0.0.76=hcb278e6_1 + - libdeflate=1.18=h0b41bf4_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libexpat=2.6.2=h59595ed_0 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=13.1.0=he5830b7_0 + - libgfortran-ng=13.1.0=h69a702a_0 + - libgfortran5=13.1.0=h15d22d2_0 + - libglib=2.80.0=hf2295e7_5 + - libgomp=13.1.0=he5830b7_0 + - libiconv=1.17=hd590300_2 + - libjpeg-turbo=2.1.5.1=h0b41bf4_0 + - liblapack=3.9.0=17_linux64_openblas + - libmamba=1.5.1=h744094f_0 + - libmambapy=1.5.1=py310h39ff949_0 + - libnetcdf=4.9.2=nompi_he09a3a9_107 + - libnghttp2=1.52.0=h61bc06f_0 + - libnpp=12.0.0.30=h59595ed_0 + - libnpp-dev=12.0.0.30=h59595ed_0 + - libnsl=2.0.0=h7f98852_0 + - libnuma=2.0.16=h0b41bf4_1 + - libnvjitlink=12.0.76=hcb278e6_1 + - libnvjpeg=12.0.0.28=hcb278e6_0 + - libnvjpeg-dev=12.0.0.28=ha770c72_0 + - libopenblas=0.3.23=pthreads_h80387f5_0 + - libpng=1.6.39=h753d276_0 + - libsodium=1.0.18=h36c2ea0_1 + - libsolv=0.7.27=hfc55251_0 + - libsqlite=3.42.0=h2797004_0 + - libssh2=1.11.0=h0841786_0 + - libstdcxx-ng=13.1.0=hfd8a6a1_0 + - libtiff=4.5.1=h8b53f26_0 + - libuuid=2.38.1=h0b41bf4_0 + - libwebp-base=1.3.1=hd590300_0 + - libxcb=1.15=h0b41bf4_0 + - libxml2=2.12.4=h232c23b_1 + - libzip=1.10.1=h2629f0a_3 + - libzlib=1.2.13=hd590300_5 + - lz4-c=1.9.4=hcb278e6_0 + - lzo=2.10=h516909a_1000 + - mamba=1.5.1=py310h51d5547_0 + - markupsafe=2.1.3=py310h2372a71_0 + - matplotlib-base=3.7.2=py310hf38f957_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 + - mda-xdrlib=0.2.0=pyhd8ed1ab_0 + - mdanalysis=2.7.0=py310hcc13569_1 + - mdtraj=1.9.9=py310h8e08b51_0 + - mistune=3.0.1=pyhd8ed1ab_0 + - mmtf-python=1.1.3=pyhd8ed1ab_0 + - mrcfile=1.5.0=pyhd8ed1ab_0 + - msgpack-python=1.0.7=py310hd41b1e2_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - nbclient=0.8.0=pyhd8ed1ab_0 + - nbconvert-core=7.7.4=pyhd8ed1ab_0 + - nbformat=5.9.2=pyhd8ed1ab_0 + - ncurses=6.4=hcb278e6_0 + - nest-asyncio=1.5.6=pyhd8ed1ab_0 + - netcdf4=1.6.4=nompi_py310h6f5dce6_101 + - nglview=3.1.1=pyh15ce09e_0 + - nomkl=1.0=h5ca1d4c_0 + - notebook-shim=0.2.3=pyhd8ed1ab_0 + - nsight-compute=2023.2.1.3=0 + - numexpr=2.8.4=py310hd91493a_101 + - numpy=1.25.2=py310ha4c1d20_0 + - openjpeg=2.5.0=hfec8fc6_2 + - openssl=3.2.1=hd590300_1 + - overrides=7.4.0=pyhd8ed1ab_0 + - packaging=23.1=pyhd8ed1ab_0 + - pandas=2.0.3=py310h7cbd5c2_1 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - parso=0.8.3=pyhd8ed1ab_0 + - patsy=0.5.3=pyhd8ed1ab_0 + - pcre2=10.43=hcad00b1_0 + - pexpect=4.8.0=pyh1a96a4e_2 + - pickleshare=0.7.5=py_1003 + - pillow=10.0.0=py310h582fbeb_0 + - pixman=0.43.2=h59595ed_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 + - platformdirs=3.10.0=pyhd8ed1ab_0 + - pluggy=1.3.0=pyhd8ed1ab_0 + - pooch=1.7.0=pyha770c72_3 + - prometheus_client=0.17.1=pyhd8ed1ab_0 + - prompt-toolkit=3.0.39=pyha770c72_0 + - prompt_toolkit=3.0.39=hd8ed1ab_0 + - psutil=5.9.5=py310h1fa729e_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - py-cpuinfo=9.0.0=pyhd8ed1ab_0 + - pybind11-abi=4=hd8ed1ab_3 + - pycairo=1.26.0=py310hda9f760_0 + - pycosat=0.6.6=py310h2372a71_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pyedr=0.8.0=pyhd8ed1ab_0 + - pygments=2.16.1=pyhd8ed1ab_0 + - pyopenssl=23.3.0=pyhd8ed1ab_0 + - pyparsing=3.0.9=pyhd8ed1ab_0 + - pysocks=1.7.1=pyha2e5f31_6 + - pytables=3.8.0=py310ha028ce3_2 + - python=3.10.12=hd12c33a_0_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-fastjsonschema=2.18.0=pyhd8ed1ab_0 + - python-json-logger=2.0.7=pyhd8ed1ab_0 + - python-tzdata=2023.3=pyhd8ed1ab_0 + - python_abi=3.10=3_cp310 + - pytng=0.3.1=py310h7d11597_1 + - pytorch-cuda=11.6=h867d48c_0 + - pytz=2023.3=pyhd8ed1ab_0 + - pyyaml=6.0=py310h5764c6d_5 + - pyzmq=25.1.1=py310h5bbb5d0_0 + - rdkit=2023.03.3=py310h399bcf7_0 + - readline=8.2=h8228510_1 + - referencing=0.30.2=pyhd8ed1ab_0 + - reportlab=4.1.0=py310h2372a71_0 + - reproc=14.2.4.post0=hd590300_1 + - reproc-cpp=14.2.4.post0=h59595ed_1 + - requests=2.31.0=pyhd8ed1ab_0 + - rfc3339-validator=0.1.4=pyhd8ed1ab_0 + - rfc3986-validator=0.1.1=pyh9f0ad1d_0 + - rlpycairo=0.2.0=pyhd8ed1ab_0 + - rpds-py=0.9.2=py310hcb5633a_0 + - ruamel.yaml=0.17.40=py310h2372a71_0 + - ruamel.yaml.clib=0.2.7=py310h2372a71_2 + - scikit-learn=1.3.2=py310h1fdf081_2 + - scipy=1.11.1=py310ha4c1d20_0 + - seaborn=0.12.2=hd8ed1ab_0 + - seaborn-base=0.12.2=pyhd8ed1ab_0 + - send2trash=1.8.2=pyh41d4057_0 + - setuptools=68.1.2=pyhd8ed1ab_0 + - six=1.16.0=pyh6c4a22f_0 + - snappy=1.1.10=h9fff704_0 + - sniffio=1.3.0=pyhd8ed1ab_0 + - soupsieve=2.3.2.post1=pyhd8ed1ab_0 + - sqlalchemy=2.0.29=py310h2372a71_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - statsmodels=0.14.0=py310h278f3c1_1 + - terminado=0.17.1=pyh41d4057_0 + - threadpoolctl=3.2.0=pyha21a80b_0 + - tidynamics=1.1.2=pyhd8ed1ab_0 + - tinycss2=1.2.1=pyhd8ed1ab_0 + - tk=8.6.12=h27826a3_0 + - tomli=2.0.1=pyhd8ed1ab_0 + - toolz=0.12.0=pyhd8ed1ab_0 + - tornado=6.3.3=py310h2372a71_0 + - tqdm=4.66.1=pyhd8ed1ab_0 + - traitlets=5.9.0=pyhd8ed1ab_0 + - typing-extensions=4.7.1=hd8ed1ab_0 + - typing_extensions=4.7.1=pyha770c72_0 + - typing_utils=0.1.0=pyhd8ed1ab_0 + - tzdata=2023c=h71feb2d_0 + - unicodedata2=15.0.0=py310h5764c6d_0 + - uri-template=1.3.0=pyhd8ed1ab_0 + - urllib3=2.0.4=pyhd8ed1ab_0 + - wcwidth=0.2.6=pyhd8ed1ab_0 + - webcolors=1.13=pyhd8ed1ab_0 + - webencodings=0.5.1=py_1 + - websocket-client=1.6.1=pyhd8ed1ab_0 + - wheel=0.41.1=pyhd8ed1ab_0 + - widgetsnbextension=4.0.9=pyhd8ed1ab_0 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.9=h8ee46fc_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxrender=0.9.11=hd590300_0 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - yaml=0.2.5=h7f98852_2 + - yaml-cpp=0.7.0=h59595ed_3 + - zeromq=4.3.4=h9c3ff4c_1 + - zipp=3.16.2=pyhd8ed1ab_0 + - zlib=1.2.13=hd590300_5 + - zlib-ng=2.0.7=h0b41bf4_0 + - zstandard=0.19.0=py310h1275a96_2 + - zstd=1.5.2=hfc55251_7 + - pip: + - absl-py==1.4.0 + - aiohttp==3.8.5 + - aiosignal==1.3.1 + - alabaster==0.7.16 + - annotated-types==0.5.0 + - antlr4-python3-runtime==4.9.3 + - appdirs==1.4.4 + - async-timeout==4.0.3 + - backoff==2.2.1 + - biotite==0.37.0 + - blessed==1.20.0 + - cachetools==5.3.1 + - click==8.1.7 + - cmake==3.27.2 + - contextlib2==21.6.0 + - croniter==1.4.1 + - dateutils==0.6.12 + - deepdiff==6.3.1 + - deepspeed==0.12.3 + - dm-tree==0.1.8 + - docker-pycreds==0.4.0 + - docutils==0.21.2 + - e3nn==0.5.1 + - einops==0.6.1 + - fair-esm==2.0.0 + - fastapi==0.101.1 + - filelock==3.12.2 + - frozenlist==1.4.0 + - fsspec==2023.6.0 + - gitdb==4.0.10 + - gitpython==3.1.32 + - gpustat==1.1 + - gputil==1.4.0 + - h11==0.14.0 + - hjson==3.1.0 + - hydra-core==1.3.2 + - imagesize==1.4.1 + - inquirer==3.1.3 + - ipdb==0.13.13 + - ipython-genutils==0.2.0 + - itsdangerous==2.1.2 + - jupyter==1.0.0 + - jupyter-console==6.6.3 + - jupyter-contrib-core==0.4.2 + - jupyter-contrib-nbextensions==0.7.0 + - jupyter-highlight-selected-word==0.2.0 + - jupyter-nbextensions-configurator==0.6.3 + - lightning==2.0.7 + - lightning-cloud==0.5.37 + - lightning-utilities==0.9.0 + - lit==16.0.6 + - lxml==5.1.0 + - markdown-it-py==3.0.0 + - mdurl==0.1.2 + - mizani==0.9.3 + - ml-collections==0.1.1 + - mpmath==1.3.0 + - msgpack==1.0.5 + - multidict==6.0.4 + - networkx==3.1 + - ninja==1.11.1 + - notebook==7.0.7 + - numpydoc==1.7.0 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-cupti-cu11==11.7.101 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cufft-cu11==10.9.0.58 + - nvidia-curand-cu11==10.2.10.91 + - nvidia-cusolver-cu11==11.4.0.1 + - nvidia-cusparse-cu11==11.7.4.91 + - nvidia-ml-py==12.535.77 + - nvidia-nccl-cu11==2.14.3 + - nvidia-nvtx-cu11==11.7.91 + - nvitop==1.3.0 + - oddt==0.7 + - omegaconf==2.3.0 + - opt-einsum==3.3.0 + - opt-einsum-fx==0.1.4 + - ordered-set==4.1.0 + - pathtools==0.1.2 + - pip==24.1.2 + - plotly==5.16.1 + - plotnine==0.12.3 + - protobuf==4.24.1 + - pydantic==2.1.1 + - pydantic-core==2.4.0 + - pyg-lib==0.4.0+pt20cu117 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - pyjwt==2.8.0 + - pynvml==11.5.0 + - python-editor==1.0.4 + - python-multipart==0.0.6 + - pytorch-lightning==2.0.7 + - qtconsole==5.5.1 + - qtpy==2.4.1 + - readchar==4.0.5 + - rich==13.5.2 + - sentry-sdk==1.29.2 + - setproctitle==1.3.2 + - smmap==5.0.0 + - snakeviz==2.2.0 + - snowballstemmer==2.2.0 + - sphinx==7.3.7 + - sphinxcontrib-applehelp==1.0.8 + - sphinxcontrib-devhelp==1.0.6 + - sphinxcontrib-htmlhelp==2.0.5 + - sphinxcontrib-jsmath==1.0.1 + - sphinxcontrib-qthelp==1.0.7 + - sphinxcontrib-serializinghtml==1.1.10 + - starlette==0.27.0 + - starsessions==1.3.0 + - swig==4.2.1 + - sympy==1.12 + - tabulate==0.9.0 + - tenacity==8.2.3 + - termcolor==2.3.0 + - tmtools==0.0.3 + - torch==2.0.1 + - torch-cluster==1.6.3+pt20cu117 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - torch-geometric==2.5.2 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - torch-scatter==2.1.2+pt20cu117 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - torch-sparse==0.6.18+pt20cu117 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - torch-spline-conv==1.2.2+pt20cu117 + - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html + - torchmetrics==1.0.3 + - triton==2.0.0 + - uvicorn==0.23.2 + - wandb==0.15.8 + - websockets==11.0.3 + - yarl==1.9.2 diff --git a/experiments/__pycache__/utils.cpython-310.pyc b/experiments/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c399f62acbebf8b735f941099a0491a7fcbd9aa5 Binary files /dev/null and b/experiments/__pycache__/utils.cpython-310.pyc differ diff --git a/experiments/inference_se3_flows.py b/experiments/inference_se3_flows.py new file mode 100644 index 0000000000000000000000000000000000000000..7505c143e0aef78d3ff86b72ead8e664cfdb2389 --- /dev/null +++ b/experiments/inference_se3_flows.py @@ -0,0 +1,119 @@ +"""DDP inference script.""" +import os +import time +import numpy as np +import hydra +import torch +import GPUtil +import sys + +from pytorch_lightning import Trainer +from omegaconf import DictConfig, OmegaConf +from experiments import utils as eu +from models.flow_module import FlowModule +import re +from typing import Optional +import subprocess +from biotite.sequence.io import fasta +from data import utils as du +from analysis import metrics +import pandas as pd +import esm +import shutil +import biotite.structure.io as bsio + + +torch.set_float32_matmul_precision('high') +log = eu.get_pylogger(__name__) + +class Sampler: + + def __init__(self, cfg: DictConfig): + """Initialize sampler. + + Args: + cfg: inference config. + """ + ckpt_path = cfg.inference.ckpt_path + ckpt_dir = os.path.dirname(ckpt_path) + # ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml')) + # ckpt_cfg = torch.load(ckpt_path, map_location="cpu")['hyper_parameters']['cfg'] + + # Set-up config. + OmegaConf.set_struct(cfg, False) + # OmegaConf.set_struct(ckpt_cfg, False) + # cfg = OmegaConf.merge(cfg, ckpt_cfg) + # cfg.experiment.checkpointer.dirpath = './' + + self._cfg = cfg + # self._pmpnn_dir = cfg.inference.pmpnn_dir + self._infer_cfg = cfg.inference + self._samples_cfg = self._infer_cfg.samples + # self._rng = np.random.default_rng(self._infer_cfg.seed) + + # Set-up directories to write results to + self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:]) + self._output_dir = os.path.join( + self._infer_cfg.output_dir, + self._ckpt_name, + self._infer_cfg.name, + ) + os.makedirs(self._output_dir, exist_ok=True) + log.info(f'Saving results to {self._output_dir}') + config_path = os.path.join(self._output_dir, 'config.yaml') + with open(config_path, 'w') as f: + OmegaConf.save(config=self._cfg, f=f) + log.info(f'Saving inference config to {config_path}') + + # Read checkpoint and initialize module. + self._flow_module = FlowModule.load_from_checkpoint( + checkpoint_path=ckpt_path, + ) + self._flow_module.eval() + self._flow_module._infer_cfg = self._infer_cfg + self._flow_module._samples_cfg = self._samples_cfg + self._flow_module._output_dir = self._output_dir + + # devices = GPUtil.getAvailable( + # order='memory', limit = 8)[:4] + # print(GPUtil.getAvailable(order='memory', limit = 8)) + + devices = [torch.cuda.current_device()] + + self._folding_model = esm.pretrained.esmfold_v1().eval() + self._folding_model = self._folding_model.to(devices[-1]) + + def run_sampling(self): + # devices = GPUtil.getAvailable( + # order='memory', limit = 8)[:self._infer_cfg.num_gpus] + devices = [torch.cuda.current_device()] + + log.info(f"Using devices: {devices}") + + eval_dataset = eu.LengthDataset(self._samples_cfg) + dataloader = torch.utils.data.DataLoader( + eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False) + + trainer = Trainer( + accelerator="gpu", + strategy="ddp", + devices=devices, + ) + trainer.predict(self._flow_module, dataloaders=dataloader) + + + +@hydra.main(version_base=None, config_path="../configs", config_name="inference") +def run(cfg: DictConfig) -> None: + + # Read model checkpoint. + log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs') + start_time = time.time() + sampler = Sampler(cfg) + sampler.run_sampling() + #sampler.eval_test() + elapsed_time = time.time() - start_time + log.info(f'Finished in {elapsed_time:.2f}s') + +if __name__ == '__main__': + run() diff --git a/experiments/train_se3_flows.py b/experiments/train_se3_flows.py new file mode 100644 index 0000000000000000000000000000000000000000..a80bbc90c2354c08d6af5a59074be70e3323cc25 --- /dev/null +++ b/experiments/train_se3_flows.py @@ -0,0 +1,107 @@ +import os +import GPUtil +import torch +import sys +import hydra +import wandb + +# Pytorch lightning imports +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint + +from omegaconf import DictConfig, OmegaConf +from data.pdb_dataloader import PdbDataModule +from models.flow_module import FlowModule +from experiments import utils as eu + + +os.environ["WANDB_MODE"] = "offline" +log = eu.get_pylogger(__name__) +torch.set_float32_matmul_precision('high') + + +class Experiment: + + def __init__(self, *, cfg: DictConfig): + self._cfg = cfg + self._data_cfg = cfg.data + self._exp_cfg = cfg.experiment + self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg) + self._model: LightningModule = FlowModule(self._cfg) + + def train(self): + callbacks = [] + if self._exp_cfg.debug: + log.info("Debug mode.") + logger = None + self._exp_cfg.num_devices = 1 + self._data_cfg.loader.num_workers = 0 + else: + logger = WandbLogger( + **self._exp_cfg.wandb, + ) + + # Checkpoint directory. + ckpt_dir = self._exp_cfg.checkpointer.dirpath + os.makedirs(ckpt_dir, exist_ok=True) + log.info(f"Checkpoints saved to {ckpt_dir}") + + # Model checkpoints + callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer)) + + # Save config + cfg_path = os.path.join(ckpt_dir, 'config.yaml') + with open(cfg_path, 'w') as f: + OmegaConf.save(config=self._cfg, f=f.name) + cfg_dict = OmegaConf.to_container(self._cfg, resolve=True) + flat_cfg = dict(eu.flatten_dict(cfg_dict)) + if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config): + logger.experiment.config.update(flat_cfg) + + devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices] + log.info(f"Using devices: {devices}") + trainer = Trainer( + **self._exp_cfg.trainer, + callbacks=callbacks, + logger=logger, + use_distributed_sampler=False, + enable_progress_bar=True, + enable_model_summary=True, + devices=devices, + ) + + if self._exp_cfg.warm_start is not None: + # Lightning will always save the hyperparameters to the checkpoint + self._model = self._model.load_from_checkpoint(self._exp_cfg.warm_start, strict=False, map_location="cpu") + + trainer.fit( + model=self._model, + datamodule=self._datamodule, + # ckpt_path=self._exp_cfg.warm_start + ) + + +@hydra.main(version_base=None, config_path="../configs", config_name="base.yaml") +def main(cfg: DictConfig): + + if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override: + # Loads warm start config. + warm_start_cfg_path = os.path.join( + os.path.dirname(cfg.experiment.warm_start), 'config.yaml') + warm_start_cfg = OmegaConf.load(warm_start_cfg_path) + + # Warm start config may not have latest fields in the base config. + # Add these fields to the warm start config. + OmegaConf.set_struct(cfg.model, False) + OmegaConf.set_struct(warm_start_cfg.model, False) + cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model) + OmegaConf.set_struct(cfg.model, True) + log.info(f'Loaded warm start config from {warm_start_cfg_path}') + + exp = Experiment(cfg=cfg) + exp.train() + +if __name__ == "__main__": + main() diff --git a/experiments/utils.py b/experiments/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b97b7fedc6528c56e9067fe7ff2af00dd6240fec --- /dev/null +++ b/experiments/utils.py @@ -0,0 +1,248 @@ +"""Utility functions for experiments.""" +import logging +import torch +import os +import re +import random +import esm + +import numpy as np +import pandas as pd +import random + +from analysis import utils as au +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from data.residue_constants import restype_order +from data.repr import get_pre_repr +from data import utils as du +from data.residue_constants import restype_atom37_mask +from openfold.data import data_transforms +from openfold.utils import rigid_utils +from data.cal_trans_rotmats import cal_trans_rotmats + + +class LengthDataset(torch.utils.data.Dataset): + def __init__(self, samples_cfg): + self._samples_cfg = samples_cfg + + # self._all_sample_seqs = ['LSEEEKKELEKEKRKKEIVEEVYKELKEEGKIKNLPKEEFMKKGLEILEKNEGKLKTDEEAKEELL', + # 'PSPEELEKKAKEERKLRIVKEVGEELRKEGLIKDLPEEEFLKKGLEILKENEGKLKTEEEAKEALLEKFK', + # 'PMLIEVLCRTDKVEELIKEIEELTKDKILEVKVEKIDENTVKIEIVLEKKEAAEKAAKWLSEVEGCEVIEMREV', + # 'PTKITVLCPKSKVEELIEEIKEKTNDKILSVEVEEISPDSVKINIILETEEAALKAAEWLSEVEGCEVLEISEVELE', + # 'SSSVKKRYKLTVKIEGASEEEFKELAELAKKLGEELGGLIEFEGDKENGKLTLLMESKEKAEKVGEALKEAGVKGGYTIEEFD', + # 'VTSITKRYKLTVKITGASAAEFAALGAAAEAQGKALGGLLSFTADAANGTITVLMDTKEKAEKIGDALKALGVKGGYTISEFLEAD', + # 'SKIEETKKKIAEGNYEEIKKLKEEIEKEKKKFEEEEKKEKEKAEELLKKDPEKGKKEKAKKEAEFEKKKKEYEEILKIIEKALKGKE', + # 'SRIEEVKKQIEESDKEGVKELKKEILKEYEKFKKEAEKEKAEAEKLKKEDPEKGAKEEAELKKKHEEEKKEYEKILEIIEKRLKGAEEGK', + # 'GEEALKLMEEELAAAKTEEAKKFMEGLKKMIEEIAKAMATGDPEVIEEGKKRLLEWGKEVGEKGKKEGNPFLIELEKIIEYMAEGEIEEGLKKLMEFLKKKR', + # 'GAEALALMDEMLAAAKREEDKAFYARLRELVRRLAAALATGDPAVLAAGRAEAAAEGDALGAEGRATGDPFLVELAAIVAALAAGTPEEGLAALAAFLRAKAAAR'] + + # self._all_filename = ['P450'] * 250 + # self._all_sample_seqs = [('GKLPPGPSPLPVLGNLLQMDRKGLLRSFLRLREKYGDVFTVYLGSRPVVVLCGTDAIREALVDQAEAFSGRGKIAVVDPIFQGYGVIFANGERWRALRRFSLATMRDFGMGKRSVEERIQEEARCLVEELRKSKGALLDNTLLFHSITSNIICSIVFGKRFDYKDPVFLRLLDLFFQSFSLISSFSSQVFELFSGFLKYFPGTHRQIYRNLQEINTFIGQSVEKHRATLDPSNPRDFIDVYLLRMEKDKSDPSSEFHHQNLILTVLSLFFAGTETTSTTLRYGFLLMLKYPHVTERVQKEIEQVIGSHRPPALDDRAKMPYTDAVIHEIQRLGDLIPFGVPHTVTKDTQFRGYVIPKNTEVFPVLSSALHDPRYFETPNTFNPGHFLDANGALKRNEGFMPFSLGKRICLGEGIARTELFLFFTTILQNFSIASPVPPEDIDLTPRESGVGNVPPSYQIRFLARH',0)] * 250 + + validcsv = pd.read_csv(self._samples_cfg.validset_path) + + self._all_sample_seqs = [] + self._all_filename = [] + + prob_num = 500 + exp_prob = np.exp([-prob/prob_num*2 for prob in range(prob_num)]).cumsum() + exp_prob = exp_prob/np.max(exp_prob) + + for idx in range(len(validcsv['seq'])): + + # if idx >= 0 and idx < 15: + # if idx >= 15 and idx < 30: + # if idx >= 30 and idx < 45: + # if idx >= 45 and idx < 60: + # if idx >= 60 and idx < 75: + # if idx >= 75 and idx < 90: + # if idx >= 90 and idx < 105: + # pass + # else: + # continue + + + # if not re.search('2wsi_A',validcsv['file'][idx]): + # continue + + + self._all_filename += [validcsv['file'][idx]] * self._samples_cfg.sample_num + + for batch_idx in range(self._samples_cfg.sample_num): + + rand = random.random() + for prob in range(prob_num): + if rand < exp_prob[prob]: + energy = torch.tensor(prob/prob_num) + break + + self._all_sample_seqs += [(validcsv['seq'][idx], energy)] + + + self._all_sample_ids = self._all_sample_seqs + + # Load ESM-2 model + self.device_esm=f'cuda:{torch.cuda.current_device()}' + self.model_esm2, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() + self.batch_converter = self.alphabet.get_batch_converter() + self.model_esm2.eval().cuda(self.device_esm) # disables dropout for deterministic results + self.model_esm2.requires_grad_(False) + + self._folding_model = esm.pretrained.esmfold_v1().eval() + self._folding_model = self._folding_model.to(self.device_esm) + + self.esm_savepath = self._samples_cfg.esm_savepath + + + self.device_esm=f'cuda:{torch.cuda.current_device()}' + self._folding_model = esm.pretrained.esmfold_v1().eval() + self._folding_model.requires_grad_(False) + self._folding_model.to(self.device_esm) + + def run_folding(self, sequence, save_path): + """Run ESMFold on sequence.""" + with torch.no_grad(): + output = self._folding_model.infer_pdb(sequence) + self._folding_model.to("cpu") + + with open(save_path, "w") as f: + f.write(output) + return output + + def __len__(self): + return len(self._all_sample_ids) + + def __getitem__(self, idx): + seq, energy = self._all_sample_ids[idx] + aatype = torch.tensor([restype_order[s] for s in seq]) + num_res = len(aatype) + + node_repr_pre, pair_repr_pre = get_pre_repr(aatype, self.model_esm2, + self.alphabet, self.batch_converter, device = self.device_esm) # (B,L,d_node_pre=1280), (B,L,L,d_edge_pre=20) + node_repr_pre = node_repr_pre[0].cpu() + pair_repr_pre = pair_repr_pre[0].cpu() + + motif_mask = torch.ones(aatype.shape) + + + save_path = os.path.join(self.esm_savepath, "esm_" + self._all_filename[idx] + ".pdb") + if not os.path.exists(save_path): + seq_string = seq + with torch.no_grad(): + output = self._folding_model.infer_pdb(seq_string) + with open(save_path, "w") as f: + f.write(output) + + + trans_esmfold, rotmats_esmfold = cal_trans_rotmats(save_path) + + batch = { + 'filename':self._all_filename[idx], + 'trans_esmfold': trans_esmfold, + 'rotmats_esmfold': rotmats_esmfold, + 'motif_mask': motif_mask, + 'res_mask': torch.ones(num_res).int(), + 'num_res': num_res, + 'energy': energy, + 'aatype': aatype, + 'seq': seq, + 'node_repr_pre': node_repr_pre, + 'pair_repr_pre': pair_repr_pre, + } + return batch + + + +def save_traj( + sample: np.ndarray, + bb_prot_traj: np.ndarray, + x0_traj: np.ndarray, + diffuse_mask: np.ndarray, + output_dir: str, + aatype = None, + index=0, + ): + """Writes final sample and reverse diffusion trajectory. + + Args: + bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states. + T is number of time steps. First time step is t=eps, + i.e. bb_prot_traj[0] is the final sample after reverse diffusion. + N is number of residues. + x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. + aatype: [T, N, 21] amino acid probability vector trajectory. + res_mask: [N] residue mask. + diffuse_mask: [N] which residues are diffused. + output_dir: where to save samples. + + Returns: + Dictionary with paths to saved samples. + 'sample_path': PDB file of final state of reverse trajectory. + 'traj_path': PDB file os all intermediate diffused states. + 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. + b_factors are set to 100 for diffused residues and 0 for motif + residues if there are any. + """ + + # Write sample. + diffuse_mask = diffuse_mask.astype(bool) # (B,L) + sample_path = os.path.join(output_dir, 'sample_'+str(index)+'.pdb') + prot_traj_path = os.path.join(output_dir, 'bb_traj_'+str(index)+'.pdb') + x0_traj_path = os.path.join(output_dir, 'x0_traj_'+str(index)+'.pdb') + + # Use b-factors to specify which residues are diffused. + b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37)) + + sample_path = au.write_prot_to_pdb( + sample, + sample_path, + b_factors=b_factors, + no_indexing=True, + aatype=aatype, + ) + prot_traj_path = au.write_prot_to_pdb( + bb_prot_traj, + prot_traj_path, + b_factors=b_factors, + no_indexing=True, + aatype=aatype, + ) + x0_traj_path = au.write_prot_to_pdb( + x0_traj, + x0_traj_path, + b_factors=b_factors, + no_indexing=True, + aatype=aatype + ) + return { + 'sample_path': sample_path, + 'traj_path': prot_traj_path, + 'x0_traj_path': x0_traj_path, + } + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +def flatten_dict(raw_dict): + """Flattens a nested dict.""" + flattened = [] + for k, v in raw_dict.items(): + if isinstance(v, dict): + flattened.extend([ + (f'{k}:{i}', j) for i, j in flatten_dict(v) + ]) + else: + flattened.append((k, v)) + return flattened diff --git a/inference/valid_seq.csv b/inference/valid_seq.csv new file mode 100644 index 0000000000000000000000000000000000000000..3f3fdb28dff2677d36fb777a59a38f17f3756ca4 --- /dev/null +++ b/inference/valid_seq.csv @@ -0,0 +1,102 @@ +file,seq_len,seq +6pce_B,70,PSMKERQVCWGARDEYWKCLDENLEDASQCKKLRSSFESSCPQQWIKYFDKRRDYLKFKEKFEAGQFEPS +1vk1_A,242,MGVEKVPKYDIPVKKVEYVFIELDKMKPHEQLVQRELEDFIESVTGSGIFWKPMLLAKIPGTDEYLIVDGHHRWAGLQKLGAKRAPSVILDYFDEGVKVYTWYPAFKGDVNKVIERLKAEGLEVIEDEKAEEKAEKGEIAFALIGEKSFAIPGGLEEQKKVSKVLDEMDQAKEIELVYYGLKEDAKADMEKGEIDYVFIRKAPTKEEVMELVKRGEVFSPKTTRHVLPFIPDKIDVKLEDLF +3nuf_A,119,GMTPLDANVELPTEVKAMIEQSSDAQAATALVNYVIKLAAAAEIHFTDLQLQVLTNHLIEMLGRSKSGEQLPAVDPTMFAEVSQKSLDLADQVVQHIGHLEVAEKYVLSIHFEAAQDKI +4wlr_B,102,QVDLASVLTPEIMAPILANADVQERLLPYLPSGESLPQTADEIQNTLTSPQFQQALGMFSAALASGQLGPLMCQFGLPAEAVEAANKGDVEAFAKAMQNNAK +2wkx_A,261,MGAGEKGIVEKEGYQLDTRRQAQAAYPRIKVLVIHYTADDFDSSLATLTDKQVSSHYLVPAVPPRYNGKPRIWQLVPEQELAWHAGISAWRGATRLNDTSIGIELENRGWQKSAGVKYFAPFEPAQIQALIPLAKDIIARYHIKPENVVAHADIAPQRKDDPGPLFPWQQLAQQGIGAWPDAQRVNFYLAGRAPHTPVDTASLLELLARYGYDVKPDMTPREQRRVIMAFQMHFRPTLYNGEADAETQAIAEALLEKYGQD +1xmk_A,79,GSHMASLDMAEIKEKICDYLFNVSDSSALNLAKNIGLTKARDINAVLIDMERQGDVYRQGTTPPIWHLTDKKRERMQIK +3vnb_A,155,MNKKMINGGTVNNWICINFSRQVQDNLARTFCQELAQMCYVSGMAFNPEPVLPPVSARPEQVEKVLKTRYHDATSKLSQGKEIDLLIVILPDKKNSDLYGDLKRICETELGIVSQCCLTKHVFKMSKQYMANVALKINVKVGGRNTVLVHHHHHH +6crk_G,71,MASNNTASIAQARKLVEQLKMEANIDRIKVSKAAADLMAYCEAHAKEDPLLTPVPASENPFREKKFFSAIL +5j3t_B,242,MSFTNATFSQVLDDLSARFILNLPAEEQSSVERLCFQIEQAHWFYEDFIRAQNDQLPSLGLRVFSAKLFAHCPLLWKWSKVHEEAFDDFLRYKTRIPVRGAIMLDMSMQQCVLVKGWKASSGWGFPKGKIDKDESDVDCAIREVYEETGFDCSSRINPNEFIDMTIRGQNVRLYIIPGISLDTRFESRTRKEISKIEWHNLMDLPTFKKNKPQTMKNKFYMVIPFLAPLKKWIKKRNIANNT +2prv_A,153,GMIYSKVENFINENKQNAIFTEGASHENIGRIEENLQCDLPNSYKWFLEKYGAGGLFGVLVLGYNFDHASVVNRTNEYKEHYGLTDGLVVIEDVDYFAYCLDTNKMKDGECPVVEWDRVIGYQDTVADSFIEFFYNKIQEAKDDWDEDEDWDD +3pes_A,85,SNAMLAEFEDRVAGIPCLIVVTYWEPYVPAKVSGPPEYCYPAEGGCGEWEVRDRRGRPAPWLERKLTEAERERIDQAVFDRMEGR +3u5v_A,76,MGADKRAHHNALERKRRRDINEAFRELGRMCQMHLKSDKAQTKLLILQQAVQVILGLEQQVRERNLNPLNHHHHHH +5ori_A,583,DTHKSEIAHRFNDLGEENFQGLVLIAFSQYLQQCPFDEHVKLVKELTEFAKTCVADESHAGCDKSLHTLFGDELCKVATLRETYGDMADCCEKQEPERNECFLKHKDDSPDLPKLKPEPDTLCAEFKADEKKFWGKYLYEVARRHPYFYAPELLYYANKYNGVFQECCQAEDKGACLLPKIETMREKVLASSARQRLRCASIQKFGERALKAWSVARLSQKFPKADFTDVTKIVTDLTKVHKECCHGDLLECADDRADLAKYICDHQDTLSSKLKECCDKPVLEKSHCIAEIDKDAVPENLPPLTADFAEDKEVCKNYQEAKDVFLGSFLYEYSRRHPEYAVSVLLRLAKEYEATLEDCCAKEDPHACYATVFDKLKHLVDEPQNLIKKNCELFEKHGEYGFQNALIVRYTRKAPQVSTPTLVEISRSLGKVGTKCCAKPESERMPCTEDYLSLILNRLCVLHEKTPVSEKVTKCCTESLVNRRPCFSDLTLDETYVPKPFDGESFTFHADICTLPDTEKQIKKQTALVELLKHKPKATDEQLKTVMENFVAFVDKCCAADDKEGCFLLEGPKLVASTQAALA +3wur_B,171,HHHHHHMTALTLPEDIRQQEPSALLYTLVSAYLEHTAQTGDESLSCLSDDQHTLTAFCYLDSQVEEGGFVQLIASGYGEYIFRNPLADSLRRWKIKAVPKVLDKAKALYEQHGKTIETLADGGADIPSLRKQFPEFEEWDGAYYEAAEQDLPLLAEHIQSNWETFAHIGQA +2ra8_A,362,MKRVFVFQDFKSQKFWSIDVRGTDVIVNYGKLGTDGQTQVKNFSSAGEAEKAAGKLIAEKTKKGYVETLEEVAKEMKVEAKKYALSYDEAEEGVNLMDKILKDKKLPSLKQITIGCWGYEGEDCSDIADGIVENKEKFAHFEGLFWGDIDFEEQEISWIEQVDLSPVLDAMPLLNNLKIKGTNNLSIGKKPRPNLKSLEIISGGLPDSVVEDILGSDLPNLEKLVLYVGVEDYGFDGDMNVFRPLFSKDRFPNLKWLGIVDAEEQNVVVEMFLESDILPQLETMDISAGVLTDEGARLLLDHVDKIKHLKFINMKYNYLSDEMKKELQKSLPMKIDVSDSQEYDDDYSYPMITELEHHHHHH +4xo1_A,60,MNIEELKKQAETEIADFIAQKIAEMNKNTGKEVSEMRFTAREKMTGLESYDVKIKIMLEH +3wx4_A,100,MIIDSQSVVQYTFKIDILEKLYKFLPNLYHSIVNELVEELHLENNDFLIGTYKDLSKAGYFYVIPAPGKNIDDVLKTIMIYVHDYEIEDYFELEHHHHHH +3dza_A,191,GATDSATAAPAAAATTQVQKEAADVLQVAVQGANAMRDIQFARLALFHGQPDSAKKLTDDAAALLAADDASWAKFVKTDAKAKMIADRYVIINASIALSEDYVATPEKESAIQSANEKLAKGDQKGAIDTLRLAGIGVIENQYLMPLNQTRKAVAQSQELLKAGKYYEANLVLKGAEEGIVVDSEMLVAGN +2au3_A,407,GHMSSDIDELRREIDIVDVISEYLNLEKVGSNYRTNCPFHPDDTPSFYVSPSKQIFKCFGCGVGGDAIKFVSLYEDISYFEAALELAKRYGKKLDLEKISKDEKVYVALDRVCDFYRESLLKNREASEYVKSRGIDPKVARKFDLGYAPSSEALVKVLKENDLLEAYLETKNLLSPTKGVYRDLFLRRVVIPIKDPRGRVIGFGGRRIVEDKSPKYINSPDSRVFKKGENLFGLYEAKEYIKEEGFAILVEGYFDLLRLFSEGIRNVVAPLGTALTQNQANLLSKFTKKVYILYDGDDAGRKAMKSAIPLLLSAGVEVYPVYLPEGYDPDEFIKEFGKEELRRLINSSGELFETLIKTARENLEEKTREFRYYLGFISDGVRRFALASEFHTKYKVPMEILLMKIEK +2wsi_A,306,MQLSKAAEMCYEITNSYLHIDQKSQIIASTQEAIRLTRKYLLSEIFVRWSPLNGEISFSYNGGKDCQVLLLLYLSCLWEYFFIKAQNSQFDFEFQSFPMQRLPTVFIDQEETFPTLENFVLETSERYCLSLYESQRQSGASVNMADAFRDFIKIYPETEAIVIGIRHTDPFGEALKPIQRTDSNWPDFMRLQPLLHWDLTNIWSFLLYSNEPICGLYGKGFTSIGGINNSLPNPHLRKDSNNPALHFEWEIIHAFGKDAEGERSSAINTSPISVVDKERFSKYHDNYYPGWYLVDDTLERAGRIKN +3bwu_D,125,DLYFNPRFLADDPQAVADLSRFENGQELPPGTYRVDIYLNNGYMATRDVTFNTGDSEQGIVPCLTRAQLASMGLNTASVAGMNLLADDACVPLTTMVQDATAHLDVGQQRLNLTIPQAFMSNRAR +1g6g_A,127,GENIVCRVICTTGQIPIRDLSADISQVLKEKRSIKKVWTFGRNPACDYHLGNISRLSNKHFQILLGEDGNLLLNDISTNGTWLNGQKVEKNSNQLLSQGDEITVGVGVESDILSLVIFINDKFKQCL +3dff_A,273,MPHDPGATRLLAISPHLDDAVLSFGAGLAQAAQDGANVLVYTVFAGAAQPPYSPAAQRMHTIWGLAPDDDAVLYRRKEDIAALDHLRVAHRHGRFLDSIYRKLPDGRWLTAHVEGRQKLAVNDHSPDSDHDLVGEVADDIRSIIDEFDPTLVVTCAAIGEHPDHEATRDAALFATHEKNVPVRLWEDLPYAVFKSGAVELPQGFRLGSADVSSVKPEMRSQKFQAVERYSSQMVLLNGSENNLFDRLDEHARQNAPHGGYGETTWPVVRSDDS +2bko_A,205,MEEVEEFKYEPKSVKEIFIEMKDTVELMVDLAYASLLFGDKEIAEEVLELEERIDLLNYQLMMHSVLAARNVKEAEQVITILQIANAIEDISNAAGDLAKMVLEGVELHPVIKETILEGEEIIGKIQVYPESVIVGKTLGELDLATNTGVWIIAVRRGKRWIFGPNENFKIRAGDVLIGRGTRTSIDHLKEIARGAIRVIGNERA +3ipj_A,95,SNANKYNKIANELIKIIGEDNIISITHCATRLRVMVKDREIINDKKVEKVDEVKGVFFTSGQYQIILGTGIVNKVYAEVEKMGLKTLSKKEQDEL +4z3x_A,653,MRYAETGYVLEVDLTKGSIERVATDPRDTELYLGGLGTNAKILWDRVPPEVEPFSPENLLIFAAGLLCGTPATGCNRTIVSTVSPQTKLMAFSMMGGFWAPELKYAGYDKIIFRGKSPELVYLYINNDKVEIRDASHLKGKGAIETAEIIKKELNEPRAQVAAIGKAGENRVFYASIEQGRSSASRGGIGAVMGDKGLKAVVVRGTKDLCVAKPEEYIGLCNEVLDYIKHREENPIPDVMPILAGLGSPQEMKVHDEKWHTENFNWGNARTRRKDFWTDEVSHAWEKTMDKARTRLISCYNCPMKCGATISMEGLPTYMMKCFTKLTYTMAAYSDLDFGLRIAQKATEYGLDGFSAPQVMAFAFELLEKGILKDSDFPGLPEGNEERFFYLLDKIVNRDGIGDILANGTYWAAQEIGNGAEDYAHNNIKKHEQLPLKLSMLNPIYYLMYCTGEKINITQIEGQFPQAPYPKLEQREAFVEDWIQVPDEKFKKIFLEWEPRGEKSMPNFPTVDMCCDIVDWQEMMHYIDDALGQCAGLSSFPLKPPYHIHNYPKFIAAGAGIEMDTEKLKKAAKRYRTLVRAFNIRRGMRRVDEQPPANHWKNRFPELEKELLDSYYKLKGWNDDGIPTKETLDDLGLGYVGDEFIKRGILSAG +4wyw_A,481,GLIVDTRDVEERVHVMRKTELAPTVAHGVFNPEFGPAALSNKDPRLNEGVVLDEVIFSKHKGDTKMSAEDKALFRRCAADYASRLHSVLGTANAPLSIYEAIKGVDGLDAMEPDTAPGLPWALQGKRRGALIDFENGTVGPEVEAALKLMEKREYKFACQTFLKDEIRPMEKVRAGKTRIVDVLPVEHILYTRMMIGRFCAQMHSNNGPQIGSAVGCNPDVDWQRFGTHFAQYRNVWDVDYSAFDANHCSDAMNIMFEEVFRTEFGFHPNAEWILKTLVNTEHAYENKRITVEGGMPSGCSATSIINTILNNIYVLYALRRHYEGVELDTYTMISYGDDIVVASDYDLDFEALKPHFKSLGQTITPADKSDKGFVLGHSITDVTFLKRHFHMDYGTGFYKPVMASKTLEAILSFARRGTIQEKLISVAGLAVHSGPDEYRRLFEPFQGLFEIPSYRSLYLRWVNAVCGDAAAALEHHHHHH +3ka5_A,65,EIVKTKRFAIKPMSEEEAVLEMELLGHNFFVFQNGDSNEVNVVYKRKDGNYGLIEPELEHHHHHH +3dnh_B,258,GHMLDVAPPVITPRGTKIEPSAGAPFEAVRVARDVLHTSRTAALATLDPVSGYPYTTATNIGIEPDGTPFFFAAGLTLHARNMETDARISVTLAPFGKGDALTLPRLTLVGRADRIGPDEVPLAIARYIARYPKAKLYLSLPDTRLYRLRTEGVQINGGPARNASNITPADLRTDLSGAEELMAAAESEATRLNAIKGEASRLAVLAGAKTGRWKITSIDPDGIDLASASDLARLWFAERVETLKQFEKALAQLLKGS +3vgp_A,164,GPDLTEDWKEALEWMRTSLEEQNYLNPYEKPEYSVMSWWDYGNWILYVSKKAVVANNFQAGAVDAAKFFTAKSEDEAIKIAKKRGVRYVVTADEITMKDANNTKFPAIMRIAGYNVDLMTEGEILNFFNHTVLYRLHMENAENLTHFRLVKEFGDVKIFEVVGS +3pt1_A,471,HMTIPGRFMTIDKGTFGEYTASTRWPIIIQNAIDDLSKHQETEKSNGTKFEQGEVIKKELKEFRQEIIDRVPLRPFTEEEIKIANVPLSFNEYLKKHPEVNWGAVEWLFSEVYLYRRVNVLFQRQCEWAKFDIFNRLKQSTFESSFYGVVELALRYENLLPQLREMKQNPGNEIDDILKVLFKEFIEISLWGNATDLSLLTNATLEDIKSIQGAKARAASESKIVVNDTEKAWEVLTKARADANSREIRVDFVLDNSGFELYADLMLAAFLLQSGLATKCIFHAKDIPYMVSDVMLKDFDILVHDLRDREFFPSGEPSTKESRALDLFAGEMEKFVSSGKIEFREDSFWTTELDYWNLDANETKYHGSILHKDLQKSNLVIFKGDLNYRKLTGDRKWPRTTKWETAIGPLATNGITSLSLRTCKADVQVALPEGLDAKLSQEWEKENPGRGSWWCCSGKWAVICFCSGIHK +3boe_A,210,SISPAQIAEALQGRGWDAEIVTDASMAGQLVDVRPEGILKCVDGRGSDNTRMGGPKMPGGIYAIAHNRGVTSIEGLKQITKEVASKGHLPSVHGDHSSDMLGCGFFKLWVTGRFDDMGYPRPQFDADQGANAVKDAGGIIEMHHGSHTEKVVYINLLANKTLEPNENDQRFIVDGWAADKFGLDVPKFLIAAAATVEMLGGPKNAKIVVP +7p46_A,282,KNLRDLEPGIHTDLEGRLTYGGYLRLDQLLSAQQPLSEPAHHDEMLFIIQSQTSELWLKLLAHELRAAIVHLQRDEVWQCRKVLARSKQVLRQLTEQWSVLETLTPSEYMGFRDVLGPSSGFQSLQYRYIEFLLGNKNPQMLQVFAYDPAGQARLREVLEAPSLYEEFLRYLARFGHAIPQQYQARDWTAAHVADDTLRPVFERIYENTDRYWREYSLCEDLVDVETQFQLWRFRHMRTVMRVIGFKRGTGGSSGVGFLQQALALTFFPELFDVRTSVGVDN +2z4u_A,261,VNTITFDVGNATINKYATFMESLRNEAKDPTLKCYGIPMLPDSNLTPKYVLVKLQDASSKTITLMLRRNNLYVMGYSDLYNGKCRYHIFNDISSTESTDVENTLCPNSNSREKKAINYNSQYSTLQNKAGVSSRSQVQLGIQILNSDIGKISGVSTFTDKTEAEFLLVAIQMVSEAARFKYIENQVKTNFNRAFNPNPKVLSLEENWGKISLAIHNAKNGALTSPLELKNADDTKWIVLRVDEIKPDMGLLNYVSGTCQTT +1d2t_A,231,LALVATGNDTTTKPDLYYLKNSEAINSLALLPPPPAVGSIAFLNDQAMYEQGRLLRNTERGKLAAEDANLSSGGVANAFSGAFGSPITEKDAPALHKLLTNMIEDAGDLATRSAKDHYMRIRPFAFYGVSTCNTTEQDKLSKNGSYPSGHTSIGWATALVLAEINPQRQNEILKRGYELGQSRVICGYHWQSDVDAARVVGSAVVATLHTNPAFQQQLQKAKAEFAQHQKK +5znj_A,567,MKQSKVFIPTMRDVPSEAEAQSHRLLLKSGLIKQSTSGIYSYLPLATRVLNNITAIVRQEMERIDSVEILMPALQQAELWEESGRWGAYGPELMRLQDRHGRQFALGPTHEELVTSIVRNELKSYKQLPMTLFQIQSKFRDEKRPRFGLLRGREFIMKDAYSFHADEASLDQTYQDMYQAYSRIFERVGINARPVVADSGAIGGSHTHEFMALSAIGEDTIVYSKESDYTANIEKAEVVYEPNHKHSTVQPLEKIETPNVKTAQELADFLGRPVDEIAKTMIFKVDGEYIMVLVRGHHEINDIKLKSYFGTDNIELATQDEIVNLVGANPGSLGPVIDKEIKIYADNFVQDLNNLVVGANEDGYHLINVNVGRDFNVDEYGDFRFILEGEKLSDGSGVAHFAEGIEVGQVFKLGTKYSESMNATFLDNQGKAQPLIMGCYGIGISRTLSAIVEQNHDDNGIVWPKSVTPFDLHLISINPKKDDQRELADALYAEFNTKFDVLYDDRQERAGVKFNDADLIGLPLRIVVGKRASEGIVEVKERLTGDSEEVHIDDLMTVITNKYDNLK +6jv8_A,76,LQLLHQKVEEQAAKYKHRVPKKCCYDGARENKYETCEQRVARVTIGPHCIRAFNECCTIADKIRKESHHKGMLLGR +3vej_B,41,SVDLTVPWDDIEALLKNNFENDQAAVRQVMERLQKGWSLAK +1dwk_A,156,MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLAEAFVTAALLGQQALPADAARLVGAKLDLDEDSILLLQMIPLRGCIDDRIPTDPTMYRFYEMLQVYGTTLKALVHEKFGDGIISAINFKLDVKKVADPEGGERAVITLDGKYLPTKPF +6jpt_A,122,MEDTPLVISKQKTEVVCGVPTQVVCTAFSSHILVVVTQFGKMGTLVSLEPSSVASDVSKPVLTTKVLLGQDEPLIHVFAKNLVAFVSQEAGNRAVLLAVAVKDKSMEGLKALREVIRVCQVW +7aex_A,275,MHHHHHHSEINPAEFEQVNMVLQGFVETSVLPVLELSADESHIEFREHSRNAHTVVWKIISTSYQDELTVSLHITTGKLQIQGRPLSCYRVFTFNLAALLDLQGLEKVLIRQEDGKANIVQQEVARTYLQTVMADAYPHLHVTAEKLLVSGLCVKLAAPDLPDYCMLLYPELRTIEGVLKSKMSGLGMPVQQPAGFGTYFDKPAAHYILKPQFAATLRPEQINIISTAYTFFNVERHSLFHMETVVDASRMISDMARLMGKATRAWGIIKDLYIV +2p12_A,176,GHMAGASHDIHPPKVEYFDLRDHTNTDPKGFVRHVDHYRVEPWGLYMARTSDHPQFHYLESWLLPDLGLRASIFHYHPYHQRDQDHYVDIGTFTRGDDVWKSEDHYLDLVVRTGRDTELLDVDELMEAHTTGLLDTATAEQAILTATTAIDGIAAHGHDLGRWLASIGMPIDWRGS +2oy9_B,98,MSLKTTLPISLDWSTEEVIDVVHFFQAIEQAYDQGIAREDLLGKYRRFKEIVPSKSEEKQLFRAYEQENDVSCYQTIKKAREEMEEHIQMEGHHHHHH +4byz_A,194,YTMSAQVIQIGRQRFVGGLFWQSLSRRNELRAEAVELAKKLKFDLMVLRIDRGVAAAGYANTRDGFAPGHLSLGAMVSRAIALEGAFYNGRRQPAPNWLGAFALPDGRWAYFAVRDHAFMPNGDWVGSREEALERLHTDYAWGGWNVVIGEPELERQGFQNFQPKRLDDLLPRRGGRPRTERWWALRPVERRLS +3mmy_D,56,TGTTIKFNPPTGTDTMVKAGVSTNISTKHQCITAMKEYESKSLEELRLEDYQANRK +1dk8_A,147,GSASPTPPYLKWAESLHSLLDDQDGISLFRTFLKQEGCADLLDFWFACTGFRKLEPCDSNEEKRLKLARAIYRKYILDNNGIVSRQTKPATKSFIKGCIMKQLIDPAMFDQAQTEIQATMEENTYPSFLKSDIYLEYTRTGSESPKV +7jfl_C,47,SALQDLLRTLKSPSSPQQQQQVLNILKSNPQLMAAFIKQRTAKYVAN +3c2q_B,345,SNANIPEIENANLKPALKDSVLPDGFYSTTNHPTHVKVNDEWIEVANPKMDAVIVVYPEEKRAETKVIRKVKKGDFVLIGHNGIRVMPPEKSREAGQLFEFMNSEVSSEKPKEAIIKRIAKEMHEIREEYKKTGTGGIAIVGGPAIIHTGGGPALAKMVELGYIQAILAGNALATHDIESALYGTSLGVNIKTAKPVTGGHKHHIYAINAINDAGNIKNAVESGVLKEGIMYQCIKNNIPYVLAGSIRDDGPIPDVITDSMVAQDKMRTTVMDKKMVIMLSTLLHSVATGNLMPSYIKTVCVDIQPSTVTKLMDRGTSQAIGVVTDVGVFLVLLLKELERLELQE +2bw3_B,84,SHMQSRELKTVSADCKKEAIEKCAQWVVRDCRPFSAVSGSGFIDMIKFFIKVGAEYGDHVNVEELLPSPITLSRKVTSDAKEKA +6bwq_A,149,TADAVLMIEANLDDQTGEGLGYVMNQLLTAGAYDVFFTPIQMKKDRPATKLTVLGNVNDKDLLTKLILQETTTIGVRYQTWQRTIMQRHFLTVATPYGDVQVKVATYQDIEKKMPEYADCAQLAQQFHIPFRTVYQAALVAVDQLDEEA +5k6d_B,78,MWKVSERCLKGHGKFQADQEIGNGLATAKGQCKGTDSDQKKAGKCDKHCTGVCLGSGGSCGDGSSQKPNKEDCYCKSK +5eqz_A,146,MAHHHHHHYVEEKKEIDSLMEDVLALVNDSSGGKFKDYKDKINELKENLKDIGNAELKEKLLNLQNSFQDKLAAKLAALKAAKNTIENITDKDQDISKRKIWSEAKLVGVTVPLLGSNTSGNGDKMSKNAVEQIDKVIKFLEEGTN +3d4i_A,273,AMGSATISRRGILVIRHGERVDQVFGKSWLQQCTTADGKYYRPDLNFPRSLPRRSNGIKDFENDPPLSSCGIFQARLAGEALLDSGVRVTAVFASPALRCVQTAKHILEELKLEKKLKIRVEPGIFEWMKWEASKATLTFLTLEELKEANFNVDLDYRPALPRCSLMPAESYDQYVERCAVSMGQIINTCPQDMGITLIVSHSSALDSCTRPLLGLPPRECGDFAQLVRKIPSLGMCFCEENREDGKWDLVNPPVKTLTHGANSVFNWRNWIS +5n0n_A,410,GTSTQTKAGSLTIVGTGIESIGQMTLQALSYIEAAAKVFYCVIDPATEAFILTKNKNCVDLFQYYDNGKSRLNTYTQMSELMVREVRKGLDVVGVFYGHPGVFVNPSHRALAIAKSEGYRARMLPGVSAEDCLFADLCIDPSNPGCLTYEASDFLIRDRPVSIHSHLVLFQVGCVGIADFNFTGFDNNKFGVLVDRLEQEYGAEHPVVHYIAAMMPHQDPVTDKYTVAQLREPEIAKRVGGVSTFYIPPKARKASNLDIIRRLELLPAGQVPDKKARIYPANQWEPDVPEVEPYRPSDQAAIAQLADHAPPEQYQPLATSKAMSDVMTKLALDPKALADYKADHRAFAQSVPDLTPQERAALELGDSWAIRCAMKNMPSSLLDAARESGEEASQNGFPWVIVVGVIGVIG +5iz3_B,316,ELGDSLEEFLAKATTDKNLARLLVCMGEALRTIAFKVRTASCGATACTNTFGDEQLAVDMLADKLLFEALRHSHVCKYACSEEEPILQDMEGEGFSVAFDPLDGSSIVDTNFTVGTIFGVWPGDKLTGITGRDQAASAMGIYGPRTTYVVAINGFPGTHEFLLMDDGKWQHVKETTEIKEGKLFSPGNLRATFDNADYEKLINYYVSEKYTLRYTGGMVPDVNQIIVKERGIFTNVTSPTTKAKLRLLFEVAPLGLLIENAGGYSSDGKQSVLDKVVVNTDDRTQVAYGSRDEIIRFEETLYGDSRLKAELAATVV +2ov9_C,216,VSVGTHPTFENSPSGTVLTSPPDGSAVDRATDAARRVVDALLRTDRGNANLERVAEELNSIAGHLEEHAPAVAERLIDMWNGEGVTRHDPVTGPENALAPPVVLEGLSDGSVRGTVTLTIPYQGPPGHVHGGVSALLLDHVLGVANAWGGKAGMTAQLSTRYHRPTPLFEPLTLTGKLMSVDGRKITTAGDIRTADGQVCVSVEGLFVDKTVPRPR +3tbn_A,87,GSSHHHHHHTDPVIAQKAPYPVTVEAGKTYHWCACGRSKAQPFCDGSHKGTGLAPVAYTPDKAGTAYFCGCKASKAPPLCDGTHKTL +4ue8_B,38,GPHMTKLIYERAFMKNLRGSPLSQTPPSNVPSCLLRGT +3bl9_A,301,SAPVRLPFSGFRLQKVLRESARDKIIFLHGKVNEASGDGDGEDAVVILEKTPFQVEQVAQLLTGSPELQLQFSNDIYSTYHLFPPRQLNDVKTTVVYPATEKHLQKYLRQDLRLIRETGDDYRNITLPHLESQSLSIQWVYNILDKKAEADRIVFENPDPSDGFVLIPDLKWNQQQLDDLYLIAICHRRGIRSLRDLTPEHLPLLRNILHQGQEAILQRYRMKGDHLRVYLHYLPSYYHLHVHFTALGFEAPGSGVERAHLLAEVIENLECDPRHYQQRTLTFALRADDPLLKLLQEAQQS +2ozl_A,365,MRGSFANDATFEIKKCDLHRLEEGPPVTTVLTREDGLKYYRMMQTVRRMELKADQLYKQKIIRGFCHLCDGQEACCVGLEAGINPTDHLITAYRAHGFTFTRGLSVREILAELTGRKGGCAKGKGGSMHMYAKNFYGGNGIVGAQVPLGAGIALACKYNGKDEVCLTLYGDGAANQGQIFEAYNMAALWKLPCIFICENNRYGMGTSVERAAASTDYYKRGDFIPGLRVDGMDILCVREATRFAAAYCRSGKGPILMELQTYRYHGHEMSDPGVSYRTREEIQEVRSKSDPIMLLKDRMVNSNLASVEELKEIDVEVRKEIEDAAQFATADPEPPLEELGYHIYSSDPPFEVRGANQWIKFKSVS +5l87_A,69,GAMWHTHSEREKRVSNAVEFLLDSRVRRTPTSSKVHFLKSKGLSAEEICEAFTKVGQPKTLNEIKRILS +1xkg_A,312,RPSSIKTFEEYKKAFNKSYATFEDEEAARKNFLESVKYVQSNGGAINHLSDLSLDEFKNRFLMSAEAFEHLKTQFDLNAETNACSINGNAPAEIDLRQMRTVTPIRMQGGCGSAWAFSGVAATESAYLAYRDQSLDLAEQELVDCASQHGCHGDTIPRGIEYIQHNGVVQESYYRYVAREQSCRRPNAQRFGISNYCQIYPPNANKIREALAQTHSAIAVIIGIKDLDAFRHYDGRTIIQRDNGYQPNYHAVNIVGYSNAQGVDYWIVRNSWDTNWGDNGYGYFAANIDLMMIEEYPYVVILGQTGHHHHHH +1sx3_A,525,AAKDVKFGNDAGVKMLRGVNVLADAVKVTLGPKGRNVVLDKSFGAPTITKDGVSVAREIELEDKFENMGAQMVKEVASKANDAAGDGTTTATVLAQAIITEGLKAVAAGMNPMDLKRGIDKAVTVAVEELKALSVPCSDSKAIAQVGTISANSDETVGKLIAEAMDKVGKEGVITVEDGTGLQDELDVVEGMQFDRGYLSPYFINKPETGAVELESPFILLADKKISNIREMLPVLEAVAKAGKPLLIIAEDVEGEALATLVVNTMRGIVKVAAVKAPGFGDRRKAMLQDIATLTGGTVISEEIGMELEKATLEDLGQAKRVVINKDTTTIIDGVGEEAAIQGRVAQIRQQIEEATSDYDREKLQERVAKLAGGVAVIKVGAATEVEMKEKKARVEDALHATRAAVEEGVVAGGGVALIRVASKLADLRGQNEDQNVGIKVALRAMEAPLRQIVLNCGEEPSVVANTVKGGDGNYGYNAATEEYGNMIDMGILDPTKVTRSALQYAASVAGLMITTECMVTDLPK +5j2l_B,81,GSHMGSEDYKLREAQRELDKQRKDTEEIRKRLKEIQRLTDERTSTADELIKELREIIRRLQEQSEKLREIIEELEKIIRKR +2fef_B,294,MHAFPLTLDNRLAEALPLWRNLARTDRAPRRNIDLADWKADWRELIAALDRFSRSHGYRQPFAAQGHAALENAWAWGQAAENASTLLLKAIDRGLAGAELRSIYLETAALWLDYSRLLGAARDSLREQGEVDFETAPALAPRTGQYPFALQLLAMGVLLDAQELIPALVEEVLQFDTDRLLDYLGAAALGLTSASEETFHPRPFGQLRAFFEEADGSDAQALAPYLQSQYREFFQLSPKAQKKTRRLTGPYAWGWWAMEVSALGVLYGWDDGVLRASPHYLGDLVDYARARGDA +7p41_D,448,MRGSMQQVGTVAQLWIYPVKSCKGVPVSEAECTAMGLRSGNLRDRFWLVINQEGNMVTARQEPRLVLISLTCDGDTLTLSAMNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNCNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRCALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYTKDLLLPIKTPTTNAVHKCRVHGLEIEGRDCGEATAQWITSFLKSQPYRLVHFEPHMRPRRPHQIADLFRPKDQIAYSDTSPFLILSEASLADLNSRLEKKVKATNFRPNIVISGCDVYAEDSWDELLIGDVELKRVMACSRCILTTVDPDTGVMSRKEPLETLKSYRQCDPSERKLYGKSPLFGQYFVLENPGTIKVGDPVYLLG +2bbr_A,195,MSDSKEVPSLPFLRHLLEELDSHEDSLLLFLCHDAAPGCTTVTQALCSLSQQRKLTLAALVEMLYVLQRMDLLKSRFGLSKEGAEQLLGTSFLTRYRKLMVCVGEELDSSELRALRLFACNLNPSLSTALSESSRFVELVLALENVGLVSPSSVSVLADMLRTLRRLDLCQQLVEYEQQEQARYRYCLEHHHHHH +6y2x_A,235,GSEPEPEQVIKNYTEELKVPPDEDCIICMEKLSTASGYSDVTDSKAIGSLAVGHLTKCSHAFHLLCLLAMYCNGNKDGSLQCPSCKTIYGEKTGTQPQGKMEVLRFQMSLPGHEDCGTILIVYSIPHGIQGPEHPNPGKPFTARGFPRQCYLPDNAQGRKVLELLKVAWKRRLIFTVGTSSTTGETDTVVWNEIHHKTEMDRNITGHGYPDPNYLQNVLAELAAQGVTEDCLEQQ +5a0n_A,220,GAMGGFPNDAKGISGNGKYYSLGQIEKLYSNQFATYNNLTVITSDTHENSDNFAFCLANGKRFPSFTDEKPKGIYTLVKDINKEQYTKLLKENHKWSSIPNLNQAWDTFSRLSYMYLKDPTDIVKRAWGTDLNTARTYFHQVIQYEIWRYTDGMRVSSDTNVYIYEKFSPQQKKALEMIRTDLYNFTVPYENLEYRFYKPDWVFGLGFQALATVRWKIEP +2vri_A,174,KAGSAAAPFTKPFAVYKNVKFYLGDISHLVNCVSFDFVVNAANENLLHGGGVARAIDILTEGQLQSLSKDYISSNGPLKVGAGVMLECEKFNVFNVVGPRTGKHEHSLLVEAYNSILFENGIPLMPLLSCGIFGVRIENSLKALFSCDINKPLQVFVYSSNEEQAVLKFLDGLD +2j7z_A,68,KPVSLSYRCPCRFFESHVARANVKHLKILNTPNCALQIVARLKNNNRQVCIDPKLKWIQEYLEKALNK +3a1g_A,80,SQRGVLEDEQMYQRCCNLFEKFFPSSSYRRPVGISSMVEAMVSRARIDARIDFESGRIKKEEFTEIMKICSTIEELRRQK +6oz1_A,643,QGMTETISTPAATTTPDLADQQELERRVAQVVSNDPQLQALLPDDAVSGAVNEPGLTLIELIRRLLEGYGDRPALGQRAVELVTDEHGATTVALKTEFVTTSYRELWNRAEAIAAAWYAQGIRDGDFVAQLGFTSTDFASLDVAGLRLGTVSVPLQTGASVQQRNAILEETQPTVFAASVEYLEAAVDSVLATPSVQLLSVFDYHPEADAHRAALSAVRDRLETAGRTITIDSLGDAIARGRELPAAPQPSEDPDALRLLIYTSGSTGTPKGAMYPQWLVANLWQKKWLTTTVIPSVGVNFMPMSHLAGRLTLMGTLSGGGTAYYIASSDLSTFFEDIALIRPTEVLFVPRVVEMVFQRYQAELDRSLAPGESNAEIAEKIKVRIREEDFGGRVLNAGSGSAPLSPEMNDFMESLLQVAMLDGYGSTEAGAVWRDGVLQRPPVTEYKLVDVPELGYFTTDSPHPRGELRIKSETMFPGYYKRPETTADVFDEDGFYMTGDVVAELGPDHLKYLDRVKNVLKLAQGEFVAVSKLEAAYTGSPLVRQIFVYGNSERSYLLAVVVPTPEALERYADSPDALKPLIQDSLQQVAKGAELQSYEIPRDFIVETVPFTVESGLLSDARKLLRPKLKEHYGERLEALYAD +1ijx_C,127,GSAACEPVRIPLCKSLPWEMTKMPNHLHHSTQANAILAMEQFEGLLGTHCSPDLLFFLCAMYAPICTIDFQHEPIKPCKSVCERARQGCEPILIKYRHSWPESLACDELPVYDRGVCISPEAIVTAD +1q2h_A,69,GSHMPDWRQFCELHAQAAAVDFAHKFCRFLRDNPAYDTPDAGASFSRHFAANFLDVFGEEVRRVLVAGP +5dje_B,140,GSSCDFVADVPPPKKGTDYDFYEAWGPVFEAEARFSKKTPIPSLGNMDSSKKEVEQFYAFWHRFDSWRTFEFLDEDVPDDSSNRDHKRYIERKNKAARDKKKTADMARLVKLVERAVSEDPRIKMFKEEEKKEKERRKWE +4zav_A,209,MSGPERITLAMTGASGAQYGLRLLDCLVQEEREVHFLISKAAQLVMATETDVALPAKPQAMQAFLTEYCGAAAGQIRVFGQNDWMAPPASGSSAPNAMVICPCSTGTLSAVATGACNNLIERAADVALKERRPLVLVPREAPFSSIHLENMLKLSNLGAVILPAAPGFYHQPQSVEDLVDFVVARILNTLGIPQDMLPRWGEQHLVSDE +2q7x_A,326,GMRKPKITVIGGGTGSPVILKSLREKDVEIAAIVTVADDGGSSGELRKNMQQLTPPGDLRNVLVAMSDMPKFYEKVFQYRFSEDAGAFAGHPLGNLIIAGLSEMQGSTYNAMQLLSKFFHTTGKIYPSSDHPLTLHAVFQDGTEVAGESHIVDHRGIIDNVYVTNALNDDTPLASRRVVQTILESDMIVLGPGSLFTSILPNIVIKEIGRALLETKAEIAYVCNIMTQRGETEHFTDSDHVEVLHRHLGRPFIDTVLVNIEKVPQEYMNSNRFDEYLVQVEHDFVGLCKQVSRVISSNFLRLENGGAFHDGDLIVDELMRIIQVKK +5n6y_C,113,MSQSHLDDLFAYVEERCLWQFFSRTWDREENIEGVLNQVGRLLTGQEPLRGTPQERLFYADALAMANDVRERFPWASQVNKEEIEFLLDGLKSRLVDVTITRSTNRELNHHLY +1h02_B,70,GPETLCGAELVDALQFVCGDRGFYFNKPTGYGSSSRRAPQTGIVDECCFRSCDLRRLEMYCAPLKPAKSA +1q74_C,303,MSETPRLLFVHAHPDDESLSNGATIAHYTSRGAQVHVVTCTLGEEGEVIGDRWAQLTADHADQLGGYRIGELTAALRALGVSAPIYLGGAGRWRDSGMAGTDQRSQRRFVDADPRQTVGALVAIIRELRPHVVVTYDPNGGYGHPDHVHTHTVTTAAVAAAGVGSGTADHPGDPWTVPKFYWTVLGLSALISGARALVPDDLRPEWVLPRADEIAFGYSDDGIDAVVEADEQARAAKVAALAAHATQVVVGPTGRAAALSNNLALPILADEHYVLAGGSAGARDERGWETDLLAGLGFTASGT +5c8q_B,154,GSHGNYAGNFSGSSRDICLDGARLRAECRRGDGGYSTSVIDLNRYLSNDNGHFRWVSTATVTVQQGDTLRDIGRRFDCDFHEIARRNNIQNEDLIYPGQVLQVGGNFWDSARDVRLVDGGKVLEAELRYSGGWNRSRIYLDEHIGNRNGELIHC +4m37_A,343,GYRPPDAFVNRIDRNIPVPARLRHTPVSLIEAVNDFHYAMMNDEERNNFYYEVLKKHVTPETGVLEIGAGSGLLSLMAAKLGAKWVVAVEGSEELAKLARENIRANNMEHQVKVLHMMSTELKSKHLPEPPDVLLSEIFGTMMLGESALDYVVDVRNRLLKPTTKIIPQFGTQYAVPIECDALHRISSVSGWRDLDLKHMMTLQDTVSIVFAKHYGIRMNSVNFRRLSDPIELFRVDFSSSNRNDIPRRKHFDVVAKESGTAHAMLFYWKVTDDEFVMSTDPEDTVNNFPRDMQWGQALQLLDASNGPLPTPVVFTEGKNYNFECNFSGDRVILHMQLCPESG +5n0s_A,410,GTSTQTKAGSLTIVGTGIESIGQMTLQALSYIEAAAKVFYCVIDPATEAFILTKNKNCVDLYQYYDNGKSRLNTYTQMSELMVREVRKGLDVVGVFAGHPGVFVNPSHRALAIAKSEGYRARMLPGVSAEDCLFADLCIDPSNPGCLTYEASDFLIRDRPVSIHSHLVLFQVGCVGIADFNFTGFDNNKFGVLVDRLEQEYGAEHPVVHYIAAMMPHQDPVTDKYTVAQLREPEIAKRVGGVSTFYIPPKARKASNLDIIRRLELLPAGQVPDKKARIYPANQWEPDVPEVEPYRPSDQAAIAQLADHAPPEQYQPLATSKAMSDVMTKLALDPKALADYKADHRAFAQSVPDLTPQERAALELGDSWAIRCAMKNMPSSLLDAARESGEEASQNGFPWVIVVGVIGVIG +2d1l_A,253,AGHMEAVIEKECSALGGLFQTIISDMKGSYPVWEDFINKAGKLQSQLRTTVVAAAAFLDAFQKVADMATNTRGGTREIGSALTRMCMRHRSIEAKLRQFSSALIDCLINPLQEQMEEWKKVANQLDKDHAKEYKKARQEIKKKSSDTLKLQKKAKKVDAQGRGDIQPQLDSALQDVNDKYLLLEETEKQAVRKALIEERGRFCTFISMLRPVIEEEISMLGEITHLQTISEDLKSLTMDPHKLPSSSEQVILD +3dan_A,473,MDPSSKPLREIPGSYGIPFFQPIKDRLEYFYGTGGRDEYFRSRMQKYQSTVFRANMPPGPFVSSNPKVIVLLDAKSFPILFDVSKVEKKDLFTGTYMPSTKLTGGYRVLSYLDPSEPRHAQLKNLLFFMLKNSSNRVIPQFETTYTELFEGLEAELAKNGKAAFNDVGEQAAFRFLGRAYFNSNPEETKLGTSAPTLISSWVLFNLAPTLDLGLPWFLQEPLLHTFRLPAFLIKSTYNKLYDYFQSVATPVMEQAEKLGVPKDEAVHNILFAVCFNTFGGVKILFPNTLKWIGLAGENLHTQLAEEIRGAIKSYGDGNVTLEAIEQMPLTKSVVYESLRIEPPVPPQYGKAKSNFTIESHDATFEVKKGEMLFGYQPFATKDPKVFDRPEEYVPDRFVGDGEALLKYVWWSNGPETESPTVENKQCAGKDFVVLITRLFVIELFRRYDSFEIELGESPLGAAVTLTFLKRASI +1ys1_X,320,ADNYAATRYPIILVHGLTGTDKYAGVLEYWYGIQEDLQQRGATVYVANLSGFQSDDGPNGRGEQLLAYVKTVLAATGATKVNLVGHSQGGLTSRYVAAVAPDLVASVTTIGTPHRGSEFADFVQGVLAYDPTGLSSTVIAAFVNVFGILTSSSNNTNQDALAALKTLTTAQAATYNQNYPSAGLGAPGSCQTGAPTETVGGNTHLLYSWAGTAIQPTISVGGVTGATDTSTIPLVDPANALDPSTLALFGTGTVMVNRGSGQNDGVVSKCSALYGQVLSTSYKWNHLDEINQLLGVRGANAEDPVAVIRTHANRLKLAGV +2nug_B,221,MKMLEQLEKKLGYTFKDKSLLEKALTHVSYSKKEHYETLEFLGDALVNFFIVDLLVQYSPNKREGFLSPLKAYLISEEFFNLLAQKLELHKFIRIKRGKINETIIGDVFEALWAAVYIDSGRDANFTRELFYKLFKEDILSAIKEGRVKKDYKTILQEITQKRWKERPEYRLISVEGPHHKKKFIVEAKIKEYRTLGEGKSKKEAEQRAAEELIKLLEESE +3rmq_A,116,SNAMQRYLWQQADGKRHVYDTARHRVQAGRPFTALCGETVTPQTERGDLTAGLWFDGECPVCTIALAKALGWPMREISDLAHRFDWSPALITRLAEVLHCSFGEVVELTGARMVDA +1rk8_B,147,MSTEDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDTMIRKEAFVHQSVMEELKRIIIDSEIMQEDDLPWPPPDRVGRQELEIVIGDEHISFTTSKTGSLVDVNRSKDPEGLRCFYYLVQDLKCLVFSLIGLHFKIKPI +1msk_A,331,KKPRTPPVTLEAARDNDFAFDWQAYTPPVAHRLGVQEVEASIETLRNYIDWTPFFMTWSLAGKYPRILEDEVVGVEAQRLFKDANDMLDKLSAEKTLNPRGVVGLFPANRVGDDIEIYRDETRTHVINVSHHLRQQTEKTGFANYCLADFVAPKLSGKADYIGAFAVTGGLEEDALADAFEAQHDDYNKIMVKALADRLAEAFAEYLHERVRKVYWGYAPNENLSNEELIRENYQGIRPAPGYPACPEHTEKATIWELLEVEKHTGMKLTESFAMWPGASVSGWYFSHPDSKYYAVAQIQRDQVEDYARRKGMSVTEVERWLAPNLGYDAD +2p58_A,65,TQLEEQLHNVETVRSITMQLEMALTKLKKDMMRGGDAKQYQVWQRESKALESAIAIIHYVAGDLK +3zoq_C,56,MVQNDFVDSYDVTMLLQDDDGKQYYEYHKGLSLSDFEVLYGNTADEIIKLRLDKVL +1v0w_A,506,ADSATPHLDAVEQTLRQVSPGLEGDVWERTSGNKLDGSAADPSDWLLQTPGCWGDDKCADRVGTKRLLAKMTENIGNATRTVDISTLAPFPNGAFQDAIVAGLKESAAKGNKLKVRILVGAAPVYHMNVIPSKYRDELTAKLGKAAENITLNVASMTTSKTAFSWNHSKILVVDGQSALTGGINSWKDDYLDTTHPVSDVDLALTGPAAGSAGRYLDTLWTWTCQNKSNIASVWFAASGNAGCMPTMHKDTNPKASPATGNVPVIAVGGLGVGIKDVDPKSTFRPDLPTASDTKCVVGLHDNTNADRDYDTVNPEESALRALVASAKGHIEISQQDLNATCPPLPRYDIRLYDALAAKMAAGVKVRIVVSDPANRGAVGSGGYSQIKSLSEISDTLRNRLANITGGQQAAKTAMCSNLQLATFRSSPNGKWADGHPYAQHHKLVSVDSSTFYIGSKNLYPSWLQDFGYIVESPEAAKQLDAKLLDPQWKYSQETATVDYARGICNA +3ho6_A,267,SEDNGVDFNKNTALDKNYLLNNKIPSNNVEEAGSKNYVHYIIQLQGDDISYEATCNLFSKNPKNSIIIQRNMNESAKSYFLSDDGESILELNKYRIPERLKNKEKVKVTFIGHGKDEFNTSEFARLSVDSLSNEISSFLDTIKLDISPKNVEVNLLGCNMFSYDFNVEETYPGKLLLSIMDKITSTLPDVNKNSITIGANQYEVRINSEGRKELLAHSGKWINKEEAIMSDLSSKEYIFFDSIDNKLKAKSKNIPGLASISEDIKTL +5opz_A,137,GPLGSLNDIEEIRFTARSEENLRGVHPDLVRVIRLALRYSLVPFSVSEGLRSMARQREMVRAGSSQTLRSRHLTGHAVDVVAMPAGVVSWEWDYYAQIAVAVRRAARECGIIVEWGGEWKTLKDGPHFQLTFRDYPA +5noh_A,103,ALGVVGVLESYIGSINNITKQSACVAMSKLLTELNSDDIKKLRDNEELNSPKIRVYNTVISYIESNRKNNKQTIHLLKRLPADVLKKTIKNTLDIHKSITINN +1u84_A,90,SNAMDGQQLNRLLLEWIGAWDPFGLGKDAYDVEAASVLQAVYETEDARTLAARIQSIYEFAFDEPIPFPHCLKLARRLLELKQAASCPLP +3c8w_C,255,GMDKYLSANSLEGVIDNEFSMPAPRWLNTYPAGPYRFINREFFIIAYETDPDLLQAILPPDMELLEPVVKFEFIRMPDSTGFGDYTESGQVVPVRYKGEEGGFTISMFLDCHAPIAGGREIWGFPKKLAKPKLFVEEDTLIGILKYGSIDIAIATMGYKHRPLDAEKVLESVKKPVFLLKNIPNVDGTPLVNQLTKTYLTDITVKGAWTGPGSLELHPHALAPISNLYIKKIVSVSHFITDLTLPYGKVVADYLA +4cdp_A,354,MGSSHHHHHHGSMNHYTRWLELKEQNPGKYARDIAGLMNIREAELAFARVTHDAWRMHGDIREILAALESVGETKCICRNEYAVHEQVGTFTNQHLNGHAGLILNPRALDLRLFLNQWASVFHIKENTARGERQSIQFFDHQGDALLKVYATDNTDMAAWSELLARFITDENTPLELKAVDAPVVQTRADATVVEQEWRAMTDVHQFFTLLKRHNLTRQQAFNLVADDLACKVSNSALAQILESAQQDGNEIMVFVGNRGCVQIFTGVVEKVVPMKGWLNIFNPTFTLHLLEESIAEAWVTRKPTSDGYVTSLELFAHDGTQIAQLYGQRTEGEQEQAQWRKQIASLIPEGVAA +4jmu_A,112,HMGARASVLSGGELDKWEKIRLRPGGKKQYKLKHIVWASRELERFAVNPGLLETSEGCRQILGQLQPSLQTGSEELRSLYNTIAVLYCVHQRIDVKDTKEALDKIEEEQNKS diff --git a/models/__pycache__/edge_embedder.cpython-310.pyc b/models/__pycache__/edge_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8d099a4e79827c586c307b3378c588694d6b8e Binary files /dev/null and b/models/__pycache__/edge_embedder.cpython-310.pyc differ diff --git a/models/__pycache__/flow_model.cpython-310.pyc b/models/__pycache__/flow_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b0e49814dabcc5ade5c7e7260201151c1ec8730 Binary files /dev/null and b/models/__pycache__/flow_model.cpython-310.pyc differ diff --git a/models/__pycache__/flow_module.cpython-310.pyc b/models/__pycache__/flow_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71f640fb0eaab00144ce0d812d2c8649f0f9b94b Binary files /dev/null and b/models/__pycache__/flow_module.cpython-310.pyc differ diff --git a/models/__pycache__/ipa_pytorch.cpython-310.pyc b/models/__pycache__/ipa_pytorch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccbc0db01f51cda2d42c5e6d143f452dc5156e8d Binary files /dev/null and b/models/__pycache__/ipa_pytorch.cpython-310.pyc differ diff --git a/models/__pycache__/node_embedder.cpython-310.pyc b/models/__pycache__/node_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cdbbab2253225fba49d91dc456d6059180e5104 Binary files /dev/null and b/models/__pycache__/node_embedder.cpython-310.pyc differ diff --git a/models/__pycache__/utils.cpython-310.pyc b/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c9552d9aab93864bff0a7e815ad9823448106b Binary files /dev/null and b/models/__pycache__/utils.cpython-310.pyc differ diff --git a/models/add_module/Attention_module.py b/models/add_module/Attention_module.py new file mode 100644 index 0000000000000000000000000000000000000000..aa16ca025fbd7815c574c1eccf184373b817a43b --- /dev/null +++ b/models/add_module/Attention_module.py @@ -0,0 +1,386 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from opt_einsum import contract as einsum +# from model.model_utils import init_lecun_normal + +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, r_ff, p_drop=0.1): + super(FeedForwardLayer, self).__init__() + self.norm = nn.LayerNorm(d_model) + self.linear1 = nn.Linear(d_model, d_model*r_ff) + self.dropout = nn.Dropout(p_drop) + self.linear2 = nn.Linear(d_model*r_ff, d_model) + + # self.reset_parameter() + + # def reset_parameter(self): + # # initialize linear layer right before ReLu: He initializer (kaiming normal) + # nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu') + # nn.init.zeros_(self.linear1.bias) + + # # initialize linear layer right before residual connection: zero initialize + # nn.init.zeros_(self.linear2.weight) + # nn.init.zeros_(self.linear2.bias) + + def forward(self, src): + src = self.norm(src) + src = self.linear2(self.dropout(F.relu_(self.linear1(src)))) + return src + +class Attention(nn.Module): + # calculate multi-head attention + def __init__(self, d_query, d_key, n_head, d_hidden, d_out): + super(Attention, self).__init__() + self.h = n_head + self.dim = d_hidden + # + self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False) + # + self.to_out = nn.Linear(n_head*d_hidden, d_out) + self.scaling = 1/math.sqrt(d_hidden) + # + # initialize all parameters properly + # self.reset_parameter() + + # def reset_parameter(self): + # # query/key/value projection: Glorot uniform / Xavier uniform + # nn.init.xavier_uniform_(self.to_q.weight) + # nn.init.xavier_uniform_(self.to_k.weight) + # nn.init.xavier_uniform_(self.to_v.weight) + + # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward(self, query, key, value): + B, Q = query.shape[:2] + B, K = key.shape[:2] + # + query = self.to_q(query).reshape(B, Q, self.h, self.dim) + key = self.to_k(key).reshape(B, K, self.h, self.dim) + value = self.to_v(value).reshape(B, K, self.h, self.dim) + # + query = query * self.scaling + attn = einsum('bqhd,bkhd->bhqk', query, key) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bhqk,bkhd->bqhd', attn, value) + out = out.reshape(B, Q, self.h*self.dim) + # + out = self.to_out(out) + + return out + +class AttentionWithBias(nn.Module): + def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32): + super(AttentionWithBias, self).__init__() + self.norm_in = nn.LayerNorm(d_in) + self.norm_bias = nn.LayerNorm(d_bias) + # + self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_bias, n_head, bias=False) + self.to_g = nn.Linear(d_in, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_in) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + # self.reset_parameter() + + # def reset_parameter(self): + # # query/key/value projection: Glorot uniform / Xavier uniform + # nn.init.xavier_uniform_(self.to_q.weight) + # nn.init.xavier_uniform_(self.to_k.weight) + # nn.init.xavier_uniform_(self.to_v.weight) + + # # bias: normal distribution + # self.to_b = init_lecun_normal(self.to_b) + + # # gating: zero weights, one biases (mostly open gate at the begining) + # nn.init.zeros_(self.to_g.weight) + # nn.init.ones_(self.to_g.bias) + + # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward(self, x, bias): + B, L = x.shape[:2] + # + x = self.norm_in(x) + bias = self.norm_bias(bias) + # + query = self.to_q(x).reshape(B, L, self.h, self.dim) + key = self.to_k(x).reshape(B, L, self.h, self.dim) + value = self.to_v(x).reshape(B, L, self.h, self.dim) + bias = self.to_b(bias) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(x)) + # + key = key * self.scaling + attn = einsum('bqhd,bkhd->bqkh', query, key) + attn = attn + bias + attn = F.softmax(attn, dim=-2) + # + out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +# MSA Attention (row/column) from AlphaFold architecture +# class SequenceWeight(nn.Module): +# def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1): +# super(SequenceWeight, self).__init__() +# self.h = n_head +# self.dim = d_hidden +# self.scale = 1.0 / math.sqrt(self.dim) + +# self.to_query = nn.Linear(d_msa, n_head*d_hidden) +# self.to_key = nn.Linear(d_msa, n_head*d_hidden) +# self.dropout = nn.Dropout(p_drop) + +# # self.reset_parameter() + +# # def reset_parameter(self): +# # # query/key/value projection: Glorot uniform / Xavier uniform +# # nn.init.xavier_uniform_(self.to_query.weight) +# # nn.init.xavier_uniform_(self.to_key.weight) + +# def forward(self, msa): +# B, N, L = msa.shape[:3] + +# tar_seq = msa[:,0] + +# q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim) +# k = self.to_key(msa).view(B, N, L, self.h, self.dim) + +# q = q * self.scale +# attn = einsum('bqihd,bkihd->bkihq', q, k) +# attn = F.softmax(attn, dim=1) +# return self.dropout(attn) + +class RowAttentionWithBias(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32): + super().__init__() + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + # + # self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1) + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_pair, n_head, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + def forward(self, msa, pair, mask = None): # TODO: make this as tied-attention + B, L = msa.shape[:2] + # + msa = self.norm_msa(msa) + pair = self.norm_pair(pair) + # + # seq_weight = self.seq_weight(msa) # (B, N, L, h, 1) + query = self.to_q(msa).reshape(B, L, self.h, self.dim) # (B, L, h, dim) + key = self.to_k(msa).reshape(B, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, L, self.h, self.dim) + bias = self.to_b(pair) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(msa)) + # + # query = query * seq_weight.expand(-1, -1, -1, -1, self.dim) + key = key * self.scaling + attn = einsum('bqhd,bkhd->bqkh', query, key) # (B, L, L, h) + attn = attn + bias + + if mask is not None: + mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None] + attn = attn * mask_re - 1e9 * (1-mask_re) + + attn = F.softmax(attn, dim=-2) + # + out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +class ColAttention(nn.Module): + def __init__(self, d_msa=256, n_head=8, d_hidden=32): + super().__init__() + self.norm_msa = nn.LayerNorm(d_msa) + # + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + # self.reset_parameter() + + # def reset_parameter(self): + # # query/key/value projection: Glorot uniform / Xavier uniform + # nn.init.xavier_uniform_(self.to_q.weight) + # nn.init.xavier_uniform_(self.to_k.weight) + # nn.init.xavier_uniform_(self.to_v.weight) + + # # gating: zero weights, one biases (mostly open gate at the begining) + # nn.init.zeros_(self.to_g.weight) + # nn.init.ones_(self.to_g.bias) + + # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward(self, msa, mask = None): + ''' + msa (B,L,d_node) + ''' + B, L = msa.shape[:2] + # + msa = self.norm_msa(msa) + # + query = self.to_q(msa).reshape(B, L, self.h, self.dim) # (B,L,H,D) + key = self.to_k(msa).reshape(B, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, L, self.h, self.dim) + gate = torch.sigmoid(self.to_g(msa)) + # + query = query * self.scaling + attn = einsum('bqhd,bkhd->bqkh', query, key) # (B,L,L,H) + + if mask is not None: + mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None] + attn = attn * mask_re - 1e9 * (1-mask_re) + + attn = F.softmax(attn, dim=-3) # (B,L,L,H) + # + out = einsum('bkqh,bkhd->bqhd', attn, value).reshape(B, L, -1) # (B,L,H*D) + out = gate * out + # + out = self.to_out(out) + return out + +# class MSAColGlobalAttention(nn.Module): +# def __init__(self, d_msa=64, n_head=8, d_hidden=8): +# super(MSAColGlobalAttention, self).__init__() +# self.norm_msa = nn.LayerNorm(d_msa) +# # +# self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) +# self.to_k = nn.Linear(d_msa, d_hidden, bias=False) +# self.to_v = nn.Linear(d_msa, d_hidden, bias=False) +# self.to_g = nn.Linear(d_msa, n_head*d_hidden) +# self.to_out = nn.Linear(n_head*d_hidden, d_msa) + +# self.scaling = 1/math.sqrt(d_hidden) +# self.h = n_head +# self.dim = d_hidden + +# self.reset_parameter() + +# def reset_parameter(self): +# # query/key/value projection: Glorot uniform / Xavier uniform +# nn.init.xavier_uniform_(self.to_q.weight) +# nn.init.xavier_uniform_(self.to_k.weight) +# nn.init.xavier_uniform_(self.to_v.weight) + +# # gating: zero weights, one biases (mostly open gate at the begining) +# nn.init.zeros_(self.to_g.weight) +# nn.init.ones_(self.to_g.bias) + +# # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining +# nn.init.zeros_(self.to_out.weight) +# nn.init.zeros_(self.to_out.bias) + +# def forward(self, msa): +# B, N, L = msa.shape[:3] +# # +# msa = self.norm_msa(msa) +# # +# query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) +# query = query.mean(dim=1) # (B, L, h, dim) +# key = self.to_k(msa) # (B, N, L, dim) +# value = self.to_v(msa) # (B, N, L, dim) +# gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim) +# # +# query = query * self.scaling +# attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N) +# attn = F.softmax(attn, dim=-1) +# # +# out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim) +# out = gate * out # (B, N, L, h*dim) +# # +# out = self.to_out(out) +# return out + +# Instead of triangle attention, use Tied axail attention with bias from coordinates..? +class BiasedAxialAttention(nn.Module): + def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True): + super().__init__() + # + self.is_row = is_row + self.norm_pair = nn.LayerNorm(d_pair) + self.norm_bias = nn.LayerNorm(d_bias) + + self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_bias, n_head, bias=False) + self.to_g = nn.Linear(d_pair, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_pair) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + def forward(self, pair, bias, mask = None): + ''' + pair: (B, L, L, d_pair) + mask: (B, L) + ''' + + B, L = pair.shape[:2] + + if self.is_row: + pair = pair.permute(0,2,1,3) + bias = bias.permute(0,2,1,3) + + pair = self.norm_pair(pair) + bias = self.norm_bias(bias) + + query = self.to_q(pair).reshape(B, L, L, self.h, self.dim) # (B, L, L, h, dim) + key = self.to_k(pair).reshape(B, L, L, self.h, self.dim) + value = self.to_v(pair).reshape(B, L, L, self.h, self.dim) + bias = self.to_b(bias) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) + + query = query * self.scaling + key = key / math.sqrt(L) # normalize for tied attention + attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention (B, L, L, h) + attn = attn + bias # apply bias + if mask is not None: + mask_temp = 1e-9 * (mask.type(torch.float) - 1) # (B,L) + attn = attn + mask_temp.unsqueeze(1).unsqueeze(-1) + + attn = F.softmax(attn, dim=-2) # (B, L, L, h) + + out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1) # (B, L, L, h*dim) + out = gate * out + + out = self.to_out(out) + if self.is_row: + out = out.permute(0,2,1,3) + return out + diff --git a/models/add_module/__pycache__/Attention_module.cpython-310.pyc b/models/add_module/__pycache__/Attention_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e4ee430887eac7fbb527c2fcc543aeca23b727 Binary files /dev/null and b/models/add_module/__pycache__/Attention_module.cpython-310.pyc differ diff --git a/models/add_module/__pycache__/common.cpython-310.pyc b/models/add_module/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b019925197752bfed21ac6317c5abf2a113ce7 Binary files /dev/null and b/models/add_module/__pycache__/common.cpython-310.pyc differ diff --git a/models/add_module/__pycache__/egnn.cpython-310.pyc b/models/add_module/__pycache__/egnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35cf5c55116cdd49176a0a09dce839fea30de646 Binary files /dev/null and b/models/add_module/__pycache__/egnn.cpython-310.pyc differ diff --git a/models/add_module/__pycache__/model_utils.cpython-310.pyc b/models/add_module/__pycache__/model_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce37eef2966200272a2a4a0e917e51d6050623c5 Binary files /dev/null and b/models/add_module/__pycache__/model_utils.cpython-310.pyc differ diff --git a/models/add_module/__pycache__/structure_module.cpython-310.pyc b/models/add_module/__pycache__/structure_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6a0ff1ccbfeb8ae139d57b4332515000ce518ee Binary files /dev/null and b/models/add_module/__pycache__/structure_module.cpython-310.pyc differ diff --git a/models/add_module/common.py b/models/add_module/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad140d80a60a0d11103cf90c8ab0eab74964ae4 --- /dev/null +++ b/models/add_module/common.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import knn_graph + + +class GaussianSmearing(nn.Module): + def __init__(self, start=0.0, stop=5.0, num_gaussians=50, fixed_offset=True): + super(GaussianSmearing, self).__init__() + self.start = start + self.stop = stop + self.num_gaussians = num_gaussians + if fixed_offset: + # customized offset + offset = torch.tensor([0, 1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.5, 4, 4.5, 5, 5.5, 6, 7, 8, 9, 10]) + else: + offset = torch.linspace(start, stop, num_gaussians) + self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 + self.register_buffer('offset', offset) + + def __repr__(self): + return f'GaussianSmearing(start={self.start}, stop={self.stop}, num_gaussians={self.num_gaussians})' + + def forward(self, dist): + dist = dist.view(-1, 1) - self.offset.view(1, -1) + return torch.exp(self.coeff * torch.pow(dist, 2)) + + +class AngleExpansion(nn.Module): + def __init__(self, start=1.0, stop=5.0, half_expansion=10): + super(AngleExpansion, self).__init__() + l_mul = 1. / torch.linspace(stop, start, half_expansion) + r_mul = torch.linspace(start, stop, half_expansion) + coeff = torch.cat([l_mul, r_mul], dim=-1) + self.register_buffer('coeff', coeff) + + def forward(self, angle): + return torch.cos(angle.view(-1, 1) * self.coeff.view(1, -1)) + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.beta = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): + return x * torch.sigmoid(self.beta * x) + + +NONLINEARITIES = { + "tanh": nn.Tanh(), + "relu": nn.ReLU(), + "softplus": nn.Softplus(), + "elu": nn.ELU(), + "swish": Swish(), + 'silu': nn.SiLU() +} + + +class MLP(nn.Module): + """MLP with the same hidden dim across all layers.""" + + def __init__(self, in_dim, out_dim, hidden_dim, num_layer=2, norm=True, act_fn='relu', act_last=False): + super().__init__() + layers = [] + for layer_idx in range(num_layer): + if layer_idx == 0: + layers.append(nn.Linear(in_dim, hidden_dim)) + elif layer_idx == num_layer - 1: + layers.append(nn.Linear(hidden_dim, out_dim)) + else: + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if layer_idx < num_layer - 1 or act_last: + if norm: + layers.append(nn.LayerNorm(hidden_dim)) + layers.append(NONLINEARITIES[act_fn]) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +def outer_product(*vectors): + for index, vector in enumerate(vectors): + if index == 0: + out = vector.unsqueeze(-1) + else: + out = out * vector.unsqueeze(1) + out = out.view(out.shape[0], -1).unsqueeze(-1) + return out.squeeze() + + +def get_h_dist(dist_metric, hi, hj): + if dist_metric == 'euclidean': + h_dist = torch.sum((hi - hj) ** 2, -1, keepdim=True) + return h_dist + elif dist_metric == 'cos_sim': + hi_norm = torch.norm(hi, p=2, dim=-1, keepdim=True) + hj_norm = torch.norm(hj, p=2, dim=-1, keepdim=True) + h_dist = torch.sum(hi * hj, -1, keepdim=True) / (hi_norm * hj_norm) + return h_dist, hj_norm + + +def get_r_feat(r, r_exp_func, node_type=None, edge_index=None, mode='basic'): + if mode == 'origin': + r_feat = r + elif mode == 'basic': + r_feat = r_exp_func(r) + elif mode == 'sparse': + src, dst = edge_index + nt_src = node_type[src] # [n_edges, 8] + nt_dst = node_type[dst] + r_exp = r_exp_func(r) + r_feat = outer_product(nt_src, nt_dst, r_exp) + else: + raise ValueError(mode) + return r_feat + + +def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand, mask_protein_side): + # previous version has problems when ligand atom types are fixed + # (due to sorting randomly in case of same element) + + batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) + # sort_idx = batch_ctx.argsort() + sort_idx = torch.sort(batch_ctx, stable=True).indices + + mask_ligand = torch.cat([ + torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(), + torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(), + ], dim=0)[sort_idx] + + mask_protein = torch.cat([ + mask_protein_side, + torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(), + ], dim=0)[sort_idx] + + batch_ctx = batch_ctx[sort_idx] + h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H) + pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3) + + return h_ctx, pos_ctx, batch_ctx, mask_ligand, mask_protein, sort_idx + + +def compose_context_prop(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): + batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) + sort_idx = batch_ctx.argsort() + + mask_protein = torch.cat([ + torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(), + torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(), + ], dim=0)[sort_idx] + + batch_ctx = batch_ctx[sort_idx] + h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H) + pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3) + + return h_ctx, pos_ctx, batch_ctx + + +class ShiftedSoftplus(nn.Module): + def __init__(self): + super().__init__() + self.shift = torch.log(torch.tensor(2.0)).item() + + def forward(self, x): + return F.softplus(x) - self.shift + + +def hybrid_edge_connection(ligand_pos, protein_pos, k, ligand_index, protein_index): + # fully-connected for ligand atoms + dst = torch.repeat_interleave(ligand_index, len(ligand_index)) + src = ligand_index.repeat(len(ligand_index)) + mask = dst != src + dst, src = dst[mask], src[mask] + ll_edge_index = torch.stack([src, dst]) + + # knn for ligand-protein edges + ligand_protein_pos_dist = torch.unsqueeze(ligand_pos, 1) - torch.unsqueeze(protein_pos, 0) + ligand_protein_pos_dist = torch.norm(ligand_protein_pos_dist, p=2, dim=-1) + knn_p_idx = torch.topk(ligand_protein_pos_dist, k=k, largest=False, dim=1).indices + knn_p_idx = protein_index[knn_p_idx] + knn_l_idx = torch.unsqueeze(ligand_index, 1) + knn_l_idx = knn_l_idx.repeat(1, k) + pl_edge_index = torch.stack([knn_p_idx, knn_l_idx], dim=0) + pl_edge_index = pl_edge_index.view(2, -1) + return ll_edge_index, pl_edge_index + + +def batch_hybrid_edge_connection(x, k, mask_ligand, batch, add_p_index=False): + batch_size = batch.max().item() + 1 + batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index = [], [], [] + with torch.no_grad(): + for i in range(batch_size): + ligand_index = ((batch == i) & (mask_ligand == 1)).nonzero()[:, 0] + protein_index = ((batch == i) & (mask_ligand == 0)).nonzero()[:, 0] + ligand_pos, protein_pos = x[ligand_index], x[protein_index] + ll_edge_index, pl_edge_index = hybrid_edge_connection( + ligand_pos, protein_pos, k, ligand_index, protein_index) + batch_ll_edge_index.append(ll_edge_index) + batch_pl_edge_index.append(pl_edge_index) + if add_p_index: + all_pos = torch.cat([protein_pos, ligand_pos], 0) + p_edge_index = knn_graph(all_pos, k=k, flow='source_to_target') + p_edge_index = p_edge_index[:, p_edge_index[1] < len(protein_pos)] + p_src, p_dst = p_edge_index + all_index = torch.cat([protein_index, ligand_index], 0) + p_edge_index = torch.stack([all_index[p_src], all_index[p_dst]], 0) + batch_p_edge_index.append(p_edge_index) + + if add_p_index: + edge_index = [torch.cat([ll, pl, p], -1) for ll, pl, p in zip( + batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index)] + else: + edge_index = [torch.cat([ll, pl], -1) for ll, pl in zip(batch_ll_edge_index, batch_pl_edge_index)] + edge_index = torch.cat(edge_index, -1) + return edge_index diff --git a/models/add_module/egnn.py b/models/add_module/egnn.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3577c94fb7045870c9fb417a7777952410e6af --- /dev/null +++ b/models/add_module/egnn.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_sum +from torch_geometric.nn import radius_graph, knn_graph +from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES + + +class EnBaseLayer(nn.Module): + def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False): + super().__init__() + self.r_min = 0. + self.r_max = 10. + self.hidden_dim = hidden_dim + self.num_r_gaussian = num_r_gaussian + self.update_x = update_x + self.act_fn = act_fn + self.norm = norm + if num_r_gaussian > 1: + self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian) + self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim, + num_layer=2, norm=norm, act_fn=act_fn, act_last=True) + self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) + if self.update_x: + # self.x_mlp = MLP(hidden_dim, 1, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) + x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]] + layer = nn.Linear(hidden_dim, 1, bias=False) + torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) + x_mlp.append(layer) + x_mlp.append(nn.Tanh()) + self.x_mlp = nn.Sequential(*x_mlp) + + self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) + + def forward(self, h, x, edge_index, edge_attr=None): + src, dst = edge_index + hi, hj = h[dst], h[src] + # \phi_e in Eq(3) + rel_x = x[dst] - x[src] + d_sq = torch.sum(rel_x ** 2, -1, keepdim=True) + if self.num_r_gaussian > 1: + d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8)) + else: + d_feat = d_sq + + if edge_attr is not None: + edge_feat = torch.cat([d_feat, edge_attr], -1) + else: + edge_feat = d_feat + + mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1)) + eij = self.edge_inf(mij) + mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0]) + + # h update in Eq(6) + h = h + self.node_mlp(torch.cat([mi, h], -1)) + if self.update_x: + # x update in Eq(4) + xi, xj = x[dst], x[src] + # (xi - xj) / (\|xi - xj\| + C) to make it more stable + delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0) + x = x + delta_x + + return h, x + + +class EGNN(nn.Module): + def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn', + update_x=True, act_fn='silu', norm=False): + super().__init__() + # Build the network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.num_r_gaussian = num_r_gaussian + self.update_x = update_x + self.act_fn = act_fn + self.norm = norm + self.k = k + # self.k_seq = 16 + self.cutoff = cutoff + self.cutoff_mode = cutoff_mode + self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian) + self.net = self._build_network() + + def _build_network(self): + # Equivariant layers + layers = [] + for l_idx in range(self.num_layers): + layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian, + update_x=self.update_x, act_fn=self.act_fn, norm=self.norm) + layers.append(layer) + return nn.ModuleList(layers) + + def _connect_edge(self, x, batch, orgshape): + # if self.cutoff_mode == 'radius': + # edge_index = radius_graph(x, r=self.r, batch=batch, flow='source_to_target') + if self.cutoff_mode == 'knn': + edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target') + + # mask = torch.zeros((x.shape[0],x.shape[0]), dtype = bool) # (B*L,B*L) + # mask[edge_index[0], edge_index[1]] = True + # B,L = orgshape + # i,j = torch.meshgrid(torch.arange(L),torch.arange(L)) + # mask_seq = (torch.abs(i-j) < self.k_seq) # (L,L) + # mask_seq = torch.block_diag(*([mask_seq]*B)) # (B*L,B*L) + # mask = torch.logical_or(mask,mask_seq) + + # edge_index = torch.where(mask) + # edge_index = (edge_index[0].to(x.device),edge_index[1].to(x.device)) + + elif self.cutoff_mode == 'hybrid': + pass + # edge_index = batch_hybrid_edge_connection( + # x, k=self.k, mask_ligand=mask_ligand, batch=batch, add_p_index=True) + else: + raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}') + return edge_index + + def forward(self, node_embed, trans_t): + + B,L = node_embed.shape[:2] + + x = trans_t.reshape(B*L,-1) + h = node_embed.reshape(B*L,-1) + + batch = [] + for idx in range(B): + batch += [idx] * L + batch = torch.tensor(batch,device = node_embed.device) + + all_x = [x] + all_h = [h] + for l_idx, layer in enumerate(self.net): + edge_index = self._connect_edge(x, batch, orgshape = (B,L)) + h, x = layer(h, x, edge_index) + all_x.append(x) + all_h.append(h) + outputs = {'x': x, 'h': h} + return x.reshape(B,L,-1), h.reshape(B,L,-1) diff --git a/models/add_module/model_utils.py b/models/add_module/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e6125c28236558f73bc4a894cc8d3a257d4cc7ed --- /dev/null +++ b/models/add_module/model_utils.py @@ -0,0 +1,162 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from opt_einsum import contract as einsum +import copy +# from utils.constants import * +# from utils.structure_utils import rigid_from_3_points +# from scipy.spatial.transform import Rotation as scipy_R + +# def init_lecun_normal(module): +# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): +# normal = torch.distributions.normal.Normal(0, 1) + +# alpha = (a - mu) / sigma +# beta = (b - mu) / sigma + +# alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) +# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform + +# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) +# x = mu + sigma * np.sqrt(2) * torch.erfinv(v) +# x = torch.clamp(x, a, b) + +# return x + +# def sample_truncated_normal(shape): +# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in +# return stddev * truncated_normal(torch.rand(shape)) + +# module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) ) +# return module + +# def init_lecun_normal_param(weight): +# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): +# normal = torch.distributions.normal.Normal(0, 1) + +# alpha = (a - mu) / sigma +# beta = (b - mu) / sigma + +# alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) +# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform + +# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) +# x = mu + sigma * np.sqrt(2) * torch.erfinv(v) +# x = torch.clamp(x, a, b) + +# return x + +# def sample_truncated_normal(shape): +# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in +# return stddev * truncated_normal(torch.rand(shape)) + +# weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) ) +# return weight + +# # for gradient checkpointing +# def create_custom_forward(module, **kwargs): +# def custom_forward(*inputs): +# return module(*inputs, **kwargs) +# return custom_forward + +# def get_clones(module, N): +# return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + +class Dropout(nn.Module): + # Dropout entire row or column + def __init__(self, broadcast_dim=None, p_drop=0.15): + super(Dropout, self).__init__() + # give ones with probability of 1-p_drop / zeros with p_drop + self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop])) + self.broadcast_dim=broadcast_dim + self.p_drop=p_drop + def forward(self, x): + if not self.training: # no drophead during evaluation mode + return x + shape = list(x.shape) + if not self.broadcast_dim == None: + shape[self.broadcast_dim] = 1 + mask = self.sampler.sample(shape).to(x.device).view(shape) + + x = mask * x / (1.0 - self.p_drop) + return x + +def rbf(D, D_min = 0., D_max = 20., D_count = 36): + ''' + D: (B,L,L) + ''' + # Distance radial basis function + D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) + D_mu = D_mu[None,:] + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) # (B, L, L, D_count) + return RBF + +# def get_centrality(D, dist_max): + link_mat = (D < dist_max).type(torch.int).to(D.device) + link_mat -= torch.eye(link_mat.shape[1], dtype = torch.int).to(D.device) + cent_vec = torch.sum(link_mat, dim=-1, dtype = torch.int) + return cent_vec, link_mat + +def get_sub_graph(xyz, mask = None, dist_max = 20, kmin=32): + B, L = xyz.shape[:2] + device = xyz.device + D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)* 10 * dist_max # (B, L, L) + idx = torch.arange(L, device=device)[None,:] # (1, L) + seq = (idx[:,None,:] - idx[:,:,None]).abs() # (1, L, L) + + if mask is not None: + mask_cross = mask.unsqueeze(2).type(torch.float32)* mask.unsqueeze(1).type(torch.float32) # (B, L, L) + D = D + 10 * dist_max * (1 - mask_cross) + seq = seq * mask_cross # (B, L, L) + + seq_cond = torch.logical_and(seq > 0, seq < kmin) + cond = torch.logical_or(D < dist_max, seq_cond) + b,i,j = torch.where(cond) # return idx where is true + + src = b*L+i + tgt = b*L+j + return src, tgt, (b,i,j) + +def scatter_add(src ,index ,dim_index, num_nodes): # sum by the index of src for a source node in the graph + out = src.new_zeros(num_nodes ,src.shape[1]) + index = index.reshape(-1 ,1).expand_as(src) + return out.scatter_add_(dim_index, index, src) + +# # Load ESM-2 model +# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() +# batch_converter = alphabet.get_batch_converter() +# model.eval() # disables dropout for deterministic results + +# def get_pre_repr(seq): + + # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) + # data = [ + # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), + # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), + # ("protein2 with mask","KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), + # ("protein3", "K A I S Q"), + # ] + seq_string = ''.join([aa_321[num2aa[i]] for i in seq]) + data = [("protein1", seq_string)] + batch_labels, batch_strs, batch_tokens = batch_converter(data) + batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) + + # Extract per-residue representations (on CPU) + with torch.no_grad(): + results = model(batch_tokens, repr_layers=[33], return_contacts=True) + + node_repr = results["representations"][33][0,1:-1,:] + pair_repr = results['attentions'][0,-1,:,1:-1,1:-1].permute(1,2,0) + # Generate per-sequence representations via averaging + # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. + # sequence_representations = [] + # for i, tokens_len in enumerate(batch_lens): + # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0)) + + # Look at the unsupervised self-attention map contact predictions + # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]): + # plt.matshow(attention_contacts[: tokens_len, : tokens_len]) + return node_repr, pair_repr \ No newline at end of file diff --git a/models/add_module/structure_module.py b/models/add_module/structure_module.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2560cc277fca326695a02156e87f313797f728 --- /dev/null +++ b/models/add_module/structure_module.py @@ -0,0 +1,943 @@ +import torch.nn as nn +import torch +from opt_einsum import contract as einsum +from models.add_module.model_utils import scatter_add, Dropout +from models.add_module.Attention_module import * +from torch_scatter import scatter +from models.utils import get_time_embedding +import e3nn +import math +from typing import Optional, Callable, List, Sequence +from openfold.utils.rigid_utils import Rigid + + +class TensorProductConvLayer(torch.nn.Module): + def __init__(self, in_irreps, sh_irreps, out_irreps, d_pair, dropout=0.1, fc_factor=1, use_layer_norm=True, + init='normal',use_internal_weights=False): + super().__init__() + self.in_irreps = in_irreps + self.out_irreps = out_irreps + self.sh_irreps = sh_irreps + self.fc_factor = fc_factor + + self.use_internal_weights = use_internal_weights + if self.use_internal_weights: + self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=True, + internal_weights=True) + else: + self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False) + self.fc = nn.Sequential( + nn.LayerNorm(d_pair) if use_layer_norm else nn.Identity(), + nn.Linear(d_pair, self.fc_factor * d_pair), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.fc_factor * d_pair, self.tp.weight_numel) + ) + + if init == 'final': + with torch.no_grad(): + for para in self.fc.parameters(): + para.fill_(0) + elif init != 'normal': + raise ValueError(f'Unknown init {init}') + + def forward(self, attr1, attr2, edge_attr=None): + if self.use_internal_weights: + out = self.tp(attr1, attr2) + else: + out = self.tp(attr1, attr2, self.fc(edge_attr)) + return out + +class E3_GNNlayer(nn.Module): + def __init__(self, e3_d_node_l0=32, e3_d_node_l1=8, d_rbf=64, d_node=256, d_pair=128, + dist_max=20, final = False, init='normal',dropout=0.1): + super().__init__() + + self.dist_max = dist_max + self.d_rbf = d_rbf + self.e3_d_node_l0 = e3_d_node_l0 + self.e3_d_node_l1 = e3_d_node_l1 + self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3) + + irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") + # irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e") + + + self.final = final + if final: + # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e") + # irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") + # irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") + # irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e") + # irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o") + irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e") + else: + # irreps_output = irreps_input + irreps_hid1 = irreps_input + self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) + + + # self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3) + + self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) + self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) + + self.tpc_conv = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_hid1, d_pair,init=init, dropout=dropout) + + + # self.h_hidden = e3nn.o3.Linear(irreps_hid1, irreps_hid1) + # self.h_out = e3nn.o3.Linear(irreps_hid1, irreps_output) + # self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair) + # self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair) + + def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst, + edge_sh, mask = None): + ''' + l1_feats: (B,L,3*self.e3_d_node_l1) + xyz: (B,L,3,3) + ''' + + # create graph + B, L = node.shape[:2] + num_nodes = B * L + + # (num_nodes, irreps_input) + l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0) + node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) + node_graph_total = node_graph_total.reshape(num_nodes,-1) # (B*L, e3_d_node_l0 + 3* e3_d_node_l1) + + # xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA + # edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max) + # edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst] + + # edge_length = edge_vec.norm(dim=1) + # rbf_dist = rbf(edge_length, D_count = self.d_rbf) + # edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1) + b,i,j = pair_index + edge_feats = pair[b,i,j] + edge_feats_total = edge_feats + + # compute the queries (per node), keys (per edge) and values (per edge) + conv_out = self.tpc_conv(node_graph_total[edge_dst], edge_sh, edge_feats_total) + # scatter_out = scatter_add(conv_out, edge_dst, dim_index=0, num_nodes=num_nodes) # (num_nodes, irreps_output) + out = scatter(conv_out, edge_src, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output) + + # out = self.h_hidden(out) + # out = self.h_out(out) + # out = self.h_hidden(value_out, edge_sh, edge_feats_total) + # out = self.h_out(out, edge_sh, edge_feats_total) + out = out.reshape(B,L,-1) + + if self.final: + # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15], + # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6) + + # return torch.concat([out[:,:,0:3] + out[:,:,6:9], + # out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6) + + return torch.concat([out[:,:,0:3], out[:,:,3:6]],dim=-1) # (B,L,6) + else: + l0_feats = out[:,:,:self.e3_d_node_l0] + # l1_feats = out[:,:,self.e3_d_node_l0:] + # node = self.proj_node(l0_feats) + l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats + node = self.proj_node(l0_feats) + node + return node, l1_feats + +class E3_transformer(nn.Module): + def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_rbf=64, d_node=256, d_pair=128, + dist_max=20, final = False): + super().__init__() + + self.dist_max = dist_max + self.d_rbf = d_rbf + self.e3_d_node_l0 = e3_d_node_l0 + self.e3_d_node_l1 = e3_d_node_l1 + self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3) + + irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") + irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e") + self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3) + + irreps_query = irreps_hid1 + irreps_key = irreps_hid1 + irreps_value = irreps_hid1 + self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) + + self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) + + # self.tpc_q = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_query, d_pair) + self.h_q = e3nn.o3.Linear(irreps_input, irreps_query) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False) + + self.tpc_k = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_key, d_pair) + # self.tp_k = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_key, shared_weights=False) + # self.fc_k = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_k.weight_numel], act=torch.nn.functional.relu) + + self.tpc_v = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_value, d_pair) + # self.tp_v = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_value, shared_weights=False) + # self.fc_v = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_v.weight_numel], act=torch.nn.functional.relu) + + self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e") + + self.final = final + if final: + # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e") + irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") + else: + irreps_output = irreps_input + self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) + self.h_hidden = e3nn.o3.Linear(irreps_value, irreps_value) + self.h_out = e3nn.o3.Linear(irreps_value, irreps_output) + # self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair) + # self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair) + + def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst, + edge_sh, mask = None): + ''' + l1_feats: (B,L,3*self.e3_d_node_l1) + xyz: (B,L,3,3) + ''' + + # create graph + B, L = node.shape[:2] + num_nodes = B * L + + # (num_nodes, irreps_input) + l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0) + node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) + node_graph_total = node_graph_total.reshape(num_nodes,-1) + + # xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA + # edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max) + # edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst] + + # edge_length = edge_vec.norm(dim=1) + # rbf_dist = rbf(edge_length, D_count = self.d_rbf) + # edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1) + b,i,j = pair_index + edge_feats = pair[b,i,j] + edge_feats_total = edge_feats + + # compute the queries (per node), keys (per edge) and values (per edge) + # q = self.h_q(node_graph_total)[edge_dst] + # k = self.tp_k(node_graph_total[edge_src], edge_sh, self.fc_k(edge_feats_total)) + # v = self.tp_v(node_graph_total[edge_src], edge_sh, self.fc_v(edge_feats_total)) + + q = self.h_q(node_graph_total)[edge_dst] + # q = self.tpc_q(node_graph_total[edge_dst], edge_sh, edge_feats_total) + k = self.tpc_k(node_graph_total[edge_src], edge_sh, edge_feats_total) + v = self.tpc_v(node_graph_total[edge_src], edge_sh, edge_feats_total) + + # compute the softmax (per edge) + exp = torch.exp(self.dot(q, k)/self.d_hidden_1.sqrt()) # compute the numerator (num_edges, 1) + # z = scatter_add(exp, edge_dst, dim_index=0, num_nodes=num_nodes) # compute the denominator (nodes, 1) + z = scatter(exp, edge_dst, dim=0, dim_size=num_nodes, reduce="sum") # compute the denominator (nodes, 1) + # z[z == 0] = 1 # to avoid 0/0 when all the neighbors are exactly at the cutoff + alpha = exp / (z[edge_dst] + 1e-5) # (num_edges, 1) + + # value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output) + out = scatter(alpha * v, edge_dst, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output) + + out = self.h_hidden(out) + out = self.h_out(out) + # out = self.h_hidden(value_out, edge_sh, edge_feats_total) + # out = self.h_out(out, edge_sh, edge_feats_total) + out = out.reshape(B,L,-1) + + if self.final: + # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15], + # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6) + return torch.concat([out[:,:,0:3] + out[:,:,6:9], + out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6) + else: + l0_feats = out[:,:,:self.e3_d_node_l0] + l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats + node = self.proj_node(l0_feats) + node + return node, l1_feats + + # # /self.d_hidden_1.sqrt() + # v = self.tp_hidden(v, edge_sh, self.fc_hidden(edge_feats_total)) + # v = self.tp_out(v, edge_sh, self.fc_out(edge_feats_total)) + # value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output) + + # # print("before_out_max=",torch.max(value_out)) + # out = self.h_hidden(value_out + self.res_connect(node_graph_total)) + # out = self.h_out(out) + # out = out.reshape(B,L,-1) + + #return out[:,:,6:9] + out[:,:,9:], out[:,:,:3] + out[:,:,3:6] + +class E3_transformer_no_adjacency(nn.Module): + def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_node=256, d_pair=128, no_heads=8, final = False): + super().__init__() + + self.e3_d_node_l0 = e3_d_node_l0 + self.e3_d_node_l1 = e3_d_node_l1 + self.no_heads = no_heads + self.inf = 1e5 + + irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o") + irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+ str(e3_d_node_l1)+"x1o + "+ str(e3_d_node_l1)+"x1e") + self.d_hid1 = irreps_hid1.dim + + irreps_query = irreps_hid1 + irreps_key = irreps_hid1 + irreps_value = irreps_hid1 + + self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0) + + self.h_q = e3nn.o3.Linear(irreps_input, irreps_query*self.no_heads) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False) + + # self.h_k = e3nn.o3.Linear(irreps_input, irreps_key*self.no_heads) + # self.h_v = e3nn.o3.Linear(irreps_input, irreps_value*self.no_heads) + self.tpc_k = TensorProductConvLayer(irreps_input, irreps_input, irreps_key*self.no_heads, d_pair, + use_internal_weights=True) + self.tpc_v = TensorProductConvLayer(irreps_input, irreps_input, irreps_value*self.no_heads, d_pair, + use_internal_weights=True) + + # self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e") + self.linear_pair = nn.Linear(d_pair, self.no_heads) + + self.final = final + if final: + # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e") + irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e") + else: + irreps_output = irreps_input + self.proj_node = nn.Linear(self.e3_d_node_l0, d_node) + + self.h_out = e3nn.o3.Linear(irreps_value*self.no_heads, irreps_output) + # self.h_out = TensorProductConvLayer(irreps_value*self.no_heads, self.irreps_sh, irreps_output, d_pair) + + def forward(self, node, pair, l1_feats, mask = None): + ''' + l1_feats: (B,L,3*self.e3_d_node_l1) + xyz: (B,L,3,3) + ''' + B, L = node.shape[:2] + + # (num_nodes, irreps_input) + l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0) + node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) # (B,L,d_hid1) + + q = self.h_q(node_graph_total) # (B,L,d_hid1*no_heads) + q = torch.split(q, q.shape[-1] // self.no_heads, dim=-1) + q = torch.stack(q, dim=-1) # (B,L,d_hid1,no_heads) + + # k = self.h_k(node_graph_total) + k = self.tpc_k(node_graph_total, node_graph_total) + k = torch.split(k, k.shape[-1] // self.no_heads, dim=-1) + k = torch.stack(k, dim=-1) # (B,L,d_hid1,no_heads) + + # v = self.h_v(node_graph_total) + v = self.tpc_v(node_graph_total, node_graph_total) + v = torch.split(v, v.shape[-1] // self.no_heads, dim=-1) + v = torch.stack(v, dim=-1) # (B,L,d_hid1,no_heads) + + # compute the softmax (per edge) + a = q.unsqueeze(-3) - k.unsqueeze(-4) # (B,L,L,d_hid1,no_heads) + a = torch.sum(a**2, dim=-2)/math.sqrt(3*self.d_hid1) # (B,L,L,no_heads) invariable + a = a + self.linear_pair(pair) + + if mask is not None: + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) # (B,L,L) + square_mask = self.inf * (square_mask - 1) + a = a + square_mask.unsqueeze(-1) + + a = torch.softmax(a, dim=-2) # (B,L,L,no_heads) + out = torch.matmul( + a.permute(0,3,1,2), # (B,no_heads,L,L) + v.permute(0,3,1,2), # (B,no_heads,L,d_hid1) + ) # (B,no_heads,L,d_hid1) + out = out.permute(0,2,1,3) # (B,L,no_heads,d_hid1) + out = out.reshape(B,L,-1) # (B,L,d_hid1*no_heads) + + out = self.h_out(out) + + if self.final: + # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15], + # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6) + return torch.concat([out[:,:,0:3] + out[:,:,6:9], + out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6) + else: + l0_feats = out[:,:,:self.e3_d_node_l0] + l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats + node = self.proj_node(l0_feats) + node + return node, l1_feats + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + +class E3_transformer_test(nn.Module): + + def __init__(self,ipa_conf,inf: float = 1e5, eps: float = 1e-8, + e3_d_node_l1=4): + + super().__init__() + + self._ipa_conf = ipa_conf + self.c_s = ipa_conf.c_s + self.c_z = ipa_conf.c_z + self.c_hidden = ipa_conf.c_hidden + self.no_heads = ipa_conf.no_heads + self.no_qk_points = ipa_conf.no_qk_points + self.no_v_points = ipa_conf.no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = nn.Linear(self.c_s, hc) + self.linear_kv = nn.Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.hpq = hpq//3 + self.linear_q_points = nn.Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.hpkv = hpkv//3 + self.linear_kv_points = nn.Linear(self.c_s, hpkv) + + self.linear_b = nn.Linear(self.c_z, self.no_heads) + self.down_z = nn.Linear(self.c_z, self.c_z // 4) + + # self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads))) + # ipa_point_weights_init_(self.head_weights) + + concat_out_dim = ( + self.c_z // 4 + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = nn.Linear(self.no_heads * concat_out_dim, self.c_s) + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + + + self.use_e3_qkvo = False + + irreps_3 = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o") + self.l1_feats_proj = e3nn.o3.Linear(e3nn.o3.Irreps(str(self.no_heads*self.no_v_points)+"x1o"), + irreps_3) + if self.use_e3_qkvo: + irreps_1 = e3nn.o3.Irreps("3x0e") + irreps_2 = e3nn.o3.Irreps("1x1o") + self.tp_q_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4) + # self.tp_q_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True, + # internal_weights=True) + + self.tp_kv_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4) + # self.tp_kv_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True, + # internal_weights=True) + + self.tp_o_pt_invert = TensorProductConvLayer(irreps_2, irreps_3, irreps_1, 1,fc_factor=4) + # self.tp_o_pt_invert = e3nn.o3.FullyConnectedTensorProduct(irreps_2, irreps_3, irreps_1, shared_weights=True, + # internal_weights=True) + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + l1_feats, + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + if _offload_inference: + z = _z_reference_list + else: + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + s_org = s + + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + if self.use_e3_qkvo: + q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1),q_pts) # [*, N_res, H * P_q, 3] + # q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1)) # [*, N_res, H * P_q, 3] + else: + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view( + q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) + ) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + if self.use_e3_qkvo: + kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1),kv_pts) # [*, N_res, H * (P_q + P_v), 3] + # kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1)) # [*, N_res, H * (P_q + P_v), 3] + else: + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if(_offload_inference): + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + a = torch.matmul( + math.sqrt(1.0 / (3 * self.c_hidden)) * + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + # a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) + + # [*, N_res, N_res, H, P_q, 3] + pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_displacement ** 2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + # head_weights = self.softplus(self.head_weights).view( + # *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + # ) + # head_weights = head_weights * math.sqrt( + # 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + # ) + # pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, v.transpose(-2, -3) + ).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + ( + a[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + ), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + + o_pt_l1 = o_pt.reshape(*o_pt.shape[:-3], -1) # [*, N_res, H * P_v * 3] + l1_feats = l1_feats + self.l1_feats_proj(o_pt_l1) + + if self.use_e3_qkvo: + o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat( + 1,1,self.no_heads, self.no_v_points,1), + torch.sum(o_pt ** 2, dim=-1,keepdim=True)) + # o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat( + # 1,1,self.no_heads, self.no_v_points,1)) + else: + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps) + o_pt_norm_feats = flatten_final_dims( + o_pt_dists, 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if(_offload_inference): + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z // 4] + pair_z = self.down_z(z[0]) + o_pair = torch.matmul(a.transpose(-2, -3), pair_z) + + # [*, N_res, H * C_z // 4] + o_pair = flatten_final_dims(o_pair, 2) + + o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + o_feats, dim=-1 + ) + ) + + s = s + s_org + + return s, l1_feats + +class Energy_Adapter_Node(nn.Module): + def __init__(self, d_node=256, n_head=8, p_drop=0.1, ff_factor=1): + super().__init__() + + # self.energy_proj = nn.Linear(1, d_node) + self.d_node = d_node + self.tfmr_layer = torch.nn.TransformerDecoderLayer( + d_model=d_node, + nhead=n_head, + dim_feedforward=ff_factor*d_node, + dropout=p_drop, + batch_first=True, + norm_first=False + ) + + def forward(self, node, energy, mask = None): + ''' + energy: (B,) + ''' + # energy_emb = self.energy_proj(energy = energy.unsqueeze(-1).unsqueeze(-1)) # (B,1,d_node) + energy_emb = get_time_embedding( + energy, + self.d_node, + max_positions=2056 + )[:,None,:] # (B,1,d_node) + + if mask is not None: + mask = (1 - mask).bool() + node = self.tfmr_layer(node, energy_emb, tgt_key_padding_mask = mask) + + return node + +# class EMA(): +# def __init__(self, model, decay): +# super().__init__() +# self.model = model +# self.decay = decay +# self.device = None +# self.shadow = {} +# for name, param in self.model.named_parameters(): +# if param.requires_grad: +# self.shadow[name] = param.data.clone().detach() + +# def to(self, device): +# self.shadow = {name: param.to(device) for name, param in self.shadow.items()} +# self.device = device + +# def update(self): +# for name, param in self.model.named_parameters(): +# if param.requires_grad: +# assert name in self.shadow +# new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] +# self.shadow[name] = new_average.clone() + +# def apply_shadow(self): +# for name, param in self.model.named_parameters(): +# if param.requires_grad: +# assert name in self.shadow +# param.data.copy_(self.shadow[name]) +# self.model.load_state_dict(self.shadow) + +class Node_update(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1): + super().__init__() + self.d_hidden = d_msa // n_head + self.ff_factor = ff_factor + self.norm_node = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) + self.row_attn = RowAttentionWithBias(d_msa=d_msa, d_pair=d_pair, + n_head=n_head, d_hidden=self.d_hidden) + + self.col_attn = ColAttention(d_msa=d_msa, n_head=n_head, d_hidden=self.d_hidden) + self.ff = FeedForwardLayer(d_msa, self.ff_factor, p_drop=p_drop) + + def forward(self, msa, pair, mask = None): + ''' + Inputs: + - msa: MSA feature (B, L, d_msa) + - pair: Pair feature (B, L, L, d_pair) + - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36) + - xyz: xyz coordinates (B, L, n_atom, 3) + Output: + - msa: Updated MSA feature (B, L, d_msa) + ''' + B, L = msa.shape[:2] + + # prepare input bias feature by combining pair & coordinate info + msa = self.norm_node(msa) + pair = self.norm_pair(pair) + + msa = msa + self.drop_row(self.row_attn(msa, pair, mask=mask)) + msa = msa + self.col_attn(msa, mask=mask) + msa = msa + self.ff(msa) + + return msa + +# class Node2Node(nn.Module): +# def __init__(self, d_node=256, d_pair=128, n_head = 8, dist_max=40, +# d_rbf=64, SE3_param=None, p_drop=0.1): +# super().__init__() + +# # initial node & pair feature process +# d_node_l0 = d_node // 2 + + +# self.dist_max = dist_max +# self.norm_node = nn.LayerNorm(d_node_l0) +# self.proj_node = nn.Linear(d_node, d_node_l0) +# self.norm_pair = nn.LayerNorm(d_pair) +# self.proj_pair = nn.Linear(d_pair, d_pair) + +# self.d_node_l0 = d_node_l0 +# self.d_node = d_node +# self.proj_node_t2 = nn.Linear(d_node_l0, d_node) +# self.node2node_temp = Node2Node_temp(d_msa=d_node, d_pair=d_pair, +# n_head=n_head, d_hidden=d_node//n_head, p_drop=p_drop) + +# self.E3 = E3_transformer(e3_d_node_l0=d_node_l0, e3_d_node_l1=48, d_rbf=d_rbf, +# d_pair=d_pair, dist_max = dist_max, keep_dim = True) + +# self.alpha_pred = AlphaPred(d_node=d_node,p_drop=p_drop) +# # @torch.cuda.amp.autocast(enabled=False) +# def forward(self, node, pair, l1_feats, xyz, mask = None): +# """ +# input: +# msa: node embeddings (B, L, d_node) +# pair: residue pair embeddings (B, L, L, d_pair) +# xyz: initial BB coordinates (B, L, 3, 3) +# R_in: (B,L,3,3) +# """ +# B, L = node.shape[:2] +# node = self.node2node_temp(node, pair, mask = mask) + +# node = self.norm_node(self.proj_node(node)) +# pair = self.norm_pair(self.proj_pair(pair)) + +# out = self.E3(node, pair, l1_feats, xyz, mask = mask) + +# node = self.proj_node_t2(out[...,:self.d_node_l0]) + +# # l1_feats = l1_feats + out[...,self.d_node_l0:].reshape(B*L,48,3) +# l1_feats = out[...,self.d_node_l0:].reshape(B*L,48,3) + +# alpha = self.alpha_pred(node, mask=mask) +# return node, l1_feats, alpha + +class Node2Pair_bias(nn.Module): + def __init__(self, d_node=256, d_hidden=32, d_pair=128, d_head = 16): + super().__init__() + self.d_hidden = d_hidden + self.d_head = d_head + self.norm = nn.LayerNorm(d_node) + self.proj_left = nn.Linear(d_node, d_hidden) + self.proj_right = nn.Linear(d_node, d_hidden) + self.proj_out = nn.Linear(d_head, d_pair) + + # self.reset_parameter() + + # def reset_parameter(self): + # # normal initialization + # self.proj_left = init_lecun_normal(self.proj_left) + # self.proj_right = init_lecun_normal(self.proj_right) + # nn.init.zeros_(self.proj_left.bias) + # nn.init.zeros_(self.proj_right.bias) + + # # zero initialize output + # nn.init.zeros_(self.proj_out.weight) + # nn.init.zeros_(self.proj_out.bias) + + def forward(self, node, mask = None): + B, L = node.shape[:2] + node = self.norm(node) + if mask is not None: + mask_re = mask[...,None].type(torch.float32) + node = node * mask_re + + left = self.proj_left(node).reshape(B, L, self.d_head, -1) + right = self.proj_right(node) + right = (right / np.sqrt(self.d_hidden)).reshape(B, L, self.d_head, -1) + out = einsum('bihk,bjhk->bijh', left, right) # (B, L, L, d_head) + out = self.proj_out(out) + return out + +class Pair_update(nn.Module): + def __init__(self, d_node=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1): + super().__init__() + self.d_hidden = d_pair//n_head + self.d_pair_bias = d_pair + self.d_node = d_node + self.ff_factor = ff_factor + + self.proj_bias = nn.Linear(2*self.d_node, self.d_pair_bias) + + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) # # Dropout entire row or column for broadcast_dim + self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop) + + self.row_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=True) + self.col_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=False) + + self.ff = FeedForwardLayer(d_pair, self.ff_factor, p_drop=p_drop) + + def forward(self, node, pair, mask = None): + + B, L = pair.shape[:2] + pair_bias_total = torch.cat([ + torch.tile(node[:, :, None, :], (1, 1, L, 1)), + torch.tile(node[:, None, :, :], (1, L, 1, 1)), + ], axis=-1) # (B, L, L, 2*d_node) + pair_bias_total = self.proj_bias(pair_bias_total) # (B, L, L, d_pair_bias) + + if mask is not None: + mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None] + pair = pair * mask_re + pair_bias_total = pair_bias_total * mask_re + + pair = pair + self.drop_row(self.row_attn(pair, pair_bias_total, mask)) + pair = pair + self.drop_col(self.col_attn(pair, pair_bias_total, mask)) + pair = pair + self.ff(pair) + + return pair + +class FeatBlock(nn.Module): + def __init__(self, d_node=256, d_pair=128, dist_max = 40, T=500, d_time=64, L_max=64, + d_cent=36, d_rbf=36, p_drop=0.1): + super().__init__() + self.dist_max = dist_max + self.d_rbf = d_rbf + self.d_time = d_time + # self.time_emb_layer = nn.Embedding(T + 1, d_time) + # self.cent_emb_layer = nn.Embedding(L_max + 1, d_cent) + # self.cent_norm = nn.LayerNorm(d_cent) + # self.link_emb_layer = nn.Embedding(2, d_cent) + self.seq_emb_layer = nn.Embedding(22,d_node) + self.node_in_proj = nn.Sequential( + nn.Linear(d_node+d_node, d_node), + nn.ReLU(), + nn.Dropout(p_drop), + nn.Linear(d_node, d_node) + ) + + self.pair_in_proj = nn.Sequential( + nn.Linear(d_pair+d_rbf, d_pair), + nn.ReLU(), + nn.Dropout(p_drop), + nn.Linear(d_pair, d_pair) + ) + + def forward(self, node, pair, seq, xyz, t_emb, mask = None): + ''' + input: + t_emb: timestep (B, d_time) + mask: (B,L) + ''' + B,L = xyz.shape[:2] + + # N, Ca, C = xyz[:, :, 0, :], xyz[:, :, 1, :], xyz[:, :, 2, :] # [batch, L, 3] + # R_true, __ = rigid_from_3_points(N, Ca, C) + # q_rotate = torch.tensor(scipy_R.from_matrix(R_true.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_true.shape[:3]).to(R_true.device).type(torch.float32) # [batch, L, 3] + # q_norm = torch.linalg.norm(q_rotate, dim=-1) # [batch, L] + # q_node_emb = [] + # for i_d in range(1, self.d_rbf//2 + 1): + # q_node_emb = q_node_emb + [torch.sin(i_d * q_norm)[...,None],torch.cos(i_d * q_norm)[...,None]] + # q_node_emb = torch.concat(q_node_emb,dim = -1) # [batch, L, d_rbf] + + # xyz_norm = torch.linalg.norm(xyz[:,:,1], dim=-1) # [batch, L] + # xyz_node_emb = [] + # for i_d in range(1, self.d_rbf//2 + 1): + # xyz_node_emb = xyz_node_emb + [torch.sin(i_d * xyz_norm)[...,None],torch.cos(i_d * xyz_norm)[...,None]] + # xyz_node_emb = torch.concat(xyz_node_emb,dim = -1) # [batch, L, d_rbf] + + seq_emb = self.seq_emb_layer(seq) # (B, L, d_seq) + + # centrality, link_mat = get_centrality(dist_mat, self.dist_max) # (B, L) + # cent_emb = self.cent_norm(self.cent_emb_layer(centrality)) # (B, L, d_cent) + # link_mat_emb = self.link_emb_layer(link_mat) # (B, L, L, d_cent) + + # t_emb1 = t_emb[:, None, :].repeat(1, node.shape[1], 1) # (B, L, d_time) + # print(node.shape,seq_emb.shape,t_emb1.shape,q_node_emb.shape,xyz_node_emb.shape) + node = torch.concat([node, seq_emb], dim = -1) + node = self.node_in_proj(node) # (B, L, d_node) + + dist_mat = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) # (B, L, L) + rbf_dist = rbf(dist_mat, D_count = self.d_rbf) # (B, L, L, d_rbf) + + # R_pair = torch.matmul(R_true[:,:,None,...], R_true[:,None,...].transpose(-1,-2)) # (B,L,L,3,3) + # q_pair = torch.tensor(scipy_R.from_matrix(R_pair.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_pair.shape[:4]).to(R_pair.device).type(torch.float32) # [batch, L, L, 3] + # q_pair_emb = [] + # for i_d in range(1, self.d_rbf//2 + 1): + # q_pair_emb = q_pair_emb + [torch.sin(i_d * q_pair),torch.cos(i_d * q_pair)] + # q_pair_emb = torch.concat(q_pair_emb,dim = -1) # [batch, L, L, 3 * d_rbf] + + # t_emb2 = t_emb[:, None, None, :].repeat(1, pair.shape[1], pair.shape[2], 1) # (B, L, L, d_time) + pair = torch.concat([pair, rbf_dist], dim = -1) + pair = self.pair_in_proj(pair) # (B, L, L, d_pair) + + if mask is not None: + mask = mask.type(torch.float32) # (B, L) + node = node * mask[..., None] # (B, L, d_node) + mask_cross = torch.matmul(mask.unsqueeze(2), mask.unsqueeze(1)) # (B, L, L) + pair = pair * mask_cross[...,None] # (B, L, L, d_pair) + + return node, pair, rbf_dist # node: (B, L, d_node) pair: (B, L, L, d_pair) diff --git a/models/edge_embedder.py b/models/edge_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1fd41e2f02c98583b183479780e4a104d9b996 --- /dev/null +++ b/models/edge_embedder.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + +from models.utils import get_index_embedding, calc_distogram +from models.add_module.model_utils import rbf + +class EdgeEmbedder(nn.Module): + + def __init__(self, module_cfg): + super(EdgeEmbedder, self).__init__() + self._cfg = module_cfg + + self.c_s = self._cfg.c_s + self.c_p = self._cfg.c_p + self.feat_dim = self._cfg.feat_dim + + self.linear_s_p = nn.Linear(self.c_s, self.feat_dim) + self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim) + + self.num_cross_heads = 32 + self.c_pair_pre = 20 + total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2 + self.c_pair_pre + # total_edge_feats = self.num_cross_heads + self._cfg.num_bins * 2 + self.c_pair_pre + self.edge_embedder = nn.Sequential( + nn.Linear(total_edge_feats, self.c_p), + nn.ReLU(), + nn.Dropout(self._cfg.dropout), + nn.Linear(self.c_p, self.c_p), + ) + + def embed_relpos(self, pos): + rel_pos = pos[:, :, None] - pos[:, None, :] + pos_emb = get_index_embedding(rel_pos, self._cfg.feat_dim, max_len=2056) + return self.linear_relpos(pos_emb) + + def _cross_concat(self, feats_1d, num_batch, num_res): + ''' + output: (B, L, L, 2*d_node) + + ''' + return torch.cat([ + torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)), + torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)), + ], dim=-1).float().reshape([num_batch, num_res, num_res, -1]) + + def forward(self, s, t, sc_t, pair_repr_pre, p_mask): + ''' + s: same as node, (B, L, d_node) + + ''' + num_batch, num_res, d_node = s.shape + p_i = self.linear_s_p(s) # (B,L,feat_dim) + cross_node_feats = self._cross_concat(p_i, num_batch, num_res) + + pos = torch.arange( + num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1) + relpos_feats = self.embed_relpos(pos) + # node_split_heads = s.reshape(num_batch, num_res, d_node//self.num_cross_heads, self.num_cross_heads) # (B,L,d_node//num_head,num_head) + # cross_node_feats =torch.einsum('bijh,bkjh->bikh', node_split_heads, node_split_heads) # (B,L,L,num_head) + + pos = t + dists_2d = torch.linalg.norm( + pos[:, :, None, :] - pos[:, None, :, :], axis=-1) # (B,L,L) + + dist_feats = rbf(dists_2d, D_min = 0., D_max = self._cfg.max_dist, D_count = self._cfg.num_bins) + # dist_feats = rbf(dists_2d, D_min = 0., D_max = 20.0, D_count = self._cfg.num_bins) + + pos = sc_t + dists_2d = torch.linalg.norm( + pos[:, :, None, :] - pos[:, None, :, :], axis=-1) # (B,L,L) + + sc_feats = rbf(dists_2d, D_min = 0., D_max = self._cfg.max_dist, D_count = self._cfg.num_bins) + # sc_feats = rbf(dists_2d, D_min = 0., D_max = 20.0, D_count = self._cfg.num_bins) + + all_edge_feats = torch.concat( + [cross_node_feats, relpos_feats, dist_feats, sc_feats, pair_repr_pre], dim=-1) + # all_edge_feats = torch.concat( + # [cross_node_feats, dist_feats, sc_feats, pair_repr_pre], dim=-1) + edge_feats = self.edge_embedder(all_edge_feats) # (B,L,L,c_p) + return edge_feats \ No newline at end of file diff --git a/models/flow_model.py b/models/flow_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e16a4549d5531a7c9b23826479f6f9e35643f80d --- /dev/null +++ b/models/flow_model.py @@ -0,0 +1,357 @@ +"""Neural network architecture for the flow model.""" +import torch +from torch import nn +from models.node_embedder import NodeEmbedder +from models.edge_embedder import EdgeEmbedder +from models.add_module.structure_module import Node_update, Pair_update, E3_transformer, TensorProductConvLayer, E3_GNNlayer, E3_transformer_no_adjacency, E3_transformer_test, Energy_Adapter_Node +from models.add_module.egnn import EGNN +from models.add_module.model_utils import get_sub_graph +from models import ipa_pytorch +from data import utils as du +from data import all_atom +import e3nn + +class FlowModel(nn.Module): + + def __init__(self, model_conf): + super(FlowModel, self).__init__() + + self._model_conf = model_conf + self._ipa_conf = model_conf.ipa + self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE) + self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE) + self.node_embedder = NodeEmbedder(model_conf.node_features) + self.edge_embedder = EdgeEmbedder(model_conf.edge_features) + # aatype_total = 21 + # self.aatype_pred_layer = nn.Sequential( + # nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), + # nn.ReLU(), + # nn.Dropout(self._model_conf.dropout), + # nn.Linear(self._ipa_conf.c_s, aatype_total) + # ) + + # self.use_np_update = False + # self.use_e3_transformer = False + # self.use_torsions = True + # self.use_mid_bb_update = False + # self.use_mid_bb_update_e3 = False + # self.use_adapter_node = True + + + self.use_torsions = model_conf.use_torsions + self.use_adapter_node = model_conf.use_adapter_node + self.use_mid_bb_update = model_conf.use_mid_bb_update + self.use_mid_bb_update_e3 = model_conf.use_mid_bb_update_e3 + self.use_e3_transformer = model_conf.use_e3_transformer + + + if self.use_adapter_node: + self.energy_adapter = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) + + if self.use_torsions: + self.num_torsions = 7 + self.torsions_pred_layer1 = nn.Sequential( + nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), + nn.ReLU(), + nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), + ) + self.torsions_pred_layer2 = nn.Linear(self._ipa_conf.c_s, self.num_torsions * 2) + + if self.use_e3_transformer or self.use_mid_bb_update_e3: + self.max_dist = self._model_conf.edge_features.max_dist + self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) # + input_d_l1 = 3 + e3_d_node_l1 = 32 + irreps_l1_in = e3nn.o3.Irreps(str(input_d_l1)+"x1o") + irreps_l1_out = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o") + + self.proj_l1_init = e3nn.o3.Linear(irreps_l1_in, irreps_l1_out) + + # Attention trunk + self.trunk = nn.ModuleDict() + for b in range(self._ipa_conf.num_blocks): + if self.use_e3_transformer: + # self.trunk[f'ipa_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, + # e3_d_node_l1 = e3_d_node_l1) + # self.trunk[f'ipa_{b}'] = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1, + # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size, + # final = False) + # self.trunk[f'ipa_{b}'] = E3_transformer_no_adjacency(e3_d_node_l0=4*e3_d_node_l1, e3_d_node_l1=e3_d_node_l1, + # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size, + # no_heads=16, final = False) + self.trunk[f'ipa_{b}'] = E3_transformer_test(self._ipa_conf,e3_d_node_l1 = e3_d_node_l1) + else: + self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf) + self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s) + + if self.use_adapter_node: + self.trunk[f'energy_adapter_{b}'] = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) + + + self.trunk[f'egnn_{b}'] = EGNN(hidden_dim = model_conf.node_embed_size) + + + # if self.use_np_update: + # self.trunk[f'node_transition_{b}'] = Node_update(d_msa=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size) + # else: + # tfmr_in = self._ipa_conf.c_s + # tfmr_layer = torch.nn.TransformerEncoderLayer( + # d_model=tfmr_in, + # nhead=self._ipa_conf.seq_tfmr_num_heads, + # dim_feedforward=tfmr_in, + # batch_first=True, + # dropout=self._model_conf.dropout, + # norm_first=False + # ) + # self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder( + # tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False) + # self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear( + # tfmr_in, self._ipa_conf.c_s, init="final") + + # self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition( + # c=self._ipa_conf.c_s) + + if self.use_mid_bb_update: + if self.use_mid_bb_update_e3: + self.trunk[f'bb_update_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, + e3_d_node_l1 = e3_d_node_l1, final = True, + init='final',dropout=0.0) + else: + self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate( + self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) + + # if b < self._ipa_conf.num_blocks-1: + # # No edge update on the last block. + # if self.use_np_update: + # self.trunk[f'edge_transition_{b}'] = Pair_update(d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size) + # else: + # edge_in = self._model_conf.edge_embed_size + # self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition( + # node_embed_size=self._ipa_conf.c_s, + # edge_embed_in=edge_in, + # edge_embed_out=self._model_conf.edge_embed_size, + # ) + + + + + if self.use_mid_bb_update_e3: + # self.bb_update_layer = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1, + # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size, + # final = True) + self.bb_update_layer = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, + e3_d_node_l1 = e3_d_node_l1, final = True, + init='final',dropout=0.0) + elif self.use_e3_transformer: + + irreps_1 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") + irreps_2 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") + # irreps_3 = e3nn.o3.Irreps("2x1o + 2x1e") + irreps_3 = e3nn.o3.Irreps("6x0e") + self.bb_tpc = TensorProductConvLayer(irreps_1, irreps_2, irreps_3, self._ipa_conf.c_s, fc_factor=1, + init='final', dropout=0.1) + + # self.bb_update_layer = ipa_pytorch.BackboneUpdate( + # self._ipa_conf.c_s+3*e3_d_node_l1, use_rot_updates=True, dropout=0.0) + + else: + self.bb_update_layer = ipa_pytorch.BackboneUpdate( + self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) + + def forward(self, input_feats, use_mask_aatype = False): + ''' + note: B and L are changing during training + input_feats.keys(): + 'aatype' (B,L) + 'res_mask' (B,L) + 't' (B,1) + 'trans_1' (B,L,3) + 'rotmats_1' (B,L,3,3) + 'trans_t' (B,L,3) + 'rotmats_t' (B,L,3,3) + ''' + + node_mask = input_feats['res_mask'] + edge_mask = node_mask[:, None] * node_mask[:, :, None] + continuous_t = input_feats['t'] + trans_t = input_feats['trans_t'] + rotmats_t = input_feats['rotmats_t'] + aatype = input_feats['aatype'] + + if self.use_adapter_node: + energy = input_feats['energy'] # (B,) + + + if self.use_e3_transformer or self.use_mid_bb_update_e3: + xyz_t = all_atom.to_atom37(trans_t, rotmats_t)[:, :, :3, :] # (B,L,3,3) + edge_src, edge_dst, pair_index = get_sub_graph(xyz_t[:,:,1,:], mask = node_mask, + dist_max=self.max_dist, kmin=32) + xyz_graph = xyz_t[:,:,1,:].reshape(-1,3) + edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst] + edge_sh = e3nn.o3.spherical_harmonics(self.irreps_sh, edge_vec, + normalize=True, normalization='component') # (num_edges, irreps_sh) + + # if use_mask_aatype: + # factor = 0.15 + # mask_token = 21 + # mask_aatype = torch.ones(aatype.shape) + # for b in range(aatype.shape[0]): + # mask_aa_num = torch.tensor(random.random()*factor*aatype.shape[1]).int() + # indices = random.sample(range(aatype.shape[1]), mask_aa_num) + # mask_aatype[b,indices] = 0 + # mask_aatype = mask_aatype.to(aatype.device) + # aatype = (aatype * mask_aatype + (1-mask_aatype) * mask_token).int() + + node_repr_pre = input_feats['node_repr_pre'] + pair_repr_pre = input_feats['pair_repr_pre'] + + if 'trans_sc' not in input_feats: + trans_sc = torch.zeros_like(trans_t) + else: + trans_sc = input_feats['trans_sc'] + + # Initialize node and edge embeddings, there is only seq_pos info in node, and there is bond-len info in edge + init_node_embed = self.node_embedder(continuous_t, aatype, node_repr_pre, node_mask) + init_node_embed = init_node_embed * node_mask[..., None] + + if self.use_adapter_node: + init_node_embed = self.energy_adapter(init_node_embed, energy, mask = node_mask) + init_node_embed = init_node_embed * node_mask[..., None] + + init_edge_embed = self.edge_embedder( + init_node_embed, trans_t, trans_sc, pair_repr_pre, edge_mask) + init_edge_embed = init_edge_embed * edge_mask[..., None] + + if self.use_e3_transformer or self.use_mid_bb_update_e3: + l1_feats = torch.concat([xyz_t[:,:,0,:]-xyz_t[:,:,1,:], + xyz_t[:,:,1,:], + xyz_t[:,:,2,:]-xyz_t[:,:,1,:]], dim=-1) # (B,L,3*3) + + l1_feats = self.proj_l1_init(l1_feats) # (B,L,4*3) + + l1_feats = l1_feats * node_mask[..., None] + + # Initial rigids + curr_rigids = du.create_rigid(rotmats_t, trans_t,) + + # Main trunk + curr_rigids = self.rigids_ang_to_nm(curr_rigids) + + node_embed = init_node_embed + edge_embed = init_edge_embed + for b in range(self._ipa_conf.num_blocks): + + if self.use_e3_transformer: + # ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, edge_embed, l1_feats, pair_index, + # edge_src, edge_dst, edge_sh, mask = node_mask) + # l1_feats = l1_feats * node_mask[..., None] + # ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, + # edge_embed, + # l1_feats, + # mask = node_mask) + # l1_feats = l1_feats * node_mask[..., None] + ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, + edge_embed, + l1_feats, + curr_rigids, + node_mask) # (B,L,d_node) + l1_feats = l1_feats * node_mask[..., None] + else: + ipa_embed = self.trunk[f'ipa_{b}'](node_embed, + edge_embed, + curr_rigids, + node_mask) # (B,L,d_node) + ipa_embed = node_embed + ipa_embed + + ipa_embed = ipa_embed * node_mask[..., None] + node_embed = self.trunk[f'ipa_ln_{b}'](ipa_embed) + + if self.use_adapter_node: + node_embed = self.trunk[f'energy_adapter_{b}'](node_embed, energy, mask = node_mask) + node_embed = node_embed * node_mask[..., None] + + + __, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t) + node_embed = node_embed * node_mask[..., None] + + # trans_t_new, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t) + # rotmats_t_new = torch.zeros(trans_t_new.shape, device = trans_t_new.device) # (B,L,3) + # rigid_update = torch.concat([trans_t_new,rotmats_t_new], dim=-1) # (B,L,6) + # node_embed = node_embed * node_mask[..., None] + # curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, node_mask[..., None]) + + + + # if self.use_np_update: + # node_embed = self.trunk[f'node_transition_{b}'](node_embed, edge_embed, node_mask) + # else: + # seq_tfmr_out = self.trunk[f'seq_tfmr_{b}']( + # node_embed, src_key_padding_mask=(1 - node_mask).bool()) + # #node_embed = init_node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out) + # node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out) + # node_embed = self.trunk[f'node_transition_{b}'](node_embed) + # node_embed = node_embed * node_mask[..., None] + + if self.use_mid_bb_update: + if self.use_mid_bb_update_e3: + rigid_update = self.trunk[f'bb_update_{b}'](node_embed, edge_embed, l1_feats, pair_index, + edge_src, edge_dst, edge_sh, mask = node_mask) + elif self.use_e3_transformer: + rigid_update = self.trunk[f'bb_update_{b}'](torch.concat([node_embed,l1_feats],dim=-1)) + else: + rigid_update = self.trunk[f'bb_update_{b}'](node_embed) + curr_rigids = curr_rigids.compose_q_update_vec( + rigid_update, node_mask[..., None]) + + # if b < self._ipa_conf.num_blocks-1: + # if self.use_np_update: + # edge_embed = self.trunk[f'edge_transition_{b}']( + # node_embed, edge_embed, node_mask) + # else: + # edge_embed = self.trunk[f'edge_transition_{b}']( + # node_embed, edge_embed) + # edge_embed *= edge_mask[..., None] + + + if self.use_mid_bb_update_e3: + rigid_update = self.bb_update_layer(node_embed, edge_embed, l1_feats, pair_index, + edge_src, edge_dst, edge_sh, mask = node_mask) + # print('\nl1_feats:','equ='+str(torch.mean(l1_feats[:5])),'inv='+str(torch.mean(torch.linalg.norm(l1_feats[:5],dim=-1)))) + + elif self.use_e3_transformer: + node_feats_total = torch.concat([node_embed,l1_feats],dim=-1) + rigid_update = self.bb_tpc(node_feats_total, node_feats_total, node_embed) + + else: + rigid_update = self.bb_update_layer(node_embed) + + curr_rigids = curr_rigids.compose_q_update_vec( + rigid_update, node_mask[..., None]) + + curr_rigids = self.rigids_nm_to_ang(curr_rigids) + pred_trans = curr_rigids.get_trans() + # pred_trans = pred_trans - torch.mean(pred_trans, dim=-2, keepdims=True) + pred_rotmats = curr_rigids.get_rots().get_rot_mats() + # pred_aatype = self.aatype_pred_layer(node_embed) + + + if self.use_torsions: + pred_torsions = node_embed + self.torsions_pred_layer1(node_embed) + pred_torsions = self.torsions_pred_layer2(pred_torsions).reshape(input_feats['aatype'].shape+(self.num_torsions,2)) # (B,L,self.num_torsions,2) + + norm_torsions = torch.sqrt(torch.sum(pred_torsions ** 2, dim=-1, keepdim=True)) # (B,L,self.num_torsions,1) + pred_torsions = pred_torsions / norm_torsions # (B,L,self.num_torsions,2) + + add_rot = pred_torsions.new_zeros((1,) * len(pred_torsions.shape[:-2])+(8-self.num_torsions,2)) # (1,1,8-self.num_torsions,2) + add_rot[..., 1] = 1 + add_rot = add_rot.expand(*pred_torsions.shape[:-2], -1, -1) # (B,L,8-self.num_torsions,2) + pred_torsions_with_CB = torch.concat([add_rot, pred_torsions],dim=-2) # (B,L,8,2) + + # aatype # (B,L) + + return { + 'pred_trans': pred_trans, + 'pred_rotmats': pred_rotmats, + # 'pred_aatype': pred_aatype + 'pred_torsions_with_CB': pred_torsions_with_CB, + } diff --git a/models/flow_module.py b/models/flow_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1862670984b82589caae2ae63d058a2ae2b86221 --- /dev/null +++ b/models/flow_module.py @@ -0,0 +1,475 @@ +from typing import Any +import torch +import time +import os +import random +import wandb +import numpy as np +import pandas as pd +import logging +from pytorch_lightning import LightningModule +from analysis import metrics +from analysis import utils as au +from models.flow_model import FlowModel +from models import utils as mu +from data.interpolant import Interpolant +from data import utils as du +from data import all_atom +from data import so3_utils +from experiments import utils as eu +from data import residue_constants +from data.residue_constants import order2restype_with_mask +from pytorch_lightning.loggers.wandb import WandbLogger +from openfold.utils.exponential_moving_average import ExponentialMovingAverage as EMA +from openfold.utils.tensor_utils import ( + tensor_tree_map, +) + +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +class FlowModule(LightningModule): + + def __init__(self, cfg, folding_cfg=None): + super().__init__() + self._print_logger = logging.getLogger(__name__) + self._exp_cfg = cfg.experiment + self._model_cfg = cfg.model + self._data_cfg = cfg.data + self._interpolant_cfg = cfg.interpolant + + # Set-up vector field prediction model + self.model = FlowModel(cfg.model) + + # self.ema = EMA(self.model, decay=0.99) + + # Set-up interpolant + self.interpolant = Interpolant(cfg.interpolant) + + self._sample_write_dir = self._exp_cfg.checkpointer.dirpath + os.makedirs(self._sample_write_dir, exist_ok=True) + + self.validation_epoch_metrics = [] + self.validation_epoch_samples = [] + self.save_hyperparameters() + + def on_train_start(self): + self._epoch_start_time = time.time() + + def on_train_epoch_end(self): + epoch_time = (time.time() - self._epoch_start_time) / 60.0 + self.log( + 'train/epoch_time_minutes', + epoch_time, + on_step=False, + on_epoch=True, + prog_bar=False + ) + self._epoch_start_time = time.time() + + def model_step(self, noisy_batch: Any, use_mask_aatype = True, eps=1e-8): + training_cfg = self._exp_cfg.training + loss_mask = noisy_batch['res_mask'] + if training_cfg.min_plddt_mask is not None: + plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask + loss_mask *= plddt_mask + num_batch, num_res = loss_mask.shape + + # Ground truth labels + gt_trans_1 = noisy_batch['trans_1'] + gt_rotmats_1 = noisy_batch['rotmats_1'] + rotmats_t = noisy_batch['rotmats_t'] + gt_rot_vf = so3_utils.calc_rot_vf( + rotmats_t, gt_rotmats_1.type(torch.float32)) + + gt_bb_atoms_total_result = all_atom.to_atom37(gt_trans_1, gt_rotmats_1, aatype=noisy_batch['aatype'], get_mask=True) + gt_bb_atoms = gt_bb_atoms_total_result[0][:, :, :3, :] + atom37_mask = gt_bb_atoms_total_result[1] + + gt_atoms = noisy_batch['all_atom_positions'] # (B, L, 37, 3) + + # Timestep used for normalization. + t = noisy_batch['t'] + norm_scale = 1 - torch.min( + t[..., None], torch.tensor(training_cfg.t_normalize_clip)) + + # Model output predictions. + model_output = self.model(noisy_batch, use_mask_aatype=use_mask_aatype) + pred_trans_1 = model_output['pred_trans'] + pred_rotmats_1 = model_output['pred_rotmats'] + pred_torsions_with_CB = model_output['pred_torsions_with_CB'] + pred_rots_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1) + + # aatype loss + # pred_aatype_1 = model_output['pred_aatype'] # (B,L,21) + # pred_aatype_1_logits = torch.log(torch.softmax(pred_aatype_1,dim=-1)) # (B,L,21) + # aatype_onehot = torch.nn.functional.one_hot(noisy_batch['aatype'],num_classes=21) # (B,L,21) + # aatype_loss = -torch.sum(torch.mul(pred_aatype_1_logits, aatype_onehot), dim=-1) * loss_mask # (B,L) + # aatype_loss = torch.sum(aatype_loss, dim=-1) / torch.sum(loss_mask, dim=-1) + # aatype_loss = training_cfg.aatype_loss_weight * aatype_loss # (B,) + + # Translation VF loss + trans_error = (gt_trans_1 - pred_trans_1) / norm_scale + loss_denom = torch.sum(loss_mask, dim=-1) * 3 # + trans_loss = training_cfg.translation_loss_weight * torch.sum( + trans_error ** 2 * loss_mask[..., None], + dim=(-1, -2) + ) / loss_denom + # trans_loss = torch.zeros(num_batch,device=pred_trans_1.device,dtype=torch.float32) + + # Rotation VF loss + rots_vf_error = (gt_rot_vf - pred_rots_vf) / norm_scale + rots_vf_loss = training_cfg.rotation_loss_weights * torch.sum( + rots_vf_error ** 2 * loss_mask[..., None], + dim=(-1, -2) + ) / loss_denom + # rots_vf_loss = torch.zeros(num_batch,device=pred_trans_1.device,dtype=torch.float32) + + + # torsion loss + torsion_angles, torsion_mask = all_atom.prot_to_torsion_angles(noisy_batch['aatype'], gt_atoms, atom37_mask) + # print('atom37_mask',atom37_mask.shape, atom37_mask[0,:5,:15]) # (B,L,37) + # print('torsion_angles',torsion_angles.shape, torsion_angles[0,:5,:15,:]) # (B,L,7,2) + # print('torsion_mask',torsion_mask.shape, torsion_mask[0,:5,:15]) # (B,L,7) + torsion_loss = torch.sum( + (torsion_angles - pred_torsions_with_CB[:,:,1:,:]) ** 2 * torsion_mask[..., None], + dim=(-1, -2, -3) + ) / torch.sum(torsion_mask, dim=(-1, -2)) + + + # atom loss + pred_atoms, atoms_mask = all_atom.to_atom37(pred_trans_1, pred_rotmats_1, aatype = noisy_batch['aatype'], torsions_with_CB = pred_torsions_with_CB, get_mask = True) # atoms_mask (B,L,37) + atoms_mask = atoms_mask * loss_mask[...,None] # (B,L,37) + pred_atoms_flat, gt_atoms_flat, _ = du.batch_align_structures( + pred_atoms.reshape(num_batch, -1, 3), gt_atoms.reshape(num_batch, -1, 3), mask=atoms_mask.reshape(num_batch, -1) + ) + gt_atoms = gt_atoms_flat * training_cfg.bb_atom_scale / norm_scale # (B, true_atoms,3) + pred_atoms = pred_atoms_flat * training_cfg.bb_atom_scale / norm_scale + + bb_atom_loss = torch.sum( + (gt_atoms - pred_atoms) ** 2, + dim=(-1, -2) + ) / torch.sum(atoms_mask, dim=(-1, -2)) + + + + # atom distance loss + # pred_atoms, atoms_mask = all_atom.to_atom37(pred_trans_1, pred_rotmats_1, aatype = noisy_batch['aatype'], torsions_with_CB = pred_torsions_with_CB, get_mask = True) # atoms_mask (B,L,37) + # gt_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None] + # pred_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None] + # atoms_mask = atoms_mask * loss_mask[...,None] # (B,L,37) + # gt_flat_atoms = gt_atoms.reshape([num_batch, num_res*37, 3]) # (B,L*37,3) + # gt_pair_dists = torch.linalg.norm( + # gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1) # (B,L*37,L*37) + # pred_flat_atoms = pred_atoms.reshape([num_batch, num_res*37, 3]) + # pred_pair_dists = torch.linalg.norm( + # pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1) + # flat_mask = atoms_mask.reshape([num_batch, num_res*37]) # (B,L*37) + + # gt_pair_dists = gt_pair_dists * flat_mask[..., None] + # pred_pair_dists = pred_pair_dists * flat_mask[..., None] + # pair_dist_mask = flat_mask[..., None] * flat_mask[:, None, :] # (B,L*37, L*37) + + # bb_atom_loss = torch.sum( + # (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask, + # dim=(1, 2)) + # bb_atom_loss /= (torch.sum(pair_dist_mask, dim=(1, 2))) + + + + # Pairwise distance loss + pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3] + gt_bb_atoms_pair = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] + pred_bb_atoms_pair = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] + gt_flat_atoms = gt_bb_atoms_pair.reshape([num_batch, num_res*3, 3]) + gt_pair_dists = torch.linalg.norm( + gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1) + pred_flat_atoms = pred_bb_atoms_pair.reshape([num_batch, num_res*3, 3]) + pred_pair_dists = torch.linalg.norm( + pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1) + + flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) + flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3]) # (B,L*3) + flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) + flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3]) # (B,L*3) + + gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None] + pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None] + pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :] # (B,L*3, L*3) + + pair_dist_mat_loss = torch.sum( + (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask, + dim=(1, 2)) / (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res) + + + # sequence distance loss + pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3] + gt_bb_atoms_seq = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] + pred_bb_atoms_seq = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] + gt_flat_atoms = gt_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) # (B,L*3,3) + gt_seq_dists = torch.linalg.norm( + gt_flat_atoms[:, :-3, :] - gt_flat_atoms[:, 3:, :], dim=-1) # (B,3*(L-1)) + pred_flat_atoms = pred_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) # (B,L*3,3) + pred_seq_dists = torch.linalg.norm( + pred_flat_atoms[:, :-3, :] - pred_flat_atoms[:, 3:, :], dim=-1) # (B,3*(L-1)) + + flat_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) # (B,L,3) + flat_mask = (flat_mask[:,1:,:] * flat_mask[:,:-1,:]).reshape([num_batch, 3*(num_res-1)]) # (B,3*(L-1)) + + gt_seq_dists = gt_seq_dists * flat_mask + pred_seq_dists = pred_seq_dists * flat_mask + + seq_dist_mat_loss = torch.sum( + (gt_seq_dists - pred_seq_dists)**2 * flat_mask, + dim=(1)) / (torch.sum(flat_mask, dim=(1))) + + dist_mat_loss = pair_dist_mat_loss + seq_dist_mat_loss + + se3_vf_loss = trans_loss + rots_vf_loss + auxiliary_loss = (bb_atom_loss + dist_mat_loss) * ( + t[:, 0] > training_cfg.aux_loss_t_pass + ) + auxiliary_loss *= self._exp_cfg.training.aux_loss_weight + se3_vf_loss += auxiliary_loss + + se3_vf_loss += torsion_loss + + return { + "trans_loss": trans_loss, + "rots_vf_loss": rots_vf_loss, + "bb_atom_loss": bb_atom_loss, + "dist_mat_loss": dist_mat_loss, + "auxiliary_loss": auxiliary_loss, + "se3_vf_loss": se3_vf_loss, + # "aatype_loss":aatype_loss + } + + def validation_step(self, batch: Any, batch_idx: int): + # if self.trainer.global_step > 100000: + # # load ema weights + # clone_param = lambda t: t.detach().clone() + # self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + # self.model.load_state_dict(self.ema.state_dict()["params"]) + + res_mask = batch['res_mask'] + self.interpolant.set_device(res_mask.device) + num_batch, num_res = res_mask.shape + + self.model.eval() + samples = self.interpolant.sample( + batch, + self.model, + )[0][-1].numpy() + self.model.train() + + batch_metrics = [] + for i in range(num_batch): + + # Write out sample to PDB file + final_pos = samples[i] + + # mdtraj_metrics = metrics.calc_mdtraj_metrics(saved_path) + ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, residue_constants.atom_order['CA']]) + + # use 3 bb atoms coords to calculate RMSD, and use CA to calculate tm_score + rmsd = metrics.calc_aligned_rmsd( + final_pos[:, :3].reshape(-1,3), batch['all_atom_positions'][i].cpu().numpy()[:,:3].reshape(-1,3)) + seq_string = ''.join([order2restype_with_mask[int(aa)] for aa in batch['aatype'][i].cpu()]) + + if len(seq_string) != final_pos[:, 1].shape[0]: + seq_string = 'A'*final_pos[:, 1].shape[0] + + tm_score,_ = metrics.calc_tm_score( + final_pos[:, 1], batch['all_atom_positions'][i].cpu().numpy()[:,1], + seq_string, seq_string) + + valid_loss = {'rmsd_loss':rmsd, + 'tm_score':tm_score } + + batch_metrics.append((ca_ca_metrics | valid_loss)) # merge metrics + saved_path = au.write_prot_to_pdb( + final_pos, + os.path.join( + self._sample_write_dir, + f'idx_{batch_idx}_len_{num_res}_rmsd_{rmsd}_tm_{tm_score}.pdb'), + aatype=batch['aatype'][i].cpu().numpy(), + no_indexing=True + ) + if isinstance(self.logger, WandbLogger): + self.validation_epoch_samples.append( + [saved_path, self.global_step, wandb.Molecule(saved_path)] + ) + + batch_metrics = pd.DataFrame(batch_metrics) + + print("") + for key in batch_metrics.columns: + print('%s=%.3f ' %(key, np.mean(batch_metrics[key])), end='') + + self.validation_epoch_metrics.append(batch_metrics) + + def on_validation_epoch_end(self): + if len(self.validation_epoch_samples) > 0: + self.logger.log_table( + key='valid/samples', + columns=["sample_path", "global_step", "Protein"], + data=self.validation_epoch_samples) + self.validation_epoch_samples.clear() + val_epoch_metrics = pd.concat(self.validation_epoch_metrics) + for metric_name,metric_val in val_epoch_metrics.mean().to_dict().items(): + if metric_name in ['rmsd_loss','tm_score']: + self._log_scalar( + f'valid/{metric_name}', + metric_val, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(val_epoch_metrics), + ) + else: + self._log_scalar( + f'valid/{metric_name}', + metric_val, + on_step=False, + on_epoch=True, + prog_bar=False, + batch_size=len(val_epoch_metrics), + ) + self.validation_epoch_metrics.clear() + + def _log_scalar( + self, + key, + value, + on_step=True, + on_epoch=False, + prog_bar=True, + batch_size=None, + sync_dist=False, + rank_zero_only=True + ): + if sync_dist and rank_zero_only: + raise ValueError('Unable to sync dist when rank_zero_only=True') + self.log( + key, + value, + on_step=on_step, + on_epoch=on_epoch, + prog_bar=prog_bar, + batch_size=batch_size, + sync_dist=sync_dist, + rank_zero_only=rank_zero_only + ) + + def training_step(self, batch: Any, stage: int): + # if(self.ema.device != batch["aatype"].device): + # self.ema.to(batch["aatype"].device) + # self.ema.update(self.model) + + seq_len = batch["aatype"].shape[1] + batch_size = min(self._data_cfg.sampler.max_batch_size, self._data_cfg.sampler.max_num_res_squared // seq_len**2) # dynamic batch size + + for key,value in batch.items(): + batch[key] = value.repeat((batch_size,)+(1,)*(len(value.shape)-1)) + + + # step_start_time = time.time() + self.interpolant.set_device(batch['res_mask'].device) + noisy_batch = self.interpolant.corrupt_batch(batch) + if self._interpolant_cfg.self_condition and random.random() > 0.5: + with torch.no_grad(): + self.model.eval() + model_sc = self.model(noisy_batch) + noisy_batch['trans_sc'] = model_sc['pred_trans'] + self.model.train() + batch_losses = self.model_step(noisy_batch) + num_batch = batch_losses['bb_atom_loss'].shape[0] + total_losses = { + k: torch.mean(v) for k,v in batch_losses.items() + } + for k,v in total_losses.items(): + self._log_scalar( + f"train/{k}", v, prog_bar=False, batch_size=num_batch) + + # Losses to track. Stratified across t. + t = torch.squeeze(noisy_batch['t']) + self._log_scalar( + "train/t", + np.mean(du.to_numpy(t)), + prog_bar=False, batch_size=num_batch) + + # for loss_name, loss_dict in batch_losses.items(): + # stratified_losses = mu.t_stratified_loss( + # t, loss_dict, loss_name=loss_name) + # for k,v in stratified_losses.items(): + # self._log_scalar( + # f"train/{k}", v, prog_bar=False, batch_size=num_batch) + + # Training throughput + self._log_scalar( + "train/length", batch['res_mask'].shape[1], prog_bar=False, batch_size=num_batch) + self._log_scalar( + "train/batch_size", num_batch, prog_bar=False) + + # step_time = time.time() - step_start_time + # self._log_scalar( + # "train/examples_per_second", num_batch / step_time) + + train_loss = ( + total_losses[self._exp_cfg.training.loss] + ) + self._log_scalar( + "train/loss", train_loss, batch_size=num_batch) + + + return train_loss + + def configure_optimizers(self): + return torch.optim.AdamW( + params=filter(lambda p: p.requires_grad, self.model.parameters()), + # params=self.model.parameters(), + **self._exp_cfg.optimizer + ) + + + def predict_step(self, batch, batch_idx): + # device = f'cuda:{torch.cuda.current_device()}' + # interpolant = Interpolant(self._infer_cfg.interpolant) + # interpolant.set_device(device) + # self.interpolant = interpolant + + self.interpolant.set_device(batch['res_mask'].device) + + + diffuse_mask = torch.ones(batch['aatype'].shape) + + filename = str(batch['filename'][0]) + sample_dir = os.path.join( + self._output_dir, f'batch_idx_{batch_idx}_{filename}') + + self.model.eval() + atom37_traj, model_traj, _ = self.interpolant.sample( + batch, self.model + ) + + # print('pred shape') + # print(batch['node_repr_pre'].shape) # (B,L,1280) + # print(len(atom37_traj),atom37_traj[0].shape) # 101,(B,L,37,3) + # print(len(model_traj),model_traj[0].shape) # 100,(B,L,37,3) + + os.makedirs(sample_dir, exist_ok=True) + for batch_index in range(atom37_traj[0].shape[0]): + bb_traj = du.to_numpy(torch.stack(atom37_traj, dim=0))[:,batch_index] # (101,L,37,3) + x0_traj = du.to_numpy(torch.stack(model_traj, dim=0))[:,batch_index] + _ = eu.save_traj( + bb_traj[-1], + bb_traj, + np.flip(x0_traj, axis=0), + du.to_numpy(diffuse_mask)[batch_index], + output_dir=sample_dir, + aatype=batch['aatype'][batch_index].cpu().numpy(), + index=batch_index, + ) + + with open(os.path.join(sample_dir,'seq.txt'), 'w') as f: + f.write(batch['seq'][0]) + diff --git a/models/ipa_pytorch.py b/models/ipa_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b126b4a36669b90e254f0d5cfd73ddf1ee871fc9 --- /dev/null +++ b/models/ipa_pytorch.py @@ -0,0 +1,693 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modified code of Openfold's IPA.""" + +import numpy as np +import torch +import math +from scipy.stats import truncnorm +import torch.nn as nn +from typing import Optional, Callable, List, Sequence +from openfold.utils.rigid_utils import Rigid +from data import all_atom + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def compute_angles(ca_pos, pts): + batch_size, num_res, num_heads, num_pts, _ = pts.shape + calpha_vecs = (ca_pos[:, :, None, :] - ca_pos[:, None, :, :]) + 1e-10 + calpha_vecs = torch.tile(calpha_vecs[:, :, :, None, None, :], (1, 1, 1, num_heads, num_pts, 1)) + ipa_pts = pts[:, :, None, :, :, :] - torch.tile(ca_pos[:, :, None, None, None, :], (1, 1, num_res, num_heads, num_pts, 1)) + phi_angles = all_atom.calculate_neighbor_angles( + calpha_vecs.reshape(-1, 3), + ipa_pts.reshape(-1, 3) + ).reshape(batch_size, num_res, num_res, num_heads, num_pts) + return phi_angles + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class StructureModuleTransition(nn.Module): + def __init__(self, c): + super(StructureModuleTransition, self).__init__() + + self.c = c + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + self.linear_3 = Linear(self.c, self.c, init="final") + self.relu = nn.ReLU() + self.ln = nn.LayerNorm(self.c) + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + s = s + s_initial + s = self.ln(s) + + return s + + +class EdgeTransition(nn.Module): + def __init__( + self, + *, + node_embed_size, + edge_embed_in, + edge_embed_out, + num_layers=2, + node_dilation=2 + ): + super(EdgeTransition, self).__init__() + + bias_embed_size = node_embed_size // node_dilation + self.initial_embed = Linear( + node_embed_size, bias_embed_size, init="relu") + hidden_size = bias_embed_size * 2 + edge_embed_in + trunk_layers = [] + for _ in range(num_layers): + trunk_layers.append(Linear(hidden_size, hidden_size, init="relu")) + trunk_layers.append(nn.ReLU()) + self.trunk = nn.Sequential(*trunk_layers) + self.final_layer = Linear(hidden_size, edge_embed_out, init="final") + self.layer_norm = nn.LayerNorm(edge_embed_out) + + def forward(self, node_embed, edge_embed): + node_embed = self.initial_embed(node_embed) + batch_size, num_res, _ = node_embed.shape + edge_bias = torch.cat([ + torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)), + torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)), + ], axis=-1) + edge_embed = torch.cat( + [edge_embed, edge_bias], axis=-1).reshape( + batch_size * num_res**2, -1) + edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed) + edge_embed = self.layer_norm(edge_embed) + edge_embed = edge_embed.reshape( + batch_size, num_res, num_res, -1 + ) + return edge_embed + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + def __init__( + self, + ipa_conf, + inf: float = 1e5, + eps: float = 1e-8, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + self._ipa_conf = ipa_conf + + self.c_s = ipa_conf.c_s + self.c_z = ipa_conf.c_z + self.c_hidden = ipa_conf.c_hidden + self.no_heads = ipa_conf.no_heads + self.no_qk_points = ipa_conf.no_qk_points + self.no_v_points = ipa_conf.no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc) + self.linear_kv = Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.linear_q_points = Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.linear_kv_points = Linear(self.c_s, hpkv) + + self.linear_b = Linear(self.c_z, self.no_heads) + self.down_z = Linear(self.c_z, self.c_z // 4) + + self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = ( + self.c_z // 4 + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + if _offload_inference: + z = _z_reference_list + else: + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view( + q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) + ) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if(_offload_inference): + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + a = torch.matmul( + math.sqrt(1.0 / (3 * self.c_hidden)) * + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + # a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) + + # [*, N_res, N_res, H, P_q, 3] + pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_displacement ** 2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view( + *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + ) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + ) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, v.transpose(-2, -3) + ).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + ( + a[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + ), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps) + o_pt_norm_feats = flatten_final_dims( + o_pt_dists, 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if(_offload_inference): + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z // 4] + pair_z = self.down_z(z[0]) + o_pair = torch.matmul(a.transpose(-2, -3), pair_z) + + # [*, N_res, H * C_z // 4] + o_pair = flatten_final_dims(o_pair, 2) + + o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair] + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + o_feats, dim=-1 + ) + ) + + return s + + +class TorsionAngles(nn.Module): + def __init__(self, c, num_torsions, eps=1e-8): + super(TorsionAngles, self).__init__() + + self.c = c + self.eps = eps + self.num_torsions = num_torsions + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + # TODO: Remove after published checkpoint is updated without these weights. + self.linear_3 = Linear(self.c, self.c, init="final") + self.linear_final = Linear( + self.c, self.num_torsions * 2, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + + s = s + s_initial + unnormalized_s = self.linear_final(s) + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + normalized_s = unnormalized_s / norm_denom + + return unnormalized_s, normalized_s + + +class RotationVFLayer(nn.Module): + def __init__(self, dim): + super(RotationVFLayer, self).__init__() + + self.linear_1 = Linear(dim, dim, init="relu") + self.linear_2 = Linear(dim, dim, init="relu") + self.linear_3 = Linear(dim, dim) + self.final_linear = Linear(dim, 6, init="final") + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + s = s + s_initial + return self.final_linear(s) + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s, use_rot_updates, dropout=0.0): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + self.noise_schedule = 1.0 + update_dim = 6 if use_rot_updates else 3 + + self.linear = nn.Sequential( + Linear(self.c_s, update_dim, init="final"), + # nn.Tanh(), + ) + + + def forward(self, s: torch.Tensor): + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.noise_schedule * self.linear(s) + + return update + +class IpaScore(nn.Module): + + def __init__(self, model_conf, diffuser): + super(IpaScore, self).__init__() + self._model_conf = model_conf + ipa_conf = model_conf.ipa + self._ipa_conf = ipa_conf + self.diffuser = diffuser + + self.scale_pos = lambda x: x * ipa_conf.coordinate_scaling + self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos) + + self.unscale_pos = lambda x: x / ipa_conf.coordinate_scaling + self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos) + self.trunk = nn.ModuleDict() + + for b in range(ipa_conf.num_blocks): + self.trunk[f'ipa_{b}'] = InvariantPointAttention(ipa_conf) + self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(ipa_conf.c_s) + self.trunk[f'skip_embed_{b}'] = Linear( + self._model_conf.node_embed_size, + self._ipa_conf.c_skip, + init="final" + ) + tfmr_in = ipa_conf.c_s + self._ipa_conf.c_skip + tfmr_layer = torch.nn.TransformerEncoderLayer( + d_model=tfmr_in, + nhead=ipa_conf.seq_tfmr_num_heads, + dim_feedforward=tfmr_in, + batch_first=True, + dropout=0.0, + norm_first=False + ) + self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder( + tfmr_layer, ipa_conf.seq_tfmr_num_layers) + self.trunk[f'post_tfmr_{b}'] = Linear( + tfmr_in, ipa_conf.c_s, init="final") + self.trunk[f'node_transition_{b}'] = StructureModuleTransition( + c=ipa_conf.c_s) + self.trunk[f'bb_update_{b}'] = BackboneUpdate(ipa_conf.c_s) + + if b < ipa_conf.num_blocks-1: + # No edge update on the last block. + edge_in = self._model_conf.edge_embed_size + self.trunk[f'edge_transition_{b}'] = EdgeTransition( + node_embed_size=ipa_conf.c_s, + edge_embed_in=edge_in, + edge_embed_out=self._model_conf.edge_embed_size, + ) + + self.torsion_pred = TorsionAngles(ipa_conf.c_s, 1) + + def forward(self, init_node_embed, edge_embed, input_feats): + node_mask = input_feats['res_mask'].type(torch.float32) + diffuse_mask = (1 - input_feats['fixed_mask'].type(torch.float32)) * node_mask + edge_mask = node_mask[..., None] * node_mask[..., None, :] + init_frames = input_feats['rigids_t'].type(torch.float32) + + curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames)) + init_rigids = Rigid.from_tensor_7(init_frames) + init_rots = init_rigids.get_rots() + + # Main trunk + curr_rigids = self.scale_rigids(curr_rigids) + init_node_embed = init_node_embed * node_mask[..., None] + node_embed = init_node_embed * node_mask[..., None] + for b in range(self._ipa_conf.num_blocks): + ipa_embed = self.trunk[f'ipa_{b}']( + node_embed, + edge_embed, + curr_rigids, + node_mask) + ipa_embed *= node_mask[..., None] + node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed) + seq_tfmr_in = torch.cat([ + node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed) + ], dim=-1) + seq_tfmr_out = self.trunk[f'seq_tfmr_{b}']( + seq_tfmr_in, src_key_padding_mask=1 - node_mask) + node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out) + node_embed = self.trunk[f'node_transition_{b}'](node_embed) + node_embed = node_embed * node_mask[..., None] + rigid_update = self.trunk[f'bb_update_{b}']( + node_embed * diffuse_mask[..., None]) + curr_rigids = curr_rigids.compose_q_update_vec( + rigid_update, diffuse_mask[..., None]) + + if b < self._ipa_conf.num_blocks-1: + edge_embed = self.trunk[f'edge_transition_{b}']( + node_embed, edge_embed) + edge_embed *= edge_mask[..., None] + rot_score = self.diffuser.calc_rot_score( + init_rigids.get_rots(), + curr_rigids.get_rots(), + input_feats['t'] + ) + rot_score = rot_score * node_mask[..., None] + + curr_rigids = self.unscale_rigids(curr_rigids) + trans_score = self.diffuser.calc_trans_score( + init_rigids.get_trans(), + curr_rigids.get_trans(), + input_feats['t'][:, None, None], + use_torch=True, + ) + trans_score = trans_score * node_mask[..., None] + _, psi_pred = self.torsion_pred(node_embed) + model_out = { + 'psi': psi_pred, + 'rot_score': rot_score, + 'trans_score': trans_score, + 'final_rigids': curr_rigids, + } + return model_out diff --git a/models/node_embedder.py b/models/node_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..c496003893dd282bab9f54aedfea6c642e0af760 --- /dev/null +++ b/models/node_embedder.py @@ -0,0 +1,65 @@ +"""Neural network for embedding node features.""" +import torch +from torch import nn +from models.utils import get_index_embedding, get_time_embedding, add_RoPE + +class NodeEmbedder(nn.Module): + def __init__(self, module_cfg): + super(NodeEmbedder, self).__init__() + self._cfg = module_cfg + self.c_s = self._cfg.c_s + self.c_pos_emb = self._cfg.c_pos_emb + self.c_timestep_emb = self._cfg.c_timestep_emb + self.c_node_pre = 1280 + self.aatype_emb_dim = self._cfg.c_pos_emb + + self.aatype_emb = nn.Embedding(21, self.aatype_emb_dim) + + total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + self.c_node_pre + # total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + + self.linear = nn.Sequential( + nn.Linear(total_node_feats, self.c_s), + nn.ReLU(), + nn.Dropout(self._cfg.dropout), + nn.Linear(self.c_s, self.c_s), + ) + + def embed_t(self, timesteps, mask): + timestep_emb = get_time_embedding( + timesteps[:, 0], + self.c_timestep_emb, + max_positions=2056 + )[:, None, :].repeat(1, mask.shape[1], 1) + return timestep_emb * mask.unsqueeze(-1) + + def forward(self, timesteps, aatype, node_repr_pre, mask): + ''' + mask: [B,L] + timesteps: [B,1] + energy: [B,] + ''' + + b, num_res, device = mask.shape[0], mask.shape[1], mask.device + + # [b, n_res, c_pos_emb] + # pos = torch.arange(num_res, dtype=torch.float32).to(device)[None] # (1,L) + # pos_emb = get_index_embedding( + # pos, self.c_pos_emb, max_len=2056 + # ) + # pos_emb = pos_emb.repeat([b, 1, 1]) + # pos_emb = pos_emb * mask.unsqueeze(-1) + + aatype_emb = self.aatype_emb(aatype) * mask.unsqueeze(-1) + + # [b, n_res, c_timestep_emb] + input_feats = [aatype_emb] + # timesteps are between 0 and 1. Convert to integers. + time_emb = self.embed_t(timesteps, mask) + input_feats.append(time_emb) + + input_feats.append(node_repr_pre) + + out = self.linear(torch.cat(input_feats, dim=-1)) # (B,L,d_node) + + return add_RoPE(out) \ No newline at end of file diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e575b106a7832454f4ea610f5ec9b858cb299b4 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,97 @@ +import math +import torch +from torch.nn import functional as F +import numpy as np +from data import utils as du + + +def calc_distogram(pos, min_bin, max_bin, num_bins): + dists_2d = torch.linalg.norm( + pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] # (B,L,L,1) + lower = torch.linspace( + min_bin, + max_bin, + num_bins, + device=pos.device) + upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) + dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) # (B,L,L,num_bins) + return dgram + +def add_RoPE(indices): + """Creates sine / cosine positional embeddings from a prespecified indices. + + Args: + indices: (B,L,embed_size) + embed_size: dimension of the embeddings to create + + Returns: + positional embedding of shape [B, L, embed_size] + """ + seq_len, embed_size = indices.shape[-2:] + seq_all = torch.arange(seq_len, device=indices.device)[:,None] # (L,1) + theta_all = torch.pow(1e4, torch.arange(embed_size)//2 / -embed_size)[None,:] # (1,embed_size) + sinusoidal_pos = (seq_all * theta_all.to(indices.device))[None,...] # (1,L,embed_size) + + cos_pos = torch.cos(sinusoidal_pos) # (1,L,embed_size) + sin_pos = torch.sin(sinusoidal_pos) # (1,L,embed_size) + indices_sin = torch.stack([-indices[..., 1::2], indices[..., ::2]], dim=-1) # (B,L,embed_size/2,2) + indices_sin = indices_sin.reshape(indices.shape) # (B,L,embed_size) + indices = indices * cos_pos + indices_sin * sin_pos + return indices + +def get_index_embedding(indices, embed_size, max_len=2056): + """Creates sine / cosine positional embeddings from a prespecified indices. + + Args: + indices: offsets of size [..., N_edges] of type integer + max_len: maximum length. + embed_size: dimension of the embeddings to create + + Returns: + positional embedding of shape [N, embed_size] + """ + K = torch.arange(embed_size//2, device=indices.device) + pos_embedding_sin = torch.sin( + indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) + pos_embedding_cos = torch.cos( + indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) + pos_embedding = torch.cat([ + pos_embedding_sin, pos_embedding_cos], axis=-1) + return pos_embedding + + +def get_time_embedding(timesteps, embedding_dim, max_positions=2000): + # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py + assert len(timesteps.shape) == 1 + timesteps = timesteps * max_positions + half_dim = embedding_dim // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None): + """Stratify loss by binning t.""" + batch_t = du.to_numpy(batch_t) + batch_loss = du.to_numpy(batch_loss) + flat_losses = batch_loss.flatten() + flat_t = batch_t.flatten() + bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins+1) + bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1 + t_binned_loss = np.bincount(bin_idx, weights=flat_losses) + t_binned_n = np.bincount(bin_idx) + stratified_losses = {} + if loss_name is None: + loss_name = 'loss' + for t_bin in np.unique(bin_idx).tolist(): + bin_start = bin_edges[t_bin] + bin_end = bin_edges[t_bin+1] + t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})' + range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin] + stratified_losses[t_range] = range_loss + return stratified_losses diff --git a/openfold/__pycache__/config.cpython-310.pyc b/openfold/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7ec659827725a7d5c55f0234fd9d83286346e55 Binary files /dev/null and b/openfold/__pycache__/config.cpython-310.pyc differ diff --git a/openfold/config.py b/openfold/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b73acb91d367c07f3c99cf7400883041baf67a7a --- /dev/null +++ b/openfold/config.py @@ -0,0 +1,4 @@ +NUM_RES = "num residues placeholder" +NUM_MSA_SEQ = "msa placeholder" +NUM_EXTRA_SEQ = "extra msa placeholder" +NUM_TEMPLATES = "num templates placeholder" diff --git a/openfold/data/__init__.py b/openfold/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openfold/data/__pycache__/__init__.cpython-310.pyc b/openfold/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd86eeba1f6edf1a55a3c90c7dfcb997482dd47 Binary files /dev/null and b/openfold/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/openfold/data/__pycache__/data_transforms.cpython-310.pyc b/openfold/data/__pycache__/data_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cc04441a75896d85cf8dadb99a4eb4ec8ca9cd4 Binary files /dev/null and b/openfold/data/__pycache__/data_transforms.cpython-310.pyc differ diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..fec88f02607b8366552658583213bdbdb49db7c4 --- /dev/null +++ b/openfold/data/data_modules.py @@ -0,0 +1,693 @@ +import copy +from functools import partial +import json +import logging +import os +import pickle +from typing import Optional, Sequence, List, Any + +import ml_collections as mlc +import numpy as np +import pytorch_lightning as pl +import torch +from torch.utils.data import RandomSampler + +from openfold.data import ( + data_pipeline, + feature_pipeline, + mmcif_parsing, + templates, +) +from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap + + +class OpenFoldSingleDataset(torch.utils.data.Dataset): + def __init__(self, + data_dir: str, + alignment_dir: str, + template_mmcif_dir: str, + max_template_date: str, + config: mlc.ConfigDict, + kalign_binary_path: str = '/usr/bin/kalign', + max_template_hits: int = 4, + obsolete_pdbs_file_path: Optional[str] = None, + template_release_dates_cache_path: Optional[str] = None, + shuffle_top_k_prefiltered: Optional[int] = None, + treat_pdb_as_distillation: bool = True, + mapping_path: Optional[str] = None, + mode: str = "train", + _output_raw: bool = False, + _alignment_index: Optional[Any] = None + ): + """ + Args: + data_dir: + A path to a directory containing mmCIF files (in train + mode) or FASTA files (in inference mode). + alignment_dir: + A path to a directory containing only data in the format + output by an AlignmentRunner + (defined in openfold.features.alignment_runner). + I.e. a directory of directories named {PDB_ID}_{CHAIN_ID} + or simply {PDB_ID}, each containing .a3m, .sto, and .hhr + files. + template_mmcif_dir: + Path to a directory containing template mmCIF files. + config: + A dataset config object. See openfold.config + kalign_binary_path: + Path to kalign binary. + max_template_hits: + An upper bound on how many templates are considered. During + training, the templates ultimately used are subsampled + from this total quantity. + template_release_dates_cache_path: + Path to the output of scripts/generate_mmcif_cache. + obsolete_pdbs_file_path: + Path to the file containing replacements for obsolete PDBs. + shuffle_top_k_prefiltered: + Whether to uniformly shuffle the top k template hits before + parsing max_template_hits of them. Can be used to + approximate DeepMind's training-time template subsampling + scheme much more performantly. + treat_pdb_as_distillation: + Whether to assume that .pdb files in the data_dir are from + the self-distillation set (and should be subjected to + special distillation set preprocessing steps). + mode: + "train", "val", or "predict" + """ + super(OpenFoldSingleDataset, self).__init__() + self.data_dir = data_dir + self.alignment_dir = alignment_dir + self.config = config + self.treat_pdb_as_distillation = treat_pdb_as_distillation + self.mode = mode + self._output_raw = _output_raw + self._alignment_index = _alignment_index + + valid_modes = ["train", "eval", "predict"] + if(mode not in valid_modes): + raise ValueError(f'mode must be one of {valid_modes}') + + if(template_release_dates_cache_path is None): + logging.warning( + "Template release dates cache does not exist. Remember to run " + "scripts/generate_mmcif_cache.py before running OpenFold" + ) + + if(_alignment_index is not None): + self._chain_ids = list(_alignment_index.keys()) + elif(mapping_path is None): + self._chain_ids = list(os.listdir(alignment_dir)) + else: + with open(mapping_path, "r") as f: + self._chain_ids = [l.strip() for l in f.readlines()] + + self._chain_id_to_idx_dict = { + chain: i for i, chain in enumerate(self._chain_ids) + } + + template_featurizer = templates.TemplateHitFeaturizer( + mmcif_dir=template_mmcif_dir, + max_template_date=max_template_date, + max_hits=max_template_hits, + kalign_binary_path=kalign_binary_path, + release_dates_path=template_release_dates_cache_path, + obsolete_pdbs_path=obsolete_pdbs_file_path, + _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered, + ) + + self.data_pipeline = data_pipeline.DataPipeline( + template_featurizer=template_featurizer, + ) + + if(not self._output_raw): + self.feature_pipeline = feature_pipeline.FeaturePipeline(config) + + def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): + with open(path, 'r') as f: + mmcif_string = f.read() + + mmcif_object = mmcif_parsing.parse( + file_id=file_id, mmcif_string=mmcif_string + ) + + # Crash if an error is encountered. Any parsing errors should have + # been dealt with at the alignment stage. + if(mmcif_object.mmcif_object is None): + raise list(mmcif_object.errors.values())[0] + + mmcif_object = mmcif_object.mmcif_object + + data = self.data_pipeline.process_mmcif( + mmcif=mmcif_object, + alignment_dir=alignment_dir, + chain_id=chain_id, + _alignment_index=_alignment_index + ) + + return data + + def chain_id_to_idx(self, chain_id): + return self._chain_id_to_idx_dict[chain_id] + + def idx_to_chain_id(self, idx): + return self._chain_ids[idx] + + def __getitem__(self, idx): + name = self.idx_to_chain_id(idx) + alignment_dir = os.path.join(self.alignment_dir, name) + + _alignment_index = None + if(self._alignment_index is not None): + alignment_dir = self.alignment_dir + _alignment_index = self._alignment_index[name] + + if(self.mode == 'train' or self.mode == 'eval'): + spl = name.rsplit('_', 1) + if(len(spl) == 2): + file_id, chain_id = spl + else: + file_id, = spl + chain_id = None + + path = os.path.join(self.data_dir, file_id) + if(os.path.exists(path + ".cif")): + data = self._parse_mmcif( + path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, + ) + elif(os.path.exists(path + ".core")): + data = self.data_pipeline.process_core( + path + ".core", alignment_dir, _alignment_index, + ) + elif(os.path.exists(path + ".pdb")): + data = self.data_pipeline.process_pdb( + pdb_path=path + ".pdb", + alignment_dir=alignment_dir, + is_distillation=self.treat_pdb_as_distillation, + chain_id=chain_id, + _alignment_index=_alignment_index, + ) + else: + raise ValueError("Invalid file type") + else: + path = os.path.join(name, name + ".fasta") + data = self.data_pipeline.process_fasta( + fasta_path=path, + alignment_dir=alignment_dir, + _alignment_index=_alignment_index, + ) + + if(self._output_raw): + return data + + feats = self.feature_pipeline.process_features( + data, self.mode + ) + + return feats + + def __len__(self): + return len(self._chain_ids) + + +def deterministic_train_filter( + chain_data_cache_entry: Any, + max_resolution: float = 9., + max_single_aa_prop: float = 0.8, +) -> bool: + # Hard filters + resolution = chain_data_cache_entry.get("resolution", None) + if(resolution is not None and resolution > max_resolution): + return False + + seq = chain_data_cache_entry["seq"] + counts = {} + for aa in seq: + counts.setdefault(aa, 0) + counts[aa] += 1 + largest_aa_count = max(counts.values()) + largest_single_aa_prop = largest_aa_count / len(seq) + if(largest_single_aa_prop > max_single_aa_prop): + return False + + return True + + +def get_stochastic_train_filter_prob( + chain_data_cache_entry: Any, +) -> List[float]: + # Stochastic filters + probabilities = [] + + cluster_size = chain_data_cache_entry.get("cluster_size", None) + if(cluster_size is not None and cluster_size > 0): + probabilities.append(1 / cluster_size) + + chain_length = len(chain_data_cache_entry["seq"]) + probabilities.append((1 / 512) * (max(min(chain_length, 512), 256))) + + # Risk of underflow here? + out = 1 + for p in probabilities: + out *= p + + return out + + +class OpenFoldDataset(torch.utils.data.Dataset): + """ + Implements the stochastic filters applied during AlphaFold's training. + Because samples are selected from constituent datasets randomly, the + length of an OpenFoldFilteredDataset is arbitrary. Samples are selected + and filtered once at initialization. + """ + def __init__(self, + datasets: Sequence[OpenFoldSingleDataset], + probabilities: Sequence[int], + epoch_len: int, + chain_data_cache_paths: List[str], + generator: torch.Generator = None, + _roll_at_init: bool = True, + ): + self.datasets = datasets + self.probabilities = probabilities + self.epoch_len = epoch_len + self.generator = generator + + self.chain_data_caches = [] + for path in chain_data_cache_paths: + with open(path, "r") as fp: + self.chain_data_caches.append(json.load(fp)) + + def looped_shuffled_dataset_idx(dataset_len): + while True: + # Uniformly shuffle each dataset's indices + weights = [1. for _ in range(dataset_len)] + shuf = torch.multinomial( + torch.tensor(weights), + num_samples=dataset_len, + replacement=False, + generator=self.generator, + ) + for idx in shuf: + yield idx + + def looped_samples(dataset_idx): + max_cache_len = int(epoch_len * probabilities[dataset_idx]) + dataset = self.datasets[dataset_idx] + idx_iter = looped_shuffled_dataset_idx(len(dataset)) + chain_data_cache = self.chain_data_caches[dataset_idx] + while True: + weights = [] + idx = [] + for _ in range(max_cache_len): + candidate_idx = next(idx_iter) + chain_id = dataset.idx_to_chain_id(candidate_idx) + chain_data_cache_entry = chain_data_cache[chain_id] + if(not deterministic_train_filter(chain_data_cache_entry)): + continue + + p = get_stochastic_train_filter_prob( + chain_data_cache_entry, + ) + weights.append([1. - p, p]) + idx.append(candidate_idx) + + samples = torch.multinomial( + torch.tensor(weights), + num_samples=1, + generator=self.generator, + ) + samples = samples.squeeze() + + cache = [i for i, s in zip(idx, samples) if s] + + for datapoint_idx in cache: + yield datapoint_idx + + self._samples = [looped_samples(i) for i in range(len(self.datasets))] + + if(_roll_at_init): + self.reroll() + + def __getitem__(self, idx): + dataset_idx, datapoint_idx = self.datapoints[idx] + return self.datasets[dataset_idx][datapoint_idx] + + def __len__(self): + return self.epoch_len + + def reroll(self): + dataset_choices = torch.multinomial( + torch.tensor(self.probabilities), + num_samples=self.epoch_len, + replacement=True, + generator=self.generator, + ) + + self.datapoints = [] + for dataset_idx in dataset_choices: + samples = self._samples[dataset_idx] + datapoint_idx = next(samples) + self.datapoints.append((dataset_idx, datapoint_idx)) + + +class OpenFoldBatchCollator: + def __init__(self, config, stage="train"): + self.stage = stage + self.feature_pipeline = feature_pipeline.FeaturePipeline(config) + + def __call__(self, raw_prots): + processed_prots = [] + for prot in raw_prots: + features = self.feature_pipeline.process_features( + prot, self.stage + ) + processed_prots.append(features) + + stack_fn = partial(torch.stack, dim=0) + return dict_multimap(stack_fn, processed_prots) + + +class OpenFoldDataLoader(torch.utils.data.DataLoader): + def __init__(self, *args, config, stage="train", generator=None, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + self.stage = stage + + if(generator is None): + generator = torch.Generator() + + self.generator = generator + self._prep_batch_properties_probs() + + def _prep_batch_properties_probs(self): + keyed_probs = [] + stage_cfg = self.config[self.stage] + + max_iters = self.config.common.max_recycling_iters + if(stage_cfg.supervised): + clamp_prob = self.config.supervised.clamp_prob + keyed_probs.append( + ("use_clamped_fape", [1 - clamp_prob, clamp_prob]) + ) + + if(stage_cfg.uniform_recycling): + recycling_probs = [ + 1. / (max_iters + 1) for _ in range(max_iters + 1) + ] + else: + recycling_probs = [ + 0. for _ in range(max_iters + 1) + ] + recycling_probs[-1] = 1. + + keyed_probs.append( + ("no_recycling_iters", recycling_probs) + ) + + keys, probs = zip(*keyed_probs) + max_len = max([len(p) for p in probs]) + padding = [[0.] * (max_len - len(p)) for p in probs] + + self.prop_keys = keys + self.prop_probs_tensor = torch.tensor( + [p + pad for p, pad in zip(probs, padding)], + dtype=torch.float32, + ) + + def _add_batch_properties(self, batch): + samples = torch.multinomial( + self.prop_probs_tensor, + num_samples=1, # 1 per row + replacement=True, + generator=self.generator + ) + + aatype = batch["aatype"] + batch_dims = aatype.shape[:-2] + recycling_dim = aatype.shape[-1] + no_recycling = recycling_dim + for i, key in enumerate(self.prop_keys): + sample = int(samples[i][0]) + sample_tensor = torch.tensor( + sample, + device=aatype.device, + requires_grad=False + ) + orig_shape = sample_tensor.shape + sample_tensor = sample_tensor.view( + (1,) * len(batch_dims) + sample_tensor.shape + (1,) + ) + sample_tensor = sample_tensor.expand( + batch_dims + orig_shape + (recycling_dim,) + ) + batch[key] = sample_tensor + + if(key == "no_recycling_iters"): + no_recycling = sample + + resample_recycling = lambda t: t[..., :no_recycling + 1] + batch = tensor_tree_map(resample_recycling, batch) + + return batch + + def __iter__(self): + it = super().__iter__() + + def _batch_prop_gen(iterator): + for batch in iterator: + yield self._add_batch_properties(batch) + + return _batch_prop_gen(it) + + +class OpenFoldDataModule(pl.LightningDataModule): + def __init__(self, + config: mlc.ConfigDict, + template_mmcif_dir: str, + max_template_date: str, + train_data_dir: Optional[str] = None, + train_alignment_dir: Optional[str] = None, + train_chain_data_cache_path: Optional[str] = None, + distillation_data_dir: Optional[str] = None, + distillation_alignment_dir: Optional[str] = None, + distillation_chain_data_cache_path: Optional[str] = None, + val_data_dir: Optional[str] = None, + val_alignment_dir: Optional[str] = None, + predict_data_dir: Optional[str] = None, + predict_alignment_dir: Optional[str] = None, + kalign_binary_path: str = '/usr/bin/kalign', + train_mapping_path: Optional[str] = None, + distillation_mapping_path: Optional[str] = None, + obsolete_pdbs_file_path: Optional[str] = None, + template_release_dates_cache_path: Optional[str] = None, + batch_seed: Optional[int] = None, + train_epoch_len: int = 50000, + _alignment_index_path: Optional[str] = None, + **kwargs + ): + super(OpenFoldDataModule, self).__init__() + + self.config = config + self.template_mmcif_dir = template_mmcif_dir + self.max_template_date = max_template_date + self.train_data_dir = train_data_dir + self.train_alignment_dir = train_alignment_dir + self.train_chain_data_cache_path = train_chain_data_cache_path + self.distillation_data_dir = distillation_data_dir + self.distillation_alignment_dir = distillation_alignment_dir + self.distillation_chain_data_cache_path = ( + distillation_chain_data_cache_path + ) + self.val_data_dir = val_data_dir + self.val_alignment_dir = val_alignment_dir + self.predict_data_dir = predict_data_dir + self.predict_alignment_dir = predict_alignment_dir + self.kalign_binary_path = kalign_binary_path + self.train_mapping_path = train_mapping_path + self.distillation_mapping_path = distillation_mapping_path + self.template_release_dates_cache_path = ( + template_release_dates_cache_path + ) + self.obsolete_pdbs_file_path = obsolete_pdbs_file_path + self.batch_seed = batch_seed + self.train_epoch_len = train_epoch_len + + if(self.train_data_dir is None and self.predict_data_dir is None): + raise ValueError( + 'At least one of train_data_dir or predict_data_dir must be ' + 'specified' + ) + + self.training_mode = self.train_data_dir is not None + + if(self.training_mode and train_alignment_dir is None): + raise ValueError( + 'In training mode, train_alignment_dir must be specified' + ) + elif(not self.training_mode and predict_alignment_dir is None): + raise ValueError( + 'In inference mode, predict_alignment_dir must be specified' + ) + elif(val_data_dir is not None and val_alignment_dir is None): + raise ValueError( + 'If val_data_dir is specified, val_alignment_dir must ' + 'be specified as well' + ) + + # An ad-hoc measure for our particular filesystem restrictions + self._alignment_index = None + if(_alignment_index_path is not None): + with open(_alignment_index_path, "r") as fp: + self._alignment_index = json.load(fp) + + def setup(self): + # Most of the arguments are the same for the three datasets + dataset_gen = partial(OpenFoldSingleDataset, + template_mmcif_dir=self.template_mmcif_dir, + max_template_date=self.max_template_date, + config=self.config, + kalign_binary_path=self.kalign_binary_path, + template_release_dates_cache_path= + self.template_release_dates_cache_path, + obsolete_pdbs_file_path= + self.obsolete_pdbs_file_path, + ) + + if(self.training_mode): + train_dataset = dataset_gen( + data_dir=self.train_data_dir, + alignment_dir=self.train_alignment_dir, + mapping_path=self.train_mapping_path, + max_template_hits=self.config.train.max_template_hits, + shuffle_top_k_prefiltered= + self.config.train.shuffle_top_k_prefiltered, + treat_pdb_as_distillation=False, + mode="train", + _output_raw=True, + _alignment_index=self._alignment_index, + ) + + distillation_dataset = None + if(self.distillation_data_dir is not None): + distillation_dataset = dataset_gen( + data_dir=self.distillation_data_dir, + alignment_dir=self.distillation_alignment_dir, + mapping_path=self.distillation_mapping_path, + max_template_hits=self.train.max_template_hits, + treat_pdb_as_distillation=True, + mode="train", + _output_raw=True, + ) + + d_prob = self.config.train.distillation_prob + + if(distillation_dataset is not None): + datasets = [train_dataset, distillation_dataset] + d_prob = self.config.train.distillation_prob + probabilities = [1 - d_prob, d_prob] + chain_data_cache_paths = [ + self.train_chain_data_cache_path, + self.distillation_chain_data_cache_path, + ] + else: + datasets = [train_dataset] + probabilities = [1.] + chain_data_cache_paths = [ + self.train_chain_data_cache_path, + ] + + self.train_dataset = OpenFoldDataset( + datasets=datasets, + probabilities=probabilities, + epoch_len=self.train_epoch_len, + chain_data_cache_paths=chain_data_cache_paths, + _roll_at_init=False, + ) + + if(self.val_data_dir is not None): + self.eval_dataset = dataset_gen( + data_dir=self.val_data_dir, + alignment_dir=self.val_alignment_dir, + mapping_path=None, + max_template_hits=self.config.eval.max_template_hits, + mode="eval", + _output_raw=True, + ) + else: + self.eval_dataset = None + else: + self.predict_dataset = dataset_gen( + data_dir=self.predict_data_dir, + alignment_dir=self.predict_alignment_dir, + mapping_path=None, + max_template_hits=self.config.predict.max_template_hits, + mode="predict", + ) + + def _gen_dataloader(self, stage): + generator = torch.Generator() + if(self.batch_seed is not None): + generator = generator.manual_seed(self.batch_seed) + + dataset = None + if(stage == "train"): + dataset = self.train_dataset + + # Filter the dataset, if necessary + dataset.reroll() + elif(stage == "eval"): + dataset = self.eval_dataset + elif(stage == "predict"): + dataset = self.predict_dataset + else: + raise ValueError("Invalid stage") + + batch_collator = OpenFoldBatchCollator(self.config, stage) + + dl = OpenFoldDataLoader( + dataset, + config=self.config, + stage=stage, + generator=generator, + batch_size=self.config.data_module.data_loaders.batch_size, + num_workers=self.config.data_module.data_loaders.num_workers, + collate_fn=batch_collator, + ) + + return dl + + def train_dataloader(self): + return self._gen_dataloader("train") + + def val_dataloader(self): + if(self.eval_dataset is not None): + return self._gen_dataloader("eval") + return None + + def predict_dataloader(self): + return self._gen_dataloader("predict") + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, batch_path): + with open(batch_path, "rb") as f: + self.batch = pickle.load(f) + + def __getitem__(self, idx): + return copy.deepcopy(self.batch) + + def __len__(self): + return 1000 + + +class DummyDataLoader(pl.LightningDataModule): + def __init__(self, batch_path): + super().__init__() + self.dataset = DummyDataset(batch_path) + + def train_dataloader(self): + return torch.utils.data.DataLoader(self.dataset) diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6419b7e1c48679005812cb522078f740978239 --- /dev/null +++ b/openfold/data/data_pipeline.py @@ -0,0 +1,678 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import datetime +from multiprocessing import cpu_count +from typing import Mapping, Optional, Sequence, Any + +import numpy as np + +from openfold.data import templates, parsers, mmcif_parsing +from openfold.data.tools import jackhmmer, hhblits, hhsearch +from openfold.data.tools.utils import to_date +from openfold.np import residue_constants, protein + + +FeatureDict = Mapping[str, np.ndarray] + +def empty_template_feats(n_res) -> FeatureDict: + return { + "template_aatype": np.zeros((0, n_res)).astype(np.int64), + "template_all_atom_positions": + np.zeros((0, n_res, 37, 3)).astype(np.float32), + "template_sum_probs": np.zeros((0, 1)).astype(np.float32), + "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32), + } + + +def make_template_features( + input_sequence: str, + hits: Sequence[Any], + template_featurizer: Any, + query_pdb_code: Optional[str] = None, + query_release_date: Optional[str] = None, +) -> FeatureDict: + hits_cat = sum(hits.values(), []) + if(len(hits_cat) == 0 or template_featurizer is None): + template_features = empty_template_feats(len(input_sequence)) + else: + templates_result = template_featurizer.get_templates( + query_sequence=input_sequence, + query_pdb_code=query_pdb_code, + query_release_date=query_release_date, + hits=hits_cat, + ) + template_features = templates_result.features + + # The template featurizer doesn't format empty template features + # properly. This is a quick fix. + if(template_features["template_aatype"].shape[0] == 0): + template_features = empty_template_feats(len(input_sequence)) + + return template_features + + +def make_sequence_features( + sequence: str, description: str, num_res: int +) -> FeatureDict: + """Construct a feature dict of sequence features.""" + features = {} + features["aatype"] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + features["between_segment_residues"] = np.zeros((num_res,), dtype=int) + features["domain_name"] = np.array( + [description.encode("utf-8")], dtype=np.object_ + ) + features["residue_index"] = np.array(range(num_res), dtype=int) + features["seq_length"] = np.array([num_res] * num_res, dtype=int) + features["sequence"] = np.array( + [sequence.encode("utf-8")], dtype=np.object_ + ) + return features + + +def make_mmcif_features( + mmcif_object: mmcif_parsing.MmcifObject, chain_id: str +) -> FeatureDict: + input_sequence = mmcif_object.chain_to_seqres[chain_id] + description = "_".join([mmcif_object.file_id, chain_id]) + num_res = len(input_sequence) + + mmcif_feats = {} + + mmcif_feats.update( + make_sequence_features( + sequence=input_sequence, + description=description, + num_res=num_res, + ) + ) + + all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords( + mmcif_object=mmcif_object, chain_id=chain_id + ) + mmcif_feats["all_atom_positions"] = all_atom_positions + mmcif_feats["all_atom_mask"] = all_atom_mask + + mmcif_feats["resolution"] = np.array( + [mmcif_object.header["resolution"]], dtype=np.float32 + ) + + mmcif_feats["release_date"] = np.array( + [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + ) + + mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) + + return mmcif_feats + + +def _aatype_to_str_sequence(aatype): + return ''.join([ + residue_constants.restypes_with_x[aatype[i]] + for i in range(len(aatype)) + ]) + + +def make_protein_features( + protein_object: protein.Protein, + description: str, + _is_distillation: bool = False, +) -> FeatureDict: + pdb_feats = {} + aatype = protein_object.aatype + sequence = _aatype_to_str_sequence(aatype) + pdb_feats.update( + make_sequence_features( + sequence=sequence, + description=description, + num_res=len(protein_object.aatype), + ) + ) + + all_atom_positions = protein_object.atom_positions + all_atom_mask = protein_object.atom_mask + + pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32) + pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32) + + pdb_feats["resolution"] = np.array([0.]).astype(np.float32) + pdb_feats["is_distillation"] = np.array( + 1. if _is_distillation else 0. + ).astype(np.float32) + + return pdb_feats + + +def make_pdb_features( + protein_object: protein.Protein, + description: str, + confidence_threshold: float = 0.5, + is_distillation: bool = True, +) -> FeatureDict: + pdb_feats = make_protein_features( + protein_object, description, _is_distillation=True + ) + + if(is_distillation): + high_confidence = protein_object.b_factors > confidence_threshold + high_confidence = np.any(high_confidence, axis=-1) + for i, confident in enumerate(high_confidence): + if(not confident): + pdb_feats["all_atom_mask"][i] = 0 + + return pdb_feats + + +def make_msa_features( + msas: Sequence[Sequence[str]], + deletion_matrices: Sequence[parsers.DeletionMatrix], +) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError("At least one MSA must be provided.") + + int_msa = [] + deletion_matrix = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError( + f"MSA {msa_index} must contain at least one sequence." + ) + for sequence_index, sequence in enumerate(msa): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence] + ) + deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) + + num_res = len(msas[0][0]) + num_alignments = len(int_msa) + features = {} + features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=int) + features["msa"] = np.array(int_msa, dtype=int) + features["num_alignments"] = np.array( + [num_alignments] * num_res, dtype=int + ) + return features + + +class AlignmentRunner: + """Runs alignment tools and saves the results""" + def __init__( + self, + jackhmmer_binary_path: Optional[str] = None, + hhblits_binary_path: Optional[str] = None, + hhsearch_binary_path: Optional[str] = None, + uniref90_database_path: Optional[str] = None, + mgnify_database_path: Optional[str] = None, + bfd_database_path: Optional[str] = None, + uniclust30_database_path: Optional[str] = None, + pdb70_database_path: Optional[str] = None, + use_small_bfd: Optional[bool] = None, + no_cpus: Optional[int] = None, + uniref_max_hits: int = 10000, + mgnify_max_hits: int = 5000, + ): + """ + Args: + jackhmmer_binary_path: + Path to jackhmmer binary + hhblits_binary_path: + Path to hhblits binary + hhsearch_binary_path: + Path to hhsearch binary + uniref90_database_path: + Path to uniref90 database. If provided, jackhmmer_binary_path + must also be provided + mgnify_database_path: + Path to mgnify database. If provided, jackhmmer_binary_path + must also be provided + bfd_database_path: + Path to BFD database. Depending on the value of use_small_bfd, + one of hhblits_binary_path or jackhmmer_binary_path must be + provided. + uniclust30_database_path: + Path to uniclust30. Searched alongside BFD if use_small_bfd is + false. + pdb70_database_path: + Path to pdb70 database. + use_small_bfd: + Whether to search the BFD database alone with jackhmmer or + in conjunction with uniclust30 with hhblits. + no_cpus: + The number of CPUs available for alignment. By default, all + CPUs are used. + uniref_max_hits: + Max number of uniref hits + mgnify_max_hits: + Max number of mgnify hits + """ + db_map = { + "jackhmmer": { + "binary": jackhmmer_binary_path, + "dbs": [ + uniref90_database_path, + mgnify_database_path, + bfd_database_path if use_small_bfd else None, + ], + }, + "hhblits": { + "binary": hhblits_binary_path, + "dbs": [ + bfd_database_path if not use_small_bfd else None, + ], + }, + "hhsearch": { + "binary": hhsearch_binary_path, + "dbs": [ + pdb70_database_path, + ], + }, + } + + for name, dic in db_map.items(): + binary, dbs = dic["binary"], dic["dbs"] + if(binary is None and not all([x is None for x in dbs])): + raise ValueError( + f"{name} DBs provided but {name} binary is None" + ) + + if(not all([x is None for x in db_map["hhsearch"]["dbs"]]) + and uniref90_database_path is None): + raise ValueError( + """uniref90_database_path must be specified in order to perform + template search""" + ) + + self.uniref_max_hits = uniref_max_hits + self.mgnify_max_hits = mgnify_max_hits + self.use_small_bfd = use_small_bfd + + if(no_cpus is None): + no_cpus = cpu_count() + + self.jackhmmer_uniref90_runner = None + if(jackhmmer_binary_path is not None and + uniref90_database_path is not None + ): + self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniref90_database_path, + n_cpu=no_cpus, + ) + + self.jackhmmer_small_bfd_runner = None + self.hhblits_bfd_uniclust_runner = None + if(bfd_database_path is not None): + if use_small_bfd: + self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=bfd_database_path, + n_cpu=no_cpus, + ) + else: + dbs = [bfd_database_path] + if(uniclust30_database_path is not None): + dbs.append(uniclust30_database_path) + self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=dbs, + n_cpu=no_cpus, + ) + + self.jackhmmer_mgnify_runner = None + if(mgnify_database_path is not None): + self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=mgnify_database_path, + n_cpu=no_cpus, + ) + + self.hhsearch_pdb70_runner = None + if(pdb70_database_path is not None): + self.hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path=hhsearch_binary_path, + databases=[pdb70_database_path], + n_cpu=no_cpus, + ) + + def run( + self, + fasta_path: str, + output_dir: str, + ): + """Runs alignment tools on a sequence""" + if(self.jackhmmer_uniref90_runner is not None): + jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( + fasta_path + )[0] + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + jackhmmer_uniref90_result["sto"], + max_sequences=self.uniref_max_hits + ) + uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m") + with open(uniref90_out_path, "w") as f: + f.write(uniref90_msa_as_a3m) + + if(self.hhsearch_pdb70_runner is not None): + hhsearch_result = self.hhsearch_pdb70_runner.query( + uniref90_msa_as_a3m + ) + pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr") + with open(pdb70_out_path, "w") as f: + f.write(hhsearch_result) + + if(self.jackhmmer_mgnify_runner is not None): + jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( + fasta_path + )[0] + mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( + jackhmmer_mgnify_result["sto"], + max_sequences=self.mgnify_max_hits + ) + mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m") + with open(mgnify_out_path, "w") as f: + f.write(mgnify_msa_as_a3m) + + if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None): + jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( + fasta_path + )[0] + bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto") + with open(bfd_out_path, "w") as f: + f.write(jackhmmer_small_bfd_result["sto"]) + elif(self.hhblits_bfd_uniclust_runner is not None): + hhblits_bfd_uniclust_result = ( + self.hhblits_bfd_uniclust_runner.query(fasta_path) + ) + if output_dir is not None: + bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m") + with open(bfd_out_path, "w") as f: + f.write(hhblits_bfd_uniclust_result["a3m"]) + + +class DataPipeline: + """Assembles input features.""" + def __init__( + self, + template_featurizer: Optional[templates.TemplateHitFeaturizer], + ): + self.template_featurizer = template_featurizer + + def _parse_msa_data( + self, + alignment_dir: str, + _alignment_index: Optional[Any] = None, + ) -> Mapping[str, Any]: + msa_data = {} + + if(_alignment_index is not None): + fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") + + def read_msa(start, size): + fp.seek(start) + msa = fp.read(size).decode("utf-8") + return msa + + for (name, start, size) in _alignment_index["files"]: + ext = os.path.splitext(name)[-1] + + if(ext == ".a3m"): + msa, deletion_matrix = parsers.parse_a3m( + read_msa(start, size) + ) + data = {"msa": msa, "deletion_matrix": deletion_matrix} + elif(ext == ".sto"): + msa, deletion_matrix, _ = parsers.parse_stockholm( + read_msa(start, size) + ) + data = {"msa": msa, "deletion_matrix": deletion_matrix} + else: + continue + + msa_data[name] = data + + fp.close() + else: + for f in os.listdir(alignment_dir): + path = os.path.join(alignment_dir, f) + ext = os.path.splitext(f)[-1] + + if(ext == ".a3m"): + with open(path, "r") as fp: + msa, deletion_matrix = parsers.parse_a3m(fp.read()) + data = {"msa": msa, "deletion_matrix": deletion_matrix} + elif(ext == ".sto"): + with open(path, "r") as fp: + msa, deletion_matrix, _ = parsers.parse_stockholm( + fp.read() + ) + data = {"msa": msa, "deletion_matrix": deletion_matrix} + else: + continue + + msa_data[f] = data + + return msa_data + + def _parse_template_hits( + self, + alignment_dir: str, + _alignment_index: Optional[Any] = None + ) -> Mapping[str, Any]: + all_hits = {} + if(_alignment_index is not None): + fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') + + def read_template(start, size): + fp.seek(start) + return fp.read(size).decode("utf-8") + + for (name, start, size) in _alignment_index["files"]: + ext = os.path.splitext(name)[-1] + + if(ext == ".hhr"): + hits = parsers.parse_hhr(read_template(start, size)) + all_hits[name] = hits + + fp.close() + else: + for f in os.listdir(alignment_dir): + path = os.path.join(alignment_dir, f) + ext = os.path.splitext(f)[-1] + + if(ext == ".hhr"): + with open(path, "r") as fp: + hits = parsers.parse_hhr(fp.read()) + all_hits[f] = hits + + return all_hits + + def _process_msa_feats( + self, + alignment_dir: str, + input_sequence: Optional[str] = None, + _alignment_index: Optional[str] = None + ) -> Mapping[str, Any]: + msa_data = self._parse_msa_data(alignment_dir, _alignment_index) + + if(len(msa_data) == 0): + if(input_sequence is None): + raise ValueError( + """ + If the alignment dir contains no MSAs, an input sequence + must be provided. + """ + ) + msa_data["dummy"] = { + "msa": [input_sequence], + "deletion_matrix": [[0 for _ in input_sequence]], + } + + msas, deletion_matrices = zip(*[ + (v["msa"], v["deletion_matrix"]) for v in msa_data.values() + ]) + + msa_features = make_msa_features( + msas=msas, + deletion_matrices=deletion_matrices, + ) + + return msa_features + + def process_fasta( + self, + fasta_path: str, + alignment_dir: str, + _alignment_index: Optional[str] = None, + ) -> FeatureDict: + """Assembles features for a single sequence in a FASTA file""" + with open(fasta_path) as f: + fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f"More than one input sequence found in {fasta_path}." + ) + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + hits = self._parse_template_hits(alignment_dir, _alignment_index) + template_features = make_template_features( + input_sequence, + hits, + self.template_featurizer, + ) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res, + ) + + msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) + + return { + **sequence_features, + **msa_features, + **template_features + } + + def process_mmcif( + self, + mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path + alignment_dir: str, + chain_id: Optional[str] = None, + _alignment_index: Optional[str] = None, + ) -> FeatureDict: + """ + Assembles features for a specific chain in an mmCIF object. + + If chain_id is None, it is assumed that there is only one chain + in the object. Otherwise, a ValueError is thrown. + """ + if chain_id is None: + chains = mmcif.structure.get_chains() + chain = next(chains, None) + if chain is None: + raise ValueError("No chains in mmCIF file") + chain_id = chain.id + + mmcif_feats = make_mmcif_features(mmcif, chain_id) + + input_sequence = mmcif.chain_to_seqres[chain_id] + hits = self._parse_template_hits(alignment_dir, _alignment_index) + template_features = make_template_features( + input_sequence, + hits, + self.template_featurizer, + query_release_date=to_date(mmcif.header["release_date"]) + ) + + msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) + + return {**mmcif_feats, **template_features, **msa_features} + + def process_pdb( + self, + pdb_path: str, + alignment_dir: str, + is_distillation: bool = True, + chain_id: Optional[str] = None, + _alignment_index: Optional[str] = None, + ) -> FeatureDict: + """ + Assembles features for a protein in a PDB file. + """ + with open(pdb_path, 'r') as f: + pdb_str = f.read() + + protein_object = protein.from_pdb_string(pdb_str, chain_id) + input_sequence = _aatype_to_str_sequence(protein_object.aatype) + description = os.path.splitext(os.path.basename(pdb_path))[0].upper() + pdb_feats = make_pdb_features( + protein_object, + description, + is_distillation + ) + + hits = self._parse_template_hits(alignment_dir, _alignment_index) + template_features = make_template_features( + input_sequence, + hits, + self.template_featurizer, + ) + + msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) + + return {**pdb_feats, **template_features, **msa_features} + + def process_core( + self, + core_path: str, + alignment_dir: str, + _alignment_index: Optional[str] = None, + ) -> FeatureDict: + """ + Assembles features for a protein in a ProteinNet .core file. + """ + with open(core_path, 'r') as f: + core_str = f.read() + + protein_object = protein.from_proteinnet_string(core_str) + input_sequence = _aatype_to_str_sequence(protein_object.aatype) + description = os.path.splitext(os.path.basename(core_path))[0].upper() + core_feats = make_protein_features(protein_object, description) + + hits = self._parse_template_hits(alignment_dir, _alignment_index) + template_features = make_template_features( + input_sequence, + hits, + self.template_featurizer, + ) + + msa_features = self._process_msa_feats(alignment_dir, input_sequence) + + return {**core_feats, **template_features, **msa_features} + diff --git a/openfold/data/data_transforms.py b/openfold/data/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5b19f4c74b035563944894984de59214bb48dd --- /dev/null +++ b/openfold/data/data_transforms.py @@ -0,0 +1,1191 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from functools import reduce, wraps +from operator import add + +import numpy as np +import torch + +from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ +from openfold.np import residue_constants as rc +from openfold.utils.rigid_utils import Rotation, Rigid +from openfold.utils.tensor_utils import ( + tree_map, + tensor_tree_map, + batched_gather, +) + + +MSA_FEATURE_NAMES = [ + "msa", + "deletion_matrix", + "msa_mask", + "msa_row_mask", + "bert_mask", + "true_msa", +] + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if v.dtype == torch.int32: + protein[k] = v.type(torch.int64) + + return protein + + +def make_one_hot(x, num_classes): + x_one_hot = torch.zeros(*x.shape, num_classes) + x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) + return x_one_hot + + +def make_seq_mask(protein): + protein["seq_mask"] = torch.ones( + protein["aatype"].shape, dtype=torch.float32 + ) + return protein + + +def make_template_mask(protein): + protein["template_mask"] = torch.ones( + protein["template_aatype"].shape[0], dtype=torch.float32 + ) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def make_all_atom_aatype(protein): + protein["all_atom_aatype"] = protein["aatype"] + return protein + + +def fix_templates_aatype(protein): + # Map one-hot to indices + num_templates = protein["template_aatype"].shape[0] + if(num_templates > 0): + protein["template_aatype"] = torch.argmax( + protein["template_aatype"], dim=-1 + ) + # Map hhsearch-aatype to our aatype. + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = torch.tensor(new_order_list, dtype=torch.int64).expand( + num_templates, -1 + ) + protein["template_aatype"] = torch.gather( + new_order, 1, index=protein["template_aatype"] + ) + + return protein + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as rc.""" + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = torch.tensor( + [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype + ).transpose(0, 1) + protein["msa"] = torch.gather(new_order, 0, protein["msa"]) + + perm_matrix = np.zeros((22, 22), dtype=np.float32) + perm_matrix[range(len(new_order_list)), new_order_list] = 1.0 + + for k in protein: + if "profile" in k: + num_dim = protein[k].shape.as_list()[-1] + assert num_dim in [ + 20, + 21, + 22, + ], "num_dim for %s out of expected range: %s" % (k, num_dim) + protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim]) + + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + protein["aatype"] = torch.argmax(protein["aatype"], dim=-1) + for k in [ + "domain_name", + "msa", + "num_alignments", + "seq_length", + "sequence", + "superfamily", + "deletion_matrix", + "resolution", + "between_segment_residues", + "residue_index", + "template_all_atom_mask", + ]: + if k in protein: + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ["seq_length", "num_alignments"]: + if k in protein: + protein[k] = protein[k][0] + + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a portion of the MSA with 'X'.""" + msa_mask = torch.rand(protein["msa"].shape) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx) + protein["msa"] = torch.where( + msa_mask, + torch.ones_like(protein["msa"]) * x_idx, + protein["msa"] + ) + aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion + + protein["aatype"] = torch.where( + aatype_mask, + torch.ones_like(protein["aatype"]) * x_idx, + protein["aatype"], + ) + return protein + + +@curry1 +def sample_msa(protein, max_seq, keep_extra, seed=None): + """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" + num_seq = protein["msa"].shape[0] + g = torch.Generator(device=protein["msa"].device) + if seed is not None: + g.manual_seed(seed) + shuffled = torch.randperm(num_seq - 1, generator=g) + 1 + index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) + num_sel = min(max_seq, num_seq) + sel_seq, not_sel_seq = torch.split( + index_order, [num_sel, num_seq - num_sel] + ) + + for k in MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein["extra_" + k] = torch.index_select( + protein[k], 0, not_sel_seq + ) + protein[k] = torch.index_select(protein[k], 0, sel_seq) + + return protein + + +@curry1 +def add_distillation_flag(protein, distillation): + protein['is_distillation'] = distillation + return protein + +@curry1 +def sample_msa_distillation(protein, max_seq): + if(protein["is_distillation"] == 1): + protein = sample_msa(max_seq, keep_extra=False)(protein) + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + num_seq = protein["extra_msa"].shape[0] + num_sel = min(max_extra_msa, num_seq) + select_indices = torch.randperm(num_seq)[:num_sel] + for k in MSA_FEATURE_NAMES: + if "extra_" + k in protein: + protein["extra_" + k] = torch.index_select( + protein["extra_" + k], 0, select_indices + ) + + return protein + + +def delete_extra_msa(protein): + for k in MSA_FEATURE_NAMES: + if "extra_" + k in protein: + del protein["extra_" + k] + return protein + + +# Not used in inference +@curry1 +def block_delete_msa(protein, config): + num_seq = protein["msa"].shape[0] + block_num_seq = torch.floor( + torch.tensor(num_seq, dtype=torch.float32) + * config.msa_fraction_per_block + ).to(torch.int32) + + if config.randomize_num_blocks: + nb = torch.distributions.uniform.Uniform( + 0, config.num_blocks + 1 + ).sample() + else: + nb = config.num_blocks + + del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) + del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) + del_blocks = torch.clip(del_blocks, 0, num_seq - 1) + del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] + + # Make sure we keep the original sequence + combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None])) + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + intersection = uniques[counts > 1] + keep_indices = torch.squeeze(difference, 0) + + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.gather(protein[k], keep_indices) + + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): + weights = torch.cat( + [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)], + 0, + ) + + # Make agreement score as weighted Hamming distance + msa_one_hot = make_one_hot(protein["msa"], 23) + sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot + extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23) + extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot + + num_seq, num_res, _ = sample_one_hot.shape + extra_num_seq, _, _ = extra_one_hot.shape + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + agreement = torch.matmul( + torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + torch.reshape( + sample_one_hot * weights, [num_seq, num_res * 23] + ).transpose(0, 1), + ) + + # Assign each sequence in the extra sequences to the closest MSA sample + protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to( + torch.int64 + ) + + return protein + + +def unsorted_segment_sum(data, segment_ids, num_segments): + """ + Computes the sum along segments of a tensor. Similar to + tf.unsorted_segment_sum, but only supports 1-D indices. + + :param data: A tensor whose segments are to be summed. + :param segment_ids: The 1-D segment indices tensor. + :param num_segments: The number of segments. + :return: A tensor of same data type as the data argument. + """ + assert ( + len(segment_ids.shape) == 1 and + segment_ids.shape[0] == data.shape[0] + ) + segment_ids = segment_ids.view( + segment_ids.shape[0], *((1,) * len(data.shape[1:])) + ) + segment_ids = segment_ids.expand(data.shape) + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor + + +@curry1 +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = protein["msa"].shape[0] + + def csum(x): + return unsorted_segment_sum( + x, protein["extra_cluster_assignment"], num_seq + ) + + mask = protein["extra_msa_mask"] + mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23)) + msa_sum += make_one_hot(protein["msa"], 23) # Original sequence + protein["cluster_profile"] = msa_sum / mask_counts[:, :, None] + del msa_sum + + del_sum = csum(mask * protein["extra_deletion_matrix"]) + del_sum += protein["deletion_matrix"] # Original sequence + protein["cluster_deletion_mean"] = del_sum / mask_counts + del del_sum + + return protein + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32) + protein["msa_row_mask"] = torch.ones( + (protein["msa"].shape[0]), dtype=torch.float32 + ) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + is_gly = torch.eq(aatype, rc.restype_order["G"]) + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_mask is not None: + pseudo_beta_mask = torch.where( + is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=""): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ["", "template_"] + ( + protein[prefix + "pseudo_beta"], + protein[prefix + "pseudo_beta_mask"], + ) = pseudo_beta_fn( + protein["template_aatype" if prefix else "aatype"], + protein[prefix + "all_atom_positions"], + protein["template_all_atom_mask" if prefix else "all_atom_mask"], + ) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = torch.tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = probs.shape + num_classes = ds[-1] + distribution = torch.distributions.categorical.Categorical( + torch.reshape(probs + epsilon, [-1, num_classes]) + ) + counts = distribution.sample() + return torch.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if "hhblits_profile" in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + msa_one_hot = make_one_hot(protein["msa"], 22) + + protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0) + return protein + + +@curry1 +def make_masked_msa(protein, config, replace_fraction): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein["hhblits_profile"] + + config.same_prob * make_one_hot(protein["msa"], 22) + ) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = list( + reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))]) + ) + pad_shapes[1] = 1 + mask_prob = ( + 1.0 - config.profile_prob - config.same_prob - config.uniform_prob + ) + assert mask_prob >= 0.0 + categorical_probs = torch.nn.functional.pad( + categorical_probs, pad_shapes, value=mask_prob + ) + + sh = protein["msa"].shape + mask_position = torch.rand(sh) < replace_fraction + + bert_msa = shaped_categorical(categorical_probs) + bert_msa = torch.where(mask_position, bert_msa, protein["msa"]) + + # Mix real and masked MSA + protein["bert_mask"] = mask_position.to(torch.float32) + protein["true_msa"] = protein["msa"] + protein["msa"] = bert_msa + + return protein + + +@curry1 +def make_fixed_size( + protein, + shape_schema, + msa_cluster_size, + extra_msa_size, + num_res=0, + num_templates=0, +): + """Guess at the MSA and sequence dimension to make fixed size.""" + pad_size_map = { + NUM_RES: num_res, + NUM_MSA_SEQ: msa_cluster_size, + NUM_EXTRA_SEQ: extra_msa_size, + NUM_TEMPLATES: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == "extra_cluster_assignment": + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = "Rank mismatch between shape and shape schema for" + assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}" + pad_size = [ + pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) + ] + + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + padding.reverse() + padding = list(itertools.chain(*padding)) + if padding: + protein[k] = torch.nn.functional.pad(v, padding) + protein[k] = torch.reshape(protein[k], pad_size) + + return protein + + +@curry1 +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping for + # compatibility with domain datasets. + has_break = torch.clip( + protein["between_segment_residues"].to(torch.float32), 0, 1 + ) + aatype_1hot = make_one_hot(protein["aatype"], 21) + + target_feat = [ + torch.unsqueeze(has_break, dim=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + + msa_1hot = make_one_hot(protein["msa"], 23) + has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0) + deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * ( + 2.0 / np.pi + ) + + msa_feat = [ + msa_1hot, + torch.unsqueeze(has_deletion, dim=-1), + torch.unsqueeze(deletion_value, dim=-1), + ] + + if "cluster_profile" in protein: + deletion_mean_value = torch.atan( + protein["cluster_deletion_mean"] / 3.0 + ) * (2.0 / np.pi) + msa_feat.extend( + [ + protein["cluster_profile"], + torch.unsqueeze(deletion_mean_value, dim=-1), + ] + ) + + if "extra_deletion_matrix" in protein: + protein["extra_has_deletion"] = torch.clip( + protein["extra_deletion_matrix"], 0.0, 1.0 + ) + protein["extra_deletion_value"] = torch.atan( + protein["extra_deletion_matrix"] / 3.0 + ) * (2.0 / np.pi) + + protein["msa_feat"] = torch.cat(msa_feat, dim=-1) + protein["target_feat"] = torch.cat(target_feat, dim=-1) + return protein + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +@curry1 +def crop_templates(protein, max_templates): + for k, v in protein.items(): + if k.startswith("template_"): + protein[k] = v[:max_templates] + return protein + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in rc.restypes: + atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] + restype_atom14_to_atom37.append( + [(rc.atom_order[name] if name else 0) for name in atom_names] + ) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append( + [ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in rc.atom_types + ] + ) + + restype_atom14_mask.append( + [(1.0 if name else 0.0) for name in atom_names] + ) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = torch.tensor( + restype_atom14_to_atom37, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom37_to_atom14 = torch.tensor( + restype_atom37_to_atom14, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom14_mask = torch.tensor( + restype_atom14_mask, + dtype=torch.float32, + device=protein["aatype"].device, + ) + protein_aatype = protein['aatype'].to(torch.long) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] + residx_atom14_mask = restype_atom14_mask[protein_aatype] + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() + + # create the gather indices for mapping back + residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] + protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() + + # create the corresponding mask + restype_atom37_mask = torch.zeros( + [21, 37], dtype=torch.float32, device=protein["aatype"].device + ) + for restype, restype_letter in enumerate(rc.restypes): + restype_name = rc.restype_1to3[restype_letter] + atom_names = rc.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = rc.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[protein_aatype] + protein["atom37_atom_exists"] = residx_atom37_mask + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + residx_atom14_mask = protein["atom14_atom_exists"] + residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein["all_atom_mask"], + residx_atom14_to_atom37, + dim=-1, + no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein["all_atom_positions"], + residx_atom14_to_atom37, + dim=-2, + no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), + ) + ) + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["atom14_gt_exists"] = residx_atom14_gt_mask + protein["atom14_gt_positions"] = residx_atom14_gt_positions + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [rc.restype_1to3[res] for res in rc.restypes] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = { + res: torch.eye( + 14, + dtype=protein["all_atom_mask"].dtype, + device=protein["all_atom_mask"].device, + ) + for res in restype_3 + } + for resname, swap in rc.residue_atom_renaming_swaps.items(): + correspondences = torch.arange( + 14, device=protein["all_atom_mask"].device + ) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = rc.restype_name_to_atom14_names[resname].index( + source_atom_swap + ) + target_index = rc.restype_name_to_atom14_names[resname].index( + target_atom_swap + ) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + renaming_matrices = torch.stack( + [all_matrices[restype] for restype in restype_3] + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum( + "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform + ) + protein["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum( + "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform + ) + protein["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) + for resname, swap in rc.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + atom_idx1 = rc.restype_name_to_atom14_names[resname].index( + atom_name1 + ) + atom_idx2 = rc.restype_name_to_atom14_names[resname].index( + atom_name2 + ) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ + protein["aatype"] + ] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + aatype = protein["aatype"] + all_atom_positions = protein["all_atom_positions"] + all_atom_mask = protein["all_atom_mask"] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"] + restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, : + ] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), + ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor( + rc.chi_angles_mask + ) + + lookuptable = rc.atom_order.copy() + lookuptable[""] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, + ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, + ) + restype_rigidgroup_base_atom37_idx = ( + restype_rigidgroup_base_atom37_idx.view( + *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape + ) + ) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + no_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + no_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Rigid.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + no_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + rots = Rotation(rot_mats=rots) + + gt_frames = gt_frames.compose(Rigid(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1,) * batch_dims), 21, 8 + ) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device + ) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1,) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + rot_mats=residx_rigidgroup_ambiguity_rot + ) + alt_gt_frames = gt_frames.compose( + Rigid(residx_rigidgroup_ambiguity_rot, None) + ) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein["rigidgroups_gt_frames"] = gt_frames_tensor + protein["rigidgroups_gt_exists"] = gt_exists + protein["rigidgroups_group_exists"] = group_exists + protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous + protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor + + return protein + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in rc.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in rc.restypes: + residue_name = rc.restype_1to3[residue_name] + residue_chi_angles = rc.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append( + [0, 0, 0, 0] + ) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix="", +): + """ + Convert coordinates to torsion angles. + + This function is extremely sensitive to floating point imprecisions + and should be run with double precision whenever possible. + + Args: + Dict containing: + * (prefix)aatype: + [*, N_res] residue indices + * (prefix)all_atom_positions: + [*, N_res, 37, 3] atom positions (in atom37 + format) + * (prefix)all_atom_mask: + [*, N_res, 37] atom position mask + Returns: + The same dictionary updated with the following features: + + "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Torsion angles + "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Alternate torsion angles (accounting for 180-degree symmetry) + "(prefix)torsion_angles_mask" ([*, N_res, 7]) + Torsion angles mask + """ + aatype = protein[prefix + "aatype"] + all_atom_positions = protein[prefix + "all_atom_positions"] + all_atom_mask = protein[prefix + "all_atom_mask"] + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3] + ) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3 + ) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1 + ) * torch.prod(all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype + ) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4] + ) + + chi_atom_indices = torch.as_tensor( + get_chi_atom_indices(), device=aatype.device + ) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather( + all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2]) + ) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + no_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype + ) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Rigid.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :] + ) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 + ) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + + 1e-8 + ) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( + [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], + )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, + )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None] + ) + + protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos + protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos + protein[prefix + "torsion_angles_mask"] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why. + protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][ + ..., 0, :, : + ] + protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein["all_atom_mask"].dtype + protein["chi_angles_sin_cos"] = ( + protein["torsion_angles_sin_cos"][..., 3:, :] + ).to(dtype) + protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype) + + return protein + + +@curry1 +def random_crop_to_size( + protein, + crop_size, + max_templates, + shape_schema, + subsample_templates=False, + seed=None, +): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + # We want each ensemble to be cropped the same way + g = torch.Generator(device=protein["seq_length"].device) + if seed is not None: + g.manual_seed(seed) + + seq_length = protein["seq_length"] + + if "template_mask" in protein: + num_templates = protein["template_mask"].shape[-1] + else: + num_templates = 0 + + # No need to subsample templates if there aren't any + subsample_templates = subsample_templates and num_templates + + num_res_crop_size = min(int(seq_length), crop_size) + + def _randint(lower, upper): + return int(torch.randint( + lower, + upper + 1, + (1,), + device=protein["seq_length"].device, + generator=g, + )[0]) + + if subsample_templates: + templates_crop_start = _randint(0, num_templates) + templates_select_indices = torch.randperm( + num_templates, device=protein["seq_length"].device, generator=g + ) + else: + templates_crop_start = 0 + + num_templates_crop_size = min( + num_templates - templates_crop_start, max_templates + ) + + n = seq_length - num_res_crop_size + if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: + right_anchor = n + else: + x = _randint(0, n) + right_anchor = n - x + + num_res_crop_start = _randint(0, right_anchor) + + for k, v in protein.items(): + if k not in shape_schema or ( + "template" not in k and NUM_RES not in shape_schema[k] + ): + continue + + # randomly permute the templates before cropping them. + if k.startswith("template") and subsample_templates: + v = v[templates_select_indices] + + slices = [] + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): + is_num_res = dim_size == NUM_RES + if i == 0 and k.startswith("template"): + crop_size = num_templates_crop_size + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size = num_res_crop_size if is_num_res else dim + slices.append(slice(crop_start, crop_start + crop_size)) + protein[k] = v[slices] + + protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) + + return protein diff --git a/openfold/data/errors.py b/openfold/data/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..7809d5ae29a736ccc7688d555a06a5301a3ab46f --- /dev/null +++ b/openfold/data/errors.py @@ -0,0 +1,22 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose errors used throughout the data pipeline""" +class Error(Exception): + """Base class for exceptions.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" diff --git a/openfold/data/feature_pipeline.py b/openfold/data/feature_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..77b9a9f2570d01307b724ca73cf3b00016526383 --- /dev/null +++ b/openfold/data/feature_pipeline.py @@ -0,0 +1,115 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Mapping, Tuple, List, Optional, Dict, Sequence + +import ml_collections +import numpy as np +import torch + +from openfold.data import input_pipeline + + +FeatureDict = Mapping[str, np.ndarray] +TensorDict = Dict[str, torch.Tensor] + + +def np_to_tensor_dict( + np_example: Mapping[str, np.ndarray], + features: Sequence[str], +) -> TensorDict: + """Creates dict of tensors from a dict of NumPy arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + tensor_dict = { + k: torch.tensor(v) for k, v in np_example.items() if k in features + } + return tensor_dict + + +def make_data_config( + config: ml_collections.ConfigDict, + mode: str, + num_res: int, +) -> Tuple[ml_collections.ConfigDict, List[str]]: + cfg = copy.deepcopy(config) + mode_cfg = cfg[mode] + with cfg.unlocked(): + if mode_cfg.crop_size is None: + mode_cfg.crop_size = num_res + + feature_names = cfg.common.unsupervised_features + + if cfg.common.use_templates: + feature_names += cfg.common.template_features + + if cfg[mode].supervised: + feature_names += cfg.supervised.supervised_features + + return cfg, feature_names + + +def np_example_to_features( + np_example: FeatureDict, + config: ml_collections.ConfigDict, + mode: str, +): + np_example = dict(np_example) + num_res = int(np_example["seq_length"][0]) + cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) + + if "deletion_matrix_int" in np_example: + np_example["deletion_matrix"] = np_example.pop( + "deletion_matrix_int" + ).astype(np.float32) + + tensor_dict = np_to_tensor_dict( + np_example=np_example, features=feature_names + ) + with torch.no_grad(): + features = input_pipeline.process_tensors_from_config( + tensor_dict, + cfg.common, + cfg[mode], + ) + + return {k: v for k, v in features.items()} + + +class FeaturePipeline: + def __init__( + self, + config: ml_collections.ConfigDict, + ): + self.config = config + + def process_features( + self, + raw_features: FeatureDict, + mode: str = "train", + ) -> FeatureDict: + return np_example_to_features( + np_example=raw_features, + config=self.config, + mode=mode, + ) diff --git a/openfold/data/input_pipeline.py b/openfold/data/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b706ec3df523580d25d1603fc7f653c597be2902 --- /dev/null +++ b/openfold/data/input_pipeline.py @@ -0,0 +1,208 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch + +from openfold.data import data_transforms + + +def nonensembled_transform_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that are not ensembled.""" + transforms = [ + data_transforms.cast_to_64bit_ints, + data_transforms.correct_msa_restypes, + data_transforms.squeeze_features, + data_transforms.randomly_replace_msa_with_unknown(0.0), + data_transforms.make_seq_mask, + data_transforms.make_msa_mask, + data_transforms.make_hhblits_profile, + ] + if common_cfg.use_templates: + transforms.extend( + [ + data_transforms.fix_templates_aatype, + data_transforms.make_template_mask, + data_transforms.make_pseudo_beta("template_"), + ] + ) + if common_cfg.use_template_torsion_angles: + transforms.extend( + [ + data_transforms.atom37_to_torsion_angles("template_"), + ] + ) + + transforms.extend( + [ + data_transforms.make_atom14_masks, + ] + ) + + if mode_cfg.supervised: + transforms.extend( + [ + data_transforms.make_atom14_positions, + data_transforms.atom37_to_frames, + data_transforms.atom37_to_torsion_angles(""), + data_transforms.make_pseudo_beta(""), + data_transforms.get_backbone_frames, + data_transforms.get_chi_angles, + ] + ) + + return transforms + + +def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): + """Input pipeline data transformers that can be ensembled and averaged.""" + transforms = [] + + if "max_distillation_msa_clusters" in mode_cfg: + transforms.append( + data_transforms.sample_msa_distillation( + mode_cfg.max_distillation_msa_clusters + ) + ) + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + msa_seed = None + if(not common_cfg.resample_msa_in_recycling): + msa_seed = ensemble_seed + + transforms.append( + data_transforms.sample_msa( + max_msa_clusters, + keep_extra=True, + seed=msa_seed, + ) + ) + + if "masked_msa" in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + transforms.append( + data_transforms.make_masked_msa( + common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction + ) + ) + + if common_cfg.msa_cluster_features: + transforms.append(data_transforms.nearest_neighbor_clusters()) + transforms.append(data_transforms.summarize_clusters()) + + # Crop after creating the cluster profiles. + if max_extra_msa: + transforms.append(data_transforms.crop_extra_msa(max_extra_msa)) + else: + transforms.append(data_transforms.delete_extra_msa) + + transforms.append(data_transforms.make_msa_feat()) + + crop_feats = dict(common_cfg.feat) + + if mode_cfg.fixed_size: + transforms.append(data_transforms.select_feat(list(crop_feats))) + transforms.append( + data_transforms.random_crop_to_size( + mode_cfg.crop_size, + mode_cfg.max_templates, + crop_feats, + mode_cfg.subsample_templates, + seed=ensemble_seed + 1, + ) + ) + transforms.append( + data_transforms.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + mode_cfg.crop_size, + mode_cfg.max_templates, + ) + ) + else: + transforms.append( + data_transforms.crop_templates(mode_cfg.max_templates) + ) + + return transforms + + +def process_tensors_from_config(tensors, common_cfg, mode_cfg): + """Based on the config, apply filters and transformations to the data.""" + + ensemble_seed = torch.Generator().seed() + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_transform_fns( + common_cfg, + mode_cfg, + ensemble_seed, + ) + fn = compose(fns) + d["ensemble_index"] = i + return fn(d) + + no_templates = True + if("template_aatype" in tensors): + no_templates = tensors["template_aatype"].shape[0] == 0 + + nonensembled = nonensembled_transform_fns( + common_cfg, + mode_cfg, + ) + + tensors = compose(nonensembled)(tensors) + + if("no_recycling_iters" in tensors): + num_recycling = int(tensors["no_recycling_iters"]) + else: + num_recycling = common_cfg.max_recycling_iters + + tensors = map_fn( + lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) + ) + + return tensors + + +@data_transforms.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x + + +def map_fn(fun, x): + ensembles = [fun(elem) for elem in x] + features = ensembles[0].keys() + ensembled_dict = {} + for feat in features: + ensembled_dict[feat] = torch.stack( + [dict_i[feat] for dict_i in ensembles], dim=-1 + ) + return ensembled_dict diff --git a/openfold/data/mmcif_parsing.py b/openfold/data/mmcif_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..c95047e1f9ab0cf1549b51d9f01da1c5e26080f1 --- /dev/null +++ b/openfold/data/mmcif_parsing.py @@ -0,0 +1,485 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import io +import json +import logging +import os +from typing import Any, Mapping, Optional, Sequence, Tuple + +from Bio import PDB +from Bio.Data import SCOPData +import numpy as np + +from openfold.data.errors import MultipleChainsError +import openfold.np.residue_constants as residue_constants + + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, + 1: ResidueAtPosition, + ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list( + prefix: str, parsed_info: MmCIFDict +) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([len(xs) == len(data[0]) for xs in data]), ( + "mmCIF error: Not all loops are the same length: %s" % cols + ) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict( + prefix: str, + index: str, + parsed_info: MmCIFDict, +) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +def parse( + *, file_id: str, mmcif_string: str, catch_all_errors: bool = True +) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure("", handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ""): "No protein chains found in this file."} + ) + seq_start_num = { + chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items() + } + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != "1": + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = " " + if atom.hetatm_atom == "HETATM": + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ("HOH", "WAT"): + hetflag = "W" + else: + hetflag = "H_" + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = " " + position = ResiduePosition( + chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code, + ) + seq_idx = ( + int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + ) + current = seq_to_structure_mappings.get( + atom.author_chain_id, {} + ) + current[seq_idx] = ResidueAtPosition( + position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag, + ) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition( + position=None, + name=monomer.id, + is_missing=True, + hetflag=" ", + ) + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, "X") + seq.append(code if len(code) == 1 else "X") + seq = "".join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, "")] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list("_exptl.", parsed_info) + header["structure_method"] = ",".join( + [experiment["_exptl.method"].lower() for experiment in experiments] + ) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if "_pdbx_audit_revision_history.revision_date" in parsed_info: + header["release_date"] = get_release_date(parsed_info) + else: + logging.warning( + "Could not determine release_date: %s", parsed_info["_entry.id"] + ) + + header["resolution"] = 0.00 + for res_key in ( + "_refine.ls_d_res_high", + "_em_3d_reconstruction.resolution", + "_reflns.d_resolution_high", + ): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header["resolution"] = float(raw_resolution) + except ValueError: + logging.info( + "Invalid resolution format: %s", parsed_info[res_key] + ) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [ + AtomSite(*site) + for site in zip( # pylint:disable=g-complex-comprehension + parsed_info["_atom_site.label_comp_id"], + parsed_info["_atom_site.auth_asym_id"], + parsed_info["_atom_site.label_asym_id"], + parsed_info["_atom_site.auth_seq_id"], + parsed_info["_atom_site.label_seq_id"], + parsed_info["_atom_site.pdbx_PDB_ins_code"], + parsed_info["_atom_site.group_PDB"], + parsed_info["_atom_site.pdbx_PDB_model_num"], + ) + ] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, Any] +) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append( + Monomer( + id=entity_poly_seq["_entity_poly_seq.mon_id"], + num=int(entity_poly_seq["_entity_poly_seq.num"]), + ) + ) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym["_struct_asym.id"] + entity_id = struct_asym["_struct_asym.entity_id"] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any( + [ + "peptide" in chem_comps[monomer.id]["_chem_comp.type"] + for monomer in seq_info + ] + ): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in (".", "?") + + +def get_atom_coords( + mmcif_object: MmcifObject, + chain_id: str, + _zero_center_positions: bool = True +) -> Tuple[np.ndarray, np.ndarray]: + # Locate the right chain + chains = list(mmcif_object.structure.get_chains()) + relevant_chains = [c for c in chains if c.id == chain_id] + if len(relevant_chains) != 1: + raise MultipleChainsError( + f"Expected exactly one chain in structure with id {chain_id}." + ) + chain = relevant_chains[0] + + # Extract the coordinates + num_res = len(mmcif_object.chain_to_seqres[chain_id]) + all_atom_positions = np.zeros( + [num_res, residue_constants.atom_type_num, 3], dtype=np.float32 + ) + all_atom_mask = np.zeros( + [num_res, residue_constants.atom_type_num], dtype=np.float32 + ) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index] + if not res_at_position.is_missing: + res = chain[ + ( + res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code, + ) + ] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == "SE" and res.get_resname() == "MSE": + # Put the coords of the selenium atom in the sulphur column + pos[residue_constants.atom_order["SD"]] = [x, y, z] + mask[residue_constants.atom_order["SD"]] = 1.0 + + all_atom_positions[res_index] = pos + all_atom_mask[res_index] = mask + + if _zero_center_positions: + binary_mask = all_atom_mask.astype(bool) + translation_vec = all_atom_positions[binary_mask].mean(axis=0) + all_atom_positions[binary_mask] -= translation_vec + + return all_atom_positions, all_atom_mask diff --git a/openfold/data/parsers.py b/openfold/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0574c6e86f5c13d8235b16a61cefa365f0ab5b --- /dev/null +++ b/openfold/data/parsers.py @@ -0,0 +1,388 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + + index: int + name: str + aligned_cols: int + sum_probs: float + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith(">"): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append("") + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm( + stockholm_string: str, +) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(("#", "//")): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = "" + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = "" + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != "-"] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = "".join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != "-" or query_res != "-": + if query_res == "-": + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return msa, deletion_matrix, list(name_to_sequence.keys()) + + +def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + """ + sequences, _ = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans("", "", string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return aligned_sequences, deletion_matrix + + +def _convert_sto_seq_to_a3m( + query_non_gaps: Sequence[bool], sto_seq: str +) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != "-": + yield sequence_res.lower() + + +def convert_stockholm_to_a3m( + stockholm_format: str, max_sequences: Optional[int] = None +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = ( + max_sequences and len(sequences) >= max_sequences + ) + if line.strip() and not line.startswith(("#", "//")): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = "" + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == "#=GS": + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else "" + if feature != "DE": + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != "-" for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + a3m_sequences[seqname] = "".join( + _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence) + ) + + fasta_chunks = ( + f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences + ) + return "\n".join(fasta_chunks) + "\n" # Include terminating newline. + + +def _get_hhr_line_regex_groups( + regex_pattern: str, line: str +) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f"Could not parse query line {line}") + return match.groups() + + +def _update_hhr_residue_indices_list( + sequence: str, start_index: int, indices_list: List[int] +): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == "-": + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t" + " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t " + "]*Template_Neff=(.*)" + ) + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + "Could not parse section: %s. Expected this: \n%s to contain summary." + % (detailed_lines, detailed_lines[2]) + ) + (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [ + float(x) for x in match.groups() + ] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = "" + hit_sequence = "" + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if ( + line.startswith("Q ") + and not line.startswith("Q ss_dssp") + and not line.startswith("Q ss_pred") + and not line.startswith("Q Consensus") + ): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)" + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == "-"]) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith("T "): + # Parse the hit sequence. + if ( + not line.startswith("T ss_dssp") + and not line.startswith("T ss_pred") + and not line.startswith("T Consensus") + ): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)" + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list( + delta_hit_sequence, start, indices_hit + ) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append( + _parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]]) + ) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {"query": 0} + lines = [line for line in tblout.splitlines() if line[0] != "#"] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values diff --git a/openfold/data/templates.py b/openfold/data/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..3071433826e50296818ed5d11d58b92f0bac36b2 --- /dev/null +++ b/openfold/data/templates.py @@ -0,0 +1,1107 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for getting templates and calculating template features.""" +import dataclasses +import datetime +import glob +import json +import logging +import os +import re +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple + +import numpy as np + +from openfold.data import parsers, mmcif_parsing +from openfold.data.errors import Error +from openfold.data.tools import kalign +from openfold.data.tools.utils import to_date +from openfold.np import residue_constants + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class PdbIdError(PrefilterError): + """An error indicating that the hit PDB ID was identical to the query.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + "template_aatype": np.int64, + "template_all_atom_mask": np.float32, + "template_all_atom_positions": np.float32, + "template_domain_names": np.object, + "template_sequence": np.object, + "template_sum_probs": np.float32, +} + + +def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name) + if not id_match: + raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}") + pdb_id, chain_id = id_match.group(0).split("_") + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime], +) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + pdb_id_upper = pdb_id.upper() + if release_date_cutoff is None: + raise ValueError("The release_date_cutoff must not be None.") + if pdb_id_upper in release_dates: + return release_dates[pdb_id_upper] > release_date_cutoff + else: + # Since this is just a quick prefilter to reduce the number of mmCIF files + # we need to parse, we don't have to worry about returning True here. + logging.info( + "Template structure not in release dates dict: %s", pdb_id + ) + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: + """Parses the data file from PDB that lists which PDB ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # We skip obsolete entries that don't contain a mapping to a new entry. + if line.startswith("OBSLTE") and len(line) > 30: + # Format: Date From To + # 'OBSLTE 31-JUL-94 116L 216L' + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + return result + + +def generate_release_dates_cache(mmcif_dir: str, out_path: str): + dates = {} + for f in os.listdir(mmcif_dir): + if f.endswith(".cif"): + path = os.path.join(mmcif_dir, f) + with open(path, "r") as fp: + mmcif_string = fp.read() + + file_id = os.path.splitext(f)[0] + mmcif = mmcif_parsing.parse( + file_id=file_id, mmcif_string=mmcif_string + ) + if mmcif.mmcif_object is None: + logging.info(f"Failed to parse {f}. Skipping...") + continue + + mmcif = mmcif.mmcif_object + release_date = mmcif.header["release_date"] + + dates[file_id] = release_date + + with open(out_path, "r") as fp: + fp.write(json.dumps(dates)) + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + with open(path, "r") as fp: + data = json.load(fp) + + return { + pdb.upper(): to_date(v) + for pdb, d in data.items() + for k, v in d.items() + if k == "release_date" + } + + +def _assess_hhsearch_hit( + hit: parsers.TemplateHit, + hit_pdb_code: str, + query_sequence: str, + query_pdb_code: Optional[str], + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: datetime.datetime, + max_subsequence_ratio: float = 0.95, + min_align_ratio: float = 0.1, +) -> bool: + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + query_pdb_code: 4 letter pdb code of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + PdbIdError: If the hit PDB ID was identical to the query. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace("-", "") + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = ( + template_sequence in query_sequence + and length_ratio > max_subsequence_ratio + ) + + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + date = release_dates[hit_pdb_code.upper()] + raise DateError( + f"Date ({date}) > max template date " + f"({release_date_cutoff})." + ) + + if query_pdb_code is not None: + if query_pdb_code.lower() == hit_pdb_code.lower(): + raise PdbIdError("PDB code identical to Query PDB code.") + + if align_ratio <= min_align_ratio: + raise AlignRatioError( + "Proportion of residues aligned to query too small. " + f"Align ratio: {align_ratio}." + ) + + if duplicate: + raise DuplicateError( + "Template is an exact subsequence of query with large " + f"coverage. Length ratio: {length_ratio}." + ) + + if len(template_sequence) < 10: + raise LengthError( + f"Template too short. Length: {len(template_sequence)}." + ) + + return True + + +def _find_template_in_pdb( + template_chain_id: str, + template_sequence: str, + mmcif_object: mmcif_parsing.MmcifObject, +) -> Tuple[str, str, int]: + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info( + "Found an exact template match %s_%s.", pdb_id, template_chain_id + ) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. + regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence] + regex = re.compile("".join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info( + "Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id + ) + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + "Could not find the template sequence in %s_%s. Template sequence: %s, " + "chain_to_seqres: %s" + % ( + pdb_id, + template_chain_id, + template_sequence, + mmcif_object.chain_to_seqres, + ) + ) + + +def _realign_pdb_template_to_query( + old_template_sequence: str, + template_chain_id: str, + mmcif_object: mmcif_parsing.MmcifObject, + old_mapping: Mapping[int, int], + kalign_binary_path: str, +) -> Tuple[str, Mapping[int, int]]: + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = kalign.Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get( + template_chain_id, "" + ) + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info( + "Could not find %s in %s, but there is only 1 sequence, so " + "using that one.", + template_chain_id, + mmcif_object.file_id, + ) + new_template_sequence = list(mmcif_object.chain_to_seqres.values())[ + 0 + ] + else: + raise QueryToTemplateAlignError( + f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. " + "If there are no mmCIF parsing errors, it is possible it was not a " + "protein chain." + ) + + try: + (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( + aligner.align([old_template_sequence, new_template_sequence]) + ) + except Exception as e: + raise QueryToTemplateAlignError( + "Could not align old template %s to template %s (%s_%s). Error: %s" + % ( + old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e), + ) + ) + + logging.info( + "Old aligned template: %s\nNew aligned template: %s", + old_aligned_template, + new_aligned_template, + ) + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip( + old_aligned_template, new_aligned_template + ): + if old_template_aa != "-": + old_template_index += 1 + if new_template_aa != "-": + new_template_index += 1 + if old_template_aa != "-" and new_template_aa != "-": + old_to_new_template_mapping[old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if ( + float(num_same) + / min(len(old_template_sequence), len(new_template_sequence)) + < 0.9 + ): + raise QueryToTemplateAlignError( + "Insufficient similarity of the sequence in the database: %s to the " + "actual sequence in the mmCIF file %s_%s: %s. We require at least " + "90 %% similarity wrt to the shorter of the sequences. This is not a " + "problem unless you think this is a template that should be included." + % ( + old_template_sequence, + mmcif_object.file_id, + template_chain_id, + new_template_sequence, + ) + ) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[ + query_index + ] = old_to_new_template_mapping.get(old_template_index, -1) + + new_template_sequence = new_template_sequence.replace("-", "") + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances( + all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float, +): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order["CA"] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError( + "The distance between residues %d and %d is %f > limit %f." + % (i, i + 1, distance, max_ca_ca_distance) + ) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object: mmcif_parsing.MmcifObject, + auth_chain_id: str, + max_ca_ca_distance: float, + _zero_center_positions: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + coords_with_mask = mmcif_parsing.get_atom_coords( + mmcif_object=mmcif_object, + chain_id=auth_chain_id, + _zero_center_positions=_zero_center_positions, + ) + all_atom_positions, all_atom_mask = coords_with_mask + _check_residue_distances( + all_atom_positions, all_atom_mask, max_ca_ca_distance + ) + return all_atom_positions, all_atom_mask + + +def _extract_template_features( + mmcif_object: mmcif_parsing.MmcifObject, + pdb_id: str, + mapping: Mapping[int, int], + template_sequence: str, + query_sequence: str, + template_chain_id: str, + kalign_binary_path: str, + _zero_center_positions: bool = True, +) -> Tuple[Dict[str, Any], Optional[str]]: + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError( + "No chains in PDB: %s_%s" % (pdb_id, template_chain_id) + ) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object, + ) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = ( + f"The exact sequence {template_sequence} was not found in " + f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence." + ) + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path, + ) + logging.info( + "Sequence in %s_%s: %s successfully realigned to %s", + pdb_id, + chain_id, + template_sequence, + seqres, + ) + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions( + mmcif_object, + chain_id, + max_ca_ca_distance=150.0, + _zero_center_positions=_zero_center_positions, + ) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError( + "Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex)) + ) from ex + + all_atom_positions = np.split( + all_atom_positions, all_atom_positions.shape[0] + ) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append( + np.zeros((residue_constants.atom_type_num, 3)) + ) + templates_all_atom_masks.append( + np.zeros(residue_constants.atom_type_num) + ) + output_templates_sequence.append("-") + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError( + "Template all atom mask was all zeros: %s_%s. Residue range: %d-%d" + % ( + pdb_id, + chain_id, + min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset, + ) + ) + + output_templates_sequence = "".join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID + ) + + return ( + { + "template_all_atom_positions": np.array( + templates_all_atom_positions + ), + "template_all_atom_mask": np.array(templates_all_atom_masks), + "template_sequence": output_templates_sequence.encode(), + "template_aatype": np.array(templates_aatype), + "template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(), + }, + warning, + ) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str, +) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace("-", "") + hit_sequence = hit_sequence.replace("-", "") + hhsearch_query_offset = original_query_sequence.find( + hhsearch_query_sequence + ) + + # Index of -1 used for gap characters. Subtract the min index ignoring gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query] + + # Zip the corrected indices, ignore case where both seqs have gap characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len( + original_query_sequence + ): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class PrefilterResult: + valid: bool + error: Optional[str] + warning: Optional[str] + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +def _prefilter_hit( + query_sequence: str, + query_pdb_code: Optional[str], + hit: parsers.TemplateHit, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, str], + strict_error_check: bool = False, +): + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being + # obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + release_dates=release_dates, + release_date_cutoff=max_template_date, + ) + except PrefilterError as e: + hit_name = f"{hit_pdb_code}_{hit_chain_id}" + msg = f"hit {hit_name} did not pass prefilter: {str(e)}" + logging.info("%s: %s", query_pdb_code, msg) + if strict_error_check and isinstance( + e, (DateError, PdbIdError, DuplicateError) + ): + # In strict mode we treat some prefilter cases as errors. + return PrefilterResult(valid=False, error=msg, warning=None) + + return PrefilterResult(valid=False, error=None, warning=None) + + return PrefilterResult(valid=True, error=None, warning=None) + + +def _process_single_hit( + query_sequence: str, + query_pdb_code: Optional[str], + hit: parsers.TemplateHit, + mmcif_dir: str, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, str], + kalign_binary_path: str, + strict_error_check: bool = False, + _zero_center_positions: bool = True, +) -> SingleHitResult: + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + mapping = _build_query_to_hit_index_mapping( + hit.query, + hit.hit_sequence, + hit.indices_hit, + hit.indices_query, + query_sequence, + ) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace("-", "") + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif") + logging.info( + "Reading PDB entry from %s. Query: %s, template: %s", + cif_path, + query_sequence, + template_sequence, + ) + # Fail if we can't find the mmCIF file. + with open(cif_path, "r") as cif_file: + cif_string = cif_file.read() + + parsing_result = mmcif_parsing.parse( + file_id=hit_pdb_code, mmcif_string=cif_string + ) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime( + parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d" + ) + if hit_release_date > max_template_date: + error = "Template %s date (%s) > max template date (%s)." % ( + hit_pdb_code, + hit_release_date, + max_template_date, + ) + if strict_error_check: + return SingleHitResult(features=None, error=error, warning=None) + else: + logging.info(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + kalign_binary_path=kalign_binary_path, + _zero_center_positions=_zero_center_positions, + ) + features["template_sum_probs"] = [hit.sum_probs] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult( + features=features, error=None, warning=realign_warning + ) + except ( + NoChainsError, + NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError, + ) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ( + "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: " + "%s, mmCIF parsing errors: %s" + % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + ) + ) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + else: + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ( + "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: " + "%s, mmCIF parsing errors: %s" + % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + ) + ) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer: + """A class for turning hhr hits to template features.""" + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str] = None, + obsolete_pdbs_path: Optional[str] = None, + strict_error_check: bool = False, + _shuffle_top_k_prefiltered: Optional[int] = None, + _zero_center_positions: bool = True, + ): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")): + logging.error("Could not find CIFs in %s", self._mmcif_dir) + raise ValueError(f"Could not find CIFs in {self._mmcif_dir}") + + try: + self._max_template_date = datetime.datetime.strptime( + max_template_date, "%Y-%m-%d" + ) + except ValueError: + raise ValueError( + "max_template_date must be set and have format YYYY-MM-DD." + ) + self.max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info( + "Using precomputed release dates %s.", release_dates_path + ) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info( + "Using precomputed obsolete pdbs %s.", obsolete_pdbs_path + ) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered + self._zero_center_positions = _zero_center_positions + + def get_templates( + self, + query_sequence: str, + query_pdb_code: Optional[str], + query_release_date: Optional[datetime.datetime], + hits: Sequence[parsers.TemplateHit], + ) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info("Searching for template for: %s", query_pdb_code) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + # Always use a max_template_date. Set to query_release_date minus 60 days + # if that's earlier. + template_cutoff_date = self._max_template_date + if query_release_date: + delta = datetime.timedelta(days=60) + if query_release_date - delta < template_cutoff_date: + template_cutoff_date = query_release_date - delta + assert template_cutoff_date < query_release_date + assert template_cutoff_date <= self._max_template_date + + num_hits = 0 + errors = [] + warnings = [] + + filtered = [] + for hit in hits: + prefilter_result = _prefilter_hit( + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + hit=hit, + max_template_date=template_cutoff_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + ) + + if prefilter_result.error: + errors.append(prefilter_result.error) + + if prefilter_result.warning: + warnings.append(prefilter_result.warning) + + if prefilter_result.valid: + filtered.append(hit) + + filtered = list( + sorted(filtered, key=lambda x: x.sum_probs, reverse=True) + ) + idx = list(range(len(filtered))) + if(self._shuffle_top_k_prefiltered): + stk = self._shuffle_top_k_prefiltered + idx[:stk] = np.random.permutation(idx[:stk]) + + for i in idx: + # We got all the templates we wanted, stop processing hits. + if num_hits >= self.max_hits: + break + + hit = filtered[i] + + result = _process_single_hit( + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=template_cutoff_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + _zero_center_positions=self._zero_center_positions, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info( + "Skipped invalid hit %s, error: %s, warning: %s", + hit.name, + result.error, + result.warning, + ) + else: + # Increment the hit counter, since we got features out of this hit. + num_hits += 1 + for k in template_features: + template_features[k].append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack( + template_features[name], axis=0 + ).astype(TEMPLATE_FEATURES[name]) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array( + [], dtype=TEMPLATE_FEATURES[name] + ) + + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings + ) diff --git a/openfold/data/tools/__init__.py b/openfold/data/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openfold/data/tools/hhblits.py b/openfold/data/tools/hhblits.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1c1e7bfdbf1cf7d6f373f79d8f7d2c354195fb --- /dev/null +++ b/openfold/data/tools/hhblits.py @@ -0,0 +1,175 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHblits from Python.""" +import glob +import logging +import os +import subprocess +from typing import Any, Mapping, Optional, Sequence + +from openfold.data.tools import utils + + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__( + self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z, + ): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + "_*"): + logging.error( + "Could not find HHBlits database %s", database_path + ) + raise ValueError( + f"Could not find HHBlits database {database_path}" + ) + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> Mapping[str, Any]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, "output.a3m") + + db_cmd = [] + for db_path in self.databases: + db_cmd.append("-d") + db_cmd.append(db_path) + cmd = [ + self.binary_path, + "-i", + input_fasta_path, + "-cpu", + str(self.n_cpu), + "-oa3m", + a3m_path, + "-o", + "/dev/null", + "-n", + str(self.n_iter), + "-e", + str(self.e_value), + "-maxseq", + str(self.maxseq), + "-realign_max", + str(self.realign_max), + "-maxfilt", + str(self.maxfilt), + "-min_prefilter_hits", + str(self.min_prefilter_hits), + ] + if self.all_seqs: + cmd += ["-all"] + if self.alt: + cmd += ["-alt", str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ["-p", str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ["-Z", str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', " ".join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + with utils.timing("HHblits query"): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error("HHblits failed. HHblits stderr begin:") + for error_line in stderr.decode("utf-8").splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error("HHblits stderr end") + raise RuntimeError( + "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n" + % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8")) + ) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value, + ) + return raw_output diff --git a/openfold/data/tools/hhsearch.py b/openfold/data/tools/hhsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..d62e3b410c5bcf1e1089d961c8c44aad81d7f2a2 --- /dev/null +++ b/openfold/data/tools/hhsearch.py @@ -0,0 +1,106 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHsearch from Python.""" +import glob +import logging +import os +import subprocess +from typing import Sequence + +from openfold.data.tools import utils + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 2, + maxseq: int = 1_000_000, + ): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to use + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.n_cpu = n_cpu + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + "_*"): + logging.error( + "Could not find HHsearch database %s", database_path + ) + raise ValueError( + f"Could not find HHsearch database {database_path}" + ) + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, "query.a3m") + hhr_path = os.path.join(query_tmp_dir, "output.hhr") + with open(input_path, "w") as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append("-d") + db_cmd.append(db_path) + cmd = [ + self.binary_path, + "-i", + input_path, + "-o", + hhr_path, + "-maxseq", + str(self.maxseq), + "-cpu", + str(self.n_cpu), + ] + db_cmd + + logging.info('Launching subprocess "%s"', " ".join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + with utils.timing("HHsearch query"): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n" + % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8")) + ) + + with open(hhr_path) as f: + hhr = f.read() + return hhr diff --git a/openfold/data/tools/jackhmmer.py b/openfold/data/tools/jackhmmer.py new file mode 100644 index 0000000000000000000000000000000000000000..fca93edfe08fee383e40c451a647d8ff48230f3c --- /dev/null +++ b/openfold/data/tools/jackhmmer.py @@ -0,0 +1,228 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run Jackhmmer from Python.""" + +from concurrent import futures +import glob +import logging +import os +import subprocess +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from openfold.data.tools import utils + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if ( + not os.path.exists(self.database_path) + and num_streamed_chunks is None + ): + logging.error("Could not find Jackhmmer database %s", database_path) + raise ValueError( + f"Could not find Jackhmmer database {database_path}" + ) + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk( + self, input_fasta_path: str, database_path: str + ) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, "output.sto") + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + "-o", + "/dev/null", + "-A", + sto_path, + "--noali", + "--F1", + str(self.filter_f1), + "--F2", + str(self.filter_f2), + "--F3", + str(self.filter_f3), + "--incE", + str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + "-E", + str(self.e_value), + "--cpu", + str(self.n_cpu), + "-N", + str(self.n_iter), + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, "tblout.txt") + cmd_flags.extend(["--tblout", tblout_path]) + + if self.z_value: + cmd_flags.extend(["-Z", str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(["--domE", str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(["--incdomE", str(self.incdom_e)]) + + cmd = ( + [self.binary_path] + + cmd_flags + + [input_fasta_path, database_path] + ) + + logging.info('Launching subprocess "%s"', " ".join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + with utils.timing( + f"Jackhmmer ({os.path.basename(database_path)}) query" + ): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8") + ) + + # Get e-values for each target name + tbl = "" + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value, + ) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}" + db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}" + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk("[0-9]*")): + try: + os.remove(f) + except OSError: + print(f"OSError while deleting {f}") + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit( + request.urlretrieve, + db_remote_chunk(i), + db_local_chunk(i), + ) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, + db_remote_chunk(i + 1), + db_local_chunk(i + 1), + ) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i)) + ) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output diff --git a/openfold/data/tools/kalign.py b/openfold/data/tools/kalign.py new file mode 100644 index 0000000000000000000000000000000000000000..5306f864bc220dbd979fd1ae556df56c3e65ff71 --- /dev/null +++ b/openfold/data/tools/kalign.py @@ -0,0 +1,115 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from openfold.data.tools import utils + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ["sequence %d" % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append(u">" + name + u"\n") + a3m.append(sequence + u"\n") + return "".join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info("Aligning %d sequences", len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError( + "Kalign requires all sequences to be at least 6 " + "residues long. Got %s (%d residues)." % (s, len(s)) + ) + + with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") + output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") + + with open(input_fasta_path, "w") as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + "-i", + input_fasta_path, + "-o", + output_a3m_path, + "-format", + "fasta", + ] + + logging.info('Launching subprocess "%s"', " ".join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + with utils.timing("Kalign query"): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + "Kalign stdout:\n%s\n\nstderr:\n%s\n", + stdout.decode("utf-8"), + stderr.decode("utf-8"), + ) + + if retcode: + raise RuntimeError( + "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n" + % (stdout.decode("utf-8"), stderr.decode("utf-8")) + ) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/openfold/data/tools/utils.py b/openfold/data/tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e8684f1e3e6f20f9098c16009836ac737106c2 --- /dev/null +++ b/openfold/data/tools/utils.py @@ -0,0 +1,48 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for data pipeline tools.""" +import contextlib +import datetime +import logging +import shutil +import tempfile +import time +from typing import Optional + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info("Started %s", msg) + tic = time.perf_counter() + yield + toc = time.perf_counter() + logging.info("Finished %s in %.3f seconds", msg, toc - tic) + + +def to_date(s: str): + return datetime.datetime( + year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10]) + ) diff --git a/openfold/model/__init__.py b/openfold/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1 --- /dev/null +++ b/openfold/model/__init__.py @@ -0,0 +1,16 @@ +import os +import glob +import importlib as importlib + +_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) +__all__ = [ + os.path.basename(f)[:-3] + for f in _files + if os.path.isfile(f) and not f.endswith("__init__.py") +] +_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] +for _m in _modules: + globals()[_m[0]] = _m[1] + +# Avoid needlessly cluttering the global namespace +del _files, _m, _modules diff --git a/openfold/model/__pycache__/__init__.cpython-310.pyc b/openfold/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dfd49ba79ae06d6b030200c99200cde4dbb34d5 Binary files /dev/null and b/openfold/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/dropout.cpython-310.pyc b/openfold/model/__pycache__/dropout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6397ee233b8e3a9f5cb266c7a728682262b20c4 Binary files /dev/null and b/openfold/model/__pycache__/dropout.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/embedders.cpython-310.pyc b/openfold/model/__pycache__/embedders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4b71ba49c974cff62e3984e5929ac3bdcc2b453 Binary files /dev/null and b/openfold/model/__pycache__/embedders.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/evoformer.cpython-310.pyc b/openfold/model/__pycache__/evoformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203fe5b4b8db1cd767c96f0eff244f84b90d6623 Binary files /dev/null and b/openfold/model/__pycache__/evoformer.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/heads.cpython-310.pyc b/openfold/model/__pycache__/heads.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd6b3bd70406b11a71d1bd3dea71682ca43dfffa Binary files /dev/null and b/openfold/model/__pycache__/heads.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/model.cpython-310.pyc b/openfold/model/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfb8bb6e63c21cd5ee96cc84904409239864f4fd Binary files /dev/null and b/openfold/model/__pycache__/model.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/msa.cpython-310.pyc b/openfold/model/__pycache__/msa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7af39fa36da6d3535a25d69877a645baeb91c46 Binary files /dev/null and b/openfold/model/__pycache__/msa.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc b/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6af94ce1b8963ca6664cfda25f3fad3e8349a2d Binary files /dev/null and b/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/pair_transition.cpython-310.pyc b/openfold/model/__pycache__/pair_transition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cf7918be7f8591ae0a45ef31e131cd72cac260e Binary files /dev/null and b/openfold/model/__pycache__/pair_transition.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/primitives.cpython-310.pyc b/openfold/model/__pycache__/primitives.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8c357d467780e29c9dc8a85647c72a30337725d Binary files /dev/null and b/openfold/model/__pycache__/primitives.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/structure_module.cpython-310.pyc b/openfold/model/__pycache__/structure_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e32c64bb867cf233488161e2a7d8a03740698911 Binary files /dev/null and b/openfold/model/__pycache__/structure_module.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/template.cpython-310.pyc b/openfold/model/__pycache__/template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e2614459d6d5e5c9ccdb9fd13a1bcc845464d4e Binary files /dev/null and b/openfold/model/__pycache__/template.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/torchscript.cpython-310.pyc b/openfold/model/__pycache__/torchscript.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cbfc71f0e8a3230e5063a74c0fee2b15368c6a9 Binary files /dev/null and b/openfold/model/__pycache__/torchscript.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/triangular_attention.cpython-310.pyc b/openfold/model/__pycache__/triangular_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab34565e5dcc515d04a6d1b8f8ba2d683110c49f Binary files /dev/null and b/openfold/model/__pycache__/triangular_attention.cpython-310.pyc differ diff --git a/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc b/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a49270d3fc7837b5ad9561a01a37184d5d13122 Binary files /dev/null and b/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc differ diff --git a/openfold/model/dropout.py b/openfold/model/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..651b9775ef44fba20dec75c60703f00beac66e0c --- /dev/null +++ b/openfold/model/dropout.py @@ -0,0 +1,78 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from functools import partialmethod +from typing import Union, List + + +class Dropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask + along a particular dimension. + + If not in training mode, this module computes the identity function. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + """ + Args: + r: + Dropout rate + batch_dim: + Dimension(s) along which the dropout mask is shared + """ + super(Dropout, self).__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Tensor to which dropout is applied. Can have any shape + compatible with self.batch_dim + """ + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + mask = x.new_ones(shape) + mask = self.dropout(mask) + x *= mask + return x + + +class DropoutRowwise(Dropout): + """ + Convenience class for rowwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-3) + + +class DropoutColumnwise(Dropout): + """ + Convenience class for columnwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/openfold/model/embedders.py b/openfold/model/embedders.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f1f9981c9965be630fd1ccb98ebdcf0fe42dc5 --- /dev/null +++ b/openfold/model/embedders.py @@ -0,0 +1,352 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Tuple + +from openfold.model.primitives import Linear, LayerNorm +from openfold.utils.tensor_utils import one_hot + + +class InputEmbedder(nn.Module): + """ + Embeds a subset of the input features. + + Implements Algorithms 3 (InputEmbedder) and 4 (relpos). + """ + + def __init__( + self, + tf_dim: int, + msa_dim: int, + c_z: int, + c_m: int, + relpos_k: int, + **kwargs, + ): + """ + Args: + tf_dim: + Final dimension of the target features + msa_dim: + Final dimension of the MSA features + c_z: + Pair embedding dimension + c_m: + MSA embedding dimension + relpos_k: + Window size used in relative positional encoding + """ + super(InputEmbedder, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.c_z = c_z + self.c_m = c_m + + self.linear_tf_z_i = Linear(tf_dim, c_z) + self.linear_tf_z_j = Linear(tf_dim, c_z) + self.linear_tf_m = Linear(tf_dim, c_m) + self.linear_msa_m = Linear(msa_dim, c_m) + + # RPE stuff + self.relpos_k = relpos_k + self.no_bins = 2 * relpos_k + 1 + self.linear_relpos = Linear(self.no_bins, c_z) + + def relpos(self, ri: torch.Tensor): + """ + Computes relative positional encodings + + Implements Algorithm 4. + + Args: + ri: + "residue_index" features of shape [*, N] + """ + d = ri[..., None] - ri[..., None, :] + boundaries = torch.arange( + start=-self.relpos_k, end=self.relpos_k + 1, device=d.device + ) + oh = one_hot(d, boundaries).type(ri.dtype) + return self.linear_relpos(oh) + + def forward( + self, + tf: torch.Tensor, + ri: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + tf: + "target_feat" features of shape [*, N_res, tf_dim] + ri: + "residue_index" features of shape [*, N_res] + msa: + "msa_feat" features of shape [*, N_clust, N_res, msa_dim] + Returns: + msa_emb: + [*, N_clust, N_res, C_m] MSA embedding + pair_emb: + [*, N_res, N_res, C_z] pair embedding + + """ + # [*, N_res, c_z] + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + + # [*, N_res, N_res, c_z] + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) + + # [*, N_clust, N_res, c_m] + n_clust = msa.shape[-3] + tf_m = ( + self.linear_tf_m(tf) + .unsqueeze(-3) + .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1))) + ) + msa_emb = self.linear_msa_m(msa) + tf_m + + return msa_emb, pair_emb + + +class RecyclingEmbedder(nn.Module): + """ + Embeds the output of an iteration of the model for recycling. + + Implements Algorithm 32. + """ + + def __init__( + self, + c_m: int, + c_z: int, + min_bin: float, + max_bin: float, + no_bins: int, + inf: float = 1e8, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair embedding channel dimension + min_bin: + Smallest distogram bin (Angstroms) + max_bin: + Largest distogram bin (Angstroms) + no_bins: + Number of distogram bins + """ + super(RecyclingEmbedder, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.min_bin = min_bin + self.max_bin = max_bin + self.no_bins = no_bins + self.inf = inf + + self.bins = None + + self.linear = Linear(self.no_bins, self.c_z) + self.layer_norm_m = LayerNorm(self.c_m) + self.layer_norm_z = LayerNorm(self.c_z) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + m: + First row of the MSA embedding. [*, N_res, C_m] + z: + [*, N_res, N_res, C_z] pair embedding + x: + [*, N_res, 3] predicted C_beta coordinates + Returns: + m: + [*, N_res, C_m] MSA embedding update + z: + [*, N_res, N_res, C_z] pair embedding update + """ + if self.bins is None: + self.bins = torch.linspace( + self.min_bin, + self.max_bin, + self.no_bins, + dtype=x.dtype, + device=x.device, + requires_grad=False, + ) + + # [*, N, C_m] + m_update = self.layer_norm_m(m) + + # This squared method might become problematic in FP16 mode. + # I'm using it because my homegrown method had a stubborn discrepancy I + # couldn't find in time. + squared_bins = self.bins ** 2 + upper = torch.cat( + [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 + ) + d = torch.sum( + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True + ) + + # [*, N, N, no_bins] + d = ((d > squared_bins) * (d < upper)).type(x.dtype) + + # [*, N, N, C_z] + d = self.linear(d) + z_update = d + self.layer_norm_z(z) + + return m_update, z_update + + +class TemplateAngleEmbedder(nn.Module): + """ + Embeds the "template_angle_feat" feature. + + Implements Algorithm 2, line 7. + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + Final dimension of "template_angle_feat" + c_out: + Output channel dimension + """ + super(TemplateAngleEmbedder, self).__init__() + + self.c_out = c_out + self.c_in = c_in + + self.linear_1 = Linear(self.c_in, self.c_out, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.c_out, self.c_out, init="relu") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [*, N_templ, N_res, c_in] "template_angle_feat" features + Returns: + x: [*, N_templ, N_res, C_out] embedding + """ + x = self.linear_1(x) + x = self.relu(x) + x = self.linear_2(x) + + return x + + +class TemplatePairEmbedder(nn.Module): + """ + Embeds "template_pair_feat" features. + + Implements Algorithm 2, line 9. + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + + c_out: + Output channel dimension + """ + super(TemplatePairEmbedder, self).__init__() + + self.c_in = c_in + self.c_out = c_out + + # Despite there being no relu nearby, the source uses that initializer + self.linear = Linear(self.c_in, self.c_out, init="relu") + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + [*, C_in] input tensor + Returns: + [*, C_out] output tensor + """ + x = self.linear(x) + + return x + + +class ExtraMSAEmbedder(nn.Module): + """ + Embeds unclustered MSA sequences. + + Implements Algorithm 2, line 15 + """ + + def __init__( + self, + c_in: int, + c_out: int, + **kwargs, + ): + """ + Args: + c_in: + Input channel dimension + c_out: + Output channel dimension + """ + super(ExtraMSAEmbedder, self).__init__() + + self.c_in = c_in + self.c_out = c_out + + self.linear = Linear(self.c_in, self.c_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features + Returns: + [*, N_extra_seq, N_res, C_out] embedding + """ + x = self.linear(x) + + return x diff --git a/openfold/model/evoformer.py b/openfold/model/evoformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a862097a40dcb81514e3522494a66ea5078b01 --- /dev/null +++ b/openfold/model/evoformer.py @@ -0,0 +1,630 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Tuple, Optional +from functools import partial + +from openfold.model.primitives import Linear, LayerNorm +from openfold.model.dropout import DropoutRowwise, DropoutColumnwise +from openfold.model.msa import ( + MSARowAttentionWithPairBias, + MSAColumnAttention, + MSAColumnGlobalAttention, +) +from openfold.model.outer_product_mean import OuterProductMean +from openfold.model.pair_transition import PairTransition +from openfold.model.triangular_attention import ( + TriangleAttentionStartingNode, + TriangleAttentionEndingNode, +) +from openfold.model.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) +from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn +from openfold.utils.tensor_utils import chunk_layer + + +class MSATransition(nn.Module): + """ + Feed-forward network applied to MSA activations after attention. + + Implements Algorithm 9 + """ + def __init__(self, c_m, n): + """ + Args: + c_m: + MSA channel dimension + n: + Factor multiplied to c_m to obtain the hidden channel + dimension + """ + super(MSATransition, self).__init__() + + self.c_m = c_m + self.n = n + + self.layer_norm = LayerNorm(self.c_m) + self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") + + def _transition(self, m, mask): + m = self.linear_1(m) + m = self.relu(m) + m = self.linear_2(m) * mask + return m + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"m": m, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA activation + mask: + [*, N_seq, N_res, C_m] MSA mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA activation update + """ + # DISCREPANCY: DeepMind forgets to apply the MSA mask here. + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + mask = mask.unsqueeze(-1) + + m = self.layer_norm(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._transition(m, mask) + + return m + + +class EvoformerBlockCore(nn.Module): + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + pair_dropout: float, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + ): + super(EvoformerBlockCore, self).__init__() + + self.msa_transition = MSATransition( + c_m=c_m, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + c_z, + c_hidden_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + c_z, + c_hidden_mul, + ) + + self.tri_att_start = TriangleAttentionStartingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + + self.pair_transition = PairTransition( + c_z, + transition_n, + ) + + self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) + self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # DeepMind doesn't mask these transitions in the source, so _mask_trans + # should be disabled to better approximate the exact activations of + # the original. + msa_trans_mask = msa_mask if _mask_trans else None + pair_trans_mask = pair_mask if _mask_trans else None + + m = m + self.msa_transition( + m, mask=msa_trans_mask, chunk_size=chunk_size + ) + z = z + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size + ) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer( + self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.ps_dropout_col_layer( + self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.pair_transition( + z, mask=pair_trans_mask, chunk_size=chunk_size + ) + + return m, z + + +class EvoformerBlock(nn.Module): + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ): + super(EvoformerBlock, self).__init__() + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnAttention( + c_m, + c_hidden_msa_att, + no_heads_msa, + inf=inf, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) + ) + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + + return m, z + + +class ExtraMSABlock(nn.Module): + """ + Almost identical to the standard EvoformerBlock, except in that the + ExtraMSABlock uses GlobalAttention for MSA column attention and + requires more fine-grained control over checkpointing. Separated from + its twin to preserve the TorchScript-ability of the latter. + """ + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + ): + super(ExtraMSABlock, self).__init__() + + self.ckpt = ckpt + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnGlobalAttention( + c_in=c_m, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + eps=eps, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = 1024, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def add(m1, m2): + # The first operation in a checkpoint can't be in-place, but it's + # nice to have in-place addition during inference. Thus... + if(torch.is_grad_enabled()): + m1 = m1 + m2 + else: + m1 += m2 + + return m1 + + m = add(m, self.msa_dropout_layer( + self.msa_att_row( + m.clone() if torch.is_grad_enabled() else m, + z=z.clone() if torch.is_grad_enabled() else z, + mask=msa_mask, + chunk_size=chunk_size, + _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, + _checkpoint_chunks= + self.ckpt if torch.is_grad_enabled() else False, + ) + )) + + def fn(m, z): + m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) + m, z = self.core( + m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size + ) + + return m, z + + if(torch.is_grad_enabled() and self.ckpt): + checkpoint_fn = get_checkpoint_fn() + m, z = checkpoint_fn(fn, m, z) + else: + m, z = fn(m, z) + + return m, z + + +class EvoformerStack(nn.Module): + """ + Main Evoformer trunk. + + Implements Algorithm 6. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + c_s: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + blocks_per_ckpt: int, + inf: float, + eps: float, + clear_cache_between_blocks: bool = False, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair channel dimension + c_hidden_msa_att: + Hidden dimension in MSA attention + c_hidden_opm: + Hidden dimension in outer product mean module + c_hidden_mul: + Hidden dimension in multiplicative updates + c_hidden_pair_att: + Hidden dimension in triangular attention + c_s: + Channel dimension of the output "single" embedding + no_heads_msa: + Number of heads used for MSA attention + no_heads_pair: + Number of heads used for pair attention + no_blocks: + Number of Evoformer blocks in the stack + transition_n: + Factor by which to multiply c_m to obtain the MSATransition + hidden dimension + msa_dropout: + Dropout rate for MSA activations + pair_dropout: + Dropout used for pair activations + blocks_per_ckpt: + Number of Evoformer blocks in each activation checkpoint + clear_cache_between_blocks: + Whether to clear CUDA's GPU memory cache between blocks of the + stack. Slows down each block but can reduce fragmentation + """ + super(EvoformerStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + self.clear_cache_between_blocks = clear_cache_between_blocks + + self.blocks = nn.ModuleList() + + for _ in range(no_blocks): + block = EvoformerBlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + self.blocks.append(block) + + self.linear = Linear(c_m, c_s) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: int, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + [*, N_seq, N_res] MSA mask + pair_mask: + [*, N_res, N_res] pair mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + s: + [*, N_res, C_s] single embedding (or None if extra MSA stack) + """ + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + for b in self.blocks + ] + + if(self.clear_cache_between_blocks): + def block_with_cache_clear(block, *args): + torch.cuda.empty_cache() + return block(*args) + + blocks = [partial(block_with_cache_clear, b) for b in blocks] + + m, z = checkpoint_blocks( + blocks, + args=(m, z), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + s = self.linear(m[..., 0, :, :]) + + return m, z, s + + +class ExtraMSAStack(nn.Module): + """ + Implements Algorithm 18. + """ + + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + clear_cache_between_blocks: bool = False, + **kwargs, + ): + super(ExtraMSAStack, self).__init__() + + self.clear_cache_between_blocks = clear_cache_between_blocks + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = ExtraMSABlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ckpt=ckpt, + ) + self.blocks.append(block) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + chunk_size: int, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + _mask_trans: bool = True, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_extra, N_res, C_m] extra MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + Optional [*, N_extra, N_res] MSA mask + pair_mask: + Optional [*, N_res, N_res] pair mask + Returns: + [*, N_res, N_res, C_z] pair update + """ + #checkpoint_fn = get_checkpoint_fn() + #blocks = [ + # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks + #] + + #def dodo(b, *args): + # torch.cuda.empty_cache() + # return b(*args) + + #blocks = [partial(dodo, b) for b in blocks] + + #for b in blocks: + # if(torch.is_grad_enabled()): + # m, z = checkpoint_fn(b, *(m, z)) + # else: + # m, z = b(m, z) + + for b in self.blocks: + m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) + + if(self.clear_cache_between_blocks): + torch.cuda.empty_cache() + + return z diff --git a/openfold/model/heads.py b/openfold/model/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f46126d741e79d177a80220f89e1523a408e569b --- /dev/null +++ b/openfold/model/heads.py @@ -0,0 +1,251 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear, LayerNorm +from openfold.utils.loss import ( + compute_plddt, + compute_tm, + compute_predicted_aligned_error, +) + + +class AuxiliaryHeads(nn.Module): + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PerResidueLDDTCaPredictor( + **config["lddt"], + ) + + self.distogram = DistogramHead( + **config["distogram"], + ) + + self.masked_msa = MaskedMSAHead( + **config["masked_msa"], + ) + + self.experimentally_resolved = ExperimentallyResolvedHead( + **config["experimentally_resolved"], + ) + + if config.tm.enabled: + self.tm = TMScoreHead( + **config.tm, + ) + + self.config = config + + def forward(self, outputs): + aux_out = {} + lddt_logits = self.plddt(outputs["sm"]["single"]) + aux_out["lddt_logits"] = lddt_logits + + # Required for relaxation later on + aux_out["plddt"] = compute_plddt(lddt_logits) + + distogram_logits = self.distogram(outputs["pair"]) + aux_out["distogram_logits"] = distogram_logits + + masked_msa_logits = self.masked_msa(outputs["msa"]) + aux_out["masked_msa_logits"] = masked_msa_logits + + experimentally_resolved_logits = self.experimentally_resolved( + outputs["single"] + ) + aux_out[ + "experimentally_resolved_logits" + ] = experimentally_resolved_logits + + if self.config.tm.enabled: + tm_logits = self.tm(outputs["pair"]) + aux_out["tm_logits"] = tm_logits + aux_out["predicted_tm_score"] = compute_tm( + tm_logits, **self.config.tm + ) + aux_out.update( + compute_predicted_aligned_error( + tm_logits, + **self.config.tm, + ) + ) + + return aux_out + + +class PerResidueLDDTCaPredictor(nn.Module): + def __init__(self, no_bins, c_in, c_hidden): + super(PerResidueLDDTCaPredictor, self).__init__() + + self.no_bins = no_bins + self.c_in = c_in + self.c_hidden = c_hidden + + self.layer_norm = LayerNorm(self.c_in) + + self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + return s + + +class DistogramHead(nn.Module): + """ + Computes a distogram probability distribution. + + For use in computation of distogram loss, subsection 1.9.8 + """ + + def __init__(self, c_z, no_bins, **kwargs): + """ + Args: + c_z: + Input channel dimension + no_bins: + Number of distogram bins + """ + super(DistogramHead, self).__init__() + + self.c_z = c_z + self.no_bins = no_bins + + self.linear = Linear(self.c_z, self.no_bins, init="final") + + def forward(self, z): # [*, N, N, C_z] + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N, N, no_bins] distogram probability distribution + """ + # [*, N, N, no_bins] + logits = self.linear(z) + logits = logits + logits.transpose(-2, -3) + return logits + + +class TMScoreHead(nn.Module): + """ + For use in computation of TM-score, subsection 1.9.7 + """ + + def __init__(self, c_z, no_bins, **kwargs): + """ + Args: + c_z: + Input channel dimension + no_bins: + Number of bins + """ + super(TMScoreHead, self).__init__() + + self.c_z = c_z + self.no_bins = no_bins + + self.linear = Linear(self.c_z, self.no_bins, init="final") + + def forward(self, z): + """ + Args: + z: + [*, N_res, N_res, C_z] pairwise embedding + Returns: + [*, N_res, N_res, no_bins] prediction + """ + # [*, N, N, no_bins] + logits = self.linear(z) + return logits + + +class MaskedMSAHead(nn.Module): + """ + For use in computation of masked MSA loss, subsection 1.9.9 + """ + + def __init__(self, c_m, c_out, **kwargs): + """ + Args: + c_m: + MSA channel dimension + c_out: + Output channel dimension + """ + super(MaskedMSAHead, self).__init__() + + self.c_m = c_m + self.c_out = c_out + + self.linear = Linear(self.c_m, self.c_out, init="final") + + def forward(self, m): + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + Returns: + [*, N_seq, N_res, C_out] reconstruction + """ + # [*, N_seq, N_res, C_out] + logits = self.linear(m) + return logits + + +class ExperimentallyResolvedHead(nn.Module): + """ + For use in computation of "experimentally resolved" loss, subsection + 1.9.10 + """ + + def __init__(self, c_s, c_out, **kwargs): + """ + Args: + c_s: + Input channel dimension + c_out: + Number of distogram bins + """ + super(ExperimentallyResolvedHead, self).__init__() + + self.c_s = c_s + self.c_out = c_out + + self.linear = Linear(self.c_s, self.c_out, init="final") + + def forward(self, s): + """ + Args: + s: + [*, N_res, C_s] single embedding + Returns: + [*, N, C_out] logits + """ + # [*, N, C_out] + logits = self.linear(s) + return logits diff --git a/openfold/model/model.py b/openfold/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..007b90f7af7b7f5b379f5c18d89f0f69f398c494 --- /dev/null +++ b/openfold/model/model.py @@ -0,0 +1,446 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn + +from openfold.utils.feats import ( + pseudo_beta_fn, + build_extra_msa_feat, + build_template_angle_feat, + build_template_pair_feat, + atom14_to_atom37, +) +from openfold.model.embedders import ( + InputEmbedder, + RecyclingEmbedder, + TemplateAngleEmbedder, + TemplatePairEmbedder, + ExtraMSAEmbedder, +) +from openfold.model.evoformer import EvoformerStack, ExtraMSAStack +from openfold.model.heads import AuxiliaryHeads +import openfold.np.residue_constants as residue_constants +from openfold.model.structure_module import StructureModule +from openfold.model.template import ( + TemplatePairStack, + TemplatePointwiseAttention, +) +from openfold.utils.loss import ( + compute_plddt, +) +from openfold.utils.tensor_utils import ( + dict_multimap, + tensor_tree_map, +) + + +class AlphaFold(nn.Module): + """ + Alphafold 2. + + Implements Algorithm 2 (but with training). + """ + + def __init__(self, config): + """ + Args: + config: + A dict-like config object (like the one in config.py) + """ + super(AlphaFold, self).__init__() + + self.globals = config.globals + config = config.model + template_config = config.template + extra_msa_config = config.extra_msa + + # Main trunk + structure module + self.input_embedder = InputEmbedder( + **config["input_embedder"], + ) + self.recycling_embedder = RecyclingEmbedder( + **config["recycling_embedder"], + ) + self.template_angle_embedder = TemplateAngleEmbedder( + **template_config["template_angle_embedder"], + ) + self.template_pair_embedder = TemplatePairEmbedder( + **template_config["template_pair_embedder"], + ) + self.template_pair_stack = TemplatePairStack( + **template_config["template_pair_stack"], + ) + self.template_pointwise_att = TemplatePointwiseAttention( + **template_config["template_pointwise_attention"], + ) + self.extra_msa_embedder = ExtraMSAEmbedder( + **extra_msa_config["extra_msa_embedder"], + ) + self.extra_msa_stack = ExtraMSAStack( + **extra_msa_config["extra_msa_stack"], + ) + self.evoformer = EvoformerStack( + **config["evoformer_stack"], + ) + self.structure_module = StructureModule( + **config["structure_module"], + ) + + self.aux_heads = AuxiliaryHeads( + config["heads"], + ) + + self.config = config + + def embed_templates(self, batch, z, pair_mask, templ_dim): + # Embed the templates one at a time (with a poor man's vmap) + template_embeds = [] + n_templ = batch["template_aatype"].shape[templ_dim] + for i in range(n_templ): + idx = batch["template_aatype"].new_tensor(i) + single_template_feats = tensor_tree_map( + lambda t: torch.index_select(t, templ_dim, idx), + batch, + ) + + single_template_embeds = {} + if self.config.template.embed_angles: + template_angle_feat = build_template_angle_feat( + single_template_feats, + ) + + # [*, S_t, N, C_m] + a = self.template_angle_embedder(template_angle_feat) + + single_template_embeds["angle"] = a + + # [*, S_t, N, N, C_t] + t = build_template_pair_feat( + single_template_feats, + inf=self.config.template.inf, + eps=self.config.template.eps, + **self.config.template.distogram, + ).to(z.dtype) + t = self.template_pair_embedder(t) + + single_template_embeds.update({"pair": t}) + + template_embeds.append(single_template_embeds) + + template_embeds = dict_multimap( + partial(torch.cat, dim=templ_dim), + template_embeds, + ) + + # [*, S_t, N, N, C_z] + t = self.template_pair_stack( + template_embeds["pair"], + pair_mask.unsqueeze(-3).to(dtype=z.dtype), + chunk_size=self.globals.chunk_size, + _mask_trans=self.config._mask_trans, + ) + + # [*, N, N, C_z] + t = self.template_pointwise_att( + t, + z, + template_mask=batch["template_mask"].to(dtype=z.dtype), + chunk_size=self.globals.chunk_size, + ) + t = t * (torch.sum(batch["template_mask"]) > 0) + + ret = {} + if self.config.template.embed_angles: + ret["template_angle_embedding"] = template_embeds["angle"] + + ret.update({"template_pair_embedding": t}) + + return ret + + def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): + # Primary output dictionary + outputs = {} + + # This needs to be done manually for DeepSpeed's sake + dtype = next(self.parameters()).dtype + for k in feats: + if(feats[k].dtype == torch.float32): + feats[k] = feats[k].to(dtype=dtype) + + # Grab some data about the input + batch_dims = feats["target_feat"].shape[:-2] + no_batch_dims = len(batch_dims) + n = feats["target_feat"].shape[-2] + n_seq = feats["msa_feat"].shape[-3] + device = feats["target_feat"].device + + # Prep some features + seq_mask = feats["seq_mask"] + pair_mask = seq_mask[..., None] * seq_mask[..., None, :] + msa_mask = feats["msa_mask"] + + # Initialize the MSA and pair representations + + # m: [*, S_c, N, C_m] + # z: [*, N, N, C_z] + m, z = self.input_embedder( + feats["target_feat"], + feats["residue_index"], + feats["msa_feat"], + ) + + # Initialize the recycling embeddings, if needs be + if None in [m_1_prev, z_prev, x_prev]: + # [*, N, C_m] + m_1_prev = m.new_zeros( + (*batch_dims, n, self.config.input_embedder.c_m), + requires_grad=False, + ) + + # [*, N, N, C_z] + z_prev = z.new_zeros( + (*batch_dims, n, n, self.config.input_embedder.c_z), + requires_grad=False, + ) + + # [*, N, 3] + x_prev = z.new_zeros( + (*batch_dims, n, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + + x_prev = pseudo_beta_fn( + feats["aatype"], x_prev, None + ).to(dtype=z.dtype) + + # m_1_prev_emb: [*, N, C_m] + # z_prev_emb: [*, N, N, C_z] + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + x_prev, + ) + + # If the number of recycling iterations is 0, skip recycling + # altogether. We zero them this way instead of computing them + # conditionally to avoid leaving parameters unused, which has annoying + # implications for DDP training. + if(not _recycle): + m_1_prev_emb *= 0 + z_prev_emb *= 0 + + # [*, S_c, N, C_m] + m[..., 0, :, :] += m_1_prev_emb + + # [*, N, N, C_z] + z += z_prev_emb + + # Possibly prevents memory fragmentation + del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb + + # Embed the templates + merge with MSA/pair embeddings + if self.config.template.enabled: + template_feats = { + k: v for k, v in feats.items() if k.startswith("template_") + } + template_embeds = self.embed_templates( + template_feats, + z, + pair_mask.to(dtype=z.dtype), + no_batch_dims, + ) + + # [*, N, N, C_z] + z = z + template_embeds["template_pair_embedding"] + + if self.config.template.embed_angles: + # [*, S = S_c + S_t, N, C_m] + m = torch.cat( + [m, template_embeds["template_angle_embedding"]], + dim=-3 + ) + + # [*, S, N] + torsion_angles_mask = feats["template_torsion_angles_mask"] + msa_mask = torch.cat( + [feats["msa_mask"], torsion_angles_mask[..., 2]], + dim=-2 + ) + + # Embed extra MSA features + merge with pairwise embeddings + if self.config.extra_msa.enabled: + # [*, S_e, N, C_e] + a = self.extra_msa_embedder(build_extra_msa_feat(feats)) + + # [*, N, N, C_z] + z = self.extra_msa_stack( + a, + z, + msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), + chunk_size=self.globals.chunk_size, + pair_mask=pair_mask.to(dtype=z.dtype), + _mask_trans=self.config._mask_trans, + ) + + # Run MSA + pair embeddings through the trunk of the network + # m: [*, S, N, C_m] + # z: [*, N, N, C_z] + # s: [*, N, C_s] + m, z, s = self.evoformer( + m, + z, + msa_mask=msa_mask.to(dtype=m.dtype), + pair_mask=pair_mask.to(dtype=z.dtype), + chunk_size=self.globals.chunk_size, + _mask_trans=self.config._mask_trans, + ) + + outputs["msa"] = m[..., :n_seq, :, :] + outputs["pair"] = z + outputs["single"] = s + + # Predict 3D structure + outputs["sm"] = self.structure_module( + s, + z, + feats["aatype"], + mask=feats["seq_mask"].to(dtype=s.dtype), + ) + outputs["final_atom_positions"] = atom14_to_atom37( + outputs["sm"]["positions"][-1], feats + ) + outputs["final_atom_mask"] = feats["atom37_atom_exists"] + outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] + + # Save embeddings for use during the next recycling iteration + + # [*, N, C_m] + m_1_prev = m[..., 0, :, :] + + # [*, N, N, C_z] + z_prev = z + + # [*, N, 3] + x_prev = outputs["final_atom_positions"] + + return outputs, m_1_prev, z_prev, x_prev + + def _disable_activation_checkpointing(self): + self.template_pair_stack.blocks_per_ckpt = None + self.evoformer.blocks_per_ckpt = None + + for b in self.extra_msa_stack.blocks: + b.ckpt = False + + def _enable_activation_checkpointing(self): + self.template_pair_stack.blocks_per_ckpt = ( + self.config.template.template_pair_stack.blocks_per_ckpt + ) + self.evoformer.blocks_per_ckpt = ( + self.config.evoformer_stack.blocks_per_ckpt + ) + + for b in self.extra_msa_stack.blocks: + b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt + + def forward(self, batch): + """ + Args: + batch: + Dictionary of arguments outlined in Algorithm 2. Keys must + include the official names of the features in the + supplement subsection 1.2.9. + + The final dimension of each input must have length equal to + the number of recycling iterations. + + Features (without the recycling dimension): + + "aatype" ([*, N_res]): + Contrary to the supplement, this tensor of residue + indices is not one-hot. + "target_feat" ([*, N_res, C_tf]) + One-hot encoding of the target sequence. C_tf is + config.model.input_embedder.tf_dim. + "residue_index" ([*, N_res]) + Tensor whose final dimension consists of + consecutive indices from 0 to N_res. + "msa_feat" ([*, N_seq, N_res, C_msa]) + MSA features, constructed as in the supplement. + C_msa is config.model.input_embedder.msa_dim. + "seq_mask" ([*, N_res]) + 1-D sequence mask + "msa_mask" ([*, N_seq, N_res]) + MSA mask + "pair_mask" ([*, N_res, N_res]) + 2-D pair mask + "extra_msa_mask" ([*, N_extra, N_res]) + Extra MSA mask + "template_mask" ([*, N_templ]) + Template mask (on the level of templates, not + residues) + "template_aatype" ([*, N_templ, N_res]) + Tensor of template residue indices (indices greater + than 19 are clamped to 20 (Unknown)) + "template_all_atom_positions" + ([*, N_templ, N_res, 37, 3]) + Template atom coordinates in atom37 format + "template_all_atom_mask" ([*, N_templ, N_res, 37]) + Template atom coordinate mask + "template_pseudo_beta" ([*, N_templ, N_res, 3]) + Positions of template carbon "pseudo-beta" atoms + (i.e. C_beta for all residues but glycine, for + for which C_alpha is used instead) + "template_pseudo_beta_mask" ([*, N_templ, N_res]) + Pseudo-beta mask + """ + # Initialize recycling embeddings + m_1_prev, z_prev, x_prev = None, None, None + + # Disable activation checkpointing for the first few recycling iters + is_grad_enabled = torch.is_grad_enabled() + self._disable_activation_checkpointing() + + # Main recycling loop + num_iters = batch["aatype"].shape[-1] + for cycle_no in range(num_iters): + # Select the features for the current recycling cycle + fetch_cur_batch = lambda t: t[..., cycle_no] + feats = tensor_tree_map(fetch_cur_batch, batch) + + # Enable grad iff we're training and it's the final recycling layer + is_final_iter = cycle_no == (num_iters - 1) + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + if is_final_iter: + self._enable_activation_checkpointing() + # Sidestep AMP bug (PyTorch issue #65766) + if torch.is_autocast_enabled(): + torch.clear_autocast_cache() + + # Run the next iteration of the model + outputs, m_1_prev, z_prev, x_prev = self.iteration( + feats, + m_1_prev, + z_prev, + x_prev, + _recycle=(num_iters > 1) + ) + + # Run auxiliary heads + outputs.update(self.aux_heads(outputs)) + + return outputs diff --git a/openfold/model/msa.py b/openfold/model/msa.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3936f296a8cd5b7fa7f28a9d941668b2bd6c66 --- /dev/null +++ b/openfold/model/msa.py @@ -0,0 +1,392 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Optional, List, Tuple + +from openfold.model.primitives import ( + Linear, + LayerNorm, + Attention, + GlobalAttention, + _attention_chunked_trainable, +) +from openfold.utils.checkpointing import get_checkpoint_fn +from openfold.utils.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class MSAAttention(nn.Module): + def __init__( + self, + c_in, + c_hidden, + no_heads, + pair_bias=False, + c_z=None, + inf=1e9, + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + pair_bias: + Whether to use pair embedding bias + c_z: + Pair embedding channel dimension. Ignored unless pair_bias + is true + inf: + A large number to be used in computing the attention mask + """ + super(MSAAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.pair_bias = pair_bias + self.c_z = c_z + self.inf = inf + + self.layer_norm_m = LayerNorm(self.c_in) + + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(self.c_z) + self.linear_z = Linear( + self.c_z, self.no_heads, bias=False, init="normal" + ) + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + {"q_x": m, "kv_x": m, "biases": biases}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def _prep_inputs(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, N_seq, N_res, C_m] + m = self.layer_norm_m(m) + + n_seq, n_res = m.shape[-3:-1] + if mask is None: + # [*, N_seq, N_res] + mask = m.new_ones( + m.shape[:-3] + (n_seq, n_res), + ) + + # [*, N_seq, 1, 1, N_res] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # This step simply returns a larger view of the bias, and does not + # consume additional memory. + # [*, N_seq, no_heads, N_res, N_res] + #bias = bias.expand( + # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) + #) + + if (self.pair_bias and + z is not None and # For the + self.layer_norm_z is not None and # benefit of + self.linear_z is not None # TorchScript + ): + # [*, N_res, N_res, C_z] + z = self.layer_norm_z(z) + + # [*, N_res, N_res, no_heads] + z = self.linear_z(z) + + # [*, 1, no_heads, N_res, N_res] + z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) + + return m, mask_bias, z + + @torch.jit.ignore + def _chunked_msa_attn(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + chunk_logits: int, + checkpoint: bool, + ) -> torch.Tensor: + MSA_DIM = -4 + + def _get_qkv(m, z): + m, mask_bias, z = self._prep_inputs(m, z, mask) + q, k, v = self.mha._prep_qkv(m, m) + return m, q, k, v, mask_bias, z + + checkpoint_fn = get_checkpoint_fn() + + if(torch.is_grad_enabled() and checkpoint): + m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) + else: + m, q, k, v, mask_bias, z = _get_qkv(m, z) + + o = _attention_chunked_trainable( + query=q, + key=k, + value=v, + biases=[mask_bias, z], + chunk_size=chunk_logits, + chunk_dim=MSA_DIM, + checkpoint=checkpoint, + ) + + if(torch.is_grad_enabled() and checkpoint): + # Storing an additional m here is far from ideal + m = checkpoint_fn(self.mha._wrap_up, o, m) + else: + m = self.mha._wrap_up(o, m) + + return m + + def forward(self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = None, + _checkpoint_chunks: Optional[bool] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding. Required only if + pair_bias is True + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + + """ + if(_chunk_logits is not None): + return self._chunked_msa_attn( + m=m, z=z, mask=mask, + chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks + ) + + m, mask_bias, z = self._prep_inputs(m, z, mask) + + biases = [mask_bias] + if(z is not None): + biases.append(z) + + if chunk_size is not None: + m = self._chunk(m, biases, chunk_size) + else: + m = self.mha( + q_x=m, + kv_x=m, + biases=biases + ) + + return m + + +class MSARowAttentionWithPairBias(MSAAttention): + """ + Implements Algorithm 7. + """ + + def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + Input channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSARowAttentionWithPairBias, self).__init__( + c_m, + c_hidden, + no_heads, + pair_bias=True, + c_z=c_z, + inf=inf, + ) + + +class MSAColumnAttention(nn.Module): + """ + Implements Algorithm 8. + + By rights, this should also be a subclass of MSAAttention. Alas, + most inheritance isn't supported by TorchScript. + """ + + def __init__(self, c_m, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + MSA channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSAColumnAttention, self).__init__() + + self.c_m = c_m + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self._msa_att = MSAAttention( + c_in=c_m, + c_hidden=c_hidden, + no_heads=no_heads, + pair_bias=False, + c_z=None, + inf=inf, + ) + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + """ + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.layer_norm_m = nn.LayerNorm(c_in) + + self.global_attention = GlobalAttention( + c_in=c_in, + c_hidden=c_hidden, + no_heads=no_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_input = { + "m": m, + "mask": mask, + } + return chunk_layer( + self.global_attention, + mha_input, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + n_seq, n_res, c_in = m.shape[-3:] + + if mask is None: + # [*, N_seq, N_res] + mask = torch.ones( + m.shape[:-1], + dtype=m.dtype, + device=m.device, + ).detach() + + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, N_res, N_seq, C_in] + m = self.layer_norm_m(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self.global_attention(m=m, mask=mask) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + + return m diff --git a/openfold/model/outer_product_mean.py b/openfold/model/outer_product_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..2edef22e13235b1d40e403176b5b41828e10832f --- /dev/null +++ b/openfold/model/outer_product_mean.py @@ -0,0 +1,129 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear +from openfold.utils.tensor_utils import chunk_layer + + +class OuterProductMean(nn.Module): + """ + Implements Algorithm 10. + """ + + def __init__(self, c_m, c_z, c_hidden, eps=1e-3): + """ + Args: + c_m: + MSA embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(OuterProductMean, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.c_hidden = c_hidden + self.eps = eps + + self.layer_norm = nn.LayerNorm(c_m) + self.linear_1 = Linear(c_m, c_hidden) + self.linear_2 = Linear(c_m, c_hidden) + self.linear_out = Linear(c_hidden ** 2, c_z, init="final") + + def _opm(self, a, b): + # [*, N_res, N_res, C, C] + outer = torch.einsum("...bac,...dae->...bdce", a, b) + + # [*, N_res, N_res, C * C] + outer = outer.reshape(outer.shape[:-2] + (-1,)) + + # [*, N_res, N_res, C_z] + outer = self.linear_out(outer) + + return outer + + @torch.jit.ignore + def _chunk(self, + a: torch.Tensor, + b: torch.Tensor, + chunk_size: int + ) -> torch.Tensor: + # Since the "batch dim" in this case is not a true batch dimension + # (in that the shape of the output depends on it), we need to + # iterate over it ourselves + a_reshape = a.reshape((-1,) + a.shape[-3:]) + b_reshape = b.reshape((-1,) + b.shape[-3:]) + out = [] + for a_prime, b_prime in zip(a_reshape, b_reshape): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {"a": a_prime}, + chunk_size=chunk_size, + no_batch_dims=1, + ) + out.append(outer) + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, C_m] + m = self.layer_norm(m) + + # [*, N_seq, N_res, C] + mask = mask.unsqueeze(-1) + a = self.linear_1(m) * mask + b = self.linear_2(m) * mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + outer = self._chunk(a, b, chunk_size) + else: + outer = self._opm(a, b) + + # [*, N_res, N_res, 1] + norm = torch.einsum("...abc,...adc->...bdc", mask, mask) + + # [*, N_res, N_res, C_z] + outer = outer / (self.eps + norm) + + return outer diff --git a/openfold/model/pair_transition.py b/openfold/model/pair_transition.py new file mode 100644 index 0000000000000000000000000000000000000000..c81d2a4c0e087ffdb0c98d8cb0e91c5b24d6a0e8 --- /dev/null +++ b/openfold/model/pair_transition.py @@ -0,0 +1,99 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear, LayerNorm +from openfold.utils.tensor_utils import chunk_layer + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) * mask + + return z + + @torch.jit.ignore + def _chunk(self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"z": z, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + if chunk_size is not None: + z = self._chunk(z, mask, chunk_size) + else: + z = self._transition(z=z, mask=mask) + + return z diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..0742038e49b327661c3b6a137918f500e883cd04 --- /dev/null +++ b/openfold/model/primitives.py @@ -0,0 +1,587 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np + +import deepspeed +import torch +import torch.nn as nn +from scipy.stats import truncnorm + +from openfold.utils.checkpointing import get_checkpoint_fn +from openfold.utils.tensor_utils import ( + permute_final_dims, + flatten_final_dims, + _chunk_slice, +) + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class LayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + with torch.cuda.amp.autocast(enabled=False): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight.to(dtype=d), + self.bias.to(dtype=d), + self.eps + ) + else: + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + +@torch.jit.ignore +def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, Q, C_hidden] + query = permute_final_dims(query, (1, 0, 2)) + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 2, 0)) + + # [*, H, V, C_hidden] + value = permute_final_dims(value, (1, 0, 2)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax(a, -1) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + + # [*, Q, H, C_hidden] + a = a.transpose(-2, -3) + + return a + + +@torch.jit.ignore +def _attention_chunked_trainable( + query, key, value, biases, chunk_size, chunk_dim, checkpoint, +): + if(checkpoint and len(biases) > 2): + raise ValueError( + "Checkpointed version permits only permits two bias terms" + ) + + def _checkpointable_attention(q, k, v, b1, b2): + bs = [b for b in [b1, b2] if b is not None] + return _attention(q, k, v, bs) + + o_chunks = [] + checkpoint_fn = get_checkpoint_fn() + count = query.shape[chunk_dim] + for start in range(0, count, chunk_size): + end = start + chunk_size + idx = [slice(None)] * len(query.shape) + idx[chunk_dim] = slice(start, end) + idx_tup = tuple(idx) + q_chunk = query[idx_tup] + k_chunk = key[idx_tup] + v_chunk = value[idx_tup] + + def _slice_bias(b): + idx[chunk_dim] = ( + slice(start, end) if b.shape[chunk_dim] != 1 else slice(None) + ) + return b[tuple(idx)] + + if(checkpoint): + bias_1_chunk, bias_2_chunk = [ + _slice_bias(b) if b is not None else None + for b in (biases + [None, None])[:2] + ] + + o_chunk = checkpoint_fn(_checkpointable_attention, + q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk + ) + else: + bias_chunks = [ + _slice_bias(b) for b in biases + ] + + o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) + + o_chunks.append(o_chunk) + + o = torch.cat(o_chunks, dim=chunk_dim) + return o + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear( + self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_k = Linear( + self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_v = Linear( + self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_o = Linear( + self.c_hidden * self.no_heads, self.c_q, init="final" + ) + + self.linear_g = None + if self.gating: + self.linear_g = Linear( + self.c_q, self.c_hidden * self.no_heads, init="gating" + ) + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, + q_x: torch.Tensor, + kv_x: torch.Tensor + ) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor + ]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, + o: torch.Tensor, + q_x: torch.Tensor + ) -> torch.Tensor: + if(self.linear_g is not None): + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_lma: bool = False, + q_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_lma: + Whether to use low-memory attention + q_chunk_size: + Query chunk size (for LMA) + kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if(biases is None): + biases = [] + if(use_lma and (q_chunk_size is None or kv_chunk_size is None)): + raise ValueError( + "If use_lma is specified, q_chunk_size and kv_chunk_size must " + "be provided" + ) + + q, k, v = self._prep_qkv(q_x, kv_x) + + if(use_lma): + biases = [ + b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) + for b in biases + ] + + o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) + else: + o = _attention(q, k, v, biases) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear( + c_in, c_hidden * no_heads, bias=False, init="glorot" + ) + + self.linear_k = Linear( + c_in, c_hidden, bias=False, init="glorot", + ) + self.linear_v = Linear( + c_in, c_hidden, bias=False, init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( + torch.sum(mask, dim=-1)[..., None] + self.eps + ) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden ** (-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + bias = (self.inf * (mask - 1))[..., :, None, :] + a += bias + a = softmax(a) + + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, + v, + ) + + # [*, N_res, N_seq, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, N_seq, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, N_seq, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, N_seq, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-3], k.shape[-3] + + # [*, Q, H, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] + large_bias_chunks = [ + b[..., q_s: q_s + q_chunk_size, :] for b in biases + ] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] + v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] + small_bias_chunks = [ + b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks + ] + + a = torch.einsum( + "...qhd,...khd->...hqk", q_chunk, k_chunk, + ) + + for b in small_bias_chunks: + a += b + + a = a.transpose(-2, -3) + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= max_diffs.unsqueeze(-1) + chunk_weights *= max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out + + return o diff --git a/openfold/model/structure_module.py b/openfold/model/structure_module.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6ba1fe42b995b123a43095bc51a947999b102b --- /dev/null +++ b/openfold/model/structure_module.py @@ -0,0 +1,820 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import reduce +import importlib +import math +import sys +from operator import mul + +import torch +import torch.nn as nn +from typing import Optional, Tuple, Sequence + +from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ +from openfold.np.residue_constants import ( + restype_rigid_group_default_frame, + restype_atom14_to_rigid_group, + restype_atom14_mask, + restype_atom14_rigid_group_positions, +) +from openfold.utils.feats import ( + frames_and_literature_positions_to_atom14_pos, + torsion_angles_to_frames, +) +from openfold.utils.precision_utils import is_fp16_enabled +from openfold.utils.rigid_utils import Rotation, Rigid +from openfold.utils.tensor_utils import ( + dict_multimap, + permute_final_dims, + flatten_final_dims, +) + +# attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") + + +class AngleResnetBlock(nn.Module): + def __init__(self, c_hidden): + """ + Args: + c_hidden: + Hidden channel dimension + """ + super(AngleResnetBlock, self).__init__() + + self.c_hidden = c_hidden + + self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class AngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Hidden channel dimension + no_blocks: + Number of resnet blocks + no_angles: + Number of torsion angles to generate + epsilon: + Small constant for normalization + """ + super(AngleResnet, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_blocks = no_blocks + self.no_angles = no_angles + self.eps = epsilon + + self.linear_in = Linear(self.c_in, self.c_hidden) + self.linear_initial = Linear(self.c_in, self.c_hidden) + + self.layers = nn.ModuleList() + for _ in range(self.no_blocks): + layer = AngleResnetBlock(c_hidden=self.c_hidden) + self.layers.append(layer) + + self.linear_out = Linear(self.c_hidden, self.no_angles * 2) + + self.relu = nn.ReLU() + + def forward( + self, s: torch.Tensor, s_initial: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the + StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s ** 2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc) + self.linear_kv = Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.linear_q_points = Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.linear_kv_points = Linear(self.c_s, hpkv) + + hpv = self.no_heads * self.no_v_points * 3 + + self.linear_b = Linear(self.c_z, self.no_heads) + + self.head_weights = nn.Parameter(torch.zeros((no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.no_heads * ( + self.c_z + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + inplace_safe: bool = False, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + if(_offload_inference and inplace_safe): + z = _z_reference_list + else: + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view( + q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) + ) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if(_offload_inference): + assert(sys.getrefcount(z[0]) == 2) + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + if(is_fp16_enabled()): + with torch.cuda.amp.autocast(enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + if(inplace_safe): + pt_att *= pt_att + else: + pt_att = pt_att ** 2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view( + *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + ) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + ) + if(inplace_safe): + pt_att *= head_weights + else: + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + if(inplace_safe): + a += pt_att + del pt_att + a += square_mask.unsqueeze(-3) + # in-place softmax + attn_core_inplace_cuda.forward_( + a, + reduce(mul, a.shape[:-1]), + a.shape[-1], + ) + else: + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, v.transpose(-2, -3).to(dtype=a.dtype) + ).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + if(inplace_safe): + v_pts = permute_final_dims(v_pts, (1, 3, 0, 2)) + o_pt = [ + torch.matmul(a, v.to(a.dtype)) + for v in torch.unbind(v_pts, dim=-3) + ] + o_pt = torch.stack(o_pt, dim=-3) + else: + o_pt = torch.sum( + ( + a[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + ), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims( + torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2 + ) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if(_offload_inference): + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 + ).to(dtype=z[0].dtype) + ) + + return s + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + + self.linear = Linear(self.c_s, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class StructureModuleTransitionLayer(nn.Module): + def __init__(self, c): + super(StructureModuleTransitionLayer, self).__init__() + + self.c = c + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + self.linear_3 = Linear(self.c, self.c, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class StructureModuleTransition(nn.Module): + def __init__(self, c, num_layers, dropout_rate): + super(StructureModuleTransition, self).__init__() + + self.c = c + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = nn.ModuleList() + for _ in range(self.num_layers): + l = StructureModuleTransitionLayer(self.c) + self.layers.append(l) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(self.c) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + def __init__( + self, + c_s, + c_z, + c_ipa, + c_resnet, + no_heads_ipa, + no_qk_points, + no_v_points, + dropout_rate, + no_blocks, + no_transition_layers, + no_resnet_blocks, + no_angles, + trans_scale_factor, + epsilon, + inf, + **kwargs, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_ipa: + IPA hidden channel dimension + c_resnet: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + no_heads_ipa: + Number of IPA heads + no_qk_points: + Number of query/key points to generate during IPA + no_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + no_blocks: + Number of structure module blocks + no_transition_layers: + Number of layers in the single representation transition + (Alg. 23 lines 8-9) + no_resnet_blocks: + Number of blocks in the angle resnet + no_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + """ + super(StructureModule, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_ipa = c_ipa + self.c_resnet = c_resnet + self.no_heads_ipa = no_heads_ipa + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.dropout_rate = dropout_rate + self.no_blocks = no_blocks + self.no_transition_layers = no_transition_layers + self.no_resnet_blocks = no_resnet_blocks + self.no_angles = no_angles + self.trans_scale_factor = trans_scale_factor + self.epsilon = epsilon + self.inf = inf + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(self.c_s) + self.layer_norm_z = LayerNorm(self.c_z) + + self.linear_in = Linear(self.c_s, self.c_s) + + self.ipa = InvariantPointAttention( + self.c_s, + self.c_z, + self.c_ipa, + self.no_heads_ipa, + self.no_qk_points, + self.no_v_points, + inf=self.inf, + eps=self.epsilon, + ) + + self.ipa_dropout = nn.Dropout(self.dropout_rate) + self.layer_norm_ipa = LayerNorm(self.c_s) + + self.transition = StructureModuleTransition( + self.c_s, + self.no_transition_layers, + self.dropout_rate, + ) + + self.bb_update = BackboneUpdate(self.c_s) + + self.angle_resnet = AngleResnet( + self.c_s, + self.c_resnet, + self.no_resnet_blocks, + self.no_angles, + self.epsilon, + ) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + inplace_safe=False, + _offload_inference=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + z_reference_list = None + if(_offload_inference): + assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2) + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() + z_reference_list = [z] + z = None + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + inplace_safe=inplace_safe, + _offload_inference=_offload_inference, + _z_reference_list=z_reference_list + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation( + rot_mats=rigids.get_rots().get_rot_mats(), + quats=None + ), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation( + self.trans_scale_factor + ) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global, + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + scaled_rigids = rigids.scale_translation(self.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z, z_reference_list + + if(_offload_inference): + evoformer_output_dict["pair"] = ( + evoformer_output_dict["pair"].to(s.device) + ) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos( + self, r, f # [*, N, 8] # [*, N] + ): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/openfold/model/template.py b/openfold/model/template.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e1fd95bd0fa073b55be7714a794f14ae0234d0 --- /dev/null +++ b/openfold/model/template.py @@ -0,0 +1,333 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import math +from typing import Optional, List + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear, LayerNorm, Attention +from openfold.model.dropout import ( + DropoutRowwise, + DropoutColumnwise, +) +from openfold.model.pair_transition import PairTransition +from openfold.model.triangular_attention import ( + TriangleAttentionStartingNode, + TriangleAttentionEndingNode, +) +from openfold.model.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) +from openfold.utils.checkpointing import checkpoint_blocks +from openfold.utils.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class TemplatePointwiseAttention(nn.Module): + """ + Implements Algorithm 17. + """ + def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): + """ + Args: + c_t: + Template embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(TemplatePointwiseAttention, self).__init__() + + self.c_t = c_t + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self.mha = Attention( + self.c_z, + self.c_t, + self.c_t, + self.c_hidden, + self.no_heads, + gating=False, + ) + + def _chunk(self, + z: torch.Tensor, + t: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": z, + "kv_x": t, + "biases": biases, + } + return chunk_layer( + self.mha, + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + + def forward(self, + t: torch.Tensor, + z: torch.Tensor, + template_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + t: + [*, N_templ, N_res, N_res, C_t] template embedding + z: + [*, N_res, N_res, C_t] pair embedding + template_mask: + [*, N_templ] template mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if template_mask is None: + template_mask = t.new_ones(t.shape[:-3]) + + bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) + + # [*, N_res, N_res, 1, C_z] + z = z.unsqueeze(-2) + + # [*, N_res, N_res, N_temp, C_t] + t = permute_final_dims(t, (1, 2, 0, 3)) + + # [*, N_res, N_res, 1, C_z] + biases = [bias] + if chunk_size is not None: + z = self._chunk(z, t, biases, chunk_size) + else: + z = self.mha(q_x=z, kv_x=t, biases=biases) + + # [*, N_res, N_res, C_z] + z = z.squeeze(-2) + + return z + + +class TemplatePairStackBlock(nn.Module): + def __init__( + self, + c_t: int, + c_hidden_tri_att: int, + c_hidden_tri_mul: int, + no_heads: int, + pair_transition_n: int, + dropout_rate: float, + inf: float, + **kwargs, + ): + super(TemplatePairStackBlock, self).__init__() + + self.c_t = c_t + self.c_hidden_tri_att = c_hidden_tri_att + self.c_hidden_tri_mul = c_hidden_tri_mul + self.no_heads = no_heads + self.pair_transition_n = pair_transition_n + self.dropout_rate = dropout_rate + self.inf = inf + + self.dropout_row = DropoutRowwise(self.dropout_rate) + self.dropout_col = DropoutColumnwise(self.dropout_rate) + + self.tri_att_start = TriangleAttentionStartingNode( + self.c_t, + self.c_hidden_tri_att, + self.no_heads, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + self.c_t, + self.c_hidden_tri_att, + self.no_heads, + inf=inf, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + self.c_t, + self.c_hidden_tri_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + self.c_t, + self.c_hidden_tri_mul, + ) + + self.pair_transition = PairTransition( + self.c_t, + self.pair_transition_n, + ) + + def forward(self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True + ): + single_templates = [ + t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) + ] + single_templates_masks = [ + m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) + ] + for i in range(len(single_templates)): + single = single_templates[i] + single_mask = single_templates_masks[i] + + single = single + self.dropout_row( + self.tri_att_start( + single, + chunk_size=chunk_size, + mask=single_mask + ) + ) + single = single + self.dropout_col( + self.tri_att_end( + single, + chunk_size=chunk_size, + mask=single_mask + ) + ) + single = single + self.dropout_row( + self.tri_mul_out( + single, + mask=single_mask + ) + ) + single = single + self.dropout_row( + self.tri_mul_in( + single, + mask=single_mask + ) + ) + single = single + self.pair_transition( + single, + mask=single_mask if _mask_trans else None, + chunk_size=chunk_size, + ) + + single_templates[i] = single + + z = torch.cat(single_templates, dim=-4) + + return z + + +class TemplatePairStack(nn.Module): + """ + Implements Algorithm 16. + """ + def __init__( + self, + c_t, + c_hidden_tri_att, + c_hidden_tri_mul, + no_blocks, + no_heads, + pair_transition_n, + dropout_rate, + blocks_per_ckpt, + inf=1e9, + **kwargs, + ): + """ + Args: + c_t: + Template embedding channel dimension + c_hidden_tri_att: + Per-head hidden dimension for triangular attention + c_hidden_tri_att: + Hidden dimension for triangular multiplication + no_blocks: + Number of blocks in the stack + pair_transition_n: + Scale of pair transition (Alg. 15) hidden dimension + dropout_rate: + Dropout rate used throughout the stack + blocks_per_ckpt: + Number of blocks per activation checkpoint. None disables + activation checkpointing + """ + super(TemplatePairStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = TemplatePairStackBlock( + c_t=c_t, + c_hidden_tri_att=c_hidden_tri_att, + c_hidden_tri_mul=c_hidden_tri_mul, + no_heads=no_heads, + pair_transition_n=pair_transition_n, + dropout_rate=dropout_rate, + inf=inf, + ) + self.blocks.append(block) + + self.layer_norm = LayerNorm(c_t) + + def forward( + self, + t: torch.tensor, + mask: torch.tensor, + chunk_size: int, + _mask_trans: bool = True, + ): + """ + Args: + t: + [*, N_templ, N_res, N_res, C_t] template embedding + mask: + [*, N_templ, N_res, N_res] mask + Returns: + [*, N_templ, N_res, N_res, C_t] template embedding update + """ + if(mask.shape[-3] == 1): + expand_idx = list(mask.shape) + expand_idx[-3] = t.shape[-4] + mask = mask.expand(*expand_idx) + + t, = checkpoint_blocks( + blocks=[ + partial( + b, + mask=mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + for b in self.blocks + ], + args=(t,), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + t = self.layer_norm(t) + + return t diff --git a/openfold/model/torchscript.py b/openfold/model/torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..cac7e6725b04f9383bd27cb5599b4e5172462de5 --- /dev/null +++ b/openfold/model/torchscript.py @@ -0,0 +1,215 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from openfold.model.dropout import ( + DropoutRowwise, + DropoutColumnwise, +) +from openfold.model.evoformer import ( + EvoformerBlock, + EvoformerStack, +) +from openfold.model.outer_product_mean import OuterProductMean +from openfold.model.msa import ( + MSARowAttentionWithPairBias, + MSAColumnAttention, + MSAColumnGlobalAttention, +) +from openfold.model.pair_transition import PairTransition +from openfold.model.primitives import Attention, GlobalAttention +from openfold.model.structure_module import ( + InvariantPointAttention, + BackboneUpdate, +) +from openfold.model.template import TemplatePairStackBlock +from openfold.model.triangular_attention import ( + TriangleAttentionStartingNode, + TriangleAttentionEndingNode, +) +from openfold.model.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) + + +def script_preset_(model: torch.nn.Module): + """ + TorchScript a handful of low-level but frequently used submodule types + that are known to be scriptable. + + Args: + model: + A torch.nn.Module. It should contain at least some modules from + this repository, or this function won't do anything. + """ + script_submodules_( + model, + [ + nn.Dropout, + Attention, + GlobalAttention, + EvoformerBlock, + #TemplatePairStackBlock, + ], + attempt_trace=False, + batch_dims=None, + ) + + +def _get_module_device(module: torch.nn.Module) -> torch.device: + """ + Fetches the device of a module, assuming that all of the module's + parameters reside on a single device + + Args: + module: A torch.nn.Module + Returns: + The module's device + """ + return next(module.parameters()).device + + +def _trace_module(module, batch_dims=None): + if(batch_dims is None): + batch_dims = () + + # Stand-in values + n_seq = 10 + n_res = 10 + + device = _get_module_device(module) + + def msa(channel_dim): + return torch.rand( + (*batch_dims, n_seq, n_res, channel_dim), + device=device, + ) + + def pair(channel_dim): + return torch.rand( + (*batch_dims, n_res, n_res, channel_dim), + device=device, + ) + + if(isinstance(module, MSARowAttentionWithPairBias)): + inputs = { + "forward": ( + msa(module.c_in), # m + pair(module.c_z), # z + torch.randint( + 0, 2, + (*batch_dims, n_seq, n_res) + ), # mask + ), + } + elif(isinstance(module, MSAColumnAttention)): + inputs = { + "forward": ( + msa(module.c_in), # m + torch.randint( + 0, 2, + (*batch_dims, n_seq, n_res) + ), # mask + ), + } + elif(isinstance(module, OuterProductMean)): + inputs = { + "forward": ( + msa(module.c_m), + torch.randint( + 0, 2, + (*batch_dims, n_seq, n_res) + ) + ) + } + else: + raise TypeError( + f"tracing is not supported for modules of type {type(module)}" + ) + + return torch.jit.trace_module(module, inputs) + + +def _script_submodules_helper_( + model, + types, + attempt_trace, + to_trace, +): + for name, child in model.named_children(): + if(types is None or any(isinstance(child, t) for t in types)): + try: + scripted = torch.jit.script(child) + setattr(model, name, scripted) + continue + except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: + if(attempt_trace): + to_trace.add(type(child)) + else: + raise e + + _script_submodules_helper_(child, types, attempt_trace, to_trace) + + +def _trace_submodules_( + model, + types, + batch_dims=None, +): + for name, child in model.named_children(): + if(any(isinstance(child, t) for t in types)): + traced = _trace_module(child, batch_dims=batch_dims) + setattr(model, name, traced) + else: + _trace_submodules_(child, types, batch_dims=batch_dims) + + +def script_submodules_( + model: nn.Module, + types: Optional[Sequence[type]] = None, + attempt_trace: Optional[bool] = True, + batch_dims: Optional[Tuple[int]] = None, +): + """ + Convert all submodules whose types match one of those in the input + list to recursively scripted equivalents in place. To script the entire + model, just call torch.jit.script on it directly. + + When types is None, all submodules are scripted. + + Args: + model: + A torch.nn.Module + types: + A list of types of submodules to script + attempt_trace: + Whether to attempt to trace specified modules if scripting + fails. Recall that tracing eliminates all conditional + logic---with great tracing comes the mild responsibility of + having to remember to ensure that the modules in question + perform the same computations no matter what. + """ + to_trace = set() + + # Aggressively script as much as possible first... + _script_submodules_helper_(model, types, attempt_trace, to_trace) + + # ... and then trace stragglers. + if(attempt_trace and len(to_trace) > 0): + _trace_submodules_(model, to_trace, batch_dims=batch_dims) diff --git a/openfold/model/triangular_attention.py b/openfold/model/triangular_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..91fa2bb62a110825cc584eb2aff02e753ad3b307 --- /dev/null +++ b/openfold/model/triangular_attention.py @@ -0,0 +1,139 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod, partial +import math +from typing import Optional, List + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear, LayerNorm, Attention +from openfold.utils.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class TriangleAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, starting, inf=1e9 + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super(TriangleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + return chunk_layer( + partial(self.mha), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + ) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + # Shape annotations assume self.starting. Else, I and J are flipped + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk(x, biases, chunk_size) + else: + x = self.mha(q_x=x, kv_x=x, biases=biases) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class TriangleAttentionStartingNode(TriangleAttention): + """ + Implements Algorithm 13. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEndingNode(TriangleAttention): + """ + Implements Algorithm 14. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/openfold/model/triangular_multiplicative_update.py b/openfold/model/triangular_multiplicative_update.py new file mode 100644 index 0000000000000000000000000000000000000000..2edd24f6f8ccab7e73e577ddc2aaa6d2443757ed --- /dev/null +++ b/openfold/model/triangular_multiplicative_update.py @@ -0,0 +1,127 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.model.primitives import Linear, LayerNorm +from openfold.utils.tensor_utils import permute_final_dims + + +class TriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + def __init__(self, c_z, c_hidden, _outgoing=True): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(TriangleMultiplicativeUpdate, self).__init__() + self.c_z = c_z + self.c_hidden = c_hidden + self._outgoing = _outgoing + + self.linear_a_p = Linear(self.c_z, self.c_hidden) + self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_b_p = Linear(self.c_z, self.c_hidden) + self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_g = Linear(self.c_z, self.c_z, init="gating") + self.linear_z = Linear(self.c_hidden, self.c_z, init="final") + + self.layer_norm_in = LayerNorm(self.c_z) + self.layer_norm_out = LayerNorm(self.c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections(self, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("This method needs to be overridden") + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) + a = a * mask + b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) + b = b * mask + x = self._combine_projections(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + z = x * g + + return z + + +class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 11. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_i, N_k, C] + b: torch.Tensor, # [*, N_j, N_k, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 0, 1)), + permute_final_dims(b, (2, 1, 0)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + + +class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 12. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_k, N_i, C] + b: torch.Tensor, # [*, N_k, N_j, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 1, 0)), + permute_final_dims(b, (2, 0, 1)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + diff --git a/openfold/np/__init__.py b/openfold/np/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1 --- /dev/null +++ b/openfold/np/__init__.py @@ -0,0 +1,16 @@ +import os +import glob +import importlib as importlib + +_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) +__all__ = [ + os.path.basename(f)[:-3] + for f in _files + if os.path.isfile(f) and not f.endswith("__init__.py") +] +_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] +for _m in _modules: + globals()[_m[0]] = _m[1] + +# Avoid needlessly cluttering the global namespace +del _files, _m, _modules diff --git a/openfold/np/__pycache__/__init__.cpython-310.pyc b/openfold/np/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a92827dd828cfadf9345a7606211b74ca975339c Binary files /dev/null and b/openfold/np/__pycache__/__init__.cpython-310.pyc differ diff --git a/openfold/np/__pycache__/protein.cpython-310.pyc b/openfold/np/__pycache__/protein.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4502b1209825e70267a93d21b45f8ef9a4b56e7 Binary files /dev/null and b/openfold/np/__pycache__/protein.cpython-310.pyc differ diff --git a/openfold/np/__pycache__/residue_constants.cpython-310.pyc b/openfold/np/__pycache__/residue_constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b27df6a398b16067a61abd35678619c474276f1 Binary files /dev/null and b/openfold/np/__pycache__/residue_constants.cpython-310.pyc differ diff --git a/openfold/np/protein.py b/openfold/np/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..75748b27d768d7765f1823c3eada5f7bcd0e4099 --- /dev/null +++ b/openfold/np/protein.py @@ -0,0 +1,438 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Sequence, Mapping, Optional +import re +import string + +from openfold.np import residue_constants +from Bio.PDB import PDBParser +import numpy as np + + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. +PICO_TO_ANGSTROM = 0.01 + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + # Chain indices for multi-chain predictions + chain_index: Optional[np.ndarray] = None + + # Optional remark about the protein. Included as a comment in output PDB + # files + remark: Optional[str] = None + + # Templates used to generate this protein (prediction-only) + parents: Optional[Sequence[str]] = None + + # Chain corresponding to each parent + parents_chain_index: Optional[Sequence[int]] = None + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If None, then the pdb file must contain a single chain (which + will be parsed). If chain_id is specified (e.g. A), then only that chain + is parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure("none", pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f"Only single model PDBs are supported. Found {len(models)} models." + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if(chain_id is not None and chain.id != chain_id): + continue + for res in chain: + if res.id[2] != " ": + raise ValueError( + f"PDB contains an insertion code at chain {chain.id} and residue " + f"index {res.id[1]}. These are not supported." + ) + res_shortname = residue_constants.restype_3to1.get(res.resname, "X") + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num + ) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[ + residue_constants.atom_order[atom.name] + ] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + parents = None + parents_chain_index = None + if("PARENT" in pdb_str): + parents = [] + parents_chain_index = [] + chain_id = 0 + for l in pdb_str.split("\n"): + if("PARENT" in l): + if(not "N/A" in l): + parent_names = l.split()[1:] + parents.extend(parent_names) + parents_chain_index.extend([ + chain_id for _ in parent_names + ]) + chain_id += 1 + + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + parents=parents, + parents_chain_index=parents_chain_index, + ) + + +def from_proteinnet_string(proteinnet_str: str) -> Protein: + tag_re = r'(\[[A-Z]+\]\n)' + tags = [ + tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0 + ] + groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) + + atoms = ['N', 'CA', 'C'] + aatype = None + atom_positions = None + atom_mask = None + for g in groups: + if("[PRIMARY]" == g[0]): + seq = g[1][0].strip() + for i in range(len(seq)): + if(seq[i] not in residue_constants.restypes): + seq[i] = 'X' + aatype = np.array([ + residue_constants.restype_order.get( + res_symbol, residue_constants.restype_num + ) for res_symbol in seq + ]) + elif("[TERTIARY]" == g[0]): + tertiary = [] + for axis in range(3): + tertiary.append(list(map(float, g[1][axis].split()))) + tertiary_np = np.array(tertiary) + atom_positions = np.zeros( + (len(tertiary[0])//3, residue_constants.atom_type_num, 3) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_positions[:, residue_constants.atom_order[atom], :] = ( + np.transpose(tertiary_np[:, i::3]) + ) + atom_positions *= PICO_TO_ANGSTROM + elif("[MASK]" == g[0]): + mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip()))) + atom_mask = np.zeros( + (len(mask), residue_constants.atom_type_num,) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_mask[:, residue_constants.atom_order[atom]] = 1 + atom_mask *= mask[..., None] + + return Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=np.arange(len(aatype)), + b_factors=None, + ) + + +def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: + pdb_headers = [] + + remark = prot.remark + if(remark is not None): + pdb_headers.append(f"REMARK {remark}") + + parents = prot.parents + parents_chain_index = prot.parents_chain_index + if(parents_chain_index is not None): + parents = [ + p for i, p in zip(parents_chain_index, parents) if i == chain_id + ] + + if(parents is None or len(parents) == 0): + parents = ["N/A"] + + pdb_headers.append(f"PARENT {' '.join(parents)}") + + return pdb_headers + + +def add_pdb_headers(prot: Protein, pdb_str: str) -> str: + """ Add pdb headers to an existing PDB string. Useful during multi-chain + recycling + """ + out_pdb_lines = [] + lines = pdb_str.split('\n') + + remark = prot.remark + if(remark is not None): + out_pdb_lines.append(f"REMARK {remark}") + + parents_per_chain = None + if(prot.parents is not None and len(prot.parents) > 0): + parents_per_chain = [] + if(prot.parents_chain_index is not None): + cur_chain = prot.parents_chain_index[0] + parent_dict = {} + for p, i in zip(prot.parents, prot.parents_chain_index): + parent_dict.setdefault(str(i), []) + parent_dict[str(i)].append(p) + + max_idx = max([int(chain_idx) for chain_idx in parent_dict]) + for i in range(max_idx + 1): + chain_parents = parent_dict.get(str(i), ["N/A"]) + parents_per_chain.append(chain_parents) + else: + parents_per_chain.append(prot.parents) + else: + parents_per_chain = [["N/A"]] + + make_parent_line = lambda p: f"PARENT {' '.join(p)}" + + out_pdb_lines.append(make_parent_line(parents_per_chain[0])) + + chain_counter = 0 + for i, l in enumerate(lines): + if("PARENT" not in l and "REMARK" not in l): + out_pdb_lines.append(l) + if("TER" in l and not "END" in lines[i + 1]): + chain_counter += 1 + if(not chain_counter >= len(parents_per_chain)): + chain_parents = parents_per_chain[chain_counter] + else: + chain_parents = ["N/A"] + + out_pdb_lines.append(make_parent_line(chain_parents)) + + return '\n'.join(out_pdb_lines) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ["X"] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK") + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(int) + b_factors = prot.b_factors + chain_index = prot.chain_index + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + headers = get_pdb_headers(prot) + if(len(headers) > 0): + pdb_lines.extend(headers) + + n = aatype.shape[0] + atom_index = 1 + prev_chain_index = 0 + chain_tags = string.ascii_uppercase + # Add all atom sites. + for i in range(n): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[ + 0 + ] # Protein supports only C, N, O, S, this works. + charge = "" + + chain_tag = "A" + if(chain_index is not None): + chain_tag = chain_tags[chain_index[i]] + + # PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + f"{res_name_3:>3} {chain_tag:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + should_terminate = (i == n - 1) + if(chain_index is not None): + if(i != n - 1 and chain_index[i + 1] != prev_chain_index): + should_terminate = True + prev_chain_index = chain_index[i + 1] + + if(should_terminate): + # Close the chain. + chain_end = "TER" + chain_termination_line = ( + f"{chain_end:<6}{atom_index:>5} " + f"{res_1to3(aatype[i]):>3} " + f"{chain_tag:>1}{residue_index[i]:>4}" + ) + pdb_lines.append(chain_termination_line) + atom_index += 1 + + if(i != n - 1): + # "prev" is a misnomer here. This happens at the beginning of + # each new chain. + pdb_lines.extend(get_pdb_headers(prot, prev_chain_index)) + + pdb_lines.append("END") + pdb_lines.append("") + return "\n".join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + chain_index: Optional[np.ndarray] = None, + remark: Optional[str] = None, + parents: Optional[Sequence[str]] = None, + parents_chain_index: Optional[Sequence[int]] = None +) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + chain_index: (Optional) Chain indices for multi-chain predictions + remark: (Optional) Remark about the prediction + parents: (Optional) List of template names + Returns: + A protein instance. + """ + if b_factors is None: + b_factors = np.zeros_like(result["final_atom_mask"]) + + return Protein( + aatype=features["aatype"], + atom_positions=result["final_atom_positions"], + atom_mask=result["final_atom_mask"], + residue_index=features["residue_index"] + 1, + b_factors=b_factors, + chain_index=chain_index, + remark=remark, + parents=parents, + parents_chain_index=parents_chain_index, + ) diff --git a/openfold/np/relax/__init__.py b/openfold/np/relax/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1 --- /dev/null +++ b/openfold/np/relax/__init__.py @@ -0,0 +1,16 @@ +import os +import glob +import importlib as importlib + +_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) +__all__ = [ + os.path.basename(f)[:-3] + for f in _files + if os.path.isfile(f) and not f.endswith("__init__.py") +] +_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] +for _m in _modules: + globals()[_m[0]] = _m[1] + +# Avoid needlessly cluttering the global namespace +del _files, _m, _modules diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py new file mode 100644 index 0000000000000000000000000000000000000000..8f35a2d9fdf56fca3e89d822b7c6a8f56a265ccd --- /dev/null +++ b/openfold/np/relax/amber_minimize.py @@ -0,0 +1,612 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Restrained Amber Minimization of a structure.""" + +import io +import time +from typing import Collection, Optional, Sequence + +from absl import logging +from openfold.np import ( + protein, + residue_constants, +) +import openfold.utils.loss as loss +from openfold.np.relax import cleanup, utils +import ml_collections +import numpy as np +import openmm +from openmm import unit +from openmm import app as openmm_app +from openmm.app.internal.pdbstructure import PdbStructure + +ENERGY = unit.kilocalories_per_mole +LENGTH = unit.angstroms + + +def will_restrain(atom: openmm_app.Atom, rset: str) -> bool: + """Returns True if the atom will be restrained by the given restraint set.""" + + if rset == "non_hydrogen": + return atom.element.name != "hydrogen" + elif rset == "c_alpha": + return atom.name == "CA" + + +def _add_restraints( + system: openmm.System, + reference_pdb: openmm_app.PDBFile, + stiffness: unit.Unit, + rset: str, + exclude_residues: Sequence[int], +): + """Adds a harmonic potential that restrains the system to a structure.""" + assert rset in ["non_hydrogen", "c_alpha"] + + force = openmm.CustomExternalForce( + "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)" + ) + force.addGlobalParameter("k", stiffness) + for p in ["x0", "y0", "z0"]: + force.addPerParticleParameter(p) + + for i, atom in enumerate(reference_pdb.topology.atoms()): + if atom.residue.index in exclude_residues: + continue + if will_restrain(atom, rset): + force.addParticle(i, reference_pdb.positions[i]) + logging.info( + "Restraining %d / %d particles.", + force.getNumParticles(), + system.getNumParticles(), + ) + system.addForce(force) + + +def _openmm_minimize( + pdb_str: str, + max_iterations: int, + tolerance: unit.Unit, + stiffness: unit.Unit, + restraint_set: str, + exclude_residues: Sequence[int], + use_gpu: bool, +): + """Minimize energy via openmm.""" + + pdb_file = io.StringIO(pdb_str) + pdb = openmm_app.PDBFile(pdb_file) + + force_field = openmm_app.ForceField("amber99sb.xml") + constraints = openmm_app.HBonds + system = force_field.createSystem(pdb.topology, constraints=constraints) + if stiffness > 0 * ENERGY / (LENGTH ** 2): + _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues) + + integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) + platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU") + simulation = openmm_app.Simulation( + pdb.topology, system, integrator, platform + ) + simulation.context.setPositions(pdb.positions) + + ret = {} + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance) + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions()) + return ret + + +def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity): + """Returns a pdb string provided OpenMM topology and positions.""" + with io.StringIO() as f: + openmm_app.PDBFile.writeFile(topology, positions, f) + return f.getvalue() + + +def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str): + """Checks that no atom positions have been altered by cleaning.""" + cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string)) + reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string)) + + cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH)) + ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH)) + + for ref_res, cl_res in zip( + reference.topology.residues(), cleaned.topology.residues() + ): + assert ref_res.name == cl_res.name + for rat in ref_res.atoms(): + for cat in cl_res.atoms(): + if cat.name == rat.name: + if not np.array_equal( + cl_xyz[cat.index], ref_xyz[rat.index] + ): + raise ValueError( + f"Coordinates of cleaned atom {cat} do not match " + f"coordinates of reference atom {rat}." + ) + + +def _check_residues_are_well_defined(prot: protein.Protein): + """Checks that all residues contain non-empty atom sets.""" + if (prot.atom_mask.sum(axis=-1) == 0).any(): + raise ValueError( + "Amber minimization can only be performed on proteins with" + " well-defined residues. This protein contains at least" + " one residue with no atoms." + ) + + +def _check_atom_mask_is_ideal(prot): + """Sanity-check the atom mask is ideal, up to a possible OXT.""" + atom_mask = prot.atom_mask + ideal_atom_mask = protein.ideal_atom_mask(prot) + utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask) + + +def clean_protein(prot: protein.Protein, checks: bool = True): + """Adds missing atoms to Protein instance. + + Args: + prot: A `protein.Protein` instance. + checks: A `bool` specifying whether to add additional checks to the cleaning + process. + + Returns: + pdb_string: A string of the cleaned protein. + """ + _check_atom_mask_is_ideal(prot) + + # Clean pdb. + prot_pdb_string = protein.to_pdb(prot) + pdb_file = io.StringIO(prot_pdb_string) + alterations_info = {} + fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info) + fixed_pdb_file = io.StringIO(fixed_pdb) + pdb_structure = PdbStructure(fixed_pdb_file) + cleanup.clean_structure(pdb_structure, alterations_info) + + logging.info("alterations info: %s", alterations_info) + + # Write pdb file of cleaned structure. + as_file = openmm_app.PDBFile(pdb_structure) + pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions()) + if checks: + _check_cleaned_atoms(pdb_string, prot_pdb_string) + return pdb_string + + +def make_atom14_positions(prot): + """Constructs denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt] + ] + + restype_atom14_to_atom37.append( + [ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ] + ) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append( + [ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ] + ) + + restype_atom14_mask.append( + [(1.0 if name else 0.0) for name in atom_names] + ) + + # Add dummy mapping for restype 'UNK'. + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = np.array( + restype_atom14_to_atom37, dtype=int + ) + restype_atom37_to_atom14 = np.array( + restype_atom37_to_atom14, dtype=int + ) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # Create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein. + residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]] + residx_atom14_mask = restype_atom14_mask[prot["aatype"]] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( + prot["all_atom_mask"], residx_atom14_to_atom37, axis=1 + ).astype(np.float32) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( + np.take_along_axis( + prot["all_atom_positions"], + residx_atom14_to_atom37[..., None], + axis=1, + ) + ) + + prot["atom14_atom_exists"] = residx_atom14_mask + prot["atom14_gt_exists"] = residx_atom14_gt_mask + prot["atom14_gt_positions"] = residx_atom14_gt_positions + + prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64) + + # Create the gather indices for mapping back. + residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]] + prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64) + + # Create the corresponding mask. + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[prot["aatype"]] + prot["atom37_atom_exists"] = residx_atom37_mask + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] + for res in residue_constants.restypes + ] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname + ].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname + ].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack( + [all_matrices[restype] for restype in restype_3] + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[prot["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = np.einsum( + "rac,rab->rbc", residx_atom14_gt_positions, renaming_transform + ) + prot["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = np.einsum( + "ra,rab->rb", residx_atom14_gt_mask, renaming_transform + ) + + prot["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname] + ] + atom_idx1 = residue_constants.restype_name_to_atom14_names[ + resname + ].index(atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[ + resname + ].index(atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ + prot["aatype"] + ] + + return prot + + +def find_violations(prot_np: protein.Protein): + """Analyzes a protein and returns structural violation information. + + Args: + prot_np: A protein. + + Returns: + violations: A `dict` of structure components with structural violations. + violation_metrics: A `dict` of violation metrics. + """ + batch = { + "aatype": prot_np.aatype, + "all_atom_positions": prot_np.atom_positions.astype(np.float32), + "all_atom_mask": prot_np.atom_mask.astype(np.float32), + "residue_index": prot_np.residue_index, + } + + batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32) + batch = make_atom14_positions(batch) + + violations = loss.find_structural_violations_np( + batch=batch, + atom14_pred_positions=batch["atom14_gt_positions"], + config=ml_collections.ConfigDict( + { + "violation_tolerance_factor": 12, # Taken from model config. + "clash_overlap_tolerance": 1.5, # Taken from model config. + } + ), + ) + violation_metrics = loss.compute_violation_metrics_np( + batch=batch, + atom14_pred_positions=batch["atom14_gt_positions"], + violations=violations, + ) + + return violations, violation_metrics + + +def get_violation_metrics(prot: protein.Protein): + """Computes violation and alignment metrics.""" + structural_violations, struct_metrics = find_violations(prot) + violation_idx = np.flatnonzero( + structural_violations["total_per_residue_violations_mask"] + ) + + struct_metrics["residue_violations"] = violation_idx + struct_metrics["num_residue_violations"] = len(violation_idx) + struct_metrics["structural_violations"] = structural_violations + return struct_metrics + + +def _run_one_iteration( + *, + pdb_string: str, + max_iterations: int, + tolerance: float, + stiffness: float, + restraint_set: str, + max_attempts: int, + exclude_residues: Optional[Collection[int]] = None, + use_gpu: bool, +): + """Runs the minimization pipeline. + + Args: + pdb_string: A pdb string. + max_iterations: An `int` specifying the maximum number of L-BFGS iterations. + A value of 0 specifies no limit. + tolerance: kcal/mol, the energy tolerance of L-BFGS. + stiffness: kcal/mol A**2, spring constant of heavy atom restraining + potential. + restraint_set: The set of atoms to restrain. + max_attempts: The maximum number of minimization attempts. + exclude_residues: An optional list of zero-indexed residues to exclude from + restraints. + use_gpu: Whether to run relaxation on GPU + Returns: + A `dict` of minimization info. + """ + exclude_residues = exclude_residues or [] + + # Assign physical dimensions. + tolerance = tolerance * ENERGY + stiffness = stiffness * ENERGY / (LENGTH ** 2) + + start = time.perf_counter() + minimized = False + attempts = 0 + while not minimized and attempts < max_attempts: + attempts += 1 + try: + logging.info( + "Minimizing protein, attempt %d of %d.", attempts, max_attempts + ) + ret = _openmm_minimize( + pdb_string, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + exclude_residues=exclude_residues, + use_gpu=use_gpu, + ) + minimized = True + except Exception as e: # pylint: disable=broad-except + print(e) + logging.info(e) + if not minimized: + raise ValueError(f"Minimization failed after {max_attempts} attempts.") + ret["opt_time"] = time.perf_counter() - start + ret["min_attempts"] = attempts + return ret + + +def run_pipeline( + prot: protein.Protein, + stiffness: float, + use_gpu: bool, + max_outer_iterations: int = 1, + place_hydrogens_every_iteration: bool = True, + max_iterations: int = 0, + tolerance: float = 2.39, + restraint_set: str = "non_hydrogen", + max_attempts: int = 100, + checks: bool = True, + exclude_residues: Optional[Sequence[int]] = None, +): + """Run iterative amber relax. + + Successive relax iterations are performed until all violations have been + resolved. Each iteration involves a restrained Amber minimization, with + restraint exclusions determined by violation-participating residues. + + Args: + prot: A protein to be relaxed. + stiffness: kcal/mol A**2, the restraint stiffness. + use_gpu: Whether to run on GPU + max_outer_iterations: The maximum number of iterative minimization. + place_hydrogens_every_iteration: Whether hydrogens are re-initialized + prior to every minimization. + max_iterations: An `int` specifying the maximum number of L-BFGS steps + per relax iteration. A value of 0 specifies no limit. + tolerance: kcal/mol, the energy tolerance of L-BFGS. + The default value is the OpenMM default. + restraint_set: The set of atoms to restrain. + max_attempts: The maximum number of minimization attempts per iteration. + checks: Whether to perform cleaning checks. + exclude_residues: An optional list of zero-indexed residues to exclude from + restraints. + + Returns: + out: A dictionary of output values. + """ + + # `protein.to_pdb` will strip any poorly-defined residues so we need to + # perform this check before `clean_protein`. + _check_residues_are_well_defined(prot) + pdb_string = clean_protein(prot, checks=checks) + + exclude_residues = exclude_residues or [] + exclude_residues = set(exclude_residues) + violations = np.inf + iteration = 0 + + while violations > 0 and iteration < max_outer_iterations: + ret = _run_one_iteration( + pdb_string=pdb_string, + exclude_residues=exclude_residues, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + max_attempts=max_attempts, + use_gpu=use_gpu, + ) + prot = protein.from_pdb_string(ret["min_pdb"]) + if place_hydrogens_every_iteration: + pdb_string = clean_protein(prot, checks=True) + else: + pdb_string = ret["min_pdb"] + ret.update(get_violation_metrics(prot)) + ret.update( + { + "num_exclusions": len(exclude_residues), + "iteration": iteration, + } + ) + violations = ret["violations_per_residue"] + exclude_residues = exclude_residues.union(ret["residue_violations"]) + + logging.info( + "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s " + "num residue violations %d num residue exclusions %d ", + ret["einit"], + ret["efinal"], + ret["opt_time"], + ret["num_residue_violations"], + ret["num_exclusions"], + ) + iteration += 1 + return ret + + +def get_initial_energies( + pdb_strs: Sequence[str], + stiffness: float = 0.0, + restraint_set: str = "non_hydrogen", + exclude_residues: Optional[Sequence[int]] = None, +): + """Returns initial potential energies for a sequence of PDBs. + + Assumes the input PDBs are ready for minimization, and all have the same + topology. + Allows time to be saved by not pdbfixing / rebuilding the system. + + Args: + pdb_strs: List of PDB strings. + stiffness: kcal/mol A**2, spring constant of heavy atom restraining + potential. + restraint_set: Which atom types to restrain. + exclude_residues: An optional list of zero-indexed residues to exclude from + restraints. + + Returns: + A list of initial energies in the same order as pdb_strs. + """ + exclude_residues = exclude_residues or [] + + openmm_pdbs = [ + openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs + ] + force_field = openmm_app.ForceField("amber99sb.xml") + system = force_field.createSystem( + openmm_pdbs[0].topology, constraints=openmm_app.HBonds + ) + stiffness = stiffness * ENERGY / (LENGTH ** 2) + if stiffness > 0 * ENERGY / (LENGTH ** 2): + _add_restraints( + system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues + ) + simulation = openmm_app.Simulation( + openmm_pdbs[0].topology, + system, + openmm.LangevinIntegrator(0, 0.01, 0.0), + openmm.Platform.getPlatformByName("CPU"), + ) + energies = [] + for pdb in openmm_pdbs: + try: + simulation.context.setPositions(pdb.positions) + state = simulation.context.getState(getEnergy=True) + energies.append(state.getPotentialEnergy().value_in_unit(ENERGY)) + except Exception as e: # pylint: disable=broad-except + logging.error( + "Error getting initial energy, returning large value %s", e + ) + energies.append(unit.Quantity(1e20, ENERGY)) + return energies diff --git a/openfold/np/relax/cleanup.py b/openfold/np/relax/cleanup.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f59a48b3b2184c68ccbbce9048b6a137086ca4 --- /dev/null +++ b/openfold/np/relax/cleanup.py @@ -0,0 +1,131 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. + +fix_pdb uses a third-party tool. We also support fixing some additional edge +cases like removing chains of length one (see clean_structure). +""" +import io + +import pdbfixer +from simtk.openmm import app +from simtk.openmm.app import element + + +def fix_pdb(pdbfile, alterations_info): + """Apply pdbfixer to the contents of a PDB file; return a PDB string result. + + 1) Replaces nonstandard residues. + 2) Removes heterogens (non protein residues) including water. + 3) Adds missing residues and missing atoms within existing residues. + 4) Adds hydrogens assuming pH=7.0. + 5) KeepIds is currently true, so the fixer must keep the existing chain and + residue identifiers. This will fail for some files in wider PDB that have + invalid IDs. + + Args: + pdbfile: Input PDB file handle. + alterations_info: A dict that will store details of changes made. + + Returns: + A PDB string representing the fixed structure. + """ + fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) + fixer.findNonstandardResidues() + alterations_info["nonstandard_residues"] = fixer.nonstandardResidues + fixer.replaceNonstandardResidues() + _remove_heterogens(fixer, alterations_info, keep_water=False) + fixer.findMissingResidues() + alterations_info["missing_residues"] = fixer.missingResidues + fixer.findMissingAtoms() + alterations_info["missing_heavy_atoms"] = fixer.missingAtoms + alterations_info["missing_terminals"] = fixer.missingTerminals + fixer.addMissingAtoms(seed=0) + fixer.addMissingHydrogens() + out_handle = io.StringIO() + app.PDBFile.writeFile( + fixer.topology, fixer.positions, out_handle, keepIds=True + ) + return out_handle.getvalue() + + +def clean_structure(pdb_structure, alterations_info): + """Applies additional fixes to an OpenMM structure, to handle edge cases. + + Args: + pdb_structure: An OpenMM structure to modify and fix. + alterations_info: A dict that will store details of changes made. + """ + _replace_met_se(pdb_structure, alterations_info) + _remove_chains_of_length_one(pdb_structure, alterations_info) + + +def _remove_heterogens(fixer, alterations_info, keep_water): + """Removes the residues that Pdbfixer considers to be heterogens. + + Args: + fixer: A Pdbfixer instance. + alterations_info: A dict that will store details of changes made. + keep_water: If True, water (HOH) is not considered to be a heterogen. + """ + initial_resnames = set() + for chain in fixer.topology.chains(): + for residue in chain.residues(): + initial_resnames.add(residue.name) + fixer.removeHeterogens(keepWater=keep_water) + final_resnames = set() + for chain in fixer.topology.chains(): + for residue in chain.residues(): + final_resnames.add(residue.name) + alterations_info["removed_heterogens"] = initial_resnames.difference( + final_resnames + ) + + +def _replace_met_se(pdb_structure, alterations_info): + """Replace the Se in any MET residues that were not marked as modified.""" + modified_met_residues = [] + for res in pdb_structure.iter_residues(): + name = res.get_name_with_spaces().strip() + if name == "MET": + s_atom = res.get_atom("SD") + if s_atom.element_symbol == "Se": + s_atom.element_symbol = "S" + s_atom.element = element.get_by_symbol("S") + modified_met_residues.append(s_atom.residue_number) + alterations_info["Se_in_MET"] = modified_met_residues + + +def _remove_chains_of_length_one(pdb_structure, alterations_info): + """Removes chains that correspond to a single amino acid. + + A single amino acid in a chain is both N and C terminus. There is no force + template for this case. + + Args: + pdb_structure: An OpenMM pdb_structure to modify and fix. + alterations_info: A dict that will store details of changes made. + """ + removed_chains = {} + for model in pdb_structure.iter_models(): + valid_chains = [c for c in model.iter_chains() if len(c) > 1] + invalid_chain_ids = [ + c.chain_id for c in model.iter_chains() if len(c) <= 1 + ] + model.chains = valid_chains + for chain_id in invalid_chain_ids: + model.chains_by_id.pop(chain_id) + removed_chains[model.number] = invalid_chain_ids + alterations_info["removed_chains"] = removed_chains diff --git a/openfold/np/relax/relax.py b/openfold/np/relax/relax.py new file mode 100644 index 0000000000000000000000000000000000000000..8feee742d033ee3165f5975cb14e13c599bfb364 --- /dev/null +++ b/openfold/np/relax/relax.py @@ -0,0 +1,90 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Amber relaxation.""" +from typing import Any, Dict, Sequence, Tuple +from openfold.np import protein +from openfold.np.relax import amber_minimize, utils +import numpy as np + + +class AmberRelaxation(object): + """Amber relaxation.""" + def __init__( + self, + *, + max_iterations: int, + tolerance: float, + stiffness: float, + exclude_residues: Sequence[int], + max_outer_iterations: int, + use_gpu: bool, + ): + """Initialize Amber Relaxer. + + Args: + max_iterations: Maximum number of L-BFGS iterations. 0 means no max. + tolerance: kcal/mol, the energy tolerance of L-BFGS. + stiffness: kcal/mol A**2, spring constant of heavy atom restraining + potential. + exclude_residues: Residues to exclude from per-atom restraining. + Zero-indexed. + max_outer_iterations: Maximum number of violation-informed relax + iterations. A value of 1 will run the non-iterative procedure used in + CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes + as soon as there are no violations, hence in most cases this causes no + slowdown. In the worst case we do 20 outer iterations. + use_gpu: Whether to run on GPU + """ + + self._max_iterations = max_iterations + self._tolerance = tolerance + self._stiffness = stiffness + self._exclude_residues = exclude_residues + self._max_outer_iterations = max_outer_iterations + self._use_gpu = use_gpu + + def process( + self, *, prot: protein.Protein + ) -> Tuple[str, Dict[str, Any], np.ndarray]: + """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" + out = amber_minimize.run_pipeline( + prot=prot, + max_iterations=self._max_iterations, + tolerance=self._tolerance, + stiffness=self._stiffness, + exclude_residues=self._exclude_residues, + max_outer_iterations=self._max_outer_iterations, + use_gpu=self._use_gpu, + ) + min_pos = out["pos"] + start_pos = out["posinit"] + rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0]) + debug_data = { + "initial_energy": out["einit"], + "final_energy": out["efinal"], + "attempts": out["min_attempts"], + "rmsd": rmsd, + } + pdb_str = amber_minimize.clean_protein(prot) + min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) + min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) + utils.assert_equal_nonterminal_atom_types( + protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask + ) + violations = out["structural_violations"][ + "total_per_residue_violations_mask" + ] + return min_pdb, debug_data, violations diff --git a/openfold/np/relax/utils.py b/openfold/np/relax/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e41aea524a97ddd563dc32963b4671f438c7ae4 --- /dev/null +++ b/openfold/np/relax/utils.py @@ -0,0 +1,88 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for minimization.""" +import io +from openfold.np import residue_constants +from Bio import PDB +import numpy as np +# simtk.openmm is not supported anymore. Remove simtk. +# https://github.com/openmm/openmm/releases +from openmm import app as openmm_app +from openmm.app.internal.pdbstructure import PdbStructure + + +def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: + pdb_file = io.StringIO(pdb_str) + structure = PdbStructure(pdb_file) + topology = openmm_app.PDBFile(structure).getTopology() + with io.StringIO() as f: + openmm_app.PDBFile.writeFile(topology, pos, f) + return f.getvalue() + + +def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: + """Overwrites the B-factors in pdb_str with contents of bfactors array. + + Args: + pdb_str: An input PDB string. + bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the + B-factors are per residue; i.e. that the nonzero entries are identical in + [0, i, :]. + + Returns: + A new PDB string with the B-factors replaced. + """ + if bfactors.shape[-1] != residue_constants.atom_type_num: + raise ValueError( + f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}." + ) + + parser = PDB.PDBParser(QUIET=True) + handle = io.StringIO(pdb_str) + structure = parser.get_structure("", handle) + + curr_resid = ("", "", "") + idx = -1 + for atom in structure.get_atoms(): + atom_resid = atom.parent.get_id() + if atom_resid != curr_resid: + idx += 1 + if idx >= bfactors.shape[0]: + raise ValueError( + "Index into bfactors exceeds number of residues. " + "B-factors shape: {shape}, idx: {idx}." + ) + curr_resid = atom_resid + atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]] + + new_pdb = io.StringIO() + pdb_io = PDB.PDBIO() + pdb_io.set_structure(structure) + pdb_io.save(new_pdb) + return new_pdb.getvalue() + + +def assert_equal_nonterminal_atom_types( + atom_mask: np.ndarray, ref_atom_mask: np.ndarray +): + """Checks that pre- and post-minimized proteins have same atom set.""" + # Ignore any terminal OXT atoms which may have been added by minimization. + oxt = residue_constants.atom_order["OXT"] + no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) + no_oxt_mask[..., oxt] = False + np.testing.assert_almost_equal( + ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] + ) diff --git a/openfold/np/residue_constants.py b/openfold/np/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3eae6fe48634e2dee88375c4a38c293305c5b0 --- /dev/null +++ b/openfold/np/residue_constants.py @@ -0,0 +1,1310 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +from typing import Mapping, List, Tuple +from importlib import resources + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + "ALA": [ + ["N", 0, (-0.525, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.529, -0.774, -1.205)], + ["O", 3, (0.627, 1.062, 0.000)], + ], + "ARG": [ + ["N", 0, (-0.524, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.524, -0.778, -1.209)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.616, 1.390, -0.000)], + ["CD", 5, (0.564, 1.414, 0.000)], + ["NE", 6, (0.539, 1.357, -0.000)], + ["NH1", 7, (0.206, 2.301, 0.000)], + ["NH2", 7, (2.078, 0.978, -0.000)], + ["CZ", 7, (0.758, 1.093, -0.000)], + ], + "ASN": [ + ["N", 0, (-0.536, 1.357, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.531, -0.787, -1.200)], + ["O", 3, (0.625, 1.062, 0.000)], + ["CG", 4, (0.584, 1.399, 0.000)], + ["ND2", 5, (0.593, -1.188, 0.001)], + ["OD1", 5, (0.633, 1.059, 0.000)], + ], + "ASP": [ + ["N", 0, (-0.525, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, 0.000, -0.000)], + ["CB", 0, (-0.526, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.593, 1.398, -0.000)], + ["OD1", 5, (0.610, 1.091, 0.000)], + ["OD2", 5, (0.592, -1.101, -0.003)], + ], + "CYS": [ + ["N", 0, (-0.522, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, 0.000)], + ["CB", 0, (-0.519, -0.773, -1.212)], + ["O", 3, (0.625, 1.062, -0.000)], + ["SG", 4, (0.728, 1.653, 0.000)], + ], + "GLN": [ + ["N", 0, (-0.526, 1.361, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.779, -1.207)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.615, 1.393, 0.000)], + ["CD", 5, (0.587, 1.399, -0.000)], + ["NE2", 6, (0.593, -1.189, -0.001)], + ["OE1", 6, (0.634, 1.060, 0.000)], + ], + "GLU": [ + ["N", 0, (-0.528, 1.361, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.526, -0.781, -1.207)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.615, 1.392, 0.000)], + ["CD", 5, (0.600, 1.397, 0.000)], + ["OE1", 6, (0.607, 1.095, -0.000)], + ["OE2", 6, (0.589, -1.104, -0.001)], + ], + "GLY": [ + ["N", 0, (-0.572, 1.337, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.517, -0.000, -0.000)], + ["O", 3, (0.626, 1.062, -0.000)], + ], + "HIS": [ + ["N", 0, (-0.527, 1.360, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.778, -1.208)], + ["O", 3, (0.625, 1.063, 0.000)], + ["CG", 4, (0.600, 1.370, -0.000)], + ["CD2", 5, (0.889, -1.021, 0.003)], + ["ND1", 5, (0.744, 1.160, -0.000)], + ["CE1", 5, (2.030, 0.851, 0.002)], + ["NE2", 5, (2.145, -0.466, 0.004)], + ], + "ILE": [ + ["N", 0, (-0.493, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.536, -0.793, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.534, 1.437, -0.000)], + ["CG2", 4, (0.540, -0.785, -1.199)], + ["CD1", 5, (0.619, 1.391, 0.000)], + ], + "LEU": [ + ["N", 0, (-0.520, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.773, -1.214)], + ["O", 3, (0.625, 1.063, -0.000)], + ["CG", 4, (0.678, 1.371, 0.000)], + ["CD1", 5, (0.530, 1.430, -0.000)], + ["CD2", 5, (0.535, -0.774, 1.200)], + ], + "LYS": [ + ["N", 0, (-0.526, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.524, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.619, 1.390, 0.000)], + ["CD", 5, (0.559, 1.417, 0.000)], + ["CE", 6, (0.560, 1.416, 0.000)], + ["NZ", 7, (0.554, 1.387, 0.000)], + ], + "MET": [ + ["N", 0, (-0.521, 1.364, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.210)], + ["O", 3, (0.625, 1.062, -0.000)], + ["CG", 4, (0.613, 1.391, -0.000)], + ["SD", 5, (0.703, 1.695, 0.000)], + ["CE", 6, (0.320, 1.786, -0.000)], + ], + "PHE": [ + ["N", 0, (-0.518, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, -0.000)], + ["CB", 0, (-0.525, -0.776, -1.212)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.607, 1.377, 0.000)], + ["CD1", 5, (0.709, 1.195, -0.000)], + ["CD2", 5, (0.706, -1.196, 0.000)], + ["CE1", 5, (2.102, 1.198, -0.000)], + ["CE2", 5, (2.098, -1.201, -0.000)], + ["CZ", 5, (2.794, -0.003, -0.001)], + ], + "PRO": [ + ["N", 0, (-0.566, 1.351, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, 0.000)], + ["CB", 0, (-0.546, -0.611, -1.293)], + ["O", 3, (0.621, 1.066, 0.000)], + ["CG", 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + "SER": [ + ["N", 0, (-0.529, 1.360, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.518, -0.777, -1.211)], + ["O", 3, (0.626, 1.062, -0.000)], + ["OG", 4, (0.503, 1.325, 0.000)], + ], + "THR": [ + ["N", 0, (-0.517, 1.364, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, -0.000)], + ["CB", 0, (-0.516, -0.793, -1.215)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG2", 4, (0.550, -0.718, -1.228)], + ["OG1", 4, (0.472, 1.353, 0.000)], + ], + "TRP": [ + ["N", 0, (-0.521, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.212)], + ["O", 3, (0.627, 1.062, 0.000)], + ["CG", 4, (0.609, 1.370, -0.000)], + ["CD1", 5, (0.824, 1.091, 0.000)], + ["CD2", 5, (0.854, -1.148, -0.005)], + ["CE2", 5, (2.186, -0.678, -0.007)], + ["CE3", 5, (0.622, -2.530, -0.007)], + ["NE1", 5, (2.140, 0.690, -0.004)], + ["CH2", 5, (3.028, -2.890, -0.013)], + ["CZ2", 5, (3.283, -1.543, -0.011)], + ["CZ3", 5, (1.715, -3.389, -0.011)], + ], + "TYR": [ + ["N", 0, (-0.522, 1.362, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.776, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG", 4, (0.607, 1.382, -0.000)], + ["CD1", 5, (0.716, 1.195, -0.000)], + ["CD2", 5, (0.713, -1.194, -0.001)], + ["CE1", 5, (2.107, 1.200, -0.002)], + ["CE2", 5, (2.104, -1.201, -0.003)], + ["OH", 5, (4.168, -0.002, -0.005)], + ["CZ", 5, (2.791, -0.001, -0.003)], + ], + "VAL": [ + ["N", 0, (-0.494, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.533, -0.795, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.540, 1.429, -0.000)], + ["CG2", 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "N", + "O", + "OH", + ], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# TODO: ^ interpret this +residue_atom_renaming_swaps = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple( + "Bond", ["atom1_name", "atom2_name", "length", "stddev"] +) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + # TODO: this file should be downloaded in a setup script + stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt") + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev)) + ) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt( + bond1.length ** 2 + + bond2.length ** 2 + - 2 * bond1.length * bond2.length * np.cos(gamma) + ) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = ( + 2 * bond1.length * bond2.length * np.sin(gamma) + ) * dl_outer + dl_db1 = ( + 2 * bond1.length - 2 * bond2.length * np.cos(gamma) + ) * dl_outer + dl_db2 = ( + 2 * bond2.length - 2 * bond1.length * np.cos(gamma) + ) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + + (dl_db1 * bond1.stddev) ** 2 + + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev) + ) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "ND2", + "", + "", + "", + "", + "", + "", + ], + "ASP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "OD2", + "", + "", + "", + "", + "", + "", + ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "CD1", + "", + "", + "", + "", + "", + "", + ], + "LEU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "", + "", + "", + "", + "", + "", + ], + "LYS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "CE", + "NZ", + "", + "", + "", + "", + "", + ], + "MET": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "SD", + "CE", + "", + "", + "", + "", + "", + "", + ], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": [ + "N", + "CA", + "C", + "O", + "CB", + "OG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "NE1", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ["X"] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False +) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 " + "without any gaps. Got: %s" % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError( + f"Invalid character in the sequence: {aa_type}" + ) + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + "A": 0, + "B": 2, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "J": 20, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "O": 20, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "U": 1, + "V": 17, + "W": 18, + "X": 20, + "Y": 19, + "Z": 3, + "-": 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: "A", + 1: "C", # Also U. + 2: "D", # Also B. + 3: "E", # Also Z. + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", # Includes J and O. + 21: "-", +} + +restypes_with_x_and_gap = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap)) +) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=int) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices +) +chi_angles_atom_indices = np.array( + [ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices + ] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack( + [ex_normalized, ey_normalized, eznorm, translation] + ).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname + ]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[ + restype, atomtype, : + ] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[ + restype, atom14idx, : + ] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[ + restype, 4 + chi_idx, :, : + ] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds( + overlap_tolerance=1.5, bond_length_tolerance_factor=15 +): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx = np.tile( + np.arange(14, dtype=int), (21, 1) +) + + +def _make_atom14_ambiguity_feats(): + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom1_idx + ] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom2_idx + ] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype): + return ''.join([ + restypes_with_x[aatype[i]] + for i in range(len(aatype)) + ]) diff --git a/openfold/utils/__pycache__/checkpointing.cpython-310.pyc b/openfold/utils/__pycache__/checkpointing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1849c7538e102c16b4d5faff4c0a109a946f852 Binary files /dev/null and b/openfold/utils/__pycache__/checkpointing.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc b/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbc8d273987289da03ea78e161e8709db1e9c64c Binary files /dev/null and b/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/feats.cpython-310.pyc b/openfold/utils/__pycache__/feats.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bae498213af441eba021d71240ee3450b7a20aef Binary files /dev/null and b/openfold/utils/__pycache__/feats.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/loss.cpython-310.pyc b/openfold/utils/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dc3735c9e452c66675182977db67028923c192c Binary files /dev/null and b/openfold/utils/__pycache__/loss.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/precision_utils.cpython-310.pyc b/openfold/utils/__pycache__/precision_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e169848727ad015efaa0fdbbbc47122d6c96eefa Binary files /dev/null and b/openfold/utils/__pycache__/precision_utils.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc b/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..119fe21dcd80406a3b5d234c28a37f56d6932cdb Binary files /dev/null and b/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc differ diff --git a/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc b/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b043ca55f494fc84d30c696ac552219247160ed Binary files /dev/null and b/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc differ diff --git a/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc b/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee386bc0b20fe83c201b59d11cb6da9618a33bd2 Binary files /dev/null and b/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc differ diff --git a/openfold/utils/argparse.py b/openfold/utils/argparse.py new file mode 100644 index 0000000000000000000000000000000000000000..d13f895ec418ef74af1c1ffd7ec227cdd691236f --- /dev/null +++ b/openfold/utils/argparse.py @@ -0,0 +1,30 @@ +from argparse import HelpFormatter +from operator import attrgetter + +class ArgparseAlphabetizer(HelpFormatter): + """ + Sorts the optional arguments of an argparse parser alphabetically + """ + + @staticmethod + def sort_actions(actions): + return sorted(actions, key=attrgetter("option_strings")) + + # Formats the help message + def add_arguments(self, actions): + actions = ArgparseAlphabetizer.sort_actions(actions) + super(ArgparseAlphabetizer, self).add_arguments(actions) + + # Formats the usage message + def add_usage(self, usage, actions, groups, prefix=None): + actions = ArgparseAlphabetizer.sort_actions(actions) + args = usage, actions, groups, prefix + super(ArgparseAlphabetizer, self).add_usage(*args) + + +def remove_arguments(parser, args): + for arg in args: + for action in parser._actions: + opts = vars(action)["option_strings"] + if(arg in opts): + parser._handle_conflict_resolve(None, [(arg, action)]) diff --git a/openfold/utils/callbacks.py b/openfold/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..872dbdce98bcca08e58368481c05024986b4a71f --- /dev/null +++ b/openfold/utils/callbacks.py @@ -0,0 +1,14 @@ +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +class EarlyStoppingVerbose(EarlyStopping): + """ + The default EarlyStopping callback's verbose mode is too verbose. + This class outputs a message only when it's getting ready to stop. + """ + def _evalute_stopping_criteria(self, *args): + should_stop, reason = super()._evalute_stopping_criteria(*args) + if(should_stop): + rank_zero_info(f"{reason}\n") + + return should_stop, reason diff --git a/openfold/utils/checkpointing.py b/openfold/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..75a4455ae6b1581d76dab7bcf3adf590e9be15df --- /dev/null +++ b/openfold/utils/checkpointing.py @@ -0,0 +1,88 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import deepspeed +import torch +import torch.utils.checkpoint +from typing import Any, Tuple, List, Callable, Optional + + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +def get_checkpoint_fn(): + if(deepspeed.checkpointing.is_configured()): + checkpoint = deepspeed.checkpointing.checkpoint + else: + checkpoint = torch.utils.checkpoint.checkpoint + + return checkpoint + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None: + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + checkpoint = get_checkpoint_fn() + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = checkpoint(chunker(s, e), *args) + args = wrap(args) + + return args \ No newline at end of file diff --git a/openfold/utils/exponential_moving_average.py b/openfold/utils/exponential_moving_average.py new file mode 100644 index 0000000000000000000000000000000000000000..53649506a7af46c7530687c14c79c2cb8d43c089 --- /dev/null +++ b/openfold/utils/exponential_moving_average.py @@ -0,0 +1,70 @@ +from collections import OrderedDict +import copy +import torch +import torch.nn as nn + +from openfold.utils.tensor_utils import tensor_tree_map + + +class ExponentialMovingAverage: + """ + Maintains moving averages of parameters with exponential decay + + At each step, the stored copy `copy` of each parameter `param` is + updated as follows: + + `copy = decay * copy + (1 - decay) * param` + + where `decay` is an attribute of the ExponentialMovingAverage object. + """ + + def __init__(self, model: nn.Module, decay: float): + """ + Args: + model: + A torch.nn.Module whose parameters are to be tracked + decay: + A value (usually close to 1.) by which updates are + weighted as part of the above formula + """ + super(ExponentialMovingAverage, self).__init__() + + clone_param = lambda t: t.clone().detach() + self.params = tensor_tree_map(clone_param, model.state_dict()) + self.decay = decay + self.device = next(model.parameters()).device + + def to(self, device): + self.params = tensor_tree_map(lambda t: t.to(device), self.params) + self.device = device + + def _update_state_dict_(self, update, state_dict): + with torch.no_grad(): + for k, v in update.items(): + stored = state_dict[k] + if not isinstance(v, torch.Tensor): + self._update_state_dict_(v, stored) + else: + diff = stored - v + diff *= 1 - self.decay + stored -= diff + + def update(self, model: torch.nn.Module) -> None: + """ + Updates the stored parameters using the state dict of the provided + module. The module should have the same structure as that used to + initialize the ExponentialMovingAverage object. + """ + self._update_state_dict_(model.state_dict(), self.params) + + def load_state_dict(self, state_dict: OrderedDict) -> None: + self.params = state_dict["params"] + self.decay = state_dict["decay"] + + def state_dict(self) -> OrderedDict: + return OrderedDict( + { + "params": self.params, + "decay": self.decay, + } + ) diff --git a/openfold/utils/feats.py b/openfold/utils/feats.py new file mode 100644 index 0000000000000000000000000000000000000000..3be298ff4cfd578d446142826937b5643551db29 --- /dev/null +++ b/openfold/utils/feats.py @@ -0,0 +1,267 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +from typing import Dict + +from openfold.np import protein +import openfold.np.residue_constants as rc +from openfold.utils.rigid_utils import Rotation, Rigid +from openfold.utils.tensor_utils import ( + batched_gather, + one_hot, + tree_map, + tensor_tree_map, +) + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order["G"] + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats): + template_aatype = template_feats["template_aatype"] + torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] + alt_torsion_angles_sin_cos = template_feats[ + "template_alt_torsion_angles_sin_cos" + ] + torsion_angles_mask = template_feats["template_torsion_angles_mask"] + template_angle_feat = torch.cat( + [ + nn.functional.one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape( + *torsion_angles_sin_cos.shape[:-2], 14 + ), + alt_torsion_angles_sin_cos.reshape( + *alt_torsion_angles_sin_cos.shape[:-2], 14 + ), + torsion_angles_mask, + ], + dim=-1, + ) + + return template_angle_feat + + +def build_template_pair_feat( + batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8 +): + template_mask = batch["template_pseudo_beta_mask"] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + # Compute distogram (this seems to differ slightly from Alg. 5) + tpb = batch["template_pseudo_beta"] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True + ) + lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2 + upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch["template_aatype"], + rc.restype_num + 2, + ) + + n_res = batch["template_aatype"].shape[-1] + to_concat.append( + aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1 + ) + ) + to_concat.append( + aatype_one_hot[..., None, :].expand( + *aatype_one_hot.shape[:-2], -1, n_res, -1 + ) + ) + + n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] + rigids = Rigid.make_transform_from_reference( + n_xyz=batch["template_all_atom_positions"][..., n, :], + ca_xyz=batch["template_all_atom_positions"][..., ca, :], + c_xyz=batch["template_all_atom_positions"][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1)) + + t_aa_masks = batch["template_all_atom_mask"] + template_mask = ( + t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] + ) + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + inv_distance_scalar = inv_distance_scalar * template_mask_2d + unit_vector = rigid_vec * inv_distance_scalar[..., None] + to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_extra_msa_feat(batch): + msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) + msa_feat = [ + msa_1hot, + batch["extra_has_deletion"].unsqueeze(-1), + batch["extra_deletion_value"].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) + + +def torsion_angles_to_frames( + r: Rigid, + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +): + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat( + [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2 + ) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Rigid, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + # [*, N, 14, 4, 4] + default_4x4 = default_frames[aatype, ...] + + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + group_mask = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1) + ) + + # [*, N, 14, 1] + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/openfold/utils/import_weights.py b/openfold/utils/import_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1ebd6d6adc6d60e9dadbb04c677db013cb7f3a --- /dev/null +++ b/openfold/utils/import_weights.py @@ -0,0 +1,449 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from dataclasses import dataclass +from functools import partial +import numpy as np +import torch +from typing import Union, List + + +_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" + + +# With Param, a poor man's enum with attributes (Rust-style) +class ParamType(Enum): + LinearWeight = partial( # hack: partial prevents fns from becoming methods + lambda w: w.transpose(-1, -2) + ) + LinearWeightMHA = partial( + lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) + ) + LinearMHAOutputWeight = partial( + lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) + ) + LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1)) + LinearWeightOPM = partial( + lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) + ) + Other = partial(lambda w: w) + + def __init__(self, fn): + self.transformation = fn + + +@dataclass +class Param: + param: Union[torch.Tensor, List[torch.Tensor]] + param_type: ParamType = ParamType.Other + stacked: bool = False + + +def _process_translations_dict(d, top_layer=True): + flat = {} + for k, v in d.items(): + if type(v) == dict: + prefix = _NPZ_KEY_PREFIX if top_layer else "" + sub_flat = { + (prefix + "/".join([k, k_prime])): v_prime + for k_prime, v_prime in _process_translations_dict( + v, top_layer=False + ).items() + } + flat.update(sub_flat) + else: + k = "/" + k if not top_layer else k + flat[k] = v + + return flat + + +def stacked(param_dict_list, out=None): + """ + Args: + param_dict_list: + A list of (nested) Param dicts to stack. The structure of + each dict must be the identical (down to the ParamTypes of + "parallel" Params). There must be at least one dict + in the list. + """ + if out is None: + out = {} + template = param_dict_list[0] + for k, _ in template.items(): + v = [d[k] for d in param_dict_list] + if type(v[0]) is dict: + out[k] = {} + stacked(v, out=out[k]) + elif type(v[0]) is Param: + stacked_param = Param( + param=[param.param for param in v], + param_type=v[0].param_type, + stacked=True, + ) + + out[k] = stacked_param + + return out + + +def assign(translation_dict, orig_weights): + for k, param in translation_dict.items(): + with torch.no_grad(): + weights = torch.as_tensor(orig_weights[k]) + ref, param_type = param.param, param.param_type + if param.stacked: + weights = torch.unbind(weights, 0) + else: + weights = [weights] + ref = [ref] + + try: + weights = list(map(param_type.transformation, weights)) + for p, w in zip(ref, weights): + p.copy_(w) + except: + print(k) + print(ref[0].shape) + print(weights[0].shape) + raise + + +def import_jax_weights_(model, npz_path, version="model_1"): + data = np.load(npz_path) + + ####################### + # Some templates + ####################### + + LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) + + LinearBias = lambda l: (Param(l)) + + LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) + + LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) + + LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) + + LinearParams = lambda l: { + "weights": LinearWeight(l.weight), + "bias": LinearBias(l.bias), + } + + LayerNormParams = lambda l: { + "scale": Param(l.weight), + "offset": Param(l.bias), + } + + AttentionParams = lambda att: { + "query_w": LinearWeightMHA(att.linear_q.weight), + "key_w": LinearWeightMHA(att.linear_k.weight), + "value_w": LinearWeightMHA(att.linear_v.weight), + "output_w": Param( + att.linear_o.weight, + param_type=ParamType.LinearMHAOutputWeight, + ), + "output_b": LinearBias(att.linear_o.bias), + } + + AttentionGatedParams = lambda att: dict( + **AttentionParams(att), + **{ + "gating_w": LinearWeightMHA(att.linear_g.weight), + "gating_b": LinearBiasMHA(att.linear_g.bias), + }, + ) + + GlobalAttentionParams = lambda att: dict( + AttentionGatedParams(att), + key_w=LinearWeight(att.linear_k.weight), + value_w=LinearWeight(att.linear_v.weight), + ) + + TriAttParams = lambda tri_att: { + "query_norm": LayerNormParams(tri_att.layer_norm), + "feat_2d_weights": LinearWeight(tri_att.linear.weight), + "attention": AttentionGatedParams(tri_att.mha), + } + + TriMulOutParams = lambda tri_mul: { + "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), + "left_projection": LinearParams(tri_mul.linear_a_p), + "right_projection": LinearParams(tri_mul.linear_b_p), + "left_gate": LinearParams(tri_mul.linear_a_g), + "right_gate": LinearParams(tri_mul.linear_b_g), + "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), + "output_projection": LinearParams(tri_mul.linear_z), + "gating_linear": LinearParams(tri_mul.linear_g), + } + + # see commit b88f8da on the Alphafold repo + # Alphafold swaps the pseudocode's a and b between the incoming/outcoming + # iterations of triangle multiplication, which is confusing and not + # reproduced in our implementation. + TriMulInParams = lambda tri_mul: { + "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), + "left_projection": LinearParams(tri_mul.linear_b_p), + "right_projection": LinearParams(tri_mul.linear_a_p), + "left_gate": LinearParams(tri_mul.linear_b_g), + "right_gate": LinearParams(tri_mul.linear_a_g), + "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), + "output_projection": LinearParams(tri_mul.linear_z), + "gating_linear": LinearParams(tri_mul.linear_g), + } + + PairTransitionParams = lambda pt: { + "input_layer_norm": LayerNormParams(pt.layer_norm), + "transition1": LinearParams(pt.linear_1), + "transition2": LinearParams(pt.linear_2), + } + + MSAAttParams = lambda matt: { + "query_norm": LayerNormParams(matt.layer_norm_m), + "attention": AttentionGatedParams(matt.mha), + } + + MSAColAttParams = lambda matt: { + "query_norm": LayerNormParams(matt._msa_att.layer_norm_m), + "attention": AttentionGatedParams(matt._msa_att.mha), + } + + MSAGlobalAttParams = lambda matt: { + "query_norm": LayerNormParams(matt.layer_norm_m), + "attention": GlobalAttentionParams(matt.global_attention), + } + + MSAAttPairBiasParams = lambda matt: dict( + **MSAAttParams(matt), + **{ + "feat_2d_norm": LayerNormParams(matt.layer_norm_z), + "feat_2d_weights": LinearWeight(matt.linear_z.weight), + }, + ) + + IPAParams = lambda ipa: { + "q_scalar": LinearParams(ipa.linear_q), + "kv_scalar": LinearParams(ipa.linear_kv), + "q_point_local": LinearParams(ipa.linear_q_points), + "kv_point_local": LinearParams(ipa.linear_kv_points), + "trainable_point_weights": Param( + param=ipa.head_weights, param_type=ParamType.Other + ), + "attention_2d": LinearParams(ipa.linear_b), + "output_projection": LinearParams(ipa.linear_out), + } + + TemplatePairBlockParams = lambda b: { + "triangle_attention_starting_node": TriAttParams(b.tri_att_start), + "triangle_attention_ending_node": TriAttParams(b.tri_att_end), + "triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out), + "triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in), + "pair_transition": PairTransitionParams(b.pair_transition), + } + + MSATransitionParams = lambda m: { + "input_layer_norm": LayerNormParams(m.layer_norm), + "transition1": LinearParams(m.linear_1), + "transition2": LinearParams(m.linear_2), + } + + OuterProductMeanParams = lambda o: { + "layer_norm_input": LayerNormParams(o.layer_norm), + "left_projection": LinearParams(o.linear_1), + "right_projection": LinearParams(o.linear_2), + "output_w": LinearWeightOPM(o.linear_out.weight), + "output_b": LinearBias(o.linear_out.bias), + } + + def EvoformerBlockParams(b, is_extra_msa=False): + if is_extra_msa: + col_att_name = "msa_column_global_attention" + msa_col_att_params = MSAGlobalAttParams(b.msa_att_col) + else: + col_att_name = "msa_column_attention" + msa_col_att_params = MSAColAttParams(b.msa_att_col) + + d = { + "msa_row_attention_with_pair_bias": MSAAttPairBiasParams( + b.msa_att_row + ), + col_att_name: msa_col_att_params, + "msa_transition": MSATransitionParams(b.core.msa_transition), + "outer_product_mean": + OuterProductMeanParams(b.core.outer_product_mean), + "triangle_multiplication_outgoing": + TriMulOutParams(b.core.tri_mul_out), + "triangle_multiplication_incoming": + TriMulInParams(b.core.tri_mul_in), + "triangle_attention_starting_node": + TriAttParams(b.core.tri_att_start), + "triangle_attention_ending_node": + TriAttParams(b.core.tri_att_end), + "pair_transition": + PairTransitionParams(b.core.pair_transition), + } + + return d + + ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) + + FoldIterationParams = lambda sm: { + "invariant_point_attention": IPAParams(sm.ipa), + "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), + "transition": LinearParams(sm.transition.layers[0].linear_1), + "transition_1": LinearParams(sm.transition.layers[0].linear_2), + "transition_2": LinearParams(sm.transition.layers[0].linear_3), + "transition_layer_norm": LayerNormParams(sm.transition.layer_norm), + "affine_update": LinearParams(sm.bb_update.linear), + "rigid_sidechain": { + "input_projection": LinearParams(sm.angle_resnet.linear_in), + "input_projection_1": LinearParams(sm.angle_resnet.linear_initial), + "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), + "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), + "resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), + "resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), + "unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), + }, + } + + ############################ + # translations dict overflow + ############################ + + tps_blocks = model.template_pair_stack.blocks + tps_blocks_params = stacked( + [TemplatePairBlockParams(b) for b in tps_blocks] + ) + + ems_blocks = model.extra_msa_stack.blocks + ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) + + evo_blocks = model.evoformer.blocks + evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) + + translations = { + "evoformer": { + "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), + "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m), + "left_single": LinearParams(model.input_embedder.linear_tf_z_i), + "right_single": LinearParams(model.input_embedder.linear_tf_z_j), + "prev_pos_linear": LinearParams(model.recycling_embedder.linear), + "prev_msa_first_row_norm": LayerNormParams( + model.recycling_embedder.layer_norm_m + ), + "prev_pair_norm": LayerNormParams( + model.recycling_embedder.layer_norm_z + ), + "pair_activiations": LinearParams( + model.input_embedder.linear_relpos + ), + "template_embedding": { + "single_template_embedding": { + "embedding2d": LinearParams( + model.template_pair_embedder.linear + ), + "template_pair_stack": { + "__layer_stack_no_state": tps_blocks_params, + }, + "output_layer_norm": LayerNormParams( + model.template_pair_stack.layer_norm + ), + }, + "attention": AttentionParams(model.template_pointwise_att.mha), + }, + "extra_msa_activations": LinearParams( + model.extra_msa_embedder.linear + ), + "extra_msa_stack": ems_blocks_params, + "template_single_embedding": LinearParams( + model.template_angle_embedder.linear_1 + ), + "template_projection": LinearParams( + model.template_angle_embedder.linear_2 + ), + "evoformer_iteration": evo_blocks_params, + "single_activations": LinearParams(model.evoformer.linear), + }, + "structure_module": { + "single_layer_norm": LayerNormParams( + model.structure_module.layer_norm_s + ), + "initial_projection": LinearParams( + model.structure_module.linear_in + ), + "pair_layer_norm": LayerNormParams( + model.structure_module.layer_norm_z + ), + "fold_iteration": FoldIterationParams(model.structure_module), + }, + "predicted_lddt_head": { + "input_layer_norm": LayerNormParams( + model.aux_heads.plddt.layer_norm + ), + "act_0": LinearParams(model.aux_heads.plddt.linear_1), + "act_1": LinearParams(model.aux_heads.plddt.linear_2), + "logits": LinearParams(model.aux_heads.plddt.linear_3), + }, + "distogram_head": { + "half_logits": LinearParams(model.aux_heads.distogram.linear), + }, + "experimentally_resolved_head": { + "logits": LinearParams( + model.aux_heads.experimentally_resolved.linear + ), + }, + "masked_msa_head": { + "logits": LinearParams(model.aux_heads.masked_msa.linear), + }, + } + + no_templ = [ + "model_3", + "model_4", + "model_5", + "model_3_ptm", + "model_4_ptm", + "model_5_ptm", + ] + if version in no_templ: + evo_dict = translations["evoformer"] + keys = list(evo_dict.keys()) + for k in keys: + if "template_" in k: + evo_dict.pop(k) + + if "_ptm" in version: + translations["predicted_aligned_error_head"] = { + "logits": LinearParams(model.aux_heads.tm.linear) + } + + # Flatten keys and insert missing key prefixes + flat = _process_translations_dict(translations) + + # Sanity check + keys = list(data.keys()) + flat_keys = list(flat.keys()) + incorrect = [k for k in flat_keys if k not in keys] + missing = [k for k in keys if k not in flat_keys] + # print(f"Incorrect: {incorrect}") + # print(f"Missing: {missing}") + + assert len(incorrect) == 0 + # assert(sorted(list(flat.keys())) == sorted(list(data.keys()))) + + # Set weights + assign(flat, data) diff --git a/openfold/utils/logger.py b/openfold/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e2223e3020765b8d847c269285a5f5de248b94 --- /dev/null +++ b/openfold/utils/logger.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import operator +import time + +import dllogger as logger +import numpy as np +import torch.cuda.profiler as profiler +from dllogger import JSONStreamBackend, StdOutBackend, Verbosity +from pytorch_lightning import Callback + + +def is_main_process(): + return int(os.getenv("LOCAL_RANK", "0")) == 0 + + +class PerformanceLoggingCallback(Callback): + def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False): + logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)]) + self.warmup_steps = warmup_steps + self.global_batch_size = global_batch_size + self.step = 0 + self.profile = profile + self.timestamps = [] + + def do_step(self): + self.step += 1 + if self.profile and self.step == self.warmup_steps: + profiler.start() + if self.step > self.warmup_steps: + self.timestamps.append(time.time()) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.do_step() + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.do_step() + + def process_performance_stats(self, deltas): + def _round3(val): + return round(val, 3) + + throughput_imgps = _round3(self.global_batch_size / np.mean(deltas)) + timestamps_ms = 1000 * deltas + stats = { + f"throughput": throughput_imgps, + f"latency_mean": _round3(timestamps_ms.mean()), + } + for level in [90, 95, 99]: + stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))}) + + return stats + + def _log(self): + if is_main_process(): + diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1])) + deltas = np.array(diffs) + stats = self.process_performance_stats(deltas) + logger.log(step=(), data=stats) + logger.flush() + + def on_train_end(self, trainer, pl_module): + if self.profile: + profiler.stop() + self._log() + + def on_epoch_end(self, trainer, pl_module): + self._log() diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..50809af9e7c3564e7104bfb298f50fc64f8dd597 --- /dev/null +++ b/openfold/utils/loss.py @@ -0,0 +1,1637 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import logging +import ml_collections +import numpy as np +import torch +import torch.nn as nn +from torch.distributions.bernoulli import Bernoulli +from typing import Dict, Optional, Tuple + +from openfold.np import residue_constants +from openfold.utils import feats +from openfold.utils.rigid_utils import Rotation, Rigid +from openfold.utils.tensor_utils import ( + tree_map, + tensor_tree_map, + masked_mean, + permute_final_dims, + batched_gather, +) + + +def softmax_cross_entropy(logits, labels): + loss = -1 * torch.sum( + labels * torch.nn.functional.log_softmax(logits, dim=-1), + dim=-1, + ) + return loss + + +def sigmoid_cross_entropy(logits, labels): + log_p = torch.log(torch.sigmoid(logits)) + log_not_p = torch.log(torch.sigmoid(-logits)) + loss = -labels * log_p - (1 - labels) * log_not_p + return loss + + +def torsion_angle_loss( + a, # [*, N, 7, 2] + a_gt, # [*, N, 7, 2] + a_alt_gt, # [*, N, 7, 2] +): + # [*, N, 7] + norm = torch.norm(a, dim=-1) + + # [*, N, 7, 2] + a = a / norm.unsqueeze(-1) + + # [*, N, 7] + diff_norm_gt = torch.norm(a - a_gt, dim=-1) + diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) + min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2) + + # [*] + l_torsion = torch.mean(min_diff, dim=(-1, -2)) + l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) + + an_weight = 0.02 + return l_torsion + an_weight * l_angle_norm + + +def compute_fape( + pred_frames: Rigid, + target_frames: Rigid, + frames_mask: torch.Tensor, + pred_positions: torch.Tensor, + target_positions: torch.Tensor, + positions_mask: torch.Tensor, + length_scale: float, + l1_clamp_distance: Optional[float] = None, + eps=1e-8, + ignore_nan=True, +) -> torch.Tensor: + """ + Computes FAPE loss. + + Args: + pred_frames: + [*, N_frames] Rigid object of predicted frames + target_frames: + [*, N_frames] Rigid object of ground truth frames + frames_mask: + [*, N_frames] binary mask for the frames + pred_positions: + [*, N_pts, 3] predicted atom positions + target_positions: + [*, N_pts, 3] ground truth positions + positions_mask: + [*, N_pts] positions mask + length_scale: + Length scale by which the loss is divided + l1_clamp_distance: + Cutoff above which distance errors are disregarded + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + # [*, N_frames, N_pts, 3] + local_pred_pos = pred_frames.invert()[..., None].apply( + pred_positions[..., None, :, :], + ) + local_target_pos = target_frames.invert()[..., None].apply( + target_positions[..., None, :, :], + ) + + error_dist = torch.sqrt( + torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps + ) + + if l1_clamp_distance is not None: + error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error = normed_error * frames_mask[..., None] + normed_error = normed_error * positions_mask[..., None, :] + if ignore_nan: + normed_error = torch.nan_to_num(normed_error) + + # FP16-friendly averaging. Roughly equivalent to: + # + # norm_factor = ( + # torch.sum(frames_mask, dim=-1) * + # torch.sum(positions_mask, dim=-1) + # ) + # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) + # + # ("roughly" because eps is necessarily duplicated in the latter) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = ( + normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] + ) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) + return normed_error + + +def backbone_loss( + backbone_rigid_tensor: torch.Tensor, + backbone_rigid_mask: torch.Tensor, + traj: torch.Tensor, + use_clamped_fape: Optional[torch.Tensor] = None, + clamp_distance: float = 10.0, + loss_unit_distance: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + pred_aff = Rigid.from_tensor_7(traj) + pred_aff = Rigid( + Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), + pred_aff.get_trans(), + ) + + # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of + # backbone tensor, normalizes it, and then turns it back to a rotation + # matrix. To avoid a potentially numerically unstable rotation matrix + # to quaternion conversion, we just use the original rotation matrix + # outright. This one hasn't been composed a bunch of times, though, so + # it might be fine. + gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=clamp_distance, + length_scale=loss_unit_distance, + eps=eps, + ) + if use_clamped_fape is not None: + unclamped_fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=None, + length_scale=loss_unit_distance, + eps=eps, + ) + + fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * ( + 1 - use_clamped_fape + ) + + # Average over the batch dimension + fape_loss = torch.mean(fape_loss) + + return fape_loss + + +def sidechain_loss( + sidechain_frames: torch.Tensor, + sidechain_atom_pos: torch.Tensor, + rigidgroups_gt_frames: torch.Tensor, + rigidgroups_alt_gt_frames: torch.Tensor, + rigidgroups_gt_exists: torch.Tensor, + renamed_atom14_gt_positions: torch.Tensor, + renamed_atom14_gt_exists: torch.Tensor, + alt_naming_is_better: torch.Tensor, + clamp_distance: float = 10.0, + length_scale: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + renamed_gt_frames = ( + 1.0 - alt_naming_is_better[..., None, None, None] + ) * rigidgroups_gt_frames + alt_naming_is_better[ + ..., None, None, None + ] * rigidgroups_alt_gt_frames + + # Steamroll the inputs + sidechain_frames = sidechain_frames[-1] + batch_dims = sidechain_frames.shape[:-4] + sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) + sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) + renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) + renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) + rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) + sidechain_atom_pos = sidechain_atom_pos[-1] + sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) + renamed_atom14_gt_positions = renamed_atom14_gt_positions.view( + *batch_dims, -1, 3 + ) + renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) + + fape = compute_fape( + sidechain_frames, + renamed_gt_frames, + rigidgroups_gt_exists, + sidechain_atom_pos, + renamed_atom14_gt_positions, + renamed_atom14_gt_exists, + l1_clamp_distance=clamp_distance, + length_scale=length_scale, + eps=eps, + ) + + return fape + + +def fape_loss( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + bb_loss = backbone_loss( + traj=out["sm"]["frames"], + **{**batch, **config.backbone}, + ) + + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + **{**batch, **config.sidechain}, + ) + + loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def supervised_chi_loss( + angles_sin_cos: torch.Tensor, + unnormalized_angles_sin_cos: torch.Tensor, + aatype: torch.Tensor, + seq_mask: torch.Tensor, + chi_mask: torch.Tensor, + chi_angles_sin_cos: torch.Tensor, + chi_weight: float, + angle_norm_weight: float, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + """ + Implements Algorithm 27 (torsionAngleLoss) + + Args: + angles_sin_cos: + [*, N, 7, 2] predicted angles + unnormalized_angles_sin_cos: + The same angles, but unnormalized + aatype: + [*, N] residue indices + seq_mask: + [*, N] sequence mask + chi_mask: + [*, N, 7] angle mask + chi_angles_sin_cos: + [*, N, 7, 2] ground truth angles + chi_weight: + Weight for the angle component of the loss + angle_norm_weight: + Weight for the normalization component of the loss + Returns: + [*] loss tensor + """ + pred_angles = angles_sin_cos[..., 3:, :] + residue_type_one_hot = torch.nn.functional.one_hot( + aatype, + residue_constants.restype_num + 1, + ) + chi_pi_periodic = torch.einsum( + "...ij,jk->ik", + residue_type_one_hot.type(angles_sin_cos.dtype), + angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), + ) + + true_chi = chi_angles_sin_cos[None] + + shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) + true_chi_shifted = shifted_mask * true_chi + sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1) + sq_chi_error_shifted = torch.sum( + (true_chi_shifted - pred_angles) ** 2, dim=-1 + ) + sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) + # The ol' switcheroo + sq_chi_error = sq_chi_error.permute( + *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1 + ) + sq_chi_loss = masked_mean( + chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3) + ) + + loss = chi_weight * sq_chi_loss + + angle_norm = torch.sqrt( + torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps + ) + norm_error = torch.abs(angle_norm - 1.0) + norm_error = norm_error.permute( + *range(len(norm_error.shape))[1:-2], 0, -2, -1 + ) + angle_norm_loss = masked_mean( + seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3) + ) + + loss = loss + angle_norm_weight * angle_norm_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def compute_plddt(logits: torch.Tensor) -> torch.Tensor: + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device + ) + probs = torch.nn.functional.softmax(logits, dim=-1) + pred_lddt_ca = torch.sum( + probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return pred_lddt_ca * 100 + + +def lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + ( + all_atom_positions[..., None, :] + - all_atom_positions[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + ( + all_atom_pred_pos[..., None, :] + - all_atom_pred_pos[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * permute_final_dims(all_atom_mask, (1, 0)) + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + return lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) + + +def lddt_loss( + logits: torch.Tensor, + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + cutoff: float = 15.0, + no_bins: int = 50, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + score = lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps + ) + + score = score.detach() + + bin_index = torch.floor(score * no_bins).long() + bin_index = torch.clamp(bin_index, max=(no_bins - 1)) + lddt_ca_one_hot = torch.nn.functional.one_hot( + bin_index, num_classes=no_bins + ) + + errors = softmax_cross_entropy(logits, lddt_ca_one_hot) + all_atom_mask = all_atom_mask.squeeze(-1) + loss = torch.sum(errors * all_atom_mask, dim=-1) / ( + eps + torch.sum(all_atom_mask, dim=-1) + ) + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + min_bin=2.3125, + max_bin=21.6875, + no_bins=64, + eps=1e-6, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + boundaries = boundaries ** 2 + + dists = torch.sum( + (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + true_bins = torch.sum(dists > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] + + # FP16-friendly sum. Equivalent to: + # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / + # (eps + torch.sum(square_mask, dim=(-1, -2)))) + denom = eps + torch.sum(square_mask, dim=(-1, -2)) + mean = errors * square_mask + mean = torch.sum(mean, dim=-1) + mean = mean / denom[..., None] + mean = torch.sum(mean, dim=-1) + + # Average over the batch dimensions + mean = torch.mean(mean) + + return mean + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat( + [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 + ) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + ( + predicted_aligned_error, + max_predicted_aligned_error, + ) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + bin_centers = _calculate_bin_centers(boundaries) + torch.sum(residue_weights) + n = logits.shape[-2] + clipped_n = max(n, 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] + + +def tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + max_bin=31, + no_bins=64, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps=1e-8, + **kwargs, +): + pred_affine = Rigid.from_tensor_7(final_affine_tensor) + backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + def _points(affine): + pts = affine.get_trans()[..., None, :, :] + return affine.invert()[..., None].apply(pts) + + sq_diff = torch.sum( + (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1 + ) + + sq_diff = sq_diff.detach() + + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + boundaries = boundaries ** 2 + true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, torch.nn.functional.one_hot(true_bins, no_bins) + ) + + square_mask = ( + backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :] + ) + + loss = torch.sum(errors * square_mask, dim=-1) + scale = 0.5 # hack to help FP16 training along + denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + # Average over the loss dimension + loss = torch.mean(loss) + + return loss + + +def between_residue_bond_loss( + pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3) + pred_atom_mask: torch.Tensor, # (*, N, 37/14) + residue_index: torch.Tensor, # (*, N) + aatype: torch.Tensor, # (*, N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0, + eps=1e-6, +) -> Dict[str, torch.Tensor]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + this_c_pos = pred_atom_positions[..., :-1, 2, :] + this_c_mask = pred_atom_mask[..., :-1, 2] + next_n_pos = pred_atom_positions[..., 1:, 0, :] + next_n_mask = pred_atom_mask[..., 1:, 0] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + + # Compute loss for the C--N bond. + c_n_bond_length = torch.sqrt( + eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1) + ) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] + gt_length = ( + ~next_is_proline + ) * residue_constants.between_res_bond_length_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_c_n[ + 1 + ] + gt_stddev = ( + ~next_is_proline + ) * residue_constants.between_res_bond_length_stddev_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[ + 1 + ] + c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) + c_n_loss_per_residue = torch.nn.functional.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev + ) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev) + ) + + # Compute loss for the angles. + ca_c_bond_length = torch.sqrt( + eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1) + ) + n_ca_bond_length = torch.sqrt( + eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1) + ) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] + + ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = torch.sqrt( + eps + (ca_c_n_cos_angle - gt_angle) ** 2 + ) + ca_c_n_loss_per_residue = torch.nn.functional.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev + ) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + ca_c_n_violation_mask = mask * ( + ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev) + ) + + c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = torch.sqrt( + eps + torch.square(c_n_ca_cos_angle - gt_angle) + ) + c_n_ca_loss_per_residue = torch.nn.functional.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev + ) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / ( + torch.sum(mask, dim=-1) + eps + ) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev) + ) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = ( + c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue + ) + per_residue_loss_sum = 0.5 * ( + torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + + torch.nn.functional.pad(per_residue_loss_sum, (1, 0)) + ) + + # Compute hard violations. + violation_mask = torch.max( + torch.stack( + [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], + dim=-2, + ), + dim=-2, + )[0] + violation_mask = torch.maximum( + torch.nn.functional.pad(violation_mask, (0, 1)), + torch.nn.functional.pad(violation_mask, (1, 0)), + ) + + return { + "c_n_loss_mean": c_n_loss, + "ca_c_n_loss_mean": ca_c_n_loss, + "c_n_ca_loss_mean": c_n_ca_loss, + "per_residue_loss_sum": per_residue_loss_sum, + "per_residue_violation_mask": violation_mask, + } + + +def between_residue_clash_loss( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_atom_radius: torch.Tensor, + residue_index: torch.Tensor, + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + fp_type = atom14_pred_positions.dtype + + # Create the distance matrix. + # (N, N, 14, 14) + dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., :, None, :, None, :] + - atom14_pred_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = ( + atom14_atom_exists[..., :, None, :, None] + * atom14_atom_exists[..., None, :, None, :] + ).type(fp_type) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask = dists_mask * ( + residue_index[..., :, None, None, None] + < residue_index[..., None, :, None, None] + ) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = torch.nn.functional.one_hot( + residue_index.new_tensor(2), num_classes=14 + ) + c_one_hot = c_one_hot.reshape( + *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape + ) + c_one_hot = c_one_hot.type(fp_type) + n_one_hot = torch.nn.functional.one_hot( + residue_index.new_tensor(0), num_classes=14 + ) + n_one_hot = n_one_hot.reshape( + *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape + ) + n_one_hot = n_one_hot.type(fp_type) + + neighbour_mask = ( + residue_index[..., :, None, None, None] + 1 + ) == residue_index[..., None, :, None, None] + c_n_bonds = ( + neighbour_mask + * c_one_hot[..., None, None, :, None] + * n_one_hot[..., None, None, None, :] + ) + dists_mask = dists_mask * (1.0 - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys = residue_constants.restype_name_to_atom14_names["CYS"] + cys_sg_idx = cys.index("SG") + cys_sg_idx = residue_index.new_tensor(cys_sg_idx) + cys_sg_idx = cys_sg_idx.reshape( + *((1,) * len(residue_index.shape[:-1])), 1 + ).squeeze(-1) + cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = ( + cys_sg_one_hot[..., None, None, :, None] + * cys_sg_one_hot[..., None, None, None, :] + ) + dists_mask = dists_mask * (1.0 - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom14_atom_radius[..., :, None, :, None] + + atom14_atom_radius[..., None, :, None, :] + ) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * torch.nn.functional.relu( + dists_lower_bound - overlap_tolerance_soft - dists + ) + + # Compute the mean loss. + # shape () + mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum( + dists_to_low_error, axis=(-3, -1) + ) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard) + ) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = torch.maximum( + torch.amax(clash_mask, axis=(-4, -2)), + torch.amax(clash_mask, axis=(-3, -1)), + ) + + return { + "mean_loss": mean_loss, # shape () + "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) + "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_dists_lower_bound: torch.Tensor, + atom14_dists_upper_bound: torch.Tensor, + tighten_bounds_for_loss=0.0, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions ([*, N, 14, 3]): + Predicted positions of atoms in global prediction frame. + atom14_atom_exists ([*, N, 14]): + Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound ([*, N, 14]): + Lower bound on allowed distances. + atom14_dists_upper_bound ([*, N, 14]): + Upper bound on allowed distances + tighten_bounds_for_loss ([*, N]): + Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum' ([*, N, 14]): + sum of all clash losses per atom, shape + * 'per_atom_clash_mask' ([*, N, 14]): + mask whether atom clashes with any other atom shape + """ + # Compute the mask for each residue. + dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] + dists_masks = dists_masks.reshape( + *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape + ) + dists_masks = ( + atom14_atom_exists[..., :, :, None] + * atom14_atom_exists[..., :, None, :] + * dists_masks + ) + + # Distance matrix + dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., :, :, None, :] + - atom14_pred_positions[..., :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + # Compute the loss. + dists_to_low_error = torch.nn.functional.relu( + atom14_dists_lower_bound + tighten_bounds_for_loss - dists + ) + dists_to_high_error = torch.nn.functional.relu( + dists - (atom14_dists_upper_bound - tighten_bounds_for_loss) + ) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) + + # Compute the violations mask. + violations = dists_masks * ( + (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound) + ) + + # Compute the per atom violations. + per_atom_violations = torch.maximum( + torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0] + ) + + return { + "per_atom_loss_sum": per_atom_loss_sum, + "per_atom_violations": per_atom_violations, + } + + +def find_structural_violations( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + violation_tolerance_factor: float, + clash_overlap_tolerance: float, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + aatype=batch["aatype"], + tolerance_factor_soft=violation_tolerance_factor, + tolerance_factor_hard=violation_tolerance_factor, + ) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ] + atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) + atom14_atom_radius = ( + batch["atom14_atom_exists"] + * atomtype_radius[batch["residx_atom14_to_atom37"]] + ) + + # Compute the between residue clash loss. + between_residue_clashes = between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch["residue_index"], + overlap_tolerance_soft=clash_overlap_tolerance, + overlap_tolerance_hard=clash_overlap_tolerance, + ) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=clash_overlap_tolerance, + bond_length_tolerance_factor=violation_tolerance_factor, + ) + atom14_atom_exists = batch["atom14_atom_exists"] + atom14_dists_lower_bound = atom14_pred_positions.new_tensor( + restype_atom14_bounds["lower_bound"] + )[batch["aatype"]] + atom14_dists_upper_bound = atom14_pred_positions.new_tensor( + restype_atom14_bounds["upper_bound"] + )[batch["aatype"]] + residue_violations = within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0, + ) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = torch.max( + torch.stack( + [ + connection_violations["per_residue_violation_mask"], + torch.max( + between_residue_clashes["per_atom_clash_mask"], dim=-1 + )[0], + torch.max(residue_violations["per_atom_violations"], dim=-1)[0], + ], + dim=-1, + ), + dim=-1, + )[0] + + return { + "between_residues": { + "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # () + "angles_ca_c_n_loss_mean": connection_violations[ + "ca_c_n_loss_mean" + ], # () + "angles_c_n_ca_loss_mean": connection_violations[ + "c_n_ca_loss_mean" + ], # () + "connections_per_residue_loss_sum": connection_violations[ + "per_residue_loss_sum" + ], # (N) + "connections_per_residue_violation_mask": connection_violations[ + "per_residue_violation_mask" + ], # (N) + "clashes_mean_loss": between_residue_clashes["mean_loss"], # () + "clashes_per_atom_loss_sum": between_residue_clashes[ + "per_atom_loss_sum" + ], # (N, 14) + "clashes_per_atom_clash_mask": between_residue_clashes[ + "per_atom_clash_mask" + ], # (N, 14) + }, + "within_residues": { + "per_atom_loss_sum": residue_violations[ + "per_atom_loss_sum" + ], # (N, 14) + "per_atom_violations": residue_violations[ + "per_atom_violations" + ], # (N, 14), + }, + "total_per_residue_violations_mask": per_residue_violations_mask, # (N) + } + + +def find_structural_violations_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + config: ml_collections.ConfigDict, +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + + out = find_structural_violations(batch, atom14_pred_positions, **config) + + to_np = lambda x: np.array(x) + np_out = tensor_tree_map(to_np, out) + + return np_out + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: torch.Tensor, # (N, 37(14), 3) + pred_atom_mask: torch.Tensor, # (N, 37(14)) + residue_index: torch.Tensor, # (N) + max_angstrom_tolerance=1.5, + eps=1e-6, +) -> torch.Tensor: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + ca_ca_distance = torch.sqrt( + eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1) + ) + violations = ( + ca_ca_distance - residue_constants.ca_ca + ) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + mean = masked_mean(mask, violations, -1) + return mean + + +def compute_violation_metrics( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, # (N, 14, 3) + violations: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Compute several metrics to assess the structural violations.""" + ret = {} + extreme_ca_ca_violations = extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + ) + ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations + ret["violations_between_residue_bond"] = masked_mean( + batch["seq_mask"], + violations["between_residues"][ + "connections_per_residue_violation_mask" + ], + dim=-1, + ) + ret["violations_between_residue_clash"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["between_residues"]["clashes_per_atom_clash_mask"], + dim=-1, + )[0], + dim=-1, + ) + ret["violations_within_residue"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["within_residues"]["per_atom_violations"], dim=-1 + )[0], + dim=-1, + ) + ret["violations_per_residue"] = masked_mean( + mask=batch["seq_mask"], + value=violations["total_per_residue_violations_mask"], + dim=-1, + ) + return ret + + +def compute_violation_metrics_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + violations: Dict[str, np.ndarray], +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + violations = tree_map(to_tensor, violations, np.ndarray) + + out = compute_violation_metrics(batch, atom14_pred_positions, violations) + + to_np = lambda x: np.array(x) + return tree_map(to_np, out, torch.Tensor) + + +def violation_loss( + violations: Dict[str, torch.Tensor], + atom14_atom_exists: torch.Tensor, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + num_atoms = torch.sum(atom14_atom_exists) + l_clash = torch.sum( + violations["between_residues"]["clashes_per_atom_loss_sum"] + + violations["within_residues"]["per_atom_loss_sum"] + ) + l_clash = l_clash / (eps + num_atoms) + loss = ( + violations["between_residues"]["bonds_c_n_loss_mean"] + + violations["between_residues"]["angles_ca_c_n_loss_mean"] + + violations["between_residues"]["angles_c_n_ca_loss_mean"] + + l_clash + ) + + return loss + + +def compute_renamed_ground_truth( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """ + Find optimal renaming of ground truth based on the predicted positions. + + Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + + pred_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., None, :, None, :] + - atom14_pred_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_gt_positions = batch["atom14_gt_positions"] + gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_gt_positions[..., None, :, None, :] + - atom14_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] + alt_gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_alt_gt_positions[..., None, :, None, :] + - atom14_alt_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) + alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) + + atom14_gt_exists = batch["atom14_gt_exists"] + atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] + mask = ( + atom14_gt_exists[..., None, :, None] + * atom14_atom_is_ambiguous[..., None, :, None] + * atom14_gt_exists[..., None, :, None, :] + * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) + ) + + per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) + alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) + + fp_type = atom14_pred_positions.dtype + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) + + renamed_atom14_gt_positions = ( + 1.0 - alt_naming_is_better[..., None, None] + ) * atom14_gt_positions + alt_naming_is_better[ + ..., None, None + ] * atom14_alt_gt_positions + + renamed_atom14_gt_mask = ( + 1.0 - alt_naming_is_better[..., None] + ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[ + "atom14_alt_gt_exists" + ] + + return { + "alt_naming_is_better": alt_naming_is_better, + "renamed_atom14_gt_positions": renamed_atom14_gt_positions, + "renamed_atom14_gt_exists": renamed_atom14_gt_mask, + } + + +def experimentally_resolved_loss( + logits: torch.Tensor, + atom37_atom_exists: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + min_resolution: float, + max_resolution: float, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + errors = sigmoid_cross_entropy(logits, all_atom_mask) + loss = torch.sum(errors * atom37_atom_exists, dim=-1) + loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) + loss = torch.sum(loss, dim=-1) + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + loss = torch.mean(loss) + + return loss + + +def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): + """ + Computes BERT-style masked MSA loss. Implements subsection 1.9.9. + + Args: + logits: [*, N_seq, N_res, 23] predicted residue distribution + true_msa: [*, N_seq, N_res] true MSA + bert_mask: [*, N_seq, N_res] MSA mask + Returns: + Masked MSA loss + """ + errors = softmax_cross_entropy( + logits, torch.nn.functional.one_hot(true_msa, num_classes=23) + ) + + # FP16-friendly averaging. Equivalent to: + # loss = ( + # torch.sum(errors * bert_mask, dim=(-1, -2)) / + # (eps + torch.sum(bert_mask, dim=(-1, -2))) + # ) + loss = errors * bert_mask + loss = torch.sum(loss, dim=-1) + scale = 0.5 + denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = torch.mean(loss) + + return loss + + +def compute_drmsd(structure_1, structure_2, mask=None): + if(mask is not None): + structure_1 = structure_1 * mask[..., None] + structure_2 = structure_2 * mask[..., None] + + d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] + d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] + + d1 = d1 ** 2 + d2 = d2 ** 2 + + d1 = torch.sqrt(torch.sum(d1, dim=-1)) + d2 = torch.sqrt(torch.sum(d2, dim=-1)) + + drmsd = d1 - d2 + drmsd = drmsd ** 2 + drmsd = torch.sum(drmsd, dim=(-1, -2)) + n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) + drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) + drmsd = torch.sqrt(drmsd) + + return drmsd + + +def compute_drmsd_np(structure_1, structure_2, mask=None): + structure_1 = torch.tensor(structure_1) + structure_2 = torch.tensor(structure_2) + if(mask is not None): + mask = torch.tensor(mask) + + return compute_drmsd(structure_1, structure_2, mask) + + +class AlphaFoldLoss(nn.Module): + """Aggregation of the various losses described in the supplement""" + def __init__(self, config): + super(AlphaFoldLoss, self).__init__() + self.config = config + + def forward(self, out, batch, _return_breakdown=False): + if "violation" not in out.keys(): + out["violation"] = find_structural_violations( + batch, + out["sm"]["positions"][-1], + **self.config.violation, + ) + + if "renamed_atom14_gt_positions" not in out.keys(): + batch.update( + compute_renamed_ground_truth( + batch, + out["sm"]["positions"][-1], + ) + ) + + loss_fns = { + "distogram": lambda: distogram_loss( + logits=out["distogram_logits"], + **{**batch, **self.config.distogram}, + ), + "experimentally_resolved": lambda: experimentally_resolved_loss( + logits=out["experimentally_resolved_logits"], + **{**batch, **self.config.experimentally_resolved}, + ), + "fape": lambda: fape_loss( + out, + batch, + self.config.fape, + ), + "lddt": lambda: lddt_loss( + logits=out["lddt_logits"], + all_atom_pred_pos=out["final_atom_positions"], + **{**batch, **self.config.lddt}, + ), + "masked_msa": lambda: masked_msa_loss( + logits=out["masked_msa_logits"], + **{**batch, **self.config.masked_msa}, + ), + "supervised_chi": lambda: supervised_chi_loss( + out["sm"]["angles"], + out["sm"]["unnormalized_angles"], + **{**batch, **self.config.supervised_chi}, + ), + "violation": lambda: violation_loss( + out["violation"], + **batch, + ), + } + + if(self.config.tm.enabled): + loss_fns["tm"] = lambda: tm_loss( + logits=out["tm_logits"], + **{**batch, **out, **self.config.tm}, + ) + + cum_loss = 0. + losses = {} + for loss_name, loss_fn in loss_fns.items(): + weight = self.config[loss_name].weight + loss = loss_fn() + if(torch.isnan(loss) or torch.isinf(loss)): + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0., requires_grad=True) + cum_loss = cum_loss + weight * loss + losses[loss_name] = loss.detach().clone() + + losses["unscaled_loss"] = cum_loss.detach().clone() + + # Scale the loss by the square root of the minimum of the crop size and + # the (average) sequence length. See subsection 1.9. + seq_len = torch.mean(batch["seq_length"].float()) + crop_len = batch["aatype"].shape[-1] + cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + + losses["loss"] = cum_loss.detach().clone() + + if(not _return_breakdown): + return cum_loss + + return cum_loss, losses diff --git a/openfold/utils/lr_schedulers.py b/openfold/utils/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d068be5ec826e7d1f8d49b3459b7cc66547f9f8 --- /dev/null +++ b/openfold/utils/lr_schedulers.py @@ -0,0 +1,107 @@ +import torch + + +class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): + """ Implements the learning rate schedule defined in the AlphaFold 2 + supplement. A linear warmup is followed by a plateau at the maximum + learning rate and then exponential decay. + + Note that the initial learning rate of the optimizer in question is + ignored; use this class' base_lr parameter to specify the starting + point of the warmup. + """ + def __init__(self, + optimizer, + last_epoch: int = -1, + verbose: bool = False, + base_lr: float = 0., + max_lr: float = 0.001, + warmup_no_steps: int = 1000, + start_decay_after_n_steps: int = 50000, + decay_every_n_steps: int = 50000, + decay_factor: float = 0.95, + ): + step_counts = { + "warmup_no_steps": warmup_no_steps, + "start_decay_after_n_steps": start_decay_after_n_steps, + } + + for k,v in step_counts.items(): + if(v < 0): + raise ValueError(f"{k} must be nonnegative") + + if(warmup_no_steps > start_decay_after_n_steps): + raise ValueError( + "warmup_no_steps must not exceed start_decay_after_n_steps" + ) + + self.optimizer = optimizer + self.last_epoch = last_epoch + self.verbose = verbose + self.base_lr = base_lr + self.max_lr = max_lr + self.warmup_no_steps = warmup_no_steps + self.start_decay_after_n_steps = start_decay_after_n_steps + self.decay_every_n_steps = decay_every_n_steps + self.decay_factor = decay_factor + + super(AlphaFoldLRScheduler, self).__init__( + optimizer, + last_epoch=last_epoch, + verbose=verbose, + ) + + def state_dict(self): + state_dict = { + k:v for k,v in self.__dict__.items() if k not in ["optimizer"] + } + + return state_dict + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def get_lr(self): + if(not self._get_lr_called_within_step): + raise RuntimeError( + "To get the last learning rate computed by the scheduler, use " + "get_last_lr()" + ) + + step_no = self.last_epoch + + if(step_no <= self.warmup_no_steps): + lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr + elif(step_no > self.start_decay_after_n_steps): + steps_since_decay = step_no - self.start_decay_after_n_steps + exp = (steps_since_decay // self.decay_every_n_steps) + 1 + lr = self.max_lr * (self.decay_factor ** exp) + else: # plateau + lr = self.max_lr + + return [lr for group in self.optimizer.param_groups] + + +class TestAF2LRScheduler(AlphaFoldLRScheduler): + def __init__(self, + optimizer, + last_epoch: int = -1, + verbose: bool = False, + base_lr: float = 0., + max_lr: float = 0.0001, + warmup_no_steps: int = 10, + start_decay_after_n_steps: int = 100, + decay_every_n_steps: int = 10, + decay_factor: float = 0.95, + ): + super().__init__( + optimizer, + last_epoch, + verbose, + base_lr, + max_lr, + warmup_no_steps, + start_decay_after_n_steps, + decay_every_n_steps, + decay_factor, + ) \ No newline at end of file diff --git a/openfold/utils/precision_utils.py b/openfold/utils/precision_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81ef568637a1f6ee30a8fc295edbb60a70f4c602 --- /dev/null +++ b/openfold/utils/precision_utils.py @@ -0,0 +1,23 @@ +# Copyright 2022 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + +import torch + +def is_fp16_enabled(): + # Autocast world + fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled diff --git a/openfold/utils/rigid_utils.py b/openfold/utils/rigid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a49d48d76764c42516f6e33da9f738fb45d0c7b5 --- /dev/null +++ b/openfold/utils/rigid_utils.py @@ -0,0 +1,1468 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Any, Sequence, Callable, Optional + +import numpy as np +import torch + + +def rot_matmul( + a: torch.Tensor, + b: torch.Tensor +) -> torch.Tensor: + """ + Performs matrix multiplication of two rotation matrix tensors. Written + out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + row_1 = torch.stack( + [ + a[..., 0, 0] * b[..., 0, 0] + + a[..., 0, 1] * b[..., 1, 0] + + a[..., 0, 2] * b[..., 2, 0], + a[..., 0, 0] * b[..., 0, 1] + + a[..., 0, 1] * b[..., 1, 1] + + a[..., 0, 2] * b[..., 2, 1], + a[..., 0, 0] * b[..., 0, 2] + + a[..., 0, 1] * b[..., 1, 2] + + a[..., 0, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_2 = torch.stack( + [ + a[..., 1, 0] * b[..., 0, 0] + + a[..., 1, 1] * b[..., 1, 0] + + a[..., 1, 2] * b[..., 2, 0], + a[..., 1, 0] * b[..., 0, 1] + + a[..., 1, 1] * b[..., 1, 1] + + a[..., 1, 2] * b[..., 2, 1], + a[..., 1, 0] * b[..., 0, 2] + + a[..., 1, 1] * b[..., 1, 2] + + a[..., 1, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_3 = torch.stack( + [ + a[..., 2, 0] * b[..., 0, 0] + + a[..., 2, 1] * b[..., 1, 0] + + a[..., 2, 2] * b[..., 2, 0], + a[..., 2, 0] * b[..., 0, 1] + + a[..., 2, 1] * b[..., 1, 1] + + a[..., 2, 2] * b[..., 2, 1], + a[..., 2, 0] * b[..., 0, 2] + + a[..., 2, 1] * b[..., 1, 2] + + a[..., 2, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack([row_1, row_2, row_3], dim=-2) + + +def rot_vec_mul( + r: torch.Tensor, + t: torch.Tensor +) -> torch.Tensor: + """ + Applies a rotation to a vector. Written out by hand to avoid transfer + to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + x = t[..., 0] + y = t[..., 1] + z = t[..., 2] + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + + +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad + ) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + + return rots + + +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros( + (*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + return trans + + +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros( + (*batch_dims, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """ + Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = quat.new_tensor(_QTR_MAT, requires_grad=False) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if(rot.shape[-2:] != (3, 3)): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,] + ] + + k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = quat1.new_tensor(_QUAT_MULTIPLY) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat1[..., :, None, None] * + quat2[..., None, :, None], + dim=(-3, -2) + ) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat[..., :, None, None] * + vec[..., None, :, None], + dim=(-3, -2) + ) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """ + A 3D rotation. Depending on how the object is initialized, the + rotation is represented by either a rotation matrix or a + quaternion, though both formats are made available by helper functions. + To simplify gradient computation, the underlying format of the + rotation cannot be changed in-place. Like Rigid, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotation + object were a tensor of rotations, in one format or another. + """ + def __init__(self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """ + Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with + quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not True, must be a unit quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if((rot_mats is None and quats is None) or + (rot_mats is not None and quats is not None)): + raise ValueError("Exactly one input argument must be specified") + + if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or + (quats is not None and quats.shape[-1] != 4)): + raise ValueError( + "Incorrectly shaped rotation matrix or quaternion" + ) + + # Force full-precision + if(quats is not None): + quats = quats.type(torch.float32) + if(rot_mats is not None): + rot_mats = rot_mats.type(torch.float32) + + if(quats is not None and normalize_quats): + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ): + """ + Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation + for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object + should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format + of the new object's rotation + Returns: + A new identity rotation + """ + if(fmt == "rot_mat"): + rot_mats = identity_rot_mats( + shape, dtype, device, requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif(fmt == "quat"): + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any): + """ + Allows torch-style indexing over the virtual shape of the rotation + object. See documentation for the shape property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if(self._rot_mats is not None): + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif(self._quats is not None): + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__(self, + right: torch.Tensor, + ): + """ + Pointwise left multiplication of the rotation with a tensor. Can be + used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__(self, + left: torch.Tensor, + ): + """ + Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """ + Returns the virtual shape of the rotation object. This shape is + defined as the batch dimensions of the underlying rotation matrix + or quaternion. If the Rotation was initialized with a [10, 3, 3] + rotation matrix tensor, for example, the resulting shape would be + [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if(self._quats is not None): + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.dtype + elif(self._quats is not None): + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """ + The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.device + elif(self._quats is not None): + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """ + Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if(self._rot_mats is not None): + return self._rot_mats.requires_grad + elif(self._quats is not None): + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if(rot_mats is None): + if(self._quats is None): + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a + quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if(quats is None): + if(self._rot_mats is None): + raise ValueError("Both rotations are None") + else: + quats = rot_to_quat(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """ + Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if(self._rot_mats is not None): + return self._rot_mats + elif(self._quats is not None): + return self._quats + else: + raise ValueError("Both rotations are None") + + def get_rotvec(self, eps=1e-6) -> torch.Tensor: + """ + Return the underlying axis-angle rotation vector. + + Follow's scipy's implementation: + https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402 + + Returns: + The stored rotation as a axis-angle vector. + """ + quat = self.get_quats() + # w > 0 to ensure 0 <= angle <= pi + flip = (quat[..., :1] < 0).float() + quat = (-1 * quat) * flip + (1 - flip) * quat + + angle = 2 * torch.atan2( + torch.linalg.norm(quat[..., 1:], dim=-1), + quat[..., 0] + ) + + angle2 = angle * angle + small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880 + large_angle_scales = angle / torch.sin(angle / 2 + eps) + + small_angles = (angle <= 1e-3).float() + rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales + rot_vec = rot_vec_scale[..., None] * quat[..., 1:] + return rot_vec + + # Rotation functions + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True, + update_mask: torch.Tensor = None, + ): + """ + Returns a new quaternion Rotation after updating the current + object's underlying rotation with a quaternion update, formatted + as a [*, 3] tensor whose final three columns represent x, y, z such + that (1, x, y, z) is the desired (not necessarily unit) quaternion + update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + quat_update = quat_multiply_by_vec(quats, q_update_vec) + if update_mask is not None: + quat_update = quat_update * update_mask + new_quats = quats + quat_update + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r): + """ + Compose the rotation matrices of the current Rotation object with + those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r, normalize_quats: bool = True): + """ + Compose the quaternions of the current Rotation object with those + of another. + + Depending on whether either Rotation was initialized with + quaternions, this function may call torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation( + rot_mats=None, quats=new_quats, normalize_quats=normalize_quats + ) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Apply the current Rotation as a rotation matrix to a set of 3D + coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) : + """ + Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=invert_rot_mat(self._rot_mats), + quats=None + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze(self, + dim: int, + ): + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs, + dim: int, + ): + """ + Concatenates rotations along one of the batch dimensions. Analogous + to torch.cat(). + + Note that the output of this operation is always a rotation matrix, + regardless of the format of input rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be + concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, + fn + ): + """ + Apply a Tensor -> Tensor function to underlying rotation tensors, + mapping over the rotation dimension(s). Can be used e.g. to sum out + a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if(self._rot_mats is not None): + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack( + list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1 + ) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = torch.stack( + list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1 + ) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self): + """ + Analogous to the cuda() method of torch Tensors + + Returns: + A copy of the Rotation in CUDA memory + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.cuda(), + normalize_quats=False + ) + else: + raise ValueError("Both rotations are None") + + def to(self, + device: Optional[torch.device], + dtype: Optional[torch.dtype] + ): + """ + Analogous to the to() method of torch Tensors + + Args: + device: + A torch device + dtype: + A torch dtype + Returns: + A copy of the Rotation using the new device and dtype + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=self._rot_mats.to(device=device, dtype=dtype), + quats=None, + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.to(device=device, dtype=dtype), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + def detach(self): + """ + Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """ + A class representing a rigid transformation. Little more than a wrapper + around two objects: a Rotation object and a [*, 3] translation + Designed to behave approximately like a single torch tensor with the + shape of the shared batch dimensions of its component parts. + """ + def __init__(self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """ + Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if(trans is not None): + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif(rots is not None): + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if(rots is None): + rots = Rotation.identity( + batch_dims, dtype, device, requires_grad, + ) + elif(trans is None): + trans = identity_trans( + batch_dims, dtype, device, requires_grad, + ) + + if((rots.shape != trans.shape[:-1]) or + (rots.device != trans.device)): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.type(torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ): + """ + Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__(self, + index: Any, + ): + """ + Indexes the affine transformation with PyTorch-style indices. + The index is applied to the shared dimensions of both the rotation + and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + t = Rigid(r, torch.rand(10, 10, 3)) + indexed = t[3, 4:6] + assert(indexed.shape == (2,)) + assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__(self, + right: torch.Tensor, + ): + """ + Pointwise left multiplication of the transformation with a tensor. + Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__(self, + left: torch.Tensor, + ): + """ + Reverse pointwise multiplication of the transformation with a + tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """ + Returns the shape of the shared dimensions of the rotation and + the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """ + Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + def get_rots(self) -> Rotation: + """ + Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """ + Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + update_mask: torch.Tensor=None, + ): + """ + Composes the transformation with a quaternion update vector of + shape [*, 6], where the final 6 columns represent the x, y, and + z values of a quaternion of form (1, x, y, z) followed by a 3D + translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec( + q_vec, update_mask=update_mask) + + trans_update = self._rots.apply(t_vec) + if update_mask is not None: + trans_update = trans_update * update_mask + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose_tran_update_vec(self, + t_vec: torch.Tensor, + update_mask: torch.Tensor=None, + ): + """ + Composes the transformation with a quaternion update vector of + shape [*, 3], where columns represent a 3D translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + trans_update = self._rots.apply(t_vec) + if update_mask is not None: + trans_update = trans_update * update_mask + new_translation = self._trans + trans_update + + return Rigid(self._rots, new_translation) + + def compose(self, + r, + ): + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def compose_r(self, + rot, + order='right' + ): + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + order: + Order in which to perform rotation multiplication. + Returns: + The composition of the two transformations + """ + if order == 'right': + new_rot = self._rots.compose_r(rot) + elif order == 'left': + new_rot = rot.compose_r(self._rots) + else: + raise ValueError(f'Unrecognized multiplication order: {order}') + return Rigid(new_rot, self._trans) + + def apply(self, + pts: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, + pts: torch.Tensor + ) -> torch.Tensor: + """ + Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self): + """ + Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, + fn + ): + """ + Apply a Tensor -> Tensor function to underlying translation and + rotation tensors, mapping over the translation/rotation dimensions + respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._trans, dim=-1))), + dim=-1 + ) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """ + Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4( + t: torch.Tensor + ): + """ + Constructs a transformation from a homogenous transformation + tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if(t.shape[-2:] != (4, 4)): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """ + Converts a transformation to a tensor with 7 final columns, four + for the quaternion followed by three for the translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ): + if(t.shape[-1] != 7): + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation( + rot_mats=None, + quats=quats, + normalize_quats=normalize_quats + ) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8 + ): + """ + Implements algorithm 21. Constructs transformations from sets of 3 + points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze(self, + dim: int, + ): + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts, + dim: int, + ): + """ + Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be + concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat( + [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1 + ) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn): + """ + Applies a Rotation -> Rotation function to the stored rotation + object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn): + """ + Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the + translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float): + """ + Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + fn = lambda t: t * trans_scale_factor + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self): + """ + Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + fn = lambda r: r.detach() + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """ + Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you + provide the atom positions in the non-standard way, the N atom will + end up not at [-0.527250, 1.359329, 0.0] but instead at + [-0.527250, -1.359329, 0.0]. You need to take care of such cases in + your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + Returns: + A transformation object. After applying the translation and + rotation to the reference backbone, the coordinates will + approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = sin_c1.new_zeros(sin_c1.shape) + ones = sin_c1.new_ones(sin_c1.shape) + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c1_rots[..., 2, 0] = -1 * sin_c2 + c1_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self): + """ + Moves the transformation object to GPU memory + + Returns: + A version of the transformation on GPU + """ + return Rigid(self._rots.cuda(), self._trans.cuda()) diff --git a/openfold/utils/seed.py b/openfold/utils/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..b45b81379257909e06b00839bbb90848cf3d0b3f --- /dev/null +++ b/openfold/utils/seed.py @@ -0,0 +1,19 @@ +import os +import logging +import random +import numpy as np +from pytorch_lightning.utilities.seed import seed_everything + +from openfold.utils.suppress_output import SuppressLogging + + +def seed_globally(seed=None): + if("PL_GLOBAL_SEED" not in os.environ): + if(seed is None): + seed = random.randint(0, np.iinfo(np.uint32).max) + os.environ["PL_GLOBAL_SEED"] = str(seed) + logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}') + + # seed_everything is a bit log-happy + with SuppressLogging(logging.INFO): + seed_everything(seed=None) diff --git a/openfold/utils/superimposition.py b/openfold/utils/superimposition.py new file mode 100644 index 0000000000000000000000000000000000000000..835498f014b8083b3f715bfb3b8dd426ce296f3e --- /dev/null +++ b/openfold/utils/superimposition.py @@ -0,0 +1,107 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from Bio.SVDSuperimposer import SVDSuperimposer +import numpy as np +import torch + + +def _superimpose_np(reference, coords): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [N, 3] reference array + coords: + [N, 3] array + Returns: + A tuple of [N, 3] superimposed coords and the final RMSD. + """ + sup = SVDSuperimposer() + sup.set(reference, coords) + sup.run() + return sup + +def _superimpose_single(reference, coords): + reference_np = reference.detach().cpu().numpy() + coords_np = coords.detach().cpu().numpy() + sup = _superimpose_np(reference_np, coords_np) + rot, tran = sup.get_rotran() + superimposed, rmsd = sup.get_transformed(), sup.get_rms() + return coords.new_tensor(superimposed), coords.new_tensor(rmsd), rot, tran + + +def superimpose(reference, coords, mask, return_transform=False): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [*, N, 3] reference tensor + coords: + [*, N, 3] tensor + mask: + [*, N] tensor + Returns: + A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. + """ + def select_unmasked_coords(coords, mask): + return torch.masked_select( + coords, + (mask > 0.)[..., None], + ).reshape(-1, 3) + + batch_dims = reference.shape[:-2] + flat_reference = reference.reshape((-1,) + reference.shape[-2:]) + flat_coords = coords.reshape((-1,) + reference.shape[-2:]) + flat_mask = mask.reshape((-1,) + mask.shape[-1:]) + superimposed_list = [] + rmsds = [] + rots = [] + trans = [] + for r, c, m in zip(flat_reference, flat_coords, flat_mask): + r_unmasked_coords = select_unmasked_coords(r, m) + c_unmasked_coords = select_unmasked_coords(c, m) + superimposed, rmsd, rot, tran = _superimpose_single( + r_unmasked_coords, + c_unmasked_coords + ) + rots.append(rot) + trans.append(tran) + # This is very inelegant, but idk how else to invert the masking + # procedure. + count = 0 + superimposed_full_size = torch.zeros_like(r) + for i, unmasked in enumerate(m): + if(unmasked): + superimposed_full_size[i] = superimposed[count] + count += 1 + + superimposed_list.append(superimposed_full_size) + rmsds.append(rmsd) + + superimposed_stacked = torch.stack(superimposed_list, dim=0) + rmsds_stacked = torch.stack(rmsds, dim=0) + rots_stacked = torch.tensor(np.stack(rots, axis=0), device=coords.device) + trans_stacked = torch.tensor(np.stack(trans, axis=0), device=coords.device) + + superimposed_reshaped = superimposed_stacked.reshape( + batch_dims + coords.shape[-2:] + ) + rmsds_reshaped = rmsds_stacked.reshape( + batch_dims + ) + if return_transform: + return superimposed_reshaped, rmsds_reshaped, rots_stacked, trans_stacked + return superimposed_reshaped, rmsds_reshaped diff --git a/openfold/utils/suppress_output.py b/openfold/utils/suppress_output.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c821b36a6b8d047e541a51faceef037f48ff90 --- /dev/null +++ b/openfold/utils/suppress_output.py @@ -0,0 +1,26 @@ +import logging +import sys + + +class SuppressStdout: + def __enter__(self): + self.stdout = sys.stdout + dev_null = open("/dev/null", "w") + sys.stdout = dev_null + + def __exit__(self, typ, value, traceback): + fp = sys.stdout + sys.stdout = self.stdout + fp.close() + + +class SuppressLogging: + def __init__(self, level): + self.level = level + + def __enter__(self): + logging.disable(self.level) + + def __exit__(self, typ, value, traceback): + logging.disable(logging.NOTSET) + diff --git a/openfold/utils/tensor_utils.py b/openfold/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5e8e4b6b5e9a26ccf1c5cedf1ff10102e75b2f --- /dev/null +++ b/openfold/utils/tensor_utils.py @@ -0,0 +1,408 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn +from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in + sequence on a tensor with shape dims, yields tensors that contain every + leaf in the contiguous range [start, end]. Care is taken to yield a + short sequence of slices, and perhaps even the shortest possible (I'm + pretty sure it's the latter). + + end is INCLUSIVE. + """ + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if(start_edges is None): + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if(end_edges is None): + end_edges = [e == (d - 1) for e,d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if(len(start) == 0): + return [tuple()] + elif(len(start) == 1): + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s,e in zip(start, end): + if(s == e): + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if(divergence_idx == len(dims)): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s for s in + _get_minimal_slice_set( + start[divergence_idx + 1:], + [d - 1 for d in dims[divergence_idx + 1:]], + dims[divergence_idx + 1:], + start_edges=start_edges[divergence_idx + 1:], + end_edges=[1 for _ in end_edges[divergence_idx + 1:]] + ) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s for s in + _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1:]], + end[divergence_idx + 1:], + dims[divergence_idx + 1:], + start_edges=[1 for _ in start_edges[divergence_idx + 1:]], + end_edges=end_edges[divergence_idx + 1:], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if(start_edges[divergence_idx] and end_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx] + 1),) + ) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif(start_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx]),) + ) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif(end_edges[divergence_idx]): + slices.extend(upper()) + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) + ) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if(middle_ground > 1): + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) + ) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be + memory-intensive in certain situations. The only reshape operations + in this function are performed on sub-tensors that scale with + (flat_end - flat_start), the chunk size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat( + [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] + ) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," + consisting only of (arbitrarily nested) lists, tuples, and dicts with + torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must + be tensors and must share the same batch dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch + dimensions are specified, a "sub-batch" is defined as a single + indexing of all batch dimensions simultaneously (s.t. the + number of sub-batches is the product of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can + be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary + in most cases, and is ever so slightly slower than the default + setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + # TODO: make this more memory efficient. This sucks + if(not low_mem): + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + ( + flat_batch_dim % chunk_size != 0 + ) + + i = 0 + out = None + for _ in range(no_chunks): + # Chunk the input + if(not low_mem): + select_chunk = ( + lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t + ) + else: + select_chunk = ( + partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims) + ) + ) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) + out = tensor_tree_map(allocate, output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + x1[i : i + chunk_size] = x2 + elif out_type is torch.Tensor: + out[i : i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + out = tensor_tree_map(reshape, out) + + return out diff --git a/openfold/utils/validation_metrics.py b/openfold/utils/validation_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..81b8da9860452030b6e9e1c2f46b8e49fdc17b75 --- /dev/null +++ b/openfold/utils/validation_metrics.py @@ -0,0 +1,38 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def gdt(p1, p2, mask, cutoffs): + n = torch.sum(mask, dim=-1) + + p1 = p1.float() + p2 = p2.float() + distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) + + scores = [] + for c in cutoffs: + score = torch.sum((distances <= c) * mask, dim=-1) / n + scores.append(score) + + return sum(scores) / len(scores) + + +def gdt_ts(p1, p2, mask): + return gdt(p1, p2, mask, [1., 2., 4., 8.]) + + +def gdt_ha(p1, p2, mask): + return gdt(p1, p2, mask, [0.5, 1., 2., 4.]) +