diff --git a/README.md b/README.md
index 6ddff2f305ace52cf0365e74625592c1a32ac0be..41d9bffaf15028ff141f8e5bf7f458c44510add8 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,134 @@
----
-license: gpl-3.0
----
+# P2DFlow
+
+> ## ℹ️ The version 2 of codes for P2DFlow will come soon to align with the new revised paper, and for easier use and modification (the old version can run correctly)
+
+P2DFlow is a protein ensemble generative model with SE(3) flow matching based on ESMFold, the ensembles generated by P2DFlow could aid in understanding protein functions across various scenarios.
+
+Technical details and evaluation results are provided in our paper:
+* [P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching](https://arxiv.org/abs/2411.17196)
+
+
+
+
+
+
+
+
+## Table of Contents
+1. [Installation](#Installation)
+2. [Prepare Dataset](#Prepare-Dataset)
+3. [Model weights](#Model-weights)
+4. [Training](#Training)
+5. [Inference](#Inference)
+6. [Evaluation](#Evaluation)
+7. [License](#License)
+8. [Citation](#Citation)
+
+
+## Installation
+In an environment with cuda 11.7, run:
+```
+conda env create -f environment.yml
+```
+To activate the environment, run:
+```
+conda activate P2DFlow
+```
+
+## Prepare Dataset
+#### (tips: If you want to use the data we have preprocessed, please go directly to `3. Process selected dataset`; if you prefer to process the data from scratch or work with your own data, please start from the beginning)
+
+#### 1. Download raw ATLAS dataset
+(i) Download the `Analysis & MDs` dataset from [ATLAS](https://www.dsimb.inserm.fr/ATLAS/), or you can use `./dataset/download.py` by running:
+```
+python ./dataset/download.py
+```
+We will use `.pdb` and `.xtc` files for the following calculation.
+
+#### 2. Calculate the 'approximate energy and select representative structures
+(i) Use `gaussian_kde` to calculate the 'approximate energy' (You need to put all files above in `./dataset`, include `ATLAS_filename.txt` for filenames of all proteins):
+```
+python ./dataset/traj_analyse.py
+```
+And you will get `traj_info.csv`.
+
+(ii) Select representative structures at equal intervals based on the 'approximate energy':
+```
+python ./dataset/md_select.py
+```
+
+#### 3. Process selected dataset
+
+(i) Download the selected dataset (or get it from the two steps above) from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `selected_dataset.tar`, and decompress it using:
+```
+tar -xvf select_dataset.tar
+```
+(ii) Preprocess `.pdb` files to get `.pkl` files:
+```
+python ./data/process_pdb_files.py --pdb_dir ${pdb_dir} --write_dir ${write_dir}
+```
+And you will get `metadata.csv`.
+
+then compute node representation and pair representation using ESM-2 (`csv_path` is the path of `metadata.csv`):
+```
+python ./data/cal_repr.py --csv_path ${csv_path}
+```
+then compute predicted static structure using ESMFold (`csv_path` is the path of `metadata.csv`):
+```
+python ./data/cal_static_structure.py --csv_path ${csv_path}
+```
+(iii) Provide the necessary `.csv` files for training
+
+If you are using the data we have preprocessed, download the `.csv` files from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filenames are `train_dataset.csv` and `train_dataset_energy.csv`(they correspond to `csv_path` and `energy_csv_path` in `./configs/base.yaml` during training).
+
+Or if you are using your own data, you can get `metadata.csv` from step 3 (correspond to `csv_path` in `./configs/base.yaml` during training, and you need to split a subset from it as the train dataset), and get `traj_info.csv` from step 2 (correspond to `energy_csv_path`).
+
+
+
+## Model weights
+Download the pretrained checkpoint from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `pretrained.ckpt`, and put it into `./weights` folder. You can use the pretrained weight for inference.
+
+
+## Training
+To train P2DFlow, firstly make sure you have prepared the dataset according to `Prepare Dataset`, and put it in the right folder, then modify `./configs/base.yaml` (especially for `csv_path` and `energy_csv_path`). After this, you can run:
+```
+python experiments/train_se3_flows.py
+```
+And you will get the checkpoints in `./ckpt`.
+
+
+## Inference
+To infer for specified protein sequence, firstly modify `./inference/valid_seq.csv` and `./configs/inference.yaml` (especially for `validset_path`), then run:
+```
+python experiments/inference_se3_flows.py
+```
+And you will get the results in `./inference_outputs/weights/`.
+
+
+## Evaluation
+To evaluate metrics related to fidelity and dynamics, specify paths in `./analysis/eval_test.py`, then run:
+```
+python ./analysis/eval_test.py
+```
+To evaluate PCA, specify paths in `./analysis/pca_analyse.py`, then run:
+```
+python ./analysis/pca_analyse.py
+```
+To draw the ramachandran plots, specify paths in `./analysis/Ramachandran_plot.py`, then run:
+```
+python ./analysis/Ramachandran_plot.py
+```
+
+## License
+This project is licensed under the terms of the GPL-3.0 license.
+
+
+## Citation
+```
+@article{jin2024p2dflow,
+ title={P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching},
+ author={Yaowei Jin, Qi Huang, Ziyang Song, Mingyue Zheng, Dan Teng, Qian Shi},
+ journal={arXiv preprint arXiv:2411.17196},
+ year={2024}
+}
+```
diff --git a/analysis/Ramachandran_plot.py b/analysis/Ramachandran_plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3660a4e016b22952e8ba86dbd26ed66812a0bfa
--- /dev/null
+++ b/analysis/Ramachandran_plot.py
@@ -0,0 +1,99 @@
+import MDAnalysis as mda
+import numpy as np
+from MDAnalysis.analysis.dihedrals import Ramachandran
+import os
+import re
+import pandas as pd
+import matplotlib.pyplot as plt
+
+
+def ramachandran_eval(all_paths, pdb_file, output_dir):
+ angle_results_all = []
+
+ for dirpath in all_paths:
+ pdb_path = os.path.join(dirpath,pdb_file)
+
+ u = mda.Universe(pdb_path)
+ protein = u.select_atoms('protein')
+ # print('There are {} residues in the protein'.format(len(protein.residues)))
+
+ ramachandran = Ramachandran(protein)
+ ramachandran.run()
+ angle_results = ramachandran.results.angles
+ # print(angle_results.shape)
+
+ # ramachandran.plot(color='black', marker='.')
+
+ angle_results_all.append(angle_results.reshape([-1,2]))
+
+
+ # df = pd.DataFrame(angle_results.reshape([-1,2]))
+ # df.to_csv(os.path.join(output_dir, os.path.basename(dirpath)+'_'+pdb_file.split('.')[0]+'.csv'), index=False)
+
+
+ points1 = angle_results_all[0]
+ grid_size = 360 # 网格的大小
+ x_bins = np.linspace(-180, 180, grid_size)
+ y_bins = np.linspace(-180, 180, grid_size)
+ result_tmp={}
+ for idx in range(len(angle_results_all[1:])):
+ idx = idx + 1
+ points2 = angle_results_all[idx]
+
+ # 使用2D直方图统计每组点在网格上的分布
+ hist1, _, _ = np.histogram2d(points1[:, 0], points1[:, 1], bins=[x_bins, y_bins])
+ hist2, _, _ = np.histogram2d(points2[:, 0], points2[:, 1], bins=[x_bins, y_bins])
+
+ # 将直方图转换为布尔值,表示某个网格是否有点落入
+ mask1 = hist1 > 0
+ mask2 = hist2 > 0
+
+ intersection = np.logical_and(mask1, mask2).sum()
+ all_mask2 = mask2.sum()
+ val_ratio = intersection / all_mask2
+ print(os.path.basename(all_paths[idx]), "val_ratio:", val_ratio)
+
+
+ result_tmp[os.path.basename(all_paths[idx])] = val_ratio
+ result_tmp['file'] = pdb_file
+
+ return result_tmp
+
+
+
+if __name__ == "__main__":
+ key1 = 'P2DFlow_epoch19'
+ all_paths = [
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/ATLAS_valid",
+ # "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/esm_n_pred",
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/alphaflow_pred",
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Str2Str_pred",
+
+ f'/cluster/home/shiqian/frame-flow-test1/valid/evaluate/{key1}',
+
+ ]
+ output_dir = '/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Ramachandran'
+ os.makedirs(output_dir, exist_ok=True)
+ results={
+ 'file':[],
+ # 'esm_n_pred':[],
+ 'alphaflow_pred':[],
+ 'Str2Str_pred':[],
+
+ key1:[],
+ }
+ for file in os.listdir(all_paths[0]):
+ if re.search('\.pdb',file):
+
+ pdb_file = file
+ print(file)
+ result_tmp = ramachandran_eval(
+ all_paths=all_paths,
+ pdb_file=pdb_file,
+ output_dir=output_dir
+ )
+ for key in results.keys():
+ results[key].append(result_tmp[key])
+
+ out_total_df = pd.DataFrame(results)
+ out_total_df.to_csv(os.path.join(output_dir,f'Ramachandran_plot_validity_{key1}.csv'),index=False)
diff --git a/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc b/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4ec969c89ac3f14e4082d215fa1ab4e9e48a217
Binary files /dev/null and b/analysis/__pycache__/Ramachandran_plot.cpython-310.pyc differ
diff --git a/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc b/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2006cceb04da306e21b3d87a1e73548994aa25cc
Binary files /dev/null and b/analysis/__pycache__/merge_pred_pdb.cpython-310.pyc differ
diff --git a/analysis/__pycache__/metrics.cpython-310.pyc b/analysis/__pycache__/metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c9fe51afc9bc53a9707b083799cd1d594fc67e8
Binary files /dev/null and b/analysis/__pycache__/metrics.cpython-310.pyc differ
diff --git a/analysis/__pycache__/utils.cpython-310.pyc b/analysis/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a3cf847b8371ad05b40a50d85f93733692824c6
Binary files /dev/null and b/analysis/__pycache__/utils.cpython-310.pyc differ
diff --git a/analysis/__pycache__/utils.cpython-38.pyc b/analysis/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dba74aaa6f8107a87afd99f9de5c18c57001386
Binary files /dev/null and b/analysis/__pycache__/utils.cpython-38.pyc differ
diff --git a/analysis/eval_result.py b/analysis/eval_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ae898e37bf51709dcdba4cdfe15b7be59a8238
--- /dev/null
+++ b/analysis/eval_result.py
@@ -0,0 +1,66 @@
+import os
+import re
+import sys
+sys.path.append('./analysis')
+import argparse
+
+import pandas as pd
+from src.eval import evaluate_prediction
+from merge_pred_pdb import merge_pdb_full
+from Ramachandran_plot import ramachandran_eval
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--pred_org_dir", type=str, default="./inference_outputs/weights/pretrained/2025-03-13_10-08")
+ parser.add_argument("--valid_csv_file", type=str, default="./inference/valid_seq.csv")
+ parser.add_argument("--pred_merge_dir", type=str, default="./inference/test/pred_merge_results")
+ parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
+ parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
+
+ args = parser.parse_args()
+
+
+ # merge pdb
+ pred_org_dir = args.pred_org_dir
+ valid_csv_file = args.valid_csv_file
+ pred_merge_dir = args.pred_merge_dir
+ merge_pdb_full(pred_org_dir, valid_csv_file, pred_merge_dir)
+
+
+ # cal_eval
+ pred_merge_dir = args.pred_merge_dir
+ target_dir = args.target_dir
+ crystal_dir = args.crystal_dir
+ evaluate_prediction(pred_merge_dir, target_dir, crystal_dir)
+
+
+ # cal_RP
+ all_paths = [
+ args.target_dir,
+ args.pred_merge_dir,
+ ]
+ results={}
+ for file in os.listdir(all_paths[0]):
+ if re.search('\.pdb',file):
+
+ pdb_file = file
+ print(file)
+ result_tmp = ramachandran_eval(
+ all_paths=all_paths,
+ pdb_file=pdb_file,
+ output_dir=args.pred_merge_dir,
+ )
+
+ for pred_paths in all_paths[1:]:
+ key_name = os.path.basename(pred_paths)
+ if key_name is results.keys():
+ results[key_name].append(result_tmp[key_name])
+ else:
+ results[key_name] = [result_tmp[key_name]]
+
+ out_total_df = pd.DataFrame(results)
+ out_total_df.to_csv(os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv'), index=False)
+ print(f"RP results saved to {os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv')}")
+
+
diff --git a/analysis/merge_pred_pdb.py b/analysis/merge_pred_pdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4366187ef1c19abe425ea43c99d63dcbf90812
--- /dev/null
+++ b/analysis/merge_pred_pdb.py
@@ -0,0 +1,45 @@
+import os
+import re
+import pandas as pd
+from Bio.PDB import PDBParser, PDBIO
+
+def merge_pdb(work_dir, new_file, ref_pdb):
+ parser = PDBParser()
+ structures = []
+ for pdb_dir in os.listdir(work_dir):
+ pattern=".*"+ref_pdb
+ pdb_dir_full=os.path.join(work_dir,pdb_dir)
+ if os.path.isdir(pdb_dir_full) and re.match(pattern,pdb_dir):
+ for pdb_file in os.listdir(pdb_dir_full):
+ if re.match("sample.*\.pdb",pdb_file):
+ structure = parser.get_structure(pdb_file, os.path.join(work_dir,pdb_dir,pdb_file))
+ structures.append(structure)
+
+ if len(structures) == 0:
+ return
+ print(ref_pdb,len(structures),"files")
+
+ new_structure = structures[0]
+ count = 0
+ for structure in structures[1:]:
+ for model in structure:
+ count += 1
+ # print(dir(model))
+ model.id = count
+ new_structure.add(model)
+
+ io = PDBIO()
+ io.set_structure(new_structure)
+ io.save(new_file)
+
+
+def merge_pdb_full(inference_dir_f, valid_csv, output_dir):
+ os.makedirs(output_dir,exist_ok=True)
+ valid_set = pd.read_csv(valid_csv)
+ for filename in valid_set['file']:
+ output_file = os.path.join(output_dir, filename+".pdb")
+ merge_pdb(inference_dir_f, output_file, filename)
+
+
+
+
diff --git a/analysis/metrics.py b/analysis/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..08f52fb6ac7ca7ee0547c5c76565f0ee8f5a45cd
--- /dev/null
+++ b/analysis/metrics.py
@@ -0,0 +1,54 @@
+""" Metrics. """
+import mdtraj as md
+import numpy as np
+from openfold.np import residue_constants
+from tmtools import tm_align
+from data import utils as du
+
+def calc_tm_score(pos_1, pos_2, seq_1, seq_2):
+ tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
+ return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2
+
+def calc_mdtraj_metrics(pdb_path):
+ try:
+ traj = md.load(pdb_path)
+ pdb_ss = md.compute_dssp(traj, simplified=True)
+ pdb_coil_percent = np.mean(pdb_ss == 'C')
+ pdb_helix_percent = np.mean(pdb_ss == 'H')
+ pdb_strand_percent = np.mean(pdb_ss == 'E')
+ pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
+ pdb_rg = md.compute_rg(traj)[0]
+ except IndexError as e:
+ print('Error in calc_mdtraj_metrics: {}'.format(e))
+ pdb_ss_percent = 0.0
+ pdb_coil_percent = 0.0
+ pdb_helix_percent = 0.0
+ pdb_strand_percent = 0.0
+ pdb_rg = 0.0
+ return {
+ 'non_coil_percent': pdb_ss_percent,
+ 'coil_percent': pdb_coil_percent,
+ 'helix_percent': pdb_helix_percent,
+ 'strand_percent': pdb_strand_percent,
+ 'radius_of_gyration': pdb_rg,
+ }
+
+def calc_aligned_rmsd(pos_1, pos_2):
+ aligned_pos_1 = du.rigid_transform_3D(pos_1, pos_2)[0]
+ return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
+
+def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0):
+ ca_bond_dists = np.linalg.norm(
+ ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:]
+ ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca))
+ ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol))
+
+ ca_ca_dists2d = np.linalg.norm(
+ ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1)
+ inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)]
+ clashes = inter_dists < clash_tol
+ return {
+ 'ca_ca_deviation': ca_ca_dev,
+ 'ca_ca_valid_percent': ca_ca_valid,
+ 'num_ca_ca_clashes': np.sum(clashes),
+ }
diff --git a/analysis/pca_analyse.py b/analysis/pca_analyse.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49c04e261eedfa1935952f681327e76646cd8b4
--- /dev/null
+++ b/analysis/pca_analyse.py
@@ -0,0 +1,116 @@
+import os
+import re
+import pandas as pd
+import MDAnalysis as mda
+from MDAnalysis.analysis import pca, align, rms
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import warnings
+import argparse
+warnings.filterwarnings("ignore")
+
+
+def cal_PCA(md_pdb_path,ref_path,pred_pdb_path,n_components = 2):
+ print("")
+ print('filename=',os.path.basename(ref_path))
+
+ u = mda.Universe(md_pdb_path, md_pdb_path)
+ u_ref = mda.Universe(ref_path, ref_path)
+
+ aligner = align.AlignTraj(u,
+ u_ref,
+ select='name CA or name C or name N',
+ in_memory=True).run()
+
+ pc = pca.PCA(u,
+ select='name CA or name C or name N',
+ align=False, mean=None,
+ # n_components=None,
+ n_components=n_components,
+ ).run()
+
+ backbone = u.select_atoms('name CA or name C or name N')
+ n_bb = len(backbone)
+ print('There are {} backbone atoms in the analysis'.format(n_bb))
+
+ for i in range(n_components):
+ print(f"Cumulated variance {i+1}: {pc.cumulated_variance[i]:.3f}")
+
+ transformed = pc.transform(backbone, n_components=n_components)
+
+ print(transformed.shape) # (3000, 2)
+
+ df = pd.DataFrame(transformed,
+ columns=['PC{}'.format(i+1) for i in range(n_components)])
+
+
+ plt.scatter(df['PC1'],df['PC2'],marker='o')
+ plt.show()
+
+ output_dir = os.path.dirname(md_pdb_path)
+ output_filename = os.path.basename(md_pdb_path).split('.')[0]
+
+ df.to_csv(os.path.join(output_dir, f'{output_filename}_md_pca.csv'))
+ plt.savefig(os.path.join(output_dir, f'{output_filename}_md_pca.png'))
+
+
+ for k,v in pred_pdb_path.items():
+ u_pred = mda.Universe(v, v)
+ aligner = align.AlignTraj(u_pred,
+ u_ref,
+ select='name CA or name C or name N',
+ in_memory=True).run()
+ pred_backbone = u_pred.select_atoms('name CA or name C or name N')
+ pred_transformed = pc.transform(pred_backbone, n_components=n_components)
+
+ df = pd.DataFrame(pred_transformed,
+ columns=['PC{}'.format(i+1) for i in range(n_components)])
+
+ plt.scatter(df['PC1'],df['PC2'],marker='o')
+ plt.show()
+
+ output_dir = os.path.dirname(v)
+ output_filename = os.path.basename(v).split('.')[0]
+ df.to_csv(os.path.join(output_dir, f'{output_filename}_{k}_pca.csv'))
+ plt.savefig(os.path.join(output_dir, f'{output_filename}_{k}_pca.png'))
+ plt.clf()
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--pred_pdb_dir", type=str, default="./inference/test/pred_merge_results")
+ parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
+ parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
+
+ args = parser.parse_args()
+
+
+ pred_pdb_path_org={
+ 'P2DFlow':args.pred_pdb_dir,
+ }
+ md_pdb_path_org = args.target_dir
+ ref_path_org = args.crystal_dir
+
+
+ for file in os.listdir(md_pdb_path_org):
+ if re.search('\.pdb',file):
+ pred_pdb_path={
+ 'P2DFlow':'',
+ # 'alphaflow':'',
+ # 'Str2Str':'',
+ }
+ for k,v in pred_pdb_path.items():
+ pred_pdb_path[k]=os.path.join(pred_pdb_path_org[k],file)
+ md_pdb_path = os.path.join(md_pdb_path_org, file)
+ ref_path = os.path.join(ref_path_org, file)
+ cal_PCA(md_pdb_path,ref_path,pred_pdb_path)
+
+
+
+
+
+
+
diff --git a/analysis/src/__init__.py b/analysis/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/__pycache__/__init__.cpython-310.pyc b/analysis/src/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97b45c4b20fb276672c934f9937f860c0d778ff5
Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-310.pyc differ
diff --git a/analysis/src/__pycache__/__init__.cpython-37.pyc b/analysis/src/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b8975ce7ab72af6367e9cd2689079059aead9ef
Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-37.pyc differ
diff --git a/analysis/src/__pycache__/__init__.cpython-39.pyc b/analysis/src/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f0970c21c1f8f16b2738cb805bcbefebd8380fe
Binary files /dev/null and b/analysis/src/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/__pycache__/eval.cpython-310.pyc b/analysis/src/__pycache__/eval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41c4561ea536a75d1dbff7dfc80355ec43e4a1b9
Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-310.pyc differ
diff --git a/analysis/src/__pycache__/eval.cpython-37.pyc b/analysis/src/__pycache__/eval.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b58d4ae3099de98bab35e261c6d866a4a1f445e
Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-37.pyc differ
diff --git a/analysis/src/__pycache__/eval.cpython-39.pyc b/analysis/src/__pycache__/eval.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ffe4d66740cbc14ac30ca9a23bceeaee9c9dca8
Binary files /dev/null and b/analysis/src/__pycache__/eval.cpython-39.pyc differ
diff --git a/analysis/src/common/__init__.py b/analysis/src/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/common/__pycache__/__init__.cpython-310.pyc b/analysis/src/common/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b27312e0ca59c8a4668f989b47079cfd0fc8bc3c
Binary files /dev/null and b/analysis/src/common/__pycache__/__init__.cpython-310.pyc differ
diff --git a/analysis/src/common/__pycache__/__init__.cpython-39.pyc b/analysis/src/common/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db4b9bec0684f401570540a5b6f21d4630d0bc9c
Binary files /dev/null and b/analysis/src/common/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/all_atom.cpython-39.pyc b/analysis/src/common/__pycache__/all_atom.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3c3307dc43eed2f9fc0eaf5106fb4322736a72b
Binary files /dev/null and b/analysis/src/common/__pycache__/all_atom.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc b/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfcf5d1844248f47fa19a4083216ec3c27b6fe5c
Binary files /dev/null and b/analysis/src/common/__pycache__/data_transforms.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc b/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81e4f5ab50727c67b65472fb379bc620e47a71c0
Binary files /dev/null and b/analysis/src/common/__pycache__/geo_utils.cpython-310.pyc differ
diff --git a/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc b/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..542b78adb96abcc3e32c308aecc1b3cf4bacec66
Binary files /dev/null and b/analysis/src/common/__pycache__/geo_utils.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc b/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..400edcbb597ec5560f73df0b11c9346edb7c8b11
Binary files /dev/null and b/analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc differ
diff --git a/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc b/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8c90e9b484d6e3bd6d6d4a700d0d6d5c9360ccd
Binary files /dev/null and b/analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/protein.cpython-310.pyc b/analysis/src/common/__pycache__/protein.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea3f19716658346bec4690f0f591c1eef317a1b4
Binary files /dev/null and b/analysis/src/common/__pycache__/protein.cpython-310.pyc differ
diff --git a/analysis/src/common/__pycache__/protein.cpython-39.pyc b/analysis/src/common/__pycache__/protein.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24c502a02413adb73c405278c3a974b7c07055e0
Binary files /dev/null and b/analysis/src/common/__pycache__/protein.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc b/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c3496c2c6dfd4cdbc79e7a4b1ddbd1f70bd7b1b
Binary files /dev/null and b/analysis/src/common/__pycache__/residue_constants.cpython-310.pyc differ
diff --git a/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc b/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0edc630ea11af65cd953ad52c22b8f11a3999efc
Binary files /dev/null and b/analysis/src/common/__pycache__/residue_constants.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc b/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f436865e10980ad8ddeed2a8963702cd54c98065
Binary files /dev/null and b/analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc differ
diff --git a/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc b/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ddea4ceb14e6a9aae905da01b8af85006fc7064
Binary files /dev/null and b/analysis/src/common/__pycache__/rotation3d.cpython-39.pyc differ
diff --git a/analysis/src/common/all_atom.py b/analysis/src/common/all_atom.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b22993bdbb680c3c583e800eb0741b65fd0c36
--- /dev/null
+++ b/analysis/src/common/all_atom.py
@@ -0,0 +1,219 @@
+"""
+Utilities for calculating all atom representations.
+"""
+
+import torch
+
+from src.common import residue_constants as rc
+from src.common.data_transforms import atom37_to_torsion_angles
+from src.common.rigid_utils import Rigid, Rotation
+
+
+# Residue Constants from OpenFold/AlphaFold2.
+IDEALIZED_POS37 = torch.tensor(rc.restype_atom37_rigid_group_positions)
+IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1)
+IDEALIZED_POS = torch.tensor(rc.restype_atom14_rigid_group_positions)
+DEFAULT_FRAMES = torch.tensor(rc.restype_rigid_group_default_frame)
+ATOM_MASK = torch.tensor(rc.restype_atom14_mask)
+GROUP_IDX = torch.tensor(rc.restype_atom14_to_rigid_group)
+
+
+def torsion_angles_to_frames(
+ r: Rigid,
+ alpha: torch.Tensor,
+ aatype: torch.Tensor,
+):
+ # [*, N, 8, 4, 4]
+ default_4x4 = DEFAULT_FRAMES[aatype, ...].to(r.device)
+
+ # [*, N, 8] transformations, i.e.
+ # One [*, N, 8, 3, 3] rotation matrix and
+ # One [*, N, 8, 3] translation matrix
+ default_r = r.from_tensor_4x4(default_4x4)
+
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
+ bb_rot[..., 1] = 1
+
+ # [*, N, 8, 2]
+ alpha = torch.cat(
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
+ )
+
+ # [*, N, 8, 3, 3]
+ # Produces rotation matrices of the form:
+ # [
+ # [1, 0 , 0 ],
+ # [0, a_2,-a_1],
+ # [0, a_1, a_2]
+ # ]
+ # This follows the original code rather than the supplement, which uses
+ # different indices.
+
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
+ all_rots[..., 0, 0] = 1
+ all_rots[..., 1, 1] = alpha[..., 1]
+ all_rots[..., 1, 2] = -alpha[..., 0]
+ all_rots[..., 2, 1:] = alpha
+
+ all_rots = Rigid(Rotation(rot_mats=all_rots), None)
+
+ all_frames = default_r.compose(all_rots)
+
+ chi2_frame_to_frame = all_frames[..., 5]
+ chi3_frame_to_frame = all_frames[..., 6]
+ chi4_frame_to_frame = all_frames[..., 7]
+
+ chi1_frame_to_bb = all_frames[..., 4]
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
+
+ all_frames_to_bb = Rigid.cat(
+ [
+ all_frames[..., :5],
+ chi2_frame_to_bb.unsqueeze(-1),
+ chi3_frame_to_bb.unsqueeze(-1),
+ chi4_frame_to_bb.unsqueeze(-1),
+ ],
+ dim=-1,
+ )
+
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
+
+ return all_frames_to_global
+
+
+def prot_to_torsion_angles(aatype, atom37, atom37_mask):
+ """Calculate torsion angle features from protein features."""
+ prot_feats = {
+ 'aatype': aatype,
+ 'all_atom_positions': atom37,
+ 'all_atom_mask': atom37_mask,
+ }
+ torsion_angles_feats = atom37_to_torsion_angles()(prot_feats)
+ torsion_angles = torsion_angles_feats['torsion_angles_sin_cos']
+ torsion_mask = torsion_angles_feats['torsion_angles_mask']
+ return torsion_angles, torsion_mask
+
+
+def frames_to_atom14_pos(
+ r: Rigid,
+ aatype: torch.Tensor,
+ ):
+ """Convert frames to their idealized all atom representation.
+
+ Args:
+ r: All rigid groups. [..., N, 8, 3]
+ aatype: Residue types. [..., N]
+
+ Returns:
+
+ """
+
+ # [*, N, 14]
+ group_mask = GROUP_IDX[aatype, ...]
+
+ # [*, N, 14, 8]
+ group_mask = torch.nn.functional.one_hot(
+ group_mask,
+ num_classes=DEFAULT_FRAMES.shape[-3],
+ ).to(r.device)
+
+ # [*, N, 14, 8]
+ t_atoms_to_global = r[..., None, :] * group_mask
+
+ # [*, N, 14]
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
+ lambda x: torch.sum(x, dim=-1)
+ )
+
+ # [*, N, 14, 1]
+ frame_atom_mask = ATOM_MASK[aatype, ...].unsqueeze(-1).to(r.device)
+
+ # [*, N, 14, 3]
+ frame_null_pos = IDEALIZED_POS[aatype, ...].to(r.device)
+ pred_positions = t_atoms_to_global.apply(frame_null_pos)
+ pred_positions = pred_positions * frame_atom_mask
+
+ return pred_positions
+
+
+def compute_backbone(bb_rigids, psi_torsions, aatype=None, device=None):
+ if device is None:
+ device = bb_rigids.device
+
+ torsion_angles = torch.tile(
+ psi_torsions[..., None, :],
+ tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1)
+ ).to(device)
+
+ # aatype must be on cpu for initializing the tensor by indexing
+ if aatype is None:
+ aatype = torch.zeros_like(bb_rigids).cpu().long()
+ else:
+ aatype = aatype.cpu()
+
+
+ all_frames = torsion_angles_to_frames(
+ bb_rigids,
+ torsion_angles,
+ aatype,
+ )
+ atom14_pos = frames_to_atom14_pos(
+ all_frames,
+ aatype,
+ )
+ atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=device)
+ # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
+ # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
+ atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
+ atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :]
+ atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :]
+ atom37_mask = torch.any(atom37_bb_pos, axis=-1)
+ return atom37_bb_pos, atom37_mask, aatype.to(device), atom14_pos
+
+
+def calculate_neighbor_angles(R_ac, R_ab):
+ """Calculate angles between atoms c <- a -> b.
+
+ Parameters
+ ----------
+ R_ac: Tensor, shape = (N,3)
+ Vector from atom a to c.
+ R_ab: Tensor, shape = (N,3)
+ Vector from atom a to b.
+
+ Returns
+ -------
+ angle_cab: Tensor, shape = (N,)
+ Angle between atoms c <- a -> b.
+ """
+ # cos(alpha) = (u * v) / (|u|*|v|)
+ x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
+ # sin(alpha) = |u x v| / (|u|*|v|)
+ y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
+ # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
+ y = torch.max(y, torch.tensor(1e-9))
+ angle = torch.atan2(y, x)
+ return angle
+
+
+def vector_projection(R_ab, P_n):
+ """
+ Project the vector R_ab onto a plane with normal vector P_n.
+
+ Parameters
+ ----------
+ R_ab: Tensor, shape = (N,3)
+ Vector from atom a to b.
+ P_n: Tensor, shape = (N,3)
+ Normal vector of a plane onto which to project R_ab.
+
+ Returns
+ -------
+ R_ab_proj: Tensor, shape = (N,3)
+ Projected vector (orthogonal to P_n).
+ """
+ a_x_b = torch.sum(R_ab * P_n, dim=-1)
+ b_x_b = torch.sum(P_n * P_n, dim=-1)
+ return R_ab - (a_x_b / b_x_b)[:, None] * P_n
diff --git a/analysis/src/common/data_transforms.py b/analysis/src/common/data_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..85133bac06721da875d25369884bdc8021b63d2e
--- /dev/null
+++ b/analysis/src/common/data_transforms.py
@@ -0,0 +1,1194 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+from functools import reduce, wraps
+from operator import add
+
+import numpy as np
+import torch
+
+from src.common import residue_constants as rc
+from src.common.rigid_utils import Rotation, Rigid
+from src.utils.tensor_utils import (
+ tree_map,
+ tensor_tree_map,
+ batched_gather,
+)
+
+NUM_RES = "num residues placeholder"
+NUM_MSA_SEQ = "msa placeholder"
+NUM_EXTRA_SEQ = "extra msa placeholder"
+NUM_TEMPLATES = "num templates placeholder"
+
+MSA_FEATURE_NAMES = [
+ "msa",
+ "deletion_matrix",
+ "msa_mask",
+ "msa_row_mask",
+ "bert_mask",
+ "true_msa",
+]
+
+
+def cast_to_64bit_ints(protein):
+ # We keep all ints as int64
+ for k, v in protein.items():
+ if v.dtype == torch.int32:
+ protein[k] = v.type(torch.int64)
+
+ return protein
+
+
+def make_one_hot(x, num_classes):
+ x_one_hot = torch.zeros(*x.shape, num_classes)
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
+ return x_one_hot
+
+
+def make_seq_mask(protein):
+ protein["seq_mask"] = torch.ones(
+ protein["aatype"].shape, dtype=torch.float32
+ )
+ return protein
+
+
+def make_template_mask(protein):
+ protein["template_mask"] = torch.ones(
+ protein["template_aatype"].shape[0], dtype=torch.float32
+ )
+ return protein
+
+
+def curry1(f):
+ """Supply all arguments but the first."""
+ @wraps(f)
+ def fc(*args, **kwargs):
+ return lambda x: f(x, *args, **kwargs)
+
+ return fc
+
+
+def make_all_atom_aatype(protein):
+ protein["all_atom_aatype"] = protein["aatype"]
+ return protein
+
+
+def fix_templates_aatype(protein):
+ # Map one-hot to indices
+ num_templates = protein["template_aatype"].shape[0]
+ if(num_templates > 0):
+ protein["template_aatype"] = torch.argmax(
+ protein["template_aatype"], dim=-1
+ )
+ # Map hhsearch-aatype to our aatype.
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
+ num_templates, -1
+ )
+ protein["template_aatype"] = torch.gather(
+ new_order, 1, index=protein["template_aatype"]
+ )
+
+ return protein
+
+
+def correct_msa_restypes(protein):
+ """Correct MSA restype to have the same order as rc."""
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = torch.tensor(
+ [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
+ ).transpose(0, 1)
+ protein["msa"] = torch.gather(new_order, 0, protein["msa"])
+
+ perm_matrix = np.zeros((22, 22), dtype=np.float32)
+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
+
+ for k in protein:
+ if "profile" in k:
+ num_dim = protein[k].shape.as_list()[-1]
+ assert num_dim in [
+ 20,
+ 21,
+ 22,
+ ], "num_dim for %s out of expected range: %s" % (k, num_dim)
+ protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
+
+ return protein
+
+
+def squeeze_features(protein):
+ """Remove singleton and repeated dimensions in protein features."""
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
+ for k in [
+ "domain_name",
+ "msa",
+ "num_alignments",
+ "seq_length",
+ "sequence",
+ "superfamily",
+ "deletion_matrix",
+ "resolution",
+ "between_segment_residues",
+ "residue_index",
+ "template_all_atom_mask",
+ ]:
+ if k in protein:
+ final_dim = protein[k].shape[-1]
+ if isinstance(final_dim, int) and final_dim == 1:
+ if torch.is_tensor(protein[k]):
+ protein[k] = torch.squeeze(protein[k], dim=-1)
+ else:
+ protein[k] = np.squeeze(protein[k], axis=-1)
+
+ for k in ["seq_length", "num_alignments"]:
+ if k in protein:
+ protein[k] = protein[k][0]
+
+ return protein
+
+
+@curry1
+def randomly_replace_msa_with_unknown(protein, replace_proportion):
+ """Replace a portion of the MSA with 'X'."""
+ msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
+ x_idx = 20
+ gap_idx = 21
+ msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
+ protein["msa"] = torch.where(
+ msa_mask,
+ torch.ones_like(protein["msa"]) * x_idx,
+ protein["msa"]
+ )
+ aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
+
+ protein["aatype"] = torch.where(
+ aatype_mask,
+ torch.ones_like(protein["aatype"]) * x_idx,
+ protein["aatype"],
+ )
+ return protein
+
+
+@curry1
+def sample_msa(protein, max_seq, keep_extra, seed=None):
+ """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
+ num_seq = protein["msa"].shape[0]
+ g = torch.Generator(device=protein["msa"].device)
+ if seed is not None:
+ g.manual_seed(seed)
+ shuffled = torch.randperm(num_seq - 1, generator=g) + 1
+ index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
+ num_sel = min(max_seq, num_seq)
+ sel_seq, not_sel_seq = torch.split(
+ index_order, [num_sel, num_seq - num_sel]
+ )
+
+ for k in MSA_FEATURE_NAMES:
+ if k in protein:
+ if keep_extra:
+ protein["extra_" + k] = torch.index_select(
+ protein[k], 0, not_sel_seq
+ )
+ protein[k] = torch.index_select(protein[k], 0, sel_seq)
+
+ return protein
+
+
+@curry1
+def add_distillation_flag(protein, distillation):
+ protein['is_distillation'] = distillation
+ return protein
+
+@curry1
+def sample_msa_distillation(protein, max_seq):
+ if(protein["is_distillation"] == 1):
+ protein = sample_msa(max_seq, keep_extra=False)(protein)
+ return protein
+
+
+@curry1
+def crop_extra_msa(protein, max_extra_msa):
+ num_seq = protein["extra_msa"].shape[0]
+ num_sel = min(max_extra_msa, num_seq)
+ select_indices = torch.randperm(num_seq)[:num_sel]
+ for k in MSA_FEATURE_NAMES:
+ if "extra_" + k in protein:
+ protein["extra_" + k] = torch.index_select(
+ protein["extra_" + k], 0, select_indices
+ )
+
+ return protein
+
+
+def delete_extra_msa(protein):
+ for k in MSA_FEATURE_NAMES:
+ if "extra_" + k in protein:
+ del protein["extra_" + k]
+ return protein
+
+
+# Not used in inference
+@curry1
+def block_delete_msa(protein, config):
+ num_seq = protein["msa"].shape[0]
+ block_num_seq = torch.floor(
+ torch.tensor(num_seq, dtype=torch.float32)
+ * config.msa_fraction_per_block
+ ).to(torch.int32)
+
+ if config.randomize_num_blocks:
+ nb = torch.distributions.uniform.Uniform(
+ 0, config.num_blocks + 1
+ ).sample()
+ else:
+ nb = config.num_blocks
+
+ del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
+ del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
+ del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
+ del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
+
+ # Make sure we keep the original sequence
+ combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
+ uniques, counts = combined.unique(return_counts=True)
+ difference = uniques[counts == 1]
+ intersection = uniques[counts > 1]
+ keep_indices = torch.squeeze(difference, 0)
+
+ for k in MSA_FEATURE_NAMES:
+ if k in protein:
+ protein[k] = torch.gather(protein[k], keep_indices)
+
+ return protein
+
+
+@curry1
+def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
+ weights = torch.cat(
+ [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
+ 0,
+ )
+
+ # Make agreement score as weighted Hamming distance
+ msa_one_hot = make_one_hot(protein["msa"], 23)
+ sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
+ extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
+ extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
+
+ num_seq, num_res, _ = sample_one_hot.shape
+ extra_num_seq, _, _ = extra_one_hot.shape
+
+ # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
+ # in an optimized fashion to avoid possible memory or computation blowup.
+ agreement = torch.matmul(
+ torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
+ torch.reshape(
+ sample_one_hot * weights, [num_seq, num_res * 23]
+ ).transpose(0, 1),
+ )
+
+ # Assign each sequence in the extra sequences to the closest MSA sample
+ protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
+ torch.int64
+ )
+
+ return protein
+
+
+def unsorted_segment_sum(data, segment_ids, num_segments):
+ """
+ Computes the sum along segments of a tensor. Similar to
+ tf.unsorted_segment_sum, but only supports 1-D indices.
+
+ :param data: A tensor whose segments are to be summed.
+ :param segment_ids: The 1-D segment indices tensor.
+ :param num_segments: The number of segments.
+ :return: A tensor of same data type as the data argument.
+ """
+ assert (
+ len(segment_ids.shape) == 1 and
+ segment_ids.shape[0] == data.shape[0]
+ )
+ segment_ids = segment_ids.view(
+ segment_ids.shape[0], *((1,) * len(data.shape[1:]))
+ )
+ segment_ids = segment_ids.expand(data.shape)
+ shape = [num_segments] + list(data.shape[1:])
+ tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
+ tensor = tensor.type(data.dtype)
+ return tensor
+
+
+@curry1
+def summarize_clusters(protein):
+ """Produce profile and deletion_matrix_mean within each cluster."""
+ num_seq = protein["msa"].shape[0]
+
+ def csum(x):
+ return unsorted_segment_sum(
+ x, protein["extra_cluster_assignment"], num_seq
+ )
+
+ mask = protein["extra_msa_mask"]
+ mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
+
+ msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
+ msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
+ protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
+ del msa_sum
+
+ del_sum = csum(mask * protein["extra_deletion_matrix"])
+ del_sum += protein["deletion_matrix"] # Original sequence
+ protein["cluster_deletion_mean"] = del_sum / mask_counts
+ del del_sum
+
+ return protein
+
+
+def make_msa_mask(protein):
+ """Mask features are all ones, but will later be zero-padded."""
+ protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
+ protein["msa_row_mask"] = torch.ones(
+ (protein["msa"].shape[0]), dtype=torch.float32
+ )
+ return protein
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
+ """Create pseudo beta features."""
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
+ ca_idx = rc.atom_order["CA"]
+ cb_idx = rc.atom_order["CB"]
+ pseudo_beta = torch.where(
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :],
+ )
+
+ if all_atom_mask is not None:
+ pseudo_beta_mask = torch.where(
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
+ )
+ return pseudo_beta, pseudo_beta_mask
+ else:
+ return pseudo_beta
+
+
+@curry1
+def make_pseudo_beta(protein, prefix=""):
+ """Create pseudo-beta (alpha for glycine) position and mask."""
+ assert prefix in ["", "template_"]
+ (
+ protein[prefix + "pseudo_beta"],
+ protein[prefix + "pseudo_beta_mask"],
+ ) = pseudo_beta_fn(
+ protein["template_aatype" if prefix else "aatype"],
+ protein[prefix + "all_atom_positions"],
+ protein["template_all_atom_mask" if prefix else "all_atom_mask"],
+ )
+ return protein
+
+
+@curry1
+def add_constant_field(protein, key, value):
+ protein[key] = torch.tensor(value)
+ return protein
+
+
+def shaped_categorical(probs, epsilon=1e-10):
+ ds = probs.shape
+ num_classes = ds[-1]
+ distribution = torch.distributions.categorical.Categorical(
+ torch.reshape(probs + epsilon, [-1, num_classes])
+ )
+ counts = distribution.sample()
+ return torch.reshape(counts, ds[:-1])
+
+
+def make_hhblits_profile(protein):
+ """Compute the HHblits MSA profile if not already present."""
+ if "hhblits_profile" in protein:
+ return protein
+
+ # Compute the profile for every residue (over all MSA sequences).
+ msa_one_hot = make_one_hot(protein["msa"], 22)
+
+ protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
+ return protein
+
+
+@curry1
+def make_masked_msa(protein, config, replace_fraction):
+ """Create data for BERT on raw MSA."""
+ # Add a random amino acid uniformly.
+ random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
+
+ categorical_probs = (
+ config.uniform_prob * random_aa
+ + config.profile_prob * protein["hhblits_profile"]
+ + config.same_prob * make_one_hot(protein["msa"], 22)
+ )
+
+ # Put all remaining probability on [MASK] which is a new column
+ pad_shapes = list(
+ reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
+ )
+ pad_shapes[1] = 1
+ mask_prob = (
+ 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
+ )
+ assert mask_prob >= 0.0
+ categorical_probs = torch.nn.functional.pad(
+ categorical_probs, pad_shapes, value=mask_prob
+ )
+
+ sh = protein["msa"].shape
+ mask_position = torch.rand(sh) < replace_fraction
+
+ bert_msa = shaped_categorical(categorical_probs)
+ bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
+
+ # Mix real and masked MSA
+ protein["bert_mask"] = mask_position.to(torch.float32)
+ protein["true_msa"] = protein["msa"]
+ protein["msa"] = bert_msa
+
+ return protein
+
+
+@curry1
+def make_fixed_size(
+ protein,
+ shape_schema,
+ msa_cluster_size,
+ extra_msa_size,
+ num_res=0,
+ num_templates=0,
+):
+ """Guess at the MSA and sequence dimension to make fixed size."""
+ pad_size_map = {
+ NUM_RES: num_res,
+ NUM_MSA_SEQ: msa_cluster_size,
+ NUM_EXTRA_SEQ: extra_msa_size,
+ NUM_TEMPLATES: num_templates,
+ }
+
+ for k, v in protein.items():
+ # Don't transfer this to the accelerator.
+ if k == "extra_cluster_assignment":
+ continue
+ shape = list(v.shape)
+ schema = shape_schema[k]
+ msg = "Rank mismatch between shape and shape schema for"
+ assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
+ pad_size = [
+ pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
+ ]
+
+ padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
+ padding.reverse()
+ padding = list(itertools.chain(*padding))
+ if padding:
+ protein[k] = torch.nn.functional.pad(v, padding)
+ protein[k] = torch.reshape(protein[k], pad_size)
+
+ return protein
+
+
+@curry1
+def make_msa_feat(protein):
+ """Create and concatenate MSA features."""
+ # Whether there is a domain break. Always zero for chains, but keeping for
+ # compatibility with domain datasets.
+ has_break = torch.clip(
+ protein["between_segment_residues"].to(torch.float32), 0, 1
+ )
+ aatype_1hot = make_one_hot(protein["aatype"], 21)
+
+ target_feat = [
+ torch.unsqueeze(has_break, dim=-1),
+ aatype_1hot, # Everyone gets the original sequence.
+ ]
+
+ msa_1hot = make_one_hot(protein["msa"], 23)
+ has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
+ deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
+ 2.0 / np.pi
+ )
+
+ msa_feat = [
+ msa_1hot,
+ torch.unsqueeze(has_deletion, dim=-1),
+ torch.unsqueeze(deletion_value, dim=-1),
+ ]
+
+ if "cluster_profile" in protein:
+ deletion_mean_value = torch.atan(
+ protein["cluster_deletion_mean"] / 3.0
+ ) * (2.0 / np.pi)
+ msa_feat.extend(
+ [
+ protein["cluster_profile"],
+ torch.unsqueeze(deletion_mean_value, dim=-1),
+ ]
+ )
+
+ if "extra_deletion_matrix" in protein:
+ protein["extra_has_deletion"] = torch.clip(
+ protein["extra_deletion_matrix"], 0.0, 1.0
+ )
+ protein["extra_deletion_value"] = torch.atan(
+ protein["extra_deletion_matrix"] / 3.0
+ ) * (2.0 / np.pi)
+
+ protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
+ protein["target_feat"] = torch.cat(target_feat, dim=-1)
+ return protein
+
+
+@curry1
+def select_feat(protein, feature_list):
+ return {k: v for k, v in protein.items() if k in feature_list}
+
+
+@curry1
+def crop_templates(protein, max_templates):
+ for k, v in protein.items():
+ if k.startswith("template_"):
+ protein[k] = v[:max_templates]
+ return protein
+
+
+def make_atom14_masks(protein):
+ """Construct denser atom positions (14 dimensions instead of 37)."""
+ restype_atom14_to_atom37 = []
+ restype_atom37_to_atom14 = []
+ restype_atom14_mask = []
+
+ for rt in rc.restypes:
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
+ restype_atom14_to_atom37.append(
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
+ )
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
+ restype_atom37_to_atom14.append(
+ [
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
+ for name in rc.atom_types
+ ]
+ )
+
+ restype_atom14_mask.append(
+ [(1.0 if name else 0.0) for name in atom_names]
+ )
+
+ # Add dummy mapping for restype 'UNK'
+ restype_atom14_to_atom37.append([0] * 14)
+ restype_atom37_to_atom14.append([0] * 37)
+ restype_atom14_mask.append([0.0] * 14)
+
+ restype_atom14_to_atom37 = torch.tensor(
+ restype_atom14_to_atom37,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom37_to_atom14 = torch.tensor(
+ restype_atom37_to_atom14,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom14_mask = torch.tensor(
+ restype_atom14_mask,
+ dtype=torch.float32,
+ device=protein["aatype"].device,
+ )
+ protein_aatype = protein['aatype'].to(torch.long)
+
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
+ # with shape (num_res, 14) containing the atom37 indices for this protein
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
+
+ protein["atom14_atom_exists"] = residx_atom14_mask
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
+
+ # create the gather indices for mapping back
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
+
+ # create the corresponding mask
+ restype_atom37_mask = torch.zeros(
+ [21, 37], dtype=torch.float32, device=protein["aatype"].device
+ )
+ for restype, restype_letter in enumerate(rc.restypes):
+ restype_name = rc.restype_1to3[restype_letter]
+ atom_names = rc.residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = rc.atom_order[atom_name]
+ restype_atom37_mask[restype, atom_type] = 1
+
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
+ protein["atom37_atom_exists"] = residx_atom37_mask
+
+ return protein
+
+
+def make_atom14_masks_np(batch):
+ batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
+ out = make_atom14_masks(batch)
+ out = tensor_tree_map(lambda t: np.array(t), out)
+ return out
+
+
+def make_atom14_positions(protein):
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
+ residx_atom14_mask = protein["atom14_atom_exists"]
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
+
+ # Create a mask for known ground truth positions.
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
+ protein["all_atom_mask"],
+ residx_atom14_to_atom37,
+ dim=-1,
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
+ )
+
+ # Gather the ground truth positions.
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
+ batched_gather(
+ protein["all_atom_positions"],
+ residx_atom14_to_atom37,
+ dim=-2,
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
+ )
+ )
+
+ protein["atom14_atom_exists"] = residx_atom14_mask
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
+
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
+ # alternative ground truth coordinates where the naming is swapped
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
+ restype_3 += ["UNK"]
+
+ # Matrices for renaming ambiguous atoms.
+ all_matrices = {
+ res: torch.eye(
+ 14,
+ dtype=protein["all_atom_mask"].dtype,
+ device=protein["all_atom_mask"].device,
+ )
+ for res in restype_3
+ }
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
+ correspondences = torch.arange(
+ 14, device=protein["all_atom_mask"].device
+ )
+ for source_atom_swap, target_atom_swap in swap.items():
+ source_index = rc.restype_name_to_atom14_names[resname].index(
+ source_atom_swap
+ )
+ target_index = rc.restype_name_to_atom14_names[resname].index(
+ target_atom_swap
+ )
+ correspondences[source_index] = target_index
+ correspondences[target_index] = source_index
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
+ for index, correspondence in enumerate(correspondences):
+ renaming_matrix[index, correspondence] = 1.0
+ all_matrices[resname] = renaming_matrix
+ renaming_matrices = torch.stack(
+ [all_matrices[restype] for restype in restype_3]
+ )
+
+ # Pick the transformation matrices for the given residue sequence
+ # shape (num_res, 14, 14).
+ renaming_transform = renaming_matrices[protein["aatype"]]
+
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
+ alternative_gt_positions = torch.einsum(
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
+ )
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
+
+ # Create the mask for the alternative ground truth (differs from the
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
+ # ground truth position).
+ alternative_gt_mask = torch.einsum(
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
+ )
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
+
+ # Create an ambiguous atoms mask. shape: (21, 14).
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
+ for atom_name1, atom_name2 in swap.items():
+ restype = rc.restype_order[rc.restype_3to1[resname]]
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
+ atom_name1
+ )
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
+ atom_name2
+ )
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
+
+ # From this create an ambiguous_mask for the given sequence.
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
+ protein["aatype"]
+ ]
+
+ return protein
+
+
+def atom37_to_frames(protein, eps=1e-8):
+ aatype = protein["aatype"]
+ all_atom_positions = protein["all_atom_positions"]
+ all_atom_mask = protein["all_atom_mask"]
+
+ batch_dims = len(aatype.shape[:-1])
+
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
+
+ for restype, restype_letter in enumerate(rc.restypes):
+ resname = rc.restype_1to3[restype_letter]
+ for chi_idx in range(4):
+ if rc.chi_angles_mask[restype][chi_idx]:
+ names = rc.chi_angles_atoms[resname][chi_idx]
+ restype_rigidgroup_base_atom_names[
+ restype, chi_idx + 4, :
+ ] = names[1:]
+
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
+ (*aatype.shape[:-1], 21, 8),
+ )
+ restype_rigidgroup_mask[..., 0] = 1
+ restype_rigidgroup_mask[..., 3] = 1
+ restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
+ rc.chi_angles_mask
+ )
+
+ lookuptable = rc.atom_order.copy()
+ lookuptable[""] = 0
+ lookup = np.vectorize(lambda x: lookuptable[x])
+ restype_rigidgroup_base_atom37_idx = lookup(
+ restype_rigidgroup_base_atom_names,
+ )
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
+ restype_rigidgroup_base_atom37_idx,
+ )
+ restype_rigidgroup_base_atom37_idx = (
+ restype_rigidgroup_base_atom37_idx.view(
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
+ )
+ )
+
+ residx_rigidgroup_base_atom37_idx = batched_gather(
+ restype_rigidgroup_base_atom37_idx,
+ aatype,
+ dim=-3,
+ no_batch_dims=batch_dims,
+ )
+
+ base_atom_pos = batched_gather(
+ all_atom_positions,
+ residx_rigidgroup_base_atom37_idx,
+ dim=-2,
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
+ )
+
+ gt_frames = Rigid.from_3_points(
+ p_neg_x_axis=base_atom_pos[..., 0, :],
+ origin=base_atom_pos[..., 1, :],
+ p_xy_plane=base_atom_pos[..., 2, :],
+ eps=eps,
+ )
+
+ group_exists = batched_gather(
+ restype_rigidgroup_mask,
+ aatype,
+ dim=-2,
+ no_batch_dims=batch_dims,
+ )
+
+ gt_atoms_exist = batched_gather(
+ all_atom_mask,
+ residx_rigidgroup_base_atom37_idx,
+ dim=-1,
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
+ )
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
+
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
+ rots[..., 0, 0, 0] = -1
+ rots[..., 0, 2, 2] = -1
+ rots = Rotation(rot_mats=rots)
+
+ gt_frames = gt_frames.compose(Rigid(rots, None))
+
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
+ *((1,) * batch_dims), 21, 8
+ )
+ restype_rigidgroup_rots = torch.eye(
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
+ )
+ restype_rigidgroup_rots = torch.tile(
+ restype_rigidgroup_rots,
+ (*((1,) * batch_dims), 21, 8, 1, 1),
+ )
+
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
+ restype = rc.restype_order[rc.restype_3to1[resname]]
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
+
+ residx_rigidgroup_is_ambiguous = batched_gather(
+ restype_rigidgroup_is_ambiguous,
+ aatype,
+ dim=-2,
+ no_batch_dims=batch_dims,
+ )
+
+ residx_rigidgroup_ambiguity_rot = batched_gather(
+ restype_rigidgroup_rots,
+ aatype,
+ dim=-4,
+ no_batch_dims=batch_dims,
+ )
+
+ residx_rigidgroup_ambiguity_rot = Rotation(
+ rot_mats=residx_rigidgroup_ambiguity_rot
+ )
+ alt_gt_frames = gt_frames.compose(
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
+ )
+
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
+
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
+ protein["rigidgroups_gt_exists"] = gt_exists
+ protein["rigidgroups_group_exists"] = group_exists
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
+
+ return protein
+
+
+def get_chi_atom_indices():
+ """Returns atom indices needed to compute chi angles for all residue types.
+
+ Returns:
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
+ in the order specified in rc.restypes + unknown residue type
+ at the end. For chi angles which are not defined on the residue, the
+ positions indices are by default set to 0.
+ """
+ chi_atom_indices = []
+ for residue_name in rc.restypes:
+ residue_name = rc.restype_1to3[residue_name]
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
+ atom_indices = []
+ for chi_angle in residue_chi_angles:
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
+ for _ in range(4 - len(atom_indices)):
+ atom_indices.append(
+ [0, 0, 0, 0]
+ ) # For chi angles not defined on the AA.
+ chi_atom_indices.append(atom_indices)
+
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
+
+ return chi_atom_indices
+
+
+@curry1
+def atom37_to_torsion_angles(
+ protein,
+ prefix="",
+):
+ """
+ Convert coordinates to torsion angles.
+
+ This function is extremely sensitive to floating point imprecisions
+ and should be run with double precision whenever possible.
+
+ Args:
+ Dict containing:
+ * (prefix)aatype:
+ [*, N_res] residue indices
+ * (prefix)all_atom_positions:
+ [*, N_res, 37, 3] atom positions (in atom37
+ format)
+ * (prefix)all_atom_mask:
+ [*, N_res, 37] atom position mask
+ Returns:
+ The same dictionary updated with the following features:
+
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
+ Torsion angles
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
+ Alternate torsion angles (accounting for 180-degree symmetry)
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
+ Torsion angles mask
+ """
+ aatype = protein[prefix + "aatype"]
+ all_atom_positions = protein[prefix + "all_atom_positions"]
+ all_atom_mask = protein[prefix + "all_atom_mask"]
+
+ aatype = torch.clamp(aatype, max=20)
+
+ pad = all_atom_positions.new_zeros(
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
+ )
+ prev_all_atom_positions = torch.cat(
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
+ )
+
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
+
+ pre_omega_atom_pos = torch.cat(
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
+ dim=-2,
+ )
+ phi_atom_pos = torch.cat(
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
+ dim=-2,
+ )
+ psi_atom_pos = torch.cat(
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
+ dim=-2,
+ )
+
+ pre_omega_mask = torch.prod(
+ prev_all_atom_mask[..., 1:3], dim=-1
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
+ )
+ psi_mask = (
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
+ * all_atom_mask[..., 4]
+ )
+
+ chi_atom_indices = torch.as_tensor(
+ get_chi_atom_indices(), device=aatype.device
+ )
+
+ atom_indices = chi_atom_indices[..., aatype, :, :]
+ chis_atom_pos = batched_gather(
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
+ )
+
+ chi_angles_mask = list(rc.chi_angles_mask)
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
+
+ chis_mask = chi_angles_mask[aatype, :]
+
+ chi_angle_atoms_mask = batched_gather(
+ all_atom_mask,
+ atom_indices,
+ dim=-1,
+ no_batch_dims=len(atom_indices.shape[:-2]),
+ )
+ chi_angle_atoms_mask = torch.prod(
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
+ )
+ chis_mask = chis_mask * chi_angle_atoms_mask
+
+ torsions_atom_pos = torch.cat(
+ [
+ pre_omega_atom_pos[..., None, :, :],
+ phi_atom_pos[..., None, :, :],
+ psi_atom_pos[..., None, :, :],
+ chis_atom_pos,
+ ],
+ dim=-3,
+ )
+
+ torsion_angles_mask = torch.cat(
+ [
+ pre_omega_mask[..., None],
+ phi_mask[..., None],
+ psi_mask[..., None],
+ chis_mask,
+ ],
+ dim=-1,
+ )
+
+ torsion_frames = Rigid.from_3_points(
+ torsions_atom_pos[..., 1, :],
+ torsions_atom_pos[..., 2, :],
+ torsions_atom_pos[..., 0, :],
+ eps=1e-8,
+ )
+
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
+ torsions_atom_pos[..., 3, :]
+ )
+
+ torsion_angles_sin_cos = torch.stack(
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
+ )
+
+ denom = torch.sqrt(
+ torch.sum(
+ torch.square(torsion_angles_sin_cos),
+ dim=-1,
+ dtype=torsion_angles_sin_cos.dtype,
+ keepdims=True,
+ )
+ + 1e-8
+ )
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
+
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
+
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
+ rc.chi_pi_periodic,
+ )[aatype, ...]
+
+ mirror_torsion_angles = torch.cat(
+ [
+ all_atom_mask.new_ones(*aatype.shape, 3),
+ 1.0 - 2.0 * chi_is_ambiguous,
+ ],
+ dim=-1,
+ )
+
+ alt_torsion_angles_sin_cos = (
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
+ )
+
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
+
+ return protein
+
+
+def get_backbone_frames(protein):
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
+ ..., 0, :, :
+ ]
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
+
+ return protein
+
+
+def get_chi_angles(protein):
+ dtype = protein["all_atom_mask"].dtype
+ protein["chi_angles_sin_cos"] = (
+ protein["torsion_angles_sin_cos"][..., 3:, :]
+ ).to(dtype)
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
+
+ return protein
+
+
+@curry1
+def random_crop_to_size(
+ protein,
+ crop_size,
+ max_templates,
+ shape_schema,
+ subsample_templates=False,
+ seed=None,
+):
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
+ # We want each ensemble to be cropped the same way
+ g = torch.Generator(device=protein["seq_length"].device)
+ if seed is not None:
+ g.manual_seed(seed)
+
+ seq_length = protein["seq_length"]
+
+ if "template_mask" in protein:
+ num_templates = protein["template_mask"].shape[-1]
+ else:
+ num_templates = 0
+
+ # No need to subsample templates if there aren't any
+ subsample_templates = subsample_templates and num_templates
+
+ num_res_crop_size = min(int(seq_length), crop_size)
+
+ def _randint(lower, upper):
+ return int(torch.randint(
+ lower,
+ upper + 1,
+ (1,),
+ device=protein["seq_length"].device,
+ generator=g,
+ )[0])
+
+ if subsample_templates:
+ templates_crop_start = _randint(0, num_templates)
+ templates_select_indices = torch.randperm(
+ num_templates, device=protein["seq_length"].device, generator=g
+ )
+ else:
+ templates_crop_start = 0
+
+ num_templates_crop_size = min(
+ num_templates - templates_crop_start, max_templates
+ )
+
+ n = seq_length - num_res_crop_size
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
+ right_anchor = n
+ else:
+ x = _randint(0, n)
+ right_anchor = n - x
+
+ num_res_crop_start = _randint(0, right_anchor)
+
+ for k, v in protein.items():
+ if k not in shape_schema or (
+ "template" not in k and NUM_RES not in shape_schema[k]
+ ):
+ continue
+
+ # randomly permute the templates before cropping them.
+ if k.startswith("template") and subsample_templates:
+ v = v[templates_select_indices]
+
+ slices = []
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
+ is_num_res = dim_size == NUM_RES
+ if i == 0 and k.startswith("template"):
+ crop_size = num_templates_crop_size
+ crop_start = templates_crop_start
+ else:
+ crop_start = num_res_crop_start if is_num_res else 0
+ crop_size = num_res_crop_size if is_num_res else dim
+ slices.append(slice(crop_start, crop_start + crop_size))
+ protein[k] = v[slices]
+
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
+
+ return protein
diff --git a/analysis/src/common/geo_utils.py b/analysis/src/common/geo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..62949d45b56a9c9716a5a9cc338f85f2a6d1cdc7
--- /dev/null
+++ b/analysis/src/common/geo_utils.py
@@ -0,0 +1,155 @@
+"""
+Utility functions for geometric operations (torch only).
+"""
+import torch
+
+
+def rots_mul_vecs(m, v):
+ """(Batch) Apply rotations 'm' to vectors 'v'."""
+ return torch.stack([
+ m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
+ m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
+ m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
+ ], dim=-1)
+
+def distance(p, eps=1e-10):
+ """Calculate distance between a pair of points (dim=-2)."""
+ # [*, 2, 3]
+ return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
+
+def dihedral(p, eps=1e-10):
+ """Calculate dihedral angle between a quadruple of points (dim=-2)."""
+ # p: [*, 4, 3]
+
+ # [*, 3]
+ u1 = p[..., 1, :] - p[..., 0, :]
+ u2 = p[..., 2, :] - p[..., 1, :]
+ u3 = p[..., 3, :] - p[..., 2, :]
+
+ # [*, 3]
+ u1xu2 = torch.cross(u1, u2, dim=-1)
+ u2xu3 = torch.cross(u2, u3, dim=-1)
+
+ # [*]
+ u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
+ u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
+ u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5
+
+ # [*]
+ cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
+ sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm)
+
+ return torch.stack([cos_enc, sin_enc], dim=-1)
+
+def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
+ # pos: [*, L, 3]
+ dists_2d = torch.linalg.norm(
+ pos[..., :, None, :] - pos[..., None, :, :], axis=-1
+ )[..., None]
+ lower = torch.linspace(
+ min_bin,
+ max_bin,
+ num_bins,
+ device=pos.device)
+ upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
+ distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
+ return distogram
+
+def rmsd(xyz1, xyz2):
+ """ Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
+ return squared_deviation(xyz1, xyz2, 'rmsd')
+
+def squared_deviation(xyz1, xyz2, reduction='none'):
+ """Squared point-wise deviation between two point clouds after alignment.
+
+ Args:
+ xyz1: (*, L, 3), to be transformed
+ xyz2: (*, L, 3), the reference
+
+ Returns:
+ rmsd: (*, ) or none: (*, L)
+ """
+ map_to_np = False
+ if not torch.is_tensor(xyz1):
+ map_to_np = True
+ xyz1 = torch.as_tensor(xyz1)
+ xyz2 = torch.as_tensor(xyz2)
+
+ R, t = _find_rigid_alignment(xyz1, xyz2)
+
+ # print(R.shape, t.shape) # B, 3, 3 & B, 3
+
+ # xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
+ xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+
+ sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L)
+
+ assert sd.shape == xyz1.shape[:-1]
+ if reduction == 'none':
+ pass
+ elif reduction == 'rmsd':
+ sd = torch.sqrt(sd.mean(dim=-1))
+ else:
+ raise NotImplementedError()
+
+ sd = sd.numpy() if map_to_np else sd
+ return sd
+
+def _find_rigid_alignment(src, tgt):
+ """Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
+ https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
+
+ See: https://en.wikipedia.org/wiki/Kabsch_algorithm
+
+ 2-D or 3-D registration with known correspondences.
+ Registration occurs in the zero centered coordinate system, and then
+ must be transported back.
+
+ Args:
+ src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
+ tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
+ Returns:
+ R: optimal rotation (*, 3, 3)
+ t: optimal translation (*, 3)
+
+ Test on rotation + translation and on rotation + translation + reflection
+ >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
+ >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
+ >>> B = (R0.mm(A.T)).T
+ >>> t0 = torch.tensor([3., 3.])
+ >>> B += t0
+ >>> R, t = find_rigid_alignment(A, B)
+ >>> A_aligned = (R.mm(A.T)).T + t
+ >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
+ >>> rmsd
+ tensor(3.7064e-07)
+ >>> B *= torch.tensor([-1., 1.])
+ >>> R, t = find_rigid_alignment(A, B)
+ >>> A_aligned = (R.mm(A.T)).T + t
+ >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
+ >>> rmsd
+ tensor(3.7064e-07)
+ """
+ assert src.shape[-2] > 1
+ src_com = src.mean(dim=-2, keepdim=True)
+ tgt_com = tgt.mean(dim=-2, keepdim=True)
+ src_centered = src - src_com
+ tgt_centered = tgt - tgt_com
+
+ # Covariance matrix
+
+ # H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3
+ H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered)
+
+ U, S, V = torch.svd(H)
+ # Rotation matrix
+
+ # R = V.bmm(U.transpose(-2,-1))
+ R = torch.matmul(V, U.transpose(-2, -1))
+
+ # Translation vector
+
+ # t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
+ t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)
+
+ return R, t.squeeze(-2) # (B, 3, 3), (B, 3)
diff --git a/analysis/src/common/pdb_utils.py b/analysis/src/common/pdb_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f65b7ce8ac7d4c91c3517da48c356eab2164162
--- /dev/null
+++ b/analysis/src/common/pdb_utils.py
@@ -0,0 +1,353 @@
+"""Utility functions for operating PDB files.
+"""
+import os
+import re
+from typing import Optional
+from collections import OrderedDict
+
+import numpy as np
+from tqdm import tqdm
+
+import biotite.structure as struct
+from biotite.structure.io.pdb import PDBFile
+
+from src.common import protein
+
+
+def write_pdb_string(pdb_string: str, save_to: str):
+ """Write pdb string to file"""
+ with open(save_to, 'w') as f:
+ f.write(pdb_string)
+
+def read_pdb_to_string(pdb_file):
+ """Read PDB file as pdb string. Convenient API"""
+ with open(pdb_file, 'r') as fi:
+ pdb_string = ''
+ for line in fi:
+ if line.startswith('END') or line.startswith('TER') \
+ or line.startswith('MODEL') or line.startswith('ATOM'):
+ pdb_string += line
+ return pdb_string
+
+def merge_pdbfiles(input, output_file, verbose=True):
+ """ordered merging process of pdbs"""
+ if isinstance(input, str):
+ pdb_files = [os.path.join(input, f) for f in os.listdir(input) if f.endswith('.pdb')]
+ elif isinstance(input, list):
+ pdb_files = input
+
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+
+ model_number = 0
+ pdb_lines = []
+ if verbose:
+ _iter = tqdm(pdb_files, desc='Merging PDBs')
+ else:
+ _iter = pdb_files
+
+ for pdb_file in _iter:
+ with open(pdb_file, 'r') as pdb:
+ lines = pdb.readlines()
+ single_model = True
+
+ for line in lines:
+ if line.startswith('MODEL') or line.startswith('ENDMDL'):
+ single_model = False
+ break
+
+ if single_model: # single model
+ model_number += 1
+ pdb_lines.append(f"MODEL {model_number}")
+ for line in lines:
+ if line.startswith('TER') or line.startswith('ATOM'):
+ pdb_lines.append(line.strip())
+ pdb_lines.append("ENDMDL")
+ else: # multiple models
+ for line in lines:
+ if line.startswith('MODEL'):
+ model_number += 1
+ if model_number > 1:
+ pdb_lines.append("ENDMDL")
+ pdb_lines.append(f"MODEL {model_number}")
+ elif line.startswith('END'):
+ continue
+ elif line.startswith('TER') or line.startswith('ATOM'):
+ pdb_lines.append(line.strip())
+ pdb_lines.append('ENDMDL')
+ pdb_lines.append('END')
+ pdb_lines = [_line.ljust(80) for _line in pdb_lines]
+ pdb_str = '\n'.join(pdb_lines) + '\n'
+ with open(output_file, 'w') as fo:
+ fo.write(pdb_str)
+
+ if verbose:
+ print(f"Merged {len(pdb_files)} PDB files into {output_file} with {model_number} models.")
+
+
+def split_pdbfile(pdb_file, output_dir=None, suffix='index', verbose=True):
+ """Split a PDB file into multiple PDB files in output_dir.
+ Preassume that each model is wrapped by 'MODEL' and 'ENDMDL'.
+ """
+ assert os.path.exists(pdb_file), f"File {pdb_file} does not exist."
+ assert suffix == 'index', 'Only support [suffix=index] for now.'
+
+ if output_dir is not None: # also dump to output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ base = os.path.splitext(os.path.basename(pdb_file))[0]
+
+ i = 0
+ pdb_strs = []
+ pdb_string = ''
+ with open(pdb_file, 'r') as fi:
+ # pdb_string = ''
+ for line in fi:
+ if line.startswith('MODEL'):
+ pdb_string = ''
+ elif line.startswith('ATOM') or line.startswith('TER'):
+ pdb_string += line
+ elif line.startswith('ENDMDL') or line.startswith('END'):
+ if pdb_string == '': continue
+ pdb_string += 'END\n'
+ if output_dir is not None:
+ _save_to = os.path.join(output_dir, f'{base}_{i}.pdb') if suffix == 'index' else None
+ with open(_save_to, 'w') as fo:
+ fo.write(pdb_string)
+ pdb_strs.append(pdb_string)
+ pdb_string = ''
+ i += 1
+ else:
+ if verbose:
+ print(f"Warning: line '{line}' is not recognized. Skip.")
+ if verbose:
+ print(f">>> Split pdb {pdb_file} into {i}/{len(pdb_strs)} structures.")
+ return pdb_strs
+
+
+def stratify_sample_pdbfile(input_path, output_path, n_max_sample=1000, end_at=0, verbose=True):
+ """ """
+ assert os.path.exists(input_path), f"File {input_path} does not exist."
+ assert not os.path.exists(output_path), f"Output path {output_path} already exists."
+
+ i = 0
+ pdb_strs = []
+ with open(input_path, 'r') as fi:
+ # pdb_string = ''
+ pdb_lines_per_model = []
+ for line in fi:
+ if line.startswith('MODEL'):
+ pdb_lines_per_model = []
+ elif line.startswith('ATOM') or line.startswith('TER'):
+ pdb_lines_per_model.append(line.strip())
+ elif line.startswith('ENDMDL') or line.startswith('END'):
+ if pdb_lines_per_model == []: continue # skip empty model
+ # wrap up the model
+ pdb_lines_per_model.append('ENDMDL')
+ # Pad all lines to 80 characters.
+ pdb_lines_per_model = [_line.ljust(80) for _line in pdb_lines_per_model]
+ pdb_str_per_model = '\n'.join(pdb_lines_per_model) + '\n' # Add terminating newline.
+ pdb_strs.append(pdb_str_per_model)
+ # reset
+ pdb_lines_per_model = []
+ i += 1
+ else:
+ if verbose:
+ print(f"Warning: line '{line}' is not recognized. Skip.")
+ if end_at > 0 and i > end_at:
+ break
+ end = end_at if end_at > 0 else len(pdb_strs)
+
+ # sample evenly
+ if end > n_max_sample:
+ interleave_step = int(end // n_max_sample) # floor
+ sampled_pdb_strs = pdb_strs[:end][::interleave_step][:n_max_sample]
+ else:
+ sampled_pdb_strs = pdb_strs[:end]
+
+ output_str = ''
+ for i, pdb_str in enumerate(sampled_pdb_strs): # renumber models
+ output_str += f"MODEL {i+1}".ljust(80) + '\n'
+ output_str += pdb_str
+ output_str = output_str + ('END'.ljust(80) + '\n')
+
+ write_pdb_string(output_str, save_to=output_path)
+ if verbose:
+ print(f">>> Split pdb {input_path} into {len(sampled_pdb_strs)}/{n_max_sample} structures.")
+ return
+
+
+def protein_with_default_params(
+ atom_positions: np.ndarray,
+ atom_mask: np.ndarray,
+ aatype: Optional[np.ndarray] = None,
+ b_factors: Optional[np.ndarray] = None,
+ chain_index: Optional[np.ndarray] = None,
+ residue_index: Optional[np.ndarray] = None,
+):
+ assert atom_positions.ndim == 3
+ assert atom_positions.shape[-1] == 3
+ assert atom_positions.shape[-2] == 37
+ n = atom_positions.shape[0]
+ sqz = lambda x: np.squeeze(x) if x.shape[0] == 1 and len(x.shape) > 1 else x
+
+ residue_index = np.arange(n) + 1 if residue_index is None else sqz(residue_index)
+ chain_index = np.zeros(n) if chain_index is None else sqz(chain_index)
+ b_factors = np.zeros([n, 37]) if b_factors is None else sqz(b_factors)
+ aatype = np.zeros(n, dtype=int) if aatype is None else sqz(aatype)
+
+ return protein.Protein(
+ atom_positions=atom_positions,
+ atom_mask=atom_mask,
+ aatype=aatype,
+ residue_index=residue_index,
+ chain_index=chain_index,
+ b_factors=b_factors
+ )
+
+def atom37_to_pdb(
+ save_to: str,
+ atom_positions: np.ndarray,
+ aatype: Optional[np.ndarray] = None,
+ b_factors: Optional[np.ndarray] = None,
+ chain_index: Optional[np.ndarray] = None,
+ residue_index: Optional[np.ndarray] = None,
+ overwrite: bool = False,
+ no_indexing: bool = True,
+):
+ # configure save path
+ if overwrite:
+ max_existing_idx = 0
+ else:
+ file_dir = os.path.dirname(save_to)
+ file_name = os.path.basename(save_to).strip('.pdb')
+ existing_files = [x for x in os.listdir(file_dir) if file_name in x]
+ max_existing_idx = max([
+ int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
+ if re.findall(r'_(\d+).pdb', x)] + [0])
+ if not no_indexing:
+ save_to = save_to.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
+ else:
+ save_to = save_to
+
+ with open(save_to, 'w') as f:
+ if atom_positions.ndim == 4:
+ for mi, pos37 in enumerate(atom_positions):
+ atom_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
+ prot = protein_with_default_params(
+ pos37, atom_mask, aatype=aatype, b_factors=b_factors,
+ chain_index=chain_index, residue_index=residue_index
+ )
+ pdb_str = protein.to_pdb(prot, model=mi+1, add_end=False)
+ f.write(pdb_str)
+ elif atom_positions.ndim == 3:
+ atom_mask = np.sum(np.abs(atom_positions), axis=-1) > 1e-7
+ prot = protein_with_default_params(
+ atom_positions, atom_mask, aatype=aatype, b_factors=b_factors,
+ chain_index=chain_index, residue_index=residue_index
+ )
+ pdb_str = protein.to_pdb(prot, model=1, add_end=False)
+ f.write(pdb_str)
+ else:
+ raise ValueError(f'Invalid positions shape {atom_positions.shape}')
+ f.write('END')
+
+ return save_to
+
+def extract_backbone_coords_from_pdb(pdb_path: str, target_atoms: Optional[list] = ["CA"]):
+ structure = PDBFile.read(pdb_path)
+ structure_list = structure.get_structure()
+
+ coords_list = []
+ for b_idx in range(structure.get_model_count()):
+ chain = structure_list[b_idx]
+
+ backbone_atoms = chain[struct.filter_backbone(chain)] # This includes the “N”, “CA” and “C” atoms of amino acids.
+ ret_coords = OrderedDict()
+ # init dict
+ for k in target_atoms:
+ ret_coords[k] = []
+
+ for c in backbone_atoms:
+ if c.atom_name in ret_coords:
+ ret_coords[c.atom_name].append(c.coord)
+
+ ret_coords = [np.vstack(v) for k,v in ret_coords.items()]
+ if len(target_atoms) == 1:
+ ret_coords = ret_coords[0] # L, 3
+ else:
+ ret_coords = np.stack(ret_coords, axis=1) # L, na, 3
+
+ coords_list.append(ret_coords)
+
+ coords_list = np.stack(coords_list, axis=0) # B, L, na, 3 or B, L, 3 (ca only)
+ return coords_list
+
+
+def extract_backbone_coords_from_pdb_dir(pdb_dir: str):
+ return np.concatenate([
+ extract_backbone_coords_from_pdb(os.path.join(pdb_dir, f))
+ for f in os.listdir(pdb_dir) if f.endswith('.pdb')
+ ], axis=0)
+
+def extract_backbone_coords_from_npy(npy_path: str):
+ return np.load(npy_path)
+
+
+def extract_backbone_coords(input_path: str,
+ max_n_model: Optional[int] = None,
+):
+ """Extract backbone coordinates from PDB file.
+
+ Args:
+ input_path (str): The path to the PDB file.
+ ca_only (bool): Whether to extract only CA coordinates.
+ max_n_model (int): The maximum number of models to extract.
+ """
+ assert os.path.exists(input_path), f"File {input_path} does not exist."
+ if input_path.endswith('.pdb'):
+ coords = extract_backbone_coords_from_pdb(input_path)
+ elif input_path.endswith('.npy'):
+ coords = extract_backbone_coords_from_npy(input_path)
+ elif os.path.isdir(input_path):
+ coords = extract_backbone_coords_from_pdb_dir(input_path)
+ else:
+ raise ValueError(f"Unrecognized input path {input_path}.")
+
+ if max_n_model is not None and len(coords) > max_n_model > 0:
+ coords = coords[:max_n_model]
+ return coords
+
+
+
+if __name__ == '__main__':
+ import argparse
+ def get_argparser():
+ parser = argparse.ArgumentParser(description='Main script for pdb processing.')
+ parser.add_argument("input", type=str, help="The generic path to sampled pdb directory / pdb file.")
+ parser.add_argument("-m", "--mode", type=str, help="The mode of processing.",
+ default="split")
+ parser.add_argument("-o", "--output", type=str, help="The output directory for processed pdb files.",
+ default=None)
+
+ args = parser.parse_args()
+ return args
+ args = get_argparser()
+
+ # ad hoc functions
+ def split_pdbs(args):
+ os.makedirs(args.output, exist_ok=True)
+ _ = split_pdbfile(pdb_file=args.input,
+ output_dir=args.output)
+
+ def merge_pdbs(args):
+ output = args.output or f"{args.input}_all.pdb"
+ merge_pdbfiles(input=args.input,
+ output_file=output)
+
+ if args.mode == "split":
+ split_pdbs(args)
+ elif args.mode == "merge":
+ merge_pdbs(args)
+ elif args.mode == "stratify":
+ stratify_sample_pdbfile(input_path=args.input, output_path=args.output)
+ else:
+ raise ValueError(f"Unrecognized mode {args.mode}.")
\ No newline at end of file
diff --git a/analysis/src/common/protein.py b/analysis/src/common/protein.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a4ba6b66e48c956e7c966f60fc0819865d81eaf
--- /dev/null
+++ b/analysis/src/common/protein.py
@@ -0,0 +1,289 @@
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Protein data type."""
+import dataclasses
+import io
+from typing import Any, Mapping, Optional
+
+from Bio.PDB import PDBParser
+import numpy as np
+
+from src.common import residue_constants
+
+
+FeatureDict = Mapping[str, np.ndarray]
+ModelOutput = Mapping[str, Any] # Is a nested dict.
+
+# Complete sequence of chain IDs supported by the PDB format.
+PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
+PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
+
+
+@dataclasses.dataclass(frozen=True)
+class Protein:
+ """Protein structure representation."""
+
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
+
+ # Amino-acid type for each residue represented as an integer between 0 and
+ # 20, where 20 is 'X'.
+ aatype: np.ndarray # [num_res]
+
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
+ # is present and 0.0 if not. This should be used for loss masking.
+ atom_mask: np.ndarray # [num_res, num_atom_type]
+
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
+ residue_index: np.ndarray # [num_res]
+
+ # 0-indexed number corresponding to the chain in the protein that this residue
+ # belongs to.
+ chain_index: np.ndarray # [num_res]
+
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
+ # representing the displacement of the residue from its ground truth mean
+ # value.
+ b_factors: np.ndarray # [num_res, num_atom_type]
+
+ def __post_init__(self):
+ if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
+ raise ValueError(
+ f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
+ 'because these cannot be written to PDB format.')
+
+ def to_dict(self):
+ return dataclasses.asdict(self)
+
+
+def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
+ """Takes a PDB string and constructs a Protein object.
+
+ WARNING: All non-standard residue types will be converted into UNK. All
+ non-standard atoms will be ignored.
+
+ Args:
+ pdb_str: The contents of the pdb file
+ chain_id: If chain_id is specified (e.g. A), then only that chain
+ is parsed. Otherwise all chains are parsed.
+
+ Returns:
+ A new `Protein` parsed from the pdb contents.
+ """
+ pdb_fh = io.StringIO(pdb_str)
+ parser = PDBParser(QUIET=True)
+ structure = parser.get_structure('none', pdb_fh)
+ models = list(structure.get_models())
+ if len(models) != 1:
+ raise ValueError(
+ f'Only single model PDBs are supported. Found {len(models)} models.')
+ model = models[0]
+
+ atom_positions = []
+ aatype = []
+ atom_mask = []
+ residue_index = []
+ chain_ids = []
+ b_factors = []
+
+ for chain in model:
+ if chain_id is not None and chain.id != chain_id:
+ continue
+ for res in chain:
+ if res.id[2] != ' ':
+ raise ValueError(
+ f'PDB contains an insertion code at chain {chain.id} and residue '
+ f'index {res.id[1]}. These are not supported.')
+ res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
+ restype_idx = residue_constants.restype_order.get(
+ res_shortname, residue_constants.restype_num)
+ pos = np.zeros((residue_constants.atom_type_num, 3))
+ mask = np.zeros((residue_constants.atom_type_num,))
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
+ for atom in res:
+ if atom.name not in residue_constants.atom_types:
+ continue
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
+ mask[residue_constants.atom_order[atom.name]] = 1.
+ res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
+ if np.sum(mask) < 0.5:
+ # If no known atom positions are reported for the residue then skip it.
+ continue
+ aatype.append(restype_idx)
+ atom_positions.append(pos)
+ atom_mask.append(mask)
+ residue_index.append(res.id[1])
+ chain_ids.append(chain.id)
+ b_factors.append(res_b_factors)
+
+ # Chain IDs are usually characters so map these to ints.
+ unique_chain_ids = np.unique(chain_ids)
+ chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
+
+ return Protein(
+ atom_positions=np.array(atom_positions),
+ atom_mask=np.array(atom_mask),
+ aatype=np.array(aatype),
+ residue_index=np.array(residue_index),
+ chain_index=chain_index,
+ b_factors=np.array(b_factors))
+
+
+def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
+ chain_end = 'TER'
+ return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
+ f'{chain_name:>1}{residue_index:>4}')
+
+
+def to_pdb(prot: Protein, model=1, add_end=True) -> str:
+ """Converts a `Protein` instance to a PDB string.
+
+ Args:
+ prot: The protein to convert to PDB.
+
+ Returns:
+ PDB string.
+ """
+ restypes = residue_constants.restypes + ['X']
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
+ atom_types = residue_constants.atom_types
+
+ pdb_lines = []
+
+ atom_mask = prot.atom_mask
+ aatype = prot.aatype
+ atom_positions = prot.atom_positions
+ residue_index = prot.residue_index.astype(int)
+ chain_index = prot.chain_index.astype(int)
+ b_factors = prot.b_factors
+
+ if np.any(aatype > residue_constants.restype_num):
+ raise ValueError('Invalid aatypes.')
+
+ # Construct a mapping from chain integer indices to chain ID strings.
+ chain_ids = {}
+ for i in np.unique(chain_index): # np.unique gives sorted output.
+ if i >= PDB_MAX_CHAINS:
+ raise ValueError(
+ f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
+ chain_ids[i] = PDB_CHAIN_IDS[i]
+
+ pdb_lines.append(f'MODEL {model}')
+ atom_index = 1
+ last_chain_index = chain_index[0]
+ # Add all atom sites.
+ for i in range(aatype.shape[0]):
+ # Close the previous chain if in a multichain PDB.
+ if last_chain_index != chain_index[i]:
+ pdb_lines.append(_chain_end(
+ atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
+ residue_index[i - 1]))
+ last_chain_index = chain_index[i]
+ atom_index += 1 # Atom index increases at the TER symbol.
+
+ res_name_3 = res_1to3(aatype[i])
+ for atom_name, pos, mask, b_factor in zip(
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
+ if mask < 0.5:
+ continue
+
+ # skip CB for GLY
+ if res_name_3 == 'GLY' and atom_name == 'CB':
+ continue
+
+ record_type = 'ATOM'
+ name = atom_name if len(atom_name) == 4 else f' {atom_name}'
+ alt_loc = ''
+ insertion_code = ''
+ occupancy = 1.00
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
+ charge = ''
+ # PDB is a columnar format, every space matters here!
+ atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
+ f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
+ f'{residue_index[i]:>4}{insertion_code:>1} '
+ f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
+ f'{occupancy:>6.2f}{b_factor:>6.2f} '
+ f'{element:>2}{charge:>2}')
+ pdb_lines.append(atom_line)
+ atom_index += 1
+
+ # Close the final chain.
+ pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
+ chain_ids[chain_index[-1]], residue_index[-1]))
+ pdb_lines.append('ENDMDL')
+ if add_end:
+ pdb_lines.append('END')
+
+ # Pad all lines to 80 characters.
+ pdb_lines = [line.ljust(80) for line in pdb_lines]
+ return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
+
+
+def ideal_atom_mask(prot: Protein) -> np.ndarray:
+ """Computes an ideal atom mask.
+
+ `Protein.atom_mask` typically is defined according to the atoms that are
+ reported in the PDB. This function computes a mask according to heavy atoms
+ that should be present in the given sequence of amino acids.
+
+ Args:
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
+
+ Returns:
+ An ideal atom mask.
+ """
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
+
+
+def from_prediction(
+ features: FeatureDict,
+ result: ModelOutput,
+ b_factors: Optional[np.ndarray] = None,
+ remove_leading_feature_dimension: bool = True) -> Protein:
+ """Assembles a protein from a prediction.
+
+ Args:
+ features: Dictionary holding model inputs.
+ result: Dictionary holding model outputs.
+ b_factors: (Optional) B-factors to use for the protein.
+ remove_leading_feature_dimension: Whether to remove the leading dimension
+ of the `features` values.
+
+ Returns:
+ A protein instance.
+ """
+ fold_output = result['structure_module']
+
+ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
+ return arr[0] if remove_leading_feature_dimension else arr
+
+ if 'asym_id' in features:
+ chain_index = _maybe_remove_leading_dim(features['asym_id'])
+ else:
+ chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
+
+ if b_factors is None:
+ b_factors = np.zeros_like(fold_output['final_atom_mask'])
+
+ return Protein(
+ aatype=_maybe_remove_leading_dim(features['aatype']),
+ atom_positions=fold_output['final_atom_positions'],
+ atom_mask=fold_output['final_atom_mask'],
+ residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
+ chain_index=chain_index,
+ b_factors=b_factors)
\ No newline at end of file
diff --git a/analysis/src/common/residue_constants.py b/analysis/src/common/residue_constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b925b5c25bb76bbffc12940fae3afbc9f2ba1a8
--- /dev/null
+++ b/analysis/src/common/residue_constants.py
@@ -0,0 +1,897 @@
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Constants used in AlphaFold."""
+
+import collections
+import functools
+import os
+from typing import List, Mapping, Tuple
+
+import numpy as np
+import tree
+
+# Internal import (35fd).
+
+
+# Distance from one CA to next CA [trans configuration: omega = 180].
+ca_ca = 3.80209737096
+
+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
+# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
+# chi angles so their chi angle lists are empty.
+chi_angles_atoms = {
+ 'ALA': [],
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
+ 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
+ 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'CYS': [['N', 'CA', 'CB', 'SG']],
+ 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLY': [],
+ 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
+ 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
+ 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
+ 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
+ ['CB', 'CG', 'SD', 'CE']],
+ 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
+ 'SER': [['N', 'CA', 'CB', 'OG']],
+ 'THR': [['N', 'CA', 'CB', 'OG1']],
+ 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'VAL': [['N', 'CA', 'CB', 'CG1']],
+}
+
+# If chi angles given in fixed-length array, this matrix determines how to mask
+# them for each AA type. The order is as per restype_order (see below).
+chi_angles_mask = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [1.0, 1.0, 1.0, 1.0], # ARG
+ [1.0, 1.0, 0.0, 0.0], # ASN
+ [1.0, 1.0, 0.0, 0.0], # ASP
+ [1.0, 0.0, 0.0, 0.0], # CYS
+ [1.0, 1.0, 1.0, 0.0], # GLN
+ [1.0, 1.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [1.0, 1.0, 0.0, 0.0], # HIS
+ [1.0, 1.0, 0.0, 0.0], # ILE
+ [1.0, 1.0, 0.0, 0.0], # LEU
+ [1.0, 1.0, 1.0, 1.0], # LYS
+ [1.0, 1.0, 1.0, 0.0], # MET
+ [1.0, 1.0, 0.0, 0.0], # PHE
+ [1.0, 1.0, 0.0, 0.0], # PRO
+ [1.0, 0.0, 0.0, 0.0], # SER
+ [1.0, 0.0, 0.0, 0.0], # THR
+ [1.0, 1.0, 0.0, 0.0], # TRP
+ [1.0, 1.0, 0.0, 0.0], # TYR
+ [1.0, 0.0, 0.0, 0.0], # VAL
+]
+
+# The following chi angles are pi periodic: they can be rotated by a multiple
+# of pi without affecting the structure.
+chi_pi_periodic = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [0.0, 0.0, 0.0, 0.0], # ARG
+ [0.0, 0.0, 0.0, 0.0], # ASN
+ [0.0, 1.0, 0.0, 0.0], # ASP
+ [0.0, 0.0, 0.0, 0.0], # CYS
+ [0.0, 0.0, 0.0, 0.0], # GLN
+ [0.0, 0.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [0.0, 0.0, 0.0, 0.0], # HIS
+ [0.0, 0.0, 0.0, 0.0], # ILE
+ [0.0, 0.0, 0.0, 0.0], # LEU
+ [0.0, 0.0, 0.0, 0.0], # LYS
+ [0.0, 0.0, 0.0, 0.0], # MET
+ [0.0, 1.0, 0.0, 0.0], # PHE
+ [0.0, 0.0, 0.0, 0.0], # PRO
+ [0.0, 0.0, 0.0, 0.0], # SER
+ [0.0, 0.0, 0.0, 0.0], # THR
+ [0.0, 0.0, 0.0, 0.0], # TRP
+ [0.0, 1.0, 0.0, 0.0], # TYR
+ [0.0, 0.0, 0.0, 0.0], # VAL
+ [0.0, 0.0, 0.0, 0.0], # UNK
+]
+
+# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
+# psi and chi angles:
+# 0: 'backbone group',
+# 1: 'pre-omega-group', (empty)
+# 2: 'phi-group', (currently empty, because it defines only hydrogens)
+# 3: 'psi-group',
+# 4,5,6,7: 'chi1,2,3,4-group'
+# The atom positions are relative to the axis-end-atom of the corresponding
+# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
+# is defined such that the dihedral-angle-definiting atom (the last entry in
+# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
+# format: [atomname, group_idx, rel_position]
+rigid_group_atom_positions = {
+ 'ALA': [
+ ['N', 0, (-0.525, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.529, -0.774, -1.205)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ],
+ 'ARG': [
+ ['N', 0, (-0.524, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.209)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.616, 1.390, -0.000)],
+ ['CD', 5, (0.564, 1.414, 0.000)],
+ ['NE', 6, (0.539, 1.357, -0.000)],
+ ['NH1', 7, (0.206, 2.301, 0.000)],
+ ['NH2', 7, (2.078, 0.978, -0.000)],
+ ['CZ', 7, (0.758, 1.093, -0.000)],
+ ],
+ 'ASN': [
+ ['N', 0, (-0.536, 1.357, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.531, -0.787, -1.200)],
+ ['O', 3, (0.625, 1.062, 0.000)],
+ ['CG', 4, (0.584, 1.399, 0.000)],
+ ['ND2', 5, (0.593, -1.188, 0.001)],
+ ['OD1', 5, (0.633, 1.059, 0.000)],
+ ],
+ 'ASP': [
+ ['N', 0, (-0.525, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, 0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.593, 1.398, -0.000)],
+ ['OD1', 5, (0.610, 1.091, 0.000)],
+ ['OD2', 5, (0.592, -1.101, -0.003)],
+ ],
+ 'CYS': [
+ ['N', 0, (-0.522, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, 0.000)],
+ ['CB', 0, (-0.519, -0.773, -1.212)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['SG', 4, (0.728, 1.653, 0.000)],
+ ],
+ 'GLN': [
+ ['N', 0, (-0.526, 1.361, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.779, -1.207)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.615, 1.393, 0.000)],
+ ['CD', 5, (0.587, 1.399, -0.000)],
+ ['NE2', 6, (0.593, -1.189, -0.001)],
+ ['OE1', 6, (0.634, 1.060, 0.000)],
+ ],
+ 'GLU': [
+ ['N', 0, (-0.528, 1.361, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.781, -1.207)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.615, 1.392, 0.000)],
+ ['CD', 5, (0.600, 1.397, 0.000)],
+ ['OE1', 6, (0.607, 1.095, -0.000)],
+ ['OE2', 6, (0.589, -1.104, -0.001)],
+ ],
+ 'GLY': [
+ ['N', 0, (-0.572, 1.337, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.517, -0.000, -0.000)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ],
+ 'HIS': [
+ ['N', 0, (-0.527, 1.360, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.778, -1.208)],
+ ['O', 3, (0.625, 1.063, 0.000)],
+ ['CG', 4, (0.600, 1.370, -0.000)],
+ ['CD2', 5, (0.889, -1.021, 0.003)],
+ ['ND1', 5, (0.744, 1.160, -0.000)],
+ ['CE1', 5, (2.030, 0.851, 0.002)],
+ ['NE2', 5, (2.145, -0.466, 0.004)],
+ ],
+ 'ILE': [
+ ['N', 0, (-0.493, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.536, -0.793, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.534, 1.437, -0.000)],
+ ['CG2', 4, (0.540, -0.785, -1.199)],
+ ['CD1', 5, (0.619, 1.391, 0.000)],
+ ],
+ 'LEU': [
+ ['N', 0, (-0.520, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.773, -1.214)],
+ ['O', 3, (0.625, 1.063, -0.000)],
+ ['CG', 4, (0.678, 1.371, 0.000)],
+ ['CD1', 5, (0.530, 1.430, -0.000)],
+ ['CD2', 5, (0.535, -0.774, 1.200)],
+ ],
+ 'LYS': [
+ ['N', 0, (-0.526, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.619, 1.390, 0.000)],
+ ['CD', 5, (0.559, 1.417, 0.000)],
+ ['CE', 6, (0.560, 1.416, 0.000)],
+ ['NZ', 7, (0.554, 1.387, 0.000)],
+ ],
+ 'MET': [
+ ['N', 0, (-0.521, 1.364, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.210)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['CG', 4, (0.613, 1.391, -0.000)],
+ ['SD', 5, (0.703, 1.695, 0.000)],
+ ['CE', 6, (0.320, 1.786, -0.000)],
+ ],
+ 'PHE': [
+ ['N', 0, (-0.518, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, -0.000)],
+ ['CB', 0, (-0.525, -0.776, -1.212)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.377, 0.000)],
+ ['CD1', 5, (0.709, 1.195, -0.000)],
+ ['CD2', 5, (0.706, -1.196, 0.000)],
+ ['CE1', 5, (2.102, 1.198, -0.000)],
+ ['CE2', 5, (2.098, -1.201, -0.000)],
+ ['CZ', 5, (2.794, -0.003, -0.001)],
+ ],
+ 'PRO': [
+ ['N', 0, (-0.566, 1.351, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, 0.000)],
+ ['CB', 0, (-0.546, -0.611, -1.293)],
+ ['O', 3, (0.621, 1.066, 0.000)],
+ ['CG', 4, (0.382, 1.445, 0.0)],
+ # ['CD', 5, (0.427, 1.440, 0.0)],
+ ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
+ ],
+ 'SER': [
+ ['N', 0, (-0.529, 1.360, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.518, -0.777, -1.211)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['OG', 4, (0.503, 1.325, 0.000)],
+ ],
+ 'THR': [
+ ['N', 0, (-0.517, 1.364, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, -0.000)],
+ ['CB', 0, (-0.516, -0.793, -1.215)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG2', 4, (0.550, -0.718, -1.228)],
+ ['OG1', 4, (0.472, 1.353, 0.000)],
+ ],
+ 'TRP': [
+ ['N', 0, (-0.521, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.212)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ['CG', 4, (0.609, 1.370, -0.000)],
+ ['CD1', 5, (0.824, 1.091, 0.000)],
+ ['CD2', 5, (0.854, -1.148, -0.005)],
+ ['CE2', 5, (2.186, -0.678, -0.007)],
+ ['CE3', 5, (0.622, -2.530, -0.007)],
+ ['NE1', 5, (2.140, 0.690, -0.004)],
+ ['CH2', 5, (3.028, -2.890, -0.013)],
+ ['CZ2', 5, (3.283, -1.543, -0.011)],
+ ['CZ3', 5, (1.715, -3.389, -0.011)],
+ ],
+ 'TYR': [
+ ['N', 0, (-0.522, 1.362, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.776, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.382, -0.000)],
+ ['CD1', 5, (0.716, 1.195, -0.000)],
+ ['CD2', 5, (0.713, -1.194, -0.001)],
+ ['CE1', 5, (2.107, 1.200, -0.002)],
+ ['CE2', 5, (2.104, -1.201, -0.003)],
+ ['OH', 5, (4.168, -0.002, -0.005)],
+ ['CZ', 5, (2.791, -0.001, -0.003)],
+ ],
+ 'VAL': [
+ ['N', 0, (-0.494, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.533, -0.795, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.540, 1.429, -0.000)],
+ ['CG2', 4, (0.533, -0.776, 1.203)],
+ ],
+}
+
+# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
+residue_atoms = {
+ 'ALA': ['C', 'CA', 'CB', 'N', 'O'],
+ 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
+ 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
+ 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
+ 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
+ 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
+ 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
+ 'GLY': ['C', 'CA', 'N', 'O'],
+ 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
+ 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
+ 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
+ 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
+ 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
+ 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
+ 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
+ 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
+ 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
+ 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
+ 'CH2', 'N', 'NE1', 'O'],
+ 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
+ 'OH'],
+ 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
+}
+
+# Naming swaps for ambiguous atom names.
+# Due to symmetries in the amino acids the naming of atoms is ambiguous in
+# 4 of the 20 amino acids.
+# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
+# in LEU, VAL and ARG can be resolved by using the 3d constellations of
+# the 'ambiguous' atoms and their neighbours)
+residue_atom_renaming_swaps = {
+ 'ASP': {'OD1': 'OD2'},
+ 'GLU': {'OE1': 'OE2'},
+ 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
+ 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
+}
+
+# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
+van_der_waals_radius = {
+ 'C': 1.7,
+ 'N': 1.55,
+ 'O': 1.52,
+ 'S': 1.8,
+}
+
+Bond = collections.namedtuple(
+ 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
+BondAngle = collections.namedtuple(
+ 'BondAngle',
+ ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
+
+
+@functools.lru_cache(maxsize=None)
+def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
+ Mapping[str, List[Bond]],
+ Mapping[str, List[BondAngle]]]:
+ """Load stereo_chemical_props.txt into a nice structure.
+
+ Load literature values for bond lengths and bond angles and translate
+ bond angles into the length of the opposite edge of the triangle
+ ("residue_virtual_bonds").
+
+ Returns:
+ residue_bonds: Dict that maps resname -> list of Bond tuples.
+ residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
+ residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
+ """
+ stereo_chemical_props_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
+ )
+ with open(stereo_chemical_props_path, 'rt') as f:
+ stereo_chemical_props = f.read()
+ lines_iter = iter(stereo_chemical_props.splitlines())
+ # Load bond lengths.
+ residue_bonds = {}
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, length, stddev = line.split()
+ atom1, atom2 = bond.split('-')
+ if resname not in residue_bonds:
+ residue_bonds[resname] = []
+ residue_bonds[resname].append(
+ Bond(atom1, atom2, float(length), float(stddev)))
+ residue_bonds['UNK'] = []
+
+ # Load bond angles.
+ residue_bond_angles = {}
+ next(lines_iter) # Skip empty line.
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, angle_degree, stddev_degree = line.split()
+ atom1, atom2, atom3 = bond.split('-')
+ if resname not in residue_bond_angles:
+ residue_bond_angles[resname] = []
+ residue_bond_angles[resname].append(
+ BondAngle(atom1, atom2, atom3,
+ float(angle_degree) / 180. * np.pi,
+ float(stddev_degree) / 180. * np.pi))
+ residue_bond_angles['UNK'] = []
+
+ def make_bond_key(atom1_name, atom2_name):
+ """Unique key to lookup bonds."""
+ return '-'.join(sorted([atom1_name, atom2_name]))
+
+ # Translate bond angles into distances ("virtual bonds").
+ residue_virtual_bonds = {}
+ for resname, bond_angles in residue_bond_angles.items():
+ # Create a fast lookup dict for bond lengths.
+ bond_cache = {}
+ for b in residue_bonds[resname]:
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
+ residue_virtual_bonds[resname] = []
+ for ba in bond_angles:
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
+
+ # Compute distance between atom1 and atom3 using the law of cosines
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
+ gamma = ba.angle_rad
+ length = np.sqrt(bond1.length**2 + bond2.length**2
+ - 2 * bond1.length * bond2.length * np.cos(gamma))
+
+ # Propagation of uncertainty assuming uncorrelated errors.
+ dl_outer = 0.5 / length
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
+ stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
+ (dl_db1 * bond1.stddev)**2 +
+ (dl_db2 * bond2.stddev)**2)
+ residue_virtual_bonds[resname].append(
+ Bond(ba.atom1_name, ba.atom3name, length, stddev))
+
+ return (residue_bonds,
+ residue_virtual_bonds,
+ residue_bond_angles)
+
+
+# Between-residue bond lengths for general bonds (first element) and for Proline
+# (second element).
+between_res_bond_length_c_n = [1.329, 1.341]
+between_res_bond_length_stddev_c_n = [0.014, 0.016]
+
+# Between-residue cos_angles.
+between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
+between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
+
+# This mapping is used when we need to store atom data in a format that requires
+# fixed atom data size for every residue (e.g. a numpy array).
+atom_types = [
+ 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
+ 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
+ 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
+ 'CZ3', 'NZ', 'OXT'
+]
+atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
+atom_type_num = len(atom_types) # := 37.
+
+# A compact atom encoding with 14 columns
+# pylint: disable=line-too-long
+# pylint: disable=bad-whitespace
+restype_name_to_atom14_names = {
+ 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
+ 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
+ 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
+ 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
+ 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
+ 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
+ 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
+ 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
+ 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
+ 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
+ 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
+ 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
+ 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
+ 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
+ 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
+ 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
+ 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
+ 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
+ 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
+ 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
+ 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
+
+}
+# pylint: enable=line-too-long
+# pylint: enable=bad-whitespace
+
+
+# This is the standard residue order when coding AA type as a number.
+# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
+restypes = [
+ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
+ 'S', 'T', 'W', 'Y', 'V'
+]
+restype_order = {restype: i for i, restype in enumerate(restypes)}
+restype_num = len(restypes) # := 20.
+unk_restype_index = restype_num # Catch-all index for unknown restypes.
+
+restypes_with_x = restypes + ['X']
+restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
+
+
+def sequence_to_onehot(
+ sequence: str,
+ mapping: Mapping[str, int],
+ map_unknown_to_x: bool = False) -> np.ndarray:
+ """Maps the given sequence into a one-hot encoded matrix.
+
+ Args:
+ sequence: An amino acid sequence.
+ mapping: A dictionary mapping amino acids to integers.
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
+ the mapping will throw an error.
+
+ Returns:
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
+ the sequence.
+
+ Raises:
+ ValueError: If the mapping doesn't contain values from 0 to
+ num_unique_aas - 1 without any gaps.
+ """
+ num_entries = max(mapping.values()) + 1
+
+ if sorted(set(mapping.values())) != list(range(num_entries)):
+ raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
+ 'without any gaps. Got: %s' % sorted(mapping.values()))
+
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
+
+ for aa_index, aa_type in enumerate(sequence):
+ if map_unknown_to_x:
+ if aa_type.isalpha() and aa_type.isupper():
+ aa_id = mapping.get(aa_type, mapping['X'])
+ else:
+ raise ValueError(f'Invalid character in the sequence: {aa_type}')
+ else:
+ aa_id = mapping[aa_type]
+ one_hot_arr[aa_index, aa_id] = 1
+
+ return one_hot_arr
+
+
+restype_1to3 = {
+ 'A': 'ALA',
+ 'R': 'ARG',
+ 'N': 'ASN',
+ 'D': 'ASP',
+ 'C': 'CYS',
+ 'Q': 'GLN',
+ 'E': 'GLU',
+ 'G': 'GLY',
+ 'H': 'HIS',
+ 'I': 'ILE',
+ 'L': 'LEU',
+ 'K': 'LYS',
+ 'M': 'MET',
+ 'F': 'PHE',
+ 'P': 'PRO',
+ 'S': 'SER',
+ 'T': 'THR',
+ 'W': 'TRP',
+ 'Y': 'TYR',
+ 'V': 'VAL',
+}
+
+
+# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
+# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
+# many more, and less common, three letter names as keys and maps many of these
+# to the same one letter name (including 'X' and 'U' which we don't use here).
+restype_3to1 = {v: k for k, v in restype_1to3.items()}
+
+# Define a restype name for all unknown residues.
+unk_restype = 'UNK'
+
+resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
+resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
+
+
+# The mapping here uses hhblits convention, so that B is mapped to D, J and O
+# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
+# remaining 20 amino acids are kept in alphabetical order.
+# There are 2 non-amino acid codes, X (representing any amino acid) and
+# "-" representing a missing amino acid in an alignment. The id for these
+# codes is put at the end (20 and 21) so that they can easily be ignored if
+# desired.
+HHBLITS_AA_TO_ID = {
+ 'A': 0,
+ 'B': 2,
+ 'C': 1,
+ 'D': 2,
+ 'E': 3,
+ 'F': 4,
+ 'G': 5,
+ 'H': 6,
+ 'I': 7,
+ 'J': 20,
+ 'K': 8,
+ 'L': 9,
+ 'M': 10,
+ 'N': 11,
+ 'O': 20,
+ 'P': 12,
+ 'Q': 13,
+ 'R': 14,
+ 'S': 15,
+ 'T': 16,
+ 'U': 1,
+ 'V': 17,
+ 'W': 18,
+ 'X': 20,
+ 'Y': 19,
+ 'Z': 3,
+ '-': 21,
+}
+
+# Partial inversion of HHBLITS_AA_TO_ID.
+ID_TO_HHBLITS_AA = {
+ 0: 'A',
+ 1: 'C', # Also U.
+ 2: 'D', # Also B.
+ 3: 'E', # Also Z.
+ 4: 'F',
+ 5: 'G',
+ 6: 'H',
+ 7: 'I',
+ 8: 'K',
+ 9: 'L',
+ 10: 'M',
+ 11: 'N',
+ 12: 'P',
+ 13: 'Q',
+ 14: 'R',
+ 15: 'S',
+ 16: 'T',
+ 17: 'V',
+ 18: 'W',
+ 19: 'Y',
+ 20: 'X', # Includes J and O.
+ 21: '-',
+}
+
+restypes_with_x_and_gap = restypes + ['X', '-']
+MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
+ for i in range(len(restypes_with_x_and_gap)))
+
+
+def _make_standard_atom_mask() -> np.ndarray:
+ """Returns [num_res_types, num_atom_types] mask array."""
+ # +1 to account for unknown (all 0s).
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
+ for restype, restype_letter in enumerate(restypes):
+ restype_name = restype_1to3[restype_letter]
+ atom_names = residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = atom_order[atom_name]
+ mask[restype, atom_type] = 1
+ return mask
+
+
+STANDARD_ATOM_MASK = _make_standard_atom_mask()
+
+
+# A one hot representation for the first and second atoms defining the axis
+# of rotation for each chi-angle in each residue.
+def chi_angle_atom(atom_index: int) -> np.ndarray:
+ """Define chi-angle rigid groups via one-hot representations."""
+ chi_angles_index = {}
+ one_hots = []
+
+ for k, v in chi_angles_atoms.items():
+ indices = [atom_types.index(s[atom_index]) for s in v]
+ indices.extend([-1]*(4-len(indices)))
+ chi_angles_index[k] = indices
+
+ for r in restypes:
+ res3 = restype_1to3[r]
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
+ one_hots.append(one_hot)
+
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
+ one_hot = np.stack(one_hots, axis=0)
+ one_hot = np.transpose(one_hot, [0, 2, 1])
+
+ return one_hot
+
+chi_atom_1_one_hot = chi_angle_atom(1)
+chi_atom_2_one_hot = chi_angle_atom(2)
+
+# An array like chi_angles_atoms but using indices rather than names.
+chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
+chi_angles_atom_indices = tree.map_structure(
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
+chi_angles_atom_indices = np.array([
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
+ for chi_atoms in chi_angles_atom_indices])
+
+# Mapping from (res_name, atom_name) pairs to the atom's chi group index
+# and atom index within that group.
+chi_groups_for_atom = collections.defaultdict(list)
+for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
+ for atom_i, atom in enumerate(chi_group):
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
+chi_groups_for_atom = dict(chi_groups_for_atom)
+
+
+def _make_rigid_transformation_4x4(ex, ey, translation):
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
+ # Normalize ex.
+ ex_normalized = ex / np.linalg.norm(ex)
+
+ # make ey perpendicular to ex
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
+ ey_normalized /= np.linalg.norm(ey_normalized)
+
+ # compute ez as cross product
+ eznorm = np.cross(ex_normalized, ey_normalized)
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
+ m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
+ return m
+
+
+# create an array with (restype, atomtype) --> rigid_group_idx
+# and an array with (restype, atomtype, coord) for the atom positions
+# and compute affine transformation matrices (4,4) from one rigid group to the
+# previous group
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
+restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
+restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
+restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
+restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
+
+
+def _make_rigid_group_constants():
+ """Fill the arrays above."""
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
+ resname]:
+ atomtype = atom_order[atomname]
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
+ restype_atom37_mask[restype, atomtype] = 1
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
+
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
+ restype_atom14_mask[restype, atom14idx] = 1
+ restype_atom14_rigid_group_positions[restype,
+ atom14idx, :] = atom_position
+
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_positions = {name: np.array(pos) for name, _, pos
+ in rigid_group_atom_positions[resname]}
+
+ # backbone to backbone is the identity transform
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
+
+ # pre-omega-frame to backbone (currently dummy identity matrix)
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
+
+ # phi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['N'] - atom_positions['CA'],
+ ey=np.array([1., 0., 0.]),
+ translation=atom_positions['N'])
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
+
+ # psi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['C'] - atom_positions['CA'],
+ ey=atom_positions['CA'] - atom_positions['N'],
+ translation=atom_positions['C'])
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
+
+ # chi1-frame to backbone
+ if chi_angles_mask[restype][0]:
+ base_atom_names = chi_angles_atoms[resname][0]
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
+ mat = _make_rigid_transformation_4x4(
+ ex=base_atom_positions[2] - base_atom_positions[1],
+ ey=base_atom_positions[0] - base_atom_positions[1],
+ translation=base_atom_positions[2])
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
+
+ # chi2-frame to chi1-frame
+ # chi3-frame to chi2-frame
+ # chi4-frame to chi3-frame
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
+ # previous frame
+ for chi_idx in range(1, 4):
+ if chi_angles_mask[restype][chi_idx]:
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
+ mat = _make_rigid_transformation_4x4(
+ ex=axis_end_atom_position,
+ ey=np.array([-1., 0., 0.]),
+ translation=axis_end_atom_position)
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
+
+
+_make_rigid_group_constants()
+
+
+def make_atom14_dists_bounds(overlap_tolerance=1.5,
+ bond_length_tolerance_factor=15):
+ """compute upper and lower bounds for bonds to assess violations."""
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_list = restype_name_to_atom14_names[resname]
+
+ # create lower and upper bounds for clashes
+ for atom1_idx, atom1_name in enumerate(atom_list):
+ if not atom1_name:
+ continue
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
+ for atom2_idx, atom2_name in enumerate(atom_list):
+ if (not atom2_name) or atom1_idx == atom2_idx:
+ continue
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
+ lower = atom1_radius + atom2_radius - overlap_tolerance
+ upper = 1e10
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+
+ # overwrite lower and upper bounds for bonds and angles
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
+ atom1_idx = atom_list.index(b.atom1_name)
+ atom2_idx = atom_list.index(b.atom2_name)
+ lower = b.length - bond_length_tolerance_factor * b.stddev
+ upper = b.length + bond_length_tolerance_factor * b.stddev
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
+ return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
+ 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
+ 'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
+ }
\ No newline at end of file
diff --git a/analysis/src/common/rigid_utils.py b/analysis/src/common/rigid_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..699be6541a9c34e9ae9f23e74b97fca70f3eadc8
--- /dev/null
+++ b/analysis/src/common/rigid_utils.py
@@ -0,0 +1,1451 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Any, Sequence, Callable, Optional
+
+import numpy as np
+import torch
+
+from src.common import rotation3d
+
+
+def rot_matmul(
+ a: torch.Tensor,
+ b: torch.Tensor
+) -> torch.Tensor:
+ """
+ Performs matrix multiplication of two rotation matrix tensors. Written
+ out by hand to avoid AMP downcasting.
+
+ Args:
+ a: [*, 3, 3] left multiplicand
+ b: [*, 3, 3] right multiplicand
+ Returns:
+ The product ab
+ """
+ row_1 = torch.stack(
+ [
+ a[..., 0, 0] * b[..., 0, 0]
+ + a[..., 0, 1] * b[..., 1, 0]
+ + a[..., 0, 2] * b[..., 2, 0],
+ a[..., 0, 0] * b[..., 0, 1]
+ + a[..., 0, 1] * b[..., 1, 1]
+ + a[..., 0, 2] * b[..., 2, 1],
+ a[..., 0, 0] * b[..., 0, 2]
+ + a[..., 0, 1] * b[..., 1, 2]
+ + a[..., 0, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+ row_2 = torch.stack(
+ [
+ a[..., 1, 0] * b[..., 0, 0]
+ + a[..., 1, 1] * b[..., 1, 0]
+ + a[..., 1, 2] * b[..., 2, 0],
+ a[..., 1, 0] * b[..., 0, 1]
+ + a[..., 1, 1] * b[..., 1, 1]
+ + a[..., 1, 2] * b[..., 2, 1],
+ a[..., 1, 0] * b[..., 0, 2]
+ + a[..., 1, 1] * b[..., 1, 2]
+ + a[..., 1, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+ row_3 = torch.stack(
+ [
+ a[..., 2, 0] * b[..., 0, 0]
+ + a[..., 2, 1] * b[..., 1, 0]
+ + a[..., 2, 2] * b[..., 2, 0],
+ a[..., 2, 0] * b[..., 0, 1]
+ + a[..., 2, 1] * b[..., 1, 1]
+ + a[..., 2, 2] * b[..., 2, 1],
+ a[..., 2, 0] * b[..., 0, 2]
+ + a[..., 2, 1] * b[..., 1, 2]
+ + a[..., 2, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+
+ return torch.stack([row_1, row_2, row_3], dim=-2)
+
+
+def rot_vec_mul(
+ r: torch.Tensor,
+ t: torch.Tensor
+) -> torch.Tensor:
+ """
+ Applies a rotation to a vector. Written out by hand to avoid transfer
+ to avoid AMP downcasting.
+
+ Args:
+ r: [*, 3, 3] rotation matrices
+ t: [*, 3] coordinate tensors
+ Returns:
+ [*, 3] rotated coordinates
+ """
+ x = t[..., 0]
+ y = t[..., 1]
+ z = t[..., 2]
+ return torch.stack(
+ [
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
+ ],
+ dim=-1,
+ )
+
+
+def identity_rot_mats(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ rots = torch.eye(
+ 3, dtype=dtype, device=device, requires_grad=requires_grad
+ )
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
+ rots = rots.expand(*batch_dims, -1, -1)
+
+ return rots
+
+
+def identity_trans(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ trans = torch.zeros(
+ (*batch_dims, 3),
+ dtype=dtype,
+ device=device,
+ requires_grad=requires_grad
+ )
+ return trans
+
+
+def identity_quats(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ quat = torch.zeros(
+ (*batch_dims, 4),
+ dtype=dtype,
+ device=device,
+ requires_grad=requires_grad
+ )
+
+ with torch.no_grad():
+ quat[..., 0] = 1
+
+ return quat
+
+
+_quat_elements = ["a", "b", "c", "d"]
+_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
+_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
+
+
+def _to_mat(pairs):
+ mat = np.zeros((4, 4))
+ for pair in pairs:
+ key, value = pair
+ ind = _qtr_ind_dict[key]
+ mat[ind // 4][ind % 4] = value
+
+ return mat
+
+
+_QTR_MAT = np.zeros((4, 4, 3, 3))
+_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
+_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
+_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
+_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
+_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
+_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
+_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
+_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
+_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
+
+
+def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Converts a quaternion to a rotation matrix.
+
+ Args:
+ quat: [*, 4] quaternions
+ Returns:
+ [*, 3, 3] rotation matrices
+ """
+ # [*, 4, 4]
+ quat = quat[..., None] * quat[..., None, :]
+
+ # [4, 4, 3, 3]
+ mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
+
+ # [*, 4, 4, 3, 3]
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
+ quat = quat[..., None, None] * shaped_qtr_mat
+
+ # [*, 3, 3]
+ return torch.sum(quat, dim=(-3, -4))
+
+
+def rot_to_quat(
+ rot: torch.Tensor,
+):
+ if(rot.shape[-2:] != (3, 3)):
+ raise ValueError("Input rotation is incorrectly shaped")
+
+ rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
+
+ k = [
+ [ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
+ [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
+ [ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
+ [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
+ ]
+
+ k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
+
+ _, vectors = torch.linalg.eigh(k)
+ return vectors[..., -1]
+
+
+_QUAT_MULTIPLY = np.zeros((4, 4, 4))
+_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
+ [ 0,-1, 0, 0],
+ [ 0, 0,-1, 0],
+ [ 0, 0, 0,-1]]
+
+_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
+ [ 1, 0, 0, 0],
+ [ 0, 0, 0, 1],
+ [ 0, 0,-1, 0]]
+
+_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
+ [ 0, 0, 0,-1],
+ [ 1, 0, 0, 0],
+ [ 0, 1, 0, 0]]
+
+_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
+ [ 0, 0, 1, 0],
+ [ 0,-1, 0, 0],
+ [ 1, 0, 0, 0]]
+
+_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
+
+
+def quat_multiply(quat1, quat2):
+ """Multiply a quaternion by another quaternion."""
+ mat = quat1.new_tensor(_QUAT_MULTIPLY)
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
+ return torch.sum(
+ reshaped_mat *
+ quat1[..., :, None, None] *
+ quat2[..., None, :, None],
+ dim=(-3, -2)
+ )
+
+
+def quat_multiply_by_vec(quat, vec):
+ """Multiply a quaternion by a pure-vector quaternion."""
+ mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
+ return torch.sum(
+ reshaped_mat *
+ quat[..., :, None, None] *
+ vec[..., None, :, None],
+ dim=(-3, -2)
+ )
+
+
+def invert_rot_mat(rot_mat: torch.Tensor):
+ return rot_mat.transpose(-1, -2)
+
+
+def invert_quat(quat: torch.Tensor):
+ quat_prime = quat.clone()
+ quat_prime[..., 1:] *= -1
+ inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
+ return inv
+
+
+class Rotation:
+ """
+ A 3D rotation. Depending on how the object is initialized, the
+ rotation is represented by either a rotation matrix or a
+ quaternion, though both formats are made available by helper functions.
+ To simplify gradient computation, the underlying format of the
+ rotation cannot be changed in-place. Like Rigid, the class is designed
+ to mimic the behavior of a torch Tensor, almost as if each Rotation
+ object were a tensor of rotations, in one format or another.
+ """
+ def __init__(self,
+ rot_mats: Optional[torch.Tensor] = None,
+ quats: Optional[torch.Tensor] = None,
+ normalize_quats: bool = True,
+ ):
+ """
+ Args:
+ rot_mats:
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
+ quats
+ quats:
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If
+ normalize_quats is not True, must be a unit quaternion
+ normalize_quats:
+ If quats is specified, whether to normalize quats
+ """
+ if((rot_mats is None and quats is None) or
+ (rot_mats is not None and quats is not None)):
+ raise ValueError("Exactly one input argument must be specified")
+
+ if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
+ (quats is not None and quats.shape[-1] != 4)):
+ raise ValueError(
+ "Incorrectly shaped rotation matrix or quaternion"
+ )
+
+ # Force full-precision
+ if(quats is not None):
+ quats = quats.type(torch.float32)
+ if(rot_mats is not None):
+ rot_mats = rot_mats.type(torch.float32)
+
+ if(quats is not None and normalize_quats):
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
+
+ self._rot_mats = rot_mats
+ self._quats = quats
+
+ @staticmethod
+ def identity(
+ shape,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+ fmt: str = "quat",
+ ):
+ """
+ Returns an identity Rotation.
+
+ Args:
+ shape:
+ The "shape" of the resulting Rotation object. See documentation
+ for the shape property
+ dtype:
+ The torch dtype for the rotation
+ device:
+ The torch device for the new rotation
+ requires_grad:
+ Whether the underlying tensors in the new rotation object
+ should require gradient computation
+ fmt:
+ One of "quat" or "rot_mat". Determines the underlying format
+ of the new object's rotation
+ Returns:
+ A new identity rotation
+ """
+ if(fmt == "rot_mat"):
+ rot_mats = identity_rot_mats(
+ shape, dtype, device, requires_grad,
+ )
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(fmt == "quat"):
+ quats = identity_quats(shape, dtype, device, requires_grad)
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError(f"Invalid format: f{fmt}")
+
+ # Magic methods
+
+ def __getitem__(self, index: Any):
+ """
+ Allows torch-style indexing over the virtual shape of the rotation
+ object. See documentation for the shape property.
+
+ Args:
+ index:
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
+ Returns:
+ The indexed rotation
+ """
+ if type(index) != tuple:
+ index = (index,)
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
+ return Rotation(rot_mats=rot_mats)
+ elif(self._quats is not None):
+ quats = self._quats[index + (slice(None),)]
+ return Rotation(quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def __mul__(self,
+ right: torch.Tensor,
+ ):
+ """
+ Pointwise left multiplication of the rotation with a tensor. Can be
+ used to e.g. mask the Rotation.
+
+ Args:
+ right:
+ The tensor multiplicand
+ Returns:
+ The product
+ """
+ if not(isinstance(right, torch.Tensor)):
+ raise TypeError("The other multiplicand must be a Tensor")
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats * right[..., None, None]
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = self._quats * right[..., None]
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def __rmul__(self,
+ left: torch.Tensor,
+ ):
+ """
+ Reverse pointwise multiplication of the rotation with a tensor.
+
+ Args:
+ left:
+ The left multiplicand
+ Returns:
+ The product
+ """
+ return self.__mul__(left)
+
+ # Properties
+
+ @property
+ def shape(self) -> torch.Size:
+ """
+ Returns the virtual shape of the rotation object. This shape is
+ defined as the batch dimensions of the underlying rotation matrix
+ or quaternion. If the Rotation was initialized with a [10, 3, 3]
+ rotation matrix tensor, for example, the resulting shape would be
+ [10].
+
+ Returns:
+ The virtual shape of the rotation object
+ """
+ s = None
+ if(self._quats is not None):
+ s = self._quats.shape[:-1]
+ else:
+ s = self._rot_mats.shape[:-2]
+
+ return s
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ Returns the dtype of the underlying rotation.
+
+ Returns:
+ The dtype of the underlying rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.dtype
+ elif(self._quats is not None):
+ return self._quats.dtype
+ else:
+ raise ValueError("Both rotations are None")
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the underlying rotation
+
+ Returns:
+ The device of the underlying rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.device
+ elif(self._quats is not None):
+ return self._quats.device
+ else:
+ raise ValueError("Both rotations are None")
+
+ @property
+ def requires_grad(self) -> bool:
+ """
+ Returns the requires_grad property of the underlying rotation
+
+ Returns:
+ The requires_grad property of the underlying tensor
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.requires_grad
+ elif(self._quats is not None):
+ return self._quats.requires_grad
+ else:
+ raise ValueError("Both rotations are None")
+
+ def get_rot_mats(self) -> torch.Tensor:
+ """
+ Returns the underlying rotation as a rotation matrix tensor.
+
+ Returns:
+ The rotation as a rotation matrix tensor
+ """
+ rot_mats = self._rot_mats
+ if(rot_mats is None):
+ if(self._quats is None):
+ raise ValueError("Both rotations are None")
+ else:
+ rot_mats = quat_to_rot(self._quats)
+
+ return rot_mats
+
+ def get_quats(self) -> torch.Tensor:
+ """
+ Returns the underlying rotation as a quaternion tensor.
+
+ Depending on whether the Rotation was initialized with a
+ quaternion, this function may call torch.linalg.eigh.
+
+ Returns:
+ The rotation as a quaternion tensor.
+ """
+ quats = self._quats
+ if(quats is None):
+ if(self._rot_mats is None):
+ raise ValueError("Both rotations are None")
+ else:
+ # quats = rot_to_quat(self._rot_mats)
+ quats = rotation3d.matrix_to_quaternion(self._rot_mats)
+
+ return quats
+
+ def get_cur_rot(self) -> torch.Tensor:
+ """
+ Return the underlying rotation in its current form
+
+ Returns:
+ The stored rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats
+ elif(self._quats is not None):
+ return self._quats
+ else:
+ raise ValueError("Both rotations are None")
+
+ def get_rotvec(self, eps=1e-6) -> torch.Tensor:
+ """
+ Return the underlying axis-angle rotation vector.
+
+ Follow's scipy's implementation:
+ https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
+
+ Returns:
+ The stored rotation as a axis-angle vector.
+ """
+ quat = self.get_quats()
+ # w > 0 to ensure 0 <= angle <= pi
+ flip = (quat[..., :1] < 0).float()
+ quat = (-1 * quat) * flip + (1 - flip) * quat
+
+ angle = 2 * torch.atan2(
+ torch.linalg.norm(quat[..., 1:], dim=-1),
+ quat[..., 0]
+ )
+
+ angle2 = angle * angle
+ small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
+ large_angle_scales = angle / torch.sin(angle / 2 + eps)
+
+ small_angles = (angle <= 1e-3).float()
+ rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
+ rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
+ return rot_vec
+
+ # Rotation functions
+
+ def compose_q_update_vec(self,
+ q_update_vec: torch.Tensor,
+ normalize_quats: bool = True,
+ update_mask: torch.Tensor = None,
+ ):
+ """
+ Returns a new quaternion Rotation after updating the current
+ object's underlying rotation with a quaternion update, formatted
+ as a [*, 3] tensor whose final three columns represent x, y, z such
+ that (1, x, y, z) is the desired (not necessarily unit) quaternion
+ update.
+
+ Args:
+ q_update_vec:
+ A [*, 3] quaternion update tensor
+ normalize_quats:
+ Whether to normalize the output quaternion
+ Returns:
+ An updated Rotation
+ """
+ quats = self.get_quats()
+ quat_update = quat_multiply_by_vec(quats, q_update_vec)
+ if update_mask is not None:
+ quat_update = quat_update * update_mask
+ new_quats = quats + quat_update
+ return Rotation(
+ rot_mats=None,
+ quats=new_quats,
+ normalize_quats=normalize_quats,
+ )
+
+ def compose_r(self, r):
+ """
+ Compose the rotation matrices of the current Rotation object with
+ those of another.
+
+ Args:
+ r:
+ An update rotation object
+ Returns:
+ An updated rotation object
+ """
+ r1 = self.get_rot_mats()
+ r2 = r.get_rot_mats()
+ new_rot_mats = rot_matmul(r1, r2)
+ return Rotation(rot_mats=new_rot_mats, quats=None)
+
+ def compose_q(self, r, normalize_quats: bool = True):
+ """
+ Compose the quaternions of the current Rotation object with those
+ of another.
+
+ Depending on whether either Rotation was initialized with
+ quaternions, this function may call torch.linalg.eigh.
+
+ Args:
+ r:
+ An update rotation object
+ Returns:
+ An updated rotation object
+ """
+ q1 = self.get_quats()
+ q2 = r.get_quats()
+ new_quats = quat_multiply(q1, q2)
+ return Rotation(
+ rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
+ )
+
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the current Rotation as a rotation matrix to a set of 3D
+ coordinates.
+
+ Args:
+ pts:
+ A [*, 3] set of points
+ Returns:
+ [*, 3] rotated points
+ """
+ rot_mats = self.get_rot_mats()
+ return rot_vec_mul(rot_mats, pts)
+
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
+ """
+ The inverse of the apply() method.
+
+ Args:
+ pts:
+ A [*, 3] set of points
+ Returns:
+ [*, 3] inverse-rotated points
+ """
+ rot_mats = self.get_rot_mats()
+ inv_rot_mats = invert_rot_mat(rot_mats)
+ return rot_vec_mul(inv_rot_mats, pts)
+
+ def invert(self) :
+ """
+ Returns the inverse of the current Rotation.
+
+ Returns:
+ The inverse of the current Rotation
+ """
+ if(self._rot_mats is not None):
+ return Rotation(
+ rot_mats=invert_rot_mat(self._rot_mats),
+ quats=None
+ )
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=invert_quat(self._quats),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ # "Tensor" stuff
+
+ def unsqueeze(self,
+ dim: int,
+ ):
+ """
+ Analogous to torch.unsqueeze. The dimension is relative to the
+ shape of the Rotation object.
+
+ Args:
+ dim: A positive or negative dimension index.
+ Returns:
+ The unsqueezed Rotation.
+ """
+ if dim >= len(self.shape):
+ raise ValueError("Invalid dimension")
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ @staticmethod
+ def cat(
+ rs,
+ dim: int,
+ ):
+ """
+ Concatenates rotations along one of the batch dimensions. Analogous
+ to torch.cat().
+
+ Note that the output of this operation is always a rotation matrix,
+ regardless of the format of input rotations.
+
+ Args:
+ rs:
+ A list of rotation objects
+ dim:
+ The dimension along which the rotations should be
+ concatenated
+ Returns:
+ A concatenated Rotation object in rotation matrix format
+ """
+ rot_mats = [r.get_rot_mats() for r in rs]
+ rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
+
+ return Rotation(rot_mats=rot_mats, quats=None)
+
+ def map_tensor_fn(self,
+ fn
+ ):
+ """
+ Apply a Tensor -> Tensor function to underlying rotation tensors,
+ mapping over the rotation dimension(s). Can be used e.g. to sum out
+ a one-hot batch dimension.
+
+ Args:
+ fn:
+ A Tensor -> Tensor function to be mapped over the Rotation
+ Returns:
+ The transformed Rotation object
+ """
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
+ rot_mats = torch.stack(
+ list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
+ )
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = torch.stack(
+ list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
+ )
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def cuda(self):
+ """
+ Analogous to the cuda() method of torch Tensors
+
+ Returns:
+ A copy of the Rotation in CUDA memory
+ """
+ if(self._rot_mats is not None):
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.cuda(),
+ normalize_quats=False
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ def to(self,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype]
+ ):
+ """
+ Analogous to the to() method of torch Tensors
+
+ Args:
+ device:
+ A torch device
+ dtype:
+ A torch dtype
+ Returns:
+ A copy of the Rotation using the new device and dtype
+ """
+ if(self._rot_mats is not None):
+ return Rotation(
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
+ quats=None,
+ )
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.to(device=device, dtype=dtype),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ def detach(self):
+ """
+ Returns a copy of the Rotation whose underlying Tensor has been
+ detached from its torch graph.
+
+ Returns:
+ A copy of the Rotation whose underlying Tensor has been detached
+ from its torch graph
+ """
+ if(self._rot_mats is not None):
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.detach(),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+
+class Rigid:
+ """
+ A class representing a rigid transformation. Little more than a wrapper
+ around two objects: a Rotation object and a [*, 3] translation
+ Designed to behave approximately like a single torch tensor with the
+ shape of the shared batch dimensions of its component parts.
+ """
+ def __init__(self,
+ rots: Optional[Rotation],
+ trans: Optional[torch.Tensor],
+ ):
+ """
+ Args:
+ rots: A [*, 3, 3] rotation tensor
+ trans: A corresponding [*, 3] translation tensor
+ """
+ # (we need device, dtype, etc. from at least one input)
+
+ batch_dims, dtype, device, requires_grad = None, None, None, None
+ if(trans is not None):
+ batch_dims = trans.shape[:-1]
+ dtype = trans.dtype
+ device = trans.device
+ requires_grad = trans.requires_grad
+ elif(rots is not None):
+ batch_dims = rots.shape
+ dtype = rots.dtype
+ device = rots.device
+ requires_grad = rots.requires_grad
+ else:
+ raise ValueError("At least one input argument must be specified")
+
+ if(rots is None):
+ rots = Rotation.identity(
+ batch_dims, dtype, device, requires_grad,
+ )
+ elif(trans is None):
+ trans = identity_trans(
+ batch_dims, dtype, device, requires_grad,
+ )
+
+ if((rots.shape != trans.shape[:-1]) or
+ (rots.device != trans.device)):
+ raise ValueError("Rots and trans incompatible")
+
+ # Force full precision. Happens to the rotations automatically.
+ trans = trans.type(torch.float32)
+
+ self._rots = rots
+ self._trans = trans
+
+ @staticmethod
+ def identity(
+ shape: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+ fmt: str = "quat",
+ ):
+ """
+ Constructs an identity transformation.
+
+ Args:
+ shape:
+ The desired shape
+ dtype:
+ The dtype of both internal tensors
+ device:
+ The device of both internal tensors
+ requires_grad:
+ Whether grad should be enabled for the internal tensors
+ Returns:
+ The identity transformation
+ """
+ return Rigid(
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
+ identity_trans(shape, dtype, device, requires_grad),
+ )
+
+ def __getitem__(self,
+ index: Any,
+ ):
+ """
+ Indexes the affine transformation with PyTorch-style indices.
+ The index is applied to the shared dimensions of both the rotation
+ and the translation.
+
+ E.g.::
+
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
+ t = Rigid(r, torch.rand(10, 10, 3))
+ indexed = t[3, 4:6]
+ assert(indexed.shape == (2,))
+ assert(indexed.get_rots().shape == (2,))
+ assert(indexed.get_trans().shape == (2, 3))
+
+ Args:
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
+ or (3, slice(0, 1, None))
+ Returns:
+ The indexed tensor
+ """
+ if type(index) != tuple:
+ index = (index,)
+
+ return Rigid(
+ self._rots[index],
+ self._trans[index + (slice(None),)],
+ )
+
+ def __mul__(self,
+ right: torch.Tensor,
+ ):
+ """
+ Pointwise left multiplication of the transformation with a tensor.
+ Can be used to e.g. mask the Rigid.
+
+ Args:
+ right:
+ The tensor multiplicand
+ Returns:
+ The product
+ """
+ if not(isinstance(right, torch.Tensor)):
+ raise TypeError("The other multiplicand must be a Tensor")
+
+ new_rots = self._rots * right
+ new_trans = self._trans * right[..., None]
+
+ return Rigid(new_rots, new_trans)
+
+ def __rmul__(self,
+ left: torch.Tensor,
+ ):
+ """
+ Reverse pointwise multiplication of the transformation with a
+ tensor.
+
+ Args:
+ left:
+ The left multiplicand
+ Returns:
+ The product
+ """
+ return self.__mul__(left)
+
+ @property
+ def shape(self) -> torch.Size:
+ """
+ Returns the shape of the shared dimensions of the rotation and
+ the translation.
+
+ Returns:
+ The shape of the transformation
+ """
+ s = self._trans.shape[:-1]
+ return s
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Returns the device on which the Rigid's tensors are located.
+
+ Returns:
+ The device on which the Rigid's tensors are located
+ """
+ return self._trans.device
+
+ def get_rots(self) -> Rotation:
+ """
+ Getter for the rotation.
+
+ Returns:
+ The rotation object
+ """
+ return self._rots
+
+ def get_trans(self) -> torch.Tensor:
+ """
+ Getter for the translation.
+
+ Returns:
+ The stored translation
+ """
+ return self._trans
+
+ def compose_q_update_vec(self,
+ q_update_vec: torch.Tensor,
+ update_mask: torch.Tensor=None,
+ ):
+ """
+ Composes the transformation with a quaternion update vector of
+ shape [*, 6], where the final 6 columns represent the x, y, and
+ z values of a quaternion of form (1, x, y, z) followed by a 3D
+ translation.
+
+ Args:
+ q_vec: The quaternion update vector.
+ Returns:
+ The composed transformation.
+ """
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
+ new_rots = self._rots.compose_q_update_vec(
+ q_vec, update_mask=update_mask)
+
+ trans_update = self._rots.apply(t_vec)
+ if update_mask is not None:
+ trans_update = trans_update * update_mask
+ new_translation = self._trans + trans_update
+
+ return Rigid(new_rots, new_translation)
+
+ def compose(self,
+ r,
+ ):
+ """
+ Composes the current rigid object with another.
+
+ Args:
+ r:
+ Another Rigid object
+ Returns:
+ The composition of the two transformations
+ """
+ new_rot = self._rots.compose_r(r._rots)
+ new_trans = self._rots.apply(r._trans) + self._trans
+ return Rigid(new_rot, new_trans)
+
+ def compose_r(self,
+ rot,
+ order='right'
+ ):
+ """
+ Composes the current rigid object with another.
+
+ Args:
+ r:
+ Another Rigid object
+ order:
+ Order in which to perform rotation multiplication.
+ Returns:
+ The composition of the two transformations
+ """
+ if order == 'right':
+ new_rot = self._rots.compose_r(rot)
+ elif order == 'left':
+ new_rot = rot.compose_r(self._rots)
+ else:
+ raise ValueError(f'Unrecognized multiplication order: {order}')
+ return Rigid(new_rot, self._trans)
+
+ def apply(self,
+ pts: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Applies the transformation to a coordinate tensor.
+
+ Args:
+ pts: A [*, 3] coordinate tensor.
+ Returns:
+ The transformed points.
+ """
+ rotated = self._rots.apply(pts)
+ return rotated + self._trans
+
+ def invert_apply(self,
+ pts: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Applies the inverse of the transformation to a coordinate tensor.
+
+ Args:
+ pts: A [*, 3] coordinate tensor
+ Returns:
+ The transformed points.
+ """
+ pts = pts - self._trans
+ return self._rots.invert_apply(pts)
+
+ def invert(self):
+ """
+ Inverts the transformation.
+
+ Returns:
+ The inverse transformation.
+ """
+ rot_inv = self._rots.invert()
+ trn_inv = rot_inv.apply(self._trans)
+
+ return Rigid(rot_inv, -1 * trn_inv)
+
+ def map_tensor_fn(self,
+ fn
+ ):
+ """
+ Apply a Tensor -> Tensor function to underlying translation and
+ rotation tensors, mapping over the translation/rotation dimensions
+ respectively.
+
+ Args:
+ fn:
+ A Tensor -> Tensor function to be mapped over the Rigid
+ Returns:
+ The transformed Rigid object
+ """
+ new_rots = self._rots.map_tensor_fn(fn)
+ new_trans = torch.stack(
+ list(map(fn, torch.unbind(self._trans, dim=-1))),
+ dim=-1
+ )
+
+ return Rigid(new_rots, new_trans)
+
+ def to_tensor_4x4(self) -> torch.Tensor:
+ """
+ Converts a transformation to a homogenous transformation tensor.
+
+ Returns:
+ A [*, 4, 4] homogenous transformation tensor
+ """
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
+ tensor[..., :3, 3] = self._trans
+ tensor[..., 3, 3] = 1
+ return tensor
+
+ @staticmethod
+ def from_tensor_4x4(
+ t: torch.Tensor
+ ):
+ """
+ Constructs a transformation from a homogenous transformation
+ tensor.
+
+ Args:
+ t: [*, 4, 4] homogenous transformation tensor
+ Returns:
+ T object with shape [*]
+ """
+ if(t.shape[-2:] != (4, 4)):
+ raise ValueError("Incorrectly shaped input tensor")
+
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
+ trans = t[..., :3, 3]
+
+ return Rigid(rots, trans)
+
+ def to_tensor_7(self) -> torch.Tensor:
+ """
+ Converts a transformation to a tensor with 7 final columns, four
+ for the quaternion followed by three for the translation.
+
+ Returns:
+ A [*, 7] tensor representation of the transformation
+ """
+ tensor = self._trans.new_zeros((*self.shape, 7))
+ tensor[..., :4] = self._rots.get_quats()
+ tensor[..., 4:] = self._trans
+
+ return tensor
+
+ @staticmethod
+ def from_tensor_7(
+ t: torch.Tensor,
+ normalize_quats: bool = False,
+ ):
+ if(t.shape[-1] != 7):
+ raise ValueError("Incorrectly shaped input tensor")
+
+ quats, trans = t[..., :4], t[..., 4:]
+
+ rots = Rotation(
+ rot_mats=None,
+ quats=quats,
+ normalize_quats=normalize_quats
+ )
+
+ return Rigid(rots, trans)
+
+ @staticmethod
+ def from_3_points(
+ p_neg_x_axis: torch.Tensor,
+ origin: torch.Tensor,
+ p_xy_plane: torch.Tensor,
+ eps: float = 1e-8
+ ):
+ """
+ Implements algorithm 21. Constructs transformations from sets of 3
+ points using the Gram-Schmidt algorithm.
+
+ Args:
+ p_neg_x_axis: [*, 3] coordinates
+ origin: [*, 3] coordinates used as frame origins
+ p_xy_plane: [*, 3] coordinates
+ eps: Small epsilon value
+ Returns:
+ A transformation object of shape [*]
+ """
+ p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
+ origin = torch.unbind(origin, dim=-1)
+ p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
+
+ e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
+
+ denom = torch.sqrt(sum((c * c for c in e0)) + eps)
+ e0 = [c / denom for c in e0]
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps)
+ e1 = [c / denom for c in e1]
+ e2 = [
+ e0[1] * e1[2] - e0[2] * e1[1],
+ e0[2] * e1[0] - e0[0] * e1[2],
+ e0[0] * e1[1] - e0[1] * e1[0],
+ ]
+
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
+
+ rot_obj = Rotation(rot_mats=rots, quats=None)
+
+ return Rigid(rot_obj, torch.stack(origin, dim=-1))
+
+ def unsqueeze(self,
+ dim: int,
+ ):
+ """
+ Analogous to torch.unsqueeze. The dimension is relative to the
+ shared dimensions of the rotation/translation.
+
+ Args:
+ dim: A positive or negative dimension index.
+ Returns:
+ The unsqueezed transformation.
+ """
+ if dim >= len(self.shape):
+ raise ValueError("Invalid dimension")
+ rots = self._rots.unsqueeze(dim)
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
+
+ return Rigid(rots, trans)
+
+ @staticmethod
+ def cat(
+ ts,
+ dim: int,
+ ):
+ """
+ Concatenates transformations along a new dimension.
+
+ Args:
+ ts:
+ A list of T objects
+ dim:
+ The dimension along which the transformations should be
+ concatenated
+ Returns:
+ A concatenated transformation object
+ """
+ rots = Rotation.cat([t._rots for t in ts], dim)
+ trans = torch.cat(
+ [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
+ )
+
+ return Rigid(rots, trans)
+
+ def apply_rot_fn(self, fn):
+ """
+ Applies a Rotation -> Rotation function to the stored rotation
+ object.
+
+ Args:
+ fn: A function of type Rotation -> Rotation
+ Returns:
+ A transformation object with a transformed rotation.
+ """
+ return Rigid(fn(self._rots), self._trans)
+
+ def apply_trans_fn(self, fn):
+ """
+ Applies a Tensor -> Tensor function to the stored translation.
+
+ Args:
+ fn:
+ A function of type Tensor -> Tensor to be applied to the
+ translation
+ Returns:
+ A transformation object with a transformed translation.
+ """
+ return Rigid(self._rots, fn(self._trans))
+
+ def scale_translation(self, trans_scale_factor: float):
+ """
+ Scales the translation by a constant factor.
+
+ Args:
+ trans_scale_factor:
+ The constant factor
+ Returns:
+ A transformation object with a scaled translation.
+ """
+ fn = lambda t: t * trans_scale_factor
+ return self.apply_trans_fn(fn)
+
+ def stop_rot_gradient(self):
+ """
+ Detaches the underlying rotation object
+
+ Returns:
+ A transformation object with detached rotations
+ """
+ fn = lambda r: r.detach()
+ return self.apply_rot_fn(fn)
+
+ @staticmethod
+ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
+ """
+ Returns a transformation object from reference coordinates.
+
+ Note that this method does not take care of symmetries. If you
+ provide the atom positions in the non-standard way, the N atom will
+ end up not at [-0.527250, 1.359329, 0.0] but instead at
+ [-0.527250, -1.359329, 0.0]. You need to take care of such cases in
+ your code.
+
+ Args:
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
+ Returns:
+ A transformation object. After applying the translation and
+ rotation to the reference backbone, the coordinates will
+ approximately equal to the input coordinates.
+ """
+ translation = -1 * ca_xyz
+ n_xyz = n_xyz + translation
+ c_xyz = c_xyz + translation
+
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
+ sin_c1 = -c_y / norm
+ cos_c1 = c_x / norm
+ zeros = sin_c1.new_zeros(sin_c1.shape)
+ ones = sin_c1.new_ones(sin_c1.shape)
+
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
+ c1_rots[..., 0, 0] = cos_c1
+ c1_rots[..., 0, 1] = -1 * sin_c1
+ c1_rots[..., 1, 0] = sin_c1
+ c1_rots[..., 1, 1] = cos_c1
+ c1_rots[..., 2, 2] = 1
+
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
+ sin_c2 = c_z / norm
+ cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
+
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
+ c2_rots[..., 0, 0] = cos_c2
+ c2_rots[..., 0, 2] = sin_c2
+ c2_rots[..., 1, 1] = 1
+ c1_rots[..., 2, 0] = -1 * sin_c2
+ c1_rots[..., 2, 2] = cos_c2
+
+ c_rots = rot_matmul(c2_rots, c1_rots)
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
+
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
+ norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
+ sin_n = -n_z / norm
+ cos_n = n_y / norm
+
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
+ n_rots[..., 0, 0] = 1
+ n_rots[..., 1, 1] = cos_n
+ n_rots[..., 1, 2] = -1 * sin_n
+ n_rots[..., 2, 1] = sin_n
+ n_rots[..., 2, 2] = cos_n
+
+ rots = rot_matmul(n_rots, c_rots)
+
+ rots = rots.transpose(-1, -2)
+ translation = -1 * translation
+
+ rot_obj = Rotation(rot_mats=rots, quats=None)
+
+ return Rigid(rot_obj, translation)
+
+ def cuda(self):
+ """
+ Moves the transformation object to GPU memory
+
+ Returns:
+ A version of the transformation on GPU
+ """
+ return Rigid(self._rots.cuda(), self._trans.cuda())
diff --git a/analysis/src/common/rotation3d.py b/analysis/src/common/rotation3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d9febabf302b0b81d8b97cece719570aeb6402c
--- /dev/null
+++ b/analysis/src/common/rotation3d.py
@@ -0,0 +1,596 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Union
+
+import torch
+from torch.nn import functional as F
+
+Device = Union[str, torch.device]
+
+
+"""
+The transformation matrices returned from the functions in this file assume
+the points on which the transformation will be applied are column vectors.
+i.e. the R matrix is structured as
+
+ R = [
+ [Rxx, Rxy, Rxz],
+ [Ryx, Ryy, Ryz],
+ [Rzx, Rzy, Rzz],
+ ] # (3, 3)
+
+This matrix can be applied to column vectors by post multiplication
+by the points e.g.
+
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
+ transformed_points = R * points
+
+To apply the same matrix to points which are row vectors, the R matrix
+can be transposed and pre multiplied by the points:
+
+e.g.
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
+ transformed_points = points * R.transpose(1, 0)
+"""
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Return a tensor where each element has the absolute value taken from the,
+ corresponding element of a, with sign taken from the corresponding
+ element of b. This is like the standard copysign floating-point operation,
+ but is not careful about negative 0 and NaN.
+
+ Args:
+ a: source tensor.
+ b: tensor whose signs will be used, of the same shape as a.
+
+ Returns:
+ Tensor of the same shape as a with the signs of b.
+ """
+ signs_differ = (a < 0) != (b < 0)
+ return torch.where(signs_differ, -a, a)
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+
+def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ {"X", "Y", and "Z"}.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ _axis_angle_rotation(c, e)
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
+ ]
+ # return functools.reduce(torch.matmul, matrices)
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
+
+
+def _angle_from_tan(
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
+) -> torch.Tensor:
+ """
+ Extract the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in data as a tensor
+ of shape (...).
+ """
+
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def _index_from_letter(letter: str) -> int:
+ if letter == "X":
+ return 0
+ if letter == "Y":
+ return 1
+ if letter == "Z":
+ return 2
+ raise ValueError("letter must be either X, Y or Z.")
+
+
+def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to Euler angles in radians.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3).
+ """
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ central_angle = torch.asin(
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ return torch.stack(o, -1)
+
+
+def random_quaternions(
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
+) -> torch.Tensor:
+ """
+ Generate random quaternions representing rotations,
+ i.e. versors with nonnegative real part.
+
+ Args:
+ n: Number of quaternions in a batch to return.
+ dtype: Type to return.
+ device: Desired device of returned tensor. Default:
+ uses the current device for the default tensor type.
+
+ Returns:
+ Quaternions as tensor of shape (N, 4).
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ o = torch.randn((n, 4), dtype=dtype, device=device)
+ s = (o * o).sum(1)
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
+ return o
+
+
+def random_rotations(
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
+) -> torch.Tensor:
+ """
+ Generate random rotations as 3x3 rotation matrices.
+
+ Args:
+ n: Number of rotation matrices in a batch to return.
+ dtype: Type to return.
+ device: Device of returned tensor. Default: if None,
+ uses the current device for the default tensor type.
+
+ Returns:
+ Rotation matrices as tensor of shape (n, 3, 3).
+ """
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
+ return quaternion_to_matrix(quaternions)
+
+
+def random_rotation(
+ dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
+) -> torch.Tensor:
+ """
+ Generate a single random 3x3 rotation matrix.
+
+ Args:
+ dtype: Type to return
+ device: Device of returned tensor. Default: if None,
+ uses the current device for the default tensor type
+
+ Returns:
+ Rotation matrix as tensor of shape (3, 3).
+ """
+ return random_rotations(1, dtype, device)[0]
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Multiply two quaternions.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions shape (..., 4).
+ """
+ aw, ax, ay, az = torch.unbind(a, -1)
+ bw, bx, by, bz = torch.unbind(b, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ return torch.stack((ow, ox, oy, oz), -1)
+
+
+def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Multiply two quaternions representing rotations, returning the quaternion
+ representing their composition, i.e. the versor with nonnegative real part.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions of shape (..., 4).
+ """
+ ab = quaternion_raw_multiply(a, b)
+ return standardize_quaternion(ab)
+
+
+def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ Given a quaternion representing rotation, get the quaternion representing
+ its inverse.
+
+ Args:
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
+ first, which must be versors (unit quaternions).
+
+ Returns:
+ The inverse, a tensor of quaternions of shape (..., 4).
+ """
+
+ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
+ return quaternion * scaling
+
+
+def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the rotation given by a quaternion to a 3D point.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
+ point: Tensor of 3D points of shape (..., 3).
+
+ Returns:
+ Tensor of rotated points of shape (..., 3).
+ """
+ if point.size(-1) != 3:
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
+ point_as_quaternion = torch.cat((real_parts, point), -1)
+ out = quaternion_raw_multiply(
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
+ quaternion_invert(quaternion),
+ )
+ return out[..., 1:]
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
\ No newline at end of file
diff --git a/analysis/src/data/__init__.py b/analysis/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/data/__pycache__/__init__.cpython-39.pyc b/analysis/src/data/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf56b942c18dc19acdf77b510065416f88e26b01
Binary files /dev/null and b/analysis/src/data/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc b/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f87bba346055e0ee2e6c40316502089321e7bc0b
Binary files /dev/null and b/analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc differ
diff --git a/analysis/src/data/components/__init__.py b/analysis/src/data/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc b/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e615af160dca3703e79f7b957b6e93d97e25973
Binary files /dev/null and b/analysis/src/data/components/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc b/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..411ee67d87b0b1f4e2d5bb461244b9a0be259511
Binary files /dev/null and b/analysis/src/data/components/__pycache__/dataset.cpython-39.pyc differ
diff --git a/analysis/src/data/components/dataset.py b/analysis/src/data/components/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e619d9d9d2a230d315844580336152191b4b52d2
--- /dev/null
+++ b/analysis/src/data/components/dataset.py
@@ -0,0 +1,321 @@
+"""Protein dataset class."""
+import os
+import pickle
+from pathlib import Path
+from glob import glob
+from typing import Optional, Sequence, List, Union
+from functools import lru_cache
+import tree
+
+from tqdm import tqdm
+import numpy as np
+import pandas as pd
+import torch
+
+from src.common import residue_constants, data_transforms, rigid_utils, protein
+
+
+CA_IDX = residue_constants.atom_order['CA']
+DTYPE_MAPPING = {
+ 'aatype': torch.long,
+ 'atom_positions': torch.double,
+ 'atom_mask': torch.double,
+}
+
+
+class ProteinFeatureTransform:
+ def __init__(self,
+ unit: Optional[str] = 'angstrom',
+ truncate_length: Optional[int] = None,
+ strip_missing_residues: bool = True,
+ recenter_and_scale: bool = True,
+ eps: float = 1e-8,
+ ):
+ if unit == 'angstrom':
+ self.coordinate_scale = 1.0
+ elif unit in ('nm', 'nanometer'):
+ self.coordiante_scale = 0.1
+ else:
+ raise ValueError(f"Invalid unit: {unit}")
+
+ if truncate_length is not None:
+ assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}"
+ self.truncate_length = truncate_length
+
+ self.strip_missing_residues = strip_missing_residues
+ self.recenter_and_scale = recenter_and_scale
+ self.eps = eps
+
+ def __call__(self, chain_feats):
+ chain_feats = self.patch_feats(chain_feats)
+
+ if self.strip_missing_residues:
+ chain_feats = self.strip_ends(chain_feats)
+
+ if self.truncate_length is not None:
+ chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length)
+
+ # Recenter and scale atom positions
+ if self.recenter_and_scale:
+ chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps)
+
+ # Map to torch Tensor
+ chain_feats = self.map_to_tensors(chain_feats)
+ # Add extra features from AF2
+ chain_feats = self.protein_data_transform(chain_feats)
+
+ # ** refer to line 170 in pdb_data_loader.py **
+ return chain_feats
+
+ @staticmethod
+ def patch_feats(chain_feats):
+ seq_mask = chain_feats['atom_mask'][:, CA_IDX] # a little hack here
+ # residue_idx = np.arange(seq_mask.shape[0], dtype=np.int64)
+ residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) # start from 0, possibly has chain break
+ patch_feats = {
+ 'seq_mask': seq_mask,
+ 'residue_mask': seq_mask,
+ 'residue_idx': residue_idx,
+ 'fixed_mask': np.zeros_like(seq_mask),
+ 'sc_ca_t': np.zeros(seq_mask.shape + (3, )),
+ }
+ chain_feats.update(patch_feats)
+ return chain_feats
+
+ @staticmethod
+ def strip_ends(chain_feats):
+ # Strip missing residues on both ends
+ modeled_idx = np.where(chain_feats['aatype'] != 20)[0]
+ min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
+ chain_feats = tree.map_structure(
+ lambda x: x[min_idx : (max_idx+1)], chain_feats)
+ return chain_feats
+
+ @staticmethod
+ def random_truncate(chain_feats, max_len):
+ L = chain_feats['aatype'].shape[0]
+ if L > max_len:
+ # Randomly truncate
+ start = np.random.randint(0, L - max_len + 1)
+ end = start + max_len
+ chain_feats = tree.map_structure(
+ lambda x: x[start : end], chain_feats)
+ return chain_feats
+
+ @staticmethod
+ def map_to_tensors(chain_feats):
+ chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()}
+ # Alter dtype
+ for k, dtype in DTYPE_MAPPING.items():
+ if k in chain_feats:
+ chain_feats[k] = chain_feats[k].type(dtype)
+ return chain_feats
+
+ @staticmethod
+ def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8):
+ # recenter and scale atom positions
+ bb_pos = chain_feats['atom_positions'][:, CA_IDX]
+ bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps)
+ centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
+ scaled_pos = centered_pos * coordinate_scale
+ chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
+ return chain_feats
+
+ @staticmethod
+ def protein_data_transform(chain_feats):
+ chain_feats.update(
+ {
+ "all_atom_positions": chain_feats["atom_positions"],
+ "all_atom_mask": chain_feats["atom_mask"],
+ }
+ )
+ chain_feats = data_transforms.atom37_to_frames(chain_feats)
+ chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats)
+ chain_feats = data_transforms.get_backbone_frames(chain_feats)
+ chain_feats = data_transforms.get_chi_angles(chain_feats)
+ chain_feats = data_transforms.make_pseudo_beta("")(chain_feats)
+ chain_feats = data_transforms.make_atom14_masks(chain_feats)
+ chain_feats = data_transforms.make_atom14_positions(chain_feats)
+
+ # Add convenient key
+ chain_feats.pop("all_atom_positions")
+ chain_feats.pop("all_atom_mask")
+ return chain_feats
+
+
+class MetadataFilter:
+ def __init__(self,
+ min_len: Optional[int] = None,
+ max_len: Optional[int] = None,
+ min_chains: Optional[int] = None,
+ max_chains: Optional[int] = None,
+ min_resolution: Optional[int] = None,
+ max_resolution: Optional[int] = None,
+ include_structure_method: Optional[List[str]] = None,
+ include_oligomeric_detail: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ self.min_len = min_len
+ self.max_len = max_len
+ self.min_chains = min_chains
+ self.max_chains = max_chains
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.include_structure_method = include_structure_method
+ self.include_oligomeric_detail = include_oligomeric_detail
+
+ def __call__(self, df):
+ _pre_filter_len = len(df)
+ if self.min_len is not None:
+ df = df[df['raw_seq_len'] >= self.min_len]
+ if self.max_len is not None:
+ df = df[df['raw_seq_len'] <= self.max_len]
+ if self.min_chains is not None:
+ df = df[df['num_chains'] >= self.min_chains]
+ if self.max_chains is not None:
+ df = df[df['num_chains'] <= self.max_chains]
+ if self.min_resolution is not None:
+ df = df[df['resolution'] >= self.min_resolution]
+ if self.max_resolution is not None:
+ df = df[df['resolution'] <= self.max_resolution]
+ if self.include_structure_method is not None:
+ df = df[df['include_structure_method'].isin(self.include_structure_method)]
+ if self.include_oligomeric_detail is not None:
+ df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)]
+
+ print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter")
+ return df
+
+
+class RandomAccessProteinDataset(torch.utils.data.Dataset):
+ """Random access to pickle protein objects of dataset.
+
+ dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors'])
+
+ Note that each value is a ndarray in shape (L, *), for example:
+ 'atom_positions': (L, 37, 3)
+ """
+ def __init__(self,
+ path_to_dataset: Union[Path, str],
+ path_to_seq_embedding: Optional[Path] = None,
+ metadata_filter: Optional[MetadataFilter] = None,
+ training: bool = True,
+ transform: Optional[ProteinFeatureTransform] = None,
+ suffix: Optional[str] = '.pkl',
+ accession_code_fillter: Optional[Sequence[str]] = None,
+ **kwargs,
+ ):
+ super().__init__()
+ path_to_dataset = os.path.expanduser(path_to_dataset)
+ suffix = suffix if suffix.startswith('.') else '.' + suffix
+ assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}"
+
+ if os.path.isfile(path_to_dataset): # path to csv file
+ assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)"
+ self._df = pd.read_csv(path_to_dataset)
+ self._df.sort_values('modeled_seq_len', ascending=False)
+ if metadata_filter:
+ self._df = metadata_filter(self._df)
+ self._data = self._df['processed_complex_path'].tolist()
+ elif os.path.isdir(path_to_dataset): # path to directory
+ self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix)))
+ assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'"
+ else: # path as glob pattern
+ _pattern = path_to_dataset
+ self._data = sorted(glob(_pattern))
+ assert len(self._data) > 0, f"No files found in '{_pattern}'"
+
+ if accession_code_fillter and len(accession_code_fillter) > 0:
+ self._data = [p for p in self._data
+ if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter)
+ ]
+
+ self.data = np.asarray(self._data)
+ self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \
+ if path_to_seq_embedding is not None else None
+ self.suffix = suffix
+ self.transform = transform
+ self.training = training # not implemented yet
+
+
+ @property
+ def num_samples(self):
+ return len(self.data)
+
+ def len(self):
+ return self.__len__()
+
+ def __len__(self):
+ return self.num_samples
+
+ def get(self, idx):
+ return self.__getitem__(idx)
+
+ @lru_cache(maxsize=100)
+ def __getitem__(self, idx):
+ """return single pyg.Data() instance
+ """
+ data_path = self.data[idx]
+ accession_code = os.path.splitext(os.path.basename(data_path))[0]
+
+ if self.suffix == '.pkl':
+ # Load pickled protein
+ with open(data_path, 'rb') as f:
+ data_object = pickle.load(f)
+ elif self.suffix == '.pdb':
+ # Load pdb file
+ with open(data_path, 'r') as f:
+ pdb_string = f.read()
+ data_object = protein.from_pdb_string(pdb_string).to_dict()
+
+ # Apply data transform
+ if self.transform is not None:
+ data_object = self.transform(data_object)
+
+ # Get sequence embedding if have
+ if self.path_to_seq_embedding is not None:
+ embed_dict = torch.load(
+ os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt")
+ )
+ data_object.update(
+ {
+ 'seq_emb': embed_dict['representations'][33].float(),
+ } # 33 is for ESM650M
+ )
+
+ data_object['accession_code'] = accession_code
+ return data_object # dict of arrays
+
+
+
+class PretrainPDBDataset(RandomAccessProteinDataset):
+ def __init__(self,
+ path_to_dataset: str,
+ metadata_filter: MetadataFilter,
+ transform: ProteinFeatureTransform,
+ **kwargs,
+ ):
+ super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
+ metadata_filter=metadata_filter,
+ transform=transform,
+ **kwargs,
+ )
+
+
+class SamplingPDBDataset(RandomAccessProteinDataset):
+ def __init__(self,
+ path_to_dataset: str,
+ training: bool = False,
+ suffix: str = '.pdb',
+ transform: Optional[ProteinFeatureTransform] = None,
+ accession_code_fillter: Optional[Sequence[str]] = None,
+ ):
+ assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}"
+ super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
+ training=training,
+ suffix=suffix,
+ transform=transform,
+ accession_code_fillter=accession_code_fillter,
+ metadata_filter=None,
+ )
+
\ No newline at end of file
diff --git a/analysis/src/data/protein_datamodule.py b/analysis/src/data/protein_datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4162f3ca670361054de47e2ad18b93da55c97bb
--- /dev/null
+++ b/analysis/src/data/protein_datamodule.py
@@ -0,0 +1,242 @@
+from typing import Any, Dict, Optional, Tuple, List, Sequence
+
+import torch
+from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
+from lightning import LightningDataModule
+from hydra.utils import instantiate
+
+
+class BatchTensorConverter:
+ """Callable to convert an unprocessed (labels + strings) batch to a
+ processed (labels + tensor) batch.
+ """
+ def __init__(self, target_keys: Optional[List] = None):
+ self.target_keys = target_keys
+
+ def __call__(self, raw_batch: Sequence[Dict[str, object]]):
+ B = len(raw_batch)
+ # Only do for Tensor
+ target_keys = self.target_keys \
+ if self.target_keys is not None else [k for k,v in raw_batch[0].items() if torch.is_tensor(v)]
+ # Non-array, for example string, int
+ non_array_keys = [k for k in raw_batch[0] if k not in target_keys]
+ collated_batch = dict()
+ for k in target_keys:
+ collated_batch[k] = self.collate_dense_tensors([d[k] for d in raw_batch], pad_v=0.0)
+ for k in non_array_keys: # return non-array keys as is
+ collated_batch[k] = [d[k] for d in raw_batch]
+ return collated_batch
+
+ @staticmethod
+ def collate_dense_tensors(samples: Sequence, pad_v: float = 0.0):
+ """
+ Takes a list of tensors with the following dimensions:
+ [(d_11, ..., d_1K),
+ (d_21, ..., d_2K),
+ ...,
+ (d_N1, ..., d_NK)]
+ and stack + pads them into a single tensor of:
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
+ """
+ if len(samples) == 0:
+ return torch.Tensor()
+ if len(set(x.dim() for x in samples)) != 1:
+ raise RuntimeError(
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
+ )
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
+ result = torch.empty(
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
+ )
+ result.fill_(pad_v)
+ for i in range(len(samples)):
+ result_i = result[i]
+ t = samples[i]
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
+ return result
+
+
+class ProteinDataModule(LightningDataModule):
+ """`LightningDataModule` for a single protein dataset,
+ for pretrain or finetune purpose.
+
+ ### To be revised.###
+
+ The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
+ It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
+ fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
+ while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
+ technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
+ mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
+
+ A `LightningDataModule` implements 7 key methods:
+
+ ```python
+ def prepare_data(self):
+ # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
+ # Download data, pre-process, split, save to disk, etc...
+
+ def setup(self, stage):
+ # Things to do on every process in DDP.
+ # Load data, set variables, etc...
+
+ def train_dataloader(self):
+ # return train dataloader
+
+ def val_dataloader(self):
+ # return validation dataloader
+
+ def test_dataloader(self):
+ # return test dataloader
+
+ def predict_dataloader(self):
+ # return predict dataloader
+
+ def teardown(self, stage):
+ # Called on every process in DDP.
+ # Clean up after fit or test.
+ ```
+
+ This allows you to share a full dataset without explaining how to download,
+ split, transform and process the data.
+
+ Read the docs:
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
+ """
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ batch_size: int = 64,
+ generator_seed: int = 42,
+ train_val_split: Tuple[float, float] = (0.95, 0.05),
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ shuffle: bool = False,
+ ) -> None:
+ """Initialize a `MNISTDataModule`.
+
+ :param data_dir: The data directory. Defaults to `"data/"`.
+ :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
+ :param batch_size: The batch size. Defaults to `64`.
+ :param num_workers: The number of workers. Defaults to `0`.
+ :param pin_memory: Whether to pin memory. Defaults to `False`.
+ """
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ self.dataset = dataset
+
+ self.data_train: Optional[Dataset] = None
+ self.data_val: Optional[Dataset] = None
+ self.data_test: Optional[Dataset] = None
+
+ self.batch_size_per_device = batch_size
+
+ def prepare_data(self) -> None:
+ """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
+ within a single process on CPU, so you can safely add your downloading logic within. In
+ case of multi-node training, the execution of this hook depends upon
+ `self.prepare_data_per_node()`.
+
+ Do not use it to assign state (self.x = y).
+ """
+ pass
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
+ `self.setup()` once the data is prepared and available for use.
+
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
+ """
+ # Divide batch size by the number of devices.
+ if self.trainer is not None:
+ if self.hparams.batch_size % self.trainer.world_size != 0:
+ raise RuntimeError(
+ f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
+ )
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
+
+ # load and split datasets only if not loaded already
+ if stage == 'fit' and not self.data_train and not self.data_val:
+ # dataset = ConcatDataset(datasets=[trainset, testset])
+ self.data_train, self.data_val = random_split(
+ dataset=self.dataset,
+ lengths=self.hparams.train_val_split,
+ generator=torch.Generator().manual_seed(self.hparams.generator_seed),
+ )
+ elif stage in ('predict', 'test'):
+ self.data_test = self.dataset
+ else:
+ raise NotImplementedError(f"Stage {stage} not implemented.")
+
+ def _dataloader_template(self, dataset: Dataset[Any]) -> DataLoader[Any]:
+ """Create a dataloader from a dataset.
+
+ :param dataset: The dataset.
+ :return: The dataloader.
+ """
+ batch_collator = BatchTensorConverter() # list of dicts -> dict of tensors
+ return DataLoader(
+ dataset=dataset,
+ collate_fn=batch_collator,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ shuffle=self.hparams.shuffle,
+ )
+
+ def train_dataloader(self) -> DataLoader[Any]:
+ """Create and return the train dataloader.
+
+ :return: The train dataloader.
+ """
+ return self._dataloader_template(self.data_train)
+
+
+ def val_dataloader(self) -> DataLoader[Any]:
+ """Create and return the validation dataloader.
+
+ :return: The validation dataloader.
+ """
+ return self._dataloader_template(self.data_val)
+
+ def test_dataloader(self) -> DataLoader[Any]:
+ """Create and return the test dataloader.
+
+ :return: The test dataloader.
+ """
+ return self._dataloader_template(self.data_test)
+
+ def teardown(self, stage: Optional[str] = None) -> None:
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
+ `trainer.test()`, and `trainer.predict()`.
+
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ Defaults to ``None``.
+ """
+ pass
+
+ def state_dict(self) -> Dict[Any, Any]:
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
+
+ :return: A dictionary containing the datamodule state that you want to save.
+ """
+ return {}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
+ `state_dict()`.
+
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
+ """
+ pass
+
diff --git a/analysis/src/eval.py b/analysis/src/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6e8462fb3c86c6ebee79d76614d16fee2d5b09
--- /dev/null
+++ b/analysis/src/eval.py
@@ -0,0 +1,217 @@
+from typing import Any, Dict, List, Tuple
+import os
+from time import strftime
+
+import numpy as np
+import pandas as pd
+import torch
+# import hydra
+# import rootutils
+# from lightning import LightningDataModule, LightningModule, Trainer
+# from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+# rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from src import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/rootutils
+# ------------------------------------------------------------------------------------ #
+
+from src.utils import (
+ RankedLogger,
+ extras,
+ instantiate_loggers,
+ log_hyperparameters,
+ task_wrapper,
+ checkpoint_utils,
+ plot_utils,
+)
+from src.common.pdb_utils import extract_backbone_coords
+from src.metrics import metrics
+from src.common.geo_utils import _find_rigid_alignment
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None):
+ """Evaluate prediction results based on pdb files.
+ """
+ if target_dir is None or not os.path.isdir(target_dir):
+ log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.")
+ return {}
+
+ assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory."
+
+ targets = [
+ d.replace(".pdb", "") for d in os.listdir(target_dir)
+ ]
+ # pred_bases = os.listdir(pred_dir)
+ output_dir = pred_dir
+ tag = tag if tag is not None else "dev"
+ timestamp = strftime("%m%d-%H-%M")
+
+ fns = {
+ 'val_clash': metrics.validity,
+ 'val_bond': metrics.bonding_validity,
+ 'js_pwd': metrics.js_pwd,
+ 'js_rg': metrics.js_rg,
+ # 'js_tica_pos': metrics.js_tica_pos,
+ 'w2_rmwd': metrics.w2_rmwd,
+ # 'div_rmsd': metrics.div_rmsd,
+ 'div_rmsf': metrics.div_rmsf,
+ 'pro_w_contacks': metrics.pro_w_contacts,
+ 'pro_t_contacks': metrics.pro_t_contacts,
+ # 'pro_c_contacks': metrics.pro_c_contacts,
+ }
+ eval_res = {k: {} for k in fns}
+
+
+ print(f"total_md_num = {len(targets)}")
+ count = 0
+ for target in targets:
+ count += 1
+ print("")
+ print(count, target)
+ pred_file = os.path.join(pred_dir, f"{target}.pdb")
+ # assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist."
+ if not os.path.isfile(pred_file):
+ continue
+
+ target_file = os.path.join(target_dir, f"{target}.pdb")
+ ca_coords = {
+ 'target': extract_backbone_coords(target_file),
+ 'pred': extract_backbone_coords(pred_file),
+ }
+ cry_target_file = os.path.join(crystal_dir, f"{target}.pdb")
+ cry_ca_coords = extract_backbone_coords(cry_target_file)[0]
+
+
+ for f_name, func in fns.items():
+ print(f_name)
+
+
+ if f_name == 'w2_rmwd':
+ v_ref = torch.as_tensor(ca_coords['target'][0])
+ for k, v in ca_coords.items():
+ v = torch.as_tensor(v) # (250,356,3)
+ for idx in range(v.shape[0]):
+ R, t = _find_rigid_alignment(v[idx], v_ref)
+ v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+ ca_coords[k] = v.numpy()
+
+
+ if f_name.startswith('js_'):
+ res = func(ca_coords, ref_key='target')
+ elif f_name == 'pro_c_contacks':
+ res = func(target_file, pred_file, cry_target_file)
+ elif f_name.startswith('pro_'):
+ res = func(ca_coords, cry_ca_coords)
+ else:
+ res = func(ca_coords)
+
+ if f_name == 'js_tica' or f_name == 'js_tica_pos':
+ pass
+ # eval_res[f_name][target] = res[0]['pred']
+ # save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png")
+ # plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target')
+ else:
+ eval_res[f_name][target] = res['pred']
+
+ csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv")
+ df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name
+ df.to_csv(csv_save_to)
+ print(f"metrics saved to {csv_save_to}")
+ mean_metrics = np.around(df.mean(), decimals=4)
+
+ return mean_metrics
+
+
+# @task_wrapper
+# def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+# """Sample on a test set and report evaluation metrics.
+
+# This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+# failure. Useful for multiruns, saving info about the crash, etc.
+
+# :param cfg: DictConfig configuration composed by Hydra.
+# :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
+# """
+# # assert cfg.ckpt_path
+# pred_dir = cfg.get("pred_dir")
+# if pred_dir and os.path.isdir(pred_dir):
+# log.info(f"Found pre-computed prediction directory {pred_dir}.")
+# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
+# return metric_dict, None
+
+# log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+# datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+# log.info(f"Instantiating model <{cfg.model._target_}>")
+# model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+# log.info("Instantiating loggers...")
+# logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
+
+# log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+# trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
+
+# object_dict = {
+# "cfg": cfg,
+# "datamodule": datamodule,
+# "model": model,
+# "logger": logger,
+# "trainer": trainer,
+# }
+
+# if logger:
+# log.info("Logging hyperparameters!")
+# log_hyperparameters(object_dict)
+
+# # Load checkpoint manually.
+# model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path)
+
+# # log.info("Starting testing!")
+# # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
+
+# # Get dataloader for prediction.
+# datamodule.setup(stage="predict")
+# dataloaders = datamodule.test_dataloader()
+
+# log.info("Starting predictions.")
+# pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1]
+
+# # metric_dict = trainer.callback_metrics
+# log.info("Starting evaluations.")
+# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
+
+# return metric_dict, object_dict
+
+
+# @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
+# def main(cfg: DictConfig) -> None:
+# """Main entry point for evaluation.
+
+# :param cfg: DictConfig configuration composed by Hydra.
+# """
+# # apply extra utilities
+# # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+# extras(cfg)
+
+# evaluate(cfg)
+
+
+# if __name__ == "__main__":
+# main()
diff --git a/analysis/src/metrics/__init__.py b/analysis/src/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc b/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..943786a20c0c58e337532cf097a1e607873eeb48
Binary files /dev/null and b/analysis/src/metrics/__pycache__/__init__.cpython-310.pyc differ
diff --git a/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc b/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62f483500f3111605d7d3dfc16fc71594d0a2aa1
Binary files /dev/null and b/analysis/src/metrics/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc b/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1a62a83bd62b9161c18369ceb843570d75934b5
Binary files /dev/null and b/analysis/src/metrics/__pycache__/metrics.cpython-310.pyc differ
diff --git a/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc b/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bacc6af013cf47b8708b0305d26ad5af468f777b
Binary files /dev/null and b/analysis/src/metrics/__pycache__/metrics.cpython-39.pyc differ
diff --git a/analysis/src/metrics/metrics.py b/analysis/src/metrics/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..01a1e521f38f018ea1e814bdc10a3a589191c588
--- /dev/null
+++ b/analysis/src/metrics/metrics.py
@@ -0,0 +1,550 @@
+import os
+from typing import *
+
+import numpy as np
+import torch
+from scipy.spatial import distance
+# from deeptime.decomposition import TICA
+from src.common.geo_utils import rmsd, _find_rigid_alignment, squared_deviation
+from scipy.linalg import fractional_matrix_power
+from sklearn.mixture import GaussianMixture
+from Bio.PDB import PDBParser
+# import freesasa
+from Bio.PDB.Polypeptide import PPBuilder
+import multiprocessing as mp
+
+EPS = 1e-12
+PSEUDO_C = 1e-6
+
+
+def adjacent_ca_distance(coords):
+ """Calculate distance array for a single chain of CA atoms. Only k=1 neighbors.
+ Args:
+ coords: (..., L, 3)
+ return
+ dist: (..., L-1)
+ """
+ assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3)
+ dX = coords[..., :-1, :] - coords[..., 1:, :] # (..., L-1, 3)
+ dist = np.sqrt(np.sum(dX**2, axis=-1))
+ return dist # (..., L-1)
+
+
+def distance_matrix_ca(coords):
+ """Calculate distance matrix for a single chain of CA atoms. W/o exclude neighbors.
+ Args:
+ coords: (..., L, 3)
+ Return:
+ dist: (..., L, L)
+ """
+ assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3)
+ dX = coords[..., None, :, :] - coords[..., None, :] # (..., L, L, 3)
+ dist = np.sqrt(np.sum(dX**2, axis=-1))
+ return dist # (..., L, L)
+
+
+def pairwise_distance_ca(coords, k=1):
+ """Calculate pairwise distance vector for a single chain of CA atoms. W/o exclude neighbors.
+ Args:
+ coords: (..., L, 3)
+ Return:
+ dist: (..., D) (D=L * (L - 1) // 2) when k=1)
+ """
+ assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3)
+ dist = distance_matrix_ca(coords)
+ L = dist.shape[-1]
+ row, col = np.triu_indices(L, k=k)
+ triu = dist[..., row, col] # unified (but unclear) order
+ return triu # (..., D)
+
+
+def radius_of_gyration(coords, masses=None):
+ """Compute the radius of gyration for every frame.
+
+ Args:
+ coords: (..., num_atoms, 3)
+ masses: (num_atoms,)
+
+ Returns:
+ Rg: (..., )
+
+ If masses are none, assumes equal masses.
+ """
+ assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3)
+
+ if masses is None:
+ masses = np.ones(coords.shape[-2])
+ else:
+ assert len(masses.shape) == 1, f"masses should be 1D, got {masses.shape}"
+ assert masses.shape[0] == coords.shape[-2], f"masses {masses.shape} != number of particles {coords.shape[-2]}"
+
+ weights = masses / masses.sum()
+ centered = coords - coords.mean(-2, keepdims=True)
+ squared_dists = (centered ** 2).sum(-1)
+ Rg = (squared_dists * weights).sum(-1) ** 0.5
+ return Rg
+
+
+def _steric_clash(coords, ca_vdw_radius=1.7, allowable_overlap=0.4, k_exclusion=0):
+ """ https://www.schrodinger.com/sites/default/files/s3/public/python_api/2022-3/_modules/schrodinger/structutils/interactions/steric_clash.html#clash_iterator
+ Calculate the number of clashes in a single chain of CA atoms.
+
+ Usage:
+ n_clash = calc_clash(coords)
+
+ Args:
+ coords: (n_atoms, 3), CA coordinates, coords should from one protein chain.
+ ca_vdw_radius: float, default 1.7.
+ allowable_overlap: float, default 0.4.
+ k_exclusion: int, default 0. Exclude neighbors within [i-k-1, i+k+1].
+
+ """
+ assert np.isnan(coords).sum() == 0, "coords should not contain nan"
+ assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" # (B, L, 3)
+ assert k_exclusion >= 0, "k_exclusion should be non-negative"
+ bar = 2 * ca_vdw_radius - allowable_overlap
+ # L = len(coords)
+ # dist = np.sqrt(np.sum((coords[:L-k_exclusion, None, :] - coords[None, k_exclusion:, :])**2, axis=-1))
+ pwd = pairwise_distance_ca(coords, k=k_exclusion+1) # by default, only excluding self (k=1)
+
+ # print('val_clash')
+ # print(pwd.shape)
+ # print(pwd.max(),pwd.min())
+ # idx_min=-1
+ # smin=10
+ # for idx,pwd_single in enumerate(pwd):
+ # if pwd_single.min()0).mean() for k, v in n_clash.items()
+ # }
+ results = {
+ k: 1.0 - (v/num_residue).mean() for k, v in n_clash.items()
+ }
+
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ return results
+
+
+def bonding_validity(ca_coords_dict, ref_key='target', eps=1e-6):
+ """Calculate bonding dissociation validity of ensembles."""
+ adj_dist = {k: adjacent_ca_distance(v)
+ for k, v in ca_coords_dict.items()
+ }
+ thres = adj_dist[ref_key].max()+ 1e-6
+
+ # print('val_bond')
+ # print('target')
+ # print(adj_dist['target'].shape)
+ # print(adj_dist['target'].max(),adj_dist['target'].min())
+ # print('pred')
+ # print(adj_dist['pred'].shape)
+ # print(adj_dist['pred'].max(),adj_dist['pred'].min())
+
+ # idx_max=-1
+ # smax=0
+ # for idx,adj in enumerate(adj_dist['pred']):
+ # if adj.max()>smax:
+ # smax=adj.max()
+ # idx_max=idx+1
+ # print('smax=',smax)
+ # print('idx_max=',idx_max)
+ # max_index2 = np.argmax(adj_dist['pred'], axis=-1)
+ # print('res_max_index_all',max_index2+1)
+ # print('res_max=',max_index2[idx_max-1]+1)
+
+ # results = {
+ # k: (v < thres).all(-1).sum().item() / len(v)
+ # for k, v in adj_dist.items()
+ # }
+ results = {
+ k: (v < thres).mean()
+ for k, v in adj_dist.items()
+ }
+
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ return results
+
+
+def js_pwd(ca_coords_dict, ref_key='target', n_bins=50, pwd_offset=3, weights=None):
+ # n_bins = 50 follows idpGAN
+ # k=3 follows 2for1
+
+ ca_pwd = {
+ k: pairwise_distance_ca(v, k=pwd_offset) for k, v in ca_coords_dict.items()
+ } # (B, D)
+
+ if weights is None:
+ weights = {}
+ weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights})
+
+ d_min = ca_pwd[ref_key].min(axis=0) # (D, )
+ d_max = ca_pwd[ref_key].max(axis=0)
+ ca_pwd_binned = {
+ k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0,
+ np.concatenate([v, d_min[None], d_max[None]], axis=0))
+ for k, v in ca_pwd.items()
+ } # (N_bins, D)-> (N_bins * D, )
+ # js divergence per channel and average
+ results = {k: distance.jensenshannon(v, ca_pwd_binned[ref_key], axis=0).mean()
+ for k, v in ca_pwd_binned.items() if k != ref_key}
+ results[ref_key] = 0.0
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ return results
+
+# def js_tica(ca_coords_dict, ref_key='target', n_bins=50, lagtime=20, return_tic=True, weights=None):
+# # n_bins = 50 follows idpGAN
+# ca_pwd = {
+# k: pairwise_distance_ca(v) for k, v in ca_coords_dict.items()
+# } # (B, D)
+
+# print('tica1', ca_pwd[ref_key].shape)
+# estimator = TICA(dim=2, lagtime=lagtime).fit(ca_pwd[ref_key])
+# print('tica2')
+# tica = estimator.fetch_model()
+# # dimension reduction into 2D
+# ca_dr2d = {
+# k: tica.transform(v) for k, v in ca_pwd.items()
+# }
+# if weights is None: weights = {}
+# weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights})
+
+# d_min = ca_dr2d[ref_key].min(axis=0) # (D, )
+# d_max = ca_dr2d[ref_key].max(axis=0)
+# ca_dr2d_binned = {
+# k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0,
+# np.concatenate([v, d_min[None], d_max[None]], axis=0))
+# for k, v in ca_dr2d.items()
+# } # (N_bins, 2)
+# results = {k: distance.jensenshannon(v, ca_dr2d_binned[ref_key], axis=0).mean()
+# for k, v in ca_dr2d_binned.items() if k != ref_key}
+# results[ref_key] = 0.0
+# results = {k: np.around(v, decimals=4) for k, v in results.items()}
+# if return_tic:
+# return results, ca_dr2d
+# return results
+
+# def js_tica_pos(ca_coords_dict, ref_key='target', n_bins=50, lagtime=20, return_tic=True, weights=None):
+# # n_bins = 50 follows idpGAN
+# v_ref = torch.as_tensor(ca_coords_dict['target'][0])
+# for k, v in ca_coords_dict.items():
+# v = torch.as_tensor(v)
+# for idx in range(v.shape[0]):
+# R, t = _find_rigid_alignment(v[idx], v_ref)
+# v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+# ca_coords_dict[k] = v.numpy()
+
+# ca_pos = { k: v.reshape(v.shape[0],-1) for k, v in ca_coords_dict.items()} # (B, 3*N)
+
+# estimator = TICA(dim=2, lagtime=lagtime).fit(ca_pos[ref_key])
+# tica = estimator.fetch_model()
+# # dimension reduction into 2D
+# ca_dr2d = {
+# k: tica.transform(v) for k, v in ca_pos.items()
+# }
+# if weights is None: weights = {}
+# weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights})
+
+# d_min = ca_dr2d[ref_key].min(axis=0) # (D, )
+# d_max = ca_dr2d[ref_key].max(axis=0)
+# ca_dr2d_binned = {
+# k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0,
+# np.concatenate([v, d_min[None], d_max[None]], axis=0))
+# for k, v in ca_dr2d.items()
+# } # (N_bins, 2)
+# results = {k: distance.jensenshannon(v, ca_dr2d_binned[ref_key], axis=0).mean()
+# for k, v in ca_dr2d_binned.items() if k != ref_key}
+# results[ref_key] = 0.0
+# results = {k: np.around(v, decimals=4) for k, v in results.items()}
+# if return_tic:
+# return results, ca_dr2d
+# return results
+
+def js_rg(ca_coords_dict, ref_key='target', n_bins=50, weights=None):
+ ca_rg = {
+ k: radius_of_gyration(v) for k, v in ca_coords_dict.items()
+ } # (B, )
+ if weights is None:
+ weights = {}
+ weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights})
+
+ d_min = ca_rg[ref_key].min() # (1, )
+ d_max = ca_rg[ref_key].max()
+ ca_rg_binned = {
+ k: np.histogram(v, bins=n_bins, weights=weights[k], range=(d_min, d_max))[0]+PSEUDO_C
+ for k, v in ca_rg.items()
+ } # (N_bins, )
+ # print("ca_rg_binned shape", {k: v.shape for k, v in ca_rg_binned.items()})
+ results = {k: distance.jensenshannon(v, ca_rg_binned[ref_key], axis=0).mean()
+ for k, v in ca_rg_binned.items() if k != ref_key}
+
+ results[ref_key] = 0.0
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ return results
+
+def div_rmsd(ca_coords_dict):
+ results = {}
+ for k, v in ca_coords_dict.items():
+
+ # print(k) # [target, pred]
+ # print(v.shape) # (25,356,3)
+ # only calculate Ca
+
+ v = torch.as_tensor(v)
+ # for idx in range(v.shape[0]):
+ # R, t = _find_rigid_alignment(v[idx], v[0])
+ # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+
+ # v1 = v.numpy()
+ # count = v1.shape[0]
+ # rmsd_sum = np.sum(np.sqrt(np.sum((v1[:, None, :] - v1[None, :, :])**2, axis=-1)))
+
+ count = 0
+ rmsd_2_sum = 0
+ for coord1 in v:
+ for coord2 in v:
+ count += 1
+ rmsd_2_sum += squared_deviation(coord1,coord2,reduction='none') # (356,)
+
+ # with mp.Pool() as pool:
+ # res= pool.starmap(squared_deviation,[(coord1, coord2, 'none') for coord1 in v for coord2 in v])
+ # pool.close()
+ # pool.join()
+ # count = len(res)-v.shape[0]
+ # rmsd_2_sum = sum(res)
+
+ results[k]=torch.sqrt(rmsd_2_sum/count)
+ results[k]=np.around(float(torch.mean(results[k])), decimals=4)
+ results['pred'] = (results['pred']-results['target'])/results['target']
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ # print(result)
+ return results
+
+def div_rmsf(ca_coords_dict):
+ '''
+ 1D and 0D data
+ '''
+ results = {}
+ for k, v in ca_coords_dict.items():
+
+ v = torch.as_tensor(v) # (250,356,3)
+ # for idx in range(v.shape[0]):
+ # R, t = _find_rigid_alignment(v[idx], v[0])
+ # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+
+ count = 0
+ rmsd_2_sum = 0
+ mean_str = torch.mean(v,dim = 0) # (356,3)
+ for coord1 in v:
+ count += 1
+ rmsd_2_sum += squared_deviation(coord1,mean_str,reduction='none') # (356,)
+
+ # count = v.shape[0]
+ # rmsd_2_sum = torch.sum(torch.norm(v - mean_str[None,...], dim=-1) ** 2)
+
+ # mean_str = torch.mean(v,dim = 0) # (356,3)
+ # with mp.Pool() as pool:
+ # res= pool.starmap(squared_deviation,[[(coord1, mean_str, 'none') for coord1 in v]])
+ # pool.close()
+ # pool.join()
+ # count = len(res)
+ # rmsd_2_sum = sum(res)
+
+ results[k]=torch.sqrt(rmsd_2_sum/count)
+ results[k]=np.around(float(torch.mean(results[k])), decimals=4)
+ # print(result[k])
+ results['pred'] = (results['pred']-results['target'])/results['target']
+ results = {k: np.around(v, decimals=4) for k, v in results.items()}
+ return results
+
+def w2_rmwd(ca_coords_dict):
+ result = {}
+ means_total = {}
+ covariances_total = {}
+ count = 0
+ v_ref = torch.as_tensor(ca_coords_dict['target'][0])
+ for k, v in ca_coords_dict.items():
+
+ v = torch.as_tensor(v)
+ # for idx in range(v.shape[0]):
+ # R, t = _find_rigid_alignment(v[idx], v_ref)
+ # v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
+
+ means_total[k] = []
+ covariances_total[k] = []
+
+ for idx_residue in range(v.shape[1]):
+ gmm = GaussianMixture(n_components=1)
+ gmm.fit(v[:, idx_residue, :])
+ means = torch.as_tensor(gmm.means_[0]) # 形状为 (3,)
+ covariances = torch.as_tensor(gmm.covariances_[0]) # 形状为 (3, 3)
+
+ means_total[k].append(means)
+ covariances_total[k].append(covariances)
+ means_total[k] = torch.stack(means_total[k], dim=0) # (356, 3)
+ covariances_total[k] = torch.stack(covariances_total[k], dim=0) # (356, 3, 3)
+ # print(means_total[k].shape, covariances_total[k].shape)
+ # print(means_total[k][0], covariances_total[k][0])
+
+ sigma_1_2_sqrt = [torch.as_tensor(fractional_matrix_power(i, 0.5)) for i in torch.matmul(covariances_total['target'], covariances_total['pred'])]
+ sigma_1_2_sqrt = torch.stack(sigma_1_2_sqrt, dim=0)
+ sigma_trace = covariances_total['target'] + covariances_total['pred'] - 2 * sigma_1_2_sqrt
+ sigma_trace = [torch.trace(i) for i in sigma_trace]
+ sigma_trace = torch.stack(sigma_trace, dim=0)
+
+ result_1D = torch.sum((means_total['target'] - means_total['pred'])**2, dim=-1) + sigma_trace
+ result['pred'] = np.around(float(torch.mean(result_1D)), decimals=4)
+ # print(result['pred'])
+
+ return result
+
+def pro_w_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1):
+ result = {}
+ w_contacts_total = {}
+
+ dist = distance_matrix_ca(cry_ca_coords)
+ L = dist.shape[-1]
+ row, col = np.triu_indices(L, k=1)
+ triu = dist[..., row, col] # (n*(n-1)/2)
+ w_contacts_crystall = (triu < dist_threshold)
+
+ for k, v in ca_coords_dict.items():
+
+ dist = distance_matrix_ca(v)
+
+ L = dist.shape[-1]
+ row, col = np.triu_indices(L, k=1)
+ triu = dist[..., row, col] # (b, n*(n-1)/2)
+
+ w_contacts = (torch.tensor(triu) > dist_threshold).type(torch.float32)
+ w_contacts = torch.mean(w_contacts, dim=0) # (n*(n-1)/2,)
+ w_contacts = w_contacts > percent_threshold
+
+ w_contacts_total[k] = w_contacts & w_contacts_crystall
+
+ jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred'])
+ result['pred'] = np.around(float(jac_w_contacts), decimals=4)
+ # print(result['pred'])
+
+ return result
+
+def pro_t_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1):
+ result = {}
+ w_contacts_total = {}
+
+ dist = distance_matrix_ca(cry_ca_coords)
+ L = dist.shape[-1]
+ row, col = np.triu_indices(L, k=1)
+ triu = dist[..., row, col] # (n*(n-1)/2)
+ w_contacts_crystall = (triu >= dist_threshold)
+
+ for k, v in ca_coords_dict.items():
+
+ dist = distance_matrix_ca(v)
+
+ L = dist.shape[-1]
+ row, col = np.triu_indices(L, k=1)
+ triu = dist[..., row, col] # (b, n*(n-1)/2)
+
+ w_contacts = (torch.tensor(triu) <= dist_threshold).type(torch.float32)
+ w_contacts = torch.mean(w_contacts, dim=0) # (n*(n-1)/2,)
+ w_contacts = w_contacts > percent_threshold
+
+ w_contacts_total[k] = w_contacts & w_contacts_crystall
+
+ jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred'])
+ result['pred'] = np.around(float(jac_w_contacts), decimals=4)
+ # print(result['pred'])
+
+ return result
+
+# def pro_c_contacts(target_file, pred_file, cry_target_file, area_threshold = 2.0, percent_threshold = 0.1):
+# result = {}
+# c_contacts_total = {}
+
+# parser = PDBParser()
+# params = freesasa.Parameters({'algorithm': 'ShrakeRupley', 'probe-radius': 2.8})
+
+# structure_cry_target = parser.get_structure('cry_target', cry_target_file)
+# str_params = {'separate-chains': False, 'separate-models': True}
+# structure_target = freesasa.structureArray(target_file,str_params)
+# structure_pred = freesasa.structureArray(pred_file,str_params)
+
+
+# structure_cry_target = freesasa.structureFromBioPDB(structure_cry_target)
+# sasa = freesasa.calc(structure_cry_target,params)
+# residue_sasa = sasa.residueAreas()
+
+# c_contacts_crystall = []
+# # 打印每个残基的 SASA
+# for chain_id in residue_sasa:
+# for residue_id in residue_sasa[chain_id]:
+# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}")
+# c_contacts_crystall.append(residue_sasa[chain_id][residue_id].total < area_threshold)
+# c_contacts_crystall = torch.tensor(c_contacts_crystall)
+
+# c_contacts_target = 0
+# count = 0
+# for structure_temp in structure_target:
+# count += 1
+# sasa = freesasa.calc(structure_temp,params)
+# residue_sasa = sasa.residueAreas()
+
+# c_contacts_temp = []
+# # 打印每个残基的 SASA
+# for chain_id in residue_sasa:
+# for residue_id in residue_sasa[chain_id]:
+# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}")
+# c_contacts_temp.append(residue_sasa[chain_id][residue_id].total > area_threshold)
+# c_contacts_temp = torch.tensor(c_contacts_temp).type(torch.float32)
+# c_contacts_target += c_contacts_temp
+# c_contacts_target = c_contacts_target / count
+# c_contacts_total['target'] = (c_contacts_target > percent_threshold) & c_contacts_crystall
+
+
+# c_contacts_pred = 0
+# count = 0
+# for structure_temp in structure_pred:
+# count += 1
+# sasa = freesasa.calc(structure_temp,params)
+# residue_sasa = sasa.residueAreas()
+
+# c_contacts_temp = []
+# # 打印每个残基的 SASA
+# for chain_id in residue_sasa:
+# for residue_id in residue_sasa[chain_id]:
+# # print(f"Chain {chain_id}, Residue {residue_id}: {residue_sasa[chain_id][residue_id].residueType}, area: {residue_sasa[chain_id][residue_id].total}")
+# c_contacts_temp.append(residue_sasa[chain_id][residue_id].total > area_threshold)
+# c_contacts_temp = torch.tensor(c_contacts_temp).type(torch.float32)
+# c_contacts_pred += c_contacts_temp
+# c_contacts_pred = c_contacts_pred / count
+# c_contacts_total['pred'] = (c_contacts_pred > percent_threshold) & c_contacts_crystall
+
+# jac_w_contacts = torch.sum(c_contacts_total['target'] & c_contacts_total['pred'])/torch.sum(c_contacts_total['target'] | c_contacts_total['pred'])
+# result['pred'] = np.around(float(jac_w_contacts), decimals=4)
+# # print(jac_w_contacts)
+# return result
\ No newline at end of file
diff --git a/analysis/src/models/__init__.py b/analysis/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/models/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36259e9bc2d83c41aad93af61758d758f3108391
Binary files /dev/null and b/analysis/src/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc b/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..851fa55c838e0e105cdf51b7f0c19e6a37538f78
Binary files /dev/null and b/analysis/src/models/__pycache__/diffusion_module.cpython-39.pyc differ
diff --git a/analysis/src/models/__pycache__/loss.cpython-39.pyc b/analysis/src/models/__pycache__/loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c562b4260ada6ae3604b36d9cc562a2c66dc0b34
Binary files /dev/null and b/analysis/src/models/__pycache__/loss.cpython-39.pyc differ
diff --git a/analysis/src/models/diffusion_module.py b/analysis/src/models/diffusion_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..671de760108128189ab3ef7d4b54187be039d16d
--- /dev/null
+++ b/analysis/src/models/diffusion_module.py
@@ -0,0 +1,408 @@
+import os
+from typing import Any, Dict, Tuple, Optional
+from random import random
+from copy import deepcopy
+
+import numpy as np
+import torch
+from lightning import LightningModule
+from torchmetrics import MinMetric, MeanMetric
+
+from src.models.score.frame import FrameDiffuser
+from src.models.loss import ScoreMatchingLoss
+from src.common.rigid_utils import Rigid
+from src.common.all_atom import compute_backbone
+from src.common.pdb_utils import atom37_to_pdb, merge_pdbfiles
+
+
+class DiffusionLitModule(LightningModule):
+ """Example of a `LightningModule` for denoising diffusion training.
+
+ A `LightningModule` implements 8 key methods:
+
+ ```python
+ def __init__(self):
+ # Define initialization code here.
+
+ def setup(self, stage):
+ # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
+ # This hook is called on every process when using DDP.
+
+ def training_step(self, batch, batch_idx):
+ # The complete training step.
+
+ def validation_step(self, batch, batch_idx):
+ # The complete validation step.
+
+ def test_step(self, batch, batch_idx):
+ # The complete test step.
+
+ def predict_step(self, batch, batch_idx):
+ # The complete predict step.
+
+ def configure_optimizers(self):
+ # Define and configure optimizers and LR schedulers.
+ ```
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ diffuser: FrameDiffuser,
+ loss: Dict[str, Any],
+ compile: bool,
+ inference: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """Initialize a `MNISTLitModule`.
+
+ :param net: The model to train.
+ :param optimizer: The optimizer to use for training.
+ :param scheduler: The learning rate scheduler to use for training.
+ """
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ # network and diffusion module
+ self.net = net
+ self.diffuser = diffuser
+
+ # loss function
+ self.loss = ScoreMatchingLoss(config=self.hparams.loss)
+
+ # for averaging loss across batches
+ self.train_loss = MeanMetric()
+ self.val_loss = MeanMetric()
+ # self.test_loss = MeanMetric()
+
+ # for tracking best so far validation accuracy
+ self.val_loss_best = MinMetric()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Perform a forward pass through the model `self.net`.
+ (Not actually used)
+
+ :param x: A tensor of images.
+ :return: A tensor of logits.
+ """
+ return self.net(x)
+
+ def on_train_start(self) -> None:
+ """Lightning hook that is called when training begins."""
+ # by default lightning executes validation step sanity checks before training starts,
+ # so it's worth to make sure validation metrics don't store results from these checks
+ self.val_loss.reset()
+ self.val_loss_best.reset()
+
+ def model_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor], training: Optional[bool] = True
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Perform a single model step on a batch of data.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
+
+ :return: A tuple containing (in order):
+ - A tensor of losses.
+ - A tensor of predictions.
+ - A tensor of target labels.
+ """
+ # preprocess by augmenting additional feats
+ rigids_0 = Rigid.from_tensor_4x4(batch['rigidgroups_gt_frames'][..., 0, :, :])
+ batch_size = rigids_0.shape[0]
+ t = (1.0 - self.diffuser.min_t) * torch.rand(batch_size, device=rigids_0.device) + self.diffuser.min_t
+ perturb_feats = self.diffuser.forward_marginal(
+ rigids_0=rigids_0,
+ t=t,
+ diffuse_mask=None,
+ as_tensor_7=True,
+ )
+ patch_feats = {
+ 't': t,
+ 'rigids_0': rigids_0,
+ }
+ batch.update({**perturb_feats, **patch_feats})
+
+ # probably add self-conditioning (recycle once)
+ if self.net.embedder.self_conditioning and random() > 0.5:
+ with torch.no_grad():
+ batch['sc_ca_t'] = self.net(batch, as_tensor_7=True)['rigids'][..., 4:]
+
+ # feedforward
+ out = self.net(batch)
+
+ # postprocess by add score computation
+ pred_scores = self.diffuser.score(
+ rigids_0=out['rigids'],
+ rigids_t=Rigid.from_tensor_7(batch['rigids_t']),
+ t=t,
+ mask=batch['residue_mask'],
+ )
+ out.update(pred_scores)
+
+ # calculate losses
+ loss, loss_bd = self.loss(out, batch, _return_breakdown=True)
+ return loss, loss_bd
+
+ def training_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
+ ) -> torch.Tensor:
+ """Perform a single training step on a batch of data from the training set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ :return: A tensor of losses between model predictions and targets.
+ """
+ loss, loss_bd = self.model_step(batch)
+
+ # update and log metrics
+ self.train_loss(loss)
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ for k,v in loss_bd.items():
+ if k == 'loss': continue
+ self.log(f"train/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ # return loss or backpropagation will fail
+ return loss
+
+ def on_train_epoch_end(self) -> None:
+ "Lightning hook that is called when a training epoch ends."
+ pass
+
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single validation step on a batch of data from the validation set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ loss, loss_bd = self.model_step(batch, training=False)
+
+ # update and log metrics
+ self.val_loss(loss) # update
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ def on_validation_epoch_end(self) -> None:
+ "Lightning hook that is called when a validation epoch ends."
+ _vall = self.val_loss.compute() # get current val acc
+ self.val_loss_best(_vall) # update best so far val acc
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
+ # otherwise metric would be reset by lightning after each epoch
+ self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single test step on a batch of data from the test set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ raise NotImplementedError("Test step not implemented.")
+
+ def on_test_epoch_end(self) -> None:
+ """Lightning hook that is called when a test epoch ends."""
+ pass
+
+ def predict_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int,
+ ) -> str:
+ """Perform a prediction step on a batch of data from the dataloader.
+
+ This prediction step will sample `n_replica` copies from the forward-backward process,
+ repeated for each delta-T in the range of [delta_min, delta_max] with step size
+ `delta_step`. If `backward_only` is set to True, then only backward process will be
+ performed, and `n_replica` will be multiplied by the number of delta-Ts.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ # extract hyperparams for inference
+ n_replica = self.hparams.inference.n_replica
+ replica_per_batch = self.hparams.inference.replica_per_batch
+ delta_range = np.arange(
+ self.hparams.inference.delta_min,
+ self.hparams.inference.delta_max + 1e-5,
+ self.hparams.inference.delta_step
+ )
+ delta_range = np.around(delta_range, decimals=2) # up to 2 decimal places
+ num_timesteps = self.hparams.inference.num_timesteps
+ noise_scale = self.hparams.inference.noise_scale
+ probability_flow = self.hparams.inference.probability_flow
+ self_conditioning = self.hparams.inference.self_conditioning and self.net.embedder.self_conditioning
+ min_t = self.hparams.inference.min_t
+ output_dir = self.hparams.inference.output_dir
+ backward_only = self.hparams.inference.backward_only
+ # if backward_only, then only perform backward process (vanilla sampling of diffusion)
+ if backward_only:
+ n_replica *= len(delta_range)
+ delta_range = [-1.0]
+
+ assert batch['aatype'].shape[0] == 1, "Batch size must be 1 for correct inference."
+
+ # get extra features of the current protein
+ accession_code = batch['accession_code'][0]
+ extra = {
+ 'aatype': batch['aatype'][0].detach().cpu().numpy(),
+ 'chain_index': batch['chain_index'][0].detach().cpu().numpy(),
+ 'residue_index': batch['residue_index'][0].detach().cpu().numpy(),
+ }
+
+ # define sampling subroutine
+ def forward_backward(rigids_0: Rigid, t_delta: float):
+ # if t_delta <= 0 (invalid), then perform backward only (from t=T)
+ T = t_delta if t_delta > 0 else 1.0
+
+ batch_size, device = rigids_0.shape[0], rigids_0.device
+ _num_timesteps = int(float(num_timesteps) * T)
+ dt = 1.0 / _num_timesteps
+ ts = np.linspace(min_t, T, _num_timesteps)[::-1] # reverse in time
+
+ _feats = deepcopy({
+ k: v.repeat(batch_size, *(1,)*(v.ndim-1))
+ for k,v in batch.items() if k in ('aatype', 'residue_mask', 'fixed_mask', 'residue_idx', 'torsion_angles_sin_cos')
+ })
+ if t_delta > 0:
+ rigids_t = self.diffuser.forward_marginal(
+ rigids_0=rigids_0,
+ t=t_delta * torch.ones(batch_size, device=device),
+ diffuse_mask=_feats['residue_mask'],
+ as_tensor_7=True,
+ )['rigids_t']
+ else:
+ rigids_t = self.diffuser.sample_prior(
+ shape=rigids_0.shape,
+ device=device,
+ as_tensor_7=True,
+ )['rigids_t']
+
+ _feats['rigids_t'] = rigids_t
+
+ traj_atom37 = []
+ with torch.no_grad():
+ fixed_mask = _feats['fixed_mask'] * _feats['residue_mask']
+ diffuse_mask = (1 - _feats['fixed_mask']) * _feats['residue_mask']
+
+ if self_conditioning:
+ _feats['sc_ca_t'] = torch.zeros_like(rigids_t[..., 4:])
+ _feats['t'] = ts[0] * torch.ones(batch_size, device=device)
+ _feats['sc_ca_t'] = self.net(_feats, as_tensor_7=True)['rigids'][..., 4:] # update self-conditioning feats
+
+ for t in ts:
+ _feats['t'] = t * torch.ones(batch_size, device=device)
+ out = self.net(_feats, as_tensor_7=False)
+
+ # compute predicted rigids
+ if t == min_t:
+ rigids_pred = out['rigids']
+ else:
+ # update self-conditioning feats
+ if self_conditioning:
+ _feats['sc_ca_t'] = out['rigids'].to_tensor_7()[..., 4:]
+ # get score based on predicted rigids
+ pred_scores = self.diffuser.score(
+ rigids_0=out['rigids'],
+ rigids_t=Rigid.from_tensor_7(_feats['rigids_t']),
+ t=_feats['t'],
+ mask=_feats['residue_mask'],
+ )
+ rigids_pred = self.diffuser.reverse(
+ rigids_t=Rigid.from_tensor_7(_feats['rigids_t']),
+ rot_score=pred_scores['rot_score'],
+ trans_score=pred_scores['trans_score'],
+ t=_feats['t'],
+ dt=dt,
+ diffuse_mask=diffuse_mask,
+ center_trans=True,
+ noise_scale=noise_scale,
+ probability_flow=probability_flow,
+ ) # Rigid object
+ # update rigids_t as tensor_7
+ _feats['rigids_t'] = rigids_pred.to_tensor_7()
+
+ # compute atom37 positions
+ atom37 = compute_backbone(rigids_pred, out['psi'], aatype=_feats['aatype'])[0]
+ atom37 = atom37.detach().cpu().numpy() # (Bi, L, 37 ,3)
+ return atom37
+
+ saved_paths = []
+
+ # iterate over delta-Ts
+ for t_delta in delta_range:
+ gt_rigids_4x4 = batch['rigidgroups_gt_frames'][..., 0, :, :].clone()
+ n_bs = n_replica // replica_per_batch # number of batches
+ last_bs = n_replica % replica_per_batch # last batch size
+ atom_positions = []
+ for _ in range(n_bs):
+ rigids_0 = Rigid.from_tensor_4x4(gt_rigids_4x4.repeat(replica_per_batch, *(1,)*(gt_rigids_4x4.ndim-1)))
+ traj_atom37 = forward_backward(rigids_0, t_delta)
+ atom_positions.append(traj_atom37)
+ if last_bs > 0:
+ rigids_0 = Rigid.from_tensor_4x4(gt_rigids_4x4.repeat(last_bs, *(1,)*(gt_rigids_4x4.ndim-1)))
+ traj_atom37 = forward_backward(rigids_0, t_delta)
+ atom_positions.append(traj_atom37)
+ atom_positions = np.concatenate(atom_positions, axis=0) # (B, L, 37, 3)
+
+ # Save atom positions to pdb.
+ t_delta_dir = os.path.join(output_dir, f"{t_delta}")
+ os.makedirs(t_delta_dir, exist_ok=True)
+ save_to = os.path.join(t_delta_dir, f"{accession_code}.pdb")
+ saved_to = atom37_to_pdb(
+ atom_positions=atom_positions,
+ save_to=save_to,
+ **extra,
+ )
+ saved_paths.append(saved_to)
+
+ all_delta_dir = os.path.join(output_dir, "all_delta")
+ os.makedirs(all_delta_dir, exist_ok=True)
+ merge_pdbfiles(saved_paths, os.path.join(all_delta_dir, f"{accession_code}.pdb"))
+
+ return all_delta_dir
+
+ def setup(self, stage: str) -> None:
+ """Lightning hook that is called at the beginning of fit (train + validate), validate,
+ test, or predict.
+
+ This is a good hook when you need to build models dynamically or adjust something about
+ them. This hook is called on every process when using DDP.
+
+ :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ """
+ if self.hparams.compile and stage == "fit":
+ self.net = torch.compile(self.net)
+
+ def configure_optimizers(self) -> Dict[str, Any]:
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+
+ Examples:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
+
+ :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
+ """
+ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
+ if self.hparams.scheduler is not None:
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "monitor": "val/loss",
+ "interval": "epoch",
+ "frequency": 1,
+ },
+ }
+ return {"optimizer": optimizer}
+
+
+if __name__ == "__main__":
+ _ = DiffusionLitModule(None, None, None, None, None)
diff --git a/analysis/src/models/loss.py b/analysis/src/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3538b6d59abb48c841791fa8c28d704bbd894d67
--- /dev/null
+++ b/analysis/src/models/loss.py
@@ -0,0 +1,1741 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+import logging
+from typing import Dict, Optional, Tuple
+
+import ml_collections
+import numpy as np
+import torch
+import torch.nn as nn
+
+from src.common import residue_constants
+from src.common.all_atom import compute_backbone
+from src.common.rigid_utils import Rotation, Rigid
+from src.utils.tensor_utils import (
+ tree_map,
+ tensor_tree_map,
+ masked_mean,
+ permute_final_dims,
+ batched_gather,
+ sum_except_batch,
+ inflate_array_like
+)
+
+
+def softmax_cross_entropy(logits, labels):
+ loss = -1 * torch.sum(
+ labels * torch.nn.functional.log_softmax(logits, dim=-1),
+ dim=-1,
+ )
+ return loss
+
+
+def sigmoid_cross_entropy(logits, labels):
+ log_p = torch.log(torch.sigmoid(logits))
+ log_not_p = torch.log(torch.sigmoid(-logits))
+ loss = -labels * log_p - (1 - labels) * log_not_p
+ return loss
+
+
+def torsion_angle_loss(
+ a, # [*, N, 7, 2]
+ a_gt, # [*, N, 7, 2]
+ a_alt_gt, # [*, N, 7, 2]
+):
+ # [*, N, 7]
+ norm = torch.norm(a, dim=-1)
+
+ # [*, N, 7, 2]
+ a = a / norm.unsqueeze(-1)
+
+ # [*, N, 7]
+ diff_norm_gt = torch.norm(a - a_gt, dim=-1)
+ diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
+ min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
+
+ # [*]
+ l_torsion = torch.mean(min_diff, dim=(-1, -2))
+ l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
+
+ an_weight = 0.02
+ return l_torsion + an_weight * l_angle_norm
+
+
+def compute_fape(
+ pred_frames: Rigid,
+ target_frames: Rigid,
+ frames_mask: torch.Tensor,
+ pred_positions: torch.Tensor,
+ target_positions: torch.Tensor,
+ positions_mask: torch.Tensor,
+ length_scale: float,
+ l1_clamp_distance: Optional[float] = None,
+ eps=1e-8,
+ ignore_nan=True,
+) -> torch.Tensor:
+ """
+ Computes FAPE loss.
+
+ Args:
+ pred_frames:
+ [*, N_frames] Rigid object of predicted frames
+ target_frames:
+ [*, N_frames] Rigid object of ground truth frames
+ frames_mask:
+ [*, N_frames] binary mask for the frames
+ pred_positions:
+ [*, N_pts, 3] predicted atom positions
+ target_positions:
+ [*, N_pts, 3] ground truth positions
+ positions_mask:
+ [*, N_pts] positions mask
+ length_scale:
+ Length scale by which the loss is divided
+ l1_clamp_distance:
+ Cutoff above which distance errors are disregarded
+ eps:
+ Small value used to regularize denominators
+ Returns:
+ [*] loss tensor
+ """
+ # [*, N_frames, N_pts, 3]
+ local_pred_pos = pred_frames.invert()[..., None].apply(
+ pred_positions[..., None, :, :],
+ )
+ local_target_pos = target_frames.invert()[..., None].apply(
+ target_positions[..., None, :, :],
+ )
+
+ error_dist = torch.sqrt(
+ torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
+ )
+
+ if l1_clamp_distance is not None:
+ error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
+
+ normed_error = error_dist / length_scale
+ normed_error = normed_error * frames_mask[..., None]
+ normed_error = normed_error * positions_mask[..., None, :]
+ if ignore_nan:
+ normed_error = torch.nan_to_num(normed_error)
+
+ # FP16-friendly averaging. Roughly equivalent to:
+ #
+ # norm_factor = (
+ # torch.sum(frames_mask, dim=-1) *
+ # torch.sum(positions_mask, dim=-1)
+ # )
+ # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
+ #
+ # ("roughly" because eps is necessarily duplicated in the latter)
+ normed_error = torch.sum(normed_error, dim=-1)
+ normed_error = (
+ normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
+ )
+ normed_error = torch.sum(normed_error, dim=-1)
+ normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
+ return normed_error
+
+
+def backbone_loss(
+ backbone_rigid_tensor: torch.Tensor,
+ backbone_rigid_mask: torch.Tensor,
+ traj: torch.Tensor,
+ use_clamped_fape: Optional[torch.Tensor] = None,
+ clamp_distance: float = 10.0,
+ loss_unit_distance: float = 10.0,
+ eps: float = 1e-4,
+ **kwargs,
+) -> torch.Tensor:
+ pred_aff = Rigid.from_tensor_7(traj)
+ pred_aff = Rigid(
+ Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
+ pred_aff.get_trans(),
+ )
+
+ # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
+ # backbone tensor, normalizes it, and then turns it back to a rotation
+ # matrix. To avoid a potentially numerically unstable rotation matrix
+ # to quaternion conversion, we just use the original rotation matrix
+ # outright. This one hasn't been composed a bunch of times, though, so
+ # it might be fine.
+ gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
+
+ fape_loss = compute_fape(
+ pred_aff,
+ gt_aff[None],
+ backbone_rigid_mask[None],
+ pred_aff.get_trans(),
+ gt_aff[None].get_trans(),
+ backbone_rigid_mask[None],
+ l1_clamp_distance=clamp_distance,
+ length_scale=loss_unit_distance,
+ eps=eps,
+ )
+ if use_clamped_fape is not None:
+ unclamped_fape_loss = compute_fape(
+ pred_aff,
+ gt_aff[None],
+ backbone_rigid_mask[None],
+ pred_aff.get_trans(),
+ gt_aff[None].get_trans(),
+ backbone_rigid_mask[None],
+ l1_clamp_distance=None,
+ length_scale=loss_unit_distance,
+ eps=eps,
+ )
+
+ fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
+ 1 - use_clamped_fape
+ )
+
+ # Average over the batch dimension
+ fape_loss = torch.mean(fape_loss)
+
+ return fape_loss
+
+
+def sidechain_loss(
+ sidechain_frames: torch.Tensor,
+ sidechain_atom_pos: torch.Tensor,
+ rigidgroups_gt_frames: torch.Tensor,
+ rigidgroups_alt_gt_frames: torch.Tensor,
+ rigidgroups_gt_exists: torch.Tensor,
+ renamed_atom14_gt_positions: torch.Tensor,
+ renamed_atom14_gt_exists: torch.Tensor,
+ alt_naming_is_better: torch.Tensor,
+ clamp_distance: float = 10.0,
+ length_scale: float = 10.0,
+ eps: float = 1e-4,
+ **kwargs,
+) -> torch.Tensor:
+ renamed_gt_frames = (
+ 1.0 - alt_naming_is_better[..., None, None, None]
+ ) * rigidgroups_gt_frames + alt_naming_is_better[
+ ..., None, None, None
+ ] * rigidgroups_alt_gt_frames
+
+ # Steamroll the inputs
+ sidechain_frames = sidechain_frames[-1]
+ batch_dims = sidechain_frames.shape[:-4]
+ sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
+ sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
+ renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
+ renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
+ rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
+ sidechain_atom_pos = sidechain_atom_pos[-1]
+ sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
+ renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
+ *batch_dims, -1, 3
+ )
+ renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
+
+ fape = compute_fape(
+ sidechain_frames,
+ renamed_gt_frames,
+ rigidgroups_gt_exists,
+ sidechain_atom_pos,
+ renamed_atom14_gt_positions,
+ renamed_atom14_gt_exists,
+ l1_clamp_distance=clamp_distance,
+ length_scale=length_scale,
+ eps=eps,
+ )
+
+ return fape
+
+
+def fape_loss(
+ out: Dict[str, torch.Tensor],
+ batch: Dict[str, torch.Tensor],
+ config: ml_collections.ConfigDict,
+) -> torch.Tensor:
+ bb_loss = backbone_loss(
+ traj=out["sm"]["frames"],
+ **{**batch, **config.backbone},
+ )
+
+ sc_loss = sidechain_loss(
+ out["sm"]["sidechain_frames"],
+ out["sm"]["positions"],
+ **{**batch, **config.sidechain},
+ )
+
+ loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def supervised_chi_loss(
+ angles_sin_cos: torch.Tensor,
+ unnormalized_angles_sin_cos: torch.Tensor,
+ aatype: torch.Tensor,
+ seq_mask: torch.Tensor,
+ chi_mask: torch.Tensor,
+ chi_angles_sin_cos: torch.Tensor,
+ chi_weight: float,
+ angle_norm_weight: float,
+ eps=1e-6,
+ **kwargs,
+) -> torch.Tensor:
+ """
+ Implements Algorithm 27 (torsionAngleLoss)
+
+ Args:
+ angles_sin_cos:
+ [*, N, 7, 2] predicted angles
+ unnormalized_angles_sin_cos:
+ The same angles, but unnormalized
+ aatype:
+ [*, N] residue indices
+ seq_mask:
+ [*, N] sequence mask
+ chi_mask:
+ [*, N, 7] angle mask
+ chi_angles_sin_cos:
+ [*, N, 7, 2] ground truth angles
+ chi_weight:
+ Weight for the angle component of the loss
+ angle_norm_weight:
+ Weight for the normalization component of the loss
+ Returns:
+ [*] loss tensor
+ """
+ pred_angles = angles_sin_cos[..., 3:, :]
+ residue_type_one_hot = torch.nn.functional.one_hot(
+ aatype,
+ residue_constants.restype_num + 1,
+ )
+ chi_pi_periodic = torch.einsum(
+ "...ij,jk->ik",
+ residue_type_one_hot.type(angles_sin_cos.dtype),
+ angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
+ )
+
+ true_chi = chi_angles_sin_cos[None]
+
+ shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
+ true_chi_shifted = shifted_mask * true_chi
+ sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
+ sq_chi_error_shifted = torch.sum(
+ (true_chi_shifted - pred_angles) ** 2, dim=-1
+ )
+ sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
+ # The ol' switcheroo
+ sq_chi_error = sq_chi_error.permute(
+ *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
+ )
+ sq_chi_loss = masked_mean(
+ chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
+ )
+
+ loss = chi_weight * sq_chi_loss
+
+ angle_norm = torch.sqrt(
+ torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
+ )
+ norm_error = torch.abs(angle_norm - 1.0)
+ norm_error = norm_error.permute(
+ *range(len(norm_error.shape))[1:-2], 0, -2, -1
+ )
+ angle_norm_loss = masked_mean(
+ seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
+ )
+
+ loss = loss + angle_norm_weight * angle_norm_loss
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
+ num_bins = logits.shape[-1]
+ bin_width = 1.0 / num_bins
+ bounds = torch.arange(
+ start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
+ )
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ pred_lddt_ca = torch.sum(
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
+ dim=-1,
+ )
+ return pred_lddt_ca * 100
+
+
+def lddt(
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ cutoff: float = 15.0,
+ eps: float = 1e-10,
+ per_residue: bool = True,
+) -> torch.Tensor:
+ n = all_atom_mask.shape[-2]
+ dmat_true = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ all_atom_positions[..., None, :]
+ - all_atom_positions[..., None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ dmat_pred = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ all_atom_pred_pos[..., None, :]
+ - all_atom_pred_pos[..., None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+ dists_to_score = (
+ (dmat_true < cutoff)
+ * all_atom_mask
+ * permute_final_dims(all_atom_mask, (1, 0))
+ * (1.0 - torch.eye(n, device=all_atom_mask.device))
+ )
+
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
+
+ score = (
+ (dist_l1 < 0.5).type(dist_l1.dtype)
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
+ )
+ score = score * 0.25
+
+ dims = (-1,) if per_residue else (-2, -1)
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
+
+ return score
+
+
+def lddt_ca(
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ cutoff: float = 15.0,
+ eps: float = 1e-10,
+ per_residue: bool = True,
+) -> torch.Tensor:
+ ca_pos = residue_constants.atom_order["CA"]
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
+
+ return lddt(
+ all_atom_pred_pos,
+ all_atom_positions,
+ all_atom_mask,
+ cutoff=cutoff,
+ eps=eps,
+ per_residue=per_residue,
+ )
+
+
+def lddt_loss(
+ logits: torch.Tensor,
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ resolution: torch.Tensor,
+ cutoff: float = 15.0,
+ no_bins: int = 50,
+ min_resolution: float = 0.1,
+ max_resolution: float = 3.0,
+ eps: float = 1e-10,
+ **kwargs,
+) -> torch.Tensor:
+ n = all_atom_mask.shape[-2]
+
+ ca_pos = residue_constants.atom_order["CA"]
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
+
+ score = lddt(
+ all_atom_pred_pos,
+ all_atom_positions,
+ all_atom_mask,
+ cutoff=cutoff,
+ eps=eps
+ )
+
+ score = score.detach()
+
+ bin_index = torch.floor(score * no_bins).long()
+ bin_index = torch.clamp(bin_index, max=(no_bins - 1))
+ lddt_ca_one_hot = torch.nn.functional.one_hot(
+ bin_index, num_classes=no_bins
+ )
+
+ errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
+ all_atom_mask = all_atom_mask.squeeze(-1)
+ loss = torch.sum(errors * all_atom_mask, dim=-1) / (
+ eps + torch.sum(all_atom_mask, dim=-1)
+ )
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def distogram_loss(
+ logits,
+ pseudo_beta,
+ pseudo_beta_mask,
+ min_bin=2.3125,
+ max_bin=21.6875,
+ no_bins=64,
+ eps=1e-6,
+ **kwargs,
+):
+ boundaries = torch.linspace(
+ min_bin,
+ max_bin,
+ no_bins - 1,
+ device=logits.device,
+ )
+ boundaries = boundaries ** 2
+
+ dists = torch.sum(
+ (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
+ dim=-1,
+ keepdims=True,
+ )
+
+ true_bins = torch.sum(dists > boundaries, dim=-1)
+
+ errors = softmax_cross_entropy(
+ logits,
+ torch.nn.functional.one_hot(true_bins, no_bins),
+ )
+
+ square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
+
+ # FP16-friendly sum. Equivalent to:
+ # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
+ # (eps + torch.sum(square_mask, dim=(-1, -2))))
+ denom = eps + torch.sum(square_mask, dim=(-1, -2))
+ mean = errors * square_mask
+ mean = torch.sum(mean, dim=-1)
+ mean = mean / denom[..., None]
+ mean = torch.sum(mean, dim=-1)
+
+ # Average over the batch dimensions
+ mean = torch.mean(mean)
+
+ return mean
+
+
+def _calculate_bin_centers(boundaries: torch.Tensor):
+ step = boundaries[1] - boundaries[0]
+ bin_centers = boundaries + step / 2
+ bin_centers = torch.cat(
+ [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
+ )
+ return bin_centers
+
+
+def _calculate_expected_aligned_error(
+ alignment_confidence_breaks: torch.Tensor,
+ aligned_distance_error_probs: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
+ return (
+ torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
+ bin_centers[-1],
+ )
+
+
+def compute_predicted_aligned_error(
+ logits: torch.Tensor,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ **kwargs,
+) -> Dict[str, torch.Tensor]:
+ """Computes aligned confidence metrics from logits.
+
+ Args:
+ logits: [*, num_res, num_res, num_bins] the logits output from
+ PredictedAlignedErrorHead.
+ max_bin: Maximum bin value
+ no_bins: Number of bins
+ Returns:
+ aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
+ aligned error probabilities over bins for each residue pair.
+ predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
+ error for each pair of residues.
+ max_predicted_aligned_error: [*] the maximum predicted error possible.
+ """
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+
+ aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
+ (
+ predicted_aligned_error,
+ max_predicted_aligned_error,
+ ) = _calculate_expected_aligned_error(
+ alignment_confidence_breaks=boundaries,
+ aligned_distance_error_probs=aligned_confidence_probs,
+ )
+
+ return {
+ "aligned_confidence_probs": aligned_confidence_probs,
+ "predicted_aligned_error": predicted_aligned_error,
+ "max_predicted_aligned_error": max_predicted_aligned_error,
+ }
+
+
+def compute_tm(
+ logits: torch.Tensor,
+ residue_weights: Optional[torch.Tensor] = None,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ eps: float = 1e-8,
+ **kwargs,
+) -> torch.Tensor:
+ if residue_weights is None:
+ residue_weights = logits.new_ones(logits.shape[-2])
+
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+
+ bin_centers = _calculate_bin_centers(boundaries)
+ torch.sum(residue_weights)
+ n = logits.shape[-2]
+ clipped_n = max(n, 19)
+
+ d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
+ predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
+
+ normed_residue_mask = residue_weights / (eps + residue_weights.sum())
+ per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
+ weighted = per_alignment * residue_weights
+ argmax = (weighted == torch.max(weighted)).nonzero()[0]
+ return per_alignment[tuple(argmax)]
+
+
+def tm_loss(
+ logits,
+ final_affine_tensor,
+ backbone_rigid_tensor,
+ backbone_rigid_mask,
+ resolution,
+ max_bin=31,
+ no_bins=64,
+ min_resolution: float = 0.1,
+ max_resolution: float = 3.0,
+ eps=1e-8,
+ **kwargs,
+):
+ pred_affine = Rigid.from_tensor_7(final_affine_tensor)
+ backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
+
+ def _points(affine):
+ pts = affine.get_trans()[..., None, :, :]
+ return affine.invert()[..., None].apply(pts)
+
+ sq_diff = torch.sum(
+ (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
+ )
+
+ sq_diff = sq_diff.detach()
+
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+ boundaries = boundaries ** 2
+ true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
+
+ errors = softmax_cross_entropy(
+ logits, torch.nn.functional.one_hot(true_bins, no_bins)
+ )
+
+ square_mask = (
+ backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
+ )
+
+ loss = torch.sum(errors * square_mask, dim=-1)
+ scale = 0.5 # hack to help FP16 training along
+ denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
+ loss = loss / denom[..., None]
+ loss = torch.sum(loss, dim=-1)
+ loss = loss * scale
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ # Average over the loss dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def between_residue_bond_loss(
+ pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
+ pred_atom_mask: torch.Tensor, # (*, N, 37/14)
+ residue_index: torch.Tensor, # (*, N)
+ aatype: torch.Tensor, # (*, N)
+ tolerance_factor_soft=12.0,
+ tolerance_factor_hard=12.0,
+ eps=1e-6,
+) -> Dict[str, torch.Tensor]:
+ """Flat-bottom loss to penalize structural violations between residues.
+
+ This is a loss penalizing any violation of the geometry around the peptide
+ bond between consecutive amino acids. This loss corresponds to
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
+
+ Args:
+ pred_atom_positions: Atom positions in atom37/14 representation
+ pred_atom_mask: Atom mask in atom37/14 representation
+ residue_index: Residue index for given amino acid, this is assumed to be
+ monotonically increasing.
+ aatype: Amino acid type of given residue
+ tolerance_factor_soft: soft tolerance factor measured in standard deviations
+ of pdb distributions
+ tolerance_factor_hard: hard tolerance factor measured in standard deviations
+ of pdb distributions
+
+ Returns:
+ Dict containing:
+ * 'c_n_loss_mean': Loss for peptide bond length violations
+ * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
+ by CA, C, N
+ * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
+ by C, N, CA
+ * 'per_residue_loss_sum': sum of all losses for each residue
+ * 'per_residue_violation_mask': mask denoting all residues with violation
+ present.
+ """
+ # Get the positions of the relevant backbone atoms.
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
+ this_c_pos = pred_atom_positions[..., :-1, 2, :]
+ this_c_mask = pred_atom_mask[..., :-1, 2]
+ next_n_pos = pred_atom_positions[..., 1:, 0, :]
+ next_n_mask = pred_atom_mask[..., 1:, 0]
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
+
+ # Compute loss for the C--N bond.
+ c_n_bond_length = torch.sqrt(
+ eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
+ )
+
+ # The C-N bond to proline has slightly different length because of the ring.
+ next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
+ gt_length = (
+ ~next_is_proline
+ ) * residue_constants.between_res_bond_length_c_n[
+ 0
+ ] + next_is_proline * residue_constants.between_res_bond_length_c_n[
+ 1
+ ]
+ gt_stddev = (
+ ~next_is_proline
+ ) * residue_constants.between_res_bond_length_stddev_c_n[
+ 0
+ ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
+ 1
+ ]
+ c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
+ c_n_loss_per_residue = torch.nn.functional.relu(
+ c_n_bond_length_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_c_mask * next_n_mask * has_no_gap_mask
+ c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ c_n_violation_mask = mask * (
+ c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ # Compute loss for the angles.
+ ca_c_bond_length = torch.sqrt(
+ eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
+ )
+ n_ca_bond_length = torch.sqrt(
+ eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
+ )
+
+ c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
+ c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
+ n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
+
+ ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
+ gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
+ gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
+ ca_c_n_cos_angle_error = torch.sqrt(
+ eps + (ca_c_n_cos_angle - gt_angle) ** 2
+ )
+ ca_c_n_loss_per_residue = torch.nn.functional.relu(
+ ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
+ ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ ca_c_n_violation_mask = mask * (
+ ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
+ gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
+ gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
+ c_n_ca_cos_angle_error = torch.sqrt(
+ eps + torch.square(c_n_ca_cos_angle - gt_angle)
+ )
+ c_n_ca_loss_per_residue = torch.nn.functional.relu(
+ c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
+ c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ c_n_ca_violation_mask = mask * (
+ c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ # Compute a per residue loss (equally distribute the loss to both
+ # neighbouring residues).
+ per_residue_loss_sum = (
+ c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
+ )
+ per_residue_loss_sum = 0.5 * (
+ torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ + torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
+ )
+
+ # Compute hard violations.
+ violation_mask = torch.max(
+ torch.stack(
+ [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
+ dim=-2,
+ ),
+ dim=-2,
+ )[0]
+ violation_mask = torch.maximum(
+ torch.nn.functional.pad(violation_mask, (0, 1)),
+ torch.nn.functional.pad(violation_mask, (1, 0)),
+ )
+
+ return {
+ "c_n_loss_mean": c_n_loss,
+ "ca_c_n_loss_mean": ca_c_n_loss,
+ "c_n_ca_loss_mean": c_n_ca_loss,
+ "per_residue_loss_sum": per_residue_loss_sum,
+ "per_residue_violation_mask": violation_mask,
+ }
+
+
+def between_residue_clash_loss(
+ atom14_pred_positions: torch.Tensor,
+ atom14_atom_exists: torch.Tensor,
+ atom14_atom_radius: torch.Tensor,
+ residue_index: torch.Tensor,
+ overlap_tolerance_soft=1.5,
+ overlap_tolerance_hard=1.5,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """Loss to penalize steric clashes between residues.
+
+ This is a loss penalizing any steric clashes due to non bonded atoms in
+ different peptides coming too close. This loss corresponds to the part with
+ different residues of
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
+
+ Args:
+ atom14_pred_positions: Predicted positions of atoms in
+ global prediction frame
+ atom14_atom_exists: Mask denoting whether atom at positions exists for given
+ amino acid type
+ atom14_atom_radius: Van der Waals radius for each atom.
+ residue_index: Residue index for given amino acid.
+ overlap_tolerance_soft: Soft tolerance factor.
+ overlap_tolerance_hard: Hard tolerance factor.
+
+ Returns:
+ Dict containing:
+ * 'mean_loss': average clash loss
+ * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
+ * 'per_atom_clash_mask': mask whether atom clashes with any other atom
+ shape (N, 14)
+ """
+ fp_type = atom14_pred_positions.dtype
+
+ # Create the distance matrix.
+ # (N, N, 14, 14)
+ dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., :, None, :, None, :]
+ - atom14_pred_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ # Create the mask for valid distances.
+ # shape (N, N, 14, 14)
+ dists_mask = (
+ atom14_atom_exists[..., :, None, :, None]
+ * atom14_atom_exists[..., None, :, None, :]
+ ).type(fp_type)
+
+ # Mask out all the duplicate entries in the lower triangular matrix.
+ # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
+ # are handled separately.
+ dists_mask = dists_mask * (
+ residue_index[..., :, None, None, None]
+ < residue_index[..., None, :, None, None]
+ )
+
+ # Backbone C--N bond between subsequent residues is no clash.
+ c_one_hot = torch.nn.functional.one_hot(
+ residue_index.new_tensor(2), num_classes=14
+ )
+ c_one_hot = c_one_hot.reshape(
+ *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
+ )
+ c_one_hot = c_one_hot.type(fp_type)
+ n_one_hot = torch.nn.functional.one_hot(
+ residue_index.new_tensor(0), num_classes=14
+ )
+ n_one_hot = n_one_hot.reshape(
+ *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
+ )
+ n_one_hot = n_one_hot.type(fp_type)
+
+ neighbour_mask = (
+ residue_index[..., :, None, None, None] + 1
+ ) == residue_index[..., None, :, None, None]
+ c_n_bonds = (
+ neighbour_mask
+ * c_one_hot[..., None, None, :, None]
+ * n_one_hot[..., None, None, None, :]
+ )
+ dists_mask = dists_mask * (1.0 - c_n_bonds)
+
+ # Disulfide bridge between two cysteines is no clash.
+ cys = residue_constants.restype_name_to_atom14_names["CYS"]
+ cys_sg_idx = cys.index("SG")
+ cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
+ cys_sg_idx = cys_sg_idx.reshape(
+ *((1,) * len(residue_index.shape[:-1])), 1
+ ).squeeze(-1)
+ cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
+ disulfide_bonds = (
+ cys_sg_one_hot[..., None, None, :, None]
+ * cys_sg_one_hot[..., None, None, None, :]
+ )
+ dists_mask = dists_mask * (1.0 - disulfide_bonds)
+
+ # Compute the lower bound for the allowed distances.
+ # shape (N, N, 14, 14)
+ dists_lower_bound = dists_mask * (
+ atom14_atom_radius[..., :, None, :, None]
+ + atom14_atom_radius[..., None, :, None, :]
+ )
+
+ # Compute the error.
+ # shape (N, N, 14, 14)
+ dists_to_low_error = dists_mask * torch.nn.functional.relu(
+ dists_lower_bound - overlap_tolerance_soft - dists
+ )
+
+ # Compute the mean loss.
+ # shape ()
+ mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
+
+ # Compute the per atom loss sum.
+ # shape (N, 14)
+ per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
+ dists_to_low_error, axis=(-3, -1)
+ )
+
+ # Compute the hard clash mask.
+ # shape (N, N, 14, 14)
+ clash_mask = dists_mask * (
+ dists < (dists_lower_bound - overlap_tolerance_hard)
+ )
+
+ # Compute the per atom clash.
+ # shape (N, 14)
+ per_atom_clash_mask = torch.maximum(
+ torch.amax(clash_mask, axis=(-4, -2)),
+ torch.amax(clash_mask, axis=(-3, -1)),
+ )
+
+ return {
+ "mean_loss": mean_loss, # shape ()
+ "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
+ "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
+ }
+
+
+def within_residue_violations(
+ atom14_pred_positions: torch.Tensor,
+ atom14_atom_exists: torch.Tensor,
+ atom14_dists_lower_bound: torch.Tensor,
+ atom14_dists_upper_bound: torch.Tensor,
+ tighten_bounds_for_loss=0.0,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """Loss to penalize steric clashes within residues.
+
+ This is a loss penalizing any steric violations or clashes of non-bonded atoms
+ in a given peptide. This loss corresponds to the part with
+ the same residues of
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
+
+ Args:
+ atom14_pred_positions ([*, N, 14, 3]):
+ Predicted positions of atoms in global prediction frame.
+ atom14_atom_exists ([*, N, 14]):
+ Mask denoting whether atom at positions exists for given
+ amino acid type
+ atom14_dists_lower_bound ([*, N, 14]):
+ Lower bound on allowed distances.
+ atom14_dists_upper_bound ([*, N, 14]):
+ Upper bound on allowed distances
+ tighten_bounds_for_loss ([*, N]):
+ Extra factor to tighten loss
+
+ Returns:
+ Dict containing:
+ * 'per_atom_loss_sum' ([*, N, 14]):
+ sum of all clash losses per atom, shape
+ * 'per_atom_clash_mask' ([*, N, 14]):
+ mask whether atom clashes with any other atom shape
+ """
+ # Compute the mask for each residue.
+ dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
+ dists_masks = dists_masks.reshape(
+ *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
+ )
+ dists_masks = (
+ atom14_atom_exists[..., :, :, None]
+ * atom14_atom_exists[..., :, None, :]
+ * dists_masks
+ )
+
+ # Distance matrix
+ dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., :, :, None, :]
+ - atom14_pred_positions[..., :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ # Compute the loss.
+ dists_to_low_error = torch.nn.functional.relu(
+ atom14_dists_lower_bound + tighten_bounds_for_loss - dists
+ )
+ dists_to_high_error = torch.nn.functional.relu(
+ dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
+ )
+ loss = dists_masks * (dists_to_low_error + dists_to_high_error)
+
+ # Compute the per atom loss sum.
+ per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
+
+ # Compute the violations mask.
+ violations = dists_masks * (
+ (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
+ )
+
+ # Compute the per atom violations.
+ per_atom_violations = torch.maximum(
+ torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
+ )
+
+ return {
+ "per_atom_loss_sum": per_atom_loss_sum,
+ "per_atom_violations": per_atom_violations,
+ }
+
+
+def find_structural_violations(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor,
+ violation_tolerance_factor: float,
+ clash_overlap_tolerance: float,
+ **kwargs,
+) -> Dict[str, torch.Tensor]:
+ """Computes several checks for structural violations."""
+
+ # Compute between residue backbone violations of bonds and angles.
+ connection_violations = between_residue_bond_loss(
+ pred_atom_positions=atom14_pred_positions,
+ pred_atom_mask=batch["atom14_atom_exists"],
+ residue_index=batch["residue_index"],
+ aatype=batch["aatype"],
+ tolerance_factor_soft=violation_tolerance_factor,
+ tolerance_factor_hard=violation_tolerance_factor,
+ )
+
+ # Compute the Van der Waals radius for every atom
+ # (the first letter of the atom name is the element type).
+ # Shape: (N, 14).
+ atomtype_radius = [
+ residue_constants.van_der_waals_radius[name[0]]
+ for name in residue_constants.atom_types
+ ]
+ atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
+ atom14_atom_radius = (
+ batch["atom14_atom_exists"]
+ * atomtype_radius[batch["residx_atom14_to_atom37"]]
+ )
+
+ # Compute the between residue clash loss.
+ between_residue_clashes = between_residue_clash_loss(
+ atom14_pred_positions=atom14_pred_positions,
+ atom14_atom_exists=batch["atom14_atom_exists"],
+ atom14_atom_radius=atom14_atom_radius,
+ residue_index=batch["residue_index"],
+ overlap_tolerance_soft=clash_overlap_tolerance,
+ overlap_tolerance_hard=clash_overlap_tolerance,
+ )
+
+ # Compute all within-residue violations (clashes,
+ # bond length and angle violations).
+ restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
+ overlap_tolerance=clash_overlap_tolerance,
+ bond_length_tolerance_factor=violation_tolerance_factor,
+ )
+ atom14_atom_exists = batch["atom14_atom_exists"]
+ atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
+ restype_atom14_bounds["lower_bound"]
+ )[batch["aatype"]]
+ atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
+ restype_atom14_bounds["upper_bound"]
+ )[batch["aatype"]]
+ residue_violations = within_residue_violations(
+ atom14_pred_positions=atom14_pred_positions,
+ atom14_atom_exists=batch["atom14_atom_exists"],
+ atom14_dists_lower_bound=atom14_dists_lower_bound,
+ atom14_dists_upper_bound=atom14_dists_upper_bound,
+ tighten_bounds_for_loss=0.0,
+ )
+
+ # Combine them to a single per-residue violation mask (used later for LDDT).
+ per_residue_violations_mask = torch.max(
+ torch.stack(
+ [
+ connection_violations["per_residue_violation_mask"],
+ torch.max(
+ between_residue_clashes["per_atom_clash_mask"], dim=-1
+ )[0],
+ torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
+ ],
+ dim=-1,
+ ),
+ dim=-1,
+ )[0]
+
+ return {
+ "between_residues": {
+ "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
+ "angles_ca_c_n_loss_mean": connection_violations[
+ "ca_c_n_loss_mean"
+ ], # ()
+ "angles_c_n_ca_loss_mean": connection_violations[
+ "c_n_ca_loss_mean"
+ ], # ()
+ "connections_per_residue_loss_sum": connection_violations[
+ "per_residue_loss_sum"
+ ], # (N)
+ "connections_per_residue_violation_mask": connection_violations[
+ "per_residue_violation_mask"
+ ], # (N)
+ "clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
+ "clashes_per_atom_loss_sum": between_residue_clashes[
+ "per_atom_loss_sum"
+ ], # (N, 14)
+ "clashes_per_atom_clash_mask": between_residue_clashes[
+ "per_atom_clash_mask"
+ ], # (N, 14)
+ },
+ "within_residues": {
+ "per_atom_loss_sum": residue_violations[
+ "per_atom_loss_sum"
+ ], # (N, 14)
+ "per_atom_violations": residue_violations[
+ "per_atom_violations"
+ ], # (N, 14),
+ },
+ "total_per_residue_violations_mask": per_residue_violations_mask, # (N)
+ }
+
+
+def find_structural_violations_np(
+ batch: Dict[str, np.ndarray],
+ atom14_pred_positions: np.ndarray,
+ config: ml_collections.ConfigDict,
+) -> Dict[str, np.ndarray]:
+ to_tensor = lambda x: torch.tensor(x)
+ batch = tree_map(to_tensor, batch, np.ndarray)
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
+
+ out = find_structural_violations(batch, atom14_pred_positions, **config)
+
+ to_np = lambda x: np.array(x)
+ np_out = tensor_tree_map(to_np, out)
+
+ return np_out
+
+
+def extreme_ca_ca_distance_violations(
+ pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
+ pred_atom_mask: torch.Tensor, # (N, 37(14))
+ residue_index: torch.Tensor, # (N)
+ max_angstrom_tolerance=1.5,
+ eps=1e-6,
+) -> torch.Tensor:
+ """Counts residues whose Ca is a large distance from its neighbour.
+
+ Measures the fraction of CA-CA pairs between consecutive amino acids that are
+ more than 'max_angstrom_tolerance' apart.
+
+ Args:
+ pred_atom_positions: Atom positions in atom37/14 representation
+ pred_atom_mask: Atom mask in atom37/14 representation
+ residue_index: Residue index for given amino acid, this is assumed to be
+ monotonically increasing.
+ max_angstrom_tolerance: Maximum distance allowed to not count as violation.
+ Returns:
+ Fraction of consecutive CA-CA pairs with violation.
+ """
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
+ ca_ca_distance = torch.sqrt(
+ eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
+ )
+ violations = (
+ ca_ca_distance - residue_constants.ca_ca
+ ) > max_angstrom_tolerance
+ mask = this_ca_mask * next_ca_mask * has_no_gap_mask
+ mean = masked_mean(mask, violations, -1)
+ return mean
+
+
+def compute_violation_metrics(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor, # (N, 14, 3)
+ violations: Dict[str, torch.Tensor],
+) -> Dict[str, torch.Tensor]:
+ """Compute several metrics to assess the structural violations."""
+ ret = {}
+ extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
+ pred_atom_positions=atom14_pred_positions,
+ pred_atom_mask=batch["atom14_atom_exists"],
+ residue_index=batch["residue_index"],
+ )
+ ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
+ ret["violations_between_residue_bond"] = masked_mean(
+ batch["seq_mask"],
+ violations["between_residues"][
+ "connections_per_residue_violation_mask"
+ ],
+ dim=-1,
+ )
+ ret["violations_between_residue_clash"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=torch.max(
+ violations["between_residues"]["clashes_per_atom_clash_mask"],
+ dim=-1,
+ )[0],
+ dim=-1,
+ )
+ ret["violations_within_residue"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=torch.max(
+ violations["within_residues"]["per_atom_violations"], dim=-1
+ )[0],
+ dim=-1,
+ )
+ ret["violations_per_residue"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=violations["total_per_residue_violations_mask"],
+ dim=-1,
+ )
+ return ret
+
+
+def compute_violation_metrics_np(
+ batch: Dict[str, np.ndarray],
+ atom14_pred_positions: np.ndarray,
+ violations: Dict[str, np.ndarray],
+) -> Dict[str, np.ndarray]:
+ to_tensor = lambda x: torch.tensor(x)
+ batch = tree_map(to_tensor, batch, np.ndarray)
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
+ violations = tree_map(to_tensor, violations, np.ndarray)
+
+ out = compute_violation_metrics(batch, atom14_pred_positions, violations)
+
+ to_np = lambda x: np.array(x)
+ return tree_map(to_np, out, torch.Tensor)
+
+
+def violation_loss(
+ violations: Dict[str, torch.Tensor],
+ atom14_atom_exists: torch.Tensor,
+ eps=1e-6,
+ **kwargs,
+) -> torch.Tensor:
+ num_atoms = torch.sum(atom14_atom_exists)
+ l_clash = torch.sum(
+ violations["between_residues"]["clashes_per_atom_loss_sum"]
+ + violations["within_residues"]["per_atom_loss_sum"]
+ )
+ l_clash = l_clash / (eps + num_atoms)
+ loss = (
+ violations["between_residues"]["bonds_c_n_loss_mean"]
+ + violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ + violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ + l_clash
+ )
+
+ return loss
+
+
+def compute_renamed_ground_truth(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """
+ Find optimal renaming of ground truth based on the predicted positions.
+
+ Alg. 26 "renameSymmetricGroundTruthAtoms"
+
+ This renamed ground truth is then used for all losses,
+ such that each loss moves the atoms in the same direction.
+
+ Args:
+ batch: Dictionary containing:
+ * atom14_gt_positions: Ground truth positions.
+ * atom14_alt_gt_positions: Ground truth positions with renaming swaps.
+ * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
+ renaming swaps.
+ * atom14_gt_exists: Mask for which atoms exist in ground truth.
+ * atom14_alt_gt_exists: Mask for which atoms exist in ground truth
+ after renaming.
+ * atom14_atom_exists: Mask for whether each atom is part of the given
+ amino acid type.
+ atom14_pred_positions: Array of atom positions in global frame with shape
+ Returns:
+ Dictionary containing:
+ alt_naming_is_better: Array with 1.0 where alternative swap is better.
+ renamed_atom14_gt_positions: Array of optimal ground truth positions
+ after renaming swaps are performed.
+ renamed_atom14_gt_exists: Mask after renaming swap is performed.
+ """
+
+ pred_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., None, :, None, :]
+ - atom14_pred_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ atom14_gt_positions = batch["atom14_gt_positions"]
+ gt_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_gt_positions[..., None, :, None, :]
+ - atom14_gt_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
+ alt_gt_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_alt_gt_positions[..., None, :, None, :]
+ - atom14_alt_gt_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
+ alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
+
+ atom14_gt_exists = batch["atom14_gt_exists"]
+ atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
+ mask = (
+ atom14_gt_exists[..., None, :, None]
+ * atom14_atom_is_ambiguous[..., None, :, None]
+ * atom14_gt_exists[..., None, :, None, :]
+ * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
+ )
+
+ per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
+ alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
+
+ fp_type = atom14_pred_positions.dtype
+ alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
+
+ renamed_atom14_gt_positions = (
+ 1.0 - alt_naming_is_better[..., None, None]
+ ) * atom14_gt_positions + alt_naming_is_better[
+ ..., None, None
+ ] * atom14_alt_gt_positions
+
+ renamed_atom14_gt_mask = (
+ 1.0 - alt_naming_is_better[..., None]
+ ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
+ "atom14_alt_gt_exists"
+ ]
+
+ return {
+ "alt_naming_is_better": alt_naming_is_better,
+ "renamed_atom14_gt_positions": renamed_atom14_gt_positions,
+ "renamed_atom14_gt_exists": renamed_atom14_gt_mask,
+ }
+
+
+def experimentally_resolved_loss(
+ logits: torch.Tensor,
+ atom37_atom_exists: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ resolution: torch.Tensor,
+ min_resolution: float,
+ max_resolution: float,
+ eps: float = 1e-8,
+ **kwargs,
+) -> torch.Tensor:
+ errors = sigmoid_cross_entropy(logits, all_atom_mask)
+ loss = torch.sum(errors * atom37_atom_exists, dim=-1)
+ loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
+ loss = torch.sum(loss, dim=-1)
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
+ """
+ Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
+
+ Args:
+ logits: [*, N_seq, N_res, 23] predicted residue distribution
+ true_msa: [*, N_seq, N_res] true MSA
+ bert_mask: [*, N_seq, N_res] MSA mask
+ Returns:
+ Masked MSA loss
+ """
+ errors = softmax_cross_entropy(
+ logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
+ )
+
+ # FP16-friendly averaging. Equivalent to:
+ # loss = (
+ # torch.sum(errors * bert_mask, dim=(-1, -2)) /
+ # (eps + torch.sum(bert_mask, dim=(-1, -2)))
+ # )
+ loss = errors * bert_mask
+ loss = torch.sum(loss, dim=-1)
+ scale = 0.5
+ denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
+ loss = loss / denom[..., None]
+ loss = torch.sum(loss, dim=-1)
+ loss = loss * scale
+
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def compute_drmsd(structure_1, structure_2, mask=None):
+ if(mask is not None):
+ structure_1 = structure_1 * mask[..., None]
+ structure_2 = structure_2 * mask[..., None]
+
+ d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
+ d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
+
+ d1 = d1 ** 2
+ d2 = d2 ** 2
+
+ d1 = torch.sqrt(torch.sum(d1, dim=-1))
+ d2 = torch.sqrt(torch.sum(d2, dim=-1))
+
+ drmsd = d1 - d2
+ drmsd = drmsd ** 2
+ drmsd = torch.sum(drmsd, dim=(-1, -2))
+ n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
+ drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
+ drmsd = torch.sqrt(drmsd)
+
+ return drmsd
+
+
+def compute_drmsd_np(structure_1, structure_2, mask=None):
+ structure_1 = torch.tensor(structure_1)
+ structure_2 = torch.tensor(structure_2)
+ if(mask is not None):
+ mask = torch.tensor(mask)
+
+ return compute_drmsd(structure_1, structure_2, mask)
+
+
+def backbone_atom_loss(
+ pred_atom37: torch.Tensor,
+ batch: Dict[str, torch.Tensor],
+ mask: torch.Tensor = None,
+ eps: float = 1e-4,
+ t_threshold: Optional[float] = None,
+ **kwargs,
+):
+ pred_backb_atoms = pred_atom37[:, :, :5] # (B, L, 5, 3)
+ gt_rigids = batch['rigids_0']
+ gt_psi = batch["torsion_angles_sin_cos"][..., 2, :]
+
+ gt_atom37, atom37_mask, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"])
+ gt_backb_atoms, backb_mask = gt_atom37[:, :, :5], atom37_mask[:, :, :5]
+
+ if mask is not None:
+ backb_mask = backb_mask * mask[..., None] # (B, L, 5)
+
+ backb_atom_loss = torch.sum(
+ (pred_backb_atoms - gt_backb_atoms)**2 * backb_mask[..., None],
+ dim=(-1, -2, -3)
+ ) / (backb_mask.sum(dim=(-1, -2)) + eps)
+
+ if t_threshold is not None:
+ backb_atom_loss = backb_atom_loss * (batch['t'] < t_threshold)
+ return torch.mean(backb_atom_loss)
+
+
+def pairwise_distance_loss(
+ pred_atom37: torch.Tensor,
+ batch: Dict[str, torch.Tensor],
+ mask: torch.Tensor = None,
+ eps: float = 1e-4,
+ t_threshold: Optional[float] = None,
+ dist_threshold: float = 6.0,
+ **kwargs,
+):
+ batch_size, n_res = pred_atom37.shape[:2]
+ pred_backb_atoms = pred_atom37[:, :, :5].reshape(batch_size, -1, 3)
+
+ gt_rigids = batch['rigids_0']
+ gt_psi = batch["torsion_angles_sin_cos"][..., 2, :]
+ gt_atom37, _, _, _ = compute_backbone(gt_rigids, gt_psi, batch["aatype"])
+ gt_backb_atoms = gt_atom37[:, :, :5].reshape(batch_size, -1, 3)
+
+ # Configure masks.
+ residue_mask = batch['seq_mask']
+ if mask is not None:
+ residue_mask = residue_mask * mask
+ residue_mask = torch.tile(residue_mask[:, :, None], (1, 1, 5)).view(batch_size, -1)
+
+ gt_pwd = torch.linalg.norm(
+ gt_backb_atoms[:, :, None, :] - gt_backb_atoms[:, None, :, :],
+ dim=-1
+ ) * residue_mask[..., None]
+ pred_pwd = torch.linalg.norm(
+ pred_backb_atoms[:, :, None, :] - pred_backb_atoms[:, None, :, :],
+ dim=-1
+ ) * residue_mask[..., None]
+
+
+ pair_mask = residue_mask[:, :, None] * residue_mask[:, None, :] # atom-wise
+ pair_mask = pair_mask * (pred_pwd < dist_threshold)
+ pwd_loss = torch.sum(
+ (gt_pwd - pred_pwd)**2 * pair_mask, dim=(-1, -2)
+ ) / (torch.sum(pair_mask, dim=(-1, -2)) - n_res + eps)
+
+ if t_threshold is not None:
+ pwd_loss = pwd_loss * (batch['t'] < t_threshold)
+ return torch.mean(pwd_loss)
+
+
+
+#################### Training Losses ####################
+
+
+class ScoreMatchingLoss(nn.Module):
+ """Aggregation of the various losses described in the supplement"""
+ def __init__(self, config):
+ super(ScoreMatchingLoss, self).__init__()
+ self.config = config # config.loss
+
+ def forward(self, out, batch, _return_breakdown=False):
+ # Configure masks.
+ seq_mask = batch['seq_mask']
+ diffuse_mask = 1. - batch['fixed_mask']
+ loss_mask = seq_mask * diffuse_mask # (B, L)
+ _denom = sum_except_batch(loss_mask) + self.config.eps # normalizing sum
+
+ ###
+ # Score Matching loss.
+ ###
+ pred_rot_score = out['rot_score'] * diffuse_mask[..., None]
+ pred_trans_score = out['trans_score'] * diffuse_mask[..., None]
+ gt_rot_score = batch['rot_score'] * diffuse_mask[..., None]
+ gt_trans_score = batch['trans_score'] * diffuse_mask[..., None]
+ # Translation component.
+ trans_score_loss = (gt_trans_score - pred_trans_score) * loss_mask[..., None]
+ trans_score_loss /= inflate_array_like(batch['trans_score_scaling'], trans_score_loss)
+ trans_score_loss = torch.sum(trans_score_loss**2, dim=(-1, -2)) / _denom
+ # Alternative x0 loss.
+ trans_x0_loss = (self.config.translation.coordinate_scaling *
+ (batch['rigids_0'].get_trans() - out['rigids'].get_trans()) *
+ loss_mask[..., None]
+ )
+ trans_x0_loss = torch.sum(trans_x0_loss**2, dim=(-1, -2)) / _denom
+ trans_loss = torch.mean(
+ trans_score_loss * (batch['t'] > self.config.translation.x0_threshold) +
+ trans_x0_loss * (batch['t'] <= self.config.translation.x0_threshold)
+ )
+ # Rotation component.
+ rot_loss = (gt_rot_score - pred_rot_score) * loss_mask[..., None]
+ rot_loss /= inflate_array_like(batch['rot_score_scaling'], rot_loss)
+ rot_loss = torch.mean(torch.sum(rot_loss**2, dim=(-1, -2)) / _denom)
+
+ loss_fns = {
+ "translation": lambda: trans_loss,
+ "rotation": lambda: rot_loss,
+ }
+
+ # Auxiliary folding loss.
+ if self.config.distogram.enabled:
+ loss_fns["distogram"] = lambda: distogram_loss(
+ logits=out["distogram_logits"],
+ **{**batch, **self.config.distogram},
+ )
+ if self.config.supervised_chi.enabled:
+ loss_fns["supervised_chi"] = lambda: supervised_chi_loss(
+ out["sm"]["angles"],
+ out["sm"]["unnormalized_angles"],
+ **{**batch, **self.config.supervised_chi},
+ )
+ if self.config.lddt.enabled:
+ loss_fns["lddt"] = lambda: lddt_loss(
+ logits=out["lddt_logits"],
+ all_atom_pred_pos=out["final_atom_positions"],
+ **{**batch, **self.config.lddt},
+ )
+ if self.config.fape.enabled:
+ loss_fns["fape"] = lambda: fape_loss(
+ out,
+ batch,
+ self.config.fape,
+ )
+ if self.config.tm.enabled:
+ loss_fns["tm"] = lambda: tm_loss(
+ logits=out["tm_logits"],
+ **{**batch, **out, **self.config.tm},
+ )
+ if self.config.backbone.enabled:
+ loss_fns["backbone"] = lambda: backbone_atom_loss(
+ pred_atom37=out["atom37"],
+ batch=batch,
+ mask=loss_mask,
+ **self.config.backbone,
+ )
+ if self.config.pwd.enabled:
+ loss_fns["pwd"] = lambda: pairwise_distance_loss(
+ pred_atom37=out["atom37"],
+ batch=batch,
+ mask=loss_mask,
+ **self.config.pwd,
+ )
+
+ cum_loss = 0.
+ losses = {}
+ for loss_name, loss_fn in loss_fns.items():
+ weight = self.config[loss_name].weight
+ loss = loss_fn()
+ if torch.isnan(loss) or torch.isinf(loss):
+ logging.warning(f"{loss_name} loss is NaN. Skipping...")
+ loss = loss.new_tensor(0., requires_grad=True)
+ cum_loss = cum_loss + weight * loss
+ losses[loss_name] = loss.detach().clone()
+
+ # losses["unscaled_loss"] = cum_loss.detach().clone()
+
+ # Scale the loss by the square root of the minimum of the crop size and
+ # the (average) sequence length. See subsection 1.9.
+ # seq_len = torch.mean(batch["seq_length"].float())
+ # crop_len = batch["aatype"].shape[-1]
+ # cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
+
+ losses["loss"] = cum_loss.detach().clone()
+
+ if not _return_breakdown:
+ return cum_loss
+
+ return cum_loss, losses
diff --git a/analysis/src/models/net/__init__.py b/analysis/src/models/net/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbe50b991d299f54e717393bcd3228481c5554a0
Binary files /dev/null and b/analysis/src/models/net/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc b/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb0704f6950e4402f56d8c64add284c543a3cb76
Binary files /dev/null and b/analysis/src/models/net/__pycache__/denoising_ipa.cpython-39.pyc differ
diff --git a/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc b/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dc9a2164c3b81bc48399c82b57221f57e57fc4d
Binary files /dev/null and b/analysis/src/models/net/__pycache__/ipa.cpython-39.pyc differ
diff --git a/analysis/src/models/net/__pycache__/layers.cpython-39.pyc b/analysis/src/models/net/__pycache__/layers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7173fb6bc524cad8068b4f18e5ae180751d69da
Binary files /dev/null and b/analysis/src/models/net/__pycache__/layers.cpython-39.pyc differ
diff --git a/analysis/src/models/net/denoising_ipa.py b/analysis/src/models/net/denoising_ipa.py
new file mode 100644
index 0000000000000000000000000000000000000000..a954a7b62ca3db1498e4c8bb492c80c3920d2153
--- /dev/null
+++ b/analysis/src/models/net/denoising_ipa.py
@@ -0,0 +1,211 @@
+import math
+from functools import partial
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from src.common.all_atom import compute_backbone
+from src.common.geo_utils import calc_distogram
+from src.models.net.ipa import TranslationIPA
+
+
+def get_positional_embedding(indices, embedding_dim, max_len=2056):
+ """Creates sine / cosine positional embeddings from a prespecified indices.
+
+ Args:
+ indices: offsets of size [..., N_edges] of type integer
+ max_len: maximum length.
+ embedding_dim: dimension of the embeddings to create
+
+ Returns:
+ positional embedding of shape [N, embedding_dim]
+ """
+ K = torch.arange(embedding_dim//2, device=indices.device)
+ pos_embedding_sin = torch.sin(
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device)
+ pos_embedding_cos = torch.cos(
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device)
+ pos_embedding = torch.cat([
+ pos_embedding_sin, pos_embedding_cos], axis=-1)
+ return pos_embedding
+
+
+def get_timestep_embedding(timesteps, embedding_dim, max_len=10000):
+ # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
+ assert len(timesteps.shape) == 1
+ timesteps = timesteps * max_len
+ half_dim = embedding_dim // 2
+ emb = math.log(max_len) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = F.pad(emb, (0, 1), mode='constant')
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
+ return emb
+
+
+class EmbeddingModule(nn.Module):
+ def __init__(self,
+ init_embed_size: int,
+ node_embed_size: int,
+ edge_embed_size: int,
+ num_bins: int = 22,
+ min_bin: float = 1e-5,
+ max_bin: float = 20.0,
+ self_conditioning: bool = True,
+ ):
+ super(EmbeddingModule, self).__init__()
+ pos_embed_size = init_embed_size
+ t_embed_size = init_embed_size
+
+ # time embedding
+ node_in_dim = t_embed_size + 1
+ edge_in_dim = (t_embed_size + 1) * 2
+
+ # positional embedding
+ node_in_dim += pos_embed_size
+ edge_in_dim += pos_embed_size
+
+ self.node_embed = nn.Sequential(
+ nn.Linear(node_in_dim, node_embed_size),
+ nn.ReLU(),
+ nn.Linear(node_embed_size, node_embed_size),
+ nn.ReLU(),
+ nn.Linear(node_embed_size, node_embed_size),
+ nn.LayerNorm(node_embed_size),
+ )
+
+ # self-conditioning trick used in RFDiffusion
+ self.self_conditioning = self_conditioning
+ if self_conditioning:
+ edge_in_dim += num_bins
+
+ self.edge_embed = nn.Sequential(
+ nn.Linear(edge_in_dim, edge_embed_size),
+ nn.ReLU(),
+ nn.Linear(edge_embed_size, edge_embed_size),
+ nn.ReLU(),
+ nn.Linear(edge_embed_size, edge_embed_size),
+ nn.LayerNorm(edge_embed_size),
+ )
+
+ self.time_embed = partial(
+ get_timestep_embedding, embedding_dim=t_embed_size
+ )
+ self.position_embed = partial(
+ get_positional_embedding, embedding_dim=pos_embed_size
+ )
+ self.distogram_embed = partial(
+ calc_distogram,
+ min_bin=min_bin,
+ max_bin=max_bin,
+ num_bins=num_bins,
+ )
+
+ def forward(
+ self,
+ residue_idx,
+ t,
+ fixed_mask,
+ self_conditioning_ca,
+ ):
+ """
+ Args:
+ residue_idx: [..., N] Positional sequence index for each residue.
+ t: Sampled t in [0, 1].
+ fixed_mask: mask of fixed (motif) residues.
+ self_conditioning_ca: [..., N, 3] Ca positions of self-conditioning
+ input.
+
+ Returns:
+ node_embed: [B, N, D_node]
+ edge_embed: [B, N, N, D_edge]
+ """
+ B, L = residue_idx.shape
+ fixed_mask = fixed_mask[..., None].float()
+ node_feats = []
+ pair_feats = []
+
+ # configure time embedding
+ t_embed = torch.tile(self.time_embed(t)[:, None, :], (1, L, 1))
+ t_embed = torch.cat([t_embed, fixed_mask], dim=-1)
+ node_feats.append(t_embed)
+
+ # make pair embedding from 1d time feats
+ concat_1d = torch.cat(
+ [torch.tile(t_embed[:, :, None, :], (1, 1, L, 1)),
+ torch.tile(t_embed[:, None, :, :], (1, L, 1, 1))],
+ dim=-1).float().reshape([B, L**2, -1])
+ pair_feats.append(concat_1d)
+
+ # positional embedding
+ node_feats.append(self.position_embed(residue_idx))
+
+ # relative 2d positional embedding
+ rel_seq_offset = residue_idx[:, :, None] - residue_idx[:, None, :]
+ rel_seq_offset = rel_seq_offset.reshape([B, L**2])
+ pair_feats.append(self.position_embed(rel_seq_offset))
+
+ # self-conditioning distogram of C-alpha atoms
+ if self.self_conditioning:
+ ca_dist = self.distogram_embed(self_conditioning_ca)
+ pair_feats.append(ca_dist.reshape([B, L**2, -1]))
+
+ node_embed = self.node_embed(torch.cat(node_feats, dim=-1).float())
+ edge_embed = self.edge_embed(torch.cat(pair_feats, dim=-1).float())
+ edge_embed = edge_embed.reshape([B, L, L, -1])
+ return node_embed, edge_embed
+
+
+class DenoisingNet(nn.Module):
+ def __init__(self,
+ embedder: nn.Module,
+ translator: nn.Module,
+ ):
+ super(DenoisingNet, self).__init__()
+ self.embedder = embedder # embedding module
+ self.translator = translator # translationIPA
+
+ def forward(self, batch, as_tensor_7=False):
+ """Forward computes the denoised frames p(X^t|X^{t+1})
+ """
+ # Frames as [batch, res, 7] tensors.
+ node_mask = batch['residue_mask'].type(torch.float) # [B, N]
+ fixed_mask = batch['fixed_mask'].type(torch.float)
+ edge_mask = node_mask[..., None] * node_mask[..., None, :]
+
+ # Get embeddings.
+ node_embed, edge_embed = self.embedder(
+ residue_idx=batch['residue_idx'],
+ t=batch['t'],
+ fixed_mask=fixed_mask,
+ self_conditioning_ca=batch['sc_ca_t'],
+ )
+ node_embed = node_embed * node_mask[..., None] # (L, D)
+ edge_embed = edge_embed * edge_mask[..., None] # (L, L, D)
+
+ # Translation for frames.
+ model_out = self.translator(node_embed, edge_embed, batch)
+
+ # Psi angle prediction
+ gt_psi = batch['torsion_angles_sin_cos'][..., 2, :]
+ psi_pred = gt_psi * fixed_mask[..., None] + model_out['psi'] * (1 - fixed_mask[..., None])
+ rigids_pred = model_out['out_rigids']
+
+ bb_representations = compute_backbone(
+ rigids_pred, psi_pred, aatype=batch['aatype'] if 'aatype' in batch else None
+ )
+ atom37_pos = bb_representations[0].to(rigids_pred.device)
+ atom14_pos = bb_representations[-1].to(rigids_pred.device)
+
+ if as_tensor_7:
+ rigids_pred = rigids_pred.to_tensor_7()
+
+ return {
+ 'rigids': rigids_pred,
+ 'psi': psi_pred,
+ 'atom37': atom37_pos,
+ 'atom14': atom14_pos,
+ }
\ No newline at end of file
diff --git a/analysis/src/models/net/ipa.py b/analysis/src/models/net/ipa.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc42264f44ea24914b3b00929feab779a2ededb3
--- /dev/null
+++ b/analysis/src/models/net/ipa.py
@@ -0,0 +1,387 @@
+"""
+Adapted from [Openfold](https://github.com/aqlaboratory/openfold) IPA implementation.
+"""
+
+import math
+from typing import Optional, List, Sequence
+
+import torch
+import torch.nn as nn
+
+from src.common.rigid_utils import Rigid
+from src.models.net.layers import Linear, NodeTransition, EdgeTransition, TorsionAngleHead, BackboneUpdate
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
+
+
+class InvariantPointAttention(nn.Module):
+ """
+ Implements Algorithm 22.
+ """
+ def __init__(
+ self,
+ c_s: int,
+ c_z: int,
+ c_hidden: int,
+ no_heads: int,
+ no_qk_points: int,
+ no_v_points: int,
+ inf: float = 1e5,
+ eps: float = 1e-8,
+ ):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ c_z:
+ Pair representation channel dimension
+ c_hidden:
+ Hidden channel dimension
+ no_heads:
+ Number of attention heads
+ no_qk_points:
+ Number of query/key points to generate
+ no_v_points:
+ Number of value points to generate
+ """
+ super(InvariantPointAttention, self).__init__()
+
+ self.c_s = c_s
+ self.c_z = c_z
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.no_qk_points = no_qk_points
+ self.no_v_points = no_v_points
+ self.inf = inf
+ self.eps = eps
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = self.c_hidden * self.no_heads
+ self.linear_q = Linear(self.c_s, hc)
+ self.linear_kv = Linear(self.c_s, 2 * hc)
+
+ hpq = self.no_heads * self.no_qk_points * 3
+ self.linear_q_points = Linear(self.c_s, hpq)
+
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
+ self.linear_kv_points = Linear(self.c_s, hpkv)
+
+ self.linear_b = Linear(self.c_z, self.no_heads)
+ self.down_z = Linear(self.c_z, self.c_z // 4)
+
+ self.head_weights = nn.Parameter(torch.zeros((self.no_heads)))
+ ipa_point_weights_init_(self.head_weights)
+
+ concat_out_dim = (
+ self.c_z // 4 + self.c_hidden + self.no_v_points * 4
+ )
+ self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ r: Rigid,
+ mask: torch.Tensor,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ if _offload_inference:
+ z = _z_reference_list
+ else:
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
+ )
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
+ )
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if(_offload_inference):
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ a = torch.matmul(
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ pt_att = pt_displacement ** 2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ head_weights = self.softplus(self.head_weights).view(
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
+ )
+ head_weights = head_weights * math.sqrt(
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
+ )
+ pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(
+ a, v.transpose(-2, -3).to(dtype=a.dtype)
+ ).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ o_pt = torch.sum(
+ (
+ a[..., None, :, :, None]
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
+ ),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
+ o_pt_norm_feats = flatten_final_dims(
+ o_pt_dists, 2)
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if(_offload_inference):
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z // 4]
+ pair_z = self.down_z(z[0]).to(dtype=a.dtype)
+ o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
+
+ # [*, N_res, H * C_z // 4]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat(
+ o_feats, dim=-1
+ ).to(dtype=z[0].dtype)
+ )
+
+ return s
+
+
+class TranslationIPA(nn.Module):
+ def __init__(self,
+ c_s: int,
+ c_z: int,
+ coordinate_scaling: float,
+ no_ipa_blocks: int,
+ skip_embed_size: int,
+ transformer_num_heads: int = 4,
+ transformer_num_layers: int = 2,
+ c_hidden: int = 256,
+ no_heads: int = 8,
+ no_qk_points: int = 8,
+ no_v_points: int = 12,
+ dropout: float = 0.0,
+ ):
+ super(TranslationIPA, self).__init__()
+
+ self.scale_pos = lambda x: x * coordinate_scaling
+ self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos)
+
+ self.unscale_pos = lambda x: x / coordinate_scaling
+ self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos)
+ self.trunk = nn.ModuleDict()
+ self.num_blocks = no_ipa_blocks
+
+ for b in range(no_ipa_blocks):
+ self.trunk[f'ipa_{b}'] = InvariantPointAttention(
+ c_s=c_s,
+ c_z=c_z,
+ c_hidden=c_hidden,
+ no_heads=no_heads,
+ no_qk_points=no_qk_points,
+ no_v_points=no_v_points,
+ )
+ self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(c_s)
+ self.trunk[f'skip_embed_{b}'] = Linear(
+ c_s,
+ skip_embed_size,
+ init="final"
+ )
+ _in_dim = c_s + skip_embed_size
+ transformer_layer = nn.TransformerEncoderLayer(
+ d_model=_in_dim,
+ nhead=transformer_num_heads,
+ dim_feedforward=_in_dim,
+ )
+ self.trunk[f'transformer_{b}'] = nn.TransformerEncoder(transformer_layer, transformer_num_layers)
+ self.trunk[f'linear_{b}'] = Linear(_in_dim, c_s, init="final")
+ self.trunk[f'node_transition_{b}'] = NodeTransition(c_s)
+ self.trunk[f'bb_update_{b}'] = BackboneUpdate(c_s)
+
+ if b < self.num_blocks - 1: # No need edge update on the last block
+ self.trunk[f'edge_transition_{b}'] = EdgeTransition(
+ node_embed_size=c_s,
+ edge_embed_in=c_z,
+ edge_embed_out=c_z,
+ )
+ self.torsion_pred = TorsionAngleHead(c_s, 1)
+ # self.chi_pred = TorsionAngleHead(c_s, 4)
+
+ def forward(self, node_embed, edge_embed, batch):
+ node_mask = batch['residue_mask'].type(torch.float)
+ diffuse_mask = (1 - batch['fixed_mask'].type(torch.float)) * node_mask
+ edge_mask = node_mask[..., None] * node_mask[..., None, :]
+
+ init_frames = batch['rigids_t'].type(torch.float)
+ curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames))
+ init_rigids = Rigid.from_tensor_7(init_frames)
+ curr_rigids = self.scale_rigids(curr_rigids)
+
+ # Main trunk
+ init_node_embed = node_embed
+ for b in range(self.num_blocks):
+ ipa_embed = self.trunk[f'ipa_{b}'](
+ node_embed,
+ edge_embed,
+ curr_rigids,
+ node_mask
+ )
+ ipa_embed *= node_mask[..., None]
+ node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
+ # MHA + FF
+ concat_node_embed = torch.cat([
+ node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed)
+ ], dim=-1)
+ concat_node_embed = torch.transpose(concat_node_embed, 0, 1)
+ transformed_embed = self.trunk[f'transformer_{b}'](concat_node_embed, src_key_padding_mask=1.0 - node_mask)
+ transformed_embed = torch.transpose(transformed_embed, 0, 1)
+ # residual
+ node_embed = node_embed + self.trunk[f'linear_{b}'](transformed_embed)
+
+ # MLP
+ node_embed = self.trunk[f'node_transition_{b}'](node_embed)
+ # apply mask
+ node_embed = node_embed * node_mask[..., None]
+ # backbone update
+ rigid_update = self.trunk[f'bb_update_{b}'](node_embed * diffuse_mask[..., None])
+ # apply update
+ curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, diffuse_mask[..., None])
+
+ if b < self.num_blocks - 1:
+ edge_embed = self.trunk[f'edge_transition_{b}'](node_embed, edge_embed) * edge_mask[..., None]
+
+ # torsion prediction
+ psi_pred = self.torsion_pred(node_embed)
+ # chi_pred = self.chi_pred(node_embed)
+
+ # scale back
+ curr_rigids = self.unscale_rigids(curr_rigids)
+
+ model_out = {
+ 'in_rigids': init_rigids,
+ 'out_rigids': curr_rigids,
+ 'psi': psi_pred,
+ # 'chi': chi_pred,
+ }
+ return model_out
diff --git a/analysis/src/models/net/layers.py b/analysis/src/models/net/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f777934166f409ba0ee83ef5b3c987746e18da31
--- /dev/null
+++ b/analysis/src/models/net/layers.py
@@ -0,0 +1,241 @@
+import math
+from typing import Optional, Callable, List
+
+import numpy as np
+import torch
+import torch.nn as nn
+from scipy.stats import truncnorm
+
+
+def _prod(nums):
+ out = 1
+ for n in nums:
+ out = out * n
+ return out
+
+def _calculate_fan(linear_weight_shape, fan="fan_in"):
+ fan_out, fan_in = linear_weight_shape
+
+ if fan == "fan_in":
+ f = fan_in
+ elif fan == "fan_out":
+ f = fan_out
+ elif fan == "fan_avg":
+ f = (fan_in + fan_out) / 2
+ else:
+ raise ValueError("Invalid fan option")
+
+ return f
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ f = _calculate_fan(shape, fan)
+ scale = scale / max(1, f)
+ a = -2
+ b = 2
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
+ size = _prod(shape)
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
+ samples = np.reshape(samples, shape)
+ with torch.no_grad():
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+def lecun_normal_init_(weights):
+ trunc_normal_init_(weights, scale=1.0)
+
+def he_normal_init_(weights):
+ trunc_normal_init_(weights, scale=2.0)
+
+def glorot_uniform_init_(weights):
+ nn.init.xavier_uniform_(weights, gain=1)
+
+def final_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+def gating_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+def normal_init_(weights):
+ nn.init.kaiming_normal_(weights, nonlinearity="linear")
+
+
+class Linear(nn.Linear):
+ """
+ A Linear layer with in-house nonstandard initializations.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization
+ "relu": He initialization w/ truncated normal distribution
+ "glorot": Fan-average Glorot uniform initialization
+ "gating": Weights=0, Bias=1
+ "normal": Normal initialization with std=1/sqrt(fan_in)
+ "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs.
+ Overrides init if not None.
+ """
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+ if init_fn is not None:
+ init_fn(self.weight, self.bias)
+ else:
+ if init == "default":
+ lecun_normal_init_(self.weight)
+ elif init == "relu":
+ he_normal_init_(self.weight)
+ elif init == "glorot":
+ glorot_uniform_init_(self.weight)
+ elif init == "gating":
+ gating_init_(self.weight)
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(1.0)
+ elif init == "normal":
+ normal_init_(self.weight)
+ elif init == "final":
+ final_init_(self.weight)
+ else:
+ raise ValueError("Invalid init string.")
+
+
+# Simply 3-layer MLP
+class NodeTransition(nn.Module):
+ def __init__(self, dim):
+ super(NodeTransition, self).__init__()
+ self.dim = dim
+ self.linear_1 = Linear(dim, dim, init="relu")
+ self.linear_2 = Linear(dim, dim, init="relu")
+ self.linear_3 = Linear(dim, dim, init="final")
+ self.relu = nn.ReLU()
+ self.ln = nn.LayerNorm(dim)
+
+ def forward(self, s):
+ s_initial = s
+ s = self.relu(self.linear_1(s))
+ s = self.relu(self.linear_2(s))
+ s = self.linear_3(s)
+ s = s + s_initial
+ s = self.ln(s)
+ return s
+
+
+class EdgeTransition(nn.Module):
+ def __init__(self,
+ node_embed_size,
+ edge_embed_in,
+ edge_embed_out,
+ num_layers=2,
+ node_dilation=2
+ ):
+ super(EdgeTransition, self).__init__()
+
+ bias_embed_size = node_embed_size // node_dilation
+ self.initial_embed = Linear(
+ node_embed_size, bias_embed_size, init="relu")
+ hidden_size = bias_embed_size * 2 + edge_embed_in
+ trunk_layers = []
+ for _ in range(num_layers):
+ trunk_layers.append(Linear(hidden_size, hidden_size, init="relu"))
+ trunk_layers.append(nn.ReLU())
+ self.trunk = nn.Sequential(*trunk_layers)
+ self.final_layer = Linear(hidden_size, edge_embed_out, init="final")
+ self.layer_norm = nn.LayerNorm(edge_embed_out)
+
+ def forward(self, node_embed, edge_embed):
+ node_embed = self.initial_embed(node_embed)
+ batch_size, num_res, _ = node_embed.shape
+ edge_bias = torch.cat([
+ torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)),
+ torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)),
+ ], axis=-1)
+ edge_embed = torch.cat(
+ [edge_embed, edge_bias], axis=-1).reshape(
+ batch_size * num_res**2, -1)
+ edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed)
+ edge_embed = self.layer_norm(edge_embed)
+ edge_embed = edge_embed.reshape(
+ batch_size, num_res, num_res, -1
+ )
+ return edge_embed
+
+
+class TorsionAngleHead(nn.Module):
+ def __init__(self, in_dim, n_torsion_angles, eps=1e-8):
+ super(TorsionAngleHead, self).__init__()
+
+ self.linear_1 = Linear(in_dim, in_dim, init="relu")
+ self.linear_2 = Linear(in_dim, in_dim, init="relu")
+ self.linear_3 = Linear(in_dim, in_dim, init="final")
+ self.linear_final = Linear(in_dim, n_torsion_angles * 2, init="final")
+ self.relu = nn.ReLU()
+ self.eps = eps
+
+ def forward(self, s):
+ s_initial = s
+ s = self.relu(self.linear_1(s))
+ s = self.linear_2(s)
+ s = s + s_initial
+
+ unnormalized_s = self.linear_final(s)
+ norm_denom = torch.sqrt(
+ torch.clamp(
+ torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True),
+ min=self.eps,
+ )
+ )
+ normalized_s = unnormalized_s / norm_denom
+ return normalized_s
+
+
+class BackboneUpdate(nn.Module):
+ """
+ Implements part of Algorithm 23.
+ """
+
+ def __init__(self, c_s):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ """
+ super(BackboneUpdate, self).__init__()
+
+ self.c_s = c_s
+ self.linear = Linear(self.c_s, 6, init="final")
+
+ def forward(self, s: torch.Tensor):
+ """
+ Args:
+ [*, N_res, C_s] single representation
+ Returns:
+ [*, N_res, 6] update vector
+ """
+ # [*, 6]
+ update = self.linear(s)
+ return update
\ No newline at end of file
diff --git a/analysis/src/models/score/__init__.py b/analysis/src/models/score/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc b/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04258e1ce8dd14ffd1def9b52a53e41cbdce88f4
Binary files /dev/null and b/analysis/src/models/score/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/models/score/__pycache__/frame.cpython-39.pyc b/analysis/src/models/score/__pycache__/frame.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..643ca419ca118df819762549957054e13764d3d0
Binary files /dev/null and b/analysis/src/models/score/__pycache__/frame.cpython-39.pyc differ
diff --git a/analysis/src/models/score/__pycache__/r3.cpython-39.pyc b/analysis/src/models/score/__pycache__/r3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..472954bc6f4c67f7282cb6c96da062cdef60c9af
Binary files /dev/null and b/analysis/src/models/score/__pycache__/r3.cpython-39.pyc differ
diff --git a/analysis/src/models/score/__pycache__/so3.cpython-39.pyc b/analysis/src/models/score/__pycache__/so3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78eecc5862cbaaaa6aec9445e2a967d6e6d364f5
Binary files /dev/null and b/analysis/src/models/score/__pycache__/so3.cpython-39.pyc differ
diff --git a/analysis/src/models/score/frame.py b/analysis/src/models/score/frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..26589afc1ccbbd19cb14cab7b47d378de9c62db8
--- /dev/null
+++ b/analysis/src/models/score/frame.py
@@ -0,0 +1,255 @@
+from typing import Optional
+import torch
+
+from src.models.score import so3, r3
+from src.common.rigid_utils import Rigid, Rotation, quat_multiply
+from src.common import rotation3d
+
+
+def assemble_rigid(rotvec: torch.Tensor, trans: torch.Tensor):
+ rotvec_shape = rotvec.shape
+ rotmat = rotation3d.axis_angle_to_matrix(rotvec).view(rotvec_shape[:-1] + (3, 3))
+ return Rigid(
+ rots=Rotation(rot_mats=rotmat),
+ trans=trans,
+ )
+
+def apply_mask(x_tgt, x_src, tgt_mask):
+ return tgt_mask * x_tgt + (1 - tgt_mask) * x_src
+
+
+class FrameDiffuser:
+ """
+ Wrapper class for diffusion of rigid body transformations,
+ including rotations and translations.
+ """
+ def __init__(self,
+ trans_diffuser: Optional[r3.R3Diffuser] = None,
+ rot_diffuser: Optional[so3.SO3Diffuser] = None,
+ min_t: float = 0.001,
+ ):
+ # if None, then no diffusion for this component
+ self.trans_diffuser = trans_diffuser
+ self.rot_diffuser = rot_diffuser
+ self.min_t = min_t
+
+ def forward_marginal(
+ self,
+ rigids_0: Rigid,
+ t: torch.Tensor,
+ diffuse_mask: torch.Tensor = None,
+ as_tensor_7: bool = True,
+ ):
+ """
+ Args:
+ rigids_0: [..., N] openfold Rigid objects
+ t: continuous time in [0, 1].
+
+ Returns:
+ Dict contains:
+ rigids_t: [..., N] noised rigid. [..., N, 7] if as_tensor_7 is true.
+ trans_score: [..., N, 3] translation score
+ rot_score: [..., N, 3] rotation score
+ trans_score_norm: [...] translation score norm
+ rot_score_norm: [...] rotation score norm
+ """
+ output = {}
+ rot_0 = rotation3d.matrix_to_axis_angle(rigids_0.get_rots().get_rot_mats())
+ trans_0 = rigids_0.get_trans()
+
+ if self.rot_diffuser is None:
+ rot_t = rot_0
+ rot_score, rot_score_scaling = torch.zeros_like(rot_0), t
+ else:
+ rot_t, rot_score = self.rot_diffuser.forward_marginal(rot_0, t)
+ rot_score_scaling = self.rot_diffuser.score_scaling(t)
+
+ if self.trans_diffuser is None:
+ trans_t, trans_score, trans_score_scaling = (
+ trans_0,
+ torch.zeros_like(trans_0),
+ torch.ones_like(t)
+ )
+ else:
+ trans_t, trans_score = self.trans_diffuser.forward_marginal(trans_0, t)
+ trans_score_scaling = self.trans_diffuser.score_scaling(t)
+
+ # Perturb only a subset of residues
+ if diffuse_mask is not None:
+ diffuse_mask = torch.as_tensor(diffuse_mask, device=trans_t.device, dtype=trans_t.dtype)[..., None]
+
+ rot_t = apply_mask(rot_t, rot_0, diffuse_mask)
+ trans_t = apply_mask(trans_t, trans_0, diffuse_mask)
+
+ trans_score = apply_mask(
+ trans_score,
+ torch.zeros_like(trans_score),
+ diffuse_mask
+ )
+ rot_score = apply_mask(
+ rot_score,
+ torch.zeros_like(rot_score),
+ diffuse_mask
+ )
+
+ rigids_t = assemble_rigid(rot_t, trans_t)
+
+ if as_tensor_7:
+ rigids_t = rigids_t.to_tensor_7()
+
+ output = {
+ 'rigids_t': rigids_t,
+ 'trans_score': trans_score,
+ 'rot_score': rot_score,
+ 'trans_score_scaling': trans_score_scaling,
+ 'rot_score_scaling': rot_score_scaling,
+ }
+ return output
+
+ def score(
+ self,
+ rigids_0: Rigid,
+ rigids_t: Rigid,
+ t: torch.Tensor,
+ mask: torch.Tensor = None,
+ ):
+ rot_0, trans_0 = rigids_0.get_rots(), rigids_0.get_trans()
+ rot_t, trans_t = rigids_t.get_rots(), rigids_t.get_trans()
+
+ if self.rot_diffuser is None:
+ rot_score = torch.zeros_like(rot_0)
+ else:
+ rot_0_inv = rot_0.invert()
+ quat_0_inv = rotation3d.matrix_to_quaternion(rot_0_inv.get_rot_mats())
+ quat_t = rotation3d.matrix_to_quaternion(rot_t.get_rot_mats())
+ # get relative rotation
+ quat_0t = quat_multiply(quat_0_inv, quat_t)
+ rotvec_0t = rotation3d.quaternion_to_axis_angle(quat_0t)
+ # calculate score
+ rot_score = self.rot_diffuser.score(rotvec_0t, t)
+
+ if self.trans_diffuser is None:
+ trans_score = torch.zeros_like(trans_0)
+ else:
+ trans_score = self.trans_diffuser.score(trans_t, trans_0, t, scale=True)
+
+ if mask is not None:
+ trans_score = trans_score * mask[..., None]
+ rot_score = rot_score * mask[..., None]
+
+ return {
+ 'trans_score': trans_score,
+ 'rot_score': rot_score
+ }
+
+ def score_scaling(self, t):
+ rot_score_scaling = self.rot_diffuser.score_scaling(t)
+ trans_score_scaling = self.trans_diffuser.score_scaling(t)
+ return {
+ 'trans_score_scaling': trans_score_scaling,
+ 'rot_score_scaling': rot_score_scaling,
+ }
+
+ def reverse(
+ self,
+ rigids_t: Rigid,
+ rot_score: torch.Tensor,
+ trans_score: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ diffuse_mask: torch.Tensor = None,
+ center_trans: bool = True,
+ noise_scale: float = 1.0,
+ probability_flow: bool = True,
+ ):
+ """Reverse sampling function from (t) to (t-1).
+
+ Args:
+ rigids_t: [..., N] protein rigid objects at time t.
+ rot_score: [..., N, 3] rotation score.
+ trans_score: [..., N, 3] translation score.
+ t: continuous time in [0, 1].
+ dt: continuous step size in [0, 1].
+ mask: [..., N] which residues to update.
+ center_trans: true to set center of mass to zero after step
+ probability_flow: whether to use probability flow ODE.
+
+ Returns:
+ rigids_t_1: [..., N] protein rigid objects at time t-1.
+ """
+ # extract rot and trans as tensors
+ rot_t = rotation3d.matrix_to_axis_angle(rigids_t.get_rots().get_rot_mats())
+ trans_t = rigids_t.get_trans()
+
+ # reverse rot
+ rot_t_1 = self.rot_diffuser.reverse(
+ rot_t=rot_t,
+ score_t=rot_score,
+ t=t,
+ dt=dt,
+ noise_scale=noise_scale,
+ probability_flow=probability_flow,
+ ) if self.rot_diffuser is not None else rot_t # if no diffusion module, return as-is
+
+ # reverse trans
+ trans_t_1 = self.trans_diffuser.reverse(
+ x_t=trans_t,
+ score_t=trans_score,
+ t=t,
+ dt=dt,
+ center=center_trans,
+ noise_scale=noise_scale,
+ probability_flow=probability_flow,
+ ) if self.trans_diffuser is not None else trans_t
+
+ # apply mask
+ if diffuse_mask is not None:
+ trans_t_1 = apply_mask(trans_t_1, trans_t, diffuse_mask[..., None])
+ rot_t_1 = apply_mask(rot_t_1, rot_t, diffuse_mask[..., None])
+
+ return assemble_rigid(rot_t_1, trans_t_1)
+
+ def sample_prior(
+ self,
+ shape: torch.Size,
+ device: torch.device,
+ reference_rigids: Rigid = None,
+ diffuse_mask: torch.Tensor = None,
+ as_tensor_7: bool = False
+ ):
+ """Samples rigids from reference distribution.
+
+ """
+ if reference_rigids is not None:
+ assert reference_rigids.shape[:-1] == shape, f"reference_rigids.shape[:-1] = {reference_rigids.shape[:-1]}, shape = {shape}"
+ assert diffuse_mask is not None, "diffuse_mask must be provided if reference_rigids is given"
+ rot_ref = rotation3d.matrix_to_axis_angle(reference_rigids.get_rots().get_rot_mats())
+ trans_ref = reference_rigids.get_trans()
+
+ trans_ref = self.trans_diffuser.scale(trans_ref)
+ else:
+ # sanity check
+ assert diffuse_mask is None, "diffuse_mask must be None if reference_rigids is None"
+ assert self.rot_diffuser is not None and self.trans_diffuser is not None
+
+ # sample from prior
+ trans_shape, rot_shape = shape + (3, ), shape + (3, )
+ rot_sample = self.rot_diffuser.sample_prior(shape=rot_shape, device=device) \
+ if self.rot_diffuser is not None else rot_ref
+ trans_sample = self.trans_diffuser.sample_prior(shape=trans_shape, device=device) \
+ if self.trans_diffuser is not None else trans_ref
+
+ # apply mask
+ if diffuse_mask is not None:
+ rot_sample = apply_mask(rot_sample, rot_ref, diffuse_mask[..., None])
+ trans_sample = apply_mask(trans_sample, trans_ref, diffuse_mask[..., None])
+
+ trans_sample = self.trans_diffuser.unscale(trans_sample)
+
+ # assemble sampled rot and trans -> rigid
+ rigids_t = assemble_rigid(rot_sample, trans_sample)
+
+ if as_tensor_7:
+ rigids_t = rigids_t.to_tensor_7()
+
+ return {'rigids_t': rigids_t}
\ No newline at end of file
diff --git a/analysis/src/models/score/r3.py b/analysis/src/models/score/r3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c11b6b37147e6c25ac9795ee22c4754272e4f7d8
--- /dev/null
+++ b/analysis/src/models/score/r3.py
@@ -0,0 +1,147 @@
+"""Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/r3_diffuser.py"""
+
+from math import sqrt
+import torch
+
+from src.utils.tensor_utils import inflate_array_like
+
+class R3Diffuser:
+ """VPSDE diffusion module."""
+ def __init__(
+ self,
+ min_b: float = 0.1,
+ max_b: float = 20.0,
+ coordinate_scaling: float = 1.0,
+ ):
+ self.min_b = min_b
+ self.max_b = max_b
+ self.coordinate_scaling = coordinate_scaling
+
+ def scale(self, x):
+ return x * self.coordinate_scaling
+
+ def unscale(self, x):
+ return x / self.coordinate_scaling
+
+ def b_t(self, t: torch.Tensor):
+ if torch.any(t < 0) or torch.any(t > 1):
+ raise ValueError(f'Invalid t={t}')
+ return self.min_b + t * (self.max_b - self.min_b)
+
+ def diffusion_coef(self, t):
+ return torch.sqrt(self.b_t(t))
+
+ def drift_coef(self, x, t):
+ return -0.5 * self.b_t(t) * x
+
+ def sample_prior(self, shape, device=None):
+ return torch.randn(size=shape, device=device)
+
+ def marginal_b_t(self, t):
+ return t*self.min_b + 0.5*(t**2)*(self.max_b-self.min_b)
+
+ def calc_trans_0(self, score_t, x_t, t):
+ beta_t = self.marginal_b_t(t)
+ beta_t = beta_t[..., None, None]
+ cond_var = 1 - torch.exp(-beta_t)
+ return (score_t * cond_var + x_t) / torch.exp(-0.5*beta_t)
+
+ def forward_marginal(
+ self,
+ x_0: torch.Tensor,
+ t: torch.Tensor
+ ):
+ """Samples marginal p(x(t) | x(0)).
+
+ Args:
+ x_0: [..., n, 3] initial positions in Angstroms.
+ t: continuous time in [0, 1].
+
+ Returns:
+ x_t: [..., n, 3] positions at time t in Angstroms.
+ score_t: [..., n, 3] score at time t in scaled Angstroms.
+ """
+ t = inflate_array_like(t, x_0)
+ x_0 = self.scale(x_0)
+
+ loc = torch.exp(-0.5 * self.marginal_b_t(t)) * x_0
+ scale = torch.sqrt(1 - torch.exp(-self.marginal_b_t(t)))
+ z = torch.randn_like(x_0)
+ x_t = z * scale + loc
+ score_t = self.score(x_t, x_0, t)
+
+ x_t = self.unscale(x_t)
+ return x_t, score_t
+
+ def score_scaling(self, t: torch.Tensor):
+ return 1.0 / torch.sqrt(self.conditional_var(t))
+
+ def reverse(
+ self,
+ x_t: torch.Tensor,
+ score_t: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ mask: torch.Tensor = None,
+ center: bool = True,
+ noise_scale: float = 1.0,
+ probability_flow: bool = True,
+ ):
+ """Simulates the reverse SDE for 1 step
+
+ Args:
+ x_t: [..., 3] current positions at time t in angstroms.
+ score_t: [..., 3] rotation score at time t.
+ t: continuous time in [0, 1].
+ dt: continuous step size in [0, 1].
+ mask: True indicates which residues to diffuse.
+ probability_flow: whether to use probability flow ODE.
+
+ Returns:
+ [..., 3] positions at next step t-1.
+ """
+ t = inflate_array_like(t, x_t)
+ x_t = self.scale(x_t)
+
+ f_t = self.drift_coef(x_t, t)
+ g_t = self.diffusion_coef(t)
+
+ z = noise_scale * torch.randn_like(score_t)
+
+ rev_drift = (f_t - g_t ** 2 * score_t) * dt * (0.5 if probability_flow else 1.)
+ rev_diffusion = 0. if probability_flow else (g_t * sqrt(dt) * z)
+ perturb = rev_drift + rev_diffusion
+
+ if mask is not None:
+ perturb *= mask[..., None]
+ else:
+ mask = torch.ones_like(x_t[..., 0])
+ x_t_1 = x_t - perturb # reverse in time
+ if center:
+ com = torch.sum(x_t_1, dim=-2) / torch.sum(mask, dim=-1)[..., None] # reduce length dim
+ x_t_1 -= com[..., None, :]
+
+ x_t_1 = self.unscale(x_t_1)
+ return x_t_1
+
+ def conditional_var(self, t, use_torch=False):
+ """Conditional variance of p(xt|x0).
+ Var[x_t|x_0] = conditional_var(t) * I
+ """
+ return 1.0 - torch.exp(-self.marginal_b_t(t))
+
+ def score(self, x_t, x_0, t, scale=False):
+ t = inflate_array_like(t, x_t)
+ if scale:
+ x_t, x_0 = self.scale(x_t), self.scale(x_0)
+ return -(x_t - torch.exp(-0.5 * self.marginal_b_t(t)) * x_0) / self.conditional_var(t)
+
+ def distribution(self, x_t, score_t, t, mask, dt):
+ x_t = self.scale(x_t)
+ f_t = self.drift_coef(x_t, t)
+ g_t = self.diffusion_coef(t)
+ std = g_t * sqrt(dt)
+ mu = x_t - (f_t - g_t**2 * score_t) * dt
+ if mask is not None:
+ mu *= mask[..., None]
+ return mu, std
\ No newline at end of file
diff --git a/analysis/src/models/score/so3.py b/analysis/src/models/score/so3.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a43caa5ba1e062020b1a7e95406427b29bb02c3
--- /dev/null
+++ b/analysis/src/models/score/so3.py
@@ -0,0 +1,371 @@
+"""Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/so3_diffuser.py"""
+
+import os
+import math
+
+import numpy as np
+import torch
+
+from src.utils.tensor_utils import inflate_array_like
+from src.common import rotation3d
+
+
+def compose_rotvec(rotvec1, rotvec2):
+ """Compose two rotation euler vectors."""
+ dtype = rotvec1.dtype
+ R1 = rotation3d.axis_angle_to_matrix(rotvec1)
+ R2 = rotation3d.axis_angle_to_matrix(rotvec2)
+ cR = torch.einsum('...ij,...jk->...ik', R1.double(), R2.double())
+ return rotation3d.matrix_to_axis_angle(cR).type(dtype)
+
+def igso3_expansion(omega, eps, L=1000, use_torch=False):
+ """Truncated sum of IGSO(3) distribution.
+
+ This function approximates the power series in equation 5 of
+ "DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL
+ ALIGNMENT"
+ Leach et al. 2022
+
+ This expression diverges from the expression in Leach in that here, eps =
+ sqrt(2) * eps_leach, if eps_leach were the scale parameter of the IGSO(3).
+
+ With this reparameterization, IGSO(3) agrees with the Brownian motion on
+ SO(3) with t=eps^2.
+
+ Args:
+ omega: rotation of Euler vector (i.e. the angle of rotation)
+ eps: std of IGSO(3).
+ L: Truncation level
+ use_torch: set true to use torch tensors, otherwise use numpy arrays.
+ """
+
+ lib = torch if use_torch else np
+ ls = lib.arange(L)
+ if use_torch:
+ ls = ls.to(omega.device)
+
+ if len(omega.shape) == 2:
+ # Used during predicted score calculation.
+ ls = ls[None, None] # [1, 1, L]
+ omega = omega[..., None] # [num_batch, num_res, 1]
+ eps = eps[..., None]
+ elif len(omega.shape) == 1:
+ # Used during cache computation.
+ ls = ls[None] # [1, L]
+ omega = omega[..., None] # [num_batch, 1]
+ else:
+ raise ValueError("Omega must be 1D or 2D.")
+ p = (2*ls + 1) * lib.exp(-ls*(ls+1)*eps**2/2) * lib.sin(omega*(ls+1/2)) / lib.sin(omega/2)
+ if use_torch:
+ return p.sum(dim=-1)
+ else:
+ return p.sum(axis=-1)
+
+
+def density(expansion, omega, marginal=True, use_torch=False):
+ """IGSO(3) density.
+
+ Args:
+ expansion: truncated approximation of the power series in the IGSO(3)
+ density.
+ omega: length of an Euler vector (i.e. angle of rotation)
+ marginal: set true to give marginal density over the angle of rotation,
+ otherwise include normalization to give density on SO(3) or a
+ rotation with angle omega.
+ """
+ lib = torch if use_torch else np
+ if marginal:
+ # if marginal, density over [0, pi], else over SO(3)
+ return expansion * (1.0 - lib.cos(omega)) / np.pi
+ else:
+ # the constant factor doesn't affect any actual calculations though
+ return expansion / 8 / np.pi**2
+
+
+def score(exp, omega, eps, L=1000, use_torch=False): # score of density over SO(3)
+ """score uses the quotient rule to compute the scaling factor for the score
+ of the IGSO(3) density.
+
+ This function is used within the Diffuser class to when computing the score
+ as an element of the tangent space of SO(3).
+
+ This uses the quotient rule of calculus, and take the derivative of the
+ log:
+ d hi(x)/lo(x) = (lo(x) d hi(x)/dx - hi(x) d lo(x)/dx) / lo(x)^2
+ and
+ d log expansion(x) / dx = (d expansion(x)/ dx) / expansion(x)
+
+ Args:
+ exp: truncated expansion of the power series in the IGSO(3) density
+ omega: length of an Euler vector (i.e. angle of rotation)
+ eps: scale parameter for IGSO(3) -- as in expansion() this scaling
+ differ from that in Leach by a factor of sqrt(2).
+ L: truncation level
+ use_torch: set true to use torch tensors, otherwise use numpy arrays.
+
+ Returns:
+ The d/d omega log IGSO3(omega; eps)/(1-cos(omega))
+
+ """
+ lib = torch if use_torch else np
+ ls = lib.arange(L)
+ if use_torch:
+ ls = ls.to(omega.device)
+ ls = ls[None]
+ if len(omega.shape) == 2:
+ ls = ls[None]
+ elif len(omega.shape) > 2:
+ raise ValueError("Omega must be 1D or 2D.")
+ omega = omega[..., None]
+ eps = eps[..., None]
+ hi = lib.sin(omega * (ls + 1 / 2))
+ dhi = (ls + 1 / 2) * lib.cos(omega * (ls + 1 / 2))
+ lo = lib.sin(omega / 2)
+ dlo = 1 / 2 * lib.cos(omega / 2)
+ dSigma = (2 * ls + 1) * lib.exp(-ls * (ls + 1) * eps**2/2) * (lo * dhi - hi * dlo) / lo ** 2
+ if use_torch:
+ dSigma = dSigma.sum(dim=-1)
+ else:
+ dSigma = dSigma.sum(axis=-1)
+ return dSigma / (exp + 1e-4)
+
+
+class SO3Diffuser:
+ def __init__(
+ self,
+ cache_dir: str = './cache',
+ schedule: str = 'logarithmic',
+ min_sigma: float = 0.1,
+ max_sigma: float = 1.5,
+ num_sigma: int = 1000,
+ num_omega: int = 1000,
+ use_cached_score: bool = False,
+ eps: float = 1e-6,
+ ):
+ self.schedule = schedule
+ self.min_sigma = min_sigma
+ self.max_sigma = max_sigma
+ self.num_sigma = num_sigma
+ self.use_cached_score = use_cached_score
+ self.eps = eps
+
+ # Discretize omegas for calculating CDFs. Skip omega=0.
+ self.discrete_omega = torch.linspace(0, np.pi, steps=num_omega+1)[1:]
+
+ # Configure cache directory.
+ replace_period = lambda x: str(x).replace('.', '_')
+ cache_dir = os.path.join(
+ cache_dir,
+ f'eps_{num_sigma}_omega_{num_omega}_min_sigma_{replace_period(min_sigma)}_max_sigma_{replace_period(max_sigma)}_schedule_{schedule}'
+ )
+ if not os.path.isdir(cache_dir):
+ os.makedirs(cache_dir)
+
+ pdf_cache = os.path.join(cache_dir, 'pdf_vals.pt')
+ cdf_cache = os.path.join(cache_dir, 'cdf_vals.pt')
+ score_norms_cache = os.path.join(cache_dir, 'score_norms.pt')
+
+ if os.path.exists(pdf_cache) \
+ and os.path.exists(cdf_cache) \
+ and os.path.exists(score_norms_cache):
+ self._pdf = torch.load(pdf_cache, map_location='cpu')
+ self._cdf = torch.load(cdf_cache, map_location='cpu')
+ self._score_norms = torch.load(score_norms_cache, map_location='cpu')
+ else:
+ disc_omega = self.discrete_omega.numpy()
+ disc_sigma = self.discrete_sigma.numpy()
+ # Compute the expansion of the power series.
+ exp_vals = np.asarray(
+ [igso3_expansion(disc_omega, sigma) for sigma in disc_sigma]
+ )
+ # Compute the pdf and cdf values for the marginal distribution of the axis-angles (readily needed for sampling)
+ self._pdf = np.asarray(
+ [density(x, disc_omega, marginal=True) for x in exp_vals]
+ )
+ self._cdf = np.asarray(
+ [pdf.cumsum() / num_omega * np.pi for pdf in self._pdf]
+ )
+ # Compute the norms of the scores. This are used to scale the rotation axis when
+ # computing the score as a vector.
+ self._score_norms = np.asarray(
+ [score(exp_vals[i], disc_omega, x) for i, x in enumerate(disc_sigma)]
+ )
+ self._pdf, self._cdf, self._score_norms = map(
+ torch.as_tensor, (self._pdf, self._cdf, self._score_norms)
+ )
+ # Cache the precomputed values
+ torch.save(obj=self._pdf, f=pdf_cache)
+ torch.save(obj=self._cdf, f=cdf_cache)
+ torch.save(obj=self._score_norms, f=score_norms_cache)
+
+ self._score_scaling = torch.sqrt(torch.abs(
+ torch.sum(self._score_norms**2 * self._pdf, dim=-1) / torch.sum(self._pdf, dim=-1)
+ )) / np.sqrt(3)
+
+ @property
+ def discrete_sigma(self):
+ return self.sigma(
+ torch.linspace(0.0, 1.0, self.num_sigma)
+ )
+
+ def sigma_idx(self, sigma: torch.Tensor):
+ """Calculates the index for discretized sigma during IGSO(3) initialization."""
+ return torch.as_tensor(np.digitize(sigma.cpu().numpy(), self.discrete_sigma) - 1,
+ dtype=torch.long)
+
+ def sigma(self, t: torch.Tensor):
+ """Extract \sigma(t) corresponding to chosen sigma schedule."""
+ if torch.any(t < 0) or torch.any(t > 1):
+ raise ValueError(f'Invalid t={t}')
+ if self.schedule == 'logarithmic':
+ return torch.log(t * math.exp(self.max_sigma) + (1 - t) * math.exp(self.min_sigma))
+ else:
+ raise ValueError(f'Unrecognize schedule {self.schedule}')
+
+ def diffusion_coef(self, t: torch.Tensor):
+ """Compute diffusion coefficient (g_t)."""
+ if self.schedule == 'logarithmic':
+ g_t = torch.sqrt(
+ 2 * (math.exp(self.max_sigma) - math.exp(self.min_sigma)) \
+ * self.sigma(t) / torch.exp(self.sigma(t))
+ )
+ else:
+ raise ValueError(f'Unrecognize schedule {self.schedule}')
+ return g_t
+
+ def t_to_idx(self, t: torch.Tensor):
+ """Helper function to go from time t to corresponding sigma_idx."""
+ return self.sigma_idx(self.sigma(t))
+
+ def sample_prior(self, shape: torch.Size, device=None):
+ t = torch.ones(shape[0], dtype=torch.float, device=device)
+ return self.sample(t, shape)
+
+ def sample(self, t: torch.Tensor, shape: torch.Size):
+ """Generates rotation vector(s) from IGSO(3).
+
+ Args:
+ t: continuous time in [0, 1].
+ shape: shape of the output tensor.
+
+ Returns:
+ (shape, ) axis-angle rotation vectors sampled from IGSO(3).
+ """
+ assert t.ndim == 1 and t.shape[0] == shape[0], \
+ f"t.shape={t.shape}, shape={shape}"
+ assert shape[-1] == 3, f"The last dim should be 3, but got shape={shape}"
+
+ z = torch.randn(shape, device=t.device) # axis-angle
+ x = z / torch.linalg.norm(z, dim=-1, keepdims=True)
+
+ # sample igso3
+ z_igso3 = torch.rand(shape[:-1]) # (num_batch, num_res)
+ igso3_scaling = []
+ for i, _t in enumerate(t):
+ t_idx = self.t_to_idx(_t).item() # [1]
+ # np.interp can accept Tensor as input
+ igso3_scaling.append(
+ np.interp(z_igso3[i], self._cdf[t_idx], self.discrete_omega),
+ ) # (num_res, )
+ igso3_scaling = torch.as_tensor(np.asarray(igso3_scaling), dtype=x.dtype, device=t.device)
+
+ return x * igso3_scaling[..., None]
+
+ def score(
+ self,
+ vec: torch.tensor,
+ t: torch.tensor,
+ ):
+ """Computes the score of IGSO(3) density as a rotation vector.
+
+ Same as score function but uses pytorch and performs a look-up.
+
+ Args:
+ vec: [..., 3] array of axis-angle rotation vectors.
+ t: continuous time in [0, 1].
+
+ Returns:
+ [..., 3] score vector in the direction of the sampled vector with
+ magnitude given by _score_norms.
+ """
+ assert t.ndim == 1 and t.shape[0] == vec.shape[0], \
+ f"t.shape={t.shape}, vec.shape={vec.shape}"
+ device = vec.device
+
+ omega = torch.linalg.norm(vec, dim=-1) + self.eps # [num_batch, num_res]
+
+ if self.use_cached_score:
+ score_norms_t = self._score_norms[self.t_to_idx(t)]
+ score_norms_t = torch.as_tensor(score_norms_t).to(device)
+ omega_idx = torch.bucketize(
+ omega, torch.as_tensor(self.discrete_omega[:-1]).to(device))
+ omega_scores_t = torch.gather(score_norms_t, 1, omega_idx)
+ else: # compute on the fly
+ sigma = self.discrete_sigma[self.t_to_idx(t)]
+ sigma = torch.as_tensor(sigma).to(device)
+ omega_vals = igso3_expansion(omega, sigma[:, None], use_torch=True)
+ omega_scores_t = score(omega_vals, omega, sigma[:, None], use_torch=True)
+
+ return omega_scores_t[..., None] * vec / (omega[..., None] + self.eps)
+
+ def score_scaling(self, t: torch.Tensor):
+ """Calculates scaling used for scores during trianing."""
+ return torch.as_tensor(self._score_scaling[self.t_to_idx(t)], device=t.device)
+
+ def forward_marginal(self, rot_0: torch.Tensor, t: torch.Tensor):
+ """Samples from the forward diffusion process at time index t.
+
+ Args:
+ rot_0: [..., 3] initial rotations.
+ t: continuous time in [0, 1].
+
+ Returns:
+ rot_t: [..., 3] noised rotation vectors.
+ rot_score: [..., 3] score of rot_t as a rotation vector.
+ """
+ rotvec_0t = self.sample(t, shape=rot_0.shape) # "delta_rot"
+ rot_score = self.score(rotvec_0t, t)
+
+ # Right multiply => vector_add in Euclidean space
+ rot_t = compose_rotvec(rot_0, rotvec_0t)
+ return rot_t, rot_score
+
+ def reverse(
+ self,
+ rot_t: torch.Tensor,
+ score_t: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ mask: torch.Tensor = None,
+ noise_scale: float = 1.0,
+ probability_flow: bool = True,
+ ):
+ """Simulates the reverse SDE for 1 step using the Geodesic random walk.
+
+ Args:
+ rot_t: [..., 3] current rotations at time t.
+ score_t: [..., 3] rotation score at time t.
+ t: continuous time in [0, 1].
+ dt: continuous step size in [0, 1].
+ add_noise: set False to set diffusion coefficent to 0.
+ mask: True indicates which residues to diffuse.
+ probability_flow: set True to use probability flow ODE.
+
+ Returns:
+ [..., 3] rotation vector at next step.
+ """
+ t = inflate_array_like(t, rot_t)
+
+ g_t = self.diffusion_coef(t)
+ z = noise_scale * torch.randn_like(score_t)
+
+ rev_drift = -1.0 * (g_t ** 2) * score_t * dt * (0.5 if probability_flow else 1.)
+ rev_diffusion = 0. if probability_flow else (g_t * np.sqrt(dt) * z)
+ perturb = rev_drift + rev_diffusion
+
+ if mask is not None:
+ perturb *= mask[..., None]
+
+ # Right multiply.
+ rot_t_1 = compose_rotvec(rot_t, -1.0 * perturb) # reverse in time
+ return rot_t_1
\ No newline at end of file
diff --git a/analysis/src/train.py b/analysis/src/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9973055301bf5844aef703317f637742f924c6c
--- /dev/null
+++ b/analysis/src/train.py
@@ -0,0 +1,135 @@
+from typing import Any, Dict, List, Optional, Tuple
+
+import hydra
+import lightning as L
+import rootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from src import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/rootutils
+# ------------------------------------------------------------------------------------ #
+
+from src.utils import (
+ RankedLogger,
+ extras,
+ get_metric_value,
+ instantiate_callbacks,
+ instantiate_loggers,
+ log_hyperparameters,
+ task_wrapper,
+ checkpoint_utils,
+)
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+@task_wrapper
+def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+
+ :param cfg: A DictConfig configuration composed by Hydra.
+ :return: A tuple with metrics and dict with all instantiated objects.
+ """
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ log_hyperparameters(object_dict)
+
+ model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.get("ckpt_path"))
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = None
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
+def main(cfg: DictConfig) -> Optional[float]:
+ """Main entry point for training.
+
+ :param cfg: DictConfig configuration composed by Hydra.
+ :return: Optional[float] with optimized metric value.
+ """
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ extras(cfg)
+
+ # train the model
+ metric_dict, _ = train(cfg)
+
+ # safely retrieve metric value for hydra-based hyperparameter optimization
+ metric_value = get_metric_value(
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
+ )
+
+ # return optimized metric
+ return metric_value
+
+
+if __name__ == "__main__":
+ main()
diff --git a/analysis/src/utils/__init__.py b/analysis/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0707ca57ec89fc5f5cb1a023135eeb756a8e1e
--- /dev/null
+++ b/analysis/src/utils/__init__.py
@@ -0,0 +1,5 @@
+from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
+from src.utils.logging_utils import log_hyperparameters
+from src.utils.pylogger import RankedLogger
+from src.utils.rich_utils import enforce_tags, print_config_tree
+from src.utils.utils import extras, get_metric_value, task_wrapper
diff --git a/analysis/src/utils/__pycache__/__init__.cpython-310.pyc b/analysis/src/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc38620891cdc2e6b3c469b37dfb5ccf5971370c
Binary files /dev/null and b/analysis/src/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/__init__.cpython-39.pyc b/analysis/src/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e16fdb32c1ff6c131fe52d84b7b2eaefb7456ed
Binary files /dev/null and b/analysis/src/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0e5c64df79d10dee3601326a0ed066cc4659070
Binary files /dev/null and b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44d97c36604a77e84444152aa33ba67fd17f0f68
Binary files /dev/null and b/analysis/src/utils/__pycache__/checkpoint_utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc b/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ebab309ccf49d8e01a2dadcccf996dc4e425d97
Binary files /dev/null and b/analysis/src/utils/__pycache__/instantiators.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc b/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e23b93bac589979f49f9b8fc6fac281b3026070
Binary files /dev/null and b/analysis/src/utils/__pycache__/instantiators.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e403b520b31a1bc9184a0bf74301017f0ce1e58
Binary files /dev/null and b/analysis/src/utils/__pycache__/logging_utils.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26c790bd93755550c51ec8cba7e6d2f8537a1e7c
Binary files /dev/null and b/analysis/src/utils/__pycache__/logging_utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1784bf4f81c7600125742d41edea5f19838c106
Binary files /dev/null and b/analysis/src/utils/__pycache__/plot_utils.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7dad302ccc0ef4a29b00490c6ac9eea37a24fc7
Binary files /dev/null and b/analysis/src/utils/__pycache__/plot_utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc b/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ef8808d9784553d45122ebebf168931145d9a9e
Binary files /dev/null and b/analysis/src/utils/__pycache__/pylogger.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc b/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..315a10f77f0a243ef68628c7df04759004806549
Binary files /dev/null and b/analysis/src/utils/__pycache__/pylogger.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc b/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acc7ea30ef6b89b3944f92dc3e6d3f3f836b7244
Binary files /dev/null and b/analysis/src/utils/__pycache__/rich_utils.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e48ab6a604dcd05449c7a29afe102bfc1e0e488
Binary files /dev/null and b/analysis/src/utils/__pycache__/rich_utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc b/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9647e909d1c230486fcdd49ea9533e6602e9f55
Binary files /dev/null and b/analysis/src/utils/__pycache__/tensor_utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/__pycache__/utils.cpython-310.pyc b/analysis/src/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1deb35fe5b79b73bfc4da02dece2e9a529c7e461
Binary files /dev/null and b/analysis/src/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/analysis/src/utils/__pycache__/utils.cpython-39.pyc b/analysis/src/utils/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e50a4c7a3c2f693aa2bda2891fae98ead5fe034
Binary files /dev/null and b/analysis/src/utils/__pycache__/utils.cpython-39.pyc differ
diff --git a/analysis/src/utils/checkpoint_utils.py b/analysis/src/utils/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec721430698a9f9dd1ae4da8244baef1d4b5ea8
--- /dev/null
+++ b/analysis/src/utils/checkpoint_utils.py
@@ -0,0 +1,28 @@
+import torch
+
+def load_model_checkpoint(model, ckpt_path):
+ """Load state dict from checkpoint file.
+
+ :param model: The model to load the state dict into.
+ :param ckpt_path: The path to the checkpoint file.
+ """
+ if ckpt_path is None:
+ return model, None
+
+ # The ckpt_path ending with .ckpt is a checkpoint file saved by pytorch-lightning.
+ # If the ckpt_path is a .pth file, it is viewed as a checkpoint file saved by pytorch
+ # such that only net parameters are loaded.
+ # (This may avoid the ambiguity of loading #epochs/lr for finetuning)
+ if ckpt_path.endswith(".pth"):
+ net_params = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']
+ net_params = {k.replace('net.', ''): v for k, v in net_params.items()}
+ model.net.load_state_dict(net_params)
+ ckpt_path = None
+ elif ckpt_path.endswith(".ckpt"):
+ # will be handled later by the trainer
+ pass
+ else:
+ # suffix check
+ raise ValueError(f"ckpt_path {ckpt_path} is not a valid checkpoint file.")
+
+ return model, ckpt_path
\ No newline at end of file
diff --git a/analysis/src/utils/instantiators.py b/analysis/src/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b9278a465d39565942f862442ebe79549825d7
--- /dev/null
+++ b/analysis/src/utils/instantiators.py
@@ -0,0 +1,56 @@
+from typing import List
+
+import hydra
+from lightning import Callback
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config.
+
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
+ :return: A list of instantiated callbacks.
+ """
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config.
+
+ :param logger_cfg: A DictConfig object containing logger configurations.
+ :return: A list of instantiated loggers.
+ """
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/analysis/src/utils/logging_utils.py b/analysis/src/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..360abcdceec82e551995f756ce6ec3b2d06ae641
--- /dev/null
+++ b/analysis/src/utils/logging_utils.py
@@ -0,0 +1,57 @@
+from typing import Any, Dict
+
+from lightning_utilities.core.rank_zero import rank_zero_only
+from omegaconf import OmegaConf
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
+ """Controls which config parts are saved by Lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+
+ :param object_dict: A dictionary containing the following objects:
+ - `"cfg"`: A DictConfig object containing the main config.
+ - `"model"`: The Lightning model.
+ - `"trainer"`: The Lightning trainer.
+ """
+ hparams = {}
+
+ cfg = OmegaConf.to_container(object_dict["cfg"])
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/analysis/src/utils/multiprocs.py b/analysis/src/utils/multiprocs.py
new file mode 100644
index 0000000000000000000000000000000000000000..01cdbe3476485adcc234753ebc2a737425e599ef
--- /dev/null
+++ b/analysis/src/utils/multiprocs.py
@@ -0,0 +1,34 @@
+"""
+Utilities for multiprocessing.
+"""
+
+import multiprocessing as mp
+
+import numpy as np
+
+
+def mp_map(func, args, n_cpu=None, verbose=True):
+ """
+ Multiprocessing wrapper for map function.
+ """
+ if n_cpu is None:
+ n_cpu = mp.cpu_count()
+ pool = mp.Pool(n_cpu)
+ if verbose:
+ print('Running on {} CPUs'.format(n_cpu))
+ result = pool.map(func, args)
+ pool.close()
+ pool.join()
+ return result
+
+def parse_mp_result(result_list, sort_fn=None):
+ """Parse output from mp_map"""
+ if sort_fn is None:
+ return result_list
+ result = sorted(result_list, key=sort_fn, ) # ascending
+ return result
+
+
+# unit test
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/analysis/src/utils/plot_utils.py b/analysis/src/utils/plot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3316ac75194ae0e3befb6bd76fe427c6459ac3c0
--- /dev/null
+++ b/analysis/src/utils/plot_utils.py
@@ -0,0 +1,100 @@
+from typing import Dict, Optional
+import os
+
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from scipy.stats import gaussian_kde
+from scipy import interpolate
+
+
+##################################################
+FONTSIZE = 18
+##################################################
+
+
+def scatterplot_2d(
+ data_dict: Dict,
+ save_to: str,
+ ref_key: str = 'target',
+ xlabel: str = 'tIC1',
+ ylabel: str = 'tIC2',
+ n_max_point: int = 1000,
+ pop_ref: bool = False,
+ xylim_key: bool = 'PDB_clusters',
+ plot_kde: bool = False,
+ density_mapping: Optional[Dict] = None
+):
+ # configure min max
+ if xylim_key and xylim_key in data_dict:
+ xylim = data_dict.pop(xylim_key)
+ # plot
+ x_max = max(xylim[:,0])
+ x_min = min(xylim[:,0])
+ y_max = max(xylim[:,1])
+ y_min = min(xylim[:,1])
+ else:
+ xylim = None
+ x_max = max(data_dict[ref_key][:,0])
+ x_min = min(data_dict[ref_key][:,0])
+ y_max = max(data_dict[ref_key][:,1])
+ y_min = min(data_dict[ref_key][:,1])
+
+ # Add margin.
+ x_min -= (x_max - x_min)/5.0
+ x_max += (x_max - x_min)/5.0
+ y_min -= (y_max - y_min)/5.0
+ y_max += (y_max - y_min)/5.0
+
+ # Remove reference data to save time.
+ if pop_ref:
+ data_dict.pop(ref_key)
+
+ # plot tica
+ print(f">>> Plotting scatter in 2D space. Image save to {save_to}")
+
+ # Configure subplots.
+ plot_n_row = len(data_dict) // 5 if len(data_dict) > 5 else 1 # at most 6 columns
+ plot_n_columns = len(data_dict) // plot_n_row if len(data_dict) > 5 else len(data_dict)
+ plt.figure(figsize=(6 * plot_n_columns , plot_n_row * 6))
+
+ i = 0
+ for k, v in data_dict.items():
+ i += 1
+ plt.subplot(plot_n_row, plot_n_columns, i)
+
+ if k != ref_key and v.shape[0] > n_max_point: # subsample for visualize
+ idx = np.random.choice(v.shape[0], n_max_point, replace=False)
+ v = v[idx]
+
+ if v.shape[0] < v.shape[1]:
+ print(f"Warning: {k} has more dimensions than samples, using uniform density.")
+ density = np.ones_like(v[:,0])
+ density /= density.sum()
+ else:
+ cov = np.transpose(v)
+ density = gaussian_kde(cov)(cov)
+
+ # Optional precomputed density mapping.
+ if density_mapping and k in density_mapping:
+ density = density_mapping[k]
+
+ plt.scatter(v[:, 0], v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40)
+ # sns.scatterplot(x=v[:, 0], y=v[:,1], s=10, alpha=0.7, c=density, cmap="mako_r", vmin=-0.05, vmax=0.40)
+
+ if plot_kde:
+ sns.kdeplot(x=data_dict[ref_key][:, 0], y=data_dict[ref_key][:,1]) # landscape
+
+ if xylim is not None:
+ plt.scatter(xylim[:,0], xylim[:,1], s=40, marker="o", c="none", edgecolors="tab:red") # cluster centers
+
+ plt.xlabel(xlabel, fontsize=FONTSIZE, fontfamily="sans-serif")
+ if (i-1) % plot_n_columns == 0:
+ plt.ylabel(ylabel, fontsize=FONTSIZE, fontfamily="sans-serif")
+
+ plt.xlim(x_min, x_max)
+ plt.ylim(y_min, y_max)
+ plt.title(k, fontsize=FONTSIZE, fontfamily="sans-serif")
+
+ plt.tight_layout()
+ plt.savefig(save_to, dpi=500)
\ No newline at end of file
diff --git a/analysis/src/utils/pylogger.py b/analysis/src/utils/pylogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ee8675ebde11b2a43b0679a03cd88d9268bc71
--- /dev/null
+++ b/analysis/src/utils/pylogger.py
@@ -0,0 +1,51 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = False,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/analysis/src/utils/rich_utils.py b/analysis/src/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeec6806bb1e4a15a04b91b710a546231590ab14
--- /dev/null
+++ b/analysis/src/utils/rich_utils.py
@@ -0,0 +1,99 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning_utilities.core.rank_zero import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
+
+ :param cfg: A DictConfig composed by Hydra.
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
+ "callbacks", "logger", "trainer", "paths", "extras")``.
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
+ """
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ queue.append(field) if field in cfg else log.warning(
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.
+
+ :param cfg: A DictConfig composed by Hydra.
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
+ """
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/analysis/src/utils/tensor_utils.py b/analysis/src/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b61735bd28b939ffaa196f14346b56d1429b0f60
--- /dev/null
+++ b/analysis/src/utils/tensor_utils.py
@@ -0,0 +1,149 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import List
+
+
+import torch
+import torch.nn as nn
+
+
+def inflate_array_like(array, target):
+ """ (tested)
+ Inflates the array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty
+ axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
+ Args:
+ array: (B, )
+ target: (B, ...)
+
+ Returns:
+ array: (B, ...)
+ """
+ if isinstance(array, float):
+ return array
+
+ diff_dims = target.ndim - array.ndim
+ assert diff_dims >= 0, f'Error: target.ndim {target.ndim} < array.ndim {array.ndim}'
+ if diff_dims == 0:
+ return array
+ assert target.shape[:array.ndim] == array.shape[:array.ndim], f'Error: target.shape[:array.ndim] {target.shape[:array.ndim]} != array.shape[:array.ndim] {array.shape[:array.ndim]}'
+ return array[(...,) + (None,) * diff_dims]
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+def sum_except_batch(t: torch.Tensor, batch_dims: int=1):
+ return t.reshape(t.shape[:batch_dims] + (-1,)).sum(dim=-1)
+
+def masked_mean(mask, value, dim, eps=1e-4):
+ mask = mask.expand(*value.shape)
+ return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
+
+
+def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
+ boundaries = torch.linspace(
+ min_bin, max_bin, no_bins - 1, device=pts.device
+ )
+ dists = torch.sqrt(
+ torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
+ )
+ return torch.bucketize(dists, boundaries)
+
+
+def dict_multimap(fn, dicts):
+ first = dicts[0]
+ new_dict = {}
+ for k, v in first.items():
+ all_v = [d[k] for d in dicts]
+ if type(v) is dict:
+ new_dict[k] = dict_multimap(fn, all_v)
+ else:
+ new_dict[k] = fn(all_v)
+
+ return new_dict
+
+
+def one_hot(x, v_bins):
+ reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
+ diffs = x[..., None] - reshaped_bins
+ am = torch.argmin(torch.abs(diffs), dim=-1)
+ return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
+
+
+def batched_gather(data, inds, dim=0, no_batch_dims=0):
+ ranges = []
+ for i, s in enumerate(data.shape[:no_batch_dims]):
+ r = torch.arange(s)
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
+ ranges.append(r)
+
+ remaining_dims = [
+ slice(None) for _ in range(len(data.shape) - no_batch_dims)
+ ]
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
+ ranges.extend(remaining_dims)
+ return data[ranges]
+
+
+# With tree_map, a poor man's JAX tree_map
+def dict_map(fn, dic, leaf_type):
+ new_dict = {}
+ for k, v in dic.items():
+ if type(v) is dict:
+ new_dict[k] = dict_map(fn, v, leaf_type)
+ else:
+ new_dict[k] = tree_map(fn, v, leaf_type)
+
+ return new_dict
+
+
+def tree_map(fn, tree, leaf_type):
+ if isinstance(tree, dict):
+ return dict_map(fn, tree, leaf_type)
+ elif isinstance(tree, list):
+ return [tree_map(fn, x, leaf_type) for x in tree]
+ elif isinstance(tree, tuple):
+ return tuple([tree_map(fn, x, leaf_type) for x in tree])
+ elif isinstance(tree, leaf_type):
+ return fn(tree)
+ else:
+ print(type(tree))
+ raise ValueError("Not supported")
+
+
+tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
+
+def _fetch_dims(tree):
+ shapes = []
+ tree_type = type(tree)
+ if tree_type is dict:
+ for v in tree.values():
+ shapes.extend(_fetch_dims(v))
+ elif tree_type is list or tree_type is tuple:
+ for t in tree:
+ shapes.extend(_fetch_dims(t))
+ elif tree_type is torch.Tensor:
+ shapes.append(tree.shape)
+ else:
+ raise ValueError("Not supported")
+
+ return shapes
diff --git a/analysis/src/utils/utils.py b/analysis/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..02b55765ad3de9441bed931a577ebbb3b669fda4
--- /dev/null
+++ b/analysis/src/utils/utils.py
@@ -0,0 +1,119 @@
+import warnings
+from importlib.util import find_spec
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from omegaconf import DictConfig
+
+from src.utils import pylogger, rich_utils
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+
+ :param cfg: A DictConfig object containing the config tree.
+ """
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ rich_utils.enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ ...
+ return metric_dict, object_dict
+ ```
+
+ :param task_func: The task function to be wrapped.
+
+ :return: The wrapped task function.
+ """
+
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
+ # so when using hparam search plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.output_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
+ """Safely retrieves value of the metric logged in LightningModule.
+
+ :param metric_dict: A dict containing metric values.
+ :param metric_name: If provided, the name of the metric to retrieve.
+ :return: If a metric name was provided, the value of the metric.
+ """
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
diff --git a/analysis/utils.py b/analysis/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce32888f21cc0f107d39c5a3264fedcafa431e6a
--- /dev/null
+++ b/analysis/utils.py
@@ -0,0 +1,74 @@
+import numpy as np
+import os
+import re
+from data import protein
+from openfold.utils import rigid_utils
+
+
+Rigid = rigid_utils.Rigid
+
+
+def create_full_prot(
+ atom37: np.ndarray,
+ atom37_mask: np.ndarray,
+ aatype=None,
+ b_factors=None,
+ ):
+ assert atom37.ndim == 3
+ assert atom37.shape[-1] == 3
+ assert atom37.shape[-2] == 37
+ n = atom37.shape[0]
+ residue_index = np.arange(n)
+ chain_index = np.zeros(n)
+ if b_factors is None:
+ b_factors = np.zeros([n, 37])
+ if aatype is None:
+ aatype = np.zeros(n, dtype=int)
+ return protein.Protein(
+ atom_positions=atom37,
+ atom_mask=atom37_mask,
+ aatype=aatype,
+ residue_index=residue_index,
+ chain_index=chain_index,
+ b_factors=b_factors)
+
+
+def write_prot_to_pdb(
+ prot_pos: np.ndarray,
+ file_path: str,
+ aatype: np.ndarray=None,
+ overwrite=False,
+ no_indexing=False,
+ b_factors=None,
+ ):
+ if overwrite:
+ max_existing_idx = 0
+ else:
+ file_dir = os.path.dirname(file_path)
+ file_name = os.path.basename(file_path).strip('.pdb')
+ existing_files = [x for x in os.listdir(file_dir) if file_name in x]
+ max_existing_idx = max([
+ int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
+ if re.findall(r'_(\d+).pdb', x)] + [0])
+ if not no_indexing:
+ save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
+ else:
+ save_path = file_path
+ with open(save_path, 'w') as f:
+ if prot_pos.ndim == 4:
+ for t, pos37 in enumerate(prot_pos):
+ atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
+ prot = create_full_prot(
+ pos37, atom37_mask, aatype=aatype, b_factors=b_factors)
+ pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False)
+ f.write(pdb_prot)
+ elif prot_pos.ndim == 3:
+ atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7
+ prot = create_full_prot(
+ prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors)
+ pdb_prot = protein.to_pdb(prot, model=1, add_end=False)
+ f.write(pdb_prot)
+ else:
+ raise ValueError(f'Invalid positions shape {prot_pos.shape}')
+ f.write('END')
+ return save_path
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8cc27fcf4b9c6a2d28d07a361b20f9645314aff6
--- /dev/null
+++ b/configs/base.yaml
@@ -0,0 +1,144 @@
+data:
+ # CSV for path and metadata to training examples.
+ dataset:
+ seed: 123
+ use_rotate_enhance: False
+ split_frac: 1.0 # 1.0 means all data except validation set is used for training
+ use_split: True # whether cut the sequence to slices
+ split_len: 600 # 328
+ min_num_res: 32
+ max_num_res: 2000
+ cache_num_res: 0
+ subset: null
+ samples_per_eval_length: 5
+ num_eval_lengths: 32
+ max_eval_length: 2000
+ max_valid_num: 50
+ csv_path: ./dataset/ATLAS/select/pkl/metadata_merged.csv
+ loader:
+ num_workers: 8
+ prefetch_factor: 10
+ sampler:
+ max_batch_size: 64
+ max_num_res_squared: 500000 # dynamic batch size
+ use_batch_repeats: false
+ num_batches: null
+
+interpolant:
+ min_t: 0.01
+ separate_t: true
+ provide_kappa: true
+ hierarchical_t: false
+ add_noise: True
+ rots:
+ train_schedule: linear
+ sample_schedule: exp
+ exp_rate: 10
+ noise_scale: 1.5
+ trans:
+ batch_ot: true
+ train_schedule: linear
+ sample_schedule: linear
+ noise_scale: 10.0 # noise for esmfold_pred to get prior
+ sample_temp: 1.0
+ vpsde_bmin: 0.1
+ vpsde_bmax: 20.0
+ sampling:
+ num_timesteps: 100
+ do_sde: false
+ self_condition: ${model.edge_features.self_condition}
+
+model:
+ dropout: 0.2 # 0.2
+ node_embed_size: 256
+ edge_embed_size: 128
+ symmetric: False
+ use_torsions: True
+ use_adapter_node: True
+ use_mid_bb_update: False
+ use_mid_bb_update_e3: False
+ use_e3_transformer: False
+ node_features:
+ c_s: ${model.node_embed_size}
+ c_pos_emb: 1280
+ # c_pos_emb: ${model.edge_embed_size}
+ c_timestep_emb: ${model.node_embed_size}
+ embed_diffuse_mask: False
+ separate_t: ${interpolant.separate_t}
+ max_num_res: 2000
+ timestep_int: 1000
+ dropout: ${model.dropout}
+ edge_features:
+ single_bias_transition_n: 2
+ c_s: ${model.node_embed_size}
+ c_p: ${model.edge_embed_size}
+ relpos_k: 64
+ use_rbf: True
+ num_rbf: 64
+ feat_dim: ${model.edge_embed_size}
+ num_bins: 64
+ # max_dist: angstrom
+ max_dist: 30.0 # 50.0
+ dropout: ${model.dropout}
+ self_condition: True
+ ipa:
+ c_s: ${model.node_embed_size}
+ c_z: ${model.edge_embed_size}
+ c_hidden: ${model.edge_embed_size}
+ no_heads: 16
+ no_qk_points: 32
+ no_v_points: 8
+ seq_tfmr_num_heads: 16
+ seq_tfmr_num_layers: 2
+ num_blocks: 6
+ dropout: ${model.dropout}
+
+experiment:
+ debug: False
+ seed: 123
+ num_devices: 1
+ # warm_start: ./weights/esm_egf4.ckpt
+ warm_start: null
+ warm_start_cfg_override: True
+ use_swa: False
+ batch_ot:
+ enabled: True
+ cost: kabsch
+ noise_per_sample: 1
+ permute: False
+ training:
+ min_plddt_mask: null
+ t_normalize_clip: 0.9
+ loss: se3_vf_loss
+ # trans_scale: 0.1
+ translation_loss_weight: 0.2
+ rotation_loss_weights: 1.0 # 1.0
+ bb_atom_scale: 0.1 # 0.1
+ aux_loss_weight: 1.0 # 1.0
+ aux_loss_t_pass: 0.25
+ aatype_loss_weight: 0.0
+ wandb:
+ name: loss_full
+ project: P2DFlow
+ save_code: True
+ tags: []
+ optimizer:
+ lr: 1e-4 # 1e-4
+ trainer:
+ overfit_batches: 0
+ min_epochs: 1 # prevents early stopping
+ max_epochs: 2000
+ accelerator: gpu
+ log_every_n_steps: 1
+ deterministic: False
+ # strategy: ddp
+ strategy: ddp_find_unused_parameters_true
+ check_val_every_n_epoch: 1
+ accumulate_grad_batches: 2
+ gradient_clip_val: 10
+ checkpointer:
+ dirpath: ckpt/${experiment.wandb.project}/${experiment.wandb.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
+ save_last: True
+ save_top_k: 2000
+ monitor: valid/rmsd_loss
+ mode: min
diff --git a/configs/inference.yaml b/configs/inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..86793d802acfc85a33e1f94049f067786c8e0de0
--- /dev/null
+++ b/configs/inference.yaml
@@ -0,0 +1,15 @@
+inference:
+ # Use this to write with date-time stamp.
+ name: ${now:%Y-%m-%d}_${now:%H-%M}
+ seed: 123
+ ckpt_path: ./weights/pretrained.ckpt
+ output_dir: inference_outputs/
+ use_gpu: True
+ num_gpus: 1
+
+ samples:
+ validset_path: ./inference/valid_seq.csv
+ esm_savepath: ${inference.output_dir}
+ sample_num: 250
+ sample_batch: 5
+
diff --git a/data/ESMfold_pred.py b/data/ESMfold_pred.py
new file mode 100644
index 0000000000000000000000000000000000000000..3782f88727d98c9f782b9f364e0e74811b7b5457
--- /dev/null
+++ b/data/ESMfold_pred.py
@@ -0,0 +1,35 @@
+import esm
+import torch
+
+from Bio import SeqIO
+
+class ESMFold_Pred():
+ def __init__(self, device):
+ self._folding_model = esm.pretrained.esmfold_v1().eval()
+ self._folding_model.requires_grad_(False)
+ self._folding_model.to(device)
+
+ def predict_str(self, pdbfile, save_path, max_seq_len = 1500):
+ seq_record = SeqIO.parse(pdbfile, "pdb-atom")
+ count = 0
+ seq_list = []
+ for record in seq_record:
+ seq = str(record.seq)
+ # seq = seq.replace("X","")
+
+ if len(seq) > max_seq_len:
+ continue
+
+ print(f'seq {count}:',seq)
+ seq_list.append(seq)
+ count += 1
+
+ for idx, seq in enumerate(seq_list):
+ with torch.no_grad():
+ output = self._folding_model.infer_pdb(seq)
+ with open(save_path, "w+") as f:
+ f.write(output)
+ break # only infer for the first seq
+
+
+
diff --git a/data/__pycache__/ESMfold_pred.cpython-310.pyc b/data/__pycache__/ESMfold_pred.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90c24df2a1c5df266ab49f8f00d0371a9e26ea56
Binary files /dev/null and b/data/__pycache__/ESMfold_pred.cpython-310.pyc differ
diff --git a/data/__pycache__/all_atom.cpython-310.pyc b/data/__pycache__/all_atom.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..465801d3621feaeb3df0fd0a490d84641954e3fa
Binary files /dev/null and b/data/__pycache__/all_atom.cpython-310.pyc differ
diff --git a/data/__pycache__/cal_trans_rotmats.cpython-310.pyc b/data/__pycache__/cal_trans_rotmats.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e54e6cf8ca9e5f14c5025946f335a6a02434734
Binary files /dev/null and b/data/__pycache__/cal_trans_rotmats.cpython-310.pyc differ
diff --git a/data/__pycache__/errors.cpython-310.pyc b/data/__pycache__/errors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ceeb5873186dc6804d8e68e9c5dc67932bea353
Binary files /dev/null and b/data/__pycache__/errors.cpython-310.pyc differ
diff --git a/data/__pycache__/interpolant.cpython-310.pyc b/data/__pycache__/interpolant.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..414afb84f0859964fd0a27243a241bcea766f814
Binary files /dev/null and b/data/__pycache__/interpolant.cpython-310.pyc differ
diff --git a/data/__pycache__/parsers.cpython-310.pyc b/data/__pycache__/parsers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7aa7575b05a659c8d9fea738835b6b53b10ed091
Binary files /dev/null and b/data/__pycache__/parsers.cpython-310.pyc differ
diff --git a/data/__pycache__/pdb_dataloader.cpython-310.pyc b/data/__pycache__/pdb_dataloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d864214417cd92ee5c85fa2a42ade625eccd025
Binary files /dev/null and b/data/__pycache__/pdb_dataloader.cpython-310.pyc differ
diff --git a/data/__pycache__/protein.cpython-310.pyc b/data/__pycache__/protein.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38d42ce9ba8247026fb7220e63b82bb51f7c6f1d
Binary files /dev/null and b/data/__pycache__/protein.cpython-310.pyc differ
diff --git a/data/__pycache__/protein.cpython-38.pyc b/data/__pycache__/protein.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e20a9345bf101231b0840ce932c373f81a5e21b8
Binary files /dev/null and b/data/__pycache__/protein.cpython-38.pyc differ
diff --git a/data/__pycache__/repr.cpython-310.pyc b/data/__pycache__/repr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c46ecaf7b5c0ebe47982c57c86cc1707dc54943c
Binary files /dev/null and b/data/__pycache__/repr.cpython-310.pyc differ
diff --git a/data/__pycache__/residue_constants.cpython-310.pyc b/data/__pycache__/residue_constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bec5e8f79ec87bd090fc6b342da8e54efa3dc019
Binary files /dev/null and b/data/__pycache__/residue_constants.cpython-310.pyc differ
diff --git a/data/__pycache__/residue_constants.cpython-38.pyc b/data/__pycache__/residue_constants.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7aa3129887b89f164b60dffb65189952c7f9c40
Binary files /dev/null and b/data/__pycache__/residue_constants.cpython-38.pyc differ
diff --git a/data/__pycache__/so3_utils.cpython-310.pyc b/data/__pycache__/so3_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f865ab345204a0c5824d8a65813991dffdfd02ce
Binary files /dev/null and b/data/__pycache__/so3_utils.cpython-310.pyc differ
diff --git a/data/__pycache__/utils.cpython-310.pyc b/data/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d50c90627c6c52db534b8eeb036277693a3b88a7
Binary files /dev/null and b/data/__pycache__/utils.cpython-310.pyc differ
diff --git a/data/all_atom.py b/data/all_atom.py
new file mode 100644
index 0000000000000000000000000000000000000000..05e7bd41fa54fea9663b9e5eec780aa25c16517b
--- /dev/null
+++ b/data/all_atom.py
@@ -0,0 +1,371 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utilities for calculating all atom representations.
+Code adapted from OpenFold.
+"""
+
+import torch
+from openfold.data import data_transforms
+from openfold.np import residue_constants
+from openfold.utils import rigid_utils as ru
+from data import utils as du
+
+Rigid = ru.Rigid
+Rotation = ru.Rotation
+
+# Residue Constants from OpenFold/AlphaFold2.
+
+
+IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions)
+DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame)
+ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask)
+GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group)
+
+IDEALIZED_POS_37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions)
+ATOM_MASK_37 = torch.tensor(residue_constants.restype_atom37_mask)
+GROUP_IDX_37 = torch.tensor(residue_constants.restype_atom37_to_rigid_group)
+
+def to_atom37(trans, rots, aatype=None, torsions_with_CB=None, get_mask=False):
+ num_batch, num_res, _ = trans.shape
+
+ if torsions_with_CB is None: # (B,L,)
+ torsions_with_CB = torch.concat(
+ [torch.zeros((num_batch,num_res,8,1),device=trans.device),
+ torch.ones((num_batch,num_res,8,1),device=trans.device)],
+ dim=-1
+ ) # (B,L,8,2)
+
+ # final_atom37 = compute_backbone(
+ # du.create_rigid(rots, trans),
+ # torch.zeros(num_batch, num_res, 2, device=trans.device)
+ # )[0]
+
+ final_atom37, atom37_mask = compute_atom37_pos(
+ du.create_rigid(rots, trans),
+ torsions_with_CB,
+ aatype=aatype,
+ )[:2] # (B,L,37,3)
+
+ if get_mask:
+ return final_atom37, atom37_mask
+ else:
+ return final_atom37
+
+def torsion_angles_to_frames(
+ r: Rigid, # type: ignore [valid-type]
+ alpha: torch.Tensor,
+ aatype: torch.Tensor,
+ bb_rot = None
+):
+ """Conversion method of torsion angles to frames provided the backbone.
+
+ Args:
+ r: Backbone rigid groups.
+ alpha: Torsion angles. (B,L,7,2)
+ aatype: residue types.
+
+ Returns:
+ All 8 frames corresponding to each torsion frame.
+
+
+
+
+
+ !!! May need to set omega and fai angle to be zero !!!
+
+
+
+
+
+ """
+ # [*, N, 8, 4, 4]
+ with torch.no_grad():
+ default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] # type: ignore [attr-defined]
+
+ # [*, N, 8] transformations, i.e.
+ # One [*, N, 8, 3, 3] rotation matrix and
+ # One [*, N, 8, 3] translation matrix
+ default_r = r.from_tensor_4x4(default_4x4) # (B,L,8)
+
+ if bb_rot is None:
+ bb_rot = alpha.new_zeros(((1,) * len(alpha.shape[:-1]))+(2,)) # (1,1,1,2)
+ bb_rot[..., 1] = 1
+ bb_rot = bb_rot.expand(*alpha.shape[:-2], -1, -1) # (B,L,1,2)
+
+ alpha = torch.cat([bb_rot, alpha], dim=-2) # (B,L,8,2)
+
+ # [*, N, 8, 3, 3]
+ # Produces rotation matrices of the form:
+ # [
+ # [1, 0 , 0 ],
+ # [0, a_2,-a_1],
+ # [0, a_1, a_2]
+ # ]
+ # This follows the original code rather than the supplement, which uses
+ # different indices.
+
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) # (B,L,8,3,3)
+ all_rots[..., 0, 0] = 1
+ all_rots[..., 1, 1] = alpha[..., 1]
+ all_rots[..., 1, 2] = -alpha[..., 0]
+ all_rots[..., 2, 1:] = alpha
+
+ all_rots = Rigid(Rotation(rot_mats=all_rots), None)
+
+ all_frames = default_r.compose(all_rots)
+
+ chi2_frame_to_frame = all_frames[..., 5]
+ chi3_frame_to_frame = all_frames[..., 6]
+ chi4_frame_to_frame = all_frames[..., 7]
+
+ chi1_frame_to_bb = all_frames[..., 4]
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
+
+ all_frames_to_bb = Rigid.cat(
+ [
+ all_frames[..., :5],
+ chi2_frame_to_bb.unsqueeze(-1),
+ chi3_frame_to_bb.unsqueeze(-1),
+ chi4_frame_to_bb.unsqueeze(-1),
+ ],
+ dim=-1,
+ )
+
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb) # type: ignore [index]
+
+ return all_frames_to_global
+
+
+def prot_to_torsion_angles(aatype, atom37, atom37_mask):
+ """Calculate torsion angle features from protein features."""
+ prot_feats = {
+ "aatype": aatype,
+ "all_atom_positions": atom37,
+ "all_atom_mask": atom37_mask,
+ }
+ torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats)
+ torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"]
+ torsion_mask = torsion_angles_feats["torsion_angles_mask"]
+ return torsion_angles, torsion_mask
+
+
+def frames_to_atom14_pos(
+ r: Rigid, # type: ignore [valid-type]
+ aatype: torch.Tensor,
+):
+ """Convert frames to their idealized all atom representation.
+
+ Args:
+ r: All rigid groups. [B,L,8]
+ aatype: Residue types. [B,L]
+
+ Returns:
+
+ """
+ with torch.no_grad():
+ group_mask = GROUP_IDX.to(aatype.device)[aatype, ...] # (B,L,14)
+ group_mask = torch.nn.functional.one_hot(
+ group_mask,
+ num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4)
+ ) # (B,L,14,8)
+ frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,14,1)
+ frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] # (B,L,14,3)
+
+ t_atoms_to_global = r[..., None, :] * group_mask # (B,L,14,8)
+
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,14)
+
+ pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,14,3)
+ pred_positions = pred_positions * frame_atom_mask
+
+ return pred_positions
+
+def frames_to_atom37_pos(
+ r: Rigid, # type: ignore [valid-type]
+ aatype: torch.Tensor,
+):
+ """Convert frames to their idealized all atom representation.
+
+ Args:
+ r: All rigid groups. [B,L]
+ aatype: Residue types. [B,L]
+
+ Returns:
+
+ """
+ with torch.no_grad():
+ group_mask = GROUP_IDX_37.to(aatype.device)[aatype, ...] # (B,L,37)
+ group_mask = torch.nn.functional.one_hot(
+ group_mask,
+ num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4)
+ ) # (B,L,37,8)
+ frame_atom_mask = ATOM_MASK_37.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,37,1)
+ frame_null_pos = IDEALIZED_POS_37.to(aatype.device)[aatype, ...] # (B,L,37,3)
+
+ t_atoms_to_global = r[..., None, :] * group_mask # (B,L,37,8)
+
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,37)
+
+ pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,37,3)
+ pred_positions = pred_positions * frame_atom_mask
+
+ return pred_positions, frame_atom_mask[...,0]
+
+#! may need to remove it
+def compute_backbone(bb_rigids, psi_torsions, aatype=None):
+ torsion_angles = torch.tile(
+ psi_torsions[..., None, :], tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1)
+ )
+ '''
+ psi_torsions[..., None, :].shape: (B,L,1,2)
+ torsion_angles.shape: (B,L,7,2)
+ bb_rigids.shape: (B,L)
+ '''
+
+ if aatype is None:
+ aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long()
+
+ all_frames = torsion_angles_to_frames(
+ bb_rigids,
+ torsion_angles,
+ aatype,
+ )
+ atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3)
+ atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device) # (B,L,37,3)
+
+ # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
+ # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
+ atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
+ atom37_mask = torch.any(atom37_bb_pos, axis=-1) # mask atom with all 0 xyz
+
+ return atom37_bb_pos, atom37_mask, aatype, atom14_pos
+
+def compute_atom37_pos(bb_rigids, torsions_with_CB, aatype=None):
+ '''
+ torsions_with_CB.shape: (B,L,8,2)
+ bb_rigids.shape: (B,L)
+ '''
+
+ if aatype is None:
+ aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long()
+
+ all_frames = torsion_angles_to_frames(
+ bb_rigids,
+ torsions_with_CB[:,:,1:,:],
+ aatype,
+ bb_rot = torsions_with_CB[:,:,0:1,:],
+ )
+
+ atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3)
+ atom37_pos,atom37_mask = frames_to_atom37_pos(all_frames, aatype) # (B,L,37,3)
+
+ return atom37_pos, atom37_mask, aatype, atom14_pos
+
+def calculate_neighbor_angles(R_ac, R_ab):
+ """Calculate angles between atoms c <- a -> b.
+
+ Parameters
+ ----------
+ R_ac: Tensor, shape = (N,3)
+ Vector from atom a to c.
+ R_ab: Tensor, shape = (N,3)
+ Vector from atom a to b.
+
+ Returns
+ -------
+ angle_cab: Tensor, shape = (N,)
+ Angle between atoms c <- a -> b.
+ """
+ # cos(alpha) = (u * v) / (|u|*|v|)
+ x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
+ # sin(alpha) = |u x v| / (|u|*|v|)
+ y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
+ # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
+ y = torch.max(y, torch.tensor(1e-9))
+ angle = torch.atan2(y, x)
+ return angle
+
+
+def vector_projection(R_ab, P_n):
+ """
+ Project the vector R_ab onto a plane with normal vector P_n.
+
+ Parameters
+ ----------
+ R_ab: Tensor, shape = (N,3)
+ Vector from atom a to b.
+ P_n: Tensor, shape = (N,3)
+ Normal vector of a plane onto which to project R_ab.
+
+ Returns
+ -------
+ R_ab_proj: Tensor, shape = (N,3)
+ Projected vector (orthogonal to P_n).
+ """
+ a_x_b = torch.sum(R_ab * P_n, dim=-1)
+ b_x_b = torch.sum(P_n * P_n, dim=-1)
+ return R_ab - (a_x_b / b_x_b)[:, None] * P_n
+
+
+def transrot_to_atom37(transrot_traj, res_mask, aatype=None, torsions_with_CB=None):
+ atom37_traj = []
+ res_mask = res_mask.detach().cpu()
+ num_batch = res_mask.shape[0]
+
+ for trans, rots in transrot_traj:
+ atom37 = to_atom37(trans, rots, aatype=aatype, torsions_with_CB=torsions_with_CB,get_mask=False)
+ atom37 = atom37.detach().cpu()
+ # batch_atom37 = []
+ # for i in range(num_batch):
+ # batch_atom37.append(
+ # du.adjust_oxygen_pos(atom37[i], res_mask[i])
+ # )
+ # atom37_traj.append(torch.stack(batch_atom37))
+ atom37_traj.append(atom37)
+ return atom37_traj
+
+# def atom37_from_trans_rot(trans, rots, res_mask):
+# rigids = du.create_rigid(rots, trans)
+# atom37 = compute_backbone(
+# rigids,
+# torch.zeros(
+# trans.shape[0],
+# trans.shape[1],
+# 2,
+# device=trans.device
+# )
+# )[0]
+# atom37 = atom37.detach().cpu()
+# batch_atom37 = []
+# num_batch = res_mask.shape[0]
+# for i in range(num_batch):
+# batch_atom37.append(
+# du.adjust_oxygen_pos(atom37[i], res_mask[i])
+# )
+# return torch.stack(batch_atom37)
+
+
+# def process_trans_rot_traj(trans_traj, rots_traj, res_mask):
+# res_mask = res_mask.detach().cpu()
+# atom37_traj = [
+# atom37_from_trans_rot(trans, rots, res_mask)
+# for trans, rots in zip(trans_traj, rots_traj)
+# ]
+# atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1)
+# return atom37_traj
diff --git a/data/cal_trans_rotmats.py b/data/cal_trans_rotmats.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c75c43cd2793763228c56ede69ebb283e8beaee
--- /dev/null
+++ b/data/cal_trans_rotmats.py
@@ -0,0 +1,60 @@
+import torch
+import dataclasses
+
+import numpy as np
+
+from Bio import PDB
+from data import parsers, errors
+from data import utils as du
+from openfold.data import data_transforms
+from openfold.utils import rigid_utils
+
+
+def cal_trans_rotmats(save_path):
+ metadata = {}
+ parser = PDB.PDBParser(QUIET=True)
+ structure = parser.get_structure('test', save_path)
+
+ # Extract all chains
+ struct_chains = {
+ chain.id.upper(): chain
+ for chain in structure.get_chains()}
+ metadata['num_chains'] = len(struct_chains)
+ # Extract features
+ struct_feats = []
+ all_seqs = set()
+ for chain_id, chain in struct_chains.items():
+ # Convert chain id into int
+ chain_id = du.chain_str_to_int(chain_id)
+ chain_prot = parsers.process_chain(chain, chain_id)
+ chain_dict = dataclasses.asdict(chain_prot)
+ chain_dict = du.parse_chain_feats(chain_dict)
+ all_seqs.add(tuple(chain_dict['aatype']))
+ struct_feats.append(chain_dict)
+ if len(all_seqs) == 1:
+ metadata['quaternary_category'] = 'homomer'
+ else:
+ metadata['quaternary_category'] = 'heteromer'
+ complex_feats = du.concat_np_features(struct_feats, False)
+ # Process geometry features
+ complex_aatype = complex_feats['aatype']
+ metadata['seq_len'] = len(complex_aatype)
+ modeled_idx = np.where(complex_aatype != 20)[0]
+ if np.sum(complex_aatype != 20) == 0:
+ raise errors.LengthError('No modeled residues')
+ min_modeled_idx = np.min(modeled_idx)
+ max_modeled_idx = np.max(modeled_idx)
+ metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1
+ complex_feats['modeled_idx'] = modeled_idx
+
+ processed_feats = du.parse_chain_feats(complex_feats)
+ chain_feats_temp = {
+ 'aatype': torch.tensor(processed_feats['aatype']).long(),
+ 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(),
+ 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double()
+ }
+ chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp)
+ curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0]
+ trans = curr_rigid.get_trans().cpu()
+ rotmats = curr_rigid.get_rots().get_rot_mats().cpu()
+ return trans, rotmats
\ No newline at end of file
diff --git a/data/errors.py b/data/errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdf1026227fa6f2881c0a17b9a4cc53f3ac77317
--- /dev/null
+++ b/data/errors.py
@@ -0,0 +1,26 @@
+"""Error class for handled errors."""
+
+
+class DataError(Exception):
+ """Data exception."""
+ pass
+
+
+class FileExistsError(DataError):
+ """Raised when file already exists."""
+ pass
+
+
+class MmcifParsingError(DataError):
+ """Raised when mmcif parsing fails."""
+ pass
+
+
+class ResolutionError(DataError):
+ """Raised when resolution isn't acceptable."""
+ pass
+
+
+class LengthError(DataError):
+ """Raised when length isn't acceptable."""
+ pass
\ No newline at end of file
diff --git a/data/interpolant.py b/data/interpolant.py
new file mode 100644
index 0000000000000000000000000000000000000000..b66ec5f4442759bfccdc29b1be1dea6d273f01b1
--- /dev/null
+++ b/data/interpolant.py
@@ -0,0 +1,310 @@
+import torch
+import numpy as np
+from data import so3_utils
+from data import utils as du
+from scipy.spatial.transform import Rotation
+from data import all_atom
+import copy
+from scipy.optimize import linear_sum_assignment
+
+
+def _centered_gaussian(num_batch, num_res, device):
+ noise = torch.randn(num_batch, num_res, 3, device=device)
+ return noise - torch.mean(noise, dim=-2, keepdims=True)
+
+def _uniform_so3(num_batch, num_res, device):
+ return torch.tensor(
+ Rotation.random(num_batch*num_res).as_matrix(),
+ device=device,
+ dtype=torch.float32,
+ ).reshape(num_batch, num_res, 3, 3)
+
+def _trans_diffuse_mask(trans_t, trans_1, diffuse_mask):
+ return trans_t * diffuse_mask[..., None] + trans_1 * (1 - diffuse_mask[..., None])
+
+def _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask):
+ return (
+ rotmats_t * diffuse_mask[..., None, None]
+ + rotmats_1 * (1 - diffuse_mask[..., None, None])
+ )
+
+
+class Interpolant:
+
+ def __init__(self, cfg):
+ self._cfg = cfg
+ self._rots_cfg = cfg.rots
+ self._trans_cfg = cfg.trans
+ self._sample_cfg = cfg.sampling
+ self.add_noise = cfg.add_noise
+ self._igso3 = None
+
+ @property
+ def igso3(self):
+ if self._igso3 is None:
+ sigma_grid = torch.linspace(0.1, 1.5, 1000)
+ self._igso3 = so3_utils.SampleIGSO3(
+ 1000, sigma_grid, cache_dir='.cache')
+ return self._igso3
+
+ def set_device(self, device):
+ self._device = device
+
+ def sample_t(self, num_batch):
+ # t: [min_t, 1-min_t]
+ t = torch.rand(num_batch, device=self._device)
+ return t * (1 - 2*self._cfg.min_t) + self._cfg.min_t
+
+ def _esmfold_gaussian(self, num_batch, num_res, device, trans_esmfold):
+ noise = torch.randn(num_batch, num_res, 3, device=device) # (B,L,3)
+ noise = self._trans_cfg.noise_scale * noise + trans_esmfold
+ return noise - torch.mean(noise, dim=-2, keepdims=True)
+
+ def _corrupt_trans(self, trans_1, t, res_mask, trans_esmfold):
+ # trans_nm_0 = _centered_gaussian(*res_mask.shape, self._device)
+ # trans_0 = trans_nm_0 * du.NM_TO_ANG_SCALE
+
+ if self.add_noise:
+ trans_0 = self._esmfold_gaussian(*res_mask.shape, self._device, trans_esmfold)
+ else:
+ trans_0 = trans_esmfold
+
+
+ trans_0 = self._batch_ot(trans_0, trans_1, res_mask)
+ trans_t = (1 - t[..., None]) * trans_0 + t[..., None] * trans_1
+ trans_t = _trans_diffuse_mask(trans_t, trans_1, res_mask)
+ return trans_t * res_mask[..., None]
+
+ def _batch_ot(self, trans_0, trans_1, res_mask):
+ num_batch, num_res = trans_0.shape[:2]
+ noise_idx, gt_idx = torch.where(
+ torch.ones(num_batch, num_batch))
+ batch_nm_0 = trans_0[noise_idx]
+ batch_nm_1 = trans_1[gt_idx]
+ batch_mask = res_mask[gt_idx]
+ aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
+ batch_nm_0, batch_nm_1, mask=batch_mask
+ )
+
+ aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
+ aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)
+
+ # Compute cost matrix of aligned noise to ground truth
+ batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
+ cost_matrix = torch.sum(
+ torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
+ ) / torch.sum(batch_mask, dim=-1)
+ noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
+ return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]
+ # return aligned_nm_0
+
+ def _esmfold_igso3(self, res_mask, rotmats_esmfold):
+ num_batch, num_res = res_mask.shape
+ noisy_rotmats = self.igso3.sample(
+ torch.tensor([self._rots_cfg.noise_scale]),
+ num_batch*num_res
+ ).to(self._device)
+ noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
+ rotmats_0 = torch.einsum(
+ "...ij,...jk->...ik", rotmats_esmfold, noisy_rotmats)
+ return rotmats_0
+
+ def _corrupt_rotmats(self, rotmats_1, t, res_mask, rotmats_esmfold):
+ # num_batch, num_res = res_mask.shape
+ # noisy_rotmats = self.igso3.sample(
+ # torch.tensor([1.5]),
+ # num_batch*num_res
+ # ).to(self._device)
+ # noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
+ # rotmats_0 = torch.einsum(
+ # "...ij,...jk->...ik", rotmats_1, noisy_rotmats)
+
+ if self.add_noise:
+ rotmats_0 = self._esmfold_igso3(res_mask, rotmats_esmfold)
+ else:
+ rotmats_0 = rotmats_esmfold
+
+
+ rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
+ identity = torch.eye(3, device=self._device)
+ rotmats_t = (
+ rotmats_t * res_mask[..., None, None]
+ + identity[None, None] * (1 - res_mask[..., None, None])
+ )
+ return _rots_diffuse_mask(rotmats_t, rotmats_1, res_mask)
+
+ def corrupt_batch(self, batch):
+ noisy_batch = copy.deepcopy(batch)
+
+ # [B, N, 3]
+ trans_1 = batch['trans_1'] # Angstrom
+
+ # [B, N, 3, 3]
+ rotmats_1 = batch['rotmats_1']
+
+ # [B, N]
+ res_mask = batch['res_mask']
+ num_batch, _ = res_mask.shape
+
+ # [B, 1]
+ t = self.sample_t(num_batch)[:, None]
+ noisy_batch['t'] = t
+
+ # Apply corruptions
+ trans_t = self._corrupt_trans(trans_1, t, res_mask, batch['trans_esmfold'])
+
+ noisy_batch['trans_t'] = trans_t
+
+ rotmats_t = self._corrupt_rotmats(rotmats_1, t, res_mask, batch['rotmats_esmfold'])
+
+ noisy_batch['rotmats_t'] = rotmats_t
+
+
+ # noisy_batch['t'] = 0.5 * torch.ones_like(t)
+ # noisy_batch['trans_t'] = batch['trans_1']
+ # noisy_batch['rotmats_t'] = batch['rotmats_1']
+
+
+ return noisy_batch
+
+ def rot_sample_kappa(self, t):
+ if self._rots_cfg.sample_schedule == 'exp':
+ return 1 - torch.exp(-t*self._rots_cfg.exp_rate)
+ elif self._rots_cfg.sample_schedule == 'linear':
+ return t
+ else:
+ raise ValueError(
+ f'Invalid schedule: {self._rots_cfg.sample_schedule}')
+
+ def _trans_euler_step(self, d_t, t, trans_1, trans_t):
+ trans_vf = (trans_1 - trans_t) / (1 - t)
+ return trans_t + trans_vf * d_t
+
+ def _rots_euler_step(self, d_t, t, rotmats_1, rotmats_t):
+ if self._rots_cfg.sample_schedule == 'linear':
+ scaling = 1 / (1 - t)
+ elif self._rots_cfg.sample_schedule == 'exp':
+ scaling = self._rots_cfg.exp_rate
+ else:
+ raise ValueError(
+ f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
+ return so3_utils.geodesic_t(
+ scaling * d_t, rotmats_1, rotmats_t)
+
+ def sample(
+ self,
+ batch,
+ model,
+ ):
+ res_mask = batch['res_mask']
+ num_batch = batch['aatype'].shape[0]
+ num_res = batch['aatype'].shape[1]
+ aatype = batch['aatype']
+ motif_mask = batch.get('motif_mask',torch.ones(aatype.shape))
+
+
+ # Set-up initial prior samples
+
+ # trans_0 = _centered_gaussian(
+ # num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE
+ # rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
+
+
+ if self.add_noise:
+ trans_0 = self._esmfold_gaussian(*res_mask.shape, self._device, batch['trans_esmfold'])
+ rotmats_0 = self._esmfold_igso3(res_mask, batch['rotmats_esmfold'])
+ else:
+ trans_0 = batch['trans_esmfold']
+ rotmats_0 = batch['rotmats_esmfold']
+
+
+ if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)):
+ trans_0 = motif_mask[...,None]*trans_0+(1-motif_mask[...,None])*batch['trans_fix']
+ rotmats_0 = motif_mask[...,None,None]*rotmats_0+(1-motif_mask[...,None,None])*batch['rotmats_fix']
+
+
+ # Set-up time
+
+ ts = torch.linspace(
+ self._cfg.min_t, 1.0, self._sample_cfg.num_timesteps)
+
+ # ts = torch.linspace(np.exp(self._cfg.min_t), np.exp(1.0), self._sample_cfg.num_timesteps)
+ # ts = torch.log(ts)
+
+
+
+ t_1 = ts[0]
+
+ prot_traj = [(trans_0, rotmats_0)]
+ clean_traj = []
+ for t_2 in ts[1:]:
+
+ # Run model.
+ trans_t_1, rotmats_t_1 = prot_traj[-1]
+ batch['trans_t'] = trans_t_1
+ batch['rotmats_t'] = rotmats_t_1
+ t = torch.ones((num_batch, 1), device=self._device) * t_1
+ batch['t'] = t
+ with torch.no_grad():
+ model_out = model(batch)
+
+ # Process model output.
+
+
+ pred_trans_1 = model_out['pred_trans']
+ pred_rotmats_1 = model_out['pred_rotmats']
+ if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)):
+ pred_trans_1 = motif_mask[...,None]* pred_trans_1+(1-motif_mask[...,None])*batch['trans_fix']
+ pred_rotmats_1 = motif_mask[...,None,None]*pred_rotmats_1+(1-motif_mask[...,None,None])*batch['rotmats_fix']
+
+
+ clean_traj.append(
+ (pred_trans_1.detach(), pred_rotmats_1.detach())
+ )
+ if self._cfg.self_condition:
+ batch['trans_sc'] = pred_trans_1
+
+ # Take reverse step
+ d_t = t_2 - t_1
+
+
+ trans_t_2 = self._trans_euler_step(
+ d_t, t_1, pred_trans_1, trans_t_1)
+ rotmats_t_2 = self._rots_euler_step(
+ d_t, t_1, pred_rotmats_1, rotmats_t_1)
+ if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)):
+ trans_t_2 = motif_mask[...,None]* trans_t_2+(1-motif_mask[...,None])*batch['trans_fix']
+ rotmats_t_2 = motif_mask[...,None,None]*rotmats_t_2+(1-motif_mask[...,None,None])*batch['rotmats_fix']
+
+
+
+ prot_traj.append((trans_t_2, rotmats_t_2))
+ t_1 = t_2
+
+ # We only integrated to min_t, so need to make a final step
+ t_1 = ts[-1]
+ trans_t_1, rotmats_t_1 = prot_traj[-1]
+ batch['trans_t'] = trans_t_1
+ batch['rotmats_t'] = rotmats_t_1
+ batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
+ with torch.no_grad():
+ model_out = model(batch)
+
+
+ pred_trans_1 = model_out['pred_trans']
+ pred_rotmats_1 = model_out['pred_rotmats']
+ if not torch.all(motif_mask==torch.ones(aatype.shape,device=motif_mask.device)):
+ pred_trans_1 = motif_mask[...,None]* pred_trans_1+(1-motif_mask[...,None])*batch['trans_fix']
+ pred_rotmats_1 = motif_mask[...,None,None]*pred_rotmats_1+(1-motif_mask[...,None,None])*batch['rotmats_fix']
+
+
+ clean_traj.append(
+ (pred_trans_1.detach(), pred_rotmats_1.detach())
+ )
+ prot_traj.append((pred_trans_1, pred_rotmats_1))
+
+ # Convert trajectories to atom37.
+ atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask, aatype=aatype, torsions_with_CB=model_out['pred_torsions_with_CB'])
+ clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask, aatype=aatype, torsions_with_CB=model_out['pred_torsions_with_CB'])
+
+ return atom37_traj, clean_atom37_traj, clean_traj
diff --git a/data/parsers.py b/data/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..931da393fe82668dc19066a504a180f127c5c107
--- /dev/null
+++ b/data/parsers.py
@@ -0,0 +1,82 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Library for parsing different data structures.
+Code adapted from Openfold protein.py.
+"""
+from Bio.PDB.Chain import Chain
+import numpy as np
+
+from data import residue_constants
+from data import protein
+
+Protein = protein.Protein
+
+
+def process_chain(chain: Chain, chain_id: str) -> Protein:
+ """Convert a PDB chain object into a AlphaFold Protein instance.
+
+ Forked from alphafold.common.protein.from_pdb_string
+
+ WARNING: All non-standard residue types will be converted into UNK. All
+ non-standard atoms will be ignored.
+
+ Took out lines 94-97 which don't allow insertions in the PDB.
+ Sabdab uses insertions for the chothia numbering so we need to allow them.
+
+ Took out lines 110-112 since that would mess up CDR numbering.
+
+ Args:
+ chain: Instance of Biopython's chain class.
+
+ Returns:
+ Protein object with protein features.
+ """
+ atom_positions = []
+ aatype = []
+ atom_mask = []
+ residue_index = []
+ b_factors = []
+ chain_ids = []
+ for res in chain:
+ # for residue type "X", the chain may need to be removed
+ res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
+ restype_idx = residue_constants.restype_order.get(
+ res_shortname, residue_constants.restype_num)
+ pos = np.zeros((residue_constants.atom_type_num, 3))
+ mask = np.zeros((residue_constants.atom_type_num,))
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
+ for atom in res:
+ if atom.name not in residue_constants.atom_types:
+ continue
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
+ mask[residue_constants.atom_order[atom.name]] = 1.
+ res_b_factors[residue_constants.atom_order[atom.name]
+ ] = atom.bfactor
+ aatype.append(restype_idx)
+ atom_positions.append(pos)
+ atom_mask.append(mask)
+ residue_index.append(res.id[1])
+ b_factors.append(res_b_factors)
+ chain_ids.append(chain_id)
+
+ return Protein(
+ atom_positions=np.array(atom_positions),
+ atom_mask=np.array(atom_mask),
+ aatype=np.array(aatype),
+ residue_index=np.array(residue_index),
+ chain_index=np.array(chain_ids),
+ b_factors=np.array(b_factors))
\ No newline at end of file
diff --git a/data/pdb_dataloader.py b/data/pdb_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..68b1bd22198be930dcf13f290c52ca34dede21bd
--- /dev/null
+++ b/data/pdb_dataloader.py
@@ -0,0 +1,242 @@
+"""PDB data loader."""
+import math
+import torch
+import tree
+import numpy as np
+import torch
+import pandas as pd
+import logging
+import os
+import random
+import esm
+import copy
+
+from data import utils as du
+from data.repr import get_pre_repr
+from openfold.data import data_transforms
+from openfold.utils import rigid_utils
+from data.residue_constants import restype_atom37_mask, order2restype_with_mask
+
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.distributed import DistributedSampler, dist
+from scipy.spatial.transform import Rotation as scipy_R
+
+
+
+
+class PdbDataModule(LightningDataModule):
+ def __init__(self, data_cfg):
+ super().__init__()
+ self.data_cfg = data_cfg
+ self.loader_cfg = data_cfg.loader
+ self.dataset_cfg = data_cfg.dataset
+ self.sampler_cfg = data_cfg.sampler
+
+ def setup(self, stage: str):
+ self._train_dataset = PdbDataset(
+ dataset_cfg=self.dataset_cfg,
+ is_training=True,
+ )
+ self._valid_dataset = PdbDataset(
+ dataset_cfg=self.dataset_cfg,
+ is_training=False,
+ )
+
+ def train_dataloader(self, rank=None, num_replicas=None):
+ num_workers = self.loader_cfg.num_workers
+ return DataLoader( # default batch_size is 1, and it is expand in training_step of FlowModule
+ self._train_dataset,
+
+ # batch_sampler=LengthBatcher(
+ # sampler_cfg=self.sampler_cfg,
+ # metadata_csv=self._train_dataset.csv,
+ # rank=rank,
+ # num_replicas=num_replicas,
+ # ),
+ sampler=DistributedSampler(self._train_dataset, shuffle=True),
+
+ num_workers=self.loader_cfg.num_workers,
+ prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor,
+ persistent_workers=True if num_workers > 0 else False,
+ # persistent_workers=False,
+ )
+
+ def val_dataloader(self):
+ num_workers = self.loader_cfg.num_workers
+ return DataLoader(
+ self._valid_dataset,
+ sampler=DistributedSampler(self._valid_dataset, shuffle=False),
+ num_workers=self.loader_cfg.num_workers,
+ prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor,
+ persistent_workers=True,
+ # persistent_workers=False,
+ )
+
+
+class PdbDataset(Dataset):
+ def __init__(
+ self,
+ *,
+ dataset_cfg,
+ is_training,
+ ):
+ self._log = logging.getLogger(__name__)
+ self._is_training = is_training
+ self._dataset_cfg = dataset_cfg
+ self.split_frac = self._dataset_cfg.split_frac
+ self.random_seed = self._dataset_cfg.seed
+ # self.count = 0
+
+ self._init_metadata()
+
+ @property
+ def is_training(self):
+ return self._is_training
+
+ @property
+ def dataset_cfg(self):
+ return self._dataset_cfg
+
+ def _init_metadata(self):
+ """Initialize metadata."""
+
+ # Process CSV with different filtering criterions.
+ pdb_csv = pd.read_csv(self.dataset_cfg.csv_path)
+ self.raw_csv = pdb_csv
+ pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_num_res]
+ pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= self.dataset_cfg.min_num_res]
+
+ if self.dataset_cfg.subset is not None:
+ pdb_csv = pdb_csv.iloc[:self.dataset_cfg.subset]
+ pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False)
+
+ # energy_csv_path = self.dataset_cfg.energy_csv_path
+ # self.energy_csv = pd.read_csv(energy_csv_path)
+
+ ## Training or validation specific logic.
+ if self.is_training:
+ self.csv = pdb_csv[pdb_csv['is_trainset']]
+ self.csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed).reset_index()
+ self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"train.csv"), index=False)
+
+ # self.chain_feats_total = []
+
+ # for idx in range(len(self.csv)):
+ # if idx % 200 == 0:
+ # self._log.info(f"pre_count= {idx}")
+
+ # # processed_path = self.csv.iloc[idx]['processed_path']
+
+ # # chain_feats_temp = self._process_csv_row(processed_path)
+ # # chain_feats_temp = du.read_pkl(processed_path)
+ # # chain_feats_temp['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32)
+ # self.chain_feats_total[idx]['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32)
+
+ # # self.chain_feats_total += [chain_feats_temp]
+
+
+ # self._log.info(
+ # f"Training: {len(self.chain_feats_total)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}")
+ self._log.info(
+ f"Training: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}")
+ else:
+ self.csv = pdb_csv[~pdb_csv['is_trainset']]
+ # if self.split_frac < 1.0:
+ # train_csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed)
+ # pdb_csv = pdb_csv.drop(train_csv.index)
+ self.csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_eval_length]
+ self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"valid.csv"), index=False)
+
+ self.csv = self.csv.sample(n=min(self.dataset_cfg.max_valid_num, len(self.csv)), random_state=self.random_seed).reset_index()
+
+ # self.chain_feats_total = []
+ # for idx in range(len(self.csv)):
+ # processed_path = self.csv.iloc[idx]['processed_path']
+
+ # # chain_feats_temp = self._process_csv_row(processed_path)
+ # chain_feats_temp = du.read_pkl(processed_path)
+ # chain_feats_temp['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32)
+
+ # self.chain_feats_total += [chain_feats_temp]
+
+ # self._log.info(
+ # f"Valid: {len(self.chain_feats_total)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}")
+ self._log.info(
+ f"Valid: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}")
+
+ # def _process_csv_row(self, processed_file_path):
+ # self.count += 1
+ # if self.count%200==0:
+ # self._log.info(
+ # f"pre_count= {self.count}")
+
+ # output_total = du.read_pkl(processed_file_path)
+ # energy_csv = self.energy_csv
+
+ # file = os.path.basename(processed_file_path).replace(".pkl", ".pdb")
+
+ # matching_rows = energy_csv[energy_csv['traj_filename'] == file]
+ # if not matching_rows.empty:
+ # output_total['energy'] = torch.tensor(matching_rows['energy'].values[0], dtype=torch.float32)
+
+ # return output_total
+
+ def __len__(self):
+ return len(self.csv)
+
+ def __getitem__(self, idx):
+ # chain_feats = self.chain_feats_total[idx]
+
+ processed_path = self.csv.iloc[idx]['processed_path']
+ chain_feats = du.read_pkl(processed_path)
+ chain_feats['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32)
+
+ energy = chain_feats['energy']
+
+
+ if self.is_training and self._dataset_cfg.use_split:
+ # split_len = self._dataset_cfg.split_len
+
+ split_len = random.randint(self.dataset_cfg.min_num_res, min(self._dataset_cfg.split_len, chain_feats['aatype'].shape[0]))
+
+ idx = random.randint(0,chain_feats['aatype'].shape[0]-split_len)
+ output_total = copy.deepcopy(chain_feats)
+
+ output_total['energy'] = torch.ones(chain_feats['aatype'].shape)
+
+ output_temp = tree.map_structure(lambda x: x[idx:idx+split_len], output_total)
+
+ bb_center = np.sum(output_temp['bb_positions'], axis=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) # (3,)
+ output_temp['trans_1']=(output_temp['trans_1'] - torch.from_numpy(bb_center[None, :])).float()
+ output_temp['bb_positions']=output_temp['bb_positions']- bb_center[None, :]
+ output_temp['all_atom_positions']=output_temp['all_atom_positions'] - torch.from_numpy(bb_center[None, None, :])
+ output_temp['pair_repr_pre'] = output_temp['pair_repr_pre'][:,idx:idx+split_len]
+
+ bb_center_esmfold = torch.sum(output_temp['trans_esmfold'], dim=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) # (3,)
+ output_temp['trans_esmfold']=(output_temp['trans_esmfold'] - bb_center_esmfold[None, :]).float()
+
+ chain_feats = output_temp
+ chain_feats['energy'] = energy
+
+
+ if self._dataset_cfg.use_rotate_enhance:
+ rot_vet = [random.random() for _ in range(3)]
+ rot_mat = torch.tensor(scipy_R.from_rotvec(rot_vet).as_matrix()) # (3,3)
+ chain_feats['all_atom_positions']=torch.einsum('lij,kj->lik',chain_feats['all_atom_positions'],
+ rot_mat.type(chain_feats['all_atom_positions'].dtype))
+
+ all_atom_mask = np.array([restype_atom37_mask[i] for i in chain_feats['aatype']])
+
+ chain_feats_temp = {
+ 'aatype': chain_feats['aatype'],
+ 'all_atom_positions': chain_feats['all_atom_positions'],
+ 'all_atom_mask': torch.tensor(all_atom_mask).double(),
+ }
+ chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp)
+ curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0]
+ chain_feats['trans_1'] = curr_rigid.get_trans()
+ chain_feats['rotmats_1'] = curr_rigid.get_rots().get_rot_mats()
+ chain_feats['bb_positions']=(chain_feats['trans_1']).numpy().astype(chain_feats['bb_positions'].dtype)
+
+ return chain_feats
diff --git a/data/process_pdb_files.py b/data/process_pdb_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..36a3c37d2a8bd9a219d25a7e6fdc9693406e33f1
--- /dev/null
+++ b/data/process_pdb_files.py
@@ -0,0 +1,299 @@
+import argparse
+import dataclasses
+import functools as fn
+import pandas as pd
+import os
+import tree
+import torch
+import multiprocessing as mp
+import time
+import esm
+from Bio import PDB
+import numpy as np
+from data import utils as du
+from data import parsers
+from data import errors
+from data.repr import get_pre_repr
+from openfold.data import data_transforms
+from openfold.utils import rigid_utils
+from data.cal_trans_rotmats import cal_trans_rotmats
+from data.ESMfold_pred import ESMFold_Pred
+
+
+def process_file(file_path: str, write_dir: str):
+ """Processes protein file into usable, smaller pickles.
+
+ Args:
+ file_path: Path to file to read.
+ write_dir: Directory to write pickles to.
+
+ Returns:
+ Saves processed protein to pickle and returns metadata.
+
+ Raises:
+ DataError if a known filtering rule is hit.
+ All other errors are unexpected and are propogated.
+ """
+ metadata = {}
+ pdb_name = os.path.basename(file_path).replace('.pdb', '')
+ metadata['pdb_name'] = pdb_name
+
+ processed_path = os.path.join(write_dir, f'{pdb_name}.pkl')
+ metadata['processed_path'] = os.path.abspath(processed_path)
+ metadata['raw_path'] = file_path
+ parser = PDB.PDBParser(QUIET=True)
+ structure = parser.get_structure(pdb_name, file_path)
+
+ # Extract all chains
+ struct_chains = {
+ chain.id.upper(): chain
+ for chain in structure.get_chains()}
+ metadata['num_chains'] = len(struct_chains)
+
+ # Extract features
+ struct_feats = []
+ all_seqs = set()
+ for chain_id, chain in struct_chains.items():
+ # Convert chain id into int
+ chain_id = du.chain_str_to_int(chain_id)
+ chain_prot = parsers.process_chain(chain, chain_id)
+ chain_dict = dataclasses.asdict(chain_prot)
+ chain_dict = du.parse_chain_feats(chain_dict)
+ all_seqs.add(tuple(chain_dict['aatype']))
+ struct_feats.append(chain_dict)
+ if len(all_seqs) == 1:
+ metadata['quaternary_category'] = 'homomer'
+ else:
+ metadata['quaternary_category'] = 'heteromer'
+ complex_feats = du.concat_np_features(struct_feats, False)
+
+ # Process geometry features
+ complex_aatype = complex_feats['aatype']
+ metadata['seq_len'] = len(complex_aatype)
+ modeled_idx = np.where(complex_aatype != 20)[0]
+ if np.sum(complex_aatype != 20) == 0:
+ raise errors.LengthError('No modeled residues')
+ min_modeled_idx = np.min(modeled_idx)
+ max_modeled_idx = np.max(modeled_idx)
+ metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1
+ complex_feats['modeled_idx'] = modeled_idx
+
+ # try:
+ # # MDtraj
+ # traj = md.load(file_path)
+ # # SS calculation
+ # pdb_ss = md.compute_dssp(traj, simplified=True)
+ # # DG calculation
+ # pdb_dg = md.compute_rg(traj)
+ # # os.remove(file_path)
+ # except Exception as e:
+ # # os.remove(file_path)
+ # raise errors.DataError(f'Mdtraj failed with error {e}')
+
+ # chain_dict['ss'] = pdb_ss[0]
+ # metadata['coil_percent'] = np.sum(pdb_ss == 'C') / metadata['modeled_seq_len']
+ # metadata['helix_percent'] = np.sum(pdb_ss == 'H') / metadata['modeled_seq_len']
+ # metadata['strand_percent'] = np.sum(pdb_ss == 'E') / metadata['modeled_seq_len']
+
+ # # Radius of gyration
+ # metadata['radius_gyration'] = pdb_dg[0]
+
+ # Write features to pickles.
+ du.write_pkl(processed_path, complex_feats)
+
+ return metadata
+
+
+def process_serially(all_paths, write_dir):
+ all_metadata = []
+ for i, file_path in enumerate(all_paths):
+ try:
+ start_time = time.time()
+ metadata = process_file(
+ file_path,
+ write_dir)
+ elapsed_time = time.time() - start_time
+ print(f'Finished {file_path} in {elapsed_time:2.2f}s')
+ all_metadata.append(metadata)
+ except errors.DataError as e:
+ print(f'Failed {file_path}: {e}')
+ return all_metadata
+
+
+def process_fn(
+ file_path,
+ verbose=None,
+ write_dir=None):
+ try:
+ start_time = time.time()
+ metadata = process_file(
+ file_path,
+ write_dir)
+ elapsed_time = time.time() - start_time
+ if verbose:
+ print(f'Finished {file_path} in {elapsed_time:2.2f}s')
+ return metadata
+ except errors.DataError as e:
+ if verbose:
+ print(f'Failed {file_path}: {e}')
+
+
+def main(args):
+ pdb_dir = args.pdb_dir
+ all_file_paths = [
+ os.path.join(pdb_dir, x)
+ for x in os.listdir(args.pdb_dir) if '.pdb' in x]
+ total_num_paths = len(all_file_paths)
+ write_dir = args.write_dir
+ if not os.path.exists(write_dir):
+ os.makedirs(write_dir)
+ if args.debug:
+ metadata_file_name = 'metadata_debug.csv'
+ else:
+ metadata_file_name = 'metadata.csv'
+ metadata_path = os.path.join(write_dir, metadata_file_name)
+ print(f'Files will be written to {write_dir}')
+
+ # Process each mmcif file
+ if args.num_processes == 1 or args.debug:
+ all_metadata = process_serially(
+ all_file_paths,
+ write_dir)
+ else:
+ _process_fn = fn.partial( # fix some args of fn
+ process_fn,
+ verbose=args.verbose,
+ write_dir=write_dir)
+ with mp.Pool(processes=args.num_processes) as pool:
+ all_metadata = pool.map(_process_fn, all_file_paths) # all jobs to be done are saved in a iterable object (such as list)
+ all_metadata = [x for x in all_metadata if x is not None]
+ metadata_df = pd.DataFrame(all_metadata)
+ metadata_df.to_csv(metadata_path, index=False)
+ succeeded = len(all_metadata)
+ print(
+ f'Finished processing {succeeded}/{total_num_paths} files')
+
+
+def cal_repr(processed_file_path, model_esm2, alphabet, batch_converter, esm_device):
+ print(f'cal_repr for {processed_file_path}')
+ processed_feats_org = du.read_pkl(processed_file_path)
+ processed_feats = du.parse_chain_feats(processed_feats_org)
+
+ # Only take modeled residues.
+ modeled_idx = processed_feats['modeled_idx']
+ min_idx = np.min(modeled_idx)
+ max_idx = np.max(modeled_idx)
+ # del processed_feats['modeled_idx']
+ processed_feats = tree.map_structure(
+ lambda x: x[min_idx:(max_idx+1)], processed_feats)
+
+ # Run through OpenFold data transforms.
+ chain_feats = {
+ 'aatype': torch.tensor(processed_feats['aatype']).long(),
+ 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(),
+ 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double()
+ }
+ chain_feats = data_transforms.atom37_to_frames(chain_feats)
+ rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0]
+ rotmats_1 = rigids_1.get_rots().get_rot_mats()
+ trans_1 = rigids_1.get_trans()
+ # res_idx = processed_feats['residue_index']
+
+ node_repr_pre, pair_repr_pre = get_pre_repr(chain_feats['aatype'], model_esm2, alphabet, batch_converter, device = esm_device) # (B,L,d_node_pre=1280), (B,L,L,d_edge_pre=20)
+ node_repr_pre = node_repr_pre[0].cpu()
+ pair_repr_pre = pair_repr_pre[0].cpu()
+
+ out = {
+ 'aatype': chain_feats['aatype'],
+ 'rotmats_1': rotmats_1,
+ 'trans_1': trans_1, # (L,3)
+ 'res_mask': torch.tensor(processed_feats['bb_mask']).int(),
+ 'bb_positions': processed_feats['bb_positions'],
+ 'all_atom_positions':chain_feats['all_atom_positions'],
+ 'node_repr_pre':node_repr_pre,
+ 'pair_repr_pre':pair_repr_pre,
+ }
+
+ du.write_pkl(processed_file_path, out)
+
+def cal_static_structure(processed_file_path, raw_pdb_file, ESMFold):
+ output_total = du.read_pkl(processed_file_path)
+
+ save_dir = os.path.join(os.path.dirname(raw_pdb_file), 'ESMFold_Pred_results')
+ os.makedirs(save_dir, exist_ok=True)
+ save_path = os.path.join(save_dir, os.path.basename(processed_file_path)[:6]+'_esmfold.pdb')
+ if not os.path.exists(save_path):
+ print(f'cal_static_structure for {processed_file_path}')
+ ESMFold.predict_str(raw_pdb_file, save_path)
+ trans, rotmats = cal_trans_rotmats(save_path)
+ output_total['trans_esmfold'] = trans
+ output_total['rotmats_esmfold'] = rotmats
+
+ du.write_pkl(processed_file_path, output_total)
+
+
+def merge_pdb(metadata_path, traj_info_file, valid_seq_file, merged_output_file):
+ df1 = pd.read_csv(metadata_path)
+ df2 = pd.read_csv(traj_info_file)
+ df3 = pd.read_csv(valid_seq_file)
+
+ # 获取文件名
+ df1['traj_filename'] = [os.path.basename(i) for i in df1['raw_path']]
+
+ # 合并数据
+ merged = df1.merge(df2[['traj_filename', 'energy']], on='traj_filename', how='left')
+ merged['is_trainset'] = ~merged['traj_filename'].str[:6].isin(df3['file'])
+
+ # 保存结果
+ merged.to_csv(merged_output_file, index=False)
+ print('merge complete!')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--pdb_dir", type=str, default="./dataset/ATLAS/select")
+ parser.add_argument("--write_dir", type=str, default="./dataset/ATLAS/select/pkl")
+ parser.add_argument("--csv_name", type=str, default="metadata.csv")
+ parser.add_argument("--debug", type=bool, default=False)
+ parser.add_argument("--num_processes", type=int, default=48)
+ parser.add_argument('--verbose', help='Whether to log everything.',action='store_true')
+
+ parser.add_argument("--esm_device", type=str, default='cuda')
+
+ parser.add_argument("--traj_info_file", type=str, default='./dataset/ATLAS/select/traj_info_select.csv')
+ parser.add_argument("--valid_seq_file", type=str, default='./inference/valid_seq.csv')
+ parser.add_argument("--merged_output_file", type=str, default='./dataset/ATLAS/select/pkl/metadata_merged.csv')
+
+ args = parser.parse_args()
+
+ # process .pdb to .pkl
+ main(args)
+
+ # cal_repr
+ csv_path = os.path.join(args.write_dir, args.csv_name)
+ pdb_csv = pd.read_csv(csv_path)
+ pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False)
+ model_esm2, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
+ batch_converter = alphabet.get_batch_converter()
+ model_esm2.eval()
+ model_esm2.requires_grad_(False)
+ model_esm2.to(args.esm_device)
+ for idx in range(len(pdb_csv)):
+ cal_repr(pdb_csv.iloc[idx]['processed_path'], model_esm2, alphabet, batch_converter, args.esm_device)
+
+ # cal_static_structure
+ csv_path = os.path.join(args.write_dir, args.csv_name)
+ pdb_csv = pd.read_csv(csv_path)
+ ESMFold = ESMFold_Pred(device = args.esm_device)
+ for idx in range(len(pdb_csv)):
+ cal_static_structure(pdb_csv.iloc[idx]['processed_path'], pdb_csv.iloc[idx]['raw_path'], ESMFold)
+
+ # merge csv
+ csv_path = os.path.join(args.write_dir, args.csv_name)
+ merge_pdb(csv_path, args.traj_info_file, args.valid_seq_file, args.merged_output_file)
+
+
+
+
+
diff --git a/data/protein.py b/data/protein.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bcb3b1411ef907e8f24a9e4b62e718fd25ca8aa
--- /dev/null
+++ b/data/protein.py
@@ -0,0 +1,280 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Protein data type."""
+import dataclasses
+import io
+from typing import Any, Mapping, Optional
+from data import residue_constants
+from Bio.PDB import PDBParser
+import numpy as np
+
+FeatureDict = Mapping[str, np.ndarray]
+ModelOutput = Mapping[str, Any] # Is a nested dict.
+
+# Complete sequence of chain IDs supported by the PDB format.
+PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
+PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
+
+
+@dataclasses.dataclass(frozen=True)
+class Protein:
+ """Protein structure representation."""
+
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
+
+ # Amino-acid type for each residue represented as an integer between 0 and
+ # 20, where 20 is 'X'.
+ aatype: np.ndarray # [num_res]
+
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
+ # is present and 0.0 if not. This should be used for loss masking.
+ atom_mask: np.ndarray # [num_res, num_atom_type]
+
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
+ residue_index: np.ndarray # [num_res]
+
+ # 0-indexed number corresponding to the chain in the protein that this residue
+ # belongs to.
+ chain_index: np.ndarray # [num_res]
+
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
+ # representing the displacement of the residue from its ground truth mean
+ # value.
+ b_factors: np.ndarray # [num_res, num_atom_type]
+
+ def __post_init__(self):
+ if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
+ raise ValueError(
+ f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
+ 'because these cannot be written to PDB format.')
+
+
+def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
+ """Takes a PDB string and constructs a Protein object.
+
+ WARNING: All non-standard residue types will be converted into UNK. All
+ non-standard atoms will be ignored.
+
+ Args:
+ pdb_str: The contents of the pdb file
+ chain_id: If chain_id is specified (e.g. A), then only that chain
+ is parsed. Otherwise all chains are parsed.
+
+ Returns:
+ A new `Protein` parsed from the pdb contents.
+ """
+ pdb_fh = io.StringIO(pdb_str)
+ parser = PDBParser(QUIET=True)
+ structure = parser.get_structure('none', pdb_fh)
+ models = list(structure.get_models())
+ if len(models) != 1:
+ raise ValueError(
+ f'Only single model PDBs are supported. Found {len(models)} models.')
+ model = models[0]
+
+ atom_positions = []
+ aatype = []
+ atom_mask = []
+ residue_index = []
+ chain_ids = []
+ b_factors = []
+
+ for chain in model:
+ if chain_id is not None and chain.id != chain_id:
+ continue
+ for res in chain:
+ if res.id[2] != ' ':
+ raise ValueError(
+ f'PDB contains an insertion code at chain {chain.id} and residue '
+ f'index {res.id[1]}. These are not supported.')
+ res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
+ restype_idx = residue_constants.restype_order.get(
+ res_shortname, residue_constants.restype_num)
+ pos = np.zeros((residue_constants.atom_type_num, 3))
+ mask = np.zeros((residue_constants.atom_type_num,))
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
+ for atom in res:
+ if atom.name not in residue_constants.atom_types:
+ continue
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
+ mask[residue_constants.atom_order[atom.name]] = 1.
+ res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
+ if np.sum(mask) < 0.5:
+ # If no known atom positions are reported for the residue then skip it.
+ continue
+ aatype.append(restype_idx)
+ atom_positions.append(pos)
+ atom_mask.append(mask)
+ residue_index.append(res.id[1])
+ chain_ids.append(chain.id)
+ b_factors.append(res_b_factors)
+
+ # Chain IDs are usually characters so map these to ints.
+ unique_chain_ids = np.unique(chain_ids)
+ chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
+
+ return Protein(
+ atom_positions=np.array(atom_positions),
+ atom_mask=np.array(atom_mask),
+ aatype=np.array(aatype),
+ residue_index=np.array(residue_index),
+ chain_index=chain_index,
+ b_factors=np.array(b_factors))
+
+
+def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
+ chain_end = 'TER'
+ return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
+ f'{chain_name:>1}{residue_index:>4}')
+
+
+def to_pdb(prot: Protein, model=1, add_end=True) -> str:
+ """Converts a `Protein` instance to a PDB string.
+
+ Args:
+ prot: The protein to convert to PDB.
+
+ Returns:
+ PDB string.
+ """
+ restypes = residue_constants.restypes + ['X']
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
+ atom_types = residue_constants.atom_types
+
+ pdb_lines = []
+
+ atom_mask = prot.atom_mask
+ aatype = prot.aatype
+ atom_positions = prot.atom_positions
+ residue_index = prot.residue_index.astype(int)
+ chain_index = prot.chain_index.astype(int)
+ b_factors = prot.b_factors
+
+ if np.any(aatype > residue_constants.restype_num):
+ raise ValueError('Invalid aatypes.')
+
+ # Construct a mapping from chain integer indices to chain ID strings.
+ chain_ids = {}
+ for i in np.unique(chain_index): # np.unique gives sorted output.
+ if i >= PDB_MAX_CHAINS:
+ raise ValueError(
+ f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
+ chain_ids[i] = PDB_CHAIN_IDS[i]
+
+ pdb_lines.append(f'MODEL {model}')
+ atom_index = 1
+ last_chain_index = chain_index[0]
+ # Add all atom sites.
+ for i in range(aatype.shape[0]):
+ # Close the previous chain if in a multichain PDB.
+ if last_chain_index != chain_index[i]:
+ pdb_lines.append(_chain_end(
+ atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
+ residue_index[i - 1]))
+ last_chain_index = chain_index[i]
+ atom_index += 1 # Atom index increases at the TER symbol.
+
+ res_name_3 = res_1to3(aatype[i])
+ for atom_name, pos, mask, b_factor in zip(
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
+ if mask < 0.5:
+ continue
+
+ record_type = 'ATOM'
+ name = atom_name if len(atom_name) == 4 else f' {atom_name}'
+ alt_loc = ''
+ insertion_code = ''
+ occupancy = 1.00
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
+ charge = ''
+ # PDB is a columnar format, every space matters here!
+ atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
+ f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
+ f'{residue_index[i]:>4}{insertion_code:>1} '
+ f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
+ f'{occupancy:>6.2f}{b_factor:>6.2f} '
+ f'{element:>2}{charge:>2}')
+ pdb_lines.append(atom_line)
+ atom_index += 1
+
+ # Close the final chain.
+ pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
+ chain_ids[chain_index[-1]], residue_index[-1]))
+ pdb_lines.append('ENDMDL')
+ if add_end:
+ pdb_lines.append('END')
+
+ # Pad all lines to 80 characters.
+ pdb_lines = [line.ljust(80) for line in pdb_lines]
+ return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
+
+
+def ideal_atom_mask(prot: Protein) -> np.ndarray:
+ """Computes an ideal atom mask.
+
+ `Protein.atom_mask` typically is defined according to the atoms that are
+ reported in the PDB. This function computes a mask according to heavy atoms
+ that should be present in the given sequence of amino acids.
+
+ Args:
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
+
+ Returns:
+ An ideal atom mask.
+ """
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
+
+
+def from_prediction(
+ features: FeatureDict,
+ result: ModelOutput,
+ b_factors: Optional[np.ndarray] = None,
+ remove_leading_feature_dimension: bool = True) -> Protein:
+ """Assembles a protein from a prediction.
+
+ Args:
+ features: Dictionary holding model inputs.
+ result: Dictionary holding model outputs.
+ b_factors: (Optional) B-factors to use for the protein.
+ remove_leading_feature_dimension: Whether to remove the leading dimension
+ of the `features` values.
+
+ Returns:
+ A protein instance.
+ """
+ fold_output = result['structure_module']
+
+ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
+ return arr[0] if remove_leading_feature_dimension else arr
+
+ if 'asym_id' in features:
+ chain_index = _maybe_remove_leading_dim(features['asym_id'])
+ else:
+ chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
+
+ if b_factors is None:
+ b_factors = np.zeros_like(fold_output['final_atom_mask'])
+
+ return Protein(
+ aatype=_maybe_remove_leading_dim(features['aatype']),
+ atom_positions=fold_output['final_atom_positions'],
+ atom_mask=fold_output['final_atom_mask'],
+ residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
+ chain_index=chain_index,
+ b_factors=b_factors)
\ No newline at end of file
diff --git a/data/repr.py b/data/repr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97759ca982fb7cde94fefffad952753c85be534
--- /dev/null
+++ b/data/repr.py
@@ -0,0 +1,38 @@
+import torch
+from opt_einsum import contract as einsum
+import esm
+from data.residue_constants import order2restype_with_mask
+
+def get_pre_repr(seqs, model, alphabet, batch_converter, device="cuda:0"):
+
+ # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
+ # data = [
+ # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
+ # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
+ # ("protein2 with mask","KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
+ # ("protein3", "K A I S Q"),
+ # ]
+ data = []
+ for idx, seq in enumerate([seqs]):
+ seq_string = ''.join([order2restype_with_mask[int(i)] for i in seq])
+ data.append(("protein_"+str(idx), seq_string))
+
+ batch_labels, batch_strs, batch_tokens = batch_converter(data)
+ batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
+
+ # Extract per-residue representations (on CPU)
+ with torch.no_grad():
+ results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True)
+
+ node_repr = results["representations"][33][:,1:-1,:]
+ pair_repr = results['attentions'][:,33-1,:,1:-1,1:-1].permute(0,2,3,1)
+ # Generate per-sequence representations via averaging
+ # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
+ # sequence_representations = []
+ # for i, tokens_len in enumerate(batch_lens):
+ # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
+
+ # Look at the unsupervised self-attention map contact predictions
+ # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
+ # plt.matshow(attention_contacts[: tokens_len, : tokens_len])
+ return node_repr, pair_repr
\ No newline at end of file
diff --git a/data/residue_constants.py b/data/residue_constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..9491831538549b2ce77224b7dd6513b7c0b114c7
--- /dev/null
+++ b/data/residue_constants.py
@@ -0,0 +1,902 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Constants used in AlphaFold."""
+
+import collections
+import functools
+import os
+from typing import List, Mapping, Tuple
+
+import numpy as np
+import tree
+
+# Internal import (35fd).
+
+
+# Distance from one CA to next CA [trans configuration: omega = 180].
+ca_ca = 3.80209737096
+
+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
+# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
+# chi angles so their chi angle lists are empty.
+chi_angles_atoms = {
+ 'ALA': [],
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
+ 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
+ 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'CYS': [['N', 'CA', 'CB', 'SG']],
+ 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLY': [],
+ 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
+ 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
+ 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
+ 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
+ ['CB', 'CG', 'SD', 'CE']],
+ 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
+ 'SER': [['N', 'CA', 'CB', 'OG']],
+ 'THR': [['N', 'CA', 'CB', 'OG1']],
+ 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'VAL': [['N', 'CA', 'CB', 'CG1']],
+}
+
+# If chi angles given in fixed-length array, this matrix determines how to mask
+# them for each AA type. The order is as per restype_order (see below).
+chi_angles_mask = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [1.0, 1.0, 1.0, 1.0], # ARG
+ [1.0, 1.0, 0.0, 0.0], # ASN
+ [1.0, 1.0, 0.0, 0.0], # ASP
+ [1.0, 0.0, 0.0, 0.0], # CYS
+ [1.0, 1.0, 1.0, 0.0], # GLN
+ [1.0, 1.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [1.0, 1.0, 0.0, 0.0], # HIS
+ [1.0, 1.0, 0.0, 0.0], # ILE
+ [1.0, 1.0, 0.0, 0.0], # LEU
+ [1.0, 1.0, 1.0, 1.0], # LYS
+ [1.0, 1.0, 1.0, 0.0], # MET
+ [1.0, 1.0, 0.0, 0.0], # PHE
+ [1.0, 1.0, 0.0, 0.0], # PRO
+ [1.0, 0.0, 0.0, 0.0], # SER
+ [1.0, 0.0, 0.0, 0.0], # THR
+ [1.0, 1.0, 0.0, 0.0], # TRP
+ [1.0, 1.0, 0.0, 0.0], # TYR
+ [1.0, 0.0, 0.0, 0.0], # VAL
+]
+
+# The following chi angles are pi periodic: they can be rotated by a multiple
+# of pi without affecting the structure.
+chi_pi_periodic = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [0.0, 0.0, 0.0, 0.0], # ARG
+ [0.0, 0.0, 0.0, 0.0], # ASN
+ [0.0, 1.0, 0.0, 0.0], # ASP
+ [0.0, 0.0, 0.0, 0.0], # CYS
+ [0.0, 0.0, 0.0, 0.0], # GLN
+ [0.0, 0.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [0.0, 0.0, 0.0, 0.0], # HIS
+ [0.0, 0.0, 0.0, 0.0], # ILE
+ [0.0, 0.0, 0.0, 0.0], # LEU
+ [0.0, 0.0, 0.0, 0.0], # LYS
+ [0.0, 0.0, 0.0, 0.0], # MET
+ [0.0, 1.0, 0.0, 0.0], # PHE
+ [0.0, 0.0, 0.0, 0.0], # PRO
+ [0.0, 0.0, 0.0, 0.0], # SER
+ [0.0, 0.0, 0.0, 0.0], # THR
+ [0.0, 0.0, 0.0, 0.0], # TRP
+ [0.0, 1.0, 0.0, 0.0], # TYR
+ [0.0, 0.0, 0.0, 0.0], # VAL
+ [0.0, 0.0, 0.0, 0.0], # UNK
+]
+
+# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
+# psi and chi angles:
+# 0: 'backbone group',
+# 1: 'pre-omega-group', (empty)
+# 2: 'phi-group', (currently empty, because it defines only hydrogens)
+# 3: 'psi-group',
+# 4,5,6,7: 'chi1,2,3,4-group'
+# The atom positions are relative to the axis-end-atom of the corresponding
+# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
+# is defined such that the dihedral-angle-definiting atom (the last entry in
+# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
+# format: [atomname, group_idx, rel_position]
+rigid_group_atom_positions = {
+ 'ALA': [
+ ['N', 0, (-0.525, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.529, -0.774, -1.205)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ],
+ 'ARG': [
+ ['N', 0, (-0.524, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.209)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.616, 1.390, -0.000)],
+ ['CD', 5, (0.564, 1.414, 0.000)],
+ ['NE', 6, (0.539, 1.357, -0.000)],
+ ['NH1', 7, (0.206, 2.301, 0.000)],
+ ['NH2', 7, (2.078, 0.978, -0.000)],
+ ['CZ', 7, (0.758, 1.093, -0.000)],
+ ],
+ 'ASN': [
+ ['N', 0, (-0.536, 1.357, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.531, -0.787, -1.200)],
+ ['O', 3, (0.625, 1.062, 0.000)],
+ ['CG', 4, (0.584, 1.399, 0.000)],
+ ['ND2', 5, (0.593, -1.188, 0.001)],
+ ['OD1', 5, (0.633, 1.059, 0.000)],
+ ],
+ 'ASP': [
+ ['N', 0, (-0.525, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, 0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.593, 1.398, -0.000)],
+ ['OD1', 5, (0.610, 1.091, 0.000)],
+ ['OD2', 5, (0.592, -1.101, -0.003)],
+ ],
+ 'CYS': [
+ ['N', 0, (-0.522, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, 0.000)],
+ ['CB', 0, (-0.519, -0.773, -1.212)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['SG', 4, (0.728, 1.653, 0.000)],
+ ],
+ 'GLN': [
+ ['N', 0, (-0.526, 1.361, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.779, -1.207)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.615, 1.393, 0.000)],
+ ['CD', 5, (0.587, 1.399, -0.000)],
+ ['NE2', 6, (0.593, -1.189, -0.001)],
+ ['OE1', 6, (0.634, 1.060, 0.000)],
+ ],
+ 'GLU': [
+ ['N', 0, (-0.528, 1.361, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.781, -1.207)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.615, 1.392, 0.000)],
+ ['CD', 5, (0.600, 1.397, 0.000)],
+ ['OE1', 6, (0.607, 1.095, -0.000)],
+ ['OE2', 6, (0.589, -1.104, -0.001)],
+ ],
+ 'GLY': [
+ ['N', 0, (-0.572, 1.337, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.517, -0.000, -0.000)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ],
+ 'HIS': [
+ ['N', 0, (-0.527, 1.360, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.778, -1.208)],
+ ['O', 3, (0.625, 1.063, 0.000)],
+ ['CG', 4, (0.600, 1.370, -0.000)],
+ ['CD2', 5, (0.889, -1.021, 0.003)],
+ ['ND1', 5, (0.744, 1.160, -0.000)],
+ ['CE1', 5, (2.030, 0.851, 0.002)],
+ ['NE2', 5, (2.145, -0.466, 0.004)],
+ ],
+ 'ILE': [
+ ['N', 0, (-0.493, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.536, -0.793, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.534, 1.437, -0.000)],
+ ['CG2', 4, (0.540, -0.785, -1.199)],
+ ['CD1', 5, (0.619, 1.391, 0.000)],
+ ],
+ 'LEU': [
+ ['N', 0, (-0.520, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.773, -1.214)],
+ ['O', 3, (0.625, 1.063, -0.000)],
+ ['CG', 4, (0.678, 1.371, 0.000)],
+ ['CD1', 5, (0.530, 1.430, -0.000)],
+ ['CD2', 5, (0.535, -0.774, 1.200)],
+ ],
+ 'LYS': [
+ ['N', 0, (-0.526, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.619, 1.390, 0.000)],
+ ['CD', 5, (0.559, 1.417, 0.000)],
+ ['CE', 6, (0.560, 1.416, 0.000)],
+ ['NZ', 7, (0.554, 1.387, 0.000)],
+ ],
+ 'MET': [
+ ['N', 0, (-0.521, 1.364, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.210)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['CG', 4, (0.613, 1.391, -0.000)],
+ ['SD', 5, (0.703, 1.695, 0.000)],
+ ['CE', 6, (0.320, 1.786, -0.000)],
+ ],
+ 'PHE': [
+ ['N', 0, (-0.518, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, -0.000)],
+ ['CB', 0, (-0.525, -0.776, -1.212)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.377, 0.000)],
+ ['CD1', 5, (0.709, 1.195, -0.000)],
+ ['CD2', 5, (0.706, -1.196, 0.000)],
+ ['CE1', 5, (2.102, 1.198, -0.000)],
+ ['CE2', 5, (2.098, -1.201, -0.000)],
+ ['CZ', 5, (2.794, -0.003, -0.001)],
+ ],
+ 'PRO': [
+ ['N', 0, (-0.566, 1.351, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, 0.000)],
+ ['CB', 0, (-0.546, -0.611, -1.293)],
+ ['O', 3, (0.621, 1.066, 0.000)],
+ ['CG', 4, (0.382, 1.445, 0.0)],
+ # ['CD', 5, (0.427, 1.440, 0.0)],
+ ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
+ ],
+ 'SER': [
+ ['N', 0, (-0.529, 1.360, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.518, -0.777, -1.211)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['OG', 4, (0.503, 1.325, 0.000)],
+ ],
+ 'THR': [
+ ['N', 0, (-0.517, 1.364, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, -0.000)],
+ ['CB', 0, (-0.516, -0.793, -1.215)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG2', 4, (0.550, -0.718, -1.228)],
+ ['OG1', 4, (0.472, 1.353, 0.000)],
+ ],
+ 'TRP': [
+ ['N', 0, (-0.521, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.212)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ['CG', 4, (0.609, 1.370, -0.000)],
+ ['CD1', 5, (0.824, 1.091, 0.000)],
+ ['CD2', 5, (0.854, -1.148, -0.005)],
+ ['CE2', 5, (2.186, -0.678, -0.007)],
+ ['CE3', 5, (0.622, -2.530, -0.007)],
+ ['NE1', 5, (2.140, 0.690, -0.004)],
+ ['CH2', 5, (3.028, -2.890, -0.013)],
+ ['CZ2', 5, (3.283, -1.543, -0.011)],
+ ['CZ3', 5, (1.715, -3.389, -0.011)],
+ ],
+ 'TYR': [
+ ['N', 0, (-0.522, 1.362, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.776, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.382, -0.000)],
+ ['CD1', 5, (0.716, 1.195, -0.000)],
+ ['CD2', 5, (0.713, -1.194, -0.001)],
+ ['CE1', 5, (2.107, 1.200, -0.002)],
+ ['CE2', 5, (2.104, -1.201, -0.003)],
+ ['OH', 5, (4.168, -0.002, -0.005)],
+ ['CZ', 5, (2.791, -0.001, -0.003)],
+ ],
+ 'VAL': [
+ ['N', 0, (-0.494, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.533, -0.795, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.540, 1.429, -0.000)],
+ ['CG2', 4, (0.533, -0.776, 1.203)],
+ ],
+}
+
+# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
+residue_atoms = {
+ 'ALA': ['C', 'CA', 'CB', 'N', 'O'],
+ 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
+ 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
+ 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
+ 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
+ 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
+ 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
+ 'GLY': ['C', 'CA', 'N', 'O'],
+ 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
+ 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
+ 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
+ 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
+ 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
+ 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
+ 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
+ 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
+ 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
+ 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
+ 'CH2', 'N', 'NE1', 'O'],
+ 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
+ 'OH'],
+ 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
+}
+
+# Naming swaps for ambiguous atom names.
+# Due to symmetries in the amino acids the naming of atoms is ambiguous in
+# 4 of the 20 amino acids.
+# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
+# in LEU, VAL and ARG can be resolved by using the 3d constellations of
+# the 'ambiguous' atoms and their neighbours)
+residue_atom_renaming_swaps = {
+ 'ASP': {'OD1': 'OD2'},
+ 'GLU': {'OE1': 'OE2'},
+ 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
+ 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
+}
+
+# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
+van_der_waals_radius = {
+ 'C': 1.7,
+ 'N': 1.55,
+ 'O': 1.52,
+ 'S': 1.8,
+}
+
+Bond = collections.namedtuple(
+ 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
+BondAngle = collections.namedtuple(
+ 'BondAngle',
+ ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
+
+
+@functools.lru_cache(maxsize=None)
+def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
+ Mapping[str, List[Bond]],
+ Mapping[str, List[BondAngle]]]:
+ """Load stereo_chemical_props.txt into a nice structure.
+
+ Load literature values for bond lengths and bond angles and translate
+ bond angles into the length of the opposite edge of the triangle
+ ("residue_virtual_bonds").
+
+ Returns:
+ residue_bonds: Dict that maps resname -> list of Bond tuples.
+ residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
+ residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
+ """
+ stereo_chemical_props_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
+ )
+ with open(stereo_chemical_props_path, 'rt') as f:
+ stereo_chemical_props = f.read()
+ lines_iter = iter(stereo_chemical_props.splitlines())
+ # Load bond lengths.
+ residue_bonds = {}
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, length, stddev = line.split()
+ atom1, atom2 = bond.split('-')
+ if resname not in residue_bonds:
+ residue_bonds[resname] = []
+ residue_bonds[resname].append(
+ Bond(atom1, atom2, float(length), float(stddev)))
+ residue_bonds['UNK'] = []
+
+ # Load bond angles.
+ residue_bond_angles = {}
+ next(lines_iter) # Skip empty line.
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, angle_degree, stddev_degree = line.split()
+ atom1, atom2, atom3 = bond.split('-')
+ if resname not in residue_bond_angles:
+ residue_bond_angles[resname] = []
+ residue_bond_angles[resname].append(
+ BondAngle(atom1, atom2, atom3,
+ float(angle_degree) / 180. * np.pi,
+ float(stddev_degree) / 180. * np.pi))
+ residue_bond_angles['UNK'] = []
+
+ def make_bond_key(atom1_name, atom2_name):
+ """Unique key to lookup bonds."""
+ return '-'.join(sorted([atom1_name, atom2_name]))
+
+ # Translate bond angles into distances ("virtual bonds").
+ residue_virtual_bonds = {}
+ for resname, bond_angles in residue_bond_angles.items():
+ # Create a fast lookup dict for bond lengths.
+ bond_cache = {}
+ for b in residue_bonds[resname]:
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
+ residue_virtual_bonds[resname] = []
+ for ba in bond_angles:
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
+
+ # Compute distance between atom1 and atom3 using the law of cosines
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
+ gamma = ba.angle_rad
+ length = np.sqrt(bond1.length**2 + bond2.length**2
+ - 2 * bond1.length * bond2.length * np.cos(gamma))
+
+ # Propagation of uncertainty assuming uncorrelated errors.
+ dl_outer = 0.5 / length
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
+ stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
+ (dl_db1 * bond1.stddev)**2 +
+ (dl_db2 * bond2.stddev)**2)
+ residue_virtual_bonds[resname].append(
+ Bond(ba.atom1_name, ba.atom3name, length, stddev))
+
+ return (residue_bonds,
+ residue_virtual_bonds,
+ residue_bond_angles)
+
+
+# Between-residue bond lengths for general bonds (first element) and for Proline
+# (second element).
+between_res_bond_length_c_n = [1.329, 1.341]
+between_res_bond_length_stddev_c_n = [0.014, 0.016]
+
+# Between-residue cos_angles.
+between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
+between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
+
+# This mapping is used when we need to store atom data in a format that requires
+# fixed atom data size for every residue (e.g. a numpy array).
+atom_types = [
+ 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
+ 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
+ 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
+ 'CZ3', 'NZ', 'OXT'
+]
+atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
+atom_type_num = len(atom_types) # := 37.
+
+# A compact atom encoding with 14 columns
+# pylint: disable=line-too-long
+# pylint: disable=bad-whitespace
+restype_name_to_atom14_names = {
+ 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
+ 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
+ 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
+ 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
+ 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
+ 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
+ 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
+ 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
+ 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
+ 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
+ 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
+ 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
+ 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
+ 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
+ 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
+ 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
+ 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
+ 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
+ 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
+ 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
+ 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
+
+}
+# pylint: enable=line-too-long
+# pylint: enable=bad-whitespace
+
+
+# This is the standard residue order when coding AA type as a number.
+# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
+restypes = [
+ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
+ 'S', 'T', 'W', 'Y', 'V'
+]
+restype_order = {restype: i for i, restype in enumerate(restypes)}
+restype_num = len(restypes) # := 20.
+unk_restype_index = restype_num # Catch-all index for unknown restypes.
+
+restypes_with_x = restypes + ['X']
+restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
+
+restypes_with_mask = restypes + [''] + [''] # idx=20 for unknown, idx=21 for mask
+# restype_order_with_mask = {restype: i for i, restype in enumerate(restypes_with_mask)}
+order2restype_with_mask = {i: restype for i, restype in enumerate(restypes_with_mask)}
+
+
+def sequence_to_onehot(
+ sequence: str,
+ mapping: Mapping[str, int],
+ map_unknown_to_x: bool = False) -> np.ndarray:
+ """Maps the given sequence into a one-hot encoded matrix.
+
+ Args:
+ sequence: An amino acid sequence.
+ mapping: A dictionary mapping amino acids to integers.
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
+ the mapping will throw an error.
+
+ Returns:
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
+ the sequence.
+
+ Raises:
+ ValueError: If the mapping doesn't contain values from 0 to
+ num_unique_aas - 1 without any gaps.
+ """
+ num_entries = max(mapping.values()) + 1
+
+ if sorted(set(mapping.values())) != list(range(num_entries)):
+ raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
+ 'without any gaps. Got: %s' % sorted(mapping.values()))
+
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
+
+ for aa_index, aa_type in enumerate(sequence):
+ if map_unknown_to_x:
+ if aa_type.isalpha() and aa_type.isupper():
+ aa_id = mapping.get(aa_type, mapping['X'])
+ else:
+ raise ValueError(f'Invalid character in the sequence: {aa_type}')
+ else:
+ aa_id = mapping[aa_type]
+ one_hot_arr[aa_index, aa_id] = 1
+
+ return one_hot_arr
+
+
+restype_1to3 = {
+ 'A': 'ALA',
+ 'R': 'ARG',
+ 'N': 'ASN',
+ 'D': 'ASP',
+ 'C': 'CYS',
+ 'Q': 'GLN',
+ 'E': 'GLU',
+ 'G': 'GLY',
+ 'H': 'HIS',
+ 'I': 'ILE',
+ 'L': 'LEU',
+ 'K': 'LYS',
+ 'M': 'MET',
+ 'F': 'PHE',
+ 'P': 'PRO',
+ 'S': 'SER',
+ 'T': 'THR',
+ 'W': 'TRP',
+ 'Y': 'TYR',
+ 'V': 'VAL',
+}
+
+
+# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
+# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
+# many more, and less common, three letter names as keys and maps many of these
+# to the same one letter name (including 'X' and 'U' which we don't use here).
+restype_3to1 = {v: k for k, v in restype_1to3.items()}
+
+# Define a restype name for all unknown residues.
+unk_restype = 'UNK'
+
+resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
+resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
+
+
+# The mapping here uses hhblits convention, so that B is mapped to D, J and O
+# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
+# remaining 20 amino acids are kept in alphabetical order.
+# There are 2 non-amino acid codes, X (representing any amino acid) and
+# "-" representing a missing amino acid in an alignment. The id for these
+# codes is put at the end (20 and 21) so that they can easily be ignored if
+# desired.
+HHBLITS_AA_TO_ID = {
+ 'A': 0,
+ 'B': 2,
+ 'C': 1,
+ 'D': 2,
+ 'E': 3,
+ 'F': 4,
+ 'G': 5,
+ 'H': 6,
+ 'I': 7,
+ 'J': 20,
+ 'K': 8,
+ 'L': 9,
+ 'M': 10,
+ 'N': 11,
+ 'O': 20,
+ 'P': 12,
+ 'Q': 13,
+ 'R': 14,
+ 'S': 15,
+ 'T': 16,
+ 'U': 1,
+ 'V': 17,
+ 'W': 18,
+ 'X': 20,
+ 'Y': 19,
+ 'Z': 3,
+ '-': 21,
+}
+
+# Partial inversion of HHBLITS_AA_TO_ID.
+ID_TO_HHBLITS_AA = {
+ 0: 'A',
+ 1: 'C', # Also U.
+ 2: 'D', # Also B.
+ 3: 'E', # Also Z.
+ 4: 'F',
+ 5: 'G',
+ 6: 'H',
+ 7: 'I',
+ 8: 'K',
+ 9: 'L',
+ 10: 'M',
+ 11: 'N',
+ 12: 'P',
+ 13: 'Q',
+ 14: 'R',
+ 15: 'S',
+ 16: 'T',
+ 17: 'V',
+ 18: 'W',
+ 19: 'Y',
+ 20: 'X', # Includes J and O.
+ 21: '-',
+}
+
+restypes_with_x_and_gap = restypes + ['X', '-']
+MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
+ for i in range(len(restypes_with_x_and_gap)))
+
+
+def _make_standard_atom_mask() -> np.ndarray:
+ """Returns [num_res_types, num_atom_types] mask array."""
+ # +1 to account for unknown (all 0s).
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
+ for restype, restype_letter in enumerate(restypes):
+ restype_name = restype_1to3[restype_letter]
+ atom_names = residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = atom_order[atom_name]
+ mask[restype, atom_type] = 1
+ return mask
+
+
+STANDARD_ATOM_MASK = _make_standard_atom_mask()
+
+
+# A one hot representation for the first and second atoms defining the axis
+# of rotation for each chi-angle in each residue.
+def chi_angle_atom(atom_index: int) -> np.ndarray:
+ """Define chi-angle rigid groups via one-hot representations."""
+ chi_angles_index = {}
+ one_hots = []
+
+ for k, v in chi_angles_atoms.items():
+ indices = [atom_types.index(s[atom_index]) for s in v]
+ indices.extend([-1]*(4-len(indices)))
+ chi_angles_index[k] = indices
+
+ for r in restypes:
+ res3 = restype_1to3[r]
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
+ one_hots.append(one_hot)
+
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
+ one_hot = np.stack(one_hots, axis=0)
+ one_hot = np.transpose(one_hot, [0, 2, 1])
+
+ return one_hot
+
+chi_atom_1_one_hot = chi_angle_atom(1)
+chi_atom_2_one_hot = chi_angle_atom(2)
+
+# An array like chi_angles_atoms but using indices rather than names.
+chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
+chi_angles_atom_indices = tree.map_structure(
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
+chi_angles_atom_indices = np.array([
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
+ for chi_atoms in chi_angles_atom_indices])
+
+# Mapping from (res_name, atom_name) pairs to the atom's chi group index
+# and atom index within that group.
+chi_groups_for_atom = collections.defaultdict(list)
+for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
+ for atom_i, atom in enumerate(chi_group):
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
+chi_groups_for_atom = dict(chi_groups_for_atom)
+
+
+def _make_rigid_transformation_4x4(ex, ey, translation):
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
+ # Normalize ex.
+ ex_normalized = ex / np.linalg.norm(ex)
+
+ # make ey perpendicular to ex
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
+ ey_normalized /= np.linalg.norm(ey_normalized)
+
+ # compute ez as cross product
+ eznorm = np.cross(ex_normalized, ey_normalized)
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
+ m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
+ return m
+
+
+# create an array with (restype, atomtype) --> rigid_group_idx
+# and an array with (restype, atomtype, coord) for the atom positions
+# and compute affine transformation matrices (4,4) from one rigid group to the
+# previous group
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
+restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
+restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
+restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
+restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
+
+
+def _make_rigid_group_constants():
+ """Fill the arrays above."""
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
+ resname]:
+ atomtype = atom_order[atomname]
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
+ restype_atom37_mask[restype, atomtype] = 1
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
+
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
+ restype_atom14_mask[restype, atom14idx] = 1
+ restype_atom14_rigid_group_positions[restype,
+ atom14idx, :] = atom_position
+
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_positions = {name: np.array(pos) for name, _, pos
+ in rigid_group_atom_positions[resname]}
+
+ # backbone to backbone is the identity transform
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
+
+ # pre-omega-frame to backbone (currently dummy identity matrix)
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
+
+ # phi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['N'] - atom_positions['CA'],
+ ey=np.array([1., 0., 0.]),
+ translation=atom_positions['N'])
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
+
+ # psi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['C'] - atom_positions['CA'],
+ ey=atom_positions['CA'] - atom_positions['N'],
+ translation=atom_positions['C'])
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
+
+ # chi1-frame to backbone
+ if chi_angles_mask[restype][0]:
+ base_atom_names = chi_angles_atoms[resname][0]
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
+ mat = _make_rigid_transformation_4x4(
+ ex=base_atom_positions[2] - base_atom_positions[1],
+ ey=base_atom_positions[0] - base_atom_positions[1],
+ translation=base_atom_positions[2])
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
+
+ # chi2-frame to chi1-frame
+ # chi3-frame to chi2-frame
+ # chi4-frame to chi3-frame
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
+ # previous frame
+ for chi_idx in range(1, 4):
+ if chi_angles_mask[restype][chi_idx]:
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
+ mat = _make_rigid_transformation_4x4(
+ ex=axis_end_atom_position,
+ ey=np.array([-1., 0., 0.]),
+ translation=axis_end_atom_position)
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
+
+
+_make_rigid_group_constants()
+
+
+def make_atom14_dists_bounds(overlap_tolerance=1.5,
+ bond_length_tolerance_factor=15):
+ """compute upper and lower bounds for bonds to assess violations."""
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_list = restype_name_to_atom14_names[resname]
+
+ # create lower and upper bounds for clashes
+ for atom1_idx, atom1_name in enumerate(atom_list):
+ if not atom1_name:
+ continue
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
+ for atom2_idx, atom2_name in enumerate(atom_list):
+ if (not atom2_name) or atom1_idx == atom2_idx:
+ continue
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
+ lower = atom1_radius + atom2_radius - overlap_tolerance
+ upper = 1e10
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+
+ # overwrite lower and upper bounds for bonds and angles
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
+ atom1_idx = atom_list.index(b.atom1_name)
+ atom2_idx = atom_list.index(b.atom2_name)
+ lower = b.length - bond_length_tolerance_factor * b.stddev
+ upper = b.length + bond_length_tolerance_factor * b.stddev
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
+ return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
+ 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
+ 'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
+ }
\ No newline at end of file
diff --git a/data/so3_utils.py b/data/so3_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..96353d8765d8b8b6258e8f3508d1c0e1909435bd
--- /dev/null
+++ b/data/so3_utils.py
@@ -0,0 +1,1375 @@
+import logging
+import os
+from typing import Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+def scale_rotmat(
+ rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-7
+) -> torch.Tensor:
+ """
+ Scale rotation matrix. This is done by converting it to vector representation,
+ scaling the length of the vector and converting back to matrix representation.
+
+ Args:
+ rotation_matrix: Rotation matrices.
+ scalar: Scalar values used for scaling. Should have one fewer dimension than the
+ rotation matrices for correct broadcasting.
+ tol: Numerical offset for stability.
+
+ Returns:
+ Scaled rotation matrix.
+ """
+ # Check whether dimensions match.
+ assert rotation_matrix.ndim - 1 == scalar.ndim
+ scaled_rmat = rotvec_to_rotmat(rotmat_to_rotvec(rotation_matrix) * scalar, tol=tol)
+ return scaled_rmat
+
+
+def _broadcast_identity(target: torch.Tensor) -> torch.Tensor:
+ """
+ Generate a 3 by 3 identity matrix and broadcast it to a batch of target matrices.
+
+ Args:
+ target (torch.Tensor): Batch of target 3 by 3 matrices.
+
+ Returns:
+ torch.Tensor: 3 by 3 identity matrices in the shapes of the target.
+ """
+ id3 = torch.eye(3, device=target.device, dtype=target.dtype)
+ id3 = torch.broadcast_to(id3, target.shape)
+ return id3
+
+
+def skew_matrix_exponential_map_axis_angle(
+ angles: torch.Tensor, skew_matrices: torch.Tensor
+) -> torch.Tensor:
+ """
+ Compute the matrix exponential of a rotation in axis-angle representation with the axis in skew
+ matrix representation form. Maps the rotation from the lie group to the rotation matrix
+ representation. Uses Rodrigues' formula instead of `torch.linalg.matrix_exp` for better
+ computational performance:
+
+ .. math::
+
+ \exp(\theta \mathbf{K}) = \mathbf{I} + \sin(\theta) \mathbf{K} + [1 - \cos(\theta)] \mathbf{K}^2
+
+ Args:
+ angles (torch.Tensor): Batch of rotation angles.
+ skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis.
+
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices.
+ """
+ # Set up identity matrix and broadcast.
+ id3 = _broadcast_identity(skew_matrices)
+
+ # Broadcast angle vector to right dimensions
+ angles = angles[..., None, None]
+
+ exp_skew = (
+ id3
+ + torch.sin(angles) * skew_matrices
+ + (1.0 - torch.cos(angles))
+ * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices)
+ )
+ return exp_skew
+
+
+def skew_matrix_exponential_map(
+ angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-7
+) -> torch.Tensor:
+ """
+ Compute the matrix exponential of a rotation vector in skew matrix representation. Maps the
+ rotation from the lie group to the rotation matrix representation. Uses the following form of
+ Rodrigues' formula instead of `torch.linalg.matrix_exp` for better computational performance
+ (in this case the skew matrix already contains the angle factor):
+
+ .. math ::
+
+ \exp(\mathbf{K}) = \mathbf{I} + \frac{\sin(\theta)}{\theta} \mathbf{K} + \frac{1-\cos(\theta)}{\theta^2} \mathbf{K}^2
+
+ This form has the advantage, that Taylor expansions can be used for small angles (instead of
+ having to compute the unit length axis by dividing the rotation vector by small angles):
+
+ .. math ::
+
+ \frac{\sin(\theta)}{\theta} \approx 1 - \frac{\theta^2}{6}
+ \frac{1-\cos(\theta)}{\theta^2} \approx \frac{1}{2} - \frac{\theta^2}{24}
+
+ Args:
+ angles (torch.Tensor): Batch of rotation angles.
+ skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis.
+
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices.
+ """
+ # Set up identity matrix and broadcast.
+ id3 = _broadcast_identity(skew_matrices)
+
+ # Broadcast angles and pre-compute square.
+ angles = angles[..., None, None]
+ angles_sq = angles.square()
+
+ # Get standard terms.
+ sin_coeff = torch.sin(angles) / angles
+ cos_coeff = (1.0 - torch.cos(angles)) / angles_sq
+ # Use second order Taylor expansion for values close to zero.
+ sin_coeff_small = 1.0 - angles_sq / 6.0
+ cos_coeff_small = 0.5 - angles_sq / 24.0
+
+ mask_zero = torch.abs(angles) < tol
+ sin_coeff = torch.where(mask_zero, sin_coeff_small, sin_coeff)
+ cos_coeff = torch.where(mask_zero, cos_coeff_small, cos_coeff)
+
+ # Compute matrix exponential using Rodrigues' formula.
+ exp_skew = (
+ id3
+ + sin_coeff * skew_matrices
+ + cos_coeff * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices)
+ )
+ return exp_skew
+
+
+def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-7) -> torch.Tensor:
+ """
+ Convert rotation vectors to rotation matrix representation. The length of the rotation vector
+ is the angle of rotation, the unit vector the rotation axis.
+
+ Args:
+ rotation_vectors (torch.Tensor): Batch of rotation vectors.
+ tol: small offset for numerical stability.
+
+ Returns:
+ torch.Tensor: Rotation in rotation matrix representation.
+ """
+ # Compute rotation angle as vector norm.
+ rotation_angles = torch.norm(rotation_vectors, dim=-1)
+
+ # Map axis to skew matrix basis.
+ skew_matrices = vector_to_skew_matrix(rotation_vectors)
+
+ # Compute rotation matrices via matrix exponential.
+ rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices, tol=tol)
+
+ return rotation_matrices
+
+
+def rotmat_to_rotvec(rotation_matrices: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a batch of rotation matrices to rotation vectors (logarithmic map from SO(3) to so(3)).
+ The standard logarithmic map can be derived from Rodrigues' formula via Taylor approximation
+ (in this case operating on the vector coefficients of the skew so(3) basis).
+
+ ..math ::
+
+ \left[\log(\mathbf{R})\right]^\lor = \frac{\theta}{2\sin(\theta)} \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor
+
+ This formula has problems at 1) angles theta close or equal to zero and 2) at angles close and
+ equal to pi.
+
+ To improve numerical stability for case 1), the angle term at small or zero angles is
+ approximated by its truncated Taylor expansion:
+
+ .. math ::
+
+ \left[\log(\mathbf{R})\right]^\lor \approx \frac{1}{2} (1 + \frac{\theta^2}{6}) \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor
+
+ For angles close or equal to pi (case 2), the outer product relation can be used to obtain the
+ squared rotation vector:
+
+ .. math :: \omega \otimes \omega = \frac{1}{2}(\mathbf{I} + R)
+
+ Taking the root of the diagonal elements recovers the normalized rotation vector up to the signs
+ of the component. The latter can be obtained from the off-diagonal elements.
+
+ Adapted from https://github.com/jasonkyuyim/se3_diffusion/blob/2cba9e09fdc58112126a0441493b42022c62bbea/data/so3_utils.py
+ which was adapted from https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/special_orthogonal.py
+ with heavy help from https://cvg.cit.tum.de/_media/members/demmeln/nurlanov2021so3log.pdf
+
+ Args:
+ rotation_matrices (torch.Tensor): Input batch of rotation matrices.
+
+ Returns:
+ torch.Tensor: Batch of rotation vectors.
+ """
+ # Get angles and sin/cos from rotation matrix.
+ angles, angles_sin, _ = angle_from_rotmat(rotation_matrices)
+ # Compute skew matrix representation and extract so(3) vector components.
+ vector = skew_matrix_to_vector(rotation_matrices - rotation_matrices.transpose(-2, -1))
+
+ # Three main cases for angle theta, which are captured
+ # 1) Angle is 0 or close to zero -> use Taylor series for small values / return 0 vector.
+ mask_zero = torch.isclose(angles, torch.zeros_like(angles)).to(angles.dtype)
+ # 2) Angle is close to pi -> use outer product relation.
+ mask_pi = torch.isclose(angles, torch.full_like(angles, np.pi), atol=1e-2).to(angles.dtype)
+ # 3) Angle is unproblematic -> use the standard formula.
+ mask_else = (1 - mask_zero) * (1 - mask_pi)
+
+ # Compute case dependent pre-factor (1/2 for angle close to 0, angle otherwise).
+ numerator = mask_zero / 2.0 + angles * mask_else
+ # The Taylor expansion used here is actually the inverse of the Taylor expansion of the inverted
+ # fraction sin(x) / x which gives better accuracy over a wider range (hence the minus and
+ # position in denominator).
+ denominator = (
+ (1.0 - angles**2 / 6.0) * mask_zero # Taylor expansion for small angles.
+ + 2.0 * angles_sin * mask_else # Standard formula.
+ + mask_pi # Avoid zero division at angle == pi.
+ )
+ prefactor = numerator / denominator
+ vector = vector * prefactor[..., None]
+
+ # For angles close to pi, derive vectors from their outer product (ww' = 1 + R).
+ id3 = _broadcast_identity(rotation_matrices)
+ skew_outer = (id3 + rotation_matrices) / 2.0
+ # Ensure diagonal is >= 0 for square root (uses identity for masking).
+ skew_outer = skew_outer + (torch.relu(skew_outer) - skew_outer) * id3
+
+ # Get basic rotation vector as sqrt of diagonal (is unit vector).
+ vector_pi = torch.sqrt(torch.diagonal(skew_outer, dim1=-2, dim2=-1))
+
+ # Compute the signs of vector elements (up to a global phase).
+ # Fist select indices for outer product slices with the largest norm.
+ signs_line_idx = torch.argmax(torch.norm(skew_outer, dim=-1), dim=-1).long()
+ # Select rows of outer product and determine signs.
+ signs_line = torch.take_along_dim(skew_outer, dim=-2, indices=signs_line_idx[..., None, None])
+ signs_line = signs_line.squeeze(-2)
+ signs = torch.sign(signs_line)
+
+ # Apply signs and rotation vector.
+ vector_pi = vector_pi * angles[..., None] * signs
+
+ # Fill entries for angle == pi in rotation vector (basic vector has zero entries at this point).
+ vector = vector + vector_pi * mask_pi[..., None]
+
+ return vector
+
+
+def angle_from_rotmat(
+ rotation_matrices: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Compute rotation angles (as well as their sines and cosines) encoded by rotation matrices.
+ Uses atan2 for better numerical stability for small angles.
+
+ Args:
+ rotation_matrices (torch.Tensor): Batch of rotation matrices.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Batch of computed angles, sines of the
+ angles and cosines of angles.
+ """
+ # Compute sine of angles (uses the relation that the unnormalized skew vector generated by a
+ # rotation matrix has the length 2*sin(theta))
+ skew_matrices = rotation_matrices - rotation_matrices.transpose(-2, -1)
+ skew_vectors = skew_matrix_to_vector(skew_matrices)
+ angles_sin = torch.norm(skew_vectors, dim=-1) / 2.0
+ # Compute the cosine of the angle using the relation cos theta = 1/2 * (Tr[R] - 1)
+ angles_cos = (torch.einsum("...ii", rotation_matrices) - 1.0) / 2.0
+
+ # Compute angles using the more stable atan2
+ angles = torch.atan2(angles_sin, angles_cos)
+
+ return angles, angles_sin, angles_cos
+
+
+def vector_to_skew_matrix(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Map a vector into the corresponding skew matrix so(3) basis.
+ ```
+ [ 0 -z y]
+ [x,y,z] -> [ z 0 -x]
+ [ -y x 0]
+ ```
+
+ Args:
+ vectors (torch.Tensor): Batch of vectors to be mapped to skew matrices.
+
+ Returns:
+ torch.Tensor: Vectors in skew matrix representation.
+ """
+ # Generate empty skew matrices.
+ skew_matrices = torch.zeros((*vectors.shape, 3), device=vectors.device, dtype=vectors.dtype)
+
+ # Populate positive values.
+ skew_matrices[..., 2, 1] = vectors[..., 0]
+ skew_matrices[..., 0, 2] = vectors[..., 1]
+ skew_matrices[..., 1, 0] = vectors[..., 2]
+
+ # Generate skew symmetry.
+ skew_matrices = skew_matrices - skew_matrices.transpose(-2, -1)
+
+ return skew_matrices
+
+
+def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor:
+ """
+ Extract a rotation vector from the so(3) skew matrix basis.
+
+ Args:
+ skew_matrices (torch.Tensor): Skew matrices.
+
+ Returns:
+ torch.Tensor: Rotation vectors corresponding to skew matrices.
+ """
+ vectors = torch.zeros_like(skew_matrices[..., 0])
+ vectors[..., 0] = skew_matrices[..., 2, 1]
+ vectors[..., 1] = skew_matrices[..., 0, 2]
+ vectors[..., 2] = skew_matrices[..., 1, 0]
+ return vectors
+
+
+def _rotquat_to_axis_angle(
+ rotation_quaternions: torch.Tensor, tol: float = 1e-7
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Auxiliary routine for computing rotation angle and rotation axis from unit quaternions. To avoid
+ complications, rotations vectors with angles below `tol` are set to zero.
+
+ Args:
+ rotation_quaternions (torch.Tensor): Rotation quaternions in [r, i, j, k] format.
+ tol (float, optional): Threshold for small rotations. Defaults to 1e-7.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Rotation angles and axes.
+ """
+ # Compute rotation axis and normalize (accounting for small length axes).
+ rotation_axes = rotation_quaternions[..., 1:]
+ rotation_axes_norms = torch.norm(rotation_axes, dim=-1)
+
+ # Compute rotation angle via atan2
+ rotation_angles = 2.0 * torch.atan2(rotation_axes_norms, rotation_quaternions[..., 0])
+
+ # Save division.
+ rotation_axes = rotation_axes / (rotation_axes_norms[:, None] + tol)
+ return rotation_angles, rotation_axes
+
+
+def rotquat_to_rotvec(rotation_quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert unit quaternions to rotation vectors.
+
+ Args:
+ rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format.
+
+ Returns:
+ torch.Tensor: Rotation vectors.
+ """
+ rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions)
+ rotation_vectors = rotation_axes * rotation_angles[..., None]
+ return rotation_vectors
+
+
+def rotquat_to_rotmat(rotation_quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert unit quaternion to rotation matrix.
+
+ Args:
+ rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format.
+
+ Returns:
+ torch.Tensor: Rotation matrices.
+ """
+ rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions)
+ skew_matrices = vector_to_skew_matrix(rotation_axes * rotation_angles[..., None])
+ rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices)
+ return rotation_matrices
+
+
+def apply_rotvec_to_rotmat(
+ rotation_matrices: torch.Tensor,
+ rotation_vectors: torch.Tensor,
+ tol: float = 1e-7,
+) -> torch.Tensor:
+ """
+ Update a rotation encoded in a rotation matrix with a rotation vector.
+
+ Args:
+ rotation_matrices: Input batch of rotation matrices.
+ rotation_vectors: Input batch of rotation vectors.
+ tol: Small offset for numerical stability.
+
+ Returns:
+ Updated rotation matrices.
+ """
+ # Convert vector to matrices.
+ rmat_right = rotvec_to_rotmat(rotation_vectors, tol=tol)
+ # Accumulate rotation.
+ rmat_rotated = torch.einsum("...ij,...jk->...ik", rotation_matrices, rmat_right)
+ return rmat_rotated
+
+
+def rotmat_to_skew_matrix(mat: torch.Tensor) -> torch.Tensor:
+ """
+ Generates skew matrix for corresponding rotation matrix.
+
+ Args:
+ mat (torch.Tensor): Batch of rotation matrices.
+
+ Returns:
+ torch.Tensor: Skew matrices in the shapes of mat.
+ """
+ vec = rotmat_to_rotvec(mat)
+ return vector_to_skew_matrix(vec)
+
+
+def skew_matrix_to_rotmat(skew: torch.Tensor) -> torch.Tensor:
+ """
+ Generates rotation matrix for corresponding skew matrix.
+
+ Args:
+ skew (torch.Tensor): Batch of target 3 by 3 skew symmetric matrices.
+
+ Returns:
+ torch.Tensor: Rotation matrices in the shapes of skew.
+ """
+ vec = skew_matrix_to_vector(skew)
+ return rotvec_to_rotmat(vec)
+
+
+def local_log(point: torch.Tensor, base_point: torch.Tensor) -> torch.Tensor:
+ """
+ Matrix logarithm. Computes left-invariant vector field of beinging base_point to point
+ on the manifold. Follows the signature of geomstats' equivalent function.
+ https://geomstats.github.io/api/geometry.html#geomstats.geometry.lie_group.MatrixLieGroup.log
+
+ Args:
+ point (torch.Tensor): Batch of rotation matrices to compute vector field at.
+ base_point (torch.Tensor): Transport coordinates to take matrix logarithm.
+
+ Returns:
+ torch.Tensor: Skew matrix that holds the vector field (in the tangent space).
+ """
+ return rotmat_to_skew_matrix(rot_mult(rot_transpose(base_point), point))
+
+
+def multidim_trace(mat: torch.Tensor) -> torch.Tensor:
+ """Take the trace of a matrix with leading dimensions."""
+ return torch.einsum("...ii->...", mat)
+
+
+def geodesic_dist(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor:
+ """
+ Calculate the geodesic distance of two rotation matrices.
+
+ Args:
+ mat_1 (torch.Tensor): First rotation matrix.
+ mat_2 (torch.Tensor): Second rotation matrix.
+
+ Returns:
+ Scalar for the geodesic distance between mat_1 and mat_2 with the same
+ leading (i.e. batch) dimensions.
+ """
+ A = rotmat_to_skew_matrix(rot_mult(rot_transpose(mat_1), mat_2))
+ return torch.sqrt(multidim_trace(rot_mult(A, rot_transpose(A))))
+
+
+def rot_transpose(mat: torch.Tensor) -> torch.Tensor:
+ """Take the transpose of the last two dimensions."""
+ return torch.transpose(mat, -1, -2)
+
+
+def rot_mult(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor:
+ """Matrix multiply two rotation matrices with leading dimensions."""
+ return torch.einsum("...ij,...jk->...ik", mat_1, mat_2)
+
+
+def calc_rot_vf(mat_t: torch.Tensor, mat_1: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the vector field Log_{mat_t}(mat_1).
+
+ Args:
+ mat_t (torch.Tensor): base point to compute vector field at.
+ mat_1 (torch.Tensor): target rotation.
+
+ Returns:
+ Rotation vector representing the vector field.
+ """
+ return rotmat_to_rotvec(rot_mult(rot_transpose(mat_t), mat_1))
+
+
+def geodesic_t(t: float, mat: torch.Tensor, base_mat: torch.Tensor, rot_vf=None) -> torch.Tensor:
+ """
+ Computes the geodesic at time t. Specifically, R_t = Exp_{base_mat}(t * Log_{base_mat}(mat)).
+
+ Args:
+ t: time along geodesic.
+ mat: target points on manifold.
+ base_mat: source point on manifold.
+
+ Returns:
+ Point along geodesic starting at base_mat and ending at mat.
+ """
+ if rot_vf is None:
+ rot_vf = calc_rot_vf(base_mat, mat)
+ mat_t = rotvec_to_rotmat(t * rot_vf)
+ if base_mat.shape != mat_t.shape:
+ raise ValueError(
+ f'Incompatible shapes: base_mat={base_mat.shape}, mat_t={mat_t.shape}')
+ return torch.einsum("...ij,...jk->...ik", base_mat, mat_t)
+
+
+class SO3LookupCache:
+ def __init__(
+ self,
+ cache_dir: str,
+ cache_file: str,
+ overwrite: bool = False,
+ ) -> None:
+ """
+ Auxiliary class for handling storage / loading of SO(3) lookup tables in npz format.
+
+ Args:
+ cache_dir: Path to the cache directory.
+ cache_file: Basic file name of the cache file.
+ overwrite: Whether existing cache files should be overwritten if requested.
+ """
+ if not cache_file.endswith(".npz"):
+ raise ValueError("Filename should have '.npz' extension.")
+ self.cache_file = cache_file
+ self.cache_dir = cache_dir
+ self.cache_path = os.path.join(cache_dir, cache_file)
+ self.overwrite = overwrite
+
+ @property
+ def path_exists(self) -> bool:
+ return os.path.exists(self.cache_path)
+
+ @property
+ def dir_exists(self) -> bool:
+ return os.path.exists(self.cache_dir)
+
+ def delete_cache(self) -> None:
+ """
+ Delete the cache file.
+ """
+ if self.path_exists:
+ os.remove(self.cache_path)
+
+ def load_cache(self) -> Dict[str, torch.Tensor]:
+ """
+ Load data from the cache file.
+
+ Returns:
+ Dictionary of loaded data tensors.
+ """
+ if self.path_exists:
+ # Load data and convert to torch tensors.
+ npz_data = np.load(self.cache_path)
+ torch_dict = {f: torch.from_numpy(npz_data[f]) for f in npz_data.files}
+ logger.info(f"Data loaded from {self.cache_path}")
+ return torch_dict
+ else:
+ raise ValueError(f"No cache data found at {self.cache_path}.")
+
+ def save_cache(self, data: Dict[str, torch.Tensor]) -> None:
+ """
+ Save a dictionary of tensors to the cache file. If overwrite is set to True, an existing
+ file is overwritten, otherwise a warning is raised and the file is not modified.
+
+ Args:
+ data: Dictionary of tensors that should be saved to the cache.
+ """
+ if not self.dir_exists:
+ os.makedirs(self.cache_dir)
+
+ if self.path_exists:
+ if self.overwrite:
+ logger.info("Overwriting cache ...")
+ self.delete_cache()
+ else:
+ logger.warn(
+ f"Cache at {self.cache_path} exits and overwriting disabled. Doing nothing."
+ )
+ else:
+ # Move everything to CPU and numpy and store.
+ logger.info(f"Data saved to {self.cache_path}")
+ numpy_dict = {k: v.detach().cpu().numpy() for k, v in data.items()}
+ np.savez(self.cache_path, **numpy_dict)
+
+
+class BaseSampleSO3(nn.Module):
+ so3_type: str = "base" # cache basename
+
+ def __init__(
+ self,
+ num_omega: int,
+ sigma_grid: torch.Tensor,
+ omega_exponent: int = 3,
+ tol: float = 1e-7,
+ interpolate: bool = True,
+ cache_dir: Optional[str] = None,
+ overwrite_cache: bool = False,
+ device: str = 'cpu',
+ ) -> None:
+ """
+ Base torch.nn module for sampling rotations from the IGSO(3) distribution. Samples are
+ created by uniformly sampling a rotation axis and using inverse transform sampling for
+ the angles. The latter uses the associated SO(3) cumulative probability distribution
+ function (CDF) and a uniform distribution [0,1] as described in [#leach2022_1]_. CDF values
+ are obtained by numerically integrating the probability distribution evaluated on a grid of
+ angles and noise levels and stored in a lookup table. Linear interpolation is used to
+ approximate continuos sampling of the function. Angles are discretized in an interval [0,pi]
+ and the grid can be squashed to have higher resolutions at low angles by taking different
+ powers. Since sampling relies on tabulated values of the CDF and indexing in the form of
+ `torch.bucketize`, gradients are not supported.
+
+ Args:
+ num_omega (int): Number of discrete angles used for generating the lookup table.
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+ omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
+ its power with the provided number. Defaults to 3.
+ tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
+ interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
+ to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
+ cache_dir: Path to an optional cache directory. If set to None, lookup tables are
+ computed on the fly.
+ overwrite_cache: If set to true, existing cache files are overwritten. Can be used for
+ updating stale caches.
+
+ References
+ ----------
+ .. [#leach2022_1] Leach, Schmon, Degiacomi, Willcocks:
+ Denoising diffusion probabilistic models on so (3) for rotational alignment.
+ ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022.
+ """
+ super().__init__()
+ self.num_omega = num_omega
+ self.omega_exponent = omega_exponent
+ self.tol = tol
+ self.interpolate = interpolate
+ self.device = device
+ self.register_buffer("sigma_grid", sigma_grid, persistent=False)
+
+ # Generate / load lookups and store in non-persistent buffers.
+ omega_grid, cdf_igso3 = self._setup_lookup(sigma_grid, cache_dir, overwrite_cache)
+ self.register_buffer("omega_grid", omega_grid, persistent=False)
+ self.register_buffer("cdf_igso3", cdf_igso3, persistent=False)
+
+ def _setup_lookup(
+ self,
+ sigma_grid: torch.Tensor,
+ cache_dir: Optional[str] = None,
+ overwrite_cache: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Master function for setting up the lookup tables. These can either be loaded from a npz
+ cache file or computed on the fly. Lookup tables will always be created and stored in double
+ precision. Casting to the target dtype is done at the end of the function.
+
+ Args:
+ sigma_grid: Grid of sigma values used for computing the lookup tables.
+ cache_dir: Path to the cache directory.
+ overwrite_cache: If set to true, an existing cache is overwritten. Can be used for
+ updating stale caches.
+
+ Returns:
+ Grid of angle values and SO(3) cumulative distribution function.
+ """
+ if cache_dir is not None:
+ cache_name = self._get_cache_name()
+ cache = SO3LookupCache(cache_dir, cache_name, overwrite=True)
+
+ # If cache dir is provided, check whether the necessary cache exists and whether it
+ # should be overwritten.
+ if cache.path_exists and not overwrite_cache:
+ # Load data from cache.
+ cache_data = cache.load_cache()
+ omega_grid = cache_data["omega_grid"]
+ cdf_igso3 = cache_data["cdf_igso3"]
+ else:
+ # Store data in cache (overwrite if requested).
+ omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid)
+ cache.save_cache({"omega_grid": omega_grid, "cdf_igso3": cdf_igso3})
+ else:
+ # Other wise just generate the tables.
+ omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid)
+
+ return omega_grid.to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype)
+
+ def _get_cache_name(self) -> str:
+ """
+ Auxiliary function for determining the cache file name based on the parameters (sigma,
+ omega, l, etc.) used for generating the lookup tables.
+
+ Returns:
+ Base name of the cache file.
+ """
+ cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_o{:d}-{:d}.npz".format(
+ self.so3_type,
+ torch.min(self.sigma_grid).cpu().item(),
+ torch.max(self.sigma_grid).cpu().item(),
+ self.sigma_grid.shape[0],
+ self.num_omega,
+ self.omega_exponent,
+ )
+ return cache_name
+
+ def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor:
+ """
+ Convert continuous sigmas to the indices of the closest tabulated values.
+
+ Args:
+ sigma (torch.Tensor): IGSO3 std devs.
+
+ Returns:
+ torch.Tensor: Index tensor mapping the provided sigma values to the internal lookup
+ table.
+ """
+ return torch.bucketize(sigma, self.sigma_grid)
+
+ def expansion_function(
+ self, omega_grid: torch.Tensor, sigma_grid: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Function for generating the angle probability distribution. Should return a 2D tensor with
+ values for the std dev at the first dimension (rows) and angles at the second
+ (columns).
+
+ Args:
+ omega_grid (torch.Tensor): Grid of angle values.
+ sigma_grid (torch.Tensor): IGSO3 std devs.
+
+ Returns:
+ torch.Tensor: Distribution for angles discretized on a 2D grid.
+ """
+ raise NotImplementedError
+
+ @torch.no_grad()
+ def _generate_lookup(self, sigma_grid: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Generate the lookup table for sampling from the target SO(3) CDF. The table is 2D, with the
+ rows corresponding to different sigma values and the columns with angles computed on a grid.
+ Variance is scaled by a factor of 1/2 to account for the deacceleration of time in the
+ diffusion process due to the choice of SO(3) basis and guarantee time-reversibility (see
+ appendix E.3 in [#yim2023_2]_). The returned tables are double precision and will be cast
+ to the target dtype in `_setup_lookup`.
+
+ Args:
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple containing the grid used to compute the angles
+ and the associated lookup table.
+
+ References
+ ----------
+ .. [#yim2023_2] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola:
+ SE(3) diffusion model with application to protein backbone generation.
+ arXiv preprint arXiv:2302.02277. 2023.
+ """
+
+ current_device = sigma_grid.device
+ sigma_grid_tmp = sigma_grid.to(torch.float64)
+
+ # If cuda is available, initialize everything on GPU.
+ # Even if Pytorch Lightning usually handles GPU allocation after initialization, this is
+ # required to initialize the module in GPU reducing the initializaiton time by orders of magnitude.
+ if torch.cuda.is_available():
+ sigma_grid_tmp = sigma_grid_tmp.to(device=self.device)
+
+ # Set up grid for angle resolution. Convert to double precision for better handling of numerics.
+ omega_grid = torch.linspace(0.0, 1, self.num_omega + 1).to(sigma_grid_tmp)
+
+ # If requested, increase sample density for lower values
+ omega_grid = omega_grid**self.omega_exponent
+
+ omega_grid = omega_grid * np.pi
+
+ # Compute the expansion for all omegas and sigmas.
+ pdf_igso3 = self.expansion_function(omega_grid, sigma_grid_tmp)
+
+ # Apply the pre-factor from USO(3).
+ pdf_igso3 = pdf_igso3 * (1.0 - torch.cos(omega_grid)) / np.pi
+
+ # Compute the cumulative probability distribution.
+ cdf_igso3 = integrate_trapezoid_cumulative(pdf_igso3, omega_grid)
+ # Normalize integral area to 1.
+ cdf_igso3 = cdf_igso3 / cdf_igso3[:, -1][:, None]
+
+ # Move back to original device.
+ cdf_igso3 = cdf_igso3.to(device=current_device)
+ omega_grid = omega_grid.to(device=current_device)
+
+ return omega_grid[1:].to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype)
+
+ def sample(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor:
+ """
+ Generate samples from the target SO(3) distribution by sampling a rotation axis angle,
+ which are then combined into a rotation vector and transformed into the corresponding
+ rotation matrix via an exponential map.
+
+ Args:
+ sigma_indices (torch.Tensor): Indices of the IGSO3 std devs for which to take samples.
+ num_samples (int): Number of angle samples to take for each std dev
+
+ Returns:
+ torch.Tensor: Sampled rotations in matrix representation with dimensions
+ [num_sigma x num_samples x 3 x 3].
+ """
+
+ vectors = self.sample_vector(sigma.shape[0], num_samples)
+ angles = self.sample_angle(sigma, num_samples)
+
+ # Do postprocessing on angles.
+ angles = self._process_angles(sigma, angles)
+
+ rotation_vectors = vectors * angles[..., None]
+
+ rotation_matrices = rotvec_to_rotmat(rotation_vectors, tol=self.tol)
+ return rotation_matrices
+
+ def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
+ """
+ Auxiliary function for performing additional processing steps on the sampled angles. One
+ example would be to ensure sampled angles are 0 for a std dev of 0 for IGSO(3).
+
+ Args:
+ sigma (torch.Tensor): Current values of sigma.
+ angles (torch.Tensor): Sampled angles.
+
+ Returns:
+ torch.Tensor: Processed sampled angles.
+ """
+ return angles
+
+ def sample_vector(self, num_sigma: int, num_samples: int) -> torch.Tensor:
+ """
+ Uniformly sample rotation axis for constructing the overall rotation.
+
+ Args:
+ num_sigma (int): Number of samples to draw for each std dev.
+ num_samples (int): Number of angle samples to take for each std dev.
+
+ Returns:
+ torch.Tensor: Batch of rotation axes with dimensions [num_sigma x num_samples x 3].
+ """
+ vectors = torch.randn(num_sigma, num_samples, 3, device=self.sigma_grid.device)
+ vectors = vectors / torch.norm(vectors, dim=2, keepdim=True)
+ return vectors
+
+ def sample_angle(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor:
+ """
+ Create a series of samples from the IGSO(3) angle distribution.
+
+ Args:
+ sigma_indices (torch.Tensor): Indices of the IGSO3 std deves for which to
+ take samples.
+ num_samples (int): Number of angle samples to take for each std dev.
+
+ Returns:
+ torch.Tensor: Collected samples, will have the dimension [num_sigma x num_samples].
+ """
+ # Convert sigmas to respective indices for lookup table.
+ sigma_indices = self.get_sigma_idx(sigma)
+ # Get relevant sigma slices from stored CDFs.
+ cdf_tmp = self.cdf_igso3[sigma_indices, :]
+
+ # Draw from uniform distribution.
+ p_uniform = torch.rand((*sigma_indices.shape, *[num_samples]), device=sigma_indices.device)
+
+ # Determine indices for CDF.
+ idx_stop = torch.sum(cdf_tmp[..., None] < p_uniform[:, None, :], dim=1).long()
+ idx_start = torch.clamp(idx_stop - 1, min=0)
+
+ if not self.interpolate:
+ omega = torch.gather(cdf_tmp, dim=1, index=idx_stop)
+ else:
+ # Get CDF values.
+ cdf_start = torch.gather(cdf_tmp, dim=1, index=idx_start)
+ cdf_stop = torch.gather(cdf_tmp, dim=1, index=idx_stop)
+
+ # Compute weights for linear interpolation.
+ cdf_delta = torch.clamp(cdf_stop - cdf_start, min=self.tol)
+ cdf_weight = torch.clamp((p_uniform - cdf_start) / cdf_delta, min=0.0, max=1.0)
+
+ # Get angle range for interpolation.
+ omega_start = self.omega_grid[idx_start]
+ omega_stop = self.omega_grid[idx_stop]
+
+ # Interpolate.
+ omega = torch.lerp(omega_start, omega_stop, cdf_weight)
+
+ return omega
+
+
+class SampleIGSO3(BaseSampleSO3):
+ so3_type = "igso3" # cache basename
+
+ def __init__(
+ self,
+ num_omega: int,
+ sigma_grid: torch.Tensor,
+ omega_exponent: int = 3,
+ tol: float = 1e-7,
+ interpolate: bool = True,
+ l_max: int = 1000,
+ cache_dir: Optional[str] = None,
+ overwrite_cache: bool = False,
+ device: str = 'cpu',
+ ) -> None:
+ """
+ Module for sampling rotations from the IGSO(3) distribution using the explicit series
+ expansion. Samples are created using inverse transform sampling based on the associated
+ cumulative probability distribution function (CDF) and a uniform distribution [0,1] as
+ described in [#leach2022_2]_. CDF values are obtained by numerically integrating the
+ probability distribution evaluated on a grid of angles and noise levels and stored in a
+ lookup table. Linear interpolation is used to approximate continuos sampling of the
+ function. Angles are discretized in an interval [0,pi] and the grid can be squashed to have
+ higher resolutions at low angles by taking different powers.
+ Since sampling relies on tabulated values of the CDF and indexing in the form of
+ `torch.bucketize`, gradients are not supported.
+
+ Args:
+ num_omega (int): Number of discrete angles used for generating the lookup table.
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+ omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
+ its power with the provided number. Defaults to 3.
+ tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
+ interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
+ to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
+ l_max (int, optional): Maximum number of terms used in the series expansion.
+ cache_dir: Path to an optional cache directory. If set to None, lookup tables are
+ computed on the fly.
+ overwrite_cache: If set to true, existing cache files are overwritten. Can be used for
+ updating stale caches.
+
+ References
+ ----------
+ .. [#leach2022_2] Leach, Schmon, Degiacomi, Willcocks:
+ Denoising diffusion probabilistic models on so (3) for rotational alignment.
+ ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022.
+ """
+ self.l_max = l_max
+ super().__init__(
+ num_omega=num_omega,
+ sigma_grid=sigma_grid,
+ omega_exponent=omega_exponent,
+ tol=tol,
+ interpolate=interpolate,
+ cache_dir=cache_dir,
+ overwrite_cache=overwrite_cache,
+ device=device,
+ )
+
+ def _get_cache_name(self) -> str:
+ """
+ Auxiliary function for determining the cache file name based on the parameters (sigma,
+ omega, l, etc.) used for generating the lookup tables.
+
+ Returns:
+ Base name of the cache file.
+ """
+ cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_l{:d}_o{:d}-{:d}.npz".format(
+ self.so3_type,
+ torch.min(self.sigma_grid).cpu().item(),
+ torch.max(self.sigma_grid).cpu().item(),
+ self.sigma_grid.shape[0],
+ self.l_max,
+ self.num_omega,
+ self.omega_exponent,
+ )
+ return cache_name
+
+ def expansion_function(
+ self,
+ omega_grid: torch.Tensor,
+ sigma_grid: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Use the truncated expansion of the IGSO(3) probability function to generate the lookup table.
+
+ Args:
+ omega_grid (torch.Tensor): Grid of angle values.
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+
+ Returns:
+ torch.Tensor: IGSO(3) distribution for angles discretized on a 2D grid.
+ """
+ return generate_igso3_lookup_table(omega_grid, sigma_grid, l_max=self.l_max, tol=self.tol)
+
+ def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
+ """
+ Ensure sampled angles are 0 for small noise levels in IGSO(3). (Series expansion gives
+ uniform probability distribution.)
+
+ Args:
+ sigma (torch.Tensor): Current values of sigma.
+ angles (torch.Tensor): Sampled angles.
+
+ Returns:
+ torch.Tensor: Processed sampled angles.
+ """
+ angles = torch.where(
+ sigma[..., None] < self.tol,
+ torch.zeros_like(angles),
+ angles,
+ )
+ return angles
+
+
+class SampleUSO3(BaseSampleSO3):
+ so3_type = "uso3" # cache basename
+
+ def __init__(
+ self,
+ num_omega: int,
+ sigma_grid: torch.Tensor,
+ omega_exponent: int = 3,
+ tol: float = 1e-7,
+ interpolate: bool = True,
+ cache_dir: Optional[str] = None,
+ overwrite_cache: bool = False,
+ ) -> None:
+ """
+ Module for sampling rotations from the USO(3) distribution. Can be used to generate initial
+ unbiased samples in the reverse process. Samples are created using inverse transform
+ sampling based on the associated cumulative probability distribution function (CDF) and a
+ uniform distribution [0,1] as described in [#leach2022_4]_. CDF values are obtained by
+ numerically integrating the probability distribution evaluated on a grid of angles and noise
+ levels and stored in a lookup table. Linear interpolation is used to approximate continuos
+ sampling of the function. Angles are discretized in an interval [0,pi] and the grid can be
+ squashed to have higher resolutions at low angles by taking different powers.
+ Since sampling relies on tabulated values of the CDF and indexing in the form of
+ `torch.bucketize`, gradients are not supported.
+
+ Args:
+ num_omega (int): Number of discrete angles used for generating the lookup table.
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+ omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking
+ its power with the provided number. Defaults to 3.
+ tol (float, optional): Small value for numerical stability. Defaults to 1e-7.
+ interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF
+ to sample angles. Otherwise the closest tabulated point is returned. Defaults to True.
+ cache_dir: Path to an optional cache directory. If set to None, lookup tables are
+ computed on the fly.
+ overwrite_cache: If set to true, existing cache files are overwritten. Can be used for
+ updating stale caches.
+
+ References
+ ----------
+ .. [#leach2022_4] Leach, Schmon, Degiacomi, Willcocks:
+ Denoising diffusion probabilistic models on so (3) for rotational alignment.
+ ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022.
+ """
+ super().__init__(
+ num_omega=num_omega,
+ sigma_grid=sigma_grid,
+ omega_exponent=omega_exponent,
+ tol=tol,
+ interpolate=interpolate,
+ cache_dir=cache_dir,
+ overwrite_cache=overwrite_cache,
+ )
+
+ def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor:
+ return torch.zeros_like(sigma).long()
+
+ def sample_shape(self, num_sigma: int, num_samples: int) -> torch.Tensor:
+ dummy_sigma = torch.zeros(num_sigma, device=self.sigma_grid.device)
+ return self.sample(dummy_sigma, num_samples)
+
+ def expansion_function(
+ self,
+ omega_grid: torch.Tensor,
+ sigma_grid: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ The probability density function of the uniform SO(3) distribution is the cosine scaling
+ term (1-cos(omega))/pi which is applied automatically during sampling. This means, it is
+ sufficient to return a tensor of ones to create the correct USO(3) lookup table.
+
+ Args:
+ omega_grid (torch.Tensor): Grid of angle values.
+ sigma_grid (torch.Tensor): Grid of IGSO3 std devs.
+
+ Returns:
+ torch.Tensor: USO(3) distribution for angles discretized on a 2D grid.
+ """
+ return torch.ones(1, omega_grid.shape[0], device=omega_grid.device)
+
+
+@torch.no_grad()
+def integrate_trapezoid_cumulative(f_grid: torch.Tensor, x_grid: torch.Tensor) -> torch.Tensor:
+ """
+ Auxiliary function for numerically integrating a discretized 1D function using the trapezoid
+ rule. This is mainly used for computing the cumulative probability distributions for sampling
+ from the IGSO(3) distribution. Works on a single 1D grid or a batch of grids.
+
+ Args:
+ f_grid (torch.Tensor): Discretized function values.
+ x_grid (torch.Tensor): Discretized input values.
+
+ Returns:
+ torch.Tensor: Integrated function (not normalized).
+ """
+ f_sum = f_grid[..., :-1] + f_grid[..., 1:]
+ delta_x = torch.diff(x_grid, dim=-1)
+ integral = torch.cumsum((f_sum * delta_x[None, :]) / 2.0, dim=-1)
+ return integral
+
+
+def uniform_so3_density(omega: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the density over the uniform angle distribution in SO(3).
+
+ Args:
+ omega: Angles in radians.
+
+ Returns:
+ Uniform distribution density.
+ """
+ return (1.0 - torch.cos(omega)) / np.pi
+
+
+def igso3_expansion(
+ omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
+) -> torch.Tensor:
+ """
+ Compute the IGSO(3) angle probability distribution function for pairs of angles and std dev
+ levels. The expansion is computed using a grid of expansion orders ranging from 0 to l_max.
+
+ This function approximates the power series in equation 5 of [#yim2023_3]_. With this
+ parameterization, IGSO(3) agrees with the Brownian motion on SO(3) with t=sigma^2.
+
+ Args:
+ omega: Values of angles (1D tensor).
+ sigma: Values of std dev of IGSO3 distribution (1D tensor of same shape as `omega`).
+ l_grid: Tensor containing expansion orders (0 to l_max).
+ tol: Small offset for numerical stability.
+
+ Returns:
+ IGSO(3) angle distribution function (without pre-factor for uniform SO(3) distribution).
+
+ References
+ ----------
+ .. [#yim2023_3] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola:
+ SE(3) diffusion model with application to protein backbone generation.
+ arXiv preprint arXiv:2302.02277. 2023.
+ """
+ # Pre-compute sine in denominator and clamp for stability.
+ denom_sin = torch.sin(0.5 * omega)
+
+ # Pre-compute terms that rely only on expansion orders.
+ l_fac_1 = 2.0 * l_grid + 1.0
+ l_fac_2 = -l_grid * (l_grid + 1.0)
+
+ # Pre-compute numerator of expansion which only depends on angles.
+ numerator_sin = torch.sin((l_grid[None, :] + 1 / 2) * omega[:, None])
+
+ # Pre-compute exponential term with (2l+1) prefactor.
+ exponential_term = l_fac_1[None, :] * torch.exp(l_fac_2[None, :] * sigma[:, None] ** 2 / 2)
+
+ # Compute series expansion
+ f_igso = torch.sum(exponential_term * numerator_sin, dim=1)
+ # For small omega, accumulate limit of sine fraction instead:
+ # lim[x->0] sin((l+1/2)x) / sin(x/2) = 2l + 1
+ f_limw = torch.sum(exponential_term * l_fac_1[None, :], dim=1)
+
+ # Finalize expansion. Offset for stability can be added since omega is [0,pi] and sin(omega/2)
+ # is positive in this interval.
+ f_igso = f_igso / (denom_sin + tol)
+
+ # Replace values at small omega with limit.
+ f_igso = torch.where(omega <= tol, f_limw, f_igso)
+
+ # Remove remaining numerical problems
+ f_igso = torch.where(
+ torch.logical_or(torch.isinf(f_igso), torch.isnan(f_igso)), torch.zeros_like(f_igso), f_igso
+ )
+
+ return f_igso
+
+
+def digso3_expansion(
+ omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
+) -> torch.Tensor:
+ """
+ Compute the derivative of the IGSO(3) angle probability distribution function with respect to
+ the angles for pairs of angles and std dev levels. As in `igso3_expansion` a grid is used for the
+ expansion levels. Evaluates the derivative directly in order to avoid second derivatives during
+ backpropagation.
+
+ The derivative of the angle-dependent part is computed as:
+
+ .. math ::
+ \frac{\partial}{\partial \omega} \frac{\sin((l+\tfrac{1}{2})\omega)}{\sin(\tfrac{1}{2}\omega)} = \frac{l\sin((l+1)\omega) - (l+1)\sin(l\omega)}{1 - \cos(\omega)}
+
+ (obtained via quotient rule + different trigonometric identities).
+
+ Args:
+ omega: Values of angles (1D tensor).
+ sigma: Values of IGSO3 distribution std devs (1D tensor of same shape as `omega`).
+ l_grid: Tensor containing expansion orders (0 to l_max).
+ tol: Small offset for numerical stability.
+
+ Returns:
+ IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution).
+ """
+ denom_cos = 1.0 - torch.cos(omega)
+
+ l_fac_1 = 2.0 * l_grid + 1.0
+ l_fac_2 = l_grid + 1.0
+ l_fac_3 = -l_grid * l_fac_2
+
+ # Pre-compute numerator of expansion which only depends on angles.
+ numerator_sin = l_grid[None, :] * torch.sin(l_fac_2[None, :] * omega[:, None]) - l_fac_2[
+ None, :
+ ] * torch.sin(l_grid[None, :] * omega[:, None])
+
+ # Compute series expansion
+ df_igso = torch.sum(
+ l_fac_1[None, :] * torch.exp(l_fac_3[None, :] * sigma[:, None] ** 2 / 2) * numerator_sin,
+ dim=1,
+ )
+
+ # Finalize expansion. Offset for stability can be added since omega is [0,pi] and cosine term
+ # is positive in this interval.
+ df_igso = df_igso / (denom_cos + tol)
+
+ # Replace values at small omega with limit (=0).
+ df_igso = torch.where(omega <= tol, torch.zeros_like(df_igso), df_igso)
+
+ # Remove remaining numerical problems
+ df_igso = torch.where(
+ torch.logical_or(torch.isinf(df_igso), torch.isnan(df_igso)),
+ torch.zeros_like(df_igso),
+ df_igso,
+ )
+
+ return df_igso
+
+
+def dlog_igso3_expansion(
+ omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7
+) -> torch.Tensor:
+ """
+ Compute the derivative of the logarithm of the IGSO(3) angle distribution function for pairs of
+ angles and std dev levels:
+
+ .. math ::
+ \frac{\partial}{\partial \omega} \log f(\omega) = \frac{\tfrac{\partial}{\partial \omega} f(\omega)}{f(\omega)}
+
+ Required for SO(3) score computation.
+
+ Args:
+ omega: Values of angles (1D tensor).
+ sigma: Values of IGSO3 std devs (1D tensor of same shape as `omega`).
+ l_grid: Tensor containing expansion orders (0 to l_max).
+ tol: Small offset for numerical stability.
+
+ Returns:
+ IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution).
+ """
+ f_igso3 = igso3_expansion(omega, sigma, l_grid, tol=tol)
+ df_igso3 = digso3_expansion(omega, sigma, l_grid, tol=tol)
+
+ return df_igso3 / (f_igso3 + tol)
+
+
+@torch.no_grad()
+def generate_lookup_table(
+ base_function: Callable,
+ omega_grid: torch.Tensor,
+ sigma_grid: torch.Tensor,
+ l_max: int = 1000,
+ tol: float = 1e-7,
+):
+ """
+ Auxiliary function for generating a lookup table from IGSO(3) expansions and their derivatives.
+ Takes a basic function and loops over different std dev levels.
+
+ Args:
+ base_function: Function used for setting up the lookup table.
+ omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]).
+ sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]).
+ l_max: Number of terms used in the series expansion.
+ tol: Small value for numerical stability.
+
+ Returns:
+ Table of function values evaluated at different angles and std dev levels. The final shape is
+ [num_sigma x num_omega].
+ """
+ # Generate grid of expansion orders.
+ l_grid = torch.arange(l_max + 1, device=omega_grid.device).to(omega_grid.dtype)
+
+ n_omega = len(omega_grid)
+ n_sigma = len(sigma_grid)
+
+ # Populate lookup table for different time frames.
+ f_table = torch.zeros(n_sigma, n_omega, device=omega_grid.device, dtype=omega_grid.dtype)
+
+ for eps_idx in tqdm(range(n_sigma), desc=f"Computing {base_function.__name__}"):
+ f_table[eps_idx, :] = base_function(
+ omega_grid,
+ torch.ones_like(omega_grid) * sigma_grid[eps_idx],
+ l_grid,
+ tol=tol,
+ )
+
+ return f_table
+
+
+def generate_igso3_lookup_table(
+ omega_grid: torch.Tensor,
+ sigma_grid: torch.Tensor,
+ l_max: int = 1000,
+ tol: float = 1e-7,
+) -> torch.Tensor:
+ """
+ Generate a lookup table for the IGSO(3) probability distribution function of angles.
+
+ Args:
+ omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]).
+ sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]).
+ l_max: Number of terms used in the series expansion.
+ tol: Small value for numerical stability.
+
+ Returns:
+ Table of function values evaluated at different angles and std dev levels. The final shape is
+ [num_sigma x num_omega].
+ """
+ f_igso = generate_lookup_table(
+ base_function=igso3_expansion,
+ omega_grid=omega_grid,
+ sigma_grid=sigma_grid,
+ l_max=l_max,
+ tol=tol,
+ )
+ return f_igso
+
+
+def generate_dlog_igso3_lookup_table(
+ omega_grid: torch.Tensor,
+ sigma_grid: torch.Tensor,
+ l_max: int = 1000,
+ tol: float = 1e-7,
+) -> torch.Tensor:
+ """
+ Generate a lookup table for the derivative of the logarithm of the angular IGSO(3) probability
+ distribution function. Used e.g. for computing scaling of SO(3) norms.
+
+ Args:
+ omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]).
+ sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]).
+ l_max: Number of terms used in the series expansion.
+ tol: Small value for numerical stability.
+
+ Returns:
+ Table of function values evaluated at different angles and std dev levels. The final shape is
+ [num_sigma x num_omega].
+ """
+ dlog_igso = generate_lookup_table(
+ base_function=dlog_igso3_expansion,
+ omega_grid=omega_grid,
+ sigma_grid=sigma_grid,
+ l_max=l_max,
+ tol=tol,
+ )
+ return dlog_igso
diff --git a/data/utils.py b/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aad2f7f201561f6d4a0694d1f531b23aff170519
--- /dev/null
+++ b/data/utils.py
@@ -0,0 +1,507 @@
+from typing import List, Dict, Any
+from openfold.utils import rigid_utils as ru
+from data import residue_constants
+import numpy as np
+import collections
+import string
+import pickle
+import os
+import torch
+from torch_scatter import scatter_add, scatter
+from Bio.PDB.Chain import Chain
+from data import protein
+import dataclasses
+from Bio import PDB
+
+Rigid = ru.Rigid
+Protein = protein.Protein
+
+# Global map from chain characters to integers.
+ALPHANUMERIC = string.ascii_letters + string.digits + ' '
+CHAIN_TO_INT = {
+ chain_char: i for i, chain_char in enumerate(ALPHANUMERIC)
+}
+INT_TO_CHAIN = {
+ i: chain_char for i, chain_char in enumerate(ALPHANUMERIC)
+}
+
+NM_TO_ANG_SCALE = 10.0
+ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE
+
+CHAIN_FEATS = [
+ 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors'
+]
+
+to_numpy = lambda x: x.detach().cpu().numpy()
+aatype_to_seq = lambda aatype: ''.join([
+ residue_constants.restypes_with_x[x] for x in aatype])
+
+
+class CPU_Unpickler(pickle.Unpickler):
+ """Pytorch pickle loading workaround.
+
+ https://github.com/pytorch/pytorch/issues/16797
+ """
+ def find_class(self, module, name):
+ if module == 'torch.storage' and name == '_load_from_bytes':
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
+ else: return super().find_class(module, name)
+
+
+def create_rigid(rots, trans):
+ rots = ru.Rotation(rot_mats=rots)
+ return Rigid(rots=rots, trans=trans)
+
+
+def batch_align_structures(pos_1, pos_2, mask=None):
+ if pos_1.shape != pos_2.shape:
+ raise ValueError('pos_1 and pos_2 must have the same shape.')
+ if pos_1.ndim != 3:
+ raise ValueError(f'Expected inputs to have shape [B, N, 3]')
+ num_batch = pos_1.shape[0]
+ device = pos_1.device
+ batch_indices = (
+ torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64)
+ * torch.arange(num_batch, device=device)[:, None]
+ )
+ flat_pos_1 = pos_1.reshape(-1, 3)
+ flat_pos_2 = pos_2.reshape(-1, 3)
+ flat_batch_indices = batch_indices.reshape(-1)
+ if mask is None:
+ # aligned_pos_1, aligned_pos_2, align_rots = align_structures(
+ # flat_pos_1, flat_batch_indices, flat_pos_2)
+ # aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3)
+ # aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3)
+ # return aligned_pos_1, aligned_pos_2, align_rots
+ mask = torch.ones(*pos_1.shape[:2], device=device).reshape(-1).bool()
+
+ flat_mask = mask.reshape(-1).bool()
+ # _, _, align_rots = align_structures(
+ # flat_pos_1[flat_mask],
+ # flat_batch_indices[flat_mask],
+ # flat_pos_2[flat_mask]
+ # )
+ # aligned_pos_1 = torch.bmm(
+ # pos_1,
+ # align_rots
+ # )
+ # return aligned_pos_1, pos_2, align_rots
+ aligned_pos_1, aligned_pos_2, align_rots = align_structures(
+ flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask])
+ aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3)
+ aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3)
+ return aligned_pos_1, aligned_pos_2, align_rots
+
+
+
+def adjust_oxygen_pos(
+ atom_37: torch.Tensor, pos_is_known = None
+) -> torch.Tensor:
+ """
+ Imputes the position of the oxygen atom on the backbone by using adjacent frame information.
+ Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the
+ current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom
+ away from the C in the current frame in the direction away from the Ca-C-N triangle.
+
+ For cases where the next frame is not available, for example we are at the C-terminus or the
+ next frame is not available in the data then we place the oxygen in the same plane as the
+ N-Ca-C of the current frame and pointing in the same direction as the average of the
+ Ca->C and Ca->N vectors.
+
+ Args:
+ atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering
+ which is ['N', 'CA', 'C', 'CB', 'O', ...]
+ pos_is_known (torch.Tensor): (N,) mask for known residues.
+ """
+
+ N = atom_37.shape[0]
+ assert atom_37.shape == (N, 37, 3)
+
+ # Get vectors to Carbonly from Carbon alpha and N of next residue. (N-1, 3)
+ # Note that the (N,) ordering is from N-terminal to C-terminal.
+
+ # Calpha to carbonyl both in the current frame.
+ calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / (
+ torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7
+ )
+ # For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0.
+ # The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change.
+
+ # Nitrogen of the next frame to carbonyl of the current frame.
+ nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / (
+ torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7
+ )
+
+ carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl # (N-1, 3)
+ carbonyl_to_oxygen = carbonyl_to_oxygen / (
+ torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7
+ )
+
+ atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23
+
+ # Now we deal with frames for which there is no next frame available.
+
+ # Calpha to carbonyl both in the current frame. (N, 3)
+ calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / (
+ torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
+ )
+ # Calpha to nitrogen both in the current frame. (N, 3)
+ calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / (
+ torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
+ )
+ carbonyl_to_oxygen_term: torch.Tensor = (
+ calpha_to_carbonyl_term + calpha_to_nitrogen_term
+ ) # (N, 3)
+ carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
+ torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7
+ )
+
+ # Create a mask that is 1 when the next residue is not available either
+ # due to this frame being the C-terminus or the next residue is not
+ # known due to pos_is_known being false.
+
+ if pos_is_known is None:
+ pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device)
+
+ next_res_gone: torch.Tensor = ~pos_is_known.bool() # (N,)
+ next_res_gone = torch.cat(
+ [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0
+ ) # (N+1, )
+ next_res_gone = next_res_gone[1:] # (N,)
+
+ atom_37[next_res_gone, 4, :] = (
+ atom_37[next_res_gone, 2, :]
+ + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23
+ )
+
+ return atom_37
+
+
+def write_pkl(
+ save_path: str, pkl_data: Any, create_dir: bool = False, use_torch=False):
+ """Serialize data into a pickle file."""
+ if create_dir:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ if use_torch:
+ torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
+ else:
+ with open(save_path, 'wb') as handle:
+ pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None):
+ """Read data from a pickle file."""
+ try:
+ if use_torch:
+ return torch.load(read_path, map_location=map_location)
+ else:
+ with open(read_path, 'rb') as handle:
+ return pickle.load(handle)
+ except Exception as e:
+ try:
+ with open(read_path, 'rb') as handle:
+ return CPU_Unpickler(handle).load()
+ except Exception as e2:
+ if verbose:
+ print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}')
+ raise(e)
+
+
+def chain_str_to_int(chain_str: str):
+ chain_int = 0
+ if len(chain_str) == 1:
+ return CHAIN_TO_INT[chain_str]
+ for i, chain_char in enumerate(chain_str):
+ chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC))
+ return chain_int
+
+
+def parse_chain_feats(chain_feats, scale_factor=1.):
+ ca_idx = residue_constants.atom_order['CA']
+ chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx]
+ bb_pos = chain_feats['atom_positions'][:, ca_idx]
+ bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
+ centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
+ scaled_pos = centered_pos / scale_factor
+ chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
+ chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx]
+ return chain_feats
+
+
+def concat_np_features(
+ np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool):
+ """Performs a nested concatenation of feature dicts.
+
+ Args:
+ np_dicts: list of dicts with the same structure.
+ Each dict must have the same keys and numpy arrays as the values.
+ add_batch_dim: whether to add a batch dimension to each feature.
+
+ Returns:
+ A single dict with all the features concatenated.
+ """
+ combined_dict = collections.defaultdict(list)
+ for chain_dict in np_dicts:
+ for feat_name, feat_val in chain_dict.items():
+ if add_batch_dim:
+ feat_val = feat_val[None]
+ combined_dict[feat_name].append(feat_val)
+ # Concatenate each feature
+ for feat_name, feat_vals in combined_dict.items():
+ combined_dict[feat_name] = np.concatenate(feat_vals, axis=0)
+ return combined_dict
+
+
+def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor:
+ """
+ Move the molecule center to zero for sparse position tensors.
+
+ Args:
+ pos: [N, 3] batch positions of atoms in the molecule in sparse batch format.
+ batch_indexes: [N] batch index for each atom in sparse batch format.
+
+ Returns:
+ pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format.
+ """
+ assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]"
+
+ means = scatter(pos, batch_indexes, dim=0, reduce="mean")
+ return pos - means[batch_indexes]
+
+
+@torch.no_grad()
+def align_structures(
+ batch_positions: torch.Tensor,
+ batch_indices: torch.Tensor,
+ reference_positions: torch.Tensor,
+ broadcast_reference: bool = False,
+):
+ """
+ Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the
+ sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then
+ the reference can be given as a single structure and broadcasted. Returns the structure
+ coordinates shifted to the geometric center and the batch structures rotated to match the
+ reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of
+ atoms is carried out.
+
+ Args:
+ batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned
+ to a reference.
+ batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective
+ system (e.g. batch attribute of ChemGraph batch).
+ reference_positions (Tensor): Reference structure. Can either be a batch of structures or a
+ single structure. In the second case, broadcasting is possible if the input batch is
+ composed exclusively of this structure.
+ broadcast_reference (bool, optional): If reference batch contains only a single structure,
+ broadcast this structure to match the ChemGraph batch. Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch
+ structures rotated into the reference and the centered reference batch.
+
+ References
+ ----------
+ .. [kabsch_align1] Lawrence, Bernal, Witzgall:
+ A purely algebraic justification of the Kabsch-Umeyama algorithm.
+ Journal of research of the National Institute of Standards and Technology, 124, 1. 2019.
+ """
+ # Minimize || Q @ R.T - P ||, which is the same as || Q - P @ R ||
+ # batch_positions -> P [BN x 3]
+ # reference_positions -> Q [B / BN x 3]
+
+ if batch_positions.shape[0] != reference_positions.shape[0]:
+ if broadcast_reference:
+ # Get number of systems in batch and broadcast reference structure.
+ # This assumes, all systems in the current batch correspond to the reference system.
+ # Typically always the case during evaluation.
+ num_molecules = int(torch.max(batch_indices) + 1)
+ reference_positions = reference_positions.repeat(num_molecules, 1)
+ else:
+ raise ValueError("Mismatch in batch dimensions.")
+
+ # Center structures at origin (takes care of translation alignment)
+ batch_positions = center_zero(batch_positions, batch_indices)
+ reference_positions = center_zero(reference_positions, batch_indices)
+
+ # Compute covariance matrix for optimal rotation (Q.T @ P) -> [B x 3 x 3].
+ cov = scatter_add(
+ batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0
+ )
+
+ # Perform singular value decomposition. (all [B x 3 x 3])
+ u, _, v_t = torch.linalg.svd(cov)
+ # Convenience transposes.
+ u_t = u.transpose(1, 2)
+ v = v_t.transpose(1, 2)
+
+ # Compute rotation matrix correction for ensuring right-handed coordinate system
+ # For comparison with other sources: det(AB) = det(A)*det(B) and det(A) = det(A.T)
+ sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t)))
+ # Correct transpose of U: diag(1, 1, sign_correction) @ U.T
+ u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None]
+
+ # Compute optimal rotation matrix (R = V @ diag(1, 1, sign_correction) @ U.T).
+ rotation_matrices = torch.bmm(v, u_t)
+
+ # print('batch_positions=',batch_positions.dtype)
+ # print('rotation_matrices=',rotation_matrices.dtype)
+ rotation_matrices = rotation_matrices.type(batch_positions.dtype)
+
+ # Rotate batch positions P to optimal alignment with Q (P @ R)
+ batch_positions_rotated = torch.bmm(
+ batch_positions[:, None, :],
+ rotation_matrices[batch_indices],
+ ).squeeze(1)
+
+ return batch_positions_rotated, reference_positions, rotation_matrices
+
+
+def parse_pdb_feats(
+ pdb_name: str,
+ pdb_path: str,
+ scale_factor=1.,
+ # TODO: Make the default behaviour read all chains.
+ chain_id='A',
+ ):
+ """
+ Args:
+ pdb_name: name of PDB to parse.
+ pdb_path: path to PDB file to read.
+ scale_factor: factor to scale atom positions.
+ mean_center: whether to mean center atom positions.
+ Returns:
+ Dict with CHAIN_FEATS features extracted from PDB with specified
+ preprocessing.
+ """
+ parser = PDB.PDBParser(QUIET=True)
+ structure = parser.get_structure(pdb_name, pdb_path)
+ struct_chains = {
+ chain.id: chain
+ for chain in structure.get_chains()}
+
+ def _process_chain_id(x):
+ chain_prot = process_chain(struct_chains[x], x)
+ chain_dict = dataclasses.asdict(chain_prot)
+
+ # Process features
+ feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS}
+ return parse_chain_feats(
+ feat_dict, scale_factor=scale_factor)
+
+ if isinstance(chain_id, str):
+ return _process_chain_id(chain_id)
+ elif isinstance(chain_id, list):
+ return {
+ x: _process_chain_id(x) for x in chain_id
+ }
+ elif chain_id is None:
+ return {
+ x: _process_chain_id(x) for x in struct_chains
+ }
+ else:
+ raise ValueError(f'Unrecognized chain list {chain_id}')
+
+def rigid_transform_3D(A, B, verbose=False):
+ # Transforms A to look like B
+ # https://github.com/nghiaho12/rigid_transform_3D
+ assert A.shape == B.shape
+ A = A.T
+ B = B.T
+
+ num_rows, num_cols = A.shape
+ if num_rows != 3:
+ raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
+
+ num_rows, num_cols = B.shape
+ if num_rows != 3:
+ raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
+
+ # find mean column wise
+ centroid_A = np.mean(A, axis=1)
+ centroid_B = np.mean(B, axis=1)
+
+ # ensure centroids are 3x1
+ centroid_A = centroid_A.reshape(-1, 1)
+ centroid_B = centroid_B.reshape(-1, 1)
+
+ # subtract mean
+ Am = A - centroid_A
+ Bm = B - centroid_B
+
+ H = Am @ np.transpose(Bm)
+
+ # sanity check
+ #if linalg.matrix_rank(H) < 3:
+ # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))
+
+ # find rotation
+ U, S, Vt = np.linalg.svd(H)
+ R = Vt.T @ U.T
+
+ # special reflection case
+ reflection_detected = False
+ if np.linalg.det(R) < 0:
+ if verbose:
+ print("det(R) < R, reflection detected!, correcting for it ...")
+ Vt[2,:] *= -1
+ R = Vt.T @ U.T
+ reflection_detected = True
+
+ t = -R @ centroid_A + centroid_B
+ optimal_A = R @ A + t
+
+ return optimal_A.T, R, t, reflection_detected
+
+def process_chain(chain: Chain, chain_id: str) -> Protein:
+ """Convert a PDB chain object into a AlphaFold Protein instance.
+
+ Forked from alphafold.common.protein.from_pdb_string
+
+ WARNING: All non-standard residue types will be converted into UNK. All
+ non-standard atoms will be ignored.
+
+ Took out lines 94-97 which don't allow insertions in the PDB.
+ Sabdab uses insertions for the chothia numbering so we need to allow them.
+
+ Took out lines 110-112 since that would mess up CDR numbering.
+
+ Args:
+ chain: Instance of Biopython's chain class.
+
+ Returns:
+ Protein object with protein features.
+ """
+ atom_positions = []
+ aatype = []
+ atom_mask = []
+ residue_index = []
+ b_factors = []
+ chain_ids = []
+ for res in chain:
+ res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
+ restype_idx = residue_constants.restype_order.get(
+ res_shortname, residue_constants.restype_num)
+ pos = np.zeros((residue_constants.atom_type_num, 3))
+ mask = np.zeros((residue_constants.atom_type_num,))
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
+ for atom in res:
+ if atom.name not in residue_constants.atom_types:
+ continue
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
+ mask[residue_constants.atom_order[atom.name]] = 1.
+ res_b_factors[residue_constants.atom_order[atom.name]
+ ] = atom.bfactor
+ aatype.append(restype_idx)
+ atom_positions.append(pos)
+ atom_mask.append(mask)
+ residue_index.append(res.id[1])
+ b_factors.append(res_b_factors)
+ chain_ids.append(chain_id)
+
+ return Protein(
+ atom_positions=np.array(atom_positions),
+ atom_mask=np.array(atom_mask),
+ aatype=np.array(aatype),
+ residue_index=np.array(residue_index),
+ chain_index=np.array(chain_ids),
+ b_factors=np.array(b_factors))
diff --git a/dataset/ATLAS_filename.txt b/dataset/ATLAS_filename.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a7c90622cc0b78be36866c41896ed2cc048a457e
--- /dev/null
+++ b/dataset/ATLAS_filename.txt
@@ -0,0 +1,1390 @@
+1k5n_A
+5e3e_A
+3ilw_A
+3bl9_A
+6o2v_A
+1zd7_B
+2lis_A
+3qra_A
+7ead_A
+3bn0_A
+1fx4_A
+4byz_A
+4a56_A
+1d3y_B
+4oe9_B
+5kko_D
+4bt9_B
+3pkz_A
+2nwv_A
+4tsh_A
+5kdw_B
+3rv1_B
+4goq_B
+1f20_A
+3nrw_A
+4g75_A
+1gwe_A
+6irx_A
+3g5t_A
+6uof_A
+2pmb_C
+3kby_A
+1w2w_A
+4geh_A
+6lus_A
+5kdv_A
+2py5_A
+3wol_A
+3cw3_A
+2hqk_A
+3ll7_A
+5ezu_A
+1f3u_G
+3a7l_A
+2i5v_O
+1tez_B
+1jyo_F
+1jyo_B
+1o1x_A
+6eof_A
+2f5t_X
+1ys1_X
+3e7h_A
+1nns_A
+3ipf_A
+6qj0_A
+4ezi_A
+4u9b_A
+2v6a_O
+1q74_C
+2jdd_A
+1pv5_A
+4zya_A
+2ii1_D
+3h1d_A
+3cjs_B
+1r8m_E
+1lln_A
+5opz_A
+3oyv_A
+3jsz_A
+2gg2_A
+4q7o_B
+3vkw_A
+3hgl_A
+1n2m_A
+1vk1_A
+1vdk_B
+2cb5_A
+3hms_A
+2wnp_F
+5x1e_D
+3qc7_A
+1lc3_A
+6j56_A
+6fsw_A
+1z1l_A
+3m6z_A
+2fhz_A
+5ws1_A
+2x5g_A
+1rm6_B
+3nr1_A
+3nuf_A
+2xdg_A
+3ov5_A
+4wyw_A
+3do6_A
+3g28_A
+2wzo_A
+3lx5_A
+3ooi_A
+1btk_A
+2gkp_A
+2pmk_A
+1dd3_B
+6cka_B
+7ec1_A
+1qlm_A
+3fpu_A
+1msk_A
+2qsk_A
+3emi_A
+2qcu_B
+1ux6_A
+2hy5_A
+1gqv_A
+5ol4_A
+2rkq_A
+2vi1_A
+3po8_A
+1egw_B
+2d5b_A
+2x96_A
+1f9p_A
+1qw2_A
+2ot9_A
+5gv0_A
+6hj6_A
+5kts_A
+2wkx_A
+6dgk_B
+2x49_A
+2db7_A
+1sf9_A
+3dan_A
+6qfk_A
+6fc0_B
+2qhq_B
+4f87_A
+3se9_G
+1bq8_A
+1tg7_A
+1dfu_P
+3onj_A
+1t9i_B
+2ygs_A
+4aqr_D
+4adn_B
+3g3s_B
+6xds_A
+6dlm_A
+1tgr_A
+4dol_A
+3e0z_A
+2e5y_B
+4p7x_A
+1pch_A
+2vrs_C
+3it4_B
+1j0p_A
+4ywk_A
+6xrx_A
+6q9c_B
+1tdz_A
+2wr8_A
+1feh_A
+2phn_A
+2ox7_B
+3gqv_A
+1n8u_A
+4ayg_B
+3jub_A
+6rrv_A
+5o9e_A
+3v0r_A
+2qnl_A
+1xg0_A
+1m7v_A
+1k7c_A
+2wn3_C
+1ouw_C
+4qtq_A
+2prv_A
+3eoj_A
+1ear_A
+2gkr_I
+1b2s_F
+1jke_C
+2fos_A
+7lao_A
+2w0i_A
+1nl1_A
+3lur_B
+1ijx_C
+4mzv_A
+4yxp_A
+6l4l_A
+1o98_A
+1sx3_A
+4m0q_A
+2fj8_A
+1xhk_B
+1a62_A
+2d28_C
+2v9m_A
+3l1e_A
+1yoz_B
+4mmj_A
+7asg_A
+2p54_A
+1g38_A
+4epi_A
+1ez3_B
+3im1_A
+1u55_A
+2dtr_A
+3c2q_B
+4p3x_A
+2p6v_A
+6mdw_A
+1ewf_A
+2i9x_A
+2fb5_A
+5lpa_A
+6bwq_A
+1io1_A
+1rlk_A
+4obi_A
+1uwk_B
+4n67_A
+1r5l_A
+6kty_A
+1u84_A
+3vgp_A
+2o6x_A
+6vjg_A
+4h9n_C
+4tkx_L
+1mv8_C
+1xvg_E
+1zdy_A
+5t76_A
+2p12_A
+3d7a_B
+6sms_A
+4xqm_A
+1m5q_I
+1y7t_B
+4lq4_A
+3fm2_A
+1jq5_A
+1lr0_A
+1uoy_A
+6l3r_E
+1xo7_B
+3o0g_D
+3lvu_D
+1qcs_A
+2zca_B
+3gnu_P
+1juv_A
+1qdd_A
+2i7c_C
+5el3_B
+3e4w_B
+4ew5_A
+4cv4_A
+5op0_B
+1jo0_A
+5ygh_A
+3bpt_A
+3vnb_A
+5b1n_A
+2qev_A
+2y44_A
+7qsu_A
+3fhf_A
+7p46_A
+5ydn_A
+3ncl_A
+7e2s_A
+2nuk_A
+3h5q_A
+1o75_A
+3d2q_A
+2y1q_A
+1wui_S
+3lwx_A
+5cfl_A
+1gqe_A
+3nrl_B
+3nke_B
+3q1c_A
+4ddp_A
+2y39_A
+4jja_A
+1ctf_A
+3dgp_B
+2w39_A
+1x6j_A
+1wx1_B
+1h8e_G
+2i2f_A
+1j8e_A
+2jhf_B
+6avj_B
+3omd_B
+2in8_A
+2hgx_B
+2xvs_A
+5w82_E
+4rgi_A
+1dow_A
+4za9_A
+3l4h_A
+5b3p_A
+2q9r_A
+3mx7_A
+6pxz_B
+3q1n_A
+1o0s_A
+1ojh_C
+2ebb_A
+3g5k_D
+2rff_A
+2gec_A
+3mmy_D
+4xb5_A
+3h93_A
+2vfx_C
+2vh3_A
+3zr8_X
+1j8r_A
+1rqj_A
+3v68_A
+5hqh_A
+3ii2_A
+5h7w_B
+3ip4_A
+3zzo_A
+3ghd_A
+6cb7_A
+3aa6_B
+2b0a_A
+5yrv_I
+2j6a_A
+2e7z_A
+3f0p_A
+5osw_A
+3o3x_A
+4jo7_A
+4omf_B
+5a5y_A
+1jwq_A
+2obl_A
+3a3d_B
+1kpt_A
+1md6_A
+4e5r_A
+3owt_B
+2znr_A
+4cjn_A
+4ndt_A
+4ghn_A
+2w6k_A
+1rk8_B
+6ovk_R
+5h6x_A
+5ul3_A
+3jv1_A
+5bs1_B
+3ajf_A
+2p84_A
+2pcn_A
+3zfn_A
+2y20_A
+6ndw_B
+5z51_A
+2w9y_A
+3fh2_A
+1i2t_A
+1vzy_B
+5kwb_A
+3eto_A
+6a9a_A
+2wya_A
+5erq_A
+3l4p_A
+6pce_B
+4v23_A
+3kbb_A
+1w53_A
+7p41_D
+1hbn_E
+1dzf_A
+1h16_A
+6h86_A
+1hq0_A
+3czz_B
+2v05_A
+1im5_A
+1jif_A
+2gpi_A
+1svb_A
+1pcf_C
+2glz_B
+4xdu_A
+4mp8_A
+5a67_A
+5b5i_B
+7jfl_C
+6anz_A
+6iah_A
+1fs1_A
+2bti_A
+5iz3_B
+1gte_D
+3a1g_B
+3ig9_C
+3vej_B
+5dbl_A
+1zsw_A
+1k7j_A
+2j7z_A
+4zi3_D
+3rt4_A
+3zoq_C
+6atg_C
+3fn2_B
+1i4m_A
+1ig0_B
+2bw3_B
+4f98_A
+2x9o_A
+2qdx_A
+5imu_A
+3sfv_B
+4jqu_B
+1krn_A
+4ehx_A
+6y2x_A
+1z7k_B
+7nmq_A
+5ef9_A
+4rp5_A
+4qx8_B
+1msc_A
+1l2w_I
+5w4a_B
+2gwn_A
+6xb3_H
+1zjc_A
+4zav_A
+2nml_A
+2d1l_A
+4a3p_A
+5n0s_A
+1u09_A
+2bo1_A
+1d2t_A
+3ohg_A
+2p58_A
+2yp8_A
+4oj6_C
+1hp1_A
+2gc4_L
+6jwh_A
+2wsi_A
+1z21_A
+2waw_A
+2bk9_A
+4xb6_F
+3rmq_A
+1xkn_A
+1b2s_E
+3wa1_A
+3cyi_A
+1nij_A
+2ahn_A
+3ot9_B
+5x1u_B
+2ed6_A
+2gg6_A
+1dlj_A
+2ad6_A
+3mil_A
+3lbl_A
+2fip_A
+2jek_A
+6l4p_B
+2vfk_A
+1g9g_A
+3x0g_A
+1ng6_A
+1h2e_A
+5n6y_C
+1tov_A
+2cxi_B
+2p2v_A
+1q2h_A
+4erv_A
+1ifg_A
+6jpt_A
+2xjq_A
+2b9d_A
+1t61_E
+1uz3_A
+3bk5_A
+5j3t_B
+4ykd_A
+5zmo_A
+5hee_B
+5ngd_B
+1qwz_A
+2pst_X
+1njh_A
+4zds_B
+6bk4_A
+1qb7_A
+3ib5_A
+1ypy_A
+5g68_A
+2pqx_A
+3luu_A
+3p5p_A
+2au3_A
+3wkx_A
+2v0u_A
+1ka1_A
+1hc9_B
+5vz0_A
+1pby_C
+3tgt_A
+6b1z_A
+3e8t_A
+3rxy_A
+2ims_A
+1g6g_A
+7a66_B
+3obl_B
+3ebh_A
+3twz_A
+2zzd_E
+2ov9_C
+5cof_A
+5eg7_A
+1wpb_A
+1rss_A
+3bjd_B
+6okd_C
+3boe_A
+2inc_A
+1v7r_A
+1fn9_A
+3on9_A
+1koe_A
+3fke_A
+3ino_A
+2bln_B
+2x5r_A
+6eu8_A
+1nf8_A
+1iom_A
+1zn6_A
+2cwy_A
+6mbg_A
+4z9p_A
+16pk_A
+3qc5_X
+6bm5_A
+3l51_B
+4m37_A
+5i8j_A
+6f45_D
+5um3_A
+6in7_A
+2xrh_A
+3pvh_A
+3fkc_A
+3b49_A
+1u8v_C
+1dtd_B
+2a6s_B
+3ueb_E
+2fma_A
+3lpc_A
+7onn_A
+2ehq_A
+6fub_B
+3b4q_B
+3bzn_A
+1qxo_A
+3m66_A
+3mhs_B
+4e1p_A
+4g6d_B
+1y0n_A
+6j33_B
+1v7z_F
+4qb0_A
+1xsz_A
+5mux_B
+5l87_A
+3tvj_I
+5xbc_B
+1ois_A
+6ono_C
+1d4t_A
+2v4n_A
+2pfx_A
+1bkp_A
+3n4j_A
+1pp0_C
+2ad6_D
+4xo1_A
+3o04_A
+5a0n_A
+4gn0_A
+2qmj_A
+3jtz_A
+3cy4_A
+2nnu_A
+3lq9_A
+5cyz_C
+1ocy_A
+6d7y_A
+4gvq_C
+1wq8_A
+1k8k_E
+2fef_B
+1t7v_A
+2y3w_A
+2fp1_A
+1q8f_A
+4exo_A
+3mcb_B
+4myy_A
+1fd3_A
+1lsl_A
+3da8_A
+2hba_A
+1m9z_A
+3im3_A
+3t7h_B
+3shg_B
+2pq8_A
+3enu_A
+2v27_A
+1bgf_A
+1ryl_A
+1nty_A
+3cla_A
+4fc2_B
+1j2l_A
+1u6t_A
+1oh9_A
+3m4r_A
+5ij2_B
+4ngd_A
+3l32_A
+5ce9_A
+2wcu_A
+1fm4_A
+5vap_A
+2g7o_A
+5jx1_A
+2pu9_A
+1c96_A
+2eo4_A
+6odd_B
+1u6r_B
+2z0x_A
+2v4b_B
+2jft_A
+3hol_A
+3l9a_X
+1vbw_A
+4ljo_A
+1tag_A
+4xb6_H
+3imk_A
+2o7o_A
+1tzw_A
+2yxn_A
+2xol_B
+2wpq_C
+2hq4_A
+4y7s_A
+1coj_A
+1kgq_A
+6gfx_C
+2gwm_A
+5dpo_A
+1dvo_A
+2fe5_A
+1vr8_A
+1rrm_A
+3a39_A
+2p9x_D
+2dka_A
+5ert_A
+3it4_C
+3bhg_A
+1m8s_A
+3rlk_A
+2dt8_A
+6p5x_B
+1jnr_C
+3u5v_A
+3ipj_A
+3zrg_B
+1hye_A
+3lyf_D
+2rkv_A
+2yjg_A
+3g98_B
+2pyq_B
+4hwx_A
+3a1u_A
+3ush_B
+3st1_A
+3o79_A
+5bnh_A
+1llf_B
+6tgk_C
+2bjq_A
+3kz7_A
+3fbl_A
+2nsa_A
+1dwk_A
+2at8_X
+2qia_A
+4iaj_B
+2hu9_A
+2e6x_A
+6a02_A
+2yvq_A
+3ju3_A
+4alz_A
+4dja_A
+3ic3_B
+4qxo_A
+1r9f_A
+2yzt_A
+3dff_A
+1wd5_A
+3e9v_A
+1s9u_A
+1wui_L
+2ozl_A
+5wvo_C
+4qbu_A
+5n0n_A
+2vmc_A
+6c0h_A
+7dmn_A
+3gs9_A
+4cbe_A
+1z0b_A
+3i94_A
+1dpt_A
+6dnm_A
+5pal_A
+2gs5_A
+2oeb_A
+1nnh_A
+4v4e_M
+4zac_D
+1ef1_D
+1huf_A
+2ixm_A
+7lp1_A
+3mol_B
+2cx7_A
+3nyy_A
+3ti2_A
+2xmx_A
+2ppv_A
+2q7x_A
+1v0w_A
+6l34_A
+5kvk_A
+5hz7_A
+5h1n_B
+1utg_A
+1m15_A
+3igz_B
+3f4m_A
+7ned_A
+2r2d_A
+7s86_A
+4na5_A
+4jnu_B
+2qtd_A
+3vc8_B
+2amh_A
+5l0s_A
+2rij_A
+2onf_B
+2ie6_A
+3oi8_A
+1bx7_A
+2rjw_A
+4z8q_B
+6l8s_A
+2gc4_B
+2gno_A
+3zdm_A
+3zl8_A
+4p5a_A
+5dmu_A
+2j6b_A
+7bwf_B
+1c1k_A
+4xgl_A
+4k2m_A
+3apr_E
+2z0j_E
+1esw_A
+1zxx_A
+1b8d_K
+3msw_A
+5z42_A
+1qau_A
+1wbe_A
+7aex_A
+1ntv_A
+2pv4_A
+2gh0_D
+4wzg_A
+1yqe_A
+3mm6_A
+6fy5_A
+1p5g_X
+3kde_C
+1whi_A
+4wrp_A
+2hnu_A
+6as3_A
+5w5c_A
+1t8k_A
+3c1q_B
+2w0g_A
+5njc_D
+2qjz_A
+4mb5_A
+3dnh_B
+1c52_A
+4fhr_B
+6d7y_B
+6e7e_A
+2o4t_A
+2hhv_A
+1xeo_A
+1z2u_A
+2b97_A
+5exe_C
+3a2q_A
+1yoe_A
+3qfl_A
+3ijw_A
+2wbf_X
+6bn0_A
+2ycl_A
+4ys0_A
+3bec_A
+5wfy_A
+2bmo_B
+1s9r_A
+7k7p_B
+1ikt_A
+2jh3_B
+4v1g_B
+5oh5_A
+3eye_A
+2f60_K
+1qnx_A
+4h4n_A
+3h0n_A
+4o66_D
+1ig3_A
+1jhs_A
+3agn_A
+7buy_A
+5gw6_A
+2b1y_A
+5v0m_A
+1o4k_A
+6yhu_B
+1r0k_A
+3eqx_A
+2end_A
+1dj0_B
+3mfd_A
+2ia1_A
+4hfq_B
+6hem_A
+2igp_A
+2bbr_A
+1qhv_A
+3gmv_X
+4qn8_B
+2dlb_A
+5z1n_A
+4ke2_C
+2op6_A
+2x5h_B
+3o6p_A
+2wqi_D
+1v9m_A
+5vu3_B
+1t95_A
+3m7d_A
+2xz3_A
+2bfw_A
+6c62_C
+2wb6_A
+4wlr_B
+2ovj_A
+5njl_A
+3q0i_A
+3nl9_A
+1n7z_A
+3fcn_A
+2dwk_A
+3m5q_A
+1lj5_A
+1gnt_A
+5jbg_A
+5m45_K
+5ori_A
+2xf2_A
+3u9g_A
+2yvo_A
+3ldu_A
+2cyj_A
+1dk8_A
+1yu5_X
+2zwd_A
+2pa6_B
+2p6h_B
+4cfi_A
+4j4r_A
+3lk7_A
+5iff_A
+4oie_A
+2cfe_A
+7fd1_A
+1hw1_B
+1v4p_C
+2fzp_A
+4f55_A
+5zlq_A
+5wx9_A
+2vy8_A
+2c9i_D
+2hxm_A
+2au7_A
+4phr_A
+5ggl_A
+6h49_A
+2hng_A
+3nqi_A
+1ail_A
+1vbi_A
+3mqz_A
+4xb6_E
+5gof_A
+2vlf_B
+1j77_A
+4fch_B
+1vyi_A
+7aqx_A
+2rji_A
+1mlw_A
+1g2r_A
+2qq4_B
+2uxy_A
+4z3x_A
+4qmd_A
+3fp5_A
+2rci_A
+2b0t_A
+2y1y_A
+3h8t_A
+7c45_A
+1upt_D
+5ltj_A
+2xeu_A
+1dci_C
+3lyy_A
+4lim_A
+3d3y_A
+5eqz_A
+4koq_A
+4q0y_A
+2gke_A
+3jz9_A
+1mk0_A
+4b6i_D
+1fxo_G
+2qsa_A
+3klr_A
+1ab1_A
+3msu_B
+3fo8_D
+4iu3_B
+1u07_A
+3hvv_A
+4ubp_C
+4s39_A
+6gus_A
+4d71_A
+5ftw_A
+3lgb_B
+2pu3_A
+3fyb_A
+2oxo_A
+3pf6_A
+5klh_A
+6ao8_A
+5g38_A
+3eh1_A
+3tbn_A
+4dlh_B
+6q9c_A
+4dok_A
+5o5s_A
+4lty_A
+3a1g_A
+5c8q_B
+3d1g_A
+5hkx_A
+5uki_A
+7n0j_E
+6o6y_A
+2bf6_A
+2b3y_A
+3v3l_A
+3ga3_A
+6zsl_B
+1xta_A
+3ni2_A
+2w56_A
+4ftf_A
+2ra8_A
+3bwu_D
+7rm7_A
+3g06_A
+4rsw_B
+3n4i_B
+3ikw_A
+1ssq_D
+1oao_C
+3gqh_A
+1zpe_A
+4rr9_A
+5by8_B
+5dje_B
+3dpg_B
+1l5o_A
+5m1m_A
+1yoc_B
+3hfw_A
+4af1_A
+2fb6_A
+3mmh_B
+2xpp_A
+4gv2_A
+1ptq_A
+1whz_A
+1j5u_A
+5j47_A
+2hyv_A
+3e2d_A
+1ojq_A
+4cdp_A
+2imf_A
+3rjs_A
+4pz1_A
+2oxl_B
+2bko_A
+3c8w_C
+3p3g_A
+1lri_A
+2ex2_A
+2rdi_A
+2qwo_B
+2vnl_A
+4a1k_A
+3lpe_B
+2rbk_A
+3ho6_A
+6ypi_A
+1fcz_A
+2i5u_A
+3ofg_B
+5nir_A
+3b9q_A
+1ugp_A
+2po4_A
+3l84_A
+5txr_B
+3c0f_B
+2psp_B
+5w2f_A
+2es9_A
+3lae_A
+2i9w_A
+3vth_A
+2oy9_B
+2ysk_A
+1aol_A
+1by2_A
+1sei_A
+1wlg_B
+1ijt_A
+2ofz_A
+2dt4_A
+2it2_B
+4bgp_A
+1vcl_A
+2aj7_A
+4yal_A
+3frr_A
+6ro6_A
+2y7b_A
+3ezu_A
+1l9l_A
+2oya_B
+7mf4_A
+2e7u_A
+2qp2_A
+4gip_A
+1wf3_A
+3bzl_D
+2ejn_B
+3c05_A
+3k5j_A
+1fk5_A
+4lpq_A
+7jrq_A
+5v7m_A
+2xz2_A
+5e7x_A
+3dso_A
+3fpn_A
+2iay_A
+1u7k_D
+3ka5_A
+3g0m_A
+1ok0_A
+2pw8_I
+2h1t_B
+4o6q_A
+5hub_A
+3pes_A
+5en2_C
+2p3p_B
+5xct_B
+1lfp_A
+4jmu_A
+5v03_B
+2vg3_C
+1ah7_A
+7wab_A
+3hlz_B
+1wna_A
+2vkp_B
+1vkm_C
+1doi_A
+2b7r_A
+3djl_A
+3c9p_A
+5znj_A
+5j2l_B
+3dza_A
+2vac_A
+3hdj_B
+3apa_A
+3bqk_A
+4ksn_C
+2z6r_A
+4ebg_A
+2pbk_B
+1h02_B
+1wq6_A
+3t4r_A
+5btz_A
+1ql0_B
+2h2z_A
+1x6i_A
+1ve0_A
+4xb4_A
+1yu0_A
+1vcc_A
+2pag_A
+3kd4_A
+1zt5_A
+2erl_A
+2r31_A
+5mdu_A
+4y93_A
+1lwd_B
+4hfv_A
+1f5n_A
+3las_B
+2z4u_A
+6pnv_A
+2o74_F
+3vor_A
+2fi9_A
+3pzw_A
+2xse_A
+3bpj_B
+2e8e_A
+1r6w_A
+5jca_S
+1vaj_A
+1uha_A
+6rwt_A
+2ftx_A
+2ccq_A
+2w8x_A
+2nqw_A
+6oz1_A
+3h1w_A
+4csb_A
+3f7w_A
+5ja5_A
+3dai_A
+3mwz_A
+1vlb_A
+5uc0_A
+2gib_B
+3let_B
+1kop_A
+2oez_B
+2nug_B
+1h99_A
+4ue8_B
+3kz5_E
+6nl2_A
+2igi_A
+2vzc_A
+3wx4_A
+5naz_A
+2vm5_A
+1y6i_A
+1rgz_A
+3nis_B
+2qud_A
+3gkj_A
+6e33_A
+4whe_A
+1g7s_A
+1iuq_A
+4gk9_A
+5iqj_C
+4ytw_C
+5exh_C
+5j72_A
+2id6_A
+1xkg_A
+5muy_B
+5noh_A
+3wvz_A
+2jlq_A
+3a0y_B
+1tt8_A
+4h9m_A
+3ju4_A
+5ok6_A
+6p5h_A
+3hxw_A
+2wfr_A
+2w86_A
+2pnw_A
+1cpq_A
+1kew_A
+1gpj_A
+3t92_A
+3ehg_A
+1uxy_A
+6q10_A
+1xty_B
+3mxz_A
+1cuo_A
+5jak_A
+3wur_B
+3wwh_A
+3d4i_A
+4m62_S
+3rkg_A
+3ol3_A
+1puc_A
+3p4h_A
+4lzh_A
+5t2p_B
+1wpu_A
+6sup_A
+1knt_A
+5h2d_A
+6jv8_A
+3crv_A
+1ix9_A
+4wt3_A
+6lrd_A
+6tly_A
+2rku_A
+6idx_A
+3m7a_A
+4oxw_A
+2wlt_A
+2zfz_D
+1sdi_A
+3pt1_A
+3jsy_B
+2g64_A
+2vri_A
+4gsj_A
+4rnz_A
+2cg7_A
+2xf7_A
+1chd_A
+3kvd_D
+4j8c_A
+4kct_B
+4u5h_G
+1qgi_A
+2y4x_B
+4o87_A
+2aib_A
+3f0o_B
+3b4w_A
+2q2g_A
+5ii7_A
+1l2p_A
+1zos_C
+6e5y_A
+3h2d_B
+3nci_A
+3kef_B
+7la6_A
+3txs_C
+1wqj_B
+3lti_A
+3gwi_A
+1xwl_A
+3bqp_B
+1xmk_A
+1nh2_B
+4zmk_A
+4cv7_A
+1ru4_A
+5k6d_B
+2jc5_A
+3dha_A
+2fyg_A
+5x1e_B
+4ct7_A
+2vcc_A
+3lkt_Q
+2f6e_A
+1ikp_A
+4kk7_A
+1b25_D
+1bxy_A
+2z7f_I
+3oxp_A
+5e5y_C
+4o6g_A
+4eo1_A
+3l2a_A
+4npd_A
+3iiu_M
+2xu8_B
+1rl6_A
+3t7z_A
+4b2f_A
+4ea9_A
+4xpz_A
+3dd6_A
+1xkr_A
+2abk_A
+4jgi_A
+5b1q_B
+5w53_B
+6crk_G
\ No newline at end of file
diff --git a/dataset/download.py b/dataset/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fcb989680da18b3cdf9205f149b01834851a73a
--- /dev/null
+++ b/dataset/download.py
@@ -0,0 +1,38 @@
+import os
+import argparse
+import multiprocessing as mp
+
+
+def fn(para):
+ file, url_base, data_dir = para
+ url = url_base + file
+ output_filename = file+".zip"
+ output_path = os.path.join(data_dir, output_filename)
+ unzip_path = os.path.join(data_dir, file)
+ os.system(f"curl -X GET {url} -H accept: */* --output {output_path}")
+ os.system(f"unzip {output_path} -d {unzip_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--filename_file", type=str, default="./dataset/ATLAS_filename.txt")
+ parser.add_argument("--url_base", type=str, default="https://www.dsimb.inserm.fr/ATLAS/api/ATLAS/analysis/")
+ parser.add_argument("--output_dir", type=str, default="./dataset/ATLAS_test")
+
+ args = parser.parse_args()
+
+ num_processes = 48
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ with open(args.filename_file,'r+') as f:
+ file_cont = f.read()
+ file_list = file_cont.split("\n")
+ para_list = [(file, args.url_base, args.output_dir) for file in file_list]
+
+ with mp.Pool(num_processes) as pool:
+ _ = pool.map(fn, para_list)
+ print("finished")
+
+ os.system(f"cp -rf {args.filename_file} {args.output_dir}")
diff --git a/dataset/traj_analyse_select.py b/dataset/traj_analyse_select.py
new file mode 100644
index 0000000000000000000000000000000000000000..a91fdd302784ecb307ceb44239d58d19dffa072c
--- /dev/null
+++ b/dataset/traj_analyse_select.py
@@ -0,0 +1,161 @@
+import os
+import argparse
+import shutil
+
+import pandas as pd
+import MDAnalysis as mda
+import numpy as np
+import matplotlib.pyplot as plt
+import multiprocessing as mp
+
+from scipy.stats import gaussian_kde
+from MDAnalysis.analysis import align
+
+
+
+def cal_energy(para1):
+ file_md, dirpath = para1
+ mdpath = os.path.join(dirpath, file_md)
+ filename = file_md
+
+ k=2.32*1e-4 # unit(eV/K)
+ T=298.15 # unit(K)
+
+ pdb_filepath = os.path.join(mdpath, filename+".pdb")
+ topology_filepath = os.path.join(mdpath, filename+".pdb")
+
+ u_ref = mda.Universe(pdb_filepath)
+ protein_ref = u_ref.select_atoms('protein')
+ bb_atom_ref = protein_ref.select_atoms('name CA or name C or name N')
+
+ info = {
+ 'rad_gyr': [],
+ 'rmsd_ref':[],
+ 'traj_filename':[],
+ 'energy':[],
+ }
+
+ for xtc_idx in range(1,4):
+ trajectory_filepath = os.path.join(mdpath,filename+"_R"+str(xtc_idx)+".xtc")
+
+ u = mda.Universe(topology_filepath, trajectory_filepath)
+
+ protein = u.select_atoms('protein')
+ bb_atom = protein.select_atoms('name CA or name C or name N')
+
+ # CA_atoms = u.select_atoms('name CA')
+ # bb_atoms = u.select_atoms('backbone')
+
+ count = 0
+ # for ts in u.trajectory:
+ for _ in u.trajectory:
+ count += 1
+
+ rad_gyr = bb_atom.radius_of_gyration()
+ rmsd_ref = align.alignto(bb_atom, bb_atom_ref, select='all', match_atoms=False)[-1]
+ info['rad_gyr'].append(rad_gyr)
+ info['rmsd_ref'].append(rmsd_ref)
+
+ traj_filename = filename + '_R' + str(xtc_idx) + '_'+str(count)+".pdb"
+ info['traj_filename'].append(traj_filename)
+ print(traj_filename)
+ protein.write(os.path.join(mdpath, traj_filename))
+
+
+ info_array = np.stack([info['rad_gyr'],info['rmsd_ref']],axis=0) # (2,2500)
+ kde = gaussian_kde(info_array)
+ density = kde(info_array) # (2500,)
+ G = k*T*np.log(np.max(density)/density) # (2500,)
+ G = (G-np.min(G))/(np.max(G)-np.min(G))
+
+ info['energy'] += G.tolist()
+
+ out_total = pd.DataFrame(info)
+ x, y = np.meshgrid(np.linspace(min(out_total['rad_gyr'])-0.25, max(out_total['rad_gyr'])+0.25, 200),
+ np.linspace(min(out_total['rmsd_ref'])-0.25, max(out_total['rmsd_ref'])+0.25, 200))
+ grid_coordinates = np.vstack([x.ravel(), y.ravel()])
+ density_values = kde(grid_coordinates)
+ # 将密度值变形为与网格坐标相同的形状
+ density_map = density_values.reshape(x.shape)
+ # 绘制高斯核密度估计图
+ plt.contourf(x, y, density_map, levels= np.arange(np.max(density_map)/20, np.max(density_map)*1.1, np.max(density_map)/10))
+ plt.colorbar()
+
+ plt.savefig(os.path.join(mdpath,"md.png"))
+ plt.close()
+
+ out_total.to_csv(os.path.join(mdpath,"traj_info.csv"),index=False)
+
+
+def select_str(file, data_dir, output_dir, select_num=100):
+ info_total = {
+ 'rad_gyr': [],
+ 'rmsd_ref': [],
+ 'traj_filename': [],
+ 'energy': [],
+ }
+
+ print(f"Processing {file}")
+ md_dir = os.path.join(data_dir, file)
+ md_csv = pd.read_csv(os.path.join(md_dir, 'traj_info.csv'))
+ md_csv = md_csv.sort_values('energy', ascending=True)
+
+ idx_total = np.linspace(0, len(md_csv) - 1, select_num)
+ idx_total = (idx_total / idx_total[-1]) ** (1 / 3) * (len(md_csv) - 1) # mapping energy with f(x)=x**(1/3) to get more relatively high-energy structure
+ idx_total = np.unique(np.round(idx_total).astype(int))
+
+ for idx in idx_total:
+ info = md_csv.iloc[idx]
+ traj_filename = info['traj_filename']
+ shutil.copy(os.path.join(md_dir, traj_filename), output_dir)
+
+ info_total['traj_filename'].append(traj_filename)
+ info_total['energy'].append(info['energy'])
+ info_total['rad_gyr'].append(info['rad_gyr'])
+ info_total['rmsd_ref'].append(info['rmsd_ref'])
+
+ return info_total
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dir_path", type=str, default="./dataset/ATLAS")
+ parser.add_argument("--filename", type=str, default="ATLAS_filename.txt")
+
+ parser.add_argument("--select_num", type=int, default=100)
+ parser.add_argument("--select_dir", type=str, default="./dataset/ATLAS/select")
+
+ args = parser.parse_args()
+
+ num_processes = 48
+
+ file_txt = os.path.join(args.dir_path, args.filename)
+ os.makedirs(args.select_dir, exist_ok=True)
+
+ with open(file_txt,'r+') as f:
+ file_cont = f.read()
+ file_list = file_cont.split("\n")
+
+ para1_list = [(file, args.dir_path) for file in file_list]
+ para2_list = [(file, args.dir_path, args.select_dir, args.select_num) for file in file_list]
+
+ info_total_all = {
+ 'rad_gyr': [],
+ 'rmsd_ref': [],
+ 'traj_filename': [],
+ 'energy': [],
+ }
+
+ with mp.Pool(num_processes) as pool:
+ _ = pool.map(cal_energy, para1_list)
+ results = pool.starmap(select_str, para2_list)
+
+ for result in results:
+ info_total_all['traj_filename'].extend(result['traj_filename'])
+ info_total_all['energy'].extend(result['energy'])
+ info_total_all['rad_gyr'].extend(result['rad_gyr'])
+ info_total_all['rmsd_ref'].extend(result['rmsd_ref'])
+
+ df = pd.DataFrame(info_total_all)
+ df.to_csv(os.path.join(args.select_dir, 'traj_info_select.csv'), index=False)
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..596341a039e2dacbe8f5b3db42942328610cf71b
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,481 @@
+name: P2DFlow
+channels:
+ - pytorch
+ - nvidia
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
+ - conda-forge
+dependencies:
+ - pip
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_gnu
+ - anyio=3.7.1=pyhd8ed1ab_0
+ - argon2-cffi=21.3.0=pyhd8ed1ab_0
+ - argon2-cffi-bindings=21.2.0=py310h5764c6d_3
+ - arrow=1.2.3=pyhd8ed1ab_0
+ - asttokens=2.2.1=pyhd8ed1ab_0
+ - astunparse=1.6.3=pyhd8ed1ab_0
+ - async-lru=2.0.4=pyhd8ed1ab_0
+ - attrs=23.1.0=pyh71513ae_1
+ - babel=2.12.1=pyhd8ed1ab_1
+ - backcall=0.2.0=pyh9f0ad1d_0
+ - backports=1.0=pyhd8ed1ab_3
+ - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
+ - beautifulsoup4=4.12.2=pyha770c72_0
+ - biopython=1.81=py310h1fa729e_0
+ - bleach=6.0.0=pyhd8ed1ab_0
+ - blosc=1.21.4=h0f2a231_0
+ - boltons=23.1.1=pyhd8ed1ab_0
+ - boost=1.78.0=py310hcb52e73_5
+ - boost-cpp=1.78.0=h2c5509c_4
+ - brotli=1.0.9=h166bdaf_9
+ - brotli-bin=1.0.9=h166bdaf_9
+ - brotli-python=1.0.9=py310hd8f1fbe_9
+ - bzip2=1.0.8=h7f98852_4
+ - c-ares=1.19.1=hd590300_0
+ - c-blosc2=2.10.2=hb4ffafa_0
+ - ca-certificates=2024.2.2=hbcca054_0
+ - cached-property=1.5.2=hd8ed1ab_1
+ - cached_property=1.5.2=pyha770c72_1
+ - cairo=1.18.0=h3faef2a_0
+ - certifi=2024.2.2=pyhd8ed1ab_0
+ - cffi=1.15.1=py310h255011f_3
+ - cftime=1.6.3=py310h1f7b6fc_0
+ - chardet=5.2.0=py310hff52083_1
+ - charset-normalizer=3.2.0=pyhd8ed1ab_0
+ - colorama=0.4.6=pyhd8ed1ab_0
+ - comm=0.1.4=pyhd8ed1ab_0
+ - conda=23.7.4=py310hff52083_0
+ - conda-package-handling=2.2.0=pyh38be061_0
+ - conda-package-streaming=0.9.0=pyhd8ed1ab_0
+ - contourpy=1.1.0=py310hd41b1e2_0
+ - cryptography=41.0.7=py310hb8475ec_1
+ - cuda=11.6.0=0
+ - cuda-cccl=11.6.55=hf6102b2_0
+ - cuda-command-line-tools=11.6.2=0
+ - cuda-compiler=11.6.2=0
+ - cuda-cudart=11.6.55=he381448_0
+ - cuda-cudart-dev=11.6.55=h42ad0f4_0
+ - cuda-cuobjdump=11.6.124=h2eeebcb_0
+ - cuda-cupti=11.6.124=h86345e5_0
+ - cuda-cuxxfilt=11.6.124=hecbf4f6_0
+ - cuda-driver-dev=11.6.55=0
+ - cuda-gdb=12.0.90=hd47b8d6_0
+ - cuda-libraries=11.6.2=0
+ - cuda-libraries-dev=11.6.0=0
+ - cuda-memcheck=11.8.86=0
+ - cuda-nsight=12.0.78=ha770c72_0
+ - cuda-nsight-compute=12.2.1=0
+ - cuda-nvcc=11.6.124=hbba6d2d_0
+ - cuda-nvdisasm=12.0.76=h59595ed_0
+ - cuda-nvml-dev=11.6.55=haa9ef22_0
+ - cuda-nvprof=12.0.90=h59595ed_0
+ - cuda-nvprune=11.6.124=he22ec0a_0
+ - cuda-nvrtc=11.6.124=h020bade_0
+ - cuda-nvrtc-dev=11.6.124=h249d397_0
+ - cuda-nvtx=11.6.124=h0630a44_0
+ - cuda-nvvp=12.0.90=h59595ed_0
+ - cuda-runtime=11.6.2=0
+ - cuda-samples=11.6.101=h8efea70_0
+ - cuda-sanitizer-api=12.0.90=h59595ed_0
+ - cuda-toolkit=11.6.0=0
+ - cuda-tools=11.6.0=0
+ - cuda-version=12.0=hffde075_2
+ - cuda-visual-tools=11.6.0=0
+ - cycler=0.11.0=pyhd8ed1ab_0
+ - debugpy=1.6.8=py310hc6cd4ac_0
+ - decorator=5.1.1=pyhd8ed1ab_0
+ - defusedxml=0.7.1=pyhd8ed1ab_0
+ - entrypoints=0.4=pyhd8ed1ab_0
+ - exceptiongroup=1.1.3=pyhd8ed1ab_0
+ - executing=1.2.0=pyhd8ed1ab_0
+ - expat=2.6.2=h59595ed_0
+ - fasteners=0.17.3=pyhd8ed1ab_0
+ - flit-core=3.9.0=pyhd8ed1ab_0
+ - fmt=10.1.1=h00ab1b0_1
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
+ - font-ttf-inconsolata=3.000=h77eed37_0
+ - font-ttf-source-code-pro=2.038=h77eed37_0
+ - font-ttf-ubuntu=0.83=h77eed37_1
+ - fontconfig=2.14.2=h14ed4e7_0
+ - fonts-conda-ecosystem=1=0
+ - fonts-conda-forge=1=0
+ - fonttools=4.42.0=py310h2372a71_0
+ - fqdn=1.5.1=pyhd8ed1ab_0
+ - freetype=2.12.1=hca18f0e_1
+ - freetype-py=2.3.0=pyhd8ed1ab_0
+ - gds-tools=1.5.0.59=hcb278e6_0
+ - gmp=6.2.1=h58526e2_0
+ - greenlet=3.0.3=py310hc6cd4ac_0
+ - griddataformats=1.0.2=pyhd8ed1ab_0
+ - gsd=3.2.0=py310h1f7b6fc_0
+ - h5py=3.9.0=nompi_py310hcca72df_101
+ - hdf4=4.2.15=h501b40f_6
+ - hdf5=1.14.1=nompi_h4f84152_100
+ - icu=73.2=h59595ed_0
+ - idna=3.4=pyhd8ed1ab_0
+ - importlib-metadata=6.8.0=pyha770c72_0
+ - importlib_metadata=6.8.0=hd8ed1ab_0
+ - importlib_resources=6.0.1=pyhd8ed1ab_0
+ - ipykernel=6.25.1=pyh71e2992_0
+ - ipython=8.14.0=pyh41d4057_0
+ - ipywidgets=8.1.1=pyhd8ed1ab_0
+ - isoduration=20.11.0=pyhd8ed1ab_0
+ - jedi=0.19.0=pyhd8ed1ab_0
+ - jinja2=3.1.2=pyhd8ed1ab_1
+ - joblib=1.3.2=pyhd8ed1ab_0
+ - json5=0.9.14=pyhd8ed1ab_0
+ - jsonpatch=1.33=pyhd8ed1ab_0
+ - jsonpointer=2.0=py_0
+ - jsonschema=4.19.0=pyhd8ed1ab_1
+ - jsonschema-specifications=2023.7.1=pyhd8ed1ab_0
+ - jsonschema-with-format-nongpl=4.19.0=pyhd8ed1ab_1
+ - jupyter-lsp=2.2.0=pyhd8ed1ab_0
+ - jupyter_client=8.3.0=pyhd8ed1ab_0
+ - jupyter_core=5.3.1=py310hff52083_0
+ - jupyter_events=0.7.0=pyhd8ed1ab_2
+ - jupyter_server=2.7.1=pyhd8ed1ab_0
+ - jupyter_server_terminals=0.4.4=pyhd8ed1ab_1
+ - jupyterlab=4.0.5=pyhd8ed1ab_0
+ - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
+ - jupyterlab_server=2.24.0=pyhd8ed1ab_0
+ - jupyterlab_widgets=3.0.9=pyhd8ed1ab_0
+ - keyutils=1.6.1=h166bdaf_0
+ - kiwisolver=1.4.4=py310hbf28c38_1
+ - krb5=1.21.2=h659d440_0
+ - lcms2=2.15=haa2dc70_1
+ - ld_impl_linux-64=2.40=h41732ed_0
+ - lerc=4.0.0=h27087fc_0
+ - libaec=1.0.6=hcb278e6_1
+ - libarchive=3.6.2=h039dbb9_1
+ - libblas=3.9.0=17_linux64_openblas
+ - libbrotlicommon=1.0.9=h166bdaf_9
+ - libbrotlidec=1.0.9=h166bdaf_9
+ - libbrotlienc=1.0.9=h166bdaf_9
+ - libcblas=3.9.0=17_linux64_openblas
+ - libcublas=12.0.1.189=hcb278e6_2
+ - libcublas-dev=12.0.1.189=hcb278e6_2
+ - libcufft=11.0.0.21=hcb278e6_1
+ - libcufft-dev=11.0.0.21=hcb278e6_1
+ - libcufile=1.5.0.59=hcb278e6_0
+ - libcufile-dev=1.5.0.59=hcb278e6_0
+ - libcurand=10.3.1.50=hcb278e6_0
+ - libcurand-dev=10.3.1.50=hcb278e6_0
+ - libcurl=8.2.1=hca28451_0
+ - libcusolver=11.4.2.57=hcb278e6_1
+ - libcusparse=12.0.0.76=hcb278e6_1
+ - libdeflate=1.18=h0b41bf4_0
+ - libedit=3.1.20191231=he28a2e2_2
+ - libev=4.33=h516909a_1
+ - libexpat=2.6.2=h59595ed_0
+ - libffi=3.4.2=h7f98852_5
+ - libgcc-ng=13.1.0=he5830b7_0
+ - libgfortran-ng=13.1.0=h69a702a_0
+ - libgfortran5=13.1.0=h15d22d2_0
+ - libglib=2.80.0=hf2295e7_5
+ - libgomp=13.1.0=he5830b7_0
+ - libiconv=1.17=hd590300_2
+ - libjpeg-turbo=2.1.5.1=h0b41bf4_0
+ - liblapack=3.9.0=17_linux64_openblas
+ - libmamba=1.5.1=h744094f_0
+ - libmambapy=1.5.1=py310h39ff949_0
+ - libnetcdf=4.9.2=nompi_he09a3a9_107
+ - libnghttp2=1.52.0=h61bc06f_0
+ - libnpp=12.0.0.30=h59595ed_0
+ - libnpp-dev=12.0.0.30=h59595ed_0
+ - libnsl=2.0.0=h7f98852_0
+ - libnuma=2.0.16=h0b41bf4_1
+ - libnvjitlink=12.0.76=hcb278e6_1
+ - libnvjpeg=12.0.0.28=hcb278e6_0
+ - libnvjpeg-dev=12.0.0.28=ha770c72_0
+ - libopenblas=0.3.23=pthreads_h80387f5_0
+ - libpng=1.6.39=h753d276_0
+ - libsodium=1.0.18=h36c2ea0_1
+ - libsolv=0.7.27=hfc55251_0
+ - libsqlite=3.42.0=h2797004_0
+ - libssh2=1.11.0=h0841786_0
+ - libstdcxx-ng=13.1.0=hfd8a6a1_0
+ - libtiff=4.5.1=h8b53f26_0
+ - libuuid=2.38.1=h0b41bf4_0
+ - libwebp-base=1.3.1=hd590300_0
+ - libxcb=1.15=h0b41bf4_0
+ - libxml2=2.12.4=h232c23b_1
+ - libzip=1.10.1=h2629f0a_3
+ - libzlib=1.2.13=hd590300_5
+ - lz4-c=1.9.4=hcb278e6_0
+ - lzo=2.10=h516909a_1000
+ - mamba=1.5.1=py310h51d5547_0
+ - markupsafe=2.1.3=py310h2372a71_0
+ - matplotlib-base=3.7.2=py310hf38f957_0
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
+ - mda-xdrlib=0.2.0=pyhd8ed1ab_0
+ - mdanalysis=2.7.0=py310hcc13569_1
+ - mdtraj=1.9.9=py310h8e08b51_0
+ - mistune=3.0.1=pyhd8ed1ab_0
+ - mmtf-python=1.1.3=pyhd8ed1ab_0
+ - mrcfile=1.5.0=pyhd8ed1ab_0
+ - msgpack-python=1.0.7=py310hd41b1e2_0
+ - munkres=1.1.4=pyh9f0ad1d_0
+ - nbclient=0.8.0=pyhd8ed1ab_0
+ - nbconvert-core=7.7.4=pyhd8ed1ab_0
+ - nbformat=5.9.2=pyhd8ed1ab_0
+ - ncurses=6.4=hcb278e6_0
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
+ - netcdf4=1.6.4=nompi_py310h6f5dce6_101
+ - nglview=3.1.1=pyh15ce09e_0
+ - nomkl=1.0=h5ca1d4c_0
+ - notebook-shim=0.2.3=pyhd8ed1ab_0
+ - nsight-compute=2023.2.1.3=0
+ - numexpr=2.8.4=py310hd91493a_101
+ - numpy=1.25.2=py310ha4c1d20_0
+ - openjpeg=2.5.0=hfec8fc6_2
+ - openssl=3.2.1=hd590300_1
+ - overrides=7.4.0=pyhd8ed1ab_0
+ - packaging=23.1=pyhd8ed1ab_0
+ - pandas=2.0.3=py310h7cbd5c2_1
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
+ - parso=0.8.3=pyhd8ed1ab_0
+ - patsy=0.5.3=pyhd8ed1ab_0
+ - pcre2=10.43=hcad00b1_0
+ - pexpect=4.8.0=pyh1a96a4e_2
+ - pickleshare=0.7.5=py_1003
+ - pillow=10.0.0=py310h582fbeb_0
+ - pixman=0.43.2=h59595ed_0
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
+ - platformdirs=3.10.0=pyhd8ed1ab_0
+ - pluggy=1.3.0=pyhd8ed1ab_0
+ - pooch=1.7.0=pyha770c72_3
+ - prometheus_client=0.17.1=pyhd8ed1ab_0
+ - prompt-toolkit=3.0.39=pyha770c72_0
+ - prompt_toolkit=3.0.39=hd8ed1ab_0
+ - psutil=5.9.5=py310h1fa729e_0
+ - pthread-stubs=0.4=h36c2ea0_1001
+ - ptyprocess=0.7.0=pyhd3deb0d_0
+ - pure_eval=0.2.2=pyhd8ed1ab_0
+ - py-cpuinfo=9.0.0=pyhd8ed1ab_0
+ - pybind11-abi=4=hd8ed1ab_3
+ - pycairo=1.26.0=py310hda9f760_0
+ - pycosat=0.6.6=py310h2372a71_0
+ - pycparser=2.21=pyhd8ed1ab_0
+ - pyedr=0.8.0=pyhd8ed1ab_0
+ - pygments=2.16.1=pyhd8ed1ab_0
+ - pyopenssl=23.3.0=pyhd8ed1ab_0
+ - pyparsing=3.0.9=pyhd8ed1ab_0
+ - pysocks=1.7.1=pyha2e5f31_6
+ - pytables=3.8.0=py310ha028ce3_2
+ - python=3.10.12=hd12c33a_0_cpython
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
+ - python-fastjsonschema=2.18.0=pyhd8ed1ab_0
+ - python-json-logger=2.0.7=pyhd8ed1ab_0
+ - python-tzdata=2023.3=pyhd8ed1ab_0
+ - python_abi=3.10=3_cp310
+ - pytng=0.3.1=py310h7d11597_1
+ - pytorch-cuda=11.6=h867d48c_0
+ - pytz=2023.3=pyhd8ed1ab_0
+ - pyyaml=6.0=py310h5764c6d_5
+ - pyzmq=25.1.1=py310h5bbb5d0_0
+ - rdkit=2023.03.3=py310h399bcf7_0
+ - readline=8.2=h8228510_1
+ - referencing=0.30.2=pyhd8ed1ab_0
+ - reportlab=4.1.0=py310h2372a71_0
+ - reproc=14.2.4.post0=hd590300_1
+ - reproc-cpp=14.2.4.post0=h59595ed_1
+ - requests=2.31.0=pyhd8ed1ab_0
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_0
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
+ - rlpycairo=0.2.0=pyhd8ed1ab_0
+ - rpds-py=0.9.2=py310hcb5633a_0
+ - ruamel.yaml=0.17.40=py310h2372a71_0
+ - ruamel.yaml.clib=0.2.7=py310h2372a71_2
+ - scikit-learn=1.3.2=py310h1fdf081_2
+ - scipy=1.11.1=py310ha4c1d20_0
+ - seaborn=0.12.2=hd8ed1ab_0
+ - seaborn-base=0.12.2=pyhd8ed1ab_0
+ - send2trash=1.8.2=pyh41d4057_0
+ - setuptools=68.1.2=pyhd8ed1ab_0
+ - six=1.16.0=pyh6c4a22f_0
+ - snappy=1.1.10=h9fff704_0
+ - sniffio=1.3.0=pyhd8ed1ab_0
+ - soupsieve=2.3.2.post1=pyhd8ed1ab_0
+ - sqlalchemy=2.0.29=py310h2372a71_0
+ - stack_data=0.6.2=pyhd8ed1ab_0
+ - statsmodels=0.14.0=py310h278f3c1_1
+ - terminado=0.17.1=pyh41d4057_0
+ - threadpoolctl=3.2.0=pyha21a80b_0
+ - tidynamics=1.1.2=pyhd8ed1ab_0
+ - tinycss2=1.2.1=pyhd8ed1ab_0
+ - tk=8.6.12=h27826a3_0
+ - tomli=2.0.1=pyhd8ed1ab_0
+ - toolz=0.12.0=pyhd8ed1ab_0
+ - tornado=6.3.3=py310h2372a71_0
+ - tqdm=4.66.1=pyhd8ed1ab_0
+ - traitlets=5.9.0=pyhd8ed1ab_0
+ - typing-extensions=4.7.1=hd8ed1ab_0
+ - typing_extensions=4.7.1=pyha770c72_0
+ - typing_utils=0.1.0=pyhd8ed1ab_0
+ - tzdata=2023c=h71feb2d_0
+ - unicodedata2=15.0.0=py310h5764c6d_0
+ - uri-template=1.3.0=pyhd8ed1ab_0
+ - urllib3=2.0.4=pyhd8ed1ab_0
+ - wcwidth=0.2.6=pyhd8ed1ab_0
+ - webcolors=1.13=pyhd8ed1ab_0
+ - webencodings=0.5.1=py_1
+ - websocket-client=1.6.1=pyhd8ed1ab_0
+ - wheel=0.41.1=pyhd8ed1ab_0
+ - widgetsnbextension=4.0.9=pyhd8ed1ab_0
+ - xorg-kbproto=1.0.7=h7f98852_1002
+ - xorg-libice=1.1.1=hd590300_0
+ - xorg-libsm=1.2.4=h7391055_0
+ - xorg-libx11=1.8.9=h8ee46fc_0
+ - xorg-libxau=1.0.11=hd590300_0
+ - xorg-libxdmcp=1.1.3=h7f98852_0
+ - xorg-libxext=1.3.4=h0b41bf4_2
+ - xorg-libxrender=0.9.11=hd590300_0
+ - xorg-renderproto=0.11.1=h7f98852_1002
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
+ - xorg-xproto=7.0.31=h7f98852_1007
+ - xz=5.2.6=h166bdaf_0
+ - yaml=0.2.5=h7f98852_2
+ - yaml-cpp=0.7.0=h59595ed_3
+ - zeromq=4.3.4=h9c3ff4c_1
+ - zipp=3.16.2=pyhd8ed1ab_0
+ - zlib=1.2.13=hd590300_5
+ - zlib-ng=2.0.7=h0b41bf4_0
+ - zstandard=0.19.0=py310h1275a96_2
+ - zstd=1.5.2=hfc55251_7
+ - pip:
+ - absl-py==1.4.0
+ - aiohttp==3.8.5
+ - aiosignal==1.3.1
+ - alabaster==0.7.16
+ - annotated-types==0.5.0
+ - antlr4-python3-runtime==4.9.3
+ - appdirs==1.4.4
+ - async-timeout==4.0.3
+ - backoff==2.2.1
+ - biotite==0.37.0
+ - blessed==1.20.0
+ - cachetools==5.3.1
+ - click==8.1.7
+ - cmake==3.27.2
+ - contextlib2==21.6.0
+ - croniter==1.4.1
+ - dateutils==0.6.12
+ - deepdiff==6.3.1
+ - deepspeed==0.12.3
+ - dm-tree==0.1.8
+ - docker-pycreds==0.4.0
+ - docutils==0.21.2
+ - e3nn==0.5.1
+ - einops==0.6.1
+ - fair-esm==2.0.0
+ - fastapi==0.101.1
+ - filelock==3.12.2
+ - frozenlist==1.4.0
+ - fsspec==2023.6.0
+ - gitdb==4.0.10
+ - gitpython==3.1.32
+ - gpustat==1.1
+ - gputil==1.4.0
+ - h11==0.14.0
+ - hjson==3.1.0
+ - hydra-core==1.3.2
+ - imagesize==1.4.1
+ - inquirer==3.1.3
+ - ipdb==0.13.13
+ - ipython-genutils==0.2.0
+ - itsdangerous==2.1.2
+ - jupyter==1.0.0
+ - jupyter-console==6.6.3
+ - jupyter-contrib-core==0.4.2
+ - jupyter-contrib-nbextensions==0.7.0
+ - jupyter-highlight-selected-word==0.2.0
+ - jupyter-nbextensions-configurator==0.6.3
+ - lightning==2.0.7
+ - lightning-cloud==0.5.37
+ - lightning-utilities==0.9.0
+ - lit==16.0.6
+ - lxml==5.1.0
+ - markdown-it-py==3.0.0
+ - mdurl==0.1.2
+ - mizani==0.9.3
+ - ml-collections==0.1.1
+ - mpmath==1.3.0
+ - msgpack==1.0.5
+ - multidict==6.0.4
+ - networkx==3.1
+ - ninja==1.11.1
+ - notebook==7.0.7
+ - numpydoc==1.7.0
+ - nvidia-cublas-cu11==11.10.3.66
+ - nvidia-cuda-cupti-cu11==11.7.101
+ - nvidia-cuda-nvrtc-cu11==11.7.99
+ - nvidia-cuda-runtime-cu11==11.7.99
+ - nvidia-cudnn-cu11==8.5.0.96
+ - nvidia-cufft-cu11==10.9.0.58
+ - nvidia-curand-cu11==10.2.10.91
+ - nvidia-cusolver-cu11==11.4.0.1
+ - nvidia-cusparse-cu11==11.7.4.91
+ - nvidia-ml-py==12.535.77
+ - nvidia-nccl-cu11==2.14.3
+ - nvidia-nvtx-cu11==11.7.91
+ - nvitop==1.3.0
+ - oddt==0.7
+ - omegaconf==2.3.0
+ - opt-einsum==3.3.0
+ - opt-einsum-fx==0.1.4
+ - ordered-set==4.1.0
+ - pathtools==0.1.2
+ - pip==24.1.2
+ - plotly==5.16.1
+ - plotnine==0.12.3
+ - protobuf==4.24.1
+ - pydantic==2.1.1
+ - pydantic-core==2.4.0
+ - pyg-lib==0.4.0+pt20cu117
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - pyjwt==2.8.0
+ - pynvml==11.5.0
+ - python-editor==1.0.4
+ - python-multipart==0.0.6
+ - pytorch-lightning==2.0.7
+ - qtconsole==5.5.1
+ - qtpy==2.4.1
+ - readchar==4.0.5
+ - rich==13.5.2
+ - sentry-sdk==1.29.2
+ - setproctitle==1.3.2
+ - smmap==5.0.0
+ - snakeviz==2.2.0
+ - snowballstemmer==2.2.0
+ - sphinx==7.3.7
+ - sphinxcontrib-applehelp==1.0.8
+ - sphinxcontrib-devhelp==1.0.6
+ - sphinxcontrib-htmlhelp==2.0.5
+ - sphinxcontrib-jsmath==1.0.1
+ - sphinxcontrib-qthelp==1.0.7
+ - sphinxcontrib-serializinghtml==1.1.10
+ - starlette==0.27.0
+ - starsessions==1.3.0
+ - swig==4.2.1
+ - sympy==1.12
+ - tabulate==0.9.0
+ - tenacity==8.2.3
+ - termcolor==2.3.0
+ - tmtools==0.0.3
+ - torch==2.0.1
+ - torch-cluster==1.6.3+pt20cu117
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - torch-geometric==2.5.2
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - torch-scatter==2.1.2+pt20cu117
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - torch-sparse==0.6.18+pt20cu117
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - torch-spline-conv==1.2.2+pt20cu117
+ - -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
+ - torchmetrics==1.0.3
+ - triton==2.0.0
+ - uvicorn==0.23.2
+ - wandb==0.15.8
+ - websockets==11.0.3
+ - yarl==1.9.2
diff --git a/experiments/__pycache__/utils.cpython-310.pyc b/experiments/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c399f62acbebf8b735f941099a0491a7fcbd9aa5
Binary files /dev/null and b/experiments/__pycache__/utils.cpython-310.pyc differ
diff --git a/experiments/inference_se3_flows.py b/experiments/inference_se3_flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..7505c143e0aef78d3ff86b72ead8e664cfdb2389
--- /dev/null
+++ b/experiments/inference_se3_flows.py
@@ -0,0 +1,119 @@
+"""DDP inference script."""
+import os
+import time
+import numpy as np
+import hydra
+import torch
+import GPUtil
+import sys
+
+from pytorch_lightning import Trainer
+from omegaconf import DictConfig, OmegaConf
+from experiments import utils as eu
+from models.flow_module import FlowModule
+import re
+from typing import Optional
+import subprocess
+from biotite.sequence.io import fasta
+from data import utils as du
+from analysis import metrics
+import pandas as pd
+import esm
+import shutil
+import biotite.structure.io as bsio
+
+
+torch.set_float32_matmul_precision('high')
+log = eu.get_pylogger(__name__)
+
+class Sampler:
+
+ def __init__(self, cfg: DictConfig):
+ """Initialize sampler.
+
+ Args:
+ cfg: inference config.
+ """
+ ckpt_path = cfg.inference.ckpt_path
+ ckpt_dir = os.path.dirname(ckpt_path)
+ # ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))
+ # ckpt_cfg = torch.load(ckpt_path, map_location="cpu")['hyper_parameters']['cfg']
+
+ # Set-up config.
+ OmegaConf.set_struct(cfg, False)
+ # OmegaConf.set_struct(ckpt_cfg, False)
+ # cfg = OmegaConf.merge(cfg, ckpt_cfg)
+ # cfg.experiment.checkpointer.dirpath = './'
+
+ self._cfg = cfg
+ # self._pmpnn_dir = cfg.inference.pmpnn_dir
+ self._infer_cfg = cfg.inference
+ self._samples_cfg = self._infer_cfg.samples
+ # self._rng = np.random.default_rng(self._infer_cfg.seed)
+
+ # Set-up directories to write results to
+ self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
+ self._output_dir = os.path.join(
+ self._infer_cfg.output_dir,
+ self._ckpt_name,
+ self._infer_cfg.name,
+ )
+ os.makedirs(self._output_dir, exist_ok=True)
+ log.info(f'Saving results to {self._output_dir}')
+ config_path = os.path.join(self._output_dir, 'config.yaml')
+ with open(config_path, 'w') as f:
+ OmegaConf.save(config=self._cfg, f=f)
+ log.info(f'Saving inference config to {config_path}')
+
+ # Read checkpoint and initialize module.
+ self._flow_module = FlowModule.load_from_checkpoint(
+ checkpoint_path=ckpt_path,
+ )
+ self._flow_module.eval()
+ self._flow_module._infer_cfg = self._infer_cfg
+ self._flow_module._samples_cfg = self._samples_cfg
+ self._flow_module._output_dir = self._output_dir
+
+ # devices = GPUtil.getAvailable(
+ # order='memory', limit = 8)[:4]
+ # print(GPUtil.getAvailable(order='memory', limit = 8))
+
+ devices = [torch.cuda.current_device()]
+
+ self._folding_model = esm.pretrained.esmfold_v1().eval()
+ self._folding_model = self._folding_model.to(devices[-1])
+
+ def run_sampling(self):
+ # devices = GPUtil.getAvailable(
+ # order='memory', limit = 8)[:self._infer_cfg.num_gpus]
+ devices = [torch.cuda.current_device()]
+
+ log.info(f"Using devices: {devices}")
+
+ eval_dataset = eu.LengthDataset(self._samples_cfg)
+ dataloader = torch.utils.data.DataLoader(
+ eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False)
+
+ trainer = Trainer(
+ accelerator="gpu",
+ strategy="ddp",
+ devices=devices,
+ )
+ trainer.predict(self._flow_module, dataloaders=dataloader)
+
+
+
+@hydra.main(version_base=None, config_path="../configs", config_name="inference")
+def run(cfg: DictConfig) -> None:
+
+ # Read model checkpoint.
+ log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
+ start_time = time.time()
+ sampler = Sampler(cfg)
+ sampler.run_sampling()
+ #sampler.eval_test()
+ elapsed_time = time.time() - start_time
+ log.info(f'Finished in {elapsed_time:.2f}s')
+
+if __name__ == '__main__':
+ run()
diff --git a/experiments/train_se3_flows.py b/experiments/train_se3_flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80bbc90c2354c08d6af5a59074be70e3323cc25
--- /dev/null
+++ b/experiments/train_se3_flows.py
@@ -0,0 +1,107 @@
+import os
+import GPUtil
+import torch
+import sys
+import hydra
+import wandb
+
+# Pytorch lightning imports
+from pytorch_lightning import LightningDataModule, LightningModule, Trainer
+from pytorch_lightning.loggers.wandb import WandbLogger
+from pytorch_lightning.trainer import Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from omegaconf import DictConfig, OmegaConf
+from data.pdb_dataloader import PdbDataModule
+from models.flow_module import FlowModule
+from experiments import utils as eu
+
+
+os.environ["WANDB_MODE"] = "offline"
+log = eu.get_pylogger(__name__)
+torch.set_float32_matmul_precision('high')
+
+
+class Experiment:
+
+ def __init__(self, *, cfg: DictConfig):
+ self._cfg = cfg
+ self._data_cfg = cfg.data
+ self._exp_cfg = cfg.experiment
+ self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg)
+ self._model: LightningModule = FlowModule(self._cfg)
+
+ def train(self):
+ callbacks = []
+ if self._exp_cfg.debug:
+ log.info("Debug mode.")
+ logger = None
+ self._exp_cfg.num_devices = 1
+ self._data_cfg.loader.num_workers = 0
+ else:
+ logger = WandbLogger(
+ **self._exp_cfg.wandb,
+ )
+
+ # Checkpoint directory.
+ ckpt_dir = self._exp_cfg.checkpointer.dirpath
+ os.makedirs(ckpt_dir, exist_ok=True)
+ log.info(f"Checkpoints saved to {ckpt_dir}")
+
+ # Model checkpoints
+ callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer))
+
+ # Save config
+ cfg_path = os.path.join(ckpt_dir, 'config.yaml')
+ with open(cfg_path, 'w') as f:
+ OmegaConf.save(config=self._cfg, f=f.name)
+ cfg_dict = OmegaConf.to_container(self._cfg, resolve=True)
+ flat_cfg = dict(eu.flatten_dict(cfg_dict))
+ if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config):
+ logger.experiment.config.update(flat_cfg)
+
+ devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices]
+ log.info(f"Using devices: {devices}")
+ trainer = Trainer(
+ **self._exp_cfg.trainer,
+ callbacks=callbacks,
+ logger=logger,
+ use_distributed_sampler=False,
+ enable_progress_bar=True,
+ enable_model_summary=True,
+ devices=devices,
+ )
+
+ if self._exp_cfg.warm_start is not None:
+ # Lightning will always save the hyperparameters to the checkpoint
+ self._model = self._model.load_from_checkpoint(self._exp_cfg.warm_start, strict=False, map_location="cpu")
+
+ trainer.fit(
+ model=self._model,
+ datamodule=self._datamodule,
+ # ckpt_path=self._exp_cfg.warm_start
+ )
+
+
+@hydra.main(version_base=None, config_path="../configs", config_name="base.yaml")
+def main(cfg: DictConfig):
+
+ if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override:
+ # Loads warm start config.
+ warm_start_cfg_path = os.path.join(
+ os.path.dirname(cfg.experiment.warm_start), 'config.yaml')
+ warm_start_cfg = OmegaConf.load(warm_start_cfg_path)
+
+ # Warm start config may not have latest fields in the base config.
+ # Add these fields to the warm start config.
+ OmegaConf.set_struct(cfg.model, False)
+ OmegaConf.set_struct(warm_start_cfg.model, False)
+ cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model)
+ OmegaConf.set_struct(cfg.model, True)
+ log.info(f'Loaded warm start config from {warm_start_cfg_path}')
+
+ exp = Experiment(cfg=cfg)
+ exp.train()
+
+if __name__ == "__main__":
+ main()
diff --git a/experiments/utils.py b/experiments/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b97b7fedc6528c56e9067fe7ff2af00dd6240fec
--- /dev/null
+++ b/experiments/utils.py
@@ -0,0 +1,248 @@
+"""Utility functions for experiments."""
+import logging
+import torch
+import os
+import re
+import random
+import esm
+
+import numpy as np
+import pandas as pd
+import random
+
+from analysis import utils as au
+from pytorch_lightning.utilities.rank_zero import rank_zero_only
+from data.residue_constants import restype_order
+from data.repr import get_pre_repr
+from data import utils as du
+from data.residue_constants import restype_atom37_mask
+from openfold.data import data_transforms
+from openfold.utils import rigid_utils
+from data.cal_trans_rotmats import cal_trans_rotmats
+
+
+class LengthDataset(torch.utils.data.Dataset):
+ def __init__(self, samples_cfg):
+ self._samples_cfg = samples_cfg
+
+ # self._all_sample_seqs = ['LSEEEKKELEKEKRKKEIVEEVYKELKEEGKIKNLPKEEFMKKGLEILEKNEGKLKTDEEAKEELL',
+ # 'PSPEELEKKAKEERKLRIVKEVGEELRKEGLIKDLPEEEFLKKGLEILKENEGKLKTEEEAKEALLEKFK',
+ # 'PMLIEVLCRTDKVEELIKEIEELTKDKILEVKVEKIDENTVKIEIVLEKKEAAEKAAKWLSEVEGCEVIEMREV',
+ # 'PTKITVLCPKSKVEELIEEIKEKTNDKILSVEVEEISPDSVKINIILETEEAALKAAEWLSEVEGCEVLEISEVELE',
+ # 'SSSVKKRYKLTVKIEGASEEEFKELAELAKKLGEELGGLIEFEGDKENGKLTLLMESKEKAEKVGEALKEAGVKGGYTIEEFD',
+ # 'VTSITKRYKLTVKITGASAAEFAALGAAAEAQGKALGGLLSFTADAANGTITVLMDTKEKAEKIGDALKALGVKGGYTISEFLEAD',
+ # 'SKIEETKKKIAEGNYEEIKKLKEEIEKEKKKFEEEEKKEKEKAEELLKKDPEKGKKEKAKKEAEFEKKKKEYEEILKIIEKALKGKE',
+ # 'SRIEEVKKQIEESDKEGVKELKKEILKEYEKFKKEAEKEKAEAEKLKKEDPEKGAKEEAELKKKHEEEKKEYEKILEIIEKRLKGAEEGK',
+ # 'GEEALKLMEEELAAAKTEEAKKFMEGLKKMIEEIAKAMATGDPEVIEEGKKRLLEWGKEVGEKGKKEGNPFLIELEKIIEYMAEGEIEEGLKKLMEFLKKKR',
+ # 'GAEALALMDEMLAAAKREEDKAFYARLRELVRRLAAALATGDPAVLAAGRAEAAAEGDALGAEGRATGDPFLVELAAIVAALAAGTPEEGLAALAAFLRAKAAAR']
+
+ # self._all_filename = ['P450'] * 250
+ # self._all_sample_seqs = [('GKLPPGPSPLPVLGNLLQMDRKGLLRSFLRLREKYGDVFTVYLGSRPVVVLCGTDAIREALVDQAEAFSGRGKIAVVDPIFQGYGVIFANGERWRALRRFSLATMRDFGMGKRSVEERIQEEARCLVEELRKSKGALLDNTLLFHSITSNIICSIVFGKRFDYKDPVFLRLLDLFFQSFSLISSFSSQVFELFSGFLKYFPGTHRQIYRNLQEINTFIGQSVEKHRATLDPSNPRDFIDVYLLRMEKDKSDPSSEFHHQNLILTVLSLFFAGTETTSTTLRYGFLLMLKYPHVTERVQKEIEQVIGSHRPPALDDRAKMPYTDAVIHEIQRLGDLIPFGVPHTVTKDTQFRGYVIPKNTEVFPVLSSALHDPRYFETPNTFNPGHFLDANGALKRNEGFMPFSLGKRICLGEGIARTELFLFFTTILQNFSIASPVPPEDIDLTPRESGVGNVPPSYQIRFLARH',0)] * 250
+
+ validcsv = pd.read_csv(self._samples_cfg.validset_path)
+
+ self._all_sample_seqs = []
+ self._all_filename = []
+
+ prob_num = 500
+ exp_prob = np.exp([-prob/prob_num*2 for prob in range(prob_num)]).cumsum()
+ exp_prob = exp_prob/np.max(exp_prob)
+
+ for idx in range(len(validcsv['seq'])):
+
+ # if idx >= 0 and idx < 15:
+ # if idx >= 15 and idx < 30:
+ # if idx >= 30 and idx < 45:
+ # if idx >= 45 and idx < 60:
+ # if idx >= 60 and idx < 75:
+ # if idx >= 75 and idx < 90:
+ # if idx >= 90 and idx < 105:
+ # pass
+ # else:
+ # continue
+
+
+ # if not re.search('2wsi_A',validcsv['file'][idx]):
+ # continue
+
+
+ self._all_filename += [validcsv['file'][idx]] * self._samples_cfg.sample_num
+
+ for batch_idx in range(self._samples_cfg.sample_num):
+
+ rand = random.random()
+ for prob in range(prob_num):
+ if rand < exp_prob[prob]:
+ energy = torch.tensor(prob/prob_num)
+ break
+
+ self._all_sample_seqs += [(validcsv['seq'][idx], energy)]
+
+
+ self._all_sample_ids = self._all_sample_seqs
+
+ # Load ESM-2 model
+ self.device_esm=f'cuda:{torch.cuda.current_device()}'
+ self.model_esm2, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
+ self.batch_converter = self.alphabet.get_batch_converter()
+ self.model_esm2.eval().cuda(self.device_esm) # disables dropout for deterministic results
+ self.model_esm2.requires_grad_(False)
+
+ self._folding_model = esm.pretrained.esmfold_v1().eval()
+ self._folding_model = self._folding_model.to(self.device_esm)
+
+ self.esm_savepath = self._samples_cfg.esm_savepath
+
+
+ self.device_esm=f'cuda:{torch.cuda.current_device()}'
+ self._folding_model = esm.pretrained.esmfold_v1().eval()
+ self._folding_model.requires_grad_(False)
+ self._folding_model.to(self.device_esm)
+
+ def run_folding(self, sequence, save_path):
+ """Run ESMFold on sequence."""
+ with torch.no_grad():
+ output = self._folding_model.infer_pdb(sequence)
+ self._folding_model.to("cpu")
+
+ with open(save_path, "w") as f:
+ f.write(output)
+ return output
+
+ def __len__(self):
+ return len(self._all_sample_ids)
+
+ def __getitem__(self, idx):
+ seq, energy = self._all_sample_ids[idx]
+ aatype = torch.tensor([restype_order[s] for s in seq])
+ num_res = len(aatype)
+
+ node_repr_pre, pair_repr_pre = get_pre_repr(aatype, self.model_esm2,
+ self.alphabet, self.batch_converter, device = self.device_esm) # (B,L,d_node_pre=1280), (B,L,L,d_edge_pre=20)
+ node_repr_pre = node_repr_pre[0].cpu()
+ pair_repr_pre = pair_repr_pre[0].cpu()
+
+ motif_mask = torch.ones(aatype.shape)
+
+
+ save_path = os.path.join(self.esm_savepath, "esm_" + self._all_filename[idx] + ".pdb")
+ if not os.path.exists(save_path):
+ seq_string = seq
+ with torch.no_grad():
+ output = self._folding_model.infer_pdb(seq_string)
+ with open(save_path, "w") as f:
+ f.write(output)
+
+
+ trans_esmfold, rotmats_esmfold = cal_trans_rotmats(save_path)
+
+ batch = {
+ 'filename':self._all_filename[idx],
+ 'trans_esmfold': trans_esmfold,
+ 'rotmats_esmfold': rotmats_esmfold,
+ 'motif_mask': motif_mask,
+ 'res_mask': torch.ones(num_res).int(),
+ 'num_res': num_res,
+ 'energy': energy,
+ 'aatype': aatype,
+ 'seq': seq,
+ 'node_repr_pre': node_repr_pre,
+ 'pair_repr_pre': pair_repr_pre,
+ }
+ return batch
+
+
+
+def save_traj(
+ sample: np.ndarray,
+ bb_prot_traj: np.ndarray,
+ x0_traj: np.ndarray,
+ diffuse_mask: np.ndarray,
+ output_dir: str,
+ aatype = None,
+ index=0,
+ ):
+ """Writes final sample and reverse diffusion trajectory.
+
+ Args:
+ bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states.
+ T is number of time steps. First time step is t=eps,
+ i.e. bb_prot_traj[0] is the final sample after reverse diffusion.
+ N is number of residues.
+ x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step.
+ aatype: [T, N, 21] amino acid probability vector trajectory.
+ res_mask: [N] residue mask.
+ diffuse_mask: [N] which residues are diffused.
+ output_dir: where to save samples.
+
+ Returns:
+ Dictionary with paths to saved samples.
+ 'sample_path': PDB file of final state of reverse trajectory.
+ 'traj_path': PDB file os all intermediate diffused states.
+ 'x0_traj_path': PDB file of C-alpha x_0 predictions at each state.
+ b_factors are set to 100 for diffused residues and 0 for motif
+ residues if there are any.
+ """
+
+ # Write sample.
+ diffuse_mask = diffuse_mask.astype(bool) # (B,L)
+ sample_path = os.path.join(output_dir, 'sample_'+str(index)+'.pdb')
+ prot_traj_path = os.path.join(output_dir, 'bb_traj_'+str(index)+'.pdb')
+ x0_traj_path = os.path.join(output_dir, 'x0_traj_'+str(index)+'.pdb')
+
+ # Use b-factors to specify which residues are diffused.
+ b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37))
+
+ sample_path = au.write_prot_to_pdb(
+ sample,
+ sample_path,
+ b_factors=b_factors,
+ no_indexing=True,
+ aatype=aatype,
+ )
+ prot_traj_path = au.write_prot_to_pdb(
+ bb_prot_traj,
+ prot_traj_path,
+ b_factors=b_factors,
+ no_indexing=True,
+ aatype=aatype,
+ )
+ x0_traj_path = au.write_prot_to_pdb(
+ x0_traj,
+ x0_traj_path,
+ b_factors=b_factors,
+ no_indexing=True,
+ aatype=aatype
+ )
+ return {
+ 'sample_path': sample_path,
+ 'traj_path': prot_traj_path,
+ 'x0_traj_path': x0_traj_path,
+ }
+
+
+def get_pylogger(name=__name__) -> logging.Logger:
+ """Initializes multi-GPU-friendly python command line logger."""
+
+ logger = logging.getLogger(name)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ for level in logging_levels:
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
+
+
+def flatten_dict(raw_dict):
+ """Flattens a nested dict."""
+ flattened = []
+ for k, v in raw_dict.items():
+ if isinstance(v, dict):
+ flattened.extend([
+ (f'{k}:{i}', j) for i, j in flatten_dict(v)
+ ])
+ else:
+ flattened.append((k, v))
+ return flattened
diff --git a/inference/valid_seq.csv b/inference/valid_seq.csv
new file mode 100644
index 0000000000000000000000000000000000000000..3f3fdb28dff2677d36fb777a59a38f17f3756ca4
--- /dev/null
+++ b/inference/valid_seq.csv
@@ -0,0 +1,102 @@
+file,seq_len,seq
+6pce_B,70,PSMKERQVCWGARDEYWKCLDENLEDASQCKKLRSSFESSCPQQWIKYFDKRRDYLKFKEKFEAGQFEPS
+1vk1_A,242,MGVEKVPKYDIPVKKVEYVFIELDKMKPHEQLVQRELEDFIESVTGSGIFWKPMLLAKIPGTDEYLIVDGHHRWAGLQKLGAKRAPSVILDYFDEGVKVYTWYPAFKGDVNKVIERLKAEGLEVIEDEKAEEKAEKGEIAFALIGEKSFAIPGGLEEQKKVSKVLDEMDQAKEIELVYYGLKEDAKADMEKGEIDYVFIRKAPTKEEVMELVKRGEVFSPKTTRHVLPFIPDKIDVKLEDLF
+3nuf_A,119,GMTPLDANVELPTEVKAMIEQSSDAQAATALVNYVIKLAAAAEIHFTDLQLQVLTNHLIEMLGRSKSGEQLPAVDPTMFAEVSQKSLDLADQVVQHIGHLEVAEKYVLSIHFEAAQDKI
+4wlr_B,102,QVDLASVLTPEIMAPILANADVQERLLPYLPSGESLPQTADEIQNTLTSPQFQQALGMFSAALASGQLGPLMCQFGLPAEAVEAANKGDVEAFAKAMQNNAK
+2wkx_A,261,MGAGEKGIVEKEGYQLDTRRQAQAAYPRIKVLVIHYTADDFDSSLATLTDKQVSSHYLVPAVPPRYNGKPRIWQLVPEQELAWHAGISAWRGATRLNDTSIGIELENRGWQKSAGVKYFAPFEPAQIQALIPLAKDIIARYHIKPENVVAHADIAPQRKDDPGPLFPWQQLAQQGIGAWPDAQRVNFYLAGRAPHTPVDTASLLELLARYGYDVKPDMTPREQRRVIMAFQMHFRPTLYNGEADAETQAIAEALLEKYGQD
+1xmk_A,79,GSHMASLDMAEIKEKICDYLFNVSDSSALNLAKNIGLTKARDINAVLIDMERQGDVYRQGTTPPIWHLTDKKRERMQIK
+3vnb_A,155,MNKKMINGGTVNNWICINFSRQVQDNLARTFCQELAQMCYVSGMAFNPEPVLPPVSARPEQVEKVLKTRYHDATSKLSQGKEIDLLIVILPDKKNSDLYGDLKRICETELGIVSQCCLTKHVFKMSKQYMANVALKINVKVGGRNTVLVHHHHHH
+6crk_G,71,MASNNTASIAQARKLVEQLKMEANIDRIKVSKAAADLMAYCEAHAKEDPLLTPVPASENPFREKKFFSAIL
+5j3t_B,242,MSFTNATFSQVLDDLSARFILNLPAEEQSSVERLCFQIEQAHWFYEDFIRAQNDQLPSLGLRVFSAKLFAHCPLLWKWSKVHEEAFDDFLRYKTRIPVRGAIMLDMSMQQCVLVKGWKASSGWGFPKGKIDKDESDVDCAIREVYEETGFDCSSRINPNEFIDMTIRGQNVRLYIIPGISLDTRFESRTRKEISKIEWHNLMDLPTFKKNKPQTMKNKFYMVIPFLAPLKKWIKKRNIANNT
+2prv_A,153,GMIYSKVENFINENKQNAIFTEGASHENIGRIEENLQCDLPNSYKWFLEKYGAGGLFGVLVLGYNFDHASVVNRTNEYKEHYGLTDGLVVIEDVDYFAYCLDTNKMKDGECPVVEWDRVIGYQDTVADSFIEFFYNKIQEAKDDWDEDEDWDD
+3pes_A,85,SNAMLAEFEDRVAGIPCLIVVTYWEPYVPAKVSGPPEYCYPAEGGCGEWEVRDRRGRPAPWLERKLTEAERERIDQAVFDRMEGR
+3u5v_A,76,MGADKRAHHNALERKRRRDINEAFRELGRMCQMHLKSDKAQTKLLILQQAVQVILGLEQQVRERNLNPLNHHHHHH
+5ori_A,583,DTHKSEIAHRFNDLGEENFQGLVLIAFSQYLQQCPFDEHVKLVKELTEFAKTCVADESHAGCDKSLHTLFGDELCKVATLRETYGDMADCCEKQEPERNECFLKHKDDSPDLPKLKPEPDTLCAEFKADEKKFWGKYLYEVARRHPYFYAPELLYYANKYNGVFQECCQAEDKGACLLPKIETMREKVLASSARQRLRCASIQKFGERALKAWSVARLSQKFPKADFTDVTKIVTDLTKVHKECCHGDLLECADDRADLAKYICDHQDTLSSKLKECCDKPVLEKSHCIAEIDKDAVPENLPPLTADFAEDKEVCKNYQEAKDVFLGSFLYEYSRRHPEYAVSVLLRLAKEYEATLEDCCAKEDPHACYATVFDKLKHLVDEPQNLIKKNCELFEKHGEYGFQNALIVRYTRKAPQVSTPTLVEISRSLGKVGTKCCAKPESERMPCTEDYLSLILNRLCVLHEKTPVSEKVTKCCTESLVNRRPCFSDLTLDETYVPKPFDGESFTFHADICTLPDTEKQIKKQTALVELLKHKPKATDEQLKTVMENFVAFVDKCCAADDKEGCFLLEGPKLVASTQAALA
+3wur_B,171,HHHHHHMTALTLPEDIRQQEPSALLYTLVSAYLEHTAQTGDESLSCLSDDQHTLTAFCYLDSQVEEGGFVQLIASGYGEYIFRNPLADSLRRWKIKAVPKVLDKAKALYEQHGKTIETLADGGADIPSLRKQFPEFEEWDGAYYEAAEQDLPLLAEHIQSNWETFAHIGQA
+2ra8_A,362,MKRVFVFQDFKSQKFWSIDVRGTDVIVNYGKLGTDGQTQVKNFSSAGEAEKAAGKLIAEKTKKGYVETLEEVAKEMKVEAKKYALSYDEAEEGVNLMDKILKDKKLPSLKQITIGCWGYEGEDCSDIADGIVENKEKFAHFEGLFWGDIDFEEQEISWIEQVDLSPVLDAMPLLNNLKIKGTNNLSIGKKPRPNLKSLEIISGGLPDSVVEDILGSDLPNLEKLVLYVGVEDYGFDGDMNVFRPLFSKDRFPNLKWLGIVDAEEQNVVVEMFLESDILPQLETMDISAGVLTDEGARLLLDHVDKIKHLKFINMKYNYLSDEMKKELQKSLPMKIDVSDSQEYDDDYSYPMITELEHHHHHH
+4xo1_A,60,MNIEELKKQAETEIADFIAQKIAEMNKNTGKEVSEMRFTAREKMTGLESYDVKIKIMLEH
+3wx4_A,100,MIIDSQSVVQYTFKIDILEKLYKFLPNLYHSIVNELVEELHLENNDFLIGTYKDLSKAGYFYVIPAPGKNIDDVLKTIMIYVHDYEIEDYFELEHHHHHH
+3dza_A,191,GATDSATAAPAAAATTQVQKEAADVLQVAVQGANAMRDIQFARLALFHGQPDSAKKLTDDAAALLAADDASWAKFVKTDAKAKMIADRYVIINASIALSEDYVATPEKESAIQSANEKLAKGDQKGAIDTLRLAGIGVIENQYLMPLNQTRKAVAQSQELLKAGKYYEANLVLKGAEEGIVVDSEMLVAGN
+2au3_A,407,GHMSSDIDELRREIDIVDVISEYLNLEKVGSNYRTNCPFHPDDTPSFYVSPSKQIFKCFGCGVGGDAIKFVSLYEDISYFEAALELAKRYGKKLDLEKISKDEKVYVALDRVCDFYRESLLKNREASEYVKSRGIDPKVARKFDLGYAPSSEALVKVLKENDLLEAYLETKNLLSPTKGVYRDLFLRRVVIPIKDPRGRVIGFGGRRIVEDKSPKYINSPDSRVFKKGENLFGLYEAKEYIKEEGFAILVEGYFDLLRLFSEGIRNVVAPLGTALTQNQANLLSKFTKKVYILYDGDDAGRKAMKSAIPLLLSAGVEVYPVYLPEGYDPDEFIKEFGKEELRRLINSSGELFETLIKTARENLEEKTREFRYYLGFISDGVRRFALASEFHTKYKVPMEILLMKIEK
+2wsi_A,306,MQLSKAAEMCYEITNSYLHIDQKSQIIASTQEAIRLTRKYLLSEIFVRWSPLNGEISFSYNGGKDCQVLLLLYLSCLWEYFFIKAQNSQFDFEFQSFPMQRLPTVFIDQEETFPTLENFVLETSERYCLSLYESQRQSGASVNMADAFRDFIKIYPETEAIVIGIRHTDPFGEALKPIQRTDSNWPDFMRLQPLLHWDLTNIWSFLLYSNEPICGLYGKGFTSIGGINNSLPNPHLRKDSNNPALHFEWEIIHAFGKDAEGERSSAINTSPISVVDKERFSKYHDNYYPGWYLVDDTLERAGRIKN
+3bwu_D,125,DLYFNPRFLADDPQAVADLSRFENGQELPPGTYRVDIYLNNGYMATRDVTFNTGDSEQGIVPCLTRAQLASMGLNTASVAGMNLLADDACVPLTTMVQDATAHLDVGQQRLNLTIPQAFMSNRAR
+1g6g_A,127,GENIVCRVICTTGQIPIRDLSADISQVLKEKRSIKKVWTFGRNPACDYHLGNISRLSNKHFQILLGEDGNLLLNDISTNGTWLNGQKVEKNSNQLLSQGDEITVGVGVESDILSLVIFINDKFKQCL
+3dff_A,273,MPHDPGATRLLAISPHLDDAVLSFGAGLAQAAQDGANVLVYTVFAGAAQPPYSPAAQRMHTIWGLAPDDDAVLYRRKEDIAALDHLRVAHRHGRFLDSIYRKLPDGRWLTAHVEGRQKLAVNDHSPDSDHDLVGEVADDIRSIIDEFDPTLVVTCAAIGEHPDHEATRDAALFATHEKNVPVRLWEDLPYAVFKSGAVELPQGFRLGSADVSSVKPEMRSQKFQAVERYSSQMVLLNGSENNLFDRLDEHARQNAPHGGYGETTWPVVRSDDS
+2bko_A,205,MEEVEEFKYEPKSVKEIFIEMKDTVELMVDLAYASLLFGDKEIAEEVLELEERIDLLNYQLMMHSVLAARNVKEAEQVITILQIANAIEDISNAAGDLAKMVLEGVELHPVIKETILEGEEIIGKIQVYPESVIVGKTLGELDLATNTGVWIIAVRRGKRWIFGPNENFKIRAGDVLIGRGTRTSIDHLKEIARGAIRVIGNERA
+3ipj_A,95,SNANKYNKIANELIKIIGEDNIISITHCATRLRVMVKDREIINDKKVEKVDEVKGVFFTSGQYQIILGTGIVNKVYAEVEKMGLKTLSKKEQDEL
+4z3x_A,653,MRYAETGYVLEVDLTKGSIERVATDPRDTELYLGGLGTNAKILWDRVPPEVEPFSPENLLIFAAGLLCGTPATGCNRTIVSTVSPQTKLMAFSMMGGFWAPELKYAGYDKIIFRGKSPELVYLYINNDKVEIRDASHLKGKGAIETAEIIKKELNEPRAQVAAIGKAGENRVFYASIEQGRSSASRGGIGAVMGDKGLKAVVVRGTKDLCVAKPEEYIGLCNEVLDYIKHREENPIPDVMPILAGLGSPQEMKVHDEKWHTENFNWGNARTRRKDFWTDEVSHAWEKTMDKARTRLISCYNCPMKCGATISMEGLPTYMMKCFTKLTYTMAAYSDLDFGLRIAQKATEYGLDGFSAPQVMAFAFELLEKGILKDSDFPGLPEGNEERFFYLLDKIVNRDGIGDILANGTYWAAQEIGNGAEDYAHNNIKKHEQLPLKLSMLNPIYYLMYCTGEKINITQIEGQFPQAPYPKLEQREAFVEDWIQVPDEKFKKIFLEWEPRGEKSMPNFPTVDMCCDIVDWQEMMHYIDDALGQCAGLSSFPLKPPYHIHNYPKFIAAGAGIEMDTEKLKKAAKRYRTLVRAFNIRRGMRRVDEQPPANHWKNRFPELEKELLDSYYKLKGWNDDGIPTKETLDDLGLGYVGDEFIKRGILSAG
+4wyw_A,481,GLIVDTRDVEERVHVMRKTELAPTVAHGVFNPEFGPAALSNKDPRLNEGVVLDEVIFSKHKGDTKMSAEDKALFRRCAADYASRLHSVLGTANAPLSIYEAIKGVDGLDAMEPDTAPGLPWALQGKRRGALIDFENGTVGPEVEAALKLMEKREYKFACQTFLKDEIRPMEKVRAGKTRIVDVLPVEHILYTRMMIGRFCAQMHSNNGPQIGSAVGCNPDVDWQRFGTHFAQYRNVWDVDYSAFDANHCSDAMNIMFEEVFRTEFGFHPNAEWILKTLVNTEHAYENKRITVEGGMPSGCSATSIINTILNNIYVLYALRRHYEGVELDTYTMISYGDDIVVASDYDLDFEALKPHFKSLGQTITPADKSDKGFVLGHSITDVTFLKRHFHMDYGTGFYKPVMASKTLEAILSFARRGTIQEKLISVAGLAVHSGPDEYRRLFEPFQGLFEIPSYRSLYLRWVNAVCGDAAAALEHHHHHH
+3ka5_A,65,EIVKTKRFAIKPMSEEEAVLEMELLGHNFFVFQNGDSNEVNVVYKRKDGNYGLIEPELEHHHHHH
+3dnh_B,258,GHMLDVAPPVITPRGTKIEPSAGAPFEAVRVARDVLHTSRTAALATLDPVSGYPYTTATNIGIEPDGTPFFFAAGLTLHARNMETDARISVTLAPFGKGDALTLPRLTLVGRADRIGPDEVPLAIARYIARYPKAKLYLSLPDTRLYRLRTEGVQINGGPARNASNITPADLRTDLSGAEELMAAAESEATRLNAIKGEASRLAVLAGAKTGRWKITSIDPDGIDLASASDLARLWFAERVETLKQFEKALAQLLKGS
+3vgp_A,164,GPDLTEDWKEALEWMRTSLEEQNYLNPYEKPEYSVMSWWDYGNWILYVSKKAVVANNFQAGAVDAAKFFTAKSEDEAIKIAKKRGVRYVVTADEITMKDANNTKFPAIMRIAGYNVDLMTEGEILNFFNHTVLYRLHMENAENLTHFRLVKEFGDVKIFEVVGS
+3pt1_A,471,HMTIPGRFMTIDKGTFGEYTASTRWPIIIQNAIDDLSKHQETEKSNGTKFEQGEVIKKELKEFRQEIIDRVPLRPFTEEEIKIANVPLSFNEYLKKHPEVNWGAVEWLFSEVYLYRRVNVLFQRQCEWAKFDIFNRLKQSTFESSFYGVVELALRYENLLPQLREMKQNPGNEIDDILKVLFKEFIEISLWGNATDLSLLTNATLEDIKSIQGAKARAASESKIVVNDTEKAWEVLTKARADANSREIRVDFVLDNSGFELYADLMLAAFLLQSGLATKCIFHAKDIPYMVSDVMLKDFDILVHDLRDREFFPSGEPSTKESRALDLFAGEMEKFVSSGKIEFREDSFWTTELDYWNLDANETKYHGSILHKDLQKSNLVIFKGDLNYRKLTGDRKWPRTTKWETAIGPLATNGITSLSLRTCKADVQVALPEGLDAKLSQEWEKENPGRGSWWCCSGKWAVICFCSGIHK
+3boe_A,210,SISPAQIAEALQGRGWDAEIVTDASMAGQLVDVRPEGILKCVDGRGSDNTRMGGPKMPGGIYAIAHNRGVTSIEGLKQITKEVASKGHLPSVHGDHSSDMLGCGFFKLWVTGRFDDMGYPRPQFDADQGANAVKDAGGIIEMHHGSHTEKVVYINLLANKTLEPNENDQRFIVDGWAADKFGLDVPKFLIAAAATVEMLGGPKNAKIVVP
+7p46_A,282,KNLRDLEPGIHTDLEGRLTYGGYLRLDQLLSAQQPLSEPAHHDEMLFIIQSQTSELWLKLLAHELRAAIVHLQRDEVWQCRKVLARSKQVLRQLTEQWSVLETLTPSEYMGFRDVLGPSSGFQSLQYRYIEFLLGNKNPQMLQVFAYDPAGQARLREVLEAPSLYEEFLRYLARFGHAIPQQYQARDWTAAHVADDTLRPVFERIYENTDRYWREYSLCEDLVDVETQFQLWRFRHMRTVMRVIGFKRGTGGSSGVGFLQQALALTFFPELFDVRTSVGVDN
+2z4u_A,261,VNTITFDVGNATINKYATFMESLRNEAKDPTLKCYGIPMLPDSNLTPKYVLVKLQDASSKTITLMLRRNNLYVMGYSDLYNGKCRYHIFNDISSTESTDVENTLCPNSNSREKKAINYNSQYSTLQNKAGVSSRSQVQLGIQILNSDIGKISGVSTFTDKTEAEFLLVAIQMVSEAARFKYIENQVKTNFNRAFNPNPKVLSLEENWGKISLAIHNAKNGALTSPLELKNADDTKWIVLRVDEIKPDMGLLNYVSGTCQTT
+1d2t_A,231,LALVATGNDTTTKPDLYYLKNSEAINSLALLPPPPAVGSIAFLNDQAMYEQGRLLRNTERGKLAAEDANLSSGGVANAFSGAFGSPITEKDAPALHKLLTNMIEDAGDLATRSAKDHYMRIRPFAFYGVSTCNTTEQDKLSKNGSYPSGHTSIGWATALVLAEINPQRQNEILKRGYELGQSRVICGYHWQSDVDAARVVGSAVVATLHTNPAFQQQLQKAKAEFAQHQKK
+5znj_A,567,MKQSKVFIPTMRDVPSEAEAQSHRLLLKSGLIKQSTSGIYSYLPLATRVLNNITAIVRQEMERIDSVEILMPALQQAELWEESGRWGAYGPELMRLQDRHGRQFALGPTHEELVTSIVRNELKSYKQLPMTLFQIQSKFRDEKRPRFGLLRGREFIMKDAYSFHADEASLDQTYQDMYQAYSRIFERVGINARPVVADSGAIGGSHTHEFMALSAIGEDTIVYSKESDYTANIEKAEVVYEPNHKHSTVQPLEKIETPNVKTAQELADFLGRPVDEIAKTMIFKVDGEYIMVLVRGHHEINDIKLKSYFGTDNIELATQDEIVNLVGANPGSLGPVIDKEIKIYADNFVQDLNNLVVGANEDGYHLINVNVGRDFNVDEYGDFRFILEGEKLSDGSGVAHFAEGIEVGQVFKLGTKYSESMNATFLDNQGKAQPLIMGCYGIGISRTLSAIVEQNHDDNGIVWPKSVTPFDLHLISINPKKDDQRELADALYAEFNTKFDVLYDDRQERAGVKFNDADLIGLPLRIVVGKRASEGIVEVKERLTGDSEEVHIDDLMTVITNKYDNLK
+6jv8_A,76,LQLLHQKVEEQAAKYKHRVPKKCCYDGARENKYETCEQRVARVTIGPHCIRAFNECCTIADKIRKESHHKGMLLGR
+3vej_B,41,SVDLTVPWDDIEALLKNNFENDQAAVRQVMERLQKGWSLAK
+1dwk_A,156,MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLAEAFVTAALLGQQALPADAARLVGAKLDLDEDSILLLQMIPLRGCIDDRIPTDPTMYRFYEMLQVYGTTLKALVHEKFGDGIISAINFKLDVKKVADPEGGERAVITLDGKYLPTKPF
+6jpt_A,122,MEDTPLVISKQKTEVVCGVPTQVVCTAFSSHILVVVTQFGKMGTLVSLEPSSVASDVSKPVLTTKVLLGQDEPLIHVFAKNLVAFVSQEAGNRAVLLAVAVKDKSMEGLKALREVIRVCQVW
+7aex_A,275,MHHHHHHSEINPAEFEQVNMVLQGFVETSVLPVLELSADESHIEFREHSRNAHTVVWKIISTSYQDELTVSLHITTGKLQIQGRPLSCYRVFTFNLAALLDLQGLEKVLIRQEDGKANIVQQEVARTYLQTVMADAYPHLHVTAEKLLVSGLCVKLAAPDLPDYCMLLYPELRTIEGVLKSKMSGLGMPVQQPAGFGTYFDKPAAHYILKPQFAATLRPEQINIISTAYTFFNVERHSLFHMETVVDASRMISDMARLMGKATRAWGIIKDLYIV
+2p12_A,176,GHMAGASHDIHPPKVEYFDLRDHTNTDPKGFVRHVDHYRVEPWGLYMARTSDHPQFHYLESWLLPDLGLRASIFHYHPYHQRDQDHYVDIGTFTRGDDVWKSEDHYLDLVVRTGRDTELLDVDELMEAHTTGLLDTATAEQAILTATTAIDGIAAHGHDLGRWLASIGMPIDWRGS
+2oy9_B,98,MSLKTTLPISLDWSTEEVIDVVHFFQAIEQAYDQGIAREDLLGKYRRFKEIVPSKSEEKQLFRAYEQENDVSCYQTIKKAREEMEEHIQMEGHHHHHH
+4byz_A,194,YTMSAQVIQIGRQRFVGGLFWQSLSRRNELRAEAVELAKKLKFDLMVLRIDRGVAAAGYANTRDGFAPGHLSLGAMVSRAIALEGAFYNGRRQPAPNWLGAFALPDGRWAYFAVRDHAFMPNGDWVGSREEALERLHTDYAWGGWNVVIGEPELERQGFQNFQPKRLDDLLPRRGGRPRTERWWALRPVERRLS
+3mmy_D,56,TGTTIKFNPPTGTDTMVKAGVSTNISTKHQCITAMKEYESKSLEELRLEDYQANRK
+1dk8_A,147,GSASPTPPYLKWAESLHSLLDDQDGISLFRTFLKQEGCADLLDFWFACTGFRKLEPCDSNEEKRLKLARAIYRKYILDNNGIVSRQTKPATKSFIKGCIMKQLIDPAMFDQAQTEIQATMEENTYPSFLKSDIYLEYTRTGSESPKV
+7jfl_C,47,SALQDLLRTLKSPSSPQQQQQVLNILKSNPQLMAAFIKQRTAKYVAN
+3c2q_B,345,SNANIPEIENANLKPALKDSVLPDGFYSTTNHPTHVKVNDEWIEVANPKMDAVIVVYPEEKRAETKVIRKVKKGDFVLIGHNGIRVMPPEKSREAGQLFEFMNSEVSSEKPKEAIIKRIAKEMHEIREEYKKTGTGGIAIVGGPAIIHTGGGPALAKMVELGYIQAILAGNALATHDIESALYGTSLGVNIKTAKPVTGGHKHHIYAINAINDAGNIKNAVESGVLKEGIMYQCIKNNIPYVLAGSIRDDGPIPDVITDSMVAQDKMRTTVMDKKMVIMLSTLLHSVATGNLMPSYIKTVCVDIQPSTVTKLMDRGTSQAIGVVTDVGVFLVLLLKELERLELQE
+2bw3_B,84,SHMQSRELKTVSADCKKEAIEKCAQWVVRDCRPFSAVSGSGFIDMIKFFIKVGAEYGDHVNVEELLPSPITLSRKVTSDAKEKA
+6bwq_A,149,TADAVLMIEANLDDQTGEGLGYVMNQLLTAGAYDVFFTPIQMKKDRPATKLTVLGNVNDKDLLTKLILQETTTIGVRYQTWQRTIMQRHFLTVATPYGDVQVKVATYQDIEKKMPEYADCAQLAQQFHIPFRTVYQAALVAVDQLDEEA
+5k6d_B,78,MWKVSERCLKGHGKFQADQEIGNGLATAKGQCKGTDSDQKKAGKCDKHCTGVCLGSGGSCGDGSSQKPNKEDCYCKSK
+5eqz_A,146,MAHHHHHHYVEEKKEIDSLMEDVLALVNDSSGGKFKDYKDKINELKENLKDIGNAELKEKLLNLQNSFQDKLAAKLAALKAAKNTIENITDKDQDISKRKIWSEAKLVGVTVPLLGSNTSGNGDKMSKNAVEQIDKVIKFLEEGTN
+3d4i_A,273,AMGSATISRRGILVIRHGERVDQVFGKSWLQQCTTADGKYYRPDLNFPRSLPRRSNGIKDFENDPPLSSCGIFQARLAGEALLDSGVRVTAVFASPALRCVQTAKHILEELKLEKKLKIRVEPGIFEWMKWEASKATLTFLTLEELKEANFNVDLDYRPALPRCSLMPAESYDQYVERCAVSMGQIINTCPQDMGITLIVSHSSALDSCTRPLLGLPPRECGDFAQLVRKIPSLGMCFCEENREDGKWDLVNPPVKTLTHGANSVFNWRNWIS
+5n0n_A,410,GTSTQTKAGSLTIVGTGIESIGQMTLQALSYIEAAAKVFYCVIDPATEAFILTKNKNCVDLFQYYDNGKSRLNTYTQMSELMVREVRKGLDVVGVFYGHPGVFVNPSHRALAIAKSEGYRARMLPGVSAEDCLFADLCIDPSNPGCLTYEASDFLIRDRPVSIHSHLVLFQVGCVGIADFNFTGFDNNKFGVLVDRLEQEYGAEHPVVHYIAAMMPHQDPVTDKYTVAQLREPEIAKRVGGVSTFYIPPKARKASNLDIIRRLELLPAGQVPDKKARIYPANQWEPDVPEVEPYRPSDQAAIAQLADHAPPEQYQPLATSKAMSDVMTKLALDPKALADYKADHRAFAQSVPDLTPQERAALELGDSWAIRCAMKNMPSSLLDAARESGEEASQNGFPWVIVVGVIGVIG
+5iz3_B,316,ELGDSLEEFLAKATTDKNLARLLVCMGEALRTIAFKVRTASCGATACTNTFGDEQLAVDMLADKLLFEALRHSHVCKYACSEEEPILQDMEGEGFSVAFDPLDGSSIVDTNFTVGTIFGVWPGDKLTGITGRDQAASAMGIYGPRTTYVVAINGFPGTHEFLLMDDGKWQHVKETTEIKEGKLFSPGNLRATFDNADYEKLINYYVSEKYTLRYTGGMVPDVNQIIVKERGIFTNVTSPTTKAKLRLLFEVAPLGLLIENAGGYSSDGKQSVLDKVVVNTDDRTQVAYGSRDEIIRFEETLYGDSRLKAELAATVV
+2ov9_C,216,VSVGTHPTFENSPSGTVLTSPPDGSAVDRATDAARRVVDALLRTDRGNANLERVAEELNSIAGHLEEHAPAVAERLIDMWNGEGVTRHDPVTGPENALAPPVVLEGLSDGSVRGTVTLTIPYQGPPGHVHGGVSALLLDHVLGVANAWGGKAGMTAQLSTRYHRPTPLFEPLTLTGKLMSVDGRKITTAGDIRTADGQVCVSVEGLFVDKTVPRPR
+3tbn_A,87,GSSHHHHHHTDPVIAQKAPYPVTVEAGKTYHWCACGRSKAQPFCDGSHKGTGLAPVAYTPDKAGTAYFCGCKASKAPPLCDGTHKTL
+4ue8_B,38,GPHMTKLIYERAFMKNLRGSPLSQTPPSNVPSCLLRGT
+3bl9_A,301,SAPVRLPFSGFRLQKVLRESARDKIIFLHGKVNEASGDGDGEDAVVILEKTPFQVEQVAQLLTGSPELQLQFSNDIYSTYHLFPPRQLNDVKTTVVYPATEKHLQKYLRQDLRLIRETGDDYRNITLPHLESQSLSIQWVYNILDKKAEADRIVFENPDPSDGFVLIPDLKWNQQQLDDLYLIAICHRRGIRSLRDLTPEHLPLLRNILHQGQEAILQRYRMKGDHLRVYLHYLPSYYHLHVHFTALGFEAPGSGVERAHLLAEVIENLECDPRHYQQRTLTFALRADDPLLKLLQEAQQS
+2ozl_A,365,MRGSFANDATFEIKKCDLHRLEEGPPVTTVLTREDGLKYYRMMQTVRRMELKADQLYKQKIIRGFCHLCDGQEACCVGLEAGINPTDHLITAYRAHGFTFTRGLSVREILAELTGRKGGCAKGKGGSMHMYAKNFYGGNGIVGAQVPLGAGIALACKYNGKDEVCLTLYGDGAANQGQIFEAYNMAALWKLPCIFICENNRYGMGTSVERAAASTDYYKRGDFIPGLRVDGMDILCVREATRFAAAYCRSGKGPILMELQTYRYHGHEMSDPGVSYRTREEIQEVRSKSDPIMLLKDRMVNSNLASVEELKEIDVEVRKEIEDAAQFATADPEPPLEELGYHIYSSDPPFEVRGANQWIKFKSVS
+5l87_A,69,GAMWHTHSEREKRVSNAVEFLLDSRVRRTPTSSKVHFLKSKGLSAEEICEAFTKVGQPKTLNEIKRILS
+1xkg_A,312,RPSSIKTFEEYKKAFNKSYATFEDEEAARKNFLESVKYVQSNGGAINHLSDLSLDEFKNRFLMSAEAFEHLKTQFDLNAETNACSINGNAPAEIDLRQMRTVTPIRMQGGCGSAWAFSGVAATESAYLAYRDQSLDLAEQELVDCASQHGCHGDTIPRGIEYIQHNGVVQESYYRYVAREQSCRRPNAQRFGISNYCQIYPPNANKIREALAQTHSAIAVIIGIKDLDAFRHYDGRTIIQRDNGYQPNYHAVNIVGYSNAQGVDYWIVRNSWDTNWGDNGYGYFAANIDLMMIEEYPYVVILGQTGHHHHHH
+1sx3_A,525,AAKDVKFGNDAGVKMLRGVNVLADAVKVTLGPKGRNVVLDKSFGAPTITKDGVSVAREIELEDKFENMGAQMVKEVASKANDAAGDGTTTATVLAQAIITEGLKAVAAGMNPMDLKRGIDKAVTVAVEELKALSVPCSDSKAIAQVGTISANSDETVGKLIAEAMDKVGKEGVITVEDGTGLQDELDVVEGMQFDRGYLSPYFINKPETGAVELESPFILLADKKISNIREMLPVLEAVAKAGKPLLIIAEDVEGEALATLVVNTMRGIVKVAAVKAPGFGDRRKAMLQDIATLTGGTVISEEIGMELEKATLEDLGQAKRVVINKDTTTIIDGVGEEAAIQGRVAQIRQQIEEATSDYDREKLQERVAKLAGGVAVIKVGAATEVEMKEKKARVEDALHATRAAVEEGVVAGGGVALIRVASKLADLRGQNEDQNVGIKVALRAMEAPLRQIVLNCGEEPSVVANTVKGGDGNYGYNAATEEYGNMIDMGILDPTKVTRSALQYAASVAGLMITTECMVTDLPK
+5j2l_B,81,GSHMGSEDYKLREAQRELDKQRKDTEEIRKRLKEIQRLTDERTSTADELIKELREIIRRLQEQSEKLREIIEELEKIIRKR
+2fef_B,294,MHAFPLTLDNRLAEALPLWRNLARTDRAPRRNIDLADWKADWRELIAALDRFSRSHGYRQPFAAQGHAALENAWAWGQAAENASTLLLKAIDRGLAGAELRSIYLETAALWLDYSRLLGAARDSLREQGEVDFETAPALAPRTGQYPFALQLLAMGVLLDAQELIPALVEEVLQFDTDRLLDYLGAAALGLTSASEETFHPRPFGQLRAFFEEADGSDAQALAPYLQSQYREFFQLSPKAQKKTRRLTGPYAWGWWAMEVSALGVLYGWDDGVLRASPHYLGDLVDYARARGDA
+7p41_D,448,MRGSMQQVGTVAQLWIYPVKSCKGVPVSEAECTAMGLRSGNLRDRFWLVINQEGNMVTARQEPRLVLISLTCDGDTLTLSAMNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNCNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRCALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYTKDLLLPIKTPTTNAVHKCRVHGLEIEGRDCGEATAQWITSFLKSQPYRLVHFEPHMRPRRPHQIADLFRPKDQIAYSDTSPFLILSEASLADLNSRLEKKVKATNFRPNIVISGCDVYAEDSWDELLIGDVELKRVMACSRCILTTVDPDTGVMSRKEPLETLKSYRQCDPSERKLYGKSPLFGQYFVLENPGTIKVGDPVYLLG
+2bbr_A,195,MSDSKEVPSLPFLRHLLEELDSHEDSLLLFLCHDAAPGCTTVTQALCSLSQQRKLTLAALVEMLYVLQRMDLLKSRFGLSKEGAEQLLGTSFLTRYRKLMVCVGEELDSSELRALRLFACNLNPSLSTALSESSRFVELVLALENVGLVSPSSVSVLADMLRTLRRLDLCQQLVEYEQQEQARYRYCLEHHHHHH
+6y2x_A,235,GSEPEPEQVIKNYTEELKVPPDEDCIICMEKLSTASGYSDVTDSKAIGSLAVGHLTKCSHAFHLLCLLAMYCNGNKDGSLQCPSCKTIYGEKTGTQPQGKMEVLRFQMSLPGHEDCGTILIVYSIPHGIQGPEHPNPGKPFTARGFPRQCYLPDNAQGRKVLELLKVAWKRRLIFTVGTSSTTGETDTVVWNEIHHKTEMDRNITGHGYPDPNYLQNVLAELAAQGVTEDCLEQQ
+5a0n_A,220,GAMGGFPNDAKGISGNGKYYSLGQIEKLYSNQFATYNNLTVITSDTHENSDNFAFCLANGKRFPSFTDEKPKGIYTLVKDINKEQYTKLLKENHKWSSIPNLNQAWDTFSRLSYMYLKDPTDIVKRAWGTDLNTARTYFHQVIQYEIWRYTDGMRVSSDTNVYIYEKFSPQQKKALEMIRTDLYNFTVPYENLEYRFYKPDWVFGLGFQALATVRWKIEP
+2vri_A,174,KAGSAAAPFTKPFAVYKNVKFYLGDISHLVNCVSFDFVVNAANENLLHGGGVARAIDILTEGQLQSLSKDYISSNGPLKVGAGVMLECEKFNVFNVVGPRTGKHEHSLLVEAYNSILFENGIPLMPLLSCGIFGVRIENSLKALFSCDINKPLQVFVYSSNEEQAVLKFLDGLD
+2j7z_A,68,KPVSLSYRCPCRFFESHVARANVKHLKILNTPNCALQIVARLKNNNRQVCIDPKLKWIQEYLEKALNK
+3a1g_A,80,SQRGVLEDEQMYQRCCNLFEKFFPSSSYRRPVGISSMVEAMVSRARIDARIDFESGRIKKEEFTEIMKICSTIEELRRQK
+6oz1_A,643,QGMTETISTPAATTTPDLADQQELERRVAQVVSNDPQLQALLPDDAVSGAVNEPGLTLIELIRRLLEGYGDRPALGQRAVELVTDEHGATTVALKTEFVTTSYRELWNRAEAIAAAWYAQGIRDGDFVAQLGFTSTDFASLDVAGLRLGTVSVPLQTGASVQQRNAILEETQPTVFAASVEYLEAAVDSVLATPSVQLLSVFDYHPEADAHRAALSAVRDRLETAGRTITIDSLGDAIARGRELPAAPQPSEDPDALRLLIYTSGSTGTPKGAMYPQWLVANLWQKKWLTTTVIPSVGVNFMPMSHLAGRLTLMGTLSGGGTAYYIASSDLSTFFEDIALIRPTEVLFVPRVVEMVFQRYQAELDRSLAPGESNAEIAEKIKVRIREEDFGGRVLNAGSGSAPLSPEMNDFMESLLQVAMLDGYGSTEAGAVWRDGVLQRPPVTEYKLVDVPELGYFTTDSPHPRGELRIKSETMFPGYYKRPETTADVFDEDGFYMTGDVVAELGPDHLKYLDRVKNVLKLAQGEFVAVSKLEAAYTGSPLVRQIFVYGNSERSYLLAVVVPTPEALERYADSPDALKPLIQDSLQQVAKGAELQSYEIPRDFIVETVPFTVESGLLSDARKLLRPKLKEHYGERLEALYAD
+1ijx_C,127,GSAACEPVRIPLCKSLPWEMTKMPNHLHHSTQANAILAMEQFEGLLGTHCSPDLLFFLCAMYAPICTIDFQHEPIKPCKSVCERARQGCEPILIKYRHSWPESLACDELPVYDRGVCISPEAIVTAD
+1q2h_A,69,GSHMPDWRQFCELHAQAAAVDFAHKFCRFLRDNPAYDTPDAGASFSRHFAANFLDVFGEEVRRVLVAGP
+5dje_B,140,GSSCDFVADVPPPKKGTDYDFYEAWGPVFEAEARFSKKTPIPSLGNMDSSKKEVEQFYAFWHRFDSWRTFEFLDEDVPDDSSNRDHKRYIERKNKAARDKKKTADMARLVKLVERAVSEDPRIKMFKEEEKKEKERRKWE
+4zav_A,209,MSGPERITLAMTGASGAQYGLRLLDCLVQEEREVHFLISKAAQLVMATETDVALPAKPQAMQAFLTEYCGAAAGQIRVFGQNDWMAPPASGSSAPNAMVICPCSTGTLSAVATGACNNLIERAADVALKERRPLVLVPREAPFSSIHLENMLKLSNLGAVILPAAPGFYHQPQSVEDLVDFVVARILNTLGIPQDMLPRWGEQHLVSDE
+2q7x_A,326,GMRKPKITVIGGGTGSPVILKSLREKDVEIAAIVTVADDGGSSGELRKNMQQLTPPGDLRNVLVAMSDMPKFYEKVFQYRFSEDAGAFAGHPLGNLIIAGLSEMQGSTYNAMQLLSKFFHTTGKIYPSSDHPLTLHAVFQDGTEVAGESHIVDHRGIIDNVYVTNALNDDTPLASRRVVQTILESDMIVLGPGSLFTSILPNIVIKEIGRALLETKAEIAYVCNIMTQRGETEHFTDSDHVEVLHRHLGRPFIDTVLVNIEKVPQEYMNSNRFDEYLVQVEHDFVGLCKQVSRVISSNFLRLENGGAFHDGDLIVDELMRIIQVKK
+5n6y_C,113,MSQSHLDDLFAYVEERCLWQFFSRTWDREENIEGVLNQVGRLLTGQEPLRGTPQERLFYADALAMANDVRERFPWASQVNKEEIEFLLDGLKSRLVDVTITRSTNRELNHHLY
+1h02_B,70,GPETLCGAELVDALQFVCGDRGFYFNKPTGYGSSSRRAPQTGIVDECCFRSCDLRRLEMYCAPLKPAKSA
+1q74_C,303,MSETPRLLFVHAHPDDESLSNGATIAHYTSRGAQVHVVTCTLGEEGEVIGDRWAQLTADHADQLGGYRIGELTAALRALGVSAPIYLGGAGRWRDSGMAGTDQRSQRRFVDADPRQTVGALVAIIRELRPHVVVTYDPNGGYGHPDHVHTHTVTTAAVAAAGVGSGTADHPGDPWTVPKFYWTVLGLSALISGARALVPDDLRPEWVLPRADEIAFGYSDDGIDAVVEADEQARAAKVAALAAHATQVVVGPTGRAAALSNNLALPILADEHYVLAGGSAGARDERGWETDLLAGLGFTASGT
+5c8q_B,154,GSHGNYAGNFSGSSRDICLDGARLRAECRRGDGGYSTSVIDLNRYLSNDNGHFRWVSTATVTVQQGDTLRDIGRRFDCDFHEIARRNNIQNEDLIYPGQVLQVGGNFWDSARDVRLVDGGKVLEAELRYSGGWNRSRIYLDEHIGNRNGELIHC
+4m37_A,343,GYRPPDAFVNRIDRNIPVPARLRHTPVSLIEAVNDFHYAMMNDEERNNFYYEVLKKHVTPETGVLEIGAGSGLLSLMAAKLGAKWVVAVEGSEELAKLARENIRANNMEHQVKVLHMMSTELKSKHLPEPPDVLLSEIFGTMMLGESALDYVVDVRNRLLKPTTKIIPQFGTQYAVPIECDALHRISSVSGWRDLDLKHMMTLQDTVSIVFAKHYGIRMNSVNFRRLSDPIELFRVDFSSSNRNDIPRRKHFDVVAKESGTAHAMLFYWKVTDDEFVMSTDPEDTVNNFPRDMQWGQALQLLDASNGPLPTPVVFTEGKNYNFECNFSGDRVILHMQLCPESG
+5n0s_A,410,GTSTQTKAGSLTIVGTGIESIGQMTLQALSYIEAAAKVFYCVIDPATEAFILTKNKNCVDLYQYYDNGKSRLNTYTQMSELMVREVRKGLDVVGVFAGHPGVFVNPSHRALAIAKSEGYRARMLPGVSAEDCLFADLCIDPSNPGCLTYEASDFLIRDRPVSIHSHLVLFQVGCVGIADFNFTGFDNNKFGVLVDRLEQEYGAEHPVVHYIAAMMPHQDPVTDKYTVAQLREPEIAKRVGGVSTFYIPPKARKASNLDIIRRLELLPAGQVPDKKARIYPANQWEPDVPEVEPYRPSDQAAIAQLADHAPPEQYQPLATSKAMSDVMTKLALDPKALADYKADHRAFAQSVPDLTPQERAALELGDSWAIRCAMKNMPSSLLDAARESGEEASQNGFPWVIVVGVIGVIG
+2d1l_A,253,AGHMEAVIEKECSALGGLFQTIISDMKGSYPVWEDFINKAGKLQSQLRTTVVAAAAFLDAFQKVADMATNTRGGTREIGSALTRMCMRHRSIEAKLRQFSSALIDCLINPLQEQMEEWKKVANQLDKDHAKEYKKARQEIKKKSSDTLKLQKKAKKVDAQGRGDIQPQLDSALQDVNDKYLLLEETEKQAVRKALIEERGRFCTFISMLRPVIEEEISMLGEITHLQTISEDLKSLTMDPHKLPSSSEQVILD
+3dan_A,473,MDPSSKPLREIPGSYGIPFFQPIKDRLEYFYGTGGRDEYFRSRMQKYQSTVFRANMPPGPFVSSNPKVIVLLDAKSFPILFDVSKVEKKDLFTGTYMPSTKLTGGYRVLSYLDPSEPRHAQLKNLLFFMLKNSSNRVIPQFETTYTELFEGLEAELAKNGKAAFNDVGEQAAFRFLGRAYFNSNPEETKLGTSAPTLISSWVLFNLAPTLDLGLPWFLQEPLLHTFRLPAFLIKSTYNKLYDYFQSVATPVMEQAEKLGVPKDEAVHNILFAVCFNTFGGVKILFPNTLKWIGLAGENLHTQLAEEIRGAIKSYGDGNVTLEAIEQMPLTKSVVYESLRIEPPVPPQYGKAKSNFTIESHDATFEVKKGEMLFGYQPFATKDPKVFDRPEEYVPDRFVGDGEALLKYVWWSNGPETESPTVENKQCAGKDFVVLITRLFVIELFRRYDSFEIELGESPLGAAVTLTFLKRASI
+1ys1_X,320,ADNYAATRYPIILVHGLTGTDKYAGVLEYWYGIQEDLQQRGATVYVANLSGFQSDDGPNGRGEQLLAYVKTVLAATGATKVNLVGHSQGGLTSRYVAAVAPDLVASVTTIGTPHRGSEFADFVQGVLAYDPTGLSSTVIAAFVNVFGILTSSSNNTNQDALAALKTLTTAQAATYNQNYPSAGLGAPGSCQTGAPTETVGGNTHLLYSWAGTAIQPTISVGGVTGATDTSTIPLVDPANALDPSTLALFGTGTVMVNRGSGQNDGVVSKCSALYGQVLSTSYKWNHLDEINQLLGVRGANAEDPVAVIRTHANRLKLAGV
+2nug_B,221,MKMLEQLEKKLGYTFKDKSLLEKALTHVSYSKKEHYETLEFLGDALVNFFIVDLLVQYSPNKREGFLSPLKAYLISEEFFNLLAQKLELHKFIRIKRGKINETIIGDVFEALWAAVYIDSGRDANFTRELFYKLFKEDILSAIKEGRVKKDYKTILQEITQKRWKERPEYRLISVEGPHHKKKFIVEAKIKEYRTLGEGKSKKEAEQRAAEELIKLLEESE
+3rmq_A,116,SNAMQRYLWQQADGKRHVYDTARHRVQAGRPFTALCGETVTPQTERGDLTAGLWFDGECPVCTIALAKALGWPMREISDLAHRFDWSPALITRLAEVLHCSFGEVVELTGARMVDA
+1rk8_B,147,MSTEDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDTMIRKEAFVHQSVMEELKRIIIDSEIMQEDDLPWPPPDRVGRQELEIVIGDEHISFTTSKTGSLVDVNRSKDPEGLRCFYYLVQDLKCLVFSLIGLHFKIKPI
+1msk_A,331,KKPRTPPVTLEAARDNDFAFDWQAYTPPVAHRLGVQEVEASIETLRNYIDWTPFFMTWSLAGKYPRILEDEVVGVEAQRLFKDANDMLDKLSAEKTLNPRGVVGLFPANRVGDDIEIYRDETRTHVINVSHHLRQQTEKTGFANYCLADFVAPKLSGKADYIGAFAVTGGLEEDALADAFEAQHDDYNKIMVKALADRLAEAFAEYLHERVRKVYWGYAPNENLSNEELIRENYQGIRPAPGYPACPEHTEKATIWELLEVEKHTGMKLTESFAMWPGASVSGWYFSHPDSKYYAVAQIQRDQVEDYARRKGMSVTEVERWLAPNLGYDAD
+2p58_A,65,TQLEEQLHNVETVRSITMQLEMALTKLKKDMMRGGDAKQYQVWQRESKALESAIAIIHYVAGDLK
+3zoq_C,56,MVQNDFVDSYDVTMLLQDDDGKQYYEYHKGLSLSDFEVLYGNTADEIIKLRLDKVL
+1v0w_A,506,ADSATPHLDAVEQTLRQVSPGLEGDVWERTSGNKLDGSAADPSDWLLQTPGCWGDDKCADRVGTKRLLAKMTENIGNATRTVDISTLAPFPNGAFQDAIVAGLKESAAKGNKLKVRILVGAAPVYHMNVIPSKYRDELTAKLGKAAENITLNVASMTTSKTAFSWNHSKILVVDGQSALTGGINSWKDDYLDTTHPVSDVDLALTGPAAGSAGRYLDTLWTWTCQNKSNIASVWFAASGNAGCMPTMHKDTNPKASPATGNVPVIAVGGLGVGIKDVDPKSTFRPDLPTASDTKCVVGLHDNTNADRDYDTVNPEESALRALVASAKGHIEISQQDLNATCPPLPRYDIRLYDALAAKMAAGVKVRIVVSDPANRGAVGSGGYSQIKSLSEISDTLRNRLANITGGQQAAKTAMCSNLQLATFRSSPNGKWADGHPYAQHHKLVSVDSSTFYIGSKNLYPSWLQDFGYIVESPEAAKQLDAKLLDPQWKYSQETATVDYARGICNA
+3ho6_A,267,SEDNGVDFNKNTALDKNYLLNNKIPSNNVEEAGSKNYVHYIIQLQGDDISYEATCNLFSKNPKNSIIIQRNMNESAKSYFLSDDGESILELNKYRIPERLKNKEKVKVTFIGHGKDEFNTSEFARLSVDSLSNEISSFLDTIKLDISPKNVEVNLLGCNMFSYDFNVEETYPGKLLLSIMDKITSTLPDVNKNSITIGANQYEVRINSEGRKELLAHSGKWINKEEAIMSDLSSKEYIFFDSIDNKLKAKSKNIPGLASISEDIKTL
+5opz_A,137,GPLGSLNDIEEIRFTARSEENLRGVHPDLVRVIRLALRYSLVPFSVSEGLRSMARQREMVRAGSSQTLRSRHLTGHAVDVVAMPAGVVSWEWDYYAQIAVAVRRAARECGIIVEWGGEWKTLKDGPHFQLTFRDYPA
+5noh_A,103,ALGVVGVLESYIGSINNITKQSACVAMSKLLTELNSDDIKKLRDNEELNSPKIRVYNTVISYIESNRKNNKQTIHLLKRLPADVLKKTIKNTLDIHKSITINN
+1u84_A,90,SNAMDGQQLNRLLLEWIGAWDPFGLGKDAYDVEAASVLQAVYETEDARTLAARIQSIYEFAFDEPIPFPHCLKLARRLLELKQAASCPLP
+3c8w_C,255,GMDKYLSANSLEGVIDNEFSMPAPRWLNTYPAGPYRFINREFFIIAYETDPDLLQAILPPDMELLEPVVKFEFIRMPDSTGFGDYTESGQVVPVRYKGEEGGFTISMFLDCHAPIAGGREIWGFPKKLAKPKLFVEEDTLIGILKYGSIDIAIATMGYKHRPLDAEKVLESVKKPVFLLKNIPNVDGTPLVNQLTKTYLTDITVKGAWTGPGSLELHPHALAPISNLYIKKIVSVSHFITDLTLPYGKVVADYLA
+4cdp_A,354,MGSSHHHHHHGSMNHYTRWLELKEQNPGKYARDIAGLMNIREAELAFARVTHDAWRMHGDIREILAALESVGETKCICRNEYAVHEQVGTFTNQHLNGHAGLILNPRALDLRLFLNQWASVFHIKENTARGERQSIQFFDHQGDALLKVYATDNTDMAAWSELLARFITDENTPLELKAVDAPVVQTRADATVVEQEWRAMTDVHQFFTLLKRHNLTRQQAFNLVADDLACKVSNSALAQILESAQQDGNEIMVFVGNRGCVQIFTGVVEKVVPMKGWLNIFNPTFTLHLLEESIAEAWVTRKPTSDGYVTSLELFAHDGTQIAQLYGQRTEGEQEQAQWRKQIASLIPEGVAA
+4jmu_A,112,HMGARASVLSGGELDKWEKIRLRPGGKKQYKLKHIVWASRELERFAVNPGLLETSEGCRQILGQLQPSLQTGSEELRSLYNTIAVLYCVHQRIDVKDTKEALDKIEEEQNKS
diff --git a/models/__pycache__/edge_embedder.cpython-310.pyc b/models/__pycache__/edge_embedder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce8d099a4e79827c586c307b3378c588694d6b8e
Binary files /dev/null and b/models/__pycache__/edge_embedder.cpython-310.pyc differ
diff --git a/models/__pycache__/flow_model.cpython-310.pyc b/models/__pycache__/flow_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b0e49814dabcc5ade5c7e7260201151c1ec8730
Binary files /dev/null and b/models/__pycache__/flow_model.cpython-310.pyc differ
diff --git a/models/__pycache__/flow_module.cpython-310.pyc b/models/__pycache__/flow_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71f640fb0eaab00144ce0d812d2c8649f0f9b94b
Binary files /dev/null and b/models/__pycache__/flow_module.cpython-310.pyc differ
diff --git a/models/__pycache__/ipa_pytorch.cpython-310.pyc b/models/__pycache__/ipa_pytorch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccbc0db01f51cda2d42c5e6d143f452dc5156e8d
Binary files /dev/null and b/models/__pycache__/ipa_pytorch.cpython-310.pyc differ
diff --git a/models/__pycache__/node_embedder.cpython-310.pyc b/models/__pycache__/node_embedder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cdbbab2253225fba49d91dc456d6059180e5104
Binary files /dev/null and b/models/__pycache__/node_embedder.cpython-310.pyc differ
diff --git a/models/__pycache__/utils.cpython-310.pyc b/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3c9552d9aab93864bff0a7e815ad9823448106b
Binary files /dev/null and b/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/models/add_module/Attention_module.py b/models/add_module/Attention_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa16ca025fbd7815c574c1eccf184373b817a43b
--- /dev/null
+++ b/models/add_module/Attention_module.py
@@ -0,0 +1,386 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from opt_einsum import contract as einsum
+# from model.model_utils import init_lecun_normal
+
+class FeedForwardLayer(nn.Module):
+ def __init__(self, d_model, r_ff, p_drop=0.1):
+ super(FeedForwardLayer, self).__init__()
+ self.norm = nn.LayerNorm(d_model)
+ self.linear1 = nn.Linear(d_model, d_model*r_ff)
+ self.dropout = nn.Dropout(p_drop)
+ self.linear2 = nn.Linear(d_model*r_ff, d_model)
+
+ # self.reset_parameter()
+
+ # def reset_parameter(self):
+ # # initialize linear layer right before ReLu: He initializer (kaiming normal)
+ # nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
+ # nn.init.zeros_(self.linear1.bias)
+
+ # # initialize linear layer right before residual connection: zero initialize
+ # nn.init.zeros_(self.linear2.weight)
+ # nn.init.zeros_(self.linear2.bias)
+
+ def forward(self, src):
+ src = self.norm(src)
+ src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
+ return src
+
+class Attention(nn.Module):
+ # calculate multi-head attention
+ def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
+ super(Attention, self).__init__()
+ self.h = n_head
+ self.dim = d_hidden
+ #
+ self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
+ self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
+ self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
+ #
+ self.to_out = nn.Linear(n_head*d_hidden, d_out)
+ self.scaling = 1/math.sqrt(d_hidden)
+ #
+ # initialize all parameters properly
+ # self.reset_parameter()
+
+ # def reset_parameter(self):
+ # # query/key/value projection: Glorot uniform / Xavier uniform
+ # nn.init.xavier_uniform_(self.to_q.weight)
+ # nn.init.xavier_uniform_(self.to_k.weight)
+ # nn.init.xavier_uniform_(self.to_v.weight)
+
+ # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
+ # nn.init.zeros_(self.to_out.weight)
+ # nn.init.zeros_(self.to_out.bias)
+
+ def forward(self, query, key, value):
+ B, Q = query.shape[:2]
+ B, K = key.shape[:2]
+ #
+ query = self.to_q(query).reshape(B, Q, self.h, self.dim)
+ key = self.to_k(key).reshape(B, K, self.h, self.dim)
+ value = self.to_v(value).reshape(B, K, self.h, self.dim)
+ #
+ query = query * self.scaling
+ attn = einsum('bqhd,bkhd->bhqk', query, key)
+ attn = F.softmax(attn, dim=-1)
+ #
+ out = einsum('bhqk,bkhd->bqhd', attn, value)
+ out = out.reshape(B, Q, self.h*self.dim)
+ #
+ out = self.to_out(out)
+
+ return out
+
+class AttentionWithBias(nn.Module):
+ def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
+ super(AttentionWithBias, self).__init__()
+ self.norm_in = nn.LayerNorm(d_in)
+ self.norm_bias = nn.LayerNorm(d_bias)
+ #
+ self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
+ self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
+ self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
+ self.to_g = nn.Linear(d_in, n_head*d_hidden)
+ self.to_out = nn.Linear(n_head*d_hidden, d_in)
+
+ self.scaling = 1/math.sqrt(d_hidden)
+ self.h = n_head
+ self.dim = d_hidden
+
+ # self.reset_parameter()
+
+ # def reset_parameter(self):
+ # # query/key/value projection: Glorot uniform / Xavier uniform
+ # nn.init.xavier_uniform_(self.to_q.weight)
+ # nn.init.xavier_uniform_(self.to_k.weight)
+ # nn.init.xavier_uniform_(self.to_v.weight)
+
+ # # bias: normal distribution
+ # self.to_b = init_lecun_normal(self.to_b)
+
+ # # gating: zero weights, one biases (mostly open gate at the begining)
+ # nn.init.zeros_(self.to_g.weight)
+ # nn.init.ones_(self.to_g.bias)
+
+ # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
+ # nn.init.zeros_(self.to_out.weight)
+ # nn.init.zeros_(self.to_out.bias)
+
+ def forward(self, x, bias):
+ B, L = x.shape[:2]
+ #
+ x = self.norm_in(x)
+ bias = self.norm_bias(bias)
+ #
+ query = self.to_q(x).reshape(B, L, self.h, self.dim)
+ key = self.to_k(x).reshape(B, L, self.h, self.dim)
+ value = self.to_v(x).reshape(B, L, self.h, self.dim)
+ bias = self.to_b(bias) # (B, L, L, h)
+ gate = torch.sigmoid(self.to_g(x))
+ #
+ key = key * self.scaling
+ attn = einsum('bqhd,bkhd->bqkh', query, key)
+ attn = attn + bias
+ attn = F.softmax(attn, dim=-2)
+ #
+ out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
+ out = gate * out
+ #
+ out = self.to_out(out)
+ return out
+
+# MSA Attention (row/column) from AlphaFold architecture
+# class SequenceWeight(nn.Module):
+# def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
+# super(SequenceWeight, self).__init__()
+# self.h = n_head
+# self.dim = d_hidden
+# self.scale = 1.0 / math.sqrt(self.dim)
+
+# self.to_query = nn.Linear(d_msa, n_head*d_hidden)
+# self.to_key = nn.Linear(d_msa, n_head*d_hidden)
+# self.dropout = nn.Dropout(p_drop)
+
+# # self.reset_parameter()
+
+# # def reset_parameter(self):
+# # # query/key/value projection: Glorot uniform / Xavier uniform
+# # nn.init.xavier_uniform_(self.to_query.weight)
+# # nn.init.xavier_uniform_(self.to_key.weight)
+
+# def forward(self, msa):
+# B, N, L = msa.shape[:3]
+
+# tar_seq = msa[:,0]
+
+# q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
+# k = self.to_key(msa).view(B, N, L, self.h, self.dim)
+
+# q = q * self.scale
+# attn = einsum('bqihd,bkihd->bkihq', q, k)
+# attn = F.softmax(attn, dim=1)
+# return self.dropout(attn)
+
+class RowAttentionWithBias(nn.Module):
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
+ super().__init__()
+ self.norm_msa = nn.LayerNorm(d_msa)
+ self.norm_pair = nn.LayerNorm(d_pair)
+ #
+ # self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_b = nn.Linear(d_pair, n_head, bias=False)
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
+
+ self.scaling = 1/math.sqrt(d_hidden)
+ self.h = n_head
+ self.dim = d_hidden
+
+ def forward(self, msa, pair, mask = None): # TODO: make this as tied-attention
+ B, L = msa.shape[:2]
+ #
+ msa = self.norm_msa(msa)
+ pair = self.norm_pair(pair)
+ #
+ # seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
+ query = self.to_q(msa).reshape(B, L, self.h, self.dim) # (B, L, h, dim)
+ key = self.to_k(msa).reshape(B, L, self.h, self.dim)
+ value = self.to_v(msa).reshape(B, L, self.h, self.dim)
+ bias = self.to_b(pair) # (B, L, L, h)
+ gate = torch.sigmoid(self.to_g(msa))
+ #
+ # query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
+ key = key * self.scaling
+ attn = einsum('bqhd,bkhd->bqkh', query, key) # (B, L, L, h)
+ attn = attn + bias
+
+ if mask is not None:
+ mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
+ attn = attn * mask_re - 1e9 * (1-mask_re)
+
+ attn = F.softmax(attn, dim=-2)
+ #
+ out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
+ out = gate * out
+ #
+ out = self.to_out(out)
+ return out
+
+class ColAttention(nn.Module):
+ def __init__(self, d_msa=256, n_head=8, d_hidden=32):
+ super().__init__()
+ self.norm_msa = nn.LayerNorm(d_msa)
+ #
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
+
+ self.scaling = 1/math.sqrt(d_hidden)
+ self.h = n_head
+ self.dim = d_hidden
+
+ # self.reset_parameter()
+
+ # def reset_parameter(self):
+ # # query/key/value projection: Glorot uniform / Xavier uniform
+ # nn.init.xavier_uniform_(self.to_q.weight)
+ # nn.init.xavier_uniform_(self.to_k.weight)
+ # nn.init.xavier_uniform_(self.to_v.weight)
+
+ # # gating: zero weights, one biases (mostly open gate at the begining)
+ # nn.init.zeros_(self.to_g.weight)
+ # nn.init.ones_(self.to_g.bias)
+
+ # # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
+ # nn.init.zeros_(self.to_out.weight)
+ # nn.init.zeros_(self.to_out.bias)
+
+ def forward(self, msa, mask = None):
+ '''
+ msa (B,L,d_node)
+ '''
+ B, L = msa.shape[:2]
+ #
+ msa = self.norm_msa(msa)
+ #
+ query = self.to_q(msa).reshape(B, L, self.h, self.dim) # (B,L,H,D)
+ key = self.to_k(msa).reshape(B, L, self.h, self.dim)
+ value = self.to_v(msa).reshape(B, L, self.h, self.dim)
+ gate = torch.sigmoid(self.to_g(msa))
+ #
+ query = query * self.scaling
+ attn = einsum('bqhd,bkhd->bqkh', query, key) # (B,L,L,H)
+
+ if mask is not None:
+ mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
+ attn = attn * mask_re - 1e9 * (1-mask_re)
+
+ attn = F.softmax(attn, dim=-3) # (B,L,L,H)
+ #
+ out = einsum('bkqh,bkhd->bqhd', attn, value).reshape(B, L, -1) # (B,L,H*D)
+ out = gate * out
+ #
+ out = self.to_out(out)
+ return out
+
+# class MSAColGlobalAttention(nn.Module):
+# def __init__(self, d_msa=64, n_head=8, d_hidden=8):
+# super(MSAColGlobalAttention, self).__init__()
+# self.norm_msa = nn.LayerNorm(d_msa)
+# #
+# self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
+# self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
+# self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
+# self.to_g = nn.Linear(d_msa, n_head*d_hidden)
+# self.to_out = nn.Linear(n_head*d_hidden, d_msa)
+
+# self.scaling = 1/math.sqrt(d_hidden)
+# self.h = n_head
+# self.dim = d_hidden
+
+# self.reset_parameter()
+
+# def reset_parameter(self):
+# # query/key/value projection: Glorot uniform / Xavier uniform
+# nn.init.xavier_uniform_(self.to_q.weight)
+# nn.init.xavier_uniform_(self.to_k.weight)
+# nn.init.xavier_uniform_(self.to_v.weight)
+
+# # gating: zero weights, one biases (mostly open gate at the begining)
+# nn.init.zeros_(self.to_g.weight)
+# nn.init.ones_(self.to_g.bias)
+
+# # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
+# nn.init.zeros_(self.to_out.weight)
+# nn.init.zeros_(self.to_out.bias)
+
+# def forward(self, msa):
+# B, N, L = msa.shape[:3]
+# #
+# msa = self.norm_msa(msa)
+# #
+# query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
+# query = query.mean(dim=1) # (B, L, h, dim)
+# key = self.to_k(msa) # (B, N, L, dim)
+# value = self.to_v(msa) # (B, N, L, dim)
+# gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
+# #
+# query = query * self.scaling
+# attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
+# attn = F.softmax(attn, dim=-1)
+# #
+# out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
+# out = gate * out # (B, N, L, h*dim)
+# #
+# out = self.to_out(out)
+# return out
+
+# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
+class BiasedAxialAttention(nn.Module):
+ def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
+ super().__init__()
+ #
+ self.is_row = is_row
+ self.norm_pair = nn.LayerNorm(d_pair)
+ self.norm_bias = nn.LayerNorm(d_bias)
+
+ self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
+ self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
+ self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
+ self.to_g = nn.Linear(d_pair, n_head*d_hidden)
+ self.to_out = nn.Linear(n_head*d_hidden, d_pair)
+
+ self.scaling = 1/math.sqrt(d_hidden)
+ self.h = n_head
+ self.dim = d_hidden
+
+ def forward(self, pair, bias, mask = None):
+ '''
+ pair: (B, L, L, d_pair)
+ mask: (B, L)
+ '''
+
+ B, L = pair.shape[:2]
+
+ if self.is_row:
+ pair = pair.permute(0,2,1,3)
+ bias = bias.permute(0,2,1,3)
+
+ pair = self.norm_pair(pair)
+ bias = self.norm_bias(bias)
+
+ query = self.to_q(pair).reshape(B, L, L, self.h, self.dim) # (B, L, L, h, dim)
+ key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
+ value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
+ bias = self.to_b(bias) # (B, L, L, h)
+ gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
+
+ query = query * self.scaling
+ key = key / math.sqrt(L) # normalize for tied attention
+ attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention (B, L, L, h)
+ attn = attn + bias # apply bias
+ if mask is not None:
+ mask_temp = 1e-9 * (mask.type(torch.float) - 1) # (B,L)
+ attn = attn + mask_temp.unsqueeze(1).unsqueeze(-1)
+
+ attn = F.softmax(attn, dim=-2) # (B, L, L, h)
+
+ out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1) # (B, L, L, h*dim)
+ out = gate * out
+
+ out = self.to_out(out)
+ if self.is_row:
+ out = out.permute(0,2,1,3)
+ return out
+
diff --git a/models/add_module/__pycache__/Attention_module.cpython-310.pyc b/models/add_module/__pycache__/Attention_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3e4ee430887eac7fbb527c2fcc543aeca23b727
Binary files /dev/null and b/models/add_module/__pycache__/Attention_module.cpython-310.pyc differ
diff --git a/models/add_module/__pycache__/common.cpython-310.pyc b/models/add_module/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46b019925197752bfed21ac6317c5abf2a113ce7
Binary files /dev/null and b/models/add_module/__pycache__/common.cpython-310.pyc differ
diff --git a/models/add_module/__pycache__/egnn.cpython-310.pyc b/models/add_module/__pycache__/egnn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35cf5c55116cdd49176a0a09dce839fea30de646
Binary files /dev/null and b/models/add_module/__pycache__/egnn.cpython-310.pyc differ
diff --git a/models/add_module/__pycache__/model_utils.cpython-310.pyc b/models/add_module/__pycache__/model_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce37eef2966200272a2a4a0e917e51d6050623c5
Binary files /dev/null and b/models/add_module/__pycache__/model_utils.cpython-310.pyc differ
diff --git a/models/add_module/__pycache__/structure_module.cpython-310.pyc b/models/add_module/__pycache__/structure_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6a0ff1ccbfeb8ae139d57b4332515000ce518ee
Binary files /dev/null and b/models/add_module/__pycache__/structure_module.cpython-310.pyc differ
diff --git a/models/add_module/common.py b/models/add_module/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad140d80a60a0d11103cf90c8ab0eab74964ae4
--- /dev/null
+++ b/models/add_module/common.py
@@ -0,0 +1,217 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.nn import knn_graph
+
+
+class GaussianSmearing(nn.Module):
+ def __init__(self, start=0.0, stop=5.0, num_gaussians=50, fixed_offset=True):
+ super(GaussianSmearing, self).__init__()
+ self.start = start
+ self.stop = stop
+ self.num_gaussians = num_gaussians
+ if fixed_offset:
+ # customized offset
+ offset = torch.tensor([0, 1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.5, 4, 4.5, 5, 5.5, 6, 7, 8, 9, 10])
+ else:
+ offset = torch.linspace(start, stop, num_gaussians)
+ self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
+ self.register_buffer('offset', offset)
+
+ def __repr__(self):
+ return f'GaussianSmearing(start={self.start}, stop={self.stop}, num_gaussians={self.num_gaussians})'
+
+ def forward(self, dist):
+ dist = dist.view(-1, 1) - self.offset.view(1, -1)
+ return torch.exp(self.coeff * torch.pow(dist, 2))
+
+
+class AngleExpansion(nn.Module):
+ def __init__(self, start=1.0, stop=5.0, half_expansion=10):
+ super(AngleExpansion, self).__init__()
+ l_mul = 1. / torch.linspace(stop, start, half_expansion)
+ r_mul = torch.linspace(start, stop, half_expansion)
+ coeff = torch.cat([l_mul, r_mul], dim=-1)
+ self.register_buffer('coeff', coeff)
+
+ def forward(self, angle):
+ return torch.cos(angle.view(-1, 1) * self.coeff.view(1, -1))
+
+
+class Swish(nn.Module):
+ def __init__(self):
+ super(Swish, self).__init__()
+ self.beta = nn.Parameter(torch.tensor(1.0))
+
+ def forward(self, x):
+ return x * torch.sigmoid(self.beta * x)
+
+
+NONLINEARITIES = {
+ "tanh": nn.Tanh(),
+ "relu": nn.ReLU(),
+ "softplus": nn.Softplus(),
+ "elu": nn.ELU(),
+ "swish": Swish(),
+ 'silu': nn.SiLU()
+}
+
+
+class MLP(nn.Module):
+ """MLP with the same hidden dim across all layers."""
+
+ def __init__(self, in_dim, out_dim, hidden_dim, num_layer=2, norm=True, act_fn='relu', act_last=False):
+ super().__init__()
+ layers = []
+ for layer_idx in range(num_layer):
+ if layer_idx == 0:
+ layers.append(nn.Linear(in_dim, hidden_dim))
+ elif layer_idx == num_layer - 1:
+ layers.append(nn.Linear(hidden_dim, out_dim))
+ else:
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if layer_idx < num_layer - 1 or act_last:
+ if norm:
+ layers.append(nn.LayerNorm(hidden_dim))
+ layers.append(NONLINEARITIES[act_fn])
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def outer_product(*vectors):
+ for index, vector in enumerate(vectors):
+ if index == 0:
+ out = vector.unsqueeze(-1)
+ else:
+ out = out * vector.unsqueeze(1)
+ out = out.view(out.shape[0], -1).unsqueeze(-1)
+ return out.squeeze()
+
+
+def get_h_dist(dist_metric, hi, hj):
+ if dist_metric == 'euclidean':
+ h_dist = torch.sum((hi - hj) ** 2, -1, keepdim=True)
+ return h_dist
+ elif dist_metric == 'cos_sim':
+ hi_norm = torch.norm(hi, p=2, dim=-1, keepdim=True)
+ hj_norm = torch.norm(hj, p=2, dim=-1, keepdim=True)
+ h_dist = torch.sum(hi * hj, -1, keepdim=True) / (hi_norm * hj_norm)
+ return h_dist, hj_norm
+
+
+def get_r_feat(r, r_exp_func, node_type=None, edge_index=None, mode='basic'):
+ if mode == 'origin':
+ r_feat = r
+ elif mode == 'basic':
+ r_feat = r_exp_func(r)
+ elif mode == 'sparse':
+ src, dst = edge_index
+ nt_src = node_type[src] # [n_edges, 8]
+ nt_dst = node_type[dst]
+ r_exp = r_exp_func(r)
+ r_feat = outer_product(nt_src, nt_dst, r_exp)
+ else:
+ raise ValueError(mode)
+ return r_feat
+
+
+def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand, mask_protein_side):
+ # previous version has problems when ligand atom types are fixed
+ # (due to sorting randomly in case of same element)
+
+ batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
+ # sort_idx = batch_ctx.argsort()
+ sort_idx = torch.sort(batch_ctx, stable=True).indices
+
+ mask_ligand = torch.cat([
+ torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(),
+ torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(),
+ ], dim=0)[sort_idx]
+
+ mask_protein = torch.cat([
+ mask_protein_side,
+ torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(),
+ ], dim=0)[sort_idx]
+
+ batch_ctx = batch_ctx[sort_idx]
+ h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
+ pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
+
+ return h_ctx, pos_ctx, batch_ctx, mask_ligand, mask_protein, sort_idx
+
+
+def compose_context_prop(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
+ batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
+ sort_idx = batch_ctx.argsort()
+
+ mask_protein = torch.cat([
+ torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(),
+ torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(),
+ ], dim=0)[sort_idx]
+
+ batch_ctx = batch_ctx[sort_idx]
+ h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
+ pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
+
+ return h_ctx, pos_ctx, batch_ctx
+
+
+class ShiftedSoftplus(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.shift = torch.log(torch.tensor(2.0)).item()
+
+ def forward(self, x):
+ return F.softplus(x) - self.shift
+
+
+def hybrid_edge_connection(ligand_pos, protein_pos, k, ligand_index, protein_index):
+ # fully-connected for ligand atoms
+ dst = torch.repeat_interleave(ligand_index, len(ligand_index))
+ src = ligand_index.repeat(len(ligand_index))
+ mask = dst != src
+ dst, src = dst[mask], src[mask]
+ ll_edge_index = torch.stack([src, dst])
+
+ # knn for ligand-protein edges
+ ligand_protein_pos_dist = torch.unsqueeze(ligand_pos, 1) - torch.unsqueeze(protein_pos, 0)
+ ligand_protein_pos_dist = torch.norm(ligand_protein_pos_dist, p=2, dim=-1)
+ knn_p_idx = torch.topk(ligand_protein_pos_dist, k=k, largest=False, dim=1).indices
+ knn_p_idx = protein_index[knn_p_idx]
+ knn_l_idx = torch.unsqueeze(ligand_index, 1)
+ knn_l_idx = knn_l_idx.repeat(1, k)
+ pl_edge_index = torch.stack([knn_p_idx, knn_l_idx], dim=0)
+ pl_edge_index = pl_edge_index.view(2, -1)
+ return ll_edge_index, pl_edge_index
+
+
+def batch_hybrid_edge_connection(x, k, mask_ligand, batch, add_p_index=False):
+ batch_size = batch.max().item() + 1
+ batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index = [], [], []
+ with torch.no_grad():
+ for i in range(batch_size):
+ ligand_index = ((batch == i) & (mask_ligand == 1)).nonzero()[:, 0]
+ protein_index = ((batch == i) & (mask_ligand == 0)).nonzero()[:, 0]
+ ligand_pos, protein_pos = x[ligand_index], x[protein_index]
+ ll_edge_index, pl_edge_index = hybrid_edge_connection(
+ ligand_pos, protein_pos, k, ligand_index, protein_index)
+ batch_ll_edge_index.append(ll_edge_index)
+ batch_pl_edge_index.append(pl_edge_index)
+ if add_p_index:
+ all_pos = torch.cat([protein_pos, ligand_pos], 0)
+ p_edge_index = knn_graph(all_pos, k=k, flow='source_to_target')
+ p_edge_index = p_edge_index[:, p_edge_index[1] < len(protein_pos)]
+ p_src, p_dst = p_edge_index
+ all_index = torch.cat([protein_index, ligand_index], 0)
+ p_edge_index = torch.stack([all_index[p_src], all_index[p_dst]], 0)
+ batch_p_edge_index.append(p_edge_index)
+
+ if add_p_index:
+ edge_index = [torch.cat([ll, pl, p], -1) for ll, pl, p in zip(
+ batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index)]
+ else:
+ edge_index = [torch.cat([ll, pl], -1) for ll, pl in zip(batch_ll_edge_index, batch_pl_edge_index)]
+ edge_index = torch.cat(edge_index, -1)
+ return edge_index
diff --git a/models/add_module/egnn.py b/models/add_module/egnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3577c94fb7045870c9fb417a7777952410e6af
--- /dev/null
+++ b/models/add_module/egnn.py
@@ -0,0 +1,139 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_scatter import scatter_sum
+from torch_geometric.nn import radius_graph, knn_graph
+from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES
+
+
+class EnBaseLayer(nn.Module):
+ def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False):
+ super().__init__()
+ self.r_min = 0.
+ self.r_max = 10.
+ self.hidden_dim = hidden_dim
+ self.num_r_gaussian = num_r_gaussian
+ self.update_x = update_x
+ self.act_fn = act_fn
+ self.norm = norm
+ if num_r_gaussian > 1:
+ self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian)
+ self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim,
+ num_layer=2, norm=norm, act_fn=act_fn, act_last=True)
+ self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
+ if self.update_x:
+ # self.x_mlp = MLP(hidden_dim, 1, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn)
+ x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]]
+ layer = nn.Linear(hidden_dim, 1, bias=False)
+ torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
+ x_mlp.append(layer)
+ x_mlp.append(nn.Tanh())
+ self.x_mlp = nn.Sequential(*x_mlp)
+
+ self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn)
+
+ def forward(self, h, x, edge_index, edge_attr=None):
+ src, dst = edge_index
+ hi, hj = h[dst], h[src]
+ # \phi_e in Eq(3)
+ rel_x = x[dst] - x[src]
+ d_sq = torch.sum(rel_x ** 2, -1, keepdim=True)
+ if self.num_r_gaussian > 1:
+ d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8))
+ else:
+ d_feat = d_sq
+
+ if edge_attr is not None:
+ edge_feat = torch.cat([d_feat, edge_attr], -1)
+ else:
+ edge_feat = d_feat
+
+ mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1))
+ eij = self.edge_inf(mij)
+ mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0])
+
+ # h update in Eq(6)
+ h = h + self.node_mlp(torch.cat([mi, h], -1))
+ if self.update_x:
+ # x update in Eq(4)
+ xi, xj = x[dst], x[src]
+ # (xi - xj) / (\|xi - xj\| + C) to make it more stable
+ delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0)
+ x = x + delta_x
+
+ return h, x
+
+
+class EGNN(nn.Module):
+ def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn',
+ update_x=True, act_fn='silu', norm=False):
+ super().__init__()
+ # Build the network
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+ self.num_r_gaussian = num_r_gaussian
+ self.update_x = update_x
+ self.act_fn = act_fn
+ self.norm = norm
+ self.k = k
+ # self.k_seq = 16
+ self.cutoff = cutoff
+ self.cutoff_mode = cutoff_mode
+ self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian)
+ self.net = self._build_network()
+
+ def _build_network(self):
+ # Equivariant layers
+ layers = []
+ for l_idx in range(self.num_layers):
+ layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian,
+ update_x=self.update_x, act_fn=self.act_fn, norm=self.norm)
+ layers.append(layer)
+ return nn.ModuleList(layers)
+
+ def _connect_edge(self, x, batch, orgshape):
+ # if self.cutoff_mode == 'radius':
+ # edge_index = radius_graph(x, r=self.r, batch=batch, flow='source_to_target')
+ if self.cutoff_mode == 'knn':
+ edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target')
+
+ # mask = torch.zeros((x.shape[0],x.shape[0]), dtype = bool) # (B*L,B*L)
+ # mask[edge_index[0], edge_index[1]] = True
+ # B,L = orgshape
+ # i,j = torch.meshgrid(torch.arange(L),torch.arange(L))
+ # mask_seq = (torch.abs(i-j) < self.k_seq) # (L,L)
+ # mask_seq = torch.block_diag(*([mask_seq]*B)) # (B*L,B*L)
+ # mask = torch.logical_or(mask,mask_seq)
+
+ # edge_index = torch.where(mask)
+ # edge_index = (edge_index[0].to(x.device),edge_index[1].to(x.device))
+
+ elif self.cutoff_mode == 'hybrid':
+ pass
+ # edge_index = batch_hybrid_edge_connection(
+ # x, k=self.k, mask_ligand=mask_ligand, batch=batch, add_p_index=True)
+ else:
+ raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}')
+ return edge_index
+
+ def forward(self, node_embed, trans_t):
+
+ B,L = node_embed.shape[:2]
+
+ x = trans_t.reshape(B*L,-1)
+ h = node_embed.reshape(B*L,-1)
+
+ batch = []
+ for idx in range(B):
+ batch += [idx] * L
+ batch = torch.tensor(batch,device = node_embed.device)
+
+ all_x = [x]
+ all_h = [h]
+ for l_idx, layer in enumerate(self.net):
+ edge_index = self._connect_edge(x, batch, orgshape = (B,L))
+ h, x = layer(h, x, edge_index)
+ all_x.append(x)
+ all_h.append(h)
+ outputs = {'x': x, 'h': h}
+ return x.reshape(B,L,-1), h.reshape(B,L,-1)
diff --git a/models/add_module/model_utils.py b/models/add_module/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6125c28236558f73bc4a894cc8d3a257d4cc7ed
--- /dev/null
+++ b/models/add_module/model_utils.py
@@ -0,0 +1,162 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from opt_einsum import contract as einsum
+import copy
+# from utils.constants import *
+# from utils.structure_utils import rigid_from_3_points
+# from scipy.spatial.transform import Rotation as scipy_R
+
+# def init_lecun_normal(module):
+# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
+# normal = torch.distributions.normal.Normal(0, 1)
+
+# alpha = (a - mu) / sigma
+# beta = (b - mu) / sigma
+
+# alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
+# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
+
+# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
+# x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
+# x = torch.clamp(x, a, b)
+
+# return x
+
+# def sample_truncated_normal(shape):
+# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
+# return stddev * truncated_normal(torch.rand(shape))
+
+# module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
+# return module
+
+# def init_lecun_normal_param(weight):
+# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
+# normal = torch.distributions.normal.Normal(0, 1)
+
+# alpha = (a - mu) / sigma
+# beta = (b - mu) / sigma
+
+# alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
+# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
+
+# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
+# x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
+# x = torch.clamp(x, a, b)
+
+# return x
+
+# def sample_truncated_normal(shape):
+# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
+# return stddev * truncated_normal(torch.rand(shape))
+
+# weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
+# return weight
+
+# # for gradient checkpointing
+# def create_custom_forward(module, **kwargs):
+# def custom_forward(*inputs):
+# return module(*inputs, **kwargs)
+# return custom_forward
+
+# def get_clones(module, N):
+# return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+class Dropout(nn.Module):
+ # Dropout entire row or column
+ def __init__(self, broadcast_dim=None, p_drop=0.15):
+ super(Dropout, self).__init__()
+ # give ones with probability of 1-p_drop / zeros with p_drop
+ self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
+ self.broadcast_dim=broadcast_dim
+ self.p_drop=p_drop
+ def forward(self, x):
+ if not self.training: # no drophead during evaluation mode
+ return x
+ shape = list(x.shape)
+ if not self.broadcast_dim == None:
+ shape[self.broadcast_dim] = 1
+ mask = self.sampler.sample(shape).to(x.device).view(shape)
+
+ x = mask * x / (1.0 - self.p_drop)
+ return x
+
+def rbf(D, D_min = 0., D_max = 20., D_count = 36):
+ '''
+ D: (B,L,L)
+ '''
+ # Distance radial basis function
+ D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
+ D_mu = D_mu[None,:]
+ D_sigma = (D_max - D_min) / D_count
+ D_expand = torch.unsqueeze(D, -1)
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) # (B, L, L, D_count)
+ return RBF
+
+# def get_centrality(D, dist_max):
+ link_mat = (D < dist_max).type(torch.int).to(D.device)
+ link_mat -= torch.eye(link_mat.shape[1], dtype = torch.int).to(D.device)
+ cent_vec = torch.sum(link_mat, dim=-1, dtype = torch.int)
+ return cent_vec, link_mat
+
+def get_sub_graph(xyz, mask = None, dist_max = 20, kmin=32):
+ B, L = xyz.shape[:2]
+ device = xyz.device
+ D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)* 10 * dist_max # (B, L, L)
+ idx = torch.arange(L, device=device)[None,:] # (1, L)
+ seq = (idx[:,None,:] - idx[:,:,None]).abs() # (1, L, L)
+
+ if mask is not None:
+ mask_cross = mask.unsqueeze(2).type(torch.float32)* mask.unsqueeze(1).type(torch.float32) # (B, L, L)
+ D = D + 10 * dist_max * (1 - mask_cross)
+ seq = seq * mask_cross # (B, L, L)
+
+ seq_cond = torch.logical_and(seq > 0, seq < kmin)
+ cond = torch.logical_or(D < dist_max, seq_cond)
+ b,i,j = torch.where(cond) # return idx where is true
+
+ src = b*L+i
+ tgt = b*L+j
+ return src, tgt, (b,i,j)
+
+def scatter_add(src ,index ,dim_index, num_nodes): # sum by the index of src for a source node in the graph
+ out = src.new_zeros(num_nodes ,src.shape[1])
+ index = index.reshape(-1 ,1).expand_as(src)
+ return out.scatter_add_(dim_index, index, src)
+
+# # Load ESM-2 model
+# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
+# batch_converter = alphabet.get_batch_converter()
+# model.eval() # disables dropout for deterministic results
+
+# def get_pre_repr(seq):
+
+ # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
+ # data = [
+ # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
+ # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
+ # ("protein2 with mask","KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
+ # ("protein3", "K A I S Q"),
+ # ]
+ seq_string = ''.join([aa_321[num2aa[i]] for i in seq])
+ data = [("protein1", seq_string)]
+ batch_labels, batch_strs, batch_tokens = batch_converter(data)
+ batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
+
+ # Extract per-residue representations (on CPU)
+ with torch.no_grad():
+ results = model(batch_tokens, repr_layers=[33], return_contacts=True)
+
+ node_repr = results["representations"][33][0,1:-1,:]
+ pair_repr = results['attentions'][0,-1,:,1:-1,1:-1].permute(1,2,0)
+ # Generate per-sequence representations via averaging
+ # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
+ # sequence_representations = []
+ # for i, tokens_len in enumerate(batch_lens):
+ # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
+
+ # Look at the unsupervised self-attention map contact predictions
+ # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
+ # plt.matshow(attention_contacts[: tokens_len, : tokens_len])
+ return node_repr, pair_repr
\ No newline at end of file
diff --git a/models/add_module/structure_module.py b/models/add_module/structure_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f2560cc277fca326695a02156e87f313797f728
--- /dev/null
+++ b/models/add_module/structure_module.py
@@ -0,0 +1,943 @@
+import torch.nn as nn
+import torch
+from opt_einsum import contract as einsum
+from models.add_module.model_utils import scatter_add, Dropout
+from models.add_module.Attention_module import *
+from torch_scatter import scatter
+from models.utils import get_time_embedding
+import e3nn
+import math
+from typing import Optional, Callable, List, Sequence
+from openfold.utils.rigid_utils import Rigid
+
+
+class TensorProductConvLayer(torch.nn.Module):
+ def __init__(self, in_irreps, sh_irreps, out_irreps, d_pair, dropout=0.1, fc_factor=1, use_layer_norm=True,
+ init='normal',use_internal_weights=False):
+ super().__init__()
+ self.in_irreps = in_irreps
+ self.out_irreps = out_irreps
+ self.sh_irreps = sh_irreps
+ self.fc_factor = fc_factor
+
+ self.use_internal_weights = use_internal_weights
+ if self.use_internal_weights:
+ self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=True,
+ internal_weights=True)
+ else:
+ self.tp = e3nn.o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
+ self.fc = nn.Sequential(
+ nn.LayerNorm(d_pair) if use_layer_norm else nn.Identity(),
+ nn.Linear(d_pair, self.fc_factor * d_pair),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(self.fc_factor * d_pair, self.tp.weight_numel)
+ )
+
+ if init == 'final':
+ with torch.no_grad():
+ for para in self.fc.parameters():
+ para.fill_(0)
+ elif init != 'normal':
+ raise ValueError(f'Unknown init {init}')
+
+ def forward(self, attr1, attr2, edge_attr=None):
+ if self.use_internal_weights:
+ out = self.tp(attr1, attr2)
+ else:
+ out = self.tp(attr1, attr2, self.fc(edge_attr))
+ return out
+
+class E3_GNNlayer(nn.Module):
+ def __init__(self, e3_d_node_l0=32, e3_d_node_l1=8, d_rbf=64, d_node=256, d_pair=128,
+ dist_max=20, final = False, init='normal',dropout=0.1):
+ super().__init__()
+
+ self.dist_max = dist_max
+ self.d_rbf = d_rbf
+ self.e3_d_node_l0 = e3_d_node_l0
+ self.e3_d_node_l1 = e3_d_node_l1
+ self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3)
+
+ irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
+ # irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e")
+
+
+ self.final = final
+ if final:
+ # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
+ # irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
+ # irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
+ # irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
+ # irreps_hid1 = e3nn.o3.Irreps(str(2)+"x1o")
+ irreps_hid1 = e3nn.o3.Irreps(str(6)+"x0e")
+ else:
+ # irreps_output = irreps_input
+ irreps_hid1 = irreps_input
+ self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
+
+
+ # self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3)
+
+ self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2)
+ self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
+
+ self.tpc_conv = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_hid1, d_pair,init=init, dropout=dropout)
+
+
+ # self.h_hidden = e3nn.o3.Linear(irreps_hid1, irreps_hid1)
+ # self.h_out = e3nn.o3.Linear(irreps_hid1, irreps_output)
+ # self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair)
+ # self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair)
+
+ def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst,
+ edge_sh, mask = None):
+ '''
+ l1_feats: (B,L,3*self.e3_d_node_l1)
+ xyz: (B,L,3,3)
+ '''
+
+ # create graph
+ B, L = node.shape[:2]
+ num_nodes = B * L
+
+ # (num_nodes, irreps_input)
+ l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
+ node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1)
+ node_graph_total = node_graph_total.reshape(num_nodes,-1) # (B*L, e3_d_node_l0 + 3* e3_d_node_l1)
+
+ # xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA
+ # edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max)
+ # edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
+
+ # edge_length = edge_vec.norm(dim=1)
+ # rbf_dist = rbf(edge_length, D_count = self.d_rbf)
+ # edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1)
+ b,i,j = pair_index
+ edge_feats = pair[b,i,j]
+ edge_feats_total = edge_feats
+
+ # compute the queries (per node), keys (per edge) and values (per edge)
+ conv_out = self.tpc_conv(node_graph_total[edge_dst], edge_sh, edge_feats_total)
+ # scatter_out = scatter_add(conv_out, edge_dst, dim_index=0, num_nodes=num_nodes) # (num_nodes, irreps_output)
+ out = scatter(conv_out, edge_src, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output)
+
+ # out = self.h_hidden(out)
+ # out = self.h_out(out)
+ # out = self.h_hidden(value_out, edge_sh, edge_feats_total)
+ # out = self.h_out(out, edge_sh, edge_feats_total)
+ out = out.reshape(B,L,-1)
+
+ if self.final:
+ # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
+ # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
+
+ # return torch.concat([out[:,:,0:3] + out[:,:,6:9],
+ # out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
+
+ return torch.concat([out[:,:,0:3], out[:,:,3:6]],dim=-1) # (B,L,6)
+ else:
+ l0_feats = out[:,:,:self.e3_d_node_l0]
+ # l1_feats = out[:,:,self.e3_d_node_l0:]
+ # node = self.proj_node(l0_feats)
+ l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
+ node = self.proj_node(l0_feats) + node
+ return node, l1_feats
+
+class E3_transformer(nn.Module):
+ def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_rbf=64, d_node=256, d_pair=128,
+ dist_max=20, final = False):
+ super().__init__()
+
+ self.dist_max = dist_max
+ self.d_rbf = d_rbf
+ self.e3_d_node_l0 = e3_d_node_l0
+ self.e3_d_node_l1 = e3_d_node_l1
+ self.d_hidden = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3)
+
+ irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
+ irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o + "+str(e3_d_node_l1)+"x1e")
+ self.d_hidden_1 = torch.tensor(e3_d_node_l0 + e3_d_node_l1 * 3 + e3_d_node_l1 * 3)
+
+ irreps_query = irreps_hid1
+ irreps_key = irreps_hid1
+ irreps_value = irreps_hid1
+ self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2)
+
+ self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
+
+ # self.tpc_q = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_query, d_pair)
+ self.h_q = e3nn.o3.Linear(irreps_input, irreps_query) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False)
+
+ self.tpc_k = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_key, d_pair)
+ # self.tp_k = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_key, shared_weights=False)
+ # self.fc_k = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_k.weight_numel], act=torch.nn.functional.relu)
+
+ self.tpc_v = TensorProductConvLayer(irreps_input, self.irreps_sh, irreps_value, d_pair)
+ # self.tp_v = e3nn.o3.FullyConnectedTensorProduct(irreps_input, self.irreps_sh, irreps_value, shared_weights=False)
+ # self.fc_v = e3nn.nn.FullyConnectedNet([d_pair+d_rbf, d_pair+d_rbf, self.tp_v.weight_numel], act=torch.nn.functional.relu)
+
+ self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e")
+
+ self.final = final
+ if final:
+ # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
+ irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
+ else:
+ irreps_output = irreps_input
+ self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
+ self.h_hidden = e3nn.o3.Linear(irreps_value, irreps_value)
+ self.h_out = e3nn.o3.Linear(irreps_value, irreps_output)
+ # self.h_hidden = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_value, d_pair)
+ # self.h_out = TensorProductConvLayer(irreps_value, self.irreps_sh, irreps_output, d_pair)
+
+ def forward(self, node, pair, l1_feats, pair_index, edge_src, edge_dst,
+ edge_sh, mask = None):
+ '''
+ l1_feats: (B,L,3*self.e3_d_node_l1)
+ xyz: (B,L,3,3)
+ '''
+
+ # create graph
+ B, L = node.shape[:2]
+ num_nodes = B * L
+
+ # (num_nodes, irreps_input)
+ l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
+ node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1)
+ node_graph_total = node_graph_total.reshape(num_nodes,-1)
+
+ # xyz_graph = xyz[:,:,1,:].reshape(num_nodes,3) # only for CA
+ # edge_src, edge_dst, edge_feats = get_sub_graph(xyz[:,:,1,:], pair, mask = mask, dist_max=self.dist_max)
+ # edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
+
+ # edge_length = edge_vec.norm(dim=1)
+ # rbf_dist = rbf(edge_length, D_count = self.d_rbf)
+ # edge_feats_total = torch.concat([edge_feats, rbf_dist], dim=-1)
+ b,i,j = pair_index
+ edge_feats = pair[b,i,j]
+ edge_feats_total = edge_feats
+
+ # compute the queries (per node), keys (per edge) and values (per edge)
+ # q = self.h_q(node_graph_total)[edge_dst]
+ # k = self.tp_k(node_graph_total[edge_src], edge_sh, self.fc_k(edge_feats_total))
+ # v = self.tp_v(node_graph_total[edge_src], edge_sh, self.fc_v(edge_feats_total))
+
+ q = self.h_q(node_graph_total)[edge_dst]
+ # q = self.tpc_q(node_graph_total[edge_dst], edge_sh, edge_feats_total)
+ k = self.tpc_k(node_graph_total[edge_src], edge_sh, edge_feats_total)
+ v = self.tpc_v(node_graph_total[edge_src], edge_sh, edge_feats_total)
+
+ # compute the softmax (per edge)
+ exp = torch.exp(self.dot(q, k)/self.d_hidden_1.sqrt()) # compute the numerator (num_edges, 1)
+ # z = scatter_add(exp, edge_dst, dim_index=0, num_nodes=num_nodes) # compute the denominator (nodes, 1)
+ z = scatter(exp, edge_dst, dim=0, dim_size=num_nodes, reduce="sum") # compute the denominator (nodes, 1)
+ # z[z == 0] = 1 # to avoid 0/0 when all the neighbors are exactly at the cutoff
+ alpha = exp / (z[edge_dst] + 1e-5) # (num_edges, 1)
+
+ # value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output)
+ out = scatter(alpha * v, edge_dst, dim=0, dim_size=num_nodes, reduce="mean") # (num_nodes, irreps_output)
+
+ out = self.h_hidden(out)
+ out = self.h_out(out)
+ # out = self.h_hidden(value_out, edge_sh, edge_feats_total)
+ # out = self.h_out(out, edge_sh, edge_feats_total)
+ out = out.reshape(B,L,-1)
+
+ if self.final:
+ # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
+ # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
+ return torch.concat([out[:,:,0:3] + out[:,:,6:9],
+ out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
+ else:
+ l0_feats = out[:,:,:self.e3_d_node_l0]
+ l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
+ node = self.proj_node(l0_feats) + node
+ return node, l1_feats
+
+ # # /self.d_hidden_1.sqrt()
+ # v = self.tp_hidden(v, edge_sh, self.fc_hidden(edge_feats_total))
+ # v = self.tp_out(v, edge_sh, self.fc_out(edge_feats_total))
+ # value_out = scatter_add(alpha * v, edge_dst, dim_index=0, num_nodes=num_nodes) # (nodes, irreps_output)
+
+ # # print("before_out_max=",torch.max(value_out))
+ # out = self.h_hidden(value_out + self.res_connect(node_graph_total))
+ # out = self.h_out(out)
+ # out = out.reshape(B,L,-1)
+
+ #return out[:,:,6:9] + out[:,:,9:], out[:,:,:3] + out[:,:,3:6]
+
+class E3_transformer_no_adjacency(nn.Module):
+ def __init__(self, e3_d_node_l0=16, e3_d_node_l1=4, d_node=256, d_pair=128, no_heads=8, final = False):
+ super().__init__()
+
+ self.e3_d_node_l0 = e3_d_node_l0
+ self.e3_d_node_l1 = e3_d_node_l1
+ self.no_heads = no_heads
+ self.inf = 1e5
+
+ irreps_input = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+str(e3_d_node_l1)+"x1o")
+ irreps_hid1 = e3nn.o3.Irreps(str(e3_d_node_l0)+"x0e + "+ str(e3_d_node_l1)+"x1o + "+ str(e3_d_node_l1)+"x1e")
+ self.d_hid1 = irreps_hid1.dim
+
+ irreps_query = irreps_hid1
+ irreps_key = irreps_hid1
+ irreps_value = irreps_hid1
+
+ self.proj_l0 = nn.Linear(d_node, self.e3_d_node_l0)
+
+ self.h_q = e3nn.o3.Linear(irreps_input, irreps_query*self.no_heads) # same as o3.FullyConnectedTensorProduct(irreps_input, "1x0e", irreps_query, shared_weights=False)
+
+ # self.h_k = e3nn.o3.Linear(irreps_input, irreps_key*self.no_heads)
+ # self.h_v = e3nn.o3.Linear(irreps_input, irreps_value*self.no_heads)
+ self.tpc_k = TensorProductConvLayer(irreps_input, irreps_input, irreps_key*self.no_heads, d_pair,
+ use_internal_weights=True)
+ self.tpc_v = TensorProductConvLayer(irreps_input, irreps_input, irreps_value*self.no_heads, d_pair,
+ use_internal_weights=True)
+
+ # self.dot = e3nn.o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "1x0e")
+ self.linear_pair = nn.Linear(d_pair, self.no_heads)
+
+ self.final = final
+ if final:
+ # irreps_output = e3nn.o3.Irreps(str(6)+"x0e + "+str(2)+"x1o + "+str(2)+"x1e")
+ irreps_output = e3nn.o3.Irreps(str(2)+"x1o + "+str(2)+"x1e")
+ else:
+ irreps_output = irreps_input
+ self.proj_node = nn.Linear(self.e3_d_node_l0, d_node)
+
+ self.h_out = e3nn.o3.Linear(irreps_value*self.no_heads, irreps_output)
+ # self.h_out = TensorProductConvLayer(irreps_value*self.no_heads, self.irreps_sh, irreps_output, d_pair)
+
+ def forward(self, node, pair, l1_feats, mask = None):
+ '''
+ l1_feats: (B,L,3*self.e3_d_node_l1)
+ xyz: (B,L,3,3)
+ '''
+ B, L = node.shape[:2]
+
+ # (num_nodes, irreps_input)
+ l0_feats = self.proj_l0(node) # (B,L,self.e3_d_node_l0)
+ node_graph_total = torch.concat([l0_feats,l1_feats],dim=-1) # (B,L,d_hid1)
+
+ q = self.h_q(node_graph_total) # (B,L,d_hid1*no_heads)
+ q = torch.split(q, q.shape[-1] // self.no_heads, dim=-1)
+ q = torch.stack(q, dim=-1) # (B,L,d_hid1,no_heads)
+
+ # k = self.h_k(node_graph_total)
+ k = self.tpc_k(node_graph_total, node_graph_total)
+ k = torch.split(k, k.shape[-1] // self.no_heads, dim=-1)
+ k = torch.stack(k, dim=-1) # (B,L,d_hid1,no_heads)
+
+ # v = self.h_v(node_graph_total)
+ v = self.tpc_v(node_graph_total, node_graph_total)
+ v = torch.split(v, v.shape[-1] // self.no_heads, dim=-1)
+ v = torch.stack(v, dim=-1) # (B,L,d_hid1,no_heads)
+
+ # compute the softmax (per edge)
+ a = q.unsqueeze(-3) - k.unsqueeze(-4) # (B,L,L,d_hid1,no_heads)
+ a = torch.sum(a**2, dim=-2)/math.sqrt(3*self.d_hid1) # (B,L,L,no_heads) invariable
+ a = a + self.linear_pair(pair)
+
+ if mask is not None:
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) # (B,L,L)
+ square_mask = self.inf * (square_mask - 1)
+ a = a + square_mask.unsqueeze(-1)
+
+ a = torch.softmax(a, dim=-2) # (B,L,L,no_heads)
+ out = torch.matmul(
+ a.permute(0,3,1,2), # (B,no_heads,L,L)
+ v.permute(0,3,1,2), # (B,no_heads,L,d_hid1)
+ ) # (B,no_heads,L,d_hid1)
+ out = out.permute(0,2,1,3) # (B,L,no_heads,d_hid1)
+ out = out.reshape(B,L,-1) # (B,L,d_hid1*no_heads)
+
+ out = self.h_out(out)
+
+ if self.final:
+ # return torch.concat([out[:,:,0:3] + out[:,:,6:9] + out[:,:,12:15],
+ # out[:,:,3:6] + out[:,:,9:12]+ out[:,:,15:18]],dim=-1) # (B,L,6)
+ return torch.concat([out[:,:,0:3] + out[:,:,6:9],
+ out[:,:,3:6] + out[:,:,9:12]],dim=-1) # (B,L,6)
+ else:
+ l0_feats = out[:,:,:self.e3_d_node_l0]
+ l1_feats = out[:,:,self.e3_d_node_l0:] + l1_feats
+ node = self.proj_node(l0_feats) + node
+ return node, l1_feats
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+class E3_transformer_test(nn.Module):
+
+ def __init__(self,ipa_conf,inf: float = 1e5, eps: float = 1e-8,
+ e3_d_node_l1=4):
+
+ super().__init__()
+
+ self._ipa_conf = ipa_conf
+ self.c_s = ipa_conf.c_s
+ self.c_z = ipa_conf.c_z
+ self.c_hidden = ipa_conf.c_hidden
+ self.no_heads = ipa_conf.no_heads
+ self.no_qk_points = ipa_conf.no_qk_points
+ self.no_v_points = ipa_conf.no_v_points
+ self.inf = inf
+ self.eps = eps
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = self.c_hidden * self.no_heads
+ self.linear_q = nn.Linear(self.c_s, hc)
+ self.linear_kv = nn.Linear(self.c_s, 2 * hc)
+
+ hpq = self.no_heads * self.no_qk_points * 3
+ self.hpq = hpq//3
+ self.linear_q_points = nn.Linear(self.c_s, hpq)
+
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
+ self.hpkv = hpkv//3
+ self.linear_kv_points = nn.Linear(self.c_s, hpkv)
+
+ self.linear_b = nn.Linear(self.c_z, self.no_heads)
+ self.down_z = nn.Linear(self.c_z, self.c_z // 4)
+
+ # self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads)))
+ # ipa_point_weights_init_(self.head_weights)
+
+ concat_out_dim = (
+ self.c_z // 4 + self.c_hidden + self.no_v_points * 4
+ )
+ self.linear_out = nn.Linear(self.no_heads * concat_out_dim, self.c_s)
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+
+
+ self.use_e3_qkvo = False
+
+ irreps_3 = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o")
+ self.l1_feats_proj = e3nn.o3.Linear(e3nn.o3.Irreps(str(self.no_heads*self.no_v_points)+"x1o"),
+ irreps_3)
+ if self.use_e3_qkvo:
+ irreps_1 = e3nn.o3.Irreps("3x0e")
+ irreps_2 = e3nn.o3.Irreps("1x1o")
+ self.tp_q_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4)
+ # self.tp_q_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True,
+ # internal_weights=True)
+
+ self.tp_kv_pts = TensorProductConvLayer(irreps_1, irreps_3, irreps_2, 3,fc_factor=4)
+ # self.tp_kv_pts = e3nn.o3.FullyConnectedTensorProduct(irreps_1, irreps_3, irreps_2, shared_weights=True,
+ # internal_weights=True)
+
+ self.tp_o_pt_invert = TensorProductConvLayer(irreps_2, irreps_3, irreps_1, 1,fc_factor=4)
+ # self.tp_o_pt_invert = e3nn.o3.FullyConnectedTensorProduct(irreps_2, irreps_3, irreps_1, shared_weights=True,
+ # internal_weights=True)
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ l1_feats,
+ r: Rigid,
+ mask: torch.Tensor,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ if _offload_inference:
+ z = _z_reference_list
+ else:
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ s_org = s
+
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ if self.use_e3_qkvo:
+ q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1),q_pts) # [*, N_res, H * P_q, 3]
+ # q_pts = self.tp_q_pts(q_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpq,1)) # [*, N_res, H * P_q, 3]
+ else:
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
+ )
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ if self.use_e3_qkvo:
+ kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1),kv_pts) # [*, N_res, H * (P_q + P_v), 3]
+ # kv_pts = self.tp_kv_pts(kv_pts, l1_feats[:,:,None,:].repeat(1,1,self.hpkv,1)) # [*, N_res, H * (P_q + P_v), 3]
+ else:
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
+ )
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if(_offload_inference):
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ a = torch.matmul(
+ math.sqrt(1.0 / (3 * self.c_hidden)) *
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ # a *= math.sqrt(1.0 / (3 * self.c_hidden))
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ pt_att = pt_displacement ** 2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ # head_weights = self.softplus(self.head_weights).view(
+ # *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
+ # )
+ # head_weights = head_weights * math.sqrt(
+ # 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
+ # )
+ # pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(
+ a, v.transpose(-2, -3)
+ ).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ o_pt = torch.sum(
+ (
+ a[..., None, :, :, None]
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
+ ),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+
+ o_pt_l1 = o_pt.reshape(*o_pt.shape[:-3], -1) # [*, N_res, H * P_v * 3]
+ l1_feats = l1_feats + self.l1_feats_proj(o_pt_l1)
+
+ if self.use_e3_qkvo:
+ o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat(
+ 1,1,self.no_heads, self.no_v_points,1),
+ torch.sum(o_pt ** 2, dim=-1,keepdim=True))
+ # o_pt = self.tp_o_pt_invert(o_pt,l1_feats[:,:,None,None,:].repeat(
+ # 1,1,self.no_heads, self.no_v_points,1))
+ else:
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
+ o_pt_norm_feats = flatten_final_dims(
+ o_pt_dists, 2)
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if(_offload_inference):
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z // 4]
+ pair_z = self.down_z(z[0])
+ o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
+
+ # [*, N_res, H * C_z // 4]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat(
+ o_feats, dim=-1
+ )
+ )
+
+ s = s + s_org
+
+ return s, l1_feats
+
+class Energy_Adapter_Node(nn.Module):
+ def __init__(self, d_node=256, n_head=8, p_drop=0.1, ff_factor=1):
+ super().__init__()
+
+ # self.energy_proj = nn.Linear(1, d_node)
+ self.d_node = d_node
+ self.tfmr_layer = torch.nn.TransformerDecoderLayer(
+ d_model=d_node,
+ nhead=n_head,
+ dim_feedforward=ff_factor*d_node,
+ dropout=p_drop,
+ batch_first=True,
+ norm_first=False
+ )
+
+ def forward(self, node, energy, mask = None):
+ '''
+ energy: (B,)
+ '''
+ # energy_emb = self.energy_proj(energy = energy.unsqueeze(-1).unsqueeze(-1)) # (B,1,d_node)
+ energy_emb = get_time_embedding(
+ energy,
+ self.d_node,
+ max_positions=2056
+ )[:,None,:] # (B,1,d_node)
+
+ if mask is not None:
+ mask = (1 - mask).bool()
+ node = self.tfmr_layer(node, energy_emb, tgt_key_padding_mask = mask)
+
+ return node
+
+# class EMA():
+# def __init__(self, model, decay):
+# super().__init__()
+# self.model = model
+# self.decay = decay
+# self.device = None
+# self.shadow = {}
+# for name, param in self.model.named_parameters():
+# if param.requires_grad:
+# self.shadow[name] = param.data.clone().detach()
+
+# def to(self, device):
+# self.shadow = {name: param.to(device) for name, param in self.shadow.items()}
+# self.device = device
+
+# def update(self):
+# for name, param in self.model.named_parameters():
+# if param.requires_grad:
+# assert name in self.shadow
+# new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
+# self.shadow[name] = new_average.clone()
+
+# def apply_shadow(self):
+# for name, param in self.model.named_parameters():
+# if param.requires_grad:
+# assert name in self.shadow
+# param.data.copy_(self.shadow[name])
+# self.model.load_state_dict(self.shadow)
+
+class Node_update(nn.Module):
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1):
+ super().__init__()
+ self.d_hidden = d_msa // n_head
+ self.ff_factor = ff_factor
+ self.norm_node = nn.LayerNorm(d_msa)
+ self.norm_pair = nn.LayerNorm(d_pair)
+
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
+ self.row_attn = RowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
+ n_head=n_head, d_hidden=self.d_hidden)
+
+ self.col_attn = ColAttention(d_msa=d_msa, n_head=n_head, d_hidden=self.d_hidden)
+ self.ff = FeedForwardLayer(d_msa, self.ff_factor, p_drop=p_drop)
+
+ def forward(self, msa, pair, mask = None):
+ '''
+ Inputs:
+ - msa: MSA feature (B, L, d_msa)
+ - pair: Pair feature (B, L, L, d_pair)
+ - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
+ - xyz: xyz coordinates (B, L, n_atom, 3)
+ Output:
+ - msa: Updated MSA feature (B, L, d_msa)
+ '''
+ B, L = msa.shape[:2]
+
+ # prepare input bias feature by combining pair & coordinate info
+ msa = self.norm_node(msa)
+ pair = self.norm_pair(pair)
+
+ msa = msa + self.drop_row(self.row_attn(msa, pair, mask=mask))
+ msa = msa + self.col_attn(msa, mask=mask)
+ msa = msa + self.ff(msa)
+
+ return msa
+
+# class Node2Node(nn.Module):
+# def __init__(self, d_node=256, d_pair=128, n_head = 8, dist_max=40,
+# d_rbf=64, SE3_param=None, p_drop=0.1):
+# super().__init__()
+
+# # initial node & pair feature process
+# d_node_l0 = d_node // 2
+
+
+# self.dist_max = dist_max
+# self.norm_node = nn.LayerNorm(d_node_l0)
+# self.proj_node = nn.Linear(d_node, d_node_l0)
+# self.norm_pair = nn.LayerNorm(d_pair)
+# self.proj_pair = nn.Linear(d_pair, d_pair)
+
+# self.d_node_l0 = d_node_l0
+# self.d_node = d_node
+# self.proj_node_t2 = nn.Linear(d_node_l0, d_node)
+# self.node2node_temp = Node2Node_temp(d_msa=d_node, d_pair=d_pair,
+# n_head=n_head, d_hidden=d_node//n_head, p_drop=p_drop)
+
+# self.E3 = E3_transformer(e3_d_node_l0=d_node_l0, e3_d_node_l1=48, d_rbf=d_rbf,
+# d_pair=d_pair, dist_max = dist_max, keep_dim = True)
+
+# self.alpha_pred = AlphaPred(d_node=d_node,p_drop=p_drop)
+# # @torch.cuda.amp.autocast(enabled=False)
+# def forward(self, node, pair, l1_feats, xyz, mask = None):
+# """
+# input:
+# msa: node embeddings (B, L, d_node)
+# pair: residue pair embeddings (B, L, L, d_pair)
+# xyz: initial BB coordinates (B, L, 3, 3)
+# R_in: (B,L,3,3)
+# """
+# B, L = node.shape[:2]
+# node = self.node2node_temp(node, pair, mask = mask)
+
+# node = self.norm_node(self.proj_node(node))
+# pair = self.norm_pair(self.proj_pair(pair))
+
+# out = self.E3(node, pair, l1_feats, xyz, mask = mask)
+
+# node = self.proj_node_t2(out[...,:self.d_node_l0])
+
+# # l1_feats = l1_feats + out[...,self.d_node_l0:].reshape(B*L,48,3)
+# l1_feats = out[...,self.d_node_l0:].reshape(B*L,48,3)
+
+# alpha = self.alpha_pred(node, mask=mask)
+# return node, l1_feats, alpha
+
+class Node2Pair_bias(nn.Module):
+ def __init__(self, d_node=256, d_hidden=32, d_pair=128, d_head = 16):
+ super().__init__()
+ self.d_hidden = d_hidden
+ self.d_head = d_head
+ self.norm = nn.LayerNorm(d_node)
+ self.proj_left = nn.Linear(d_node, d_hidden)
+ self.proj_right = nn.Linear(d_node, d_hidden)
+ self.proj_out = nn.Linear(d_head, d_pair)
+
+ # self.reset_parameter()
+
+ # def reset_parameter(self):
+ # # normal initialization
+ # self.proj_left = init_lecun_normal(self.proj_left)
+ # self.proj_right = init_lecun_normal(self.proj_right)
+ # nn.init.zeros_(self.proj_left.bias)
+ # nn.init.zeros_(self.proj_right.bias)
+
+ # # zero initialize output
+ # nn.init.zeros_(self.proj_out.weight)
+ # nn.init.zeros_(self.proj_out.bias)
+
+ def forward(self, node, mask = None):
+ B, L = node.shape[:2]
+ node = self.norm(node)
+ if mask is not None:
+ mask_re = mask[...,None].type(torch.float32)
+ node = node * mask_re
+
+ left = self.proj_left(node).reshape(B, L, self.d_head, -1)
+ right = self.proj_right(node)
+ right = (right / np.sqrt(self.d_hidden)).reshape(B, L, self.d_head, -1)
+ out = einsum('bihk,bjhk->bijh', left, right) # (B, L, L, d_head)
+ out = self.proj_out(out)
+ return out
+
+class Pair_update(nn.Module):
+ def __init__(self, d_node=256, d_pair=128, n_head=8, p_drop=0.1, ff_factor=1):
+ super().__init__()
+ self.d_hidden = d_pair//n_head
+ self.d_pair_bias = d_pair
+ self.d_node = d_node
+ self.ff_factor = ff_factor
+
+ self.proj_bias = nn.Linear(2*self.d_node, self.d_pair_bias)
+
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) # # Dropout entire row or column for broadcast_dim
+ self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
+
+ self.row_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=True)
+ self.col_attn = BiasedAxialAttention(d_pair, self.d_pair_bias, n_head, self.d_hidden, p_drop=p_drop, is_row=False)
+
+ self.ff = FeedForwardLayer(d_pair, self.ff_factor, p_drop=p_drop)
+
+ def forward(self, node, pair, mask = None):
+
+ B, L = pair.shape[:2]
+ pair_bias_total = torch.cat([
+ torch.tile(node[:, :, None, :], (1, 1, L, 1)),
+ torch.tile(node[:, None, :, :], (1, L, 1, 1)),
+ ], axis=-1) # (B, L, L, 2*d_node)
+ pair_bias_total = self.proj_bias(pair_bias_total) # (B, L, L, d_pair_bias)
+
+ if mask is not None:
+ mask_re = torch.matmul(mask.unsqueeze(2).type(torch.float32), mask.unsqueeze(1).type(torch.float32))[...,None]
+ pair = pair * mask_re
+ pair_bias_total = pair_bias_total * mask_re
+
+ pair = pair + self.drop_row(self.row_attn(pair, pair_bias_total, mask))
+ pair = pair + self.drop_col(self.col_attn(pair, pair_bias_total, mask))
+ pair = pair + self.ff(pair)
+
+ return pair
+
+class FeatBlock(nn.Module):
+ def __init__(self, d_node=256, d_pair=128, dist_max = 40, T=500, d_time=64, L_max=64,
+ d_cent=36, d_rbf=36, p_drop=0.1):
+ super().__init__()
+ self.dist_max = dist_max
+ self.d_rbf = d_rbf
+ self.d_time = d_time
+ # self.time_emb_layer = nn.Embedding(T + 1, d_time)
+ # self.cent_emb_layer = nn.Embedding(L_max + 1, d_cent)
+ # self.cent_norm = nn.LayerNorm(d_cent)
+ # self.link_emb_layer = nn.Embedding(2, d_cent)
+ self.seq_emb_layer = nn.Embedding(22,d_node)
+ self.node_in_proj = nn.Sequential(
+ nn.Linear(d_node+d_node, d_node),
+ nn.ReLU(),
+ nn.Dropout(p_drop),
+ nn.Linear(d_node, d_node)
+ )
+
+ self.pair_in_proj = nn.Sequential(
+ nn.Linear(d_pair+d_rbf, d_pair),
+ nn.ReLU(),
+ nn.Dropout(p_drop),
+ nn.Linear(d_pair, d_pair)
+ )
+
+ def forward(self, node, pair, seq, xyz, t_emb, mask = None):
+ '''
+ input:
+ t_emb: timestep (B, d_time)
+ mask: (B,L)
+ '''
+ B,L = xyz.shape[:2]
+
+ # N, Ca, C = xyz[:, :, 0, :], xyz[:, :, 1, :], xyz[:, :, 2, :] # [batch, L, 3]
+ # R_true, __ = rigid_from_3_points(N, Ca, C)
+ # q_rotate = torch.tensor(scipy_R.from_matrix(R_true.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_true.shape[:3]).to(R_true.device).type(torch.float32) # [batch, L, 3]
+ # q_norm = torch.linalg.norm(q_rotate, dim=-1) # [batch, L]
+ # q_node_emb = []
+ # for i_d in range(1, self.d_rbf//2 + 1):
+ # q_node_emb = q_node_emb + [torch.sin(i_d * q_norm)[...,None],torch.cos(i_d * q_norm)[...,None]]
+ # q_node_emb = torch.concat(q_node_emb,dim = -1) # [batch, L, d_rbf]
+
+ # xyz_norm = torch.linalg.norm(xyz[:,:,1], dim=-1) # [batch, L]
+ # xyz_node_emb = []
+ # for i_d in range(1, self.d_rbf//2 + 1):
+ # xyz_node_emb = xyz_node_emb + [torch.sin(i_d * xyz_norm)[...,None],torch.cos(i_d * xyz_norm)[...,None]]
+ # xyz_node_emb = torch.concat(xyz_node_emb,dim = -1) # [batch, L, d_rbf]
+
+ seq_emb = self.seq_emb_layer(seq) # (B, L, d_seq)
+
+ # centrality, link_mat = get_centrality(dist_mat, self.dist_max) # (B, L)
+ # cent_emb = self.cent_norm(self.cent_emb_layer(centrality)) # (B, L, d_cent)
+ # link_mat_emb = self.link_emb_layer(link_mat) # (B, L, L, d_cent)
+
+ # t_emb1 = t_emb[:, None, :].repeat(1, node.shape[1], 1) # (B, L, d_time)
+ # print(node.shape,seq_emb.shape,t_emb1.shape,q_node_emb.shape,xyz_node_emb.shape)
+ node = torch.concat([node, seq_emb], dim = -1)
+ node = self.node_in_proj(node) # (B, L, d_node)
+
+ dist_mat = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) # (B, L, L)
+ rbf_dist = rbf(dist_mat, D_count = self.d_rbf) # (B, L, L, d_rbf)
+
+ # R_pair = torch.matmul(R_true[:,:,None,...], R_true[:,None,...].transpose(-1,-2)) # (B,L,L,3,3)
+ # q_pair = torch.tensor(scipy_R.from_matrix(R_pair.cpu().numpy().reshape(-1,3,3)).as_rotvec()).reshape(R_pair.shape[:4]).to(R_pair.device).type(torch.float32) # [batch, L, L, 3]
+ # q_pair_emb = []
+ # for i_d in range(1, self.d_rbf//2 + 1):
+ # q_pair_emb = q_pair_emb + [torch.sin(i_d * q_pair),torch.cos(i_d * q_pair)]
+ # q_pair_emb = torch.concat(q_pair_emb,dim = -1) # [batch, L, L, 3 * d_rbf]
+
+ # t_emb2 = t_emb[:, None, None, :].repeat(1, pair.shape[1], pair.shape[2], 1) # (B, L, L, d_time)
+ pair = torch.concat([pair, rbf_dist], dim = -1)
+ pair = self.pair_in_proj(pair) # (B, L, L, d_pair)
+
+ if mask is not None:
+ mask = mask.type(torch.float32) # (B, L)
+ node = node * mask[..., None] # (B, L, d_node)
+ mask_cross = torch.matmul(mask.unsqueeze(2), mask.unsqueeze(1)) # (B, L, L)
+ pair = pair * mask_cross[...,None] # (B, L, L, d_pair)
+
+ return node, pair, rbf_dist # node: (B, L, d_node) pair: (B, L, L, d_pair)
diff --git a/models/edge_embedder.py b/models/edge_embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1fd41e2f02c98583b183479780e4a104d9b996
--- /dev/null
+++ b/models/edge_embedder.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+from models.utils import get_index_embedding, calc_distogram
+from models.add_module.model_utils import rbf
+
+class EdgeEmbedder(nn.Module):
+
+ def __init__(self, module_cfg):
+ super(EdgeEmbedder, self).__init__()
+ self._cfg = module_cfg
+
+ self.c_s = self._cfg.c_s
+ self.c_p = self._cfg.c_p
+ self.feat_dim = self._cfg.feat_dim
+
+ self.linear_s_p = nn.Linear(self.c_s, self.feat_dim)
+ self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim)
+
+ self.num_cross_heads = 32
+ self.c_pair_pre = 20
+ total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2 + self.c_pair_pre
+ # total_edge_feats = self.num_cross_heads + self._cfg.num_bins * 2 + self.c_pair_pre
+ self.edge_embedder = nn.Sequential(
+ nn.Linear(total_edge_feats, self.c_p),
+ nn.ReLU(),
+ nn.Dropout(self._cfg.dropout),
+ nn.Linear(self.c_p, self.c_p),
+ )
+
+ def embed_relpos(self, pos):
+ rel_pos = pos[:, :, None] - pos[:, None, :]
+ pos_emb = get_index_embedding(rel_pos, self._cfg.feat_dim, max_len=2056)
+ return self.linear_relpos(pos_emb)
+
+ def _cross_concat(self, feats_1d, num_batch, num_res):
+ '''
+ output: (B, L, L, 2*d_node)
+
+ '''
+ return torch.cat([
+ torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)),
+ torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)),
+ ], dim=-1).float().reshape([num_batch, num_res, num_res, -1])
+
+ def forward(self, s, t, sc_t, pair_repr_pre, p_mask):
+ '''
+ s: same as node, (B, L, d_node)
+
+ '''
+ num_batch, num_res, d_node = s.shape
+ p_i = self.linear_s_p(s) # (B,L,feat_dim)
+ cross_node_feats = self._cross_concat(p_i, num_batch, num_res)
+
+ pos = torch.arange(
+ num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1)
+ relpos_feats = self.embed_relpos(pos)
+ # node_split_heads = s.reshape(num_batch, num_res, d_node//self.num_cross_heads, self.num_cross_heads) # (B,L,d_node//num_head,num_head)
+ # cross_node_feats =torch.einsum('bijh,bkjh->bikh', node_split_heads, node_split_heads) # (B,L,L,num_head)
+
+ pos = t
+ dists_2d = torch.linalg.norm(
+ pos[:, :, None, :] - pos[:, None, :, :], axis=-1) # (B,L,L)
+
+ dist_feats = rbf(dists_2d, D_min = 0., D_max = self._cfg.max_dist, D_count = self._cfg.num_bins)
+ # dist_feats = rbf(dists_2d, D_min = 0., D_max = 20.0, D_count = self._cfg.num_bins)
+
+ pos = sc_t
+ dists_2d = torch.linalg.norm(
+ pos[:, :, None, :] - pos[:, None, :, :], axis=-1) # (B,L,L)
+
+ sc_feats = rbf(dists_2d, D_min = 0., D_max = self._cfg.max_dist, D_count = self._cfg.num_bins)
+ # sc_feats = rbf(dists_2d, D_min = 0., D_max = 20.0, D_count = self._cfg.num_bins)
+
+ all_edge_feats = torch.concat(
+ [cross_node_feats, relpos_feats, dist_feats, sc_feats, pair_repr_pre], dim=-1)
+ # all_edge_feats = torch.concat(
+ # [cross_node_feats, dist_feats, sc_feats, pair_repr_pre], dim=-1)
+ edge_feats = self.edge_embedder(all_edge_feats) # (B,L,L,c_p)
+ return edge_feats
\ No newline at end of file
diff --git a/models/flow_model.py b/models/flow_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16a4549d5531a7c9b23826479f6f9e35643f80d
--- /dev/null
+++ b/models/flow_model.py
@@ -0,0 +1,357 @@
+"""Neural network architecture for the flow model."""
+import torch
+from torch import nn
+from models.node_embedder import NodeEmbedder
+from models.edge_embedder import EdgeEmbedder
+from models.add_module.structure_module import Node_update, Pair_update, E3_transformer, TensorProductConvLayer, E3_GNNlayer, E3_transformer_no_adjacency, E3_transformer_test, Energy_Adapter_Node
+from models.add_module.egnn import EGNN
+from models.add_module.model_utils import get_sub_graph
+from models import ipa_pytorch
+from data import utils as du
+from data import all_atom
+import e3nn
+
+class FlowModel(nn.Module):
+
+ def __init__(self, model_conf):
+ super(FlowModel, self).__init__()
+
+ self._model_conf = model_conf
+ self._ipa_conf = model_conf.ipa
+ self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE)
+ self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE)
+ self.node_embedder = NodeEmbedder(model_conf.node_features)
+ self.edge_embedder = EdgeEmbedder(model_conf.edge_features)
+ # aatype_total = 21
+ # self.aatype_pred_layer = nn.Sequential(
+ # nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
+ # nn.ReLU(),
+ # nn.Dropout(self._model_conf.dropout),
+ # nn.Linear(self._ipa_conf.c_s, aatype_total)
+ # )
+
+ # self.use_np_update = False
+ # self.use_e3_transformer = False
+ # self.use_torsions = True
+ # self.use_mid_bb_update = False
+ # self.use_mid_bb_update_e3 = False
+ # self.use_adapter_node = True
+
+
+ self.use_torsions = model_conf.use_torsions
+ self.use_adapter_node = model_conf.use_adapter_node
+ self.use_mid_bb_update = model_conf.use_mid_bb_update
+ self.use_mid_bb_update_e3 = model_conf.use_mid_bb_update_e3
+ self.use_e3_transformer = model_conf.use_e3_transformer
+
+
+ if self.use_adapter_node:
+ self.energy_adapter = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout)
+
+ if self.use_torsions:
+ self.num_torsions = 7
+ self.torsions_pred_layer1 = nn.Sequential(
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
+ nn.ReLU(),
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
+ )
+ self.torsions_pred_layer2 = nn.Linear(self._ipa_conf.c_s, self.num_torsions * 2)
+
+ if self.use_e3_transformer or self.use_mid_bb_update_e3:
+ self.max_dist = self._model_conf.edge_features.max_dist
+ self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) #
+ input_d_l1 = 3
+ e3_d_node_l1 = 32
+ irreps_l1_in = e3nn.o3.Irreps(str(input_d_l1)+"x1o")
+ irreps_l1_out = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o")
+
+ self.proj_l1_init = e3nn.o3.Linear(irreps_l1_in, irreps_l1_out)
+
+ # Attention trunk
+ self.trunk = nn.ModuleDict()
+ for b in range(self._ipa_conf.num_blocks):
+ if self.use_e3_transformer:
+ # self.trunk[f'ipa_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
+ # e3_d_node_l1 = e3_d_node_l1)
+ # self.trunk[f'ipa_{b}'] = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1,
+ # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
+ # final = False)
+ # self.trunk[f'ipa_{b}'] = E3_transformer_no_adjacency(e3_d_node_l0=4*e3_d_node_l1, e3_d_node_l1=e3_d_node_l1,
+ # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
+ # no_heads=16, final = False)
+ self.trunk[f'ipa_{b}'] = E3_transformer_test(self._ipa_conf,e3_d_node_l1 = e3_d_node_l1)
+ else:
+ self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
+ self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
+
+ if self.use_adapter_node:
+ self.trunk[f'energy_adapter_{b}'] = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout)
+
+
+ self.trunk[f'egnn_{b}'] = EGNN(hidden_dim = model_conf.node_embed_size)
+
+
+ # if self.use_np_update:
+ # self.trunk[f'node_transition_{b}'] = Node_update(d_msa=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size)
+ # else:
+ # tfmr_in = self._ipa_conf.c_s
+ # tfmr_layer = torch.nn.TransformerEncoderLayer(
+ # d_model=tfmr_in,
+ # nhead=self._ipa_conf.seq_tfmr_num_heads,
+ # dim_feedforward=tfmr_in,
+ # batch_first=True,
+ # dropout=self._model_conf.dropout,
+ # norm_first=False
+ # )
+ # self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
+ # tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
+ # self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
+ # tfmr_in, self._ipa_conf.c_s, init="final")
+
+ # self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
+ # c=self._ipa_conf.c_s)
+
+ if self.use_mid_bb_update:
+ if self.use_mid_bb_update_e3:
+ self.trunk[f'bb_update_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
+ e3_d_node_l1 = e3_d_node_l1, final = True,
+ init='final',dropout=0.0)
+ else:
+ self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
+ self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0)
+
+ # if b < self._ipa_conf.num_blocks-1:
+ # # No edge update on the last block.
+ # if self.use_np_update:
+ # self.trunk[f'edge_transition_{b}'] = Pair_update(d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size)
+ # else:
+ # edge_in = self._model_conf.edge_embed_size
+ # self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
+ # node_embed_size=self._ipa_conf.c_s,
+ # edge_embed_in=edge_in,
+ # edge_embed_out=self._model_conf.edge_embed_size,
+ # )
+
+
+
+
+ if self.use_mid_bb_update_e3:
+ # self.bb_update_layer = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1,
+ # d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
+ # final = True)
+ self.bb_update_layer = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
+ e3_d_node_l1 = e3_d_node_l1, final = True,
+ init='final',dropout=0.0)
+ elif self.use_e3_transformer:
+
+ irreps_1 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o")
+ irreps_2 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o")
+ # irreps_3 = e3nn.o3.Irreps("2x1o + 2x1e")
+ irreps_3 = e3nn.o3.Irreps("6x0e")
+ self.bb_tpc = TensorProductConvLayer(irreps_1, irreps_2, irreps_3, self._ipa_conf.c_s, fc_factor=1,
+ init='final', dropout=0.1)
+
+ # self.bb_update_layer = ipa_pytorch.BackboneUpdate(
+ # self._ipa_conf.c_s+3*e3_d_node_l1, use_rot_updates=True, dropout=0.0)
+
+ else:
+ self.bb_update_layer = ipa_pytorch.BackboneUpdate(
+ self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0)
+
+ def forward(self, input_feats, use_mask_aatype = False):
+ '''
+ note: B and L are changing during training
+ input_feats.keys():
+ 'aatype' (B,L)
+ 'res_mask' (B,L)
+ 't' (B,1)
+ 'trans_1' (B,L,3)
+ 'rotmats_1' (B,L,3,3)
+ 'trans_t' (B,L,3)
+ 'rotmats_t' (B,L,3,3)
+ '''
+
+ node_mask = input_feats['res_mask']
+ edge_mask = node_mask[:, None] * node_mask[:, :, None]
+ continuous_t = input_feats['t']
+ trans_t = input_feats['trans_t']
+ rotmats_t = input_feats['rotmats_t']
+ aatype = input_feats['aatype']
+
+ if self.use_adapter_node:
+ energy = input_feats['energy'] # (B,)
+
+
+ if self.use_e3_transformer or self.use_mid_bb_update_e3:
+ xyz_t = all_atom.to_atom37(trans_t, rotmats_t)[:, :, :3, :] # (B,L,3,3)
+ edge_src, edge_dst, pair_index = get_sub_graph(xyz_t[:,:,1,:], mask = node_mask,
+ dist_max=self.max_dist, kmin=32)
+ xyz_graph = xyz_t[:,:,1,:].reshape(-1,3)
+ edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
+ edge_sh = e3nn.o3.spherical_harmonics(self.irreps_sh, edge_vec,
+ normalize=True, normalization='component') # (num_edges, irreps_sh)
+
+ # if use_mask_aatype:
+ # factor = 0.15
+ # mask_token = 21
+ # mask_aatype = torch.ones(aatype.shape)
+ # for b in range(aatype.shape[0]):
+ # mask_aa_num = torch.tensor(random.random()*factor*aatype.shape[1]).int()
+ # indices = random.sample(range(aatype.shape[1]), mask_aa_num)
+ # mask_aatype[b,indices] = 0
+ # mask_aatype = mask_aatype.to(aatype.device)
+ # aatype = (aatype * mask_aatype + (1-mask_aatype) * mask_token).int()
+
+ node_repr_pre = input_feats['node_repr_pre']
+ pair_repr_pre = input_feats['pair_repr_pre']
+
+ if 'trans_sc' not in input_feats:
+ trans_sc = torch.zeros_like(trans_t)
+ else:
+ trans_sc = input_feats['trans_sc']
+
+ # Initialize node and edge embeddings, there is only seq_pos info in node, and there is bond-len info in edge
+ init_node_embed = self.node_embedder(continuous_t, aatype, node_repr_pre, node_mask)
+ init_node_embed = init_node_embed * node_mask[..., None]
+
+ if self.use_adapter_node:
+ init_node_embed = self.energy_adapter(init_node_embed, energy, mask = node_mask)
+ init_node_embed = init_node_embed * node_mask[..., None]
+
+ init_edge_embed = self.edge_embedder(
+ init_node_embed, trans_t, trans_sc, pair_repr_pre, edge_mask)
+ init_edge_embed = init_edge_embed * edge_mask[..., None]
+
+ if self.use_e3_transformer or self.use_mid_bb_update_e3:
+ l1_feats = torch.concat([xyz_t[:,:,0,:]-xyz_t[:,:,1,:],
+ xyz_t[:,:,1,:],
+ xyz_t[:,:,2,:]-xyz_t[:,:,1,:]], dim=-1) # (B,L,3*3)
+
+ l1_feats = self.proj_l1_init(l1_feats) # (B,L,4*3)
+
+ l1_feats = l1_feats * node_mask[..., None]
+
+ # Initial rigids
+ curr_rigids = du.create_rigid(rotmats_t, trans_t,)
+
+ # Main trunk
+ curr_rigids = self.rigids_ang_to_nm(curr_rigids)
+
+ node_embed = init_node_embed
+ edge_embed = init_edge_embed
+ for b in range(self._ipa_conf.num_blocks):
+
+ if self.use_e3_transformer:
+ # ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, edge_embed, l1_feats, pair_index,
+ # edge_src, edge_dst, edge_sh, mask = node_mask)
+ # l1_feats = l1_feats * node_mask[..., None]
+ # ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed,
+ # edge_embed,
+ # l1_feats,
+ # mask = node_mask)
+ # l1_feats = l1_feats * node_mask[..., None]
+ ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed,
+ edge_embed,
+ l1_feats,
+ curr_rigids,
+ node_mask) # (B,L,d_node)
+ l1_feats = l1_feats * node_mask[..., None]
+ else:
+ ipa_embed = self.trunk[f'ipa_{b}'](node_embed,
+ edge_embed,
+ curr_rigids,
+ node_mask) # (B,L,d_node)
+ ipa_embed = node_embed + ipa_embed
+
+ ipa_embed = ipa_embed * node_mask[..., None]
+ node_embed = self.trunk[f'ipa_ln_{b}'](ipa_embed)
+
+ if self.use_adapter_node:
+ node_embed = self.trunk[f'energy_adapter_{b}'](node_embed, energy, mask = node_mask)
+ node_embed = node_embed * node_mask[..., None]
+
+
+ __, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t)
+ node_embed = node_embed * node_mask[..., None]
+
+ # trans_t_new, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t)
+ # rotmats_t_new = torch.zeros(trans_t_new.shape, device = trans_t_new.device) # (B,L,3)
+ # rigid_update = torch.concat([trans_t_new,rotmats_t_new], dim=-1) # (B,L,6)
+ # node_embed = node_embed * node_mask[..., None]
+ # curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, node_mask[..., None])
+
+
+
+ # if self.use_np_update:
+ # node_embed = self.trunk[f'node_transition_{b}'](node_embed, edge_embed, node_mask)
+ # else:
+ # seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
+ # node_embed, src_key_padding_mask=(1 - node_mask).bool())
+ # #node_embed = init_node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
+ # node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
+ # node_embed = self.trunk[f'node_transition_{b}'](node_embed)
+ # node_embed = node_embed * node_mask[..., None]
+
+ if self.use_mid_bb_update:
+ if self.use_mid_bb_update_e3:
+ rigid_update = self.trunk[f'bb_update_{b}'](node_embed, edge_embed, l1_feats, pair_index,
+ edge_src, edge_dst, edge_sh, mask = node_mask)
+ elif self.use_e3_transformer:
+ rigid_update = self.trunk[f'bb_update_{b}'](torch.concat([node_embed,l1_feats],dim=-1))
+ else:
+ rigid_update = self.trunk[f'bb_update_{b}'](node_embed)
+ curr_rigids = curr_rigids.compose_q_update_vec(
+ rigid_update, node_mask[..., None])
+
+ # if b < self._ipa_conf.num_blocks-1:
+ # if self.use_np_update:
+ # edge_embed = self.trunk[f'edge_transition_{b}'](
+ # node_embed, edge_embed, node_mask)
+ # else:
+ # edge_embed = self.trunk[f'edge_transition_{b}'](
+ # node_embed, edge_embed)
+ # edge_embed *= edge_mask[..., None]
+
+
+ if self.use_mid_bb_update_e3:
+ rigid_update = self.bb_update_layer(node_embed, edge_embed, l1_feats, pair_index,
+ edge_src, edge_dst, edge_sh, mask = node_mask)
+ # print('\nl1_feats:','equ='+str(torch.mean(l1_feats[:5])),'inv='+str(torch.mean(torch.linalg.norm(l1_feats[:5],dim=-1))))
+
+ elif self.use_e3_transformer:
+ node_feats_total = torch.concat([node_embed,l1_feats],dim=-1)
+ rigid_update = self.bb_tpc(node_feats_total, node_feats_total, node_embed)
+
+ else:
+ rigid_update = self.bb_update_layer(node_embed)
+
+ curr_rigids = curr_rigids.compose_q_update_vec(
+ rigid_update, node_mask[..., None])
+
+ curr_rigids = self.rigids_nm_to_ang(curr_rigids)
+ pred_trans = curr_rigids.get_trans()
+ # pred_trans = pred_trans - torch.mean(pred_trans, dim=-2, keepdims=True)
+ pred_rotmats = curr_rigids.get_rots().get_rot_mats()
+ # pred_aatype = self.aatype_pred_layer(node_embed)
+
+
+ if self.use_torsions:
+ pred_torsions = node_embed + self.torsions_pred_layer1(node_embed)
+ pred_torsions = self.torsions_pred_layer2(pred_torsions).reshape(input_feats['aatype'].shape+(self.num_torsions,2)) # (B,L,self.num_torsions,2)
+
+ norm_torsions = torch.sqrt(torch.sum(pred_torsions ** 2, dim=-1, keepdim=True)) # (B,L,self.num_torsions,1)
+ pred_torsions = pred_torsions / norm_torsions # (B,L,self.num_torsions,2)
+
+ add_rot = pred_torsions.new_zeros((1,) * len(pred_torsions.shape[:-2])+(8-self.num_torsions,2)) # (1,1,8-self.num_torsions,2)
+ add_rot[..., 1] = 1
+ add_rot = add_rot.expand(*pred_torsions.shape[:-2], -1, -1) # (B,L,8-self.num_torsions,2)
+ pred_torsions_with_CB = torch.concat([add_rot, pred_torsions],dim=-2) # (B,L,8,2)
+
+ # aatype # (B,L)
+
+ return {
+ 'pred_trans': pred_trans,
+ 'pred_rotmats': pred_rotmats,
+ # 'pred_aatype': pred_aatype
+ 'pred_torsions_with_CB': pred_torsions_with_CB,
+ }
diff --git a/models/flow_module.py b/models/flow_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..1862670984b82589caae2ae63d058a2ae2b86221
--- /dev/null
+++ b/models/flow_module.py
@@ -0,0 +1,475 @@
+from typing import Any
+import torch
+import time
+import os
+import random
+import wandb
+import numpy as np
+import pandas as pd
+import logging
+from pytorch_lightning import LightningModule
+from analysis import metrics
+from analysis import utils as au
+from models.flow_model import FlowModel
+from models import utils as mu
+from data.interpolant import Interpolant
+from data import utils as du
+from data import all_atom
+from data import so3_utils
+from experiments import utils as eu
+from data import residue_constants
+from data.residue_constants import order2restype_with_mask
+from pytorch_lightning.loggers.wandb import WandbLogger
+from openfold.utils.exponential_moving_average import ExponentialMovingAverage as EMA
+from openfold.utils.tensor_utils import (
+ tensor_tree_map,
+)
+
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+class FlowModule(LightningModule):
+
+ def __init__(self, cfg, folding_cfg=None):
+ super().__init__()
+ self._print_logger = logging.getLogger(__name__)
+ self._exp_cfg = cfg.experiment
+ self._model_cfg = cfg.model
+ self._data_cfg = cfg.data
+ self._interpolant_cfg = cfg.interpolant
+
+ # Set-up vector field prediction model
+ self.model = FlowModel(cfg.model)
+
+ # self.ema = EMA(self.model, decay=0.99)
+
+ # Set-up interpolant
+ self.interpolant = Interpolant(cfg.interpolant)
+
+ self._sample_write_dir = self._exp_cfg.checkpointer.dirpath
+ os.makedirs(self._sample_write_dir, exist_ok=True)
+
+ self.validation_epoch_metrics = []
+ self.validation_epoch_samples = []
+ self.save_hyperparameters()
+
+ def on_train_start(self):
+ self._epoch_start_time = time.time()
+
+ def on_train_epoch_end(self):
+ epoch_time = (time.time() - self._epoch_start_time) / 60.0
+ self.log(
+ 'train/epoch_time_minutes',
+ epoch_time,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=False
+ )
+ self._epoch_start_time = time.time()
+
+ def model_step(self, noisy_batch: Any, use_mask_aatype = True, eps=1e-8):
+ training_cfg = self._exp_cfg.training
+ loss_mask = noisy_batch['res_mask']
+ if training_cfg.min_plddt_mask is not None:
+ plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask
+ loss_mask *= plddt_mask
+ num_batch, num_res = loss_mask.shape
+
+ # Ground truth labels
+ gt_trans_1 = noisy_batch['trans_1']
+ gt_rotmats_1 = noisy_batch['rotmats_1']
+ rotmats_t = noisy_batch['rotmats_t']
+ gt_rot_vf = so3_utils.calc_rot_vf(
+ rotmats_t, gt_rotmats_1.type(torch.float32))
+
+ gt_bb_atoms_total_result = all_atom.to_atom37(gt_trans_1, gt_rotmats_1, aatype=noisy_batch['aatype'], get_mask=True)
+ gt_bb_atoms = gt_bb_atoms_total_result[0][:, :, :3, :]
+ atom37_mask = gt_bb_atoms_total_result[1]
+
+ gt_atoms = noisy_batch['all_atom_positions'] # (B, L, 37, 3)
+
+ # Timestep used for normalization.
+ t = noisy_batch['t']
+ norm_scale = 1 - torch.min(
+ t[..., None], torch.tensor(training_cfg.t_normalize_clip))
+
+ # Model output predictions.
+ model_output = self.model(noisy_batch, use_mask_aatype=use_mask_aatype)
+ pred_trans_1 = model_output['pred_trans']
+ pred_rotmats_1 = model_output['pred_rotmats']
+ pred_torsions_with_CB = model_output['pred_torsions_with_CB']
+ pred_rots_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1)
+
+ # aatype loss
+ # pred_aatype_1 = model_output['pred_aatype'] # (B,L,21)
+ # pred_aatype_1_logits = torch.log(torch.softmax(pred_aatype_1,dim=-1)) # (B,L,21)
+ # aatype_onehot = torch.nn.functional.one_hot(noisy_batch['aatype'],num_classes=21) # (B,L,21)
+ # aatype_loss = -torch.sum(torch.mul(pred_aatype_1_logits, aatype_onehot), dim=-1) * loss_mask # (B,L)
+ # aatype_loss = torch.sum(aatype_loss, dim=-1) / torch.sum(loss_mask, dim=-1)
+ # aatype_loss = training_cfg.aatype_loss_weight * aatype_loss # (B,)
+
+ # Translation VF loss
+ trans_error = (gt_trans_1 - pred_trans_1) / norm_scale
+ loss_denom = torch.sum(loss_mask, dim=-1) * 3 #
+ trans_loss = training_cfg.translation_loss_weight * torch.sum(
+ trans_error ** 2 * loss_mask[..., None],
+ dim=(-1, -2)
+ ) / loss_denom
+ # trans_loss = torch.zeros(num_batch,device=pred_trans_1.device,dtype=torch.float32)
+
+ # Rotation VF loss
+ rots_vf_error = (gt_rot_vf - pred_rots_vf) / norm_scale
+ rots_vf_loss = training_cfg.rotation_loss_weights * torch.sum(
+ rots_vf_error ** 2 * loss_mask[..., None],
+ dim=(-1, -2)
+ ) / loss_denom
+ # rots_vf_loss = torch.zeros(num_batch,device=pred_trans_1.device,dtype=torch.float32)
+
+
+ # torsion loss
+ torsion_angles, torsion_mask = all_atom.prot_to_torsion_angles(noisy_batch['aatype'], gt_atoms, atom37_mask)
+ # print('atom37_mask',atom37_mask.shape, atom37_mask[0,:5,:15]) # (B,L,37)
+ # print('torsion_angles',torsion_angles.shape, torsion_angles[0,:5,:15,:]) # (B,L,7,2)
+ # print('torsion_mask',torsion_mask.shape, torsion_mask[0,:5,:15]) # (B,L,7)
+ torsion_loss = torch.sum(
+ (torsion_angles - pred_torsions_with_CB[:,:,1:,:]) ** 2 * torsion_mask[..., None],
+ dim=(-1, -2, -3)
+ ) / torch.sum(torsion_mask, dim=(-1, -2))
+
+
+ # atom loss
+ pred_atoms, atoms_mask = all_atom.to_atom37(pred_trans_1, pred_rotmats_1, aatype = noisy_batch['aatype'], torsions_with_CB = pred_torsions_with_CB, get_mask = True) # atoms_mask (B,L,37)
+ atoms_mask = atoms_mask * loss_mask[...,None] # (B,L,37)
+ pred_atoms_flat, gt_atoms_flat, _ = du.batch_align_structures(
+ pred_atoms.reshape(num_batch, -1, 3), gt_atoms.reshape(num_batch, -1, 3), mask=atoms_mask.reshape(num_batch, -1)
+ )
+ gt_atoms = gt_atoms_flat * training_cfg.bb_atom_scale / norm_scale # (B, true_atoms,3)
+ pred_atoms = pred_atoms_flat * training_cfg.bb_atom_scale / norm_scale
+
+ bb_atom_loss = torch.sum(
+ (gt_atoms - pred_atoms) ** 2,
+ dim=(-1, -2)
+ ) / torch.sum(atoms_mask, dim=(-1, -2))
+
+
+
+ # atom distance loss
+ # pred_atoms, atoms_mask = all_atom.to_atom37(pred_trans_1, pred_rotmats_1, aatype = noisy_batch['aatype'], torsions_with_CB = pred_torsions_with_CB, get_mask = True) # atoms_mask (B,L,37)
+ # gt_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
+ # pred_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
+ # atoms_mask = atoms_mask * loss_mask[...,None] # (B,L,37)
+ # gt_flat_atoms = gt_atoms.reshape([num_batch, num_res*37, 3]) # (B,L*37,3)
+ # gt_pair_dists = torch.linalg.norm(
+ # gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1) # (B,L*37,L*37)
+ # pred_flat_atoms = pred_atoms.reshape([num_batch, num_res*37, 3])
+ # pred_pair_dists = torch.linalg.norm(
+ # pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)
+ # flat_mask = atoms_mask.reshape([num_batch, num_res*37]) # (B,L*37)
+
+ # gt_pair_dists = gt_pair_dists * flat_mask[..., None]
+ # pred_pair_dists = pred_pair_dists * flat_mask[..., None]
+ # pair_dist_mask = flat_mask[..., None] * flat_mask[:, None, :] # (B,L*37, L*37)
+
+ # bb_atom_loss = torch.sum(
+ # (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
+ # dim=(1, 2))
+ # bb_atom_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)))
+
+
+
+ # Pairwise distance loss
+ pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3]
+ gt_bb_atoms_pair = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None]
+ pred_bb_atoms_pair = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None]
+ gt_flat_atoms = gt_bb_atoms_pair.reshape([num_batch, num_res*3, 3])
+ gt_pair_dists = torch.linalg.norm(
+ gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
+ pred_flat_atoms = pred_bb_atoms_pair.reshape([num_batch, num_res*3, 3])
+ pred_pair_dists = torch.linalg.norm(
+ pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)
+
+ flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))
+ flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3]) # (B,L*3)
+ flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))
+ flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3]) # (B,L*3)
+
+ gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
+ pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
+ pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :] # (B,L*3, L*3)
+
+ pair_dist_mat_loss = torch.sum(
+ (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
+ dim=(1, 2)) / (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res)
+
+
+ # sequence distance loss
+ pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3]
+ gt_bb_atoms_seq = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None]
+ pred_bb_atoms_seq = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None]
+ gt_flat_atoms = gt_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) # (B,L*3,3)
+ gt_seq_dists = torch.linalg.norm(
+ gt_flat_atoms[:, :-3, :] - gt_flat_atoms[:, 3:, :], dim=-1) # (B,3*(L-1))
+ pred_flat_atoms = pred_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) # (B,L*3,3)
+ pred_seq_dists = torch.linalg.norm(
+ pred_flat_atoms[:, :-3, :] - pred_flat_atoms[:, 3:, :], dim=-1) # (B,3*(L-1))
+
+ flat_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) # (B,L,3)
+ flat_mask = (flat_mask[:,1:,:] * flat_mask[:,:-1,:]).reshape([num_batch, 3*(num_res-1)]) # (B,3*(L-1))
+
+ gt_seq_dists = gt_seq_dists * flat_mask
+ pred_seq_dists = pred_seq_dists * flat_mask
+
+ seq_dist_mat_loss = torch.sum(
+ (gt_seq_dists - pred_seq_dists)**2 * flat_mask,
+ dim=(1)) / (torch.sum(flat_mask, dim=(1)))
+
+ dist_mat_loss = pair_dist_mat_loss + seq_dist_mat_loss
+
+ se3_vf_loss = trans_loss + rots_vf_loss
+ auxiliary_loss = (bb_atom_loss + dist_mat_loss) * (
+ t[:, 0] > training_cfg.aux_loss_t_pass
+ )
+ auxiliary_loss *= self._exp_cfg.training.aux_loss_weight
+ se3_vf_loss += auxiliary_loss
+
+ se3_vf_loss += torsion_loss
+
+ return {
+ "trans_loss": trans_loss,
+ "rots_vf_loss": rots_vf_loss,
+ "bb_atom_loss": bb_atom_loss,
+ "dist_mat_loss": dist_mat_loss,
+ "auxiliary_loss": auxiliary_loss,
+ "se3_vf_loss": se3_vf_loss,
+ # "aatype_loss":aatype_loss
+ }
+
+ def validation_step(self, batch: Any, batch_idx: int):
+ # if self.trainer.global_step > 100000:
+ # # load ema weights
+ # clone_param = lambda t: t.detach().clone()
+ # self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
+ # self.model.load_state_dict(self.ema.state_dict()["params"])
+
+ res_mask = batch['res_mask']
+ self.interpolant.set_device(res_mask.device)
+ num_batch, num_res = res_mask.shape
+
+ self.model.eval()
+ samples = self.interpolant.sample(
+ batch,
+ self.model,
+ )[0][-1].numpy()
+ self.model.train()
+
+ batch_metrics = []
+ for i in range(num_batch):
+
+ # Write out sample to PDB file
+ final_pos = samples[i]
+
+ # mdtraj_metrics = metrics.calc_mdtraj_metrics(saved_path)
+ ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, residue_constants.atom_order['CA']])
+
+ # use 3 bb atoms coords to calculate RMSD, and use CA to calculate tm_score
+ rmsd = metrics.calc_aligned_rmsd(
+ final_pos[:, :3].reshape(-1,3), batch['all_atom_positions'][i].cpu().numpy()[:,:3].reshape(-1,3))
+ seq_string = ''.join([order2restype_with_mask[int(aa)] for aa in batch['aatype'][i].cpu()])
+
+ if len(seq_string) != final_pos[:, 1].shape[0]:
+ seq_string = 'A'*final_pos[:, 1].shape[0]
+
+ tm_score,_ = metrics.calc_tm_score(
+ final_pos[:, 1], batch['all_atom_positions'][i].cpu().numpy()[:,1],
+ seq_string, seq_string)
+
+ valid_loss = {'rmsd_loss':rmsd,
+ 'tm_score':tm_score }
+
+ batch_metrics.append((ca_ca_metrics | valid_loss)) # merge metrics
+ saved_path = au.write_prot_to_pdb(
+ final_pos,
+ os.path.join(
+ self._sample_write_dir,
+ f'idx_{batch_idx}_len_{num_res}_rmsd_{rmsd}_tm_{tm_score}.pdb'),
+ aatype=batch['aatype'][i].cpu().numpy(),
+ no_indexing=True
+ )
+ if isinstance(self.logger, WandbLogger):
+ self.validation_epoch_samples.append(
+ [saved_path, self.global_step, wandb.Molecule(saved_path)]
+ )
+
+ batch_metrics = pd.DataFrame(batch_metrics)
+
+ print("")
+ for key in batch_metrics.columns:
+ print('%s=%.3f ' %(key, np.mean(batch_metrics[key])), end='')
+
+ self.validation_epoch_metrics.append(batch_metrics)
+
+ def on_validation_epoch_end(self):
+ if len(self.validation_epoch_samples) > 0:
+ self.logger.log_table(
+ key='valid/samples',
+ columns=["sample_path", "global_step", "Protein"],
+ data=self.validation_epoch_samples)
+ self.validation_epoch_samples.clear()
+ val_epoch_metrics = pd.concat(self.validation_epoch_metrics)
+ for metric_name,metric_val in val_epoch_metrics.mean().to_dict().items():
+ if metric_name in ['rmsd_loss','tm_score']:
+ self._log_scalar(
+ f'valid/{metric_name}',
+ metric_val,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=True,
+ batch_size=len(val_epoch_metrics),
+ )
+ else:
+ self._log_scalar(
+ f'valid/{metric_name}',
+ metric_val,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=False,
+ batch_size=len(val_epoch_metrics),
+ )
+ self.validation_epoch_metrics.clear()
+
+ def _log_scalar(
+ self,
+ key,
+ value,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=True,
+ batch_size=None,
+ sync_dist=False,
+ rank_zero_only=True
+ ):
+ if sync_dist and rank_zero_only:
+ raise ValueError('Unable to sync dist when rank_zero_only=True')
+ self.log(
+ key,
+ value,
+ on_step=on_step,
+ on_epoch=on_epoch,
+ prog_bar=prog_bar,
+ batch_size=batch_size,
+ sync_dist=sync_dist,
+ rank_zero_only=rank_zero_only
+ )
+
+ def training_step(self, batch: Any, stage: int):
+ # if(self.ema.device != batch["aatype"].device):
+ # self.ema.to(batch["aatype"].device)
+ # self.ema.update(self.model)
+
+ seq_len = batch["aatype"].shape[1]
+ batch_size = min(self._data_cfg.sampler.max_batch_size, self._data_cfg.sampler.max_num_res_squared // seq_len**2) # dynamic batch size
+
+ for key,value in batch.items():
+ batch[key] = value.repeat((batch_size,)+(1,)*(len(value.shape)-1))
+
+
+ # step_start_time = time.time()
+ self.interpolant.set_device(batch['res_mask'].device)
+ noisy_batch = self.interpolant.corrupt_batch(batch)
+ if self._interpolant_cfg.self_condition and random.random() > 0.5:
+ with torch.no_grad():
+ self.model.eval()
+ model_sc = self.model(noisy_batch)
+ noisy_batch['trans_sc'] = model_sc['pred_trans']
+ self.model.train()
+ batch_losses = self.model_step(noisy_batch)
+ num_batch = batch_losses['bb_atom_loss'].shape[0]
+ total_losses = {
+ k: torch.mean(v) for k,v in batch_losses.items()
+ }
+ for k,v in total_losses.items():
+ self._log_scalar(
+ f"train/{k}", v, prog_bar=False, batch_size=num_batch)
+
+ # Losses to track. Stratified across t.
+ t = torch.squeeze(noisy_batch['t'])
+ self._log_scalar(
+ "train/t",
+ np.mean(du.to_numpy(t)),
+ prog_bar=False, batch_size=num_batch)
+
+ # for loss_name, loss_dict in batch_losses.items():
+ # stratified_losses = mu.t_stratified_loss(
+ # t, loss_dict, loss_name=loss_name)
+ # for k,v in stratified_losses.items():
+ # self._log_scalar(
+ # f"train/{k}", v, prog_bar=False, batch_size=num_batch)
+
+ # Training throughput
+ self._log_scalar(
+ "train/length", batch['res_mask'].shape[1], prog_bar=False, batch_size=num_batch)
+ self._log_scalar(
+ "train/batch_size", num_batch, prog_bar=False)
+
+ # step_time = time.time() - step_start_time
+ # self._log_scalar(
+ # "train/examples_per_second", num_batch / step_time)
+
+ train_loss = (
+ total_losses[self._exp_cfg.training.loss]
+ )
+ self._log_scalar(
+ "train/loss", train_loss, batch_size=num_batch)
+
+
+ return train_loss
+
+ def configure_optimizers(self):
+ return torch.optim.AdamW(
+ params=filter(lambda p: p.requires_grad, self.model.parameters()),
+ # params=self.model.parameters(),
+ **self._exp_cfg.optimizer
+ )
+
+
+ def predict_step(self, batch, batch_idx):
+ # device = f'cuda:{torch.cuda.current_device()}'
+ # interpolant = Interpolant(self._infer_cfg.interpolant)
+ # interpolant.set_device(device)
+ # self.interpolant = interpolant
+
+ self.interpolant.set_device(batch['res_mask'].device)
+
+
+ diffuse_mask = torch.ones(batch['aatype'].shape)
+
+ filename = str(batch['filename'][0])
+ sample_dir = os.path.join(
+ self._output_dir, f'batch_idx_{batch_idx}_{filename}')
+
+ self.model.eval()
+ atom37_traj, model_traj, _ = self.interpolant.sample(
+ batch, self.model
+ )
+
+ # print('pred shape')
+ # print(batch['node_repr_pre'].shape) # (B,L,1280)
+ # print(len(atom37_traj),atom37_traj[0].shape) # 101,(B,L,37,3)
+ # print(len(model_traj),model_traj[0].shape) # 100,(B,L,37,3)
+
+ os.makedirs(sample_dir, exist_ok=True)
+ for batch_index in range(atom37_traj[0].shape[0]):
+ bb_traj = du.to_numpy(torch.stack(atom37_traj, dim=0))[:,batch_index] # (101,L,37,3)
+ x0_traj = du.to_numpy(torch.stack(model_traj, dim=0))[:,batch_index]
+ _ = eu.save_traj(
+ bb_traj[-1],
+ bb_traj,
+ np.flip(x0_traj, axis=0),
+ du.to_numpy(diffuse_mask)[batch_index],
+ output_dir=sample_dir,
+ aatype=batch['aatype'][batch_index].cpu().numpy(),
+ index=batch_index,
+ )
+
+ with open(os.path.join(sample_dir,'seq.txt'), 'w') as f:
+ f.write(batch['seq'][0])
+
diff --git a/models/ipa_pytorch.py b/models/ipa_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b126b4a36669b90e254f0d5cfd73ddf1ee871fc9
--- /dev/null
+++ b/models/ipa_pytorch.py
@@ -0,0 +1,693 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Modified code of Openfold's IPA."""
+
+import numpy as np
+import torch
+import math
+from scipy.stats import truncnorm
+import torch.nn as nn
+from typing import Optional, Callable, List, Sequence
+from openfold.utils.rigid_utils import Rigid
+from data import all_atom
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
+
+def _prod(nums):
+ out = 1
+ for n in nums:
+ out = out * n
+ return out
+
+
+def _calculate_fan(linear_weight_shape, fan="fan_in"):
+ fan_out, fan_in = linear_weight_shape
+
+ if fan == "fan_in":
+ f = fan_in
+ elif fan == "fan_out":
+ f = fan_out
+ elif fan == "fan_avg":
+ f = (fan_in + fan_out) / 2
+ else:
+ raise ValueError("Invalid fan option")
+
+ return f
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ f = _calculate_fan(shape, fan)
+ scale = scale / max(1, f)
+ a = -2
+ b = 2
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
+ size = _prod(shape)
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
+ samples = np.reshape(samples, shape)
+ with torch.no_grad():
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def lecun_normal_init_(weights):
+ trunc_normal_init_(weights, scale=1.0)
+
+
+def he_normal_init_(weights):
+ trunc_normal_init_(weights, scale=2.0)
+
+
+def glorot_uniform_init_(weights):
+ nn.init.xavier_uniform_(weights, gain=1)
+
+
+def final_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def gating_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def normal_init_(weights):
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
+
+
+def compute_angles(ca_pos, pts):
+ batch_size, num_res, num_heads, num_pts, _ = pts.shape
+ calpha_vecs = (ca_pos[:, :, None, :] - ca_pos[:, None, :, :]) + 1e-10
+ calpha_vecs = torch.tile(calpha_vecs[:, :, :, None, None, :], (1, 1, 1, num_heads, num_pts, 1))
+ ipa_pts = pts[:, :, None, :, :, :] - torch.tile(ca_pos[:, :, None, None, None, :], (1, 1, num_res, num_heads, num_pts, 1))
+ phi_angles = all_atom.calculate_neighbor_angles(
+ calpha_vecs.reshape(-1, 3),
+ ipa_pts.reshape(-1, 3)
+ ).reshape(batch_size, num_res, num_res, num_heads, num_pts)
+ return phi_angles
+
+
+class Linear(nn.Linear):
+ """
+ A Linear layer with built-in nonstandard initializations. Called just
+ like torch.nn.Linear.
+
+ Implements the initializers in 1.11.4, plus some additional ones found
+ in the code.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization
+ "relu": He initialization w/ truncated normal distribution
+ "glorot": Fan-average Glorot uniform initialization
+ "gating": Weights=0, Bias=1
+ "normal": Normal initialization with std=1/sqrt(fan_in)
+ "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs.
+ Overrides init if not None.
+ """
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+
+ if init_fn is not None:
+ init_fn(self.weight, self.bias)
+ else:
+ if init == "default":
+ lecun_normal_init_(self.weight)
+ elif init == "relu":
+ he_normal_init_(self.weight)
+ elif init == "glorot":
+ glorot_uniform_init_(self.weight)
+ elif init == "gating":
+ gating_init_(self.weight)
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(1.0)
+ elif init == "normal":
+ normal_init_(self.weight)
+ elif init == "final":
+ final_init_(self.weight)
+ else:
+ raise ValueError("Invalid init string.")
+
+
+class StructureModuleTransition(nn.Module):
+ def __init__(self, c):
+ super(StructureModuleTransition, self).__init__()
+
+ self.c = c
+
+ self.linear_1 = Linear(self.c, self.c, init="relu")
+ self.linear_2 = Linear(self.c, self.c, init="relu")
+ self.linear_3 = Linear(self.c, self.c, init="final")
+ self.relu = nn.ReLU()
+ self.ln = nn.LayerNorm(self.c)
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+ s = s + s_initial
+ s = self.ln(s)
+
+ return s
+
+
+class EdgeTransition(nn.Module):
+ def __init__(
+ self,
+ *,
+ node_embed_size,
+ edge_embed_in,
+ edge_embed_out,
+ num_layers=2,
+ node_dilation=2
+ ):
+ super(EdgeTransition, self).__init__()
+
+ bias_embed_size = node_embed_size // node_dilation
+ self.initial_embed = Linear(
+ node_embed_size, bias_embed_size, init="relu")
+ hidden_size = bias_embed_size * 2 + edge_embed_in
+ trunk_layers = []
+ for _ in range(num_layers):
+ trunk_layers.append(Linear(hidden_size, hidden_size, init="relu"))
+ trunk_layers.append(nn.ReLU())
+ self.trunk = nn.Sequential(*trunk_layers)
+ self.final_layer = Linear(hidden_size, edge_embed_out, init="final")
+ self.layer_norm = nn.LayerNorm(edge_embed_out)
+
+ def forward(self, node_embed, edge_embed):
+ node_embed = self.initial_embed(node_embed)
+ batch_size, num_res, _ = node_embed.shape
+ edge_bias = torch.cat([
+ torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)),
+ torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)),
+ ], axis=-1)
+ edge_embed = torch.cat(
+ [edge_embed, edge_bias], axis=-1).reshape(
+ batch_size * num_res**2, -1)
+ edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed)
+ edge_embed = self.layer_norm(edge_embed)
+ edge_embed = edge_embed.reshape(
+ batch_size, num_res, num_res, -1
+ )
+ return edge_embed
+
+
+class InvariantPointAttention(nn.Module):
+ """
+ Implements Algorithm 22.
+ """
+ def __init__(
+ self,
+ ipa_conf,
+ inf: float = 1e5,
+ eps: float = 1e-8,
+ ):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ c_z:
+ Pair representation channel dimension
+ c_hidden:
+ Hidden channel dimension
+ no_heads:
+ Number of attention heads
+ no_qk_points:
+ Number of query/key points to generate
+ no_v_points:
+ Number of value points to generate
+ """
+ super(InvariantPointAttention, self).__init__()
+ self._ipa_conf = ipa_conf
+
+ self.c_s = ipa_conf.c_s
+ self.c_z = ipa_conf.c_z
+ self.c_hidden = ipa_conf.c_hidden
+ self.no_heads = ipa_conf.no_heads
+ self.no_qk_points = ipa_conf.no_qk_points
+ self.no_v_points = ipa_conf.no_v_points
+ self.inf = inf
+ self.eps = eps
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = self.c_hidden * self.no_heads
+ self.linear_q = Linear(self.c_s, hc)
+ self.linear_kv = Linear(self.c_s, 2 * hc)
+
+ hpq = self.no_heads * self.no_qk_points * 3
+ self.linear_q_points = Linear(self.c_s, hpq)
+
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
+ self.linear_kv_points = Linear(self.c_s, hpkv)
+
+ self.linear_b = Linear(self.c_z, self.no_heads)
+ self.down_z = Linear(self.c_z, self.c_z // 4)
+
+ self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads)))
+ ipa_point_weights_init_(self.head_weights)
+
+ concat_out_dim = (
+ self.c_z // 4 + self.c_hidden + self.no_v_points * 4
+ )
+ self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ r: Rigid,
+ mask: torch.Tensor,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ if _offload_inference:
+ z = _z_reference_list
+ else:
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
+ )
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
+ )
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if(_offload_inference):
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ a = torch.matmul(
+ math.sqrt(1.0 / (3 * self.c_hidden)) *
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ # a *= math.sqrt(1.0 / (3 * self.c_hidden))
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ pt_att = pt_displacement ** 2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ head_weights = self.softplus(self.head_weights).view(
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
+ )
+ head_weights = head_weights * math.sqrt(
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
+ )
+ pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(
+ a, v.transpose(-2, -3)
+ ).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ o_pt = torch.sum(
+ (
+ a[..., None, :, :, None]
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
+ ),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
+ o_pt_norm_feats = flatten_final_dims(
+ o_pt_dists, 2)
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if(_offload_inference):
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z // 4]
+ pair_z = self.down_z(z[0])
+ o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
+
+ # [*, N_res, H * C_z // 4]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat(
+ o_feats, dim=-1
+ )
+ )
+
+ return s
+
+
+class TorsionAngles(nn.Module):
+ def __init__(self, c, num_torsions, eps=1e-8):
+ super(TorsionAngles, self).__init__()
+
+ self.c = c
+ self.eps = eps
+ self.num_torsions = num_torsions
+
+ self.linear_1 = Linear(self.c, self.c, init="relu")
+ self.linear_2 = Linear(self.c, self.c, init="relu")
+ # TODO: Remove after published checkpoint is updated without these weights.
+ self.linear_3 = Linear(self.c, self.c, init="final")
+ self.linear_final = Linear(
+ self.c, self.num_torsions * 2, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+
+ s = s + s_initial
+ unnormalized_s = self.linear_final(s)
+ norm_denom = torch.sqrt(
+ torch.clamp(
+ torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True),
+ min=self.eps,
+ )
+ )
+ normalized_s = unnormalized_s / norm_denom
+
+ return unnormalized_s, normalized_s
+
+
+class RotationVFLayer(nn.Module):
+ def __init__(self, dim):
+ super(RotationVFLayer, self).__init__()
+
+ self.linear_1 = Linear(dim, dim, init="relu")
+ self.linear_2 = Linear(dim, dim, init="relu")
+ self.linear_3 = Linear(dim, dim)
+ self.final_linear = Linear(dim, 6, init="final")
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+ s = s + s_initial
+ return self.final_linear(s)
+
+
+class BackboneUpdate(nn.Module):
+ """
+ Implements part of Algorithm 23.
+ """
+
+ def __init__(self, c_s, use_rot_updates, dropout=0.0):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ """
+ super(BackboneUpdate, self).__init__()
+
+ self.c_s = c_s
+ self.noise_schedule = 1.0
+ update_dim = 6 if use_rot_updates else 3
+
+ self.linear = nn.Sequential(
+ Linear(self.c_s, update_dim, init="final"),
+ # nn.Tanh(),
+ )
+
+
+ def forward(self, s: torch.Tensor):
+ """
+ Args:
+ [*, N_res, C_s] single representation
+ Returns:
+ [*, N_res, 6] update vector
+ """
+ # [*, 6]
+ update = self.noise_schedule * self.linear(s)
+
+ return update
+
+class IpaScore(nn.Module):
+
+ def __init__(self, model_conf, diffuser):
+ super(IpaScore, self).__init__()
+ self._model_conf = model_conf
+ ipa_conf = model_conf.ipa
+ self._ipa_conf = ipa_conf
+ self.diffuser = diffuser
+
+ self.scale_pos = lambda x: x * ipa_conf.coordinate_scaling
+ self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos)
+
+ self.unscale_pos = lambda x: x / ipa_conf.coordinate_scaling
+ self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos)
+ self.trunk = nn.ModuleDict()
+
+ for b in range(ipa_conf.num_blocks):
+ self.trunk[f'ipa_{b}'] = InvariantPointAttention(ipa_conf)
+ self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(ipa_conf.c_s)
+ self.trunk[f'skip_embed_{b}'] = Linear(
+ self._model_conf.node_embed_size,
+ self._ipa_conf.c_skip,
+ init="final"
+ )
+ tfmr_in = ipa_conf.c_s + self._ipa_conf.c_skip
+ tfmr_layer = torch.nn.TransformerEncoderLayer(
+ d_model=tfmr_in,
+ nhead=ipa_conf.seq_tfmr_num_heads,
+ dim_feedforward=tfmr_in,
+ batch_first=True,
+ dropout=0.0,
+ norm_first=False
+ )
+ self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
+ tfmr_layer, ipa_conf.seq_tfmr_num_layers)
+ self.trunk[f'post_tfmr_{b}'] = Linear(
+ tfmr_in, ipa_conf.c_s, init="final")
+ self.trunk[f'node_transition_{b}'] = StructureModuleTransition(
+ c=ipa_conf.c_s)
+ self.trunk[f'bb_update_{b}'] = BackboneUpdate(ipa_conf.c_s)
+
+ if b < ipa_conf.num_blocks-1:
+ # No edge update on the last block.
+ edge_in = self._model_conf.edge_embed_size
+ self.trunk[f'edge_transition_{b}'] = EdgeTransition(
+ node_embed_size=ipa_conf.c_s,
+ edge_embed_in=edge_in,
+ edge_embed_out=self._model_conf.edge_embed_size,
+ )
+
+ self.torsion_pred = TorsionAngles(ipa_conf.c_s, 1)
+
+ def forward(self, init_node_embed, edge_embed, input_feats):
+ node_mask = input_feats['res_mask'].type(torch.float32)
+ diffuse_mask = (1 - input_feats['fixed_mask'].type(torch.float32)) * node_mask
+ edge_mask = node_mask[..., None] * node_mask[..., None, :]
+ init_frames = input_feats['rigids_t'].type(torch.float32)
+
+ curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames))
+ init_rigids = Rigid.from_tensor_7(init_frames)
+ init_rots = init_rigids.get_rots()
+
+ # Main trunk
+ curr_rigids = self.scale_rigids(curr_rigids)
+ init_node_embed = init_node_embed * node_mask[..., None]
+ node_embed = init_node_embed * node_mask[..., None]
+ for b in range(self._ipa_conf.num_blocks):
+ ipa_embed = self.trunk[f'ipa_{b}'](
+ node_embed,
+ edge_embed,
+ curr_rigids,
+ node_mask)
+ ipa_embed *= node_mask[..., None]
+ node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
+ seq_tfmr_in = torch.cat([
+ node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed)
+ ], dim=-1)
+ seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
+ seq_tfmr_in, src_key_padding_mask=1 - node_mask)
+ node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
+ node_embed = self.trunk[f'node_transition_{b}'](node_embed)
+ node_embed = node_embed * node_mask[..., None]
+ rigid_update = self.trunk[f'bb_update_{b}'](
+ node_embed * diffuse_mask[..., None])
+ curr_rigids = curr_rigids.compose_q_update_vec(
+ rigid_update, diffuse_mask[..., None])
+
+ if b < self._ipa_conf.num_blocks-1:
+ edge_embed = self.trunk[f'edge_transition_{b}'](
+ node_embed, edge_embed)
+ edge_embed *= edge_mask[..., None]
+ rot_score = self.diffuser.calc_rot_score(
+ init_rigids.get_rots(),
+ curr_rigids.get_rots(),
+ input_feats['t']
+ )
+ rot_score = rot_score * node_mask[..., None]
+
+ curr_rigids = self.unscale_rigids(curr_rigids)
+ trans_score = self.diffuser.calc_trans_score(
+ init_rigids.get_trans(),
+ curr_rigids.get_trans(),
+ input_feats['t'][:, None, None],
+ use_torch=True,
+ )
+ trans_score = trans_score * node_mask[..., None]
+ _, psi_pred = self.torsion_pred(node_embed)
+ model_out = {
+ 'psi': psi_pred,
+ 'rot_score': rot_score,
+ 'trans_score': trans_score,
+ 'final_rigids': curr_rigids,
+ }
+ return model_out
diff --git a/models/node_embedder.py b/models/node_embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c496003893dd282bab9f54aedfea6c642e0af760
--- /dev/null
+++ b/models/node_embedder.py
@@ -0,0 +1,65 @@
+"""Neural network for embedding node features."""
+import torch
+from torch import nn
+from models.utils import get_index_embedding, get_time_embedding, add_RoPE
+
+class NodeEmbedder(nn.Module):
+ def __init__(self, module_cfg):
+ super(NodeEmbedder, self).__init__()
+ self._cfg = module_cfg
+ self.c_s = self._cfg.c_s
+ self.c_pos_emb = self._cfg.c_pos_emb
+ self.c_timestep_emb = self._cfg.c_timestep_emb
+ self.c_node_pre = 1280
+ self.aatype_emb_dim = self._cfg.c_pos_emb
+
+ self.aatype_emb = nn.Embedding(21, self.aatype_emb_dim)
+
+ total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + self.c_node_pre
+ # total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb
+
+ self.linear = nn.Sequential(
+ nn.Linear(total_node_feats, self.c_s),
+ nn.ReLU(),
+ nn.Dropout(self._cfg.dropout),
+ nn.Linear(self.c_s, self.c_s),
+ )
+
+ def embed_t(self, timesteps, mask):
+ timestep_emb = get_time_embedding(
+ timesteps[:, 0],
+ self.c_timestep_emb,
+ max_positions=2056
+ )[:, None, :].repeat(1, mask.shape[1], 1)
+ return timestep_emb * mask.unsqueeze(-1)
+
+ def forward(self, timesteps, aatype, node_repr_pre, mask):
+ '''
+ mask: [B,L]
+ timesteps: [B,1]
+ energy: [B,]
+ '''
+
+ b, num_res, device = mask.shape[0], mask.shape[1], mask.device
+
+ # [b, n_res, c_pos_emb]
+ # pos = torch.arange(num_res, dtype=torch.float32).to(device)[None] # (1,L)
+ # pos_emb = get_index_embedding(
+ # pos, self.c_pos_emb, max_len=2056
+ # )
+ # pos_emb = pos_emb.repeat([b, 1, 1])
+ # pos_emb = pos_emb * mask.unsqueeze(-1)
+
+ aatype_emb = self.aatype_emb(aatype) * mask.unsqueeze(-1)
+
+ # [b, n_res, c_timestep_emb]
+ input_feats = [aatype_emb]
+ # timesteps are between 0 and 1. Convert to integers.
+ time_emb = self.embed_t(timesteps, mask)
+ input_feats.append(time_emb)
+
+ input_feats.append(node_repr_pre)
+
+ out = self.linear(torch.cat(input_feats, dim=-1)) # (B,L,d_node)
+
+ return add_RoPE(out)
\ No newline at end of file
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e575b106a7832454f4ea610f5ec9b858cb299b4
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,97 @@
+import math
+import torch
+from torch.nn import functional as F
+import numpy as np
+from data import utils as du
+
+
+def calc_distogram(pos, min_bin, max_bin, num_bins):
+ dists_2d = torch.linalg.norm(
+ pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] # (B,L,L,1)
+ lower = torch.linspace(
+ min_bin,
+ max_bin,
+ num_bins,
+ device=pos.device)
+ upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
+ dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) # (B,L,L,num_bins)
+ return dgram
+
+def add_RoPE(indices):
+ """Creates sine / cosine positional embeddings from a prespecified indices.
+
+ Args:
+ indices: (B,L,embed_size)
+ embed_size: dimension of the embeddings to create
+
+ Returns:
+ positional embedding of shape [B, L, embed_size]
+ """
+ seq_len, embed_size = indices.shape[-2:]
+ seq_all = torch.arange(seq_len, device=indices.device)[:,None] # (L,1)
+ theta_all = torch.pow(1e4, torch.arange(embed_size)//2 / -embed_size)[None,:] # (1,embed_size)
+ sinusoidal_pos = (seq_all * theta_all.to(indices.device))[None,...] # (1,L,embed_size)
+
+ cos_pos = torch.cos(sinusoidal_pos) # (1,L,embed_size)
+ sin_pos = torch.sin(sinusoidal_pos) # (1,L,embed_size)
+ indices_sin = torch.stack([-indices[..., 1::2], indices[..., ::2]], dim=-1) # (B,L,embed_size/2,2)
+ indices_sin = indices_sin.reshape(indices.shape) # (B,L,embed_size)
+ indices = indices * cos_pos + indices_sin * sin_pos
+ return indices
+
+def get_index_embedding(indices, embed_size, max_len=2056):
+ """Creates sine / cosine positional embeddings from a prespecified indices.
+
+ Args:
+ indices: offsets of size [..., N_edges] of type integer
+ max_len: maximum length.
+ embed_size: dimension of the embeddings to create
+
+ Returns:
+ positional embedding of shape [N, embed_size]
+ """
+ K = torch.arange(embed_size//2, device=indices.device)
+ pos_embedding_sin = torch.sin(
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
+ pos_embedding_cos = torch.cos(
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
+ pos_embedding = torch.cat([
+ pos_embedding_sin, pos_embedding_cos], axis=-1)
+ return pos_embedding
+
+
+def get_time_embedding(timesteps, embedding_dim, max_positions=2000):
+ # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
+ assert len(timesteps.shape) == 1
+ timesteps = timesteps * max_positions
+ half_dim = embedding_dim // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = F.pad(emb, (0, 1), mode='constant')
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
+ return emb
+
+
+def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None):
+ """Stratify loss by binning t."""
+ batch_t = du.to_numpy(batch_t)
+ batch_loss = du.to_numpy(batch_loss)
+ flat_losses = batch_loss.flatten()
+ flat_t = batch_t.flatten()
+ bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins+1)
+ bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1
+ t_binned_loss = np.bincount(bin_idx, weights=flat_losses)
+ t_binned_n = np.bincount(bin_idx)
+ stratified_losses = {}
+ if loss_name is None:
+ loss_name = 'loss'
+ for t_bin in np.unique(bin_idx).tolist():
+ bin_start = bin_edges[t_bin]
+ bin_end = bin_edges[t_bin+1]
+ t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})'
+ range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin]
+ stratified_losses[t_range] = range_loss
+ return stratified_losses
diff --git a/openfold/__pycache__/config.cpython-310.pyc b/openfold/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7ec659827725a7d5c55f0234fd9d83286346e55
Binary files /dev/null and b/openfold/__pycache__/config.cpython-310.pyc differ
diff --git a/openfold/config.py b/openfold/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b73acb91d367c07f3c99cf7400883041baf67a7a
--- /dev/null
+++ b/openfold/config.py
@@ -0,0 +1,4 @@
+NUM_RES = "num residues placeholder"
+NUM_MSA_SEQ = "msa placeholder"
+NUM_EXTRA_SEQ = "extra msa placeholder"
+NUM_TEMPLATES = "num templates placeholder"
diff --git a/openfold/data/__init__.py b/openfold/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openfold/data/__pycache__/__init__.cpython-310.pyc b/openfold/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bd86eeba1f6edf1a55a3c90c7dfcb997482dd47
Binary files /dev/null and b/openfold/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/openfold/data/__pycache__/data_transforms.cpython-310.pyc b/openfold/data/__pycache__/data_transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cc04441a75896d85cf8dadb99a4eb4ec8ca9cd4
Binary files /dev/null and b/openfold/data/__pycache__/data_transforms.cpython-310.pyc differ
diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec88f02607b8366552658583213bdbdb49db7c4
--- /dev/null
+++ b/openfold/data/data_modules.py
@@ -0,0 +1,693 @@
+import copy
+from functools import partial
+import json
+import logging
+import os
+import pickle
+from typing import Optional, Sequence, List, Any
+
+import ml_collections as mlc
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch.utils.data import RandomSampler
+
+from openfold.data import (
+ data_pipeline,
+ feature_pipeline,
+ mmcif_parsing,
+ templates,
+)
+from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
+
+
+class OpenFoldSingleDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ data_dir: str,
+ alignment_dir: str,
+ template_mmcif_dir: str,
+ max_template_date: str,
+ config: mlc.ConfigDict,
+ kalign_binary_path: str = '/usr/bin/kalign',
+ max_template_hits: int = 4,
+ obsolete_pdbs_file_path: Optional[str] = None,
+ template_release_dates_cache_path: Optional[str] = None,
+ shuffle_top_k_prefiltered: Optional[int] = None,
+ treat_pdb_as_distillation: bool = True,
+ mapping_path: Optional[str] = None,
+ mode: str = "train",
+ _output_raw: bool = False,
+ _alignment_index: Optional[Any] = None
+ ):
+ """
+ Args:
+ data_dir:
+ A path to a directory containing mmCIF files (in train
+ mode) or FASTA files (in inference mode).
+ alignment_dir:
+ A path to a directory containing only data in the format
+ output by an AlignmentRunner
+ (defined in openfold.features.alignment_runner).
+ I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
+ or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
+ files.
+ template_mmcif_dir:
+ Path to a directory containing template mmCIF files.
+ config:
+ A dataset config object. See openfold.config
+ kalign_binary_path:
+ Path to kalign binary.
+ max_template_hits:
+ An upper bound on how many templates are considered. During
+ training, the templates ultimately used are subsampled
+ from this total quantity.
+ template_release_dates_cache_path:
+ Path to the output of scripts/generate_mmcif_cache.
+ obsolete_pdbs_file_path:
+ Path to the file containing replacements for obsolete PDBs.
+ shuffle_top_k_prefiltered:
+ Whether to uniformly shuffle the top k template hits before
+ parsing max_template_hits of them. Can be used to
+ approximate DeepMind's training-time template subsampling
+ scheme much more performantly.
+ treat_pdb_as_distillation:
+ Whether to assume that .pdb files in the data_dir are from
+ the self-distillation set (and should be subjected to
+ special distillation set preprocessing steps).
+ mode:
+ "train", "val", or "predict"
+ """
+ super(OpenFoldSingleDataset, self).__init__()
+ self.data_dir = data_dir
+ self.alignment_dir = alignment_dir
+ self.config = config
+ self.treat_pdb_as_distillation = treat_pdb_as_distillation
+ self.mode = mode
+ self._output_raw = _output_raw
+ self._alignment_index = _alignment_index
+
+ valid_modes = ["train", "eval", "predict"]
+ if(mode not in valid_modes):
+ raise ValueError(f'mode must be one of {valid_modes}')
+
+ if(template_release_dates_cache_path is None):
+ logging.warning(
+ "Template release dates cache does not exist. Remember to run "
+ "scripts/generate_mmcif_cache.py before running OpenFold"
+ )
+
+ if(_alignment_index is not None):
+ self._chain_ids = list(_alignment_index.keys())
+ elif(mapping_path is None):
+ self._chain_ids = list(os.listdir(alignment_dir))
+ else:
+ with open(mapping_path, "r") as f:
+ self._chain_ids = [l.strip() for l in f.readlines()]
+
+ self._chain_id_to_idx_dict = {
+ chain: i for i, chain in enumerate(self._chain_ids)
+ }
+
+ template_featurizer = templates.TemplateHitFeaturizer(
+ mmcif_dir=template_mmcif_dir,
+ max_template_date=max_template_date,
+ max_hits=max_template_hits,
+ kalign_binary_path=kalign_binary_path,
+ release_dates_path=template_release_dates_cache_path,
+ obsolete_pdbs_path=obsolete_pdbs_file_path,
+ _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
+ )
+
+ self.data_pipeline = data_pipeline.DataPipeline(
+ template_featurizer=template_featurizer,
+ )
+
+ if(not self._output_raw):
+ self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
+
+ def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
+ with open(path, 'r') as f:
+ mmcif_string = f.read()
+
+ mmcif_object = mmcif_parsing.parse(
+ file_id=file_id, mmcif_string=mmcif_string
+ )
+
+ # Crash if an error is encountered. Any parsing errors should have
+ # been dealt with at the alignment stage.
+ if(mmcif_object.mmcif_object is None):
+ raise list(mmcif_object.errors.values())[0]
+
+ mmcif_object = mmcif_object.mmcif_object
+
+ data = self.data_pipeline.process_mmcif(
+ mmcif=mmcif_object,
+ alignment_dir=alignment_dir,
+ chain_id=chain_id,
+ _alignment_index=_alignment_index
+ )
+
+ return data
+
+ def chain_id_to_idx(self, chain_id):
+ return self._chain_id_to_idx_dict[chain_id]
+
+ def idx_to_chain_id(self, idx):
+ return self._chain_ids[idx]
+
+ def __getitem__(self, idx):
+ name = self.idx_to_chain_id(idx)
+ alignment_dir = os.path.join(self.alignment_dir, name)
+
+ _alignment_index = None
+ if(self._alignment_index is not None):
+ alignment_dir = self.alignment_dir
+ _alignment_index = self._alignment_index[name]
+
+ if(self.mode == 'train' or self.mode == 'eval'):
+ spl = name.rsplit('_', 1)
+ if(len(spl) == 2):
+ file_id, chain_id = spl
+ else:
+ file_id, = spl
+ chain_id = None
+
+ path = os.path.join(self.data_dir, file_id)
+ if(os.path.exists(path + ".cif")):
+ data = self._parse_mmcif(
+ path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
+ )
+ elif(os.path.exists(path + ".core")):
+ data = self.data_pipeline.process_core(
+ path + ".core", alignment_dir, _alignment_index,
+ )
+ elif(os.path.exists(path + ".pdb")):
+ data = self.data_pipeline.process_pdb(
+ pdb_path=path + ".pdb",
+ alignment_dir=alignment_dir,
+ is_distillation=self.treat_pdb_as_distillation,
+ chain_id=chain_id,
+ _alignment_index=_alignment_index,
+ )
+ else:
+ raise ValueError("Invalid file type")
+ else:
+ path = os.path.join(name, name + ".fasta")
+ data = self.data_pipeline.process_fasta(
+ fasta_path=path,
+ alignment_dir=alignment_dir,
+ _alignment_index=_alignment_index,
+ )
+
+ if(self._output_raw):
+ return data
+
+ feats = self.feature_pipeline.process_features(
+ data, self.mode
+ )
+
+ return feats
+
+ def __len__(self):
+ return len(self._chain_ids)
+
+
+def deterministic_train_filter(
+ chain_data_cache_entry: Any,
+ max_resolution: float = 9.,
+ max_single_aa_prop: float = 0.8,
+) -> bool:
+ # Hard filters
+ resolution = chain_data_cache_entry.get("resolution", None)
+ if(resolution is not None and resolution > max_resolution):
+ return False
+
+ seq = chain_data_cache_entry["seq"]
+ counts = {}
+ for aa in seq:
+ counts.setdefault(aa, 0)
+ counts[aa] += 1
+ largest_aa_count = max(counts.values())
+ largest_single_aa_prop = largest_aa_count / len(seq)
+ if(largest_single_aa_prop > max_single_aa_prop):
+ return False
+
+ return True
+
+
+def get_stochastic_train_filter_prob(
+ chain_data_cache_entry: Any,
+) -> List[float]:
+ # Stochastic filters
+ probabilities = []
+
+ cluster_size = chain_data_cache_entry.get("cluster_size", None)
+ if(cluster_size is not None and cluster_size > 0):
+ probabilities.append(1 / cluster_size)
+
+ chain_length = len(chain_data_cache_entry["seq"])
+ probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
+
+ # Risk of underflow here?
+ out = 1
+ for p in probabilities:
+ out *= p
+
+ return out
+
+
+class OpenFoldDataset(torch.utils.data.Dataset):
+ """
+ Implements the stochastic filters applied during AlphaFold's training.
+ Because samples are selected from constituent datasets randomly, the
+ length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
+ and filtered once at initialization.
+ """
+ def __init__(self,
+ datasets: Sequence[OpenFoldSingleDataset],
+ probabilities: Sequence[int],
+ epoch_len: int,
+ chain_data_cache_paths: List[str],
+ generator: torch.Generator = None,
+ _roll_at_init: bool = True,
+ ):
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.epoch_len = epoch_len
+ self.generator = generator
+
+ self.chain_data_caches = []
+ for path in chain_data_cache_paths:
+ with open(path, "r") as fp:
+ self.chain_data_caches.append(json.load(fp))
+
+ def looped_shuffled_dataset_idx(dataset_len):
+ while True:
+ # Uniformly shuffle each dataset's indices
+ weights = [1. for _ in range(dataset_len)]
+ shuf = torch.multinomial(
+ torch.tensor(weights),
+ num_samples=dataset_len,
+ replacement=False,
+ generator=self.generator,
+ )
+ for idx in shuf:
+ yield idx
+
+ def looped_samples(dataset_idx):
+ max_cache_len = int(epoch_len * probabilities[dataset_idx])
+ dataset = self.datasets[dataset_idx]
+ idx_iter = looped_shuffled_dataset_idx(len(dataset))
+ chain_data_cache = self.chain_data_caches[dataset_idx]
+ while True:
+ weights = []
+ idx = []
+ for _ in range(max_cache_len):
+ candidate_idx = next(idx_iter)
+ chain_id = dataset.idx_to_chain_id(candidate_idx)
+ chain_data_cache_entry = chain_data_cache[chain_id]
+ if(not deterministic_train_filter(chain_data_cache_entry)):
+ continue
+
+ p = get_stochastic_train_filter_prob(
+ chain_data_cache_entry,
+ )
+ weights.append([1. - p, p])
+ idx.append(candidate_idx)
+
+ samples = torch.multinomial(
+ torch.tensor(weights),
+ num_samples=1,
+ generator=self.generator,
+ )
+ samples = samples.squeeze()
+
+ cache = [i for i, s in zip(idx, samples) if s]
+
+ for datapoint_idx in cache:
+ yield datapoint_idx
+
+ self._samples = [looped_samples(i) for i in range(len(self.datasets))]
+
+ if(_roll_at_init):
+ self.reroll()
+
+ def __getitem__(self, idx):
+ dataset_idx, datapoint_idx = self.datapoints[idx]
+ return self.datasets[dataset_idx][datapoint_idx]
+
+ def __len__(self):
+ return self.epoch_len
+
+ def reroll(self):
+ dataset_choices = torch.multinomial(
+ torch.tensor(self.probabilities),
+ num_samples=self.epoch_len,
+ replacement=True,
+ generator=self.generator,
+ )
+
+ self.datapoints = []
+ for dataset_idx in dataset_choices:
+ samples = self._samples[dataset_idx]
+ datapoint_idx = next(samples)
+ self.datapoints.append((dataset_idx, datapoint_idx))
+
+
+class OpenFoldBatchCollator:
+ def __init__(self, config, stage="train"):
+ self.stage = stage
+ self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
+
+ def __call__(self, raw_prots):
+ processed_prots = []
+ for prot in raw_prots:
+ features = self.feature_pipeline.process_features(
+ prot, self.stage
+ )
+ processed_prots.append(features)
+
+ stack_fn = partial(torch.stack, dim=0)
+ return dict_multimap(stack_fn, processed_prots)
+
+
+class OpenFoldDataLoader(torch.utils.data.DataLoader):
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.config = config
+ self.stage = stage
+
+ if(generator is None):
+ generator = torch.Generator()
+
+ self.generator = generator
+ self._prep_batch_properties_probs()
+
+ def _prep_batch_properties_probs(self):
+ keyed_probs = []
+ stage_cfg = self.config[self.stage]
+
+ max_iters = self.config.common.max_recycling_iters
+ if(stage_cfg.supervised):
+ clamp_prob = self.config.supervised.clamp_prob
+ keyed_probs.append(
+ ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
+ )
+
+ if(stage_cfg.uniform_recycling):
+ recycling_probs = [
+ 1. / (max_iters + 1) for _ in range(max_iters + 1)
+ ]
+ else:
+ recycling_probs = [
+ 0. for _ in range(max_iters + 1)
+ ]
+ recycling_probs[-1] = 1.
+
+ keyed_probs.append(
+ ("no_recycling_iters", recycling_probs)
+ )
+
+ keys, probs = zip(*keyed_probs)
+ max_len = max([len(p) for p in probs])
+ padding = [[0.] * (max_len - len(p)) for p in probs]
+
+ self.prop_keys = keys
+ self.prop_probs_tensor = torch.tensor(
+ [p + pad for p, pad in zip(probs, padding)],
+ dtype=torch.float32,
+ )
+
+ def _add_batch_properties(self, batch):
+ samples = torch.multinomial(
+ self.prop_probs_tensor,
+ num_samples=1, # 1 per row
+ replacement=True,
+ generator=self.generator
+ )
+
+ aatype = batch["aatype"]
+ batch_dims = aatype.shape[:-2]
+ recycling_dim = aatype.shape[-1]
+ no_recycling = recycling_dim
+ for i, key in enumerate(self.prop_keys):
+ sample = int(samples[i][0])
+ sample_tensor = torch.tensor(
+ sample,
+ device=aatype.device,
+ requires_grad=False
+ )
+ orig_shape = sample_tensor.shape
+ sample_tensor = sample_tensor.view(
+ (1,) * len(batch_dims) + sample_tensor.shape + (1,)
+ )
+ sample_tensor = sample_tensor.expand(
+ batch_dims + orig_shape + (recycling_dim,)
+ )
+ batch[key] = sample_tensor
+
+ if(key == "no_recycling_iters"):
+ no_recycling = sample
+
+ resample_recycling = lambda t: t[..., :no_recycling + 1]
+ batch = tensor_tree_map(resample_recycling, batch)
+
+ return batch
+
+ def __iter__(self):
+ it = super().__iter__()
+
+ def _batch_prop_gen(iterator):
+ for batch in iterator:
+ yield self._add_batch_properties(batch)
+
+ return _batch_prop_gen(it)
+
+
+class OpenFoldDataModule(pl.LightningDataModule):
+ def __init__(self,
+ config: mlc.ConfigDict,
+ template_mmcif_dir: str,
+ max_template_date: str,
+ train_data_dir: Optional[str] = None,
+ train_alignment_dir: Optional[str] = None,
+ train_chain_data_cache_path: Optional[str] = None,
+ distillation_data_dir: Optional[str] = None,
+ distillation_alignment_dir: Optional[str] = None,
+ distillation_chain_data_cache_path: Optional[str] = None,
+ val_data_dir: Optional[str] = None,
+ val_alignment_dir: Optional[str] = None,
+ predict_data_dir: Optional[str] = None,
+ predict_alignment_dir: Optional[str] = None,
+ kalign_binary_path: str = '/usr/bin/kalign',
+ train_mapping_path: Optional[str] = None,
+ distillation_mapping_path: Optional[str] = None,
+ obsolete_pdbs_file_path: Optional[str] = None,
+ template_release_dates_cache_path: Optional[str] = None,
+ batch_seed: Optional[int] = None,
+ train_epoch_len: int = 50000,
+ _alignment_index_path: Optional[str] = None,
+ **kwargs
+ ):
+ super(OpenFoldDataModule, self).__init__()
+
+ self.config = config
+ self.template_mmcif_dir = template_mmcif_dir
+ self.max_template_date = max_template_date
+ self.train_data_dir = train_data_dir
+ self.train_alignment_dir = train_alignment_dir
+ self.train_chain_data_cache_path = train_chain_data_cache_path
+ self.distillation_data_dir = distillation_data_dir
+ self.distillation_alignment_dir = distillation_alignment_dir
+ self.distillation_chain_data_cache_path = (
+ distillation_chain_data_cache_path
+ )
+ self.val_data_dir = val_data_dir
+ self.val_alignment_dir = val_alignment_dir
+ self.predict_data_dir = predict_data_dir
+ self.predict_alignment_dir = predict_alignment_dir
+ self.kalign_binary_path = kalign_binary_path
+ self.train_mapping_path = train_mapping_path
+ self.distillation_mapping_path = distillation_mapping_path
+ self.template_release_dates_cache_path = (
+ template_release_dates_cache_path
+ )
+ self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
+ self.batch_seed = batch_seed
+ self.train_epoch_len = train_epoch_len
+
+ if(self.train_data_dir is None and self.predict_data_dir is None):
+ raise ValueError(
+ 'At least one of train_data_dir or predict_data_dir must be '
+ 'specified'
+ )
+
+ self.training_mode = self.train_data_dir is not None
+
+ if(self.training_mode and train_alignment_dir is None):
+ raise ValueError(
+ 'In training mode, train_alignment_dir must be specified'
+ )
+ elif(not self.training_mode and predict_alignment_dir is None):
+ raise ValueError(
+ 'In inference mode, predict_alignment_dir must be specified'
+ )
+ elif(val_data_dir is not None and val_alignment_dir is None):
+ raise ValueError(
+ 'If val_data_dir is specified, val_alignment_dir must '
+ 'be specified as well'
+ )
+
+ # An ad-hoc measure for our particular filesystem restrictions
+ self._alignment_index = None
+ if(_alignment_index_path is not None):
+ with open(_alignment_index_path, "r") as fp:
+ self._alignment_index = json.load(fp)
+
+ def setup(self):
+ # Most of the arguments are the same for the three datasets
+ dataset_gen = partial(OpenFoldSingleDataset,
+ template_mmcif_dir=self.template_mmcif_dir,
+ max_template_date=self.max_template_date,
+ config=self.config,
+ kalign_binary_path=self.kalign_binary_path,
+ template_release_dates_cache_path=
+ self.template_release_dates_cache_path,
+ obsolete_pdbs_file_path=
+ self.obsolete_pdbs_file_path,
+ )
+
+ if(self.training_mode):
+ train_dataset = dataset_gen(
+ data_dir=self.train_data_dir,
+ alignment_dir=self.train_alignment_dir,
+ mapping_path=self.train_mapping_path,
+ max_template_hits=self.config.train.max_template_hits,
+ shuffle_top_k_prefiltered=
+ self.config.train.shuffle_top_k_prefiltered,
+ treat_pdb_as_distillation=False,
+ mode="train",
+ _output_raw=True,
+ _alignment_index=self._alignment_index,
+ )
+
+ distillation_dataset = None
+ if(self.distillation_data_dir is not None):
+ distillation_dataset = dataset_gen(
+ data_dir=self.distillation_data_dir,
+ alignment_dir=self.distillation_alignment_dir,
+ mapping_path=self.distillation_mapping_path,
+ max_template_hits=self.train.max_template_hits,
+ treat_pdb_as_distillation=True,
+ mode="train",
+ _output_raw=True,
+ )
+
+ d_prob = self.config.train.distillation_prob
+
+ if(distillation_dataset is not None):
+ datasets = [train_dataset, distillation_dataset]
+ d_prob = self.config.train.distillation_prob
+ probabilities = [1 - d_prob, d_prob]
+ chain_data_cache_paths = [
+ self.train_chain_data_cache_path,
+ self.distillation_chain_data_cache_path,
+ ]
+ else:
+ datasets = [train_dataset]
+ probabilities = [1.]
+ chain_data_cache_paths = [
+ self.train_chain_data_cache_path,
+ ]
+
+ self.train_dataset = OpenFoldDataset(
+ datasets=datasets,
+ probabilities=probabilities,
+ epoch_len=self.train_epoch_len,
+ chain_data_cache_paths=chain_data_cache_paths,
+ _roll_at_init=False,
+ )
+
+ if(self.val_data_dir is not None):
+ self.eval_dataset = dataset_gen(
+ data_dir=self.val_data_dir,
+ alignment_dir=self.val_alignment_dir,
+ mapping_path=None,
+ max_template_hits=self.config.eval.max_template_hits,
+ mode="eval",
+ _output_raw=True,
+ )
+ else:
+ self.eval_dataset = None
+ else:
+ self.predict_dataset = dataset_gen(
+ data_dir=self.predict_data_dir,
+ alignment_dir=self.predict_alignment_dir,
+ mapping_path=None,
+ max_template_hits=self.config.predict.max_template_hits,
+ mode="predict",
+ )
+
+ def _gen_dataloader(self, stage):
+ generator = torch.Generator()
+ if(self.batch_seed is not None):
+ generator = generator.manual_seed(self.batch_seed)
+
+ dataset = None
+ if(stage == "train"):
+ dataset = self.train_dataset
+
+ # Filter the dataset, if necessary
+ dataset.reroll()
+ elif(stage == "eval"):
+ dataset = self.eval_dataset
+ elif(stage == "predict"):
+ dataset = self.predict_dataset
+ else:
+ raise ValueError("Invalid stage")
+
+ batch_collator = OpenFoldBatchCollator(self.config, stage)
+
+ dl = OpenFoldDataLoader(
+ dataset,
+ config=self.config,
+ stage=stage,
+ generator=generator,
+ batch_size=self.config.data_module.data_loaders.batch_size,
+ num_workers=self.config.data_module.data_loaders.num_workers,
+ collate_fn=batch_collator,
+ )
+
+ return dl
+
+ def train_dataloader(self):
+ return self._gen_dataloader("train")
+
+ def val_dataloader(self):
+ if(self.eval_dataset is not None):
+ return self._gen_dataloader("eval")
+ return None
+
+ def predict_dataloader(self):
+ return self._gen_dataloader("predict")
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ def __init__(self, batch_path):
+ with open(batch_path, "rb") as f:
+ self.batch = pickle.load(f)
+
+ def __getitem__(self, idx):
+ return copy.deepcopy(self.batch)
+
+ def __len__(self):
+ return 1000
+
+
+class DummyDataLoader(pl.LightningDataModule):
+ def __init__(self, batch_path):
+ super().__init__()
+ self.dataset = DummyDataset(batch_path)
+
+ def train_dataloader(self):
+ return torch.utils.data.DataLoader(self.dataset)
diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb6419b7e1c48679005812cb522078f740978239
--- /dev/null
+++ b/openfold/data/data_pipeline.py
@@ -0,0 +1,678 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import datetime
+from multiprocessing import cpu_count
+from typing import Mapping, Optional, Sequence, Any
+
+import numpy as np
+
+from openfold.data import templates, parsers, mmcif_parsing
+from openfold.data.tools import jackhmmer, hhblits, hhsearch
+from openfold.data.tools.utils import to_date
+from openfold.np import residue_constants, protein
+
+
+FeatureDict = Mapping[str, np.ndarray]
+
+def empty_template_feats(n_res) -> FeatureDict:
+ return {
+ "template_aatype": np.zeros((0, n_res)).astype(np.int64),
+ "template_all_atom_positions":
+ np.zeros((0, n_res, 37, 3)).astype(np.float32),
+ "template_sum_probs": np.zeros((0, 1)).astype(np.float32),
+ "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
+ }
+
+
+def make_template_features(
+ input_sequence: str,
+ hits: Sequence[Any],
+ template_featurizer: Any,
+ query_pdb_code: Optional[str] = None,
+ query_release_date: Optional[str] = None,
+) -> FeatureDict:
+ hits_cat = sum(hits.values(), [])
+ if(len(hits_cat) == 0 or template_featurizer is None):
+ template_features = empty_template_feats(len(input_sequence))
+ else:
+ templates_result = template_featurizer.get_templates(
+ query_sequence=input_sequence,
+ query_pdb_code=query_pdb_code,
+ query_release_date=query_release_date,
+ hits=hits_cat,
+ )
+ template_features = templates_result.features
+
+ # The template featurizer doesn't format empty template features
+ # properly. This is a quick fix.
+ if(template_features["template_aatype"].shape[0] == 0):
+ template_features = empty_template_feats(len(input_sequence))
+
+ return template_features
+
+
+def make_sequence_features(
+ sequence: str, description: str, num_res: int
+) -> FeatureDict:
+ """Construct a feature dict of sequence features."""
+ features = {}
+ features["aatype"] = residue_constants.sequence_to_onehot(
+ sequence=sequence,
+ mapping=residue_constants.restype_order_with_x,
+ map_unknown_to_x=True,
+ )
+ features["between_segment_residues"] = np.zeros((num_res,), dtype=int)
+ features["domain_name"] = np.array(
+ [description.encode("utf-8")], dtype=np.object_
+ )
+ features["residue_index"] = np.array(range(num_res), dtype=int)
+ features["seq_length"] = np.array([num_res] * num_res, dtype=int)
+ features["sequence"] = np.array(
+ [sequence.encode("utf-8")], dtype=np.object_
+ )
+ return features
+
+
+def make_mmcif_features(
+ mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
+) -> FeatureDict:
+ input_sequence = mmcif_object.chain_to_seqres[chain_id]
+ description = "_".join([mmcif_object.file_id, chain_id])
+ num_res = len(input_sequence)
+
+ mmcif_feats = {}
+
+ mmcif_feats.update(
+ make_sequence_features(
+ sequence=input_sequence,
+ description=description,
+ num_res=num_res,
+ )
+ )
+
+ all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
+ mmcif_object=mmcif_object, chain_id=chain_id
+ )
+ mmcif_feats["all_atom_positions"] = all_atom_positions
+ mmcif_feats["all_atom_mask"] = all_atom_mask
+
+ mmcif_feats["resolution"] = np.array(
+ [mmcif_object.header["resolution"]], dtype=np.float32
+ )
+
+ mmcif_feats["release_date"] = np.array(
+ [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
+ )
+
+ mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
+
+ return mmcif_feats
+
+
+def _aatype_to_str_sequence(aatype):
+ return ''.join([
+ residue_constants.restypes_with_x[aatype[i]]
+ for i in range(len(aatype))
+ ])
+
+
+def make_protein_features(
+ protein_object: protein.Protein,
+ description: str,
+ _is_distillation: bool = False,
+) -> FeatureDict:
+ pdb_feats = {}
+ aatype = protein_object.aatype
+ sequence = _aatype_to_str_sequence(aatype)
+ pdb_feats.update(
+ make_sequence_features(
+ sequence=sequence,
+ description=description,
+ num_res=len(protein_object.aatype),
+ )
+ )
+
+ all_atom_positions = protein_object.atom_positions
+ all_atom_mask = protein_object.atom_mask
+
+ pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
+ pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
+
+ pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
+ pdb_feats["is_distillation"] = np.array(
+ 1. if _is_distillation else 0.
+ ).astype(np.float32)
+
+ return pdb_feats
+
+
+def make_pdb_features(
+ protein_object: protein.Protein,
+ description: str,
+ confidence_threshold: float = 0.5,
+ is_distillation: bool = True,
+) -> FeatureDict:
+ pdb_feats = make_protein_features(
+ protein_object, description, _is_distillation=True
+ )
+
+ if(is_distillation):
+ high_confidence = protein_object.b_factors > confidence_threshold
+ high_confidence = np.any(high_confidence, axis=-1)
+ for i, confident in enumerate(high_confidence):
+ if(not confident):
+ pdb_feats["all_atom_mask"][i] = 0
+
+ return pdb_feats
+
+
+def make_msa_features(
+ msas: Sequence[Sequence[str]],
+ deletion_matrices: Sequence[parsers.DeletionMatrix],
+) -> FeatureDict:
+ """Constructs a feature dict of MSA features."""
+ if not msas:
+ raise ValueError("At least one MSA must be provided.")
+
+ int_msa = []
+ deletion_matrix = []
+ seen_sequences = set()
+ for msa_index, msa in enumerate(msas):
+ if not msa:
+ raise ValueError(
+ f"MSA {msa_index} must contain at least one sequence."
+ )
+ for sequence_index, sequence in enumerate(msa):
+ if sequence in seen_sequences:
+ continue
+ seen_sequences.add(sequence)
+ int_msa.append(
+ [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
+ )
+ deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
+
+ num_res = len(msas[0][0])
+ num_alignments = len(int_msa)
+ features = {}
+ features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=int)
+ features["msa"] = np.array(int_msa, dtype=int)
+ features["num_alignments"] = np.array(
+ [num_alignments] * num_res, dtype=int
+ )
+ return features
+
+
+class AlignmentRunner:
+ """Runs alignment tools and saves the results"""
+ def __init__(
+ self,
+ jackhmmer_binary_path: Optional[str] = None,
+ hhblits_binary_path: Optional[str] = None,
+ hhsearch_binary_path: Optional[str] = None,
+ uniref90_database_path: Optional[str] = None,
+ mgnify_database_path: Optional[str] = None,
+ bfd_database_path: Optional[str] = None,
+ uniclust30_database_path: Optional[str] = None,
+ pdb70_database_path: Optional[str] = None,
+ use_small_bfd: Optional[bool] = None,
+ no_cpus: Optional[int] = None,
+ uniref_max_hits: int = 10000,
+ mgnify_max_hits: int = 5000,
+ ):
+ """
+ Args:
+ jackhmmer_binary_path:
+ Path to jackhmmer binary
+ hhblits_binary_path:
+ Path to hhblits binary
+ hhsearch_binary_path:
+ Path to hhsearch binary
+ uniref90_database_path:
+ Path to uniref90 database. If provided, jackhmmer_binary_path
+ must also be provided
+ mgnify_database_path:
+ Path to mgnify database. If provided, jackhmmer_binary_path
+ must also be provided
+ bfd_database_path:
+ Path to BFD database. Depending on the value of use_small_bfd,
+ one of hhblits_binary_path or jackhmmer_binary_path must be
+ provided.
+ uniclust30_database_path:
+ Path to uniclust30. Searched alongside BFD if use_small_bfd is
+ false.
+ pdb70_database_path:
+ Path to pdb70 database.
+ use_small_bfd:
+ Whether to search the BFD database alone with jackhmmer or
+ in conjunction with uniclust30 with hhblits.
+ no_cpus:
+ The number of CPUs available for alignment. By default, all
+ CPUs are used.
+ uniref_max_hits:
+ Max number of uniref hits
+ mgnify_max_hits:
+ Max number of mgnify hits
+ """
+ db_map = {
+ "jackhmmer": {
+ "binary": jackhmmer_binary_path,
+ "dbs": [
+ uniref90_database_path,
+ mgnify_database_path,
+ bfd_database_path if use_small_bfd else None,
+ ],
+ },
+ "hhblits": {
+ "binary": hhblits_binary_path,
+ "dbs": [
+ bfd_database_path if not use_small_bfd else None,
+ ],
+ },
+ "hhsearch": {
+ "binary": hhsearch_binary_path,
+ "dbs": [
+ pdb70_database_path,
+ ],
+ },
+ }
+
+ for name, dic in db_map.items():
+ binary, dbs = dic["binary"], dic["dbs"]
+ if(binary is None and not all([x is None for x in dbs])):
+ raise ValueError(
+ f"{name} DBs provided but {name} binary is None"
+ )
+
+ if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
+ and uniref90_database_path is None):
+ raise ValueError(
+ """uniref90_database_path must be specified in order to perform
+ template search"""
+ )
+
+ self.uniref_max_hits = uniref_max_hits
+ self.mgnify_max_hits = mgnify_max_hits
+ self.use_small_bfd = use_small_bfd
+
+ if(no_cpus is None):
+ no_cpus = cpu_count()
+
+ self.jackhmmer_uniref90_runner = None
+ if(jackhmmer_binary_path is not None and
+ uniref90_database_path is not None
+ ):
+ self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
+ binary_path=jackhmmer_binary_path,
+ database_path=uniref90_database_path,
+ n_cpu=no_cpus,
+ )
+
+ self.jackhmmer_small_bfd_runner = None
+ self.hhblits_bfd_uniclust_runner = None
+ if(bfd_database_path is not None):
+ if use_small_bfd:
+ self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
+ binary_path=jackhmmer_binary_path,
+ database_path=bfd_database_path,
+ n_cpu=no_cpus,
+ )
+ else:
+ dbs = [bfd_database_path]
+ if(uniclust30_database_path is not None):
+ dbs.append(uniclust30_database_path)
+ self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
+ binary_path=hhblits_binary_path,
+ databases=dbs,
+ n_cpu=no_cpus,
+ )
+
+ self.jackhmmer_mgnify_runner = None
+ if(mgnify_database_path is not None):
+ self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
+ binary_path=jackhmmer_binary_path,
+ database_path=mgnify_database_path,
+ n_cpu=no_cpus,
+ )
+
+ self.hhsearch_pdb70_runner = None
+ if(pdb70_database_path is not None):
+ self.hhsearch_pdb70_runner = hhsearch.HHSearch(
+ binary_path=hhsearch_binary_path,
+ databases=[pdb70_database_path],
+ n_cpu=no_cpus,
+ )
+
+ def run(
+ self,
+ fasta_path: str,
+ output_dir: str,
+ ):
+ """Runs alignment tools on a sequence"""
+ if(self.jackhmmer_uniref90_runner is not None):
+ jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
+ fasta_path
+ )[0]
+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
+ jackhmmer_uniref90_result["sto"],
+ max_sequences=self.uniref_max_hits
+ )
+ uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
+ with open(uniref90_out_path, "w") as f:
+ f.write(uniref90_msa_as_a3m)
+
+ if(self.hhsearch_pdb70_runner is not None):
+ hhsearch_result = self.hhsearch_pdb70_runner.query(
+ uniref90_msa_as_a3m
+ )
+ pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
+ with open(pdb70_out_path, "w") as f:
+ f.write(hhsearch_result)
+
+ if(self.jackhmmer_mgnify_runner is not None):
+ jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
+ fasta_path
+ )[0]
+ mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
+ jackhmmer_mgnify_result["sto"],
+ max_sequences=self.mgnify_max_hits
+ )
+ mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
+ with open(mgnify_out_path, "w") as f:
+ f.write(mgnify_msa_as_a3m)
+
+ if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
+ jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
+ fasta_path
+ )[0]
+ bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
+ with open(bfd_out_path, "w") as f:
+ f.write(jackhmmer_small_bfd_result["sto"])
+ elif(self.hhblits_bfd_uniclust_runner is not None):
+ hhblits_bfd_uniclust_result = (
+ self.hhblits_bfd_uniclust_runner.query(fasta_path)
+ )
+ if output_dir is not None:
+ bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
+ with open(bfd_out_path, "w") as f:
+ f.write(hhblits_bfd_uniclust_result["a3m"])
+
+
+class DataPipeline:
+ """Assembles input features."""
+ def __init__(
+ self,
+ template_featurizer: Optional[templates.TemplateHitFeaturizer],
+ ):
+ self.template_featurizer = template_featurizer
+
+ def _parse_msa_data(
+ self,
+ alignment_dir: str,
+ _alignment_index: Optional[Any] = None,
+ ) -> Mapping[str, Any]:
+ msa_data = {}
+
+ if(_alignment_index is not None):
+ fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
+
+ def read_msa(start, size):
+ fp.seek(start)
+ msa = fp.read(size).decode("utf-8")
+ return msa
+
+ for (name, start, size) in _alignment_index["files"]:
+ ext = os.path.splitext(name)[-1]
+
+ if(ext == ".a3m"):
+ msa, deletion_matrix = parsers.parse_a3m(
+ read_msa(start, size)
+ )
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
+ elif(ext == ".sto"):
+ msa, deletion_matrix, _ = parsers.parse_stockholm(
+ read_msa(start, size)
+ )
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
+ else:
+ continue
+
+ msa_data[name] = data
+
+ fp.close()
+ else:
+ for f in os.listdir(alignment_dir):
+ path = os.path.join(alignment_dir, f)
+ ext = os.path.splitext(f)[-1]
+
+ if(ext == ".a3m"):
+ with open(path, "r") as fp:
+ msa, deletion_matrix = parsers.parse_a3m(fp.read())
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
+ elif(ext == ".sto"):
+ with open(path, "r") as fp:
+ msa, deletion_matrix, _ = parsers.parse_stockholm(
+ fp.read()
+ )
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
+ else:
+ continue
+
+ msa_data[f] = data
+
+ return msa_data
+
+ def _parse_template_hits(
+ self,
+ alignment_dir: str,
+ _alignment_index: Optional[Any] = None
+ ) -> Mapping[str, Any]:
+ all_hits = {}
+ if(_alignment_index is not None):
+ fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
+
+ def read_template(start, size):
+ fp.seek(start)
+ return fp.read(size).decode("utf-8")
+
+ for (name, start, size) in _alignment_index["files"]:
+ ext = os.path.splitext(name)[-1]
+
+ if(ext == ".hhr"):
+ hits = parsers.parse_hhr(read_template(start, size))
+ all_hits[name] = hits
+
+ fp.close()
+ else:
+ for f in os.listdir(alignment_dir):
+ path = os.path.join(alignment_dir, f)
+ ext = os.path.splitext(f)[-1]
+
+ if(ext == ".hhr"):
+ with open(path, "r") as fp:
+ hits = parsers.parse_hhr(fp.read())
+ all_hits[f] = hits
+
+ return all_hits
+
+ def _process_msa_feats(
+ self,
+ alignment_dir: str,
+ input_sequence: Optional[str] = None,
+ _alignment_index: Optional[str] = None
+ ) -> Mapping[str, Any]:
+ msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
+
+ if(len(msa_data) == 0):
+ if(input_sequence is None):
+ raise ValueError(
+ """
+ If the alignment dir contains no MSAs, an input sequence
+ must be provided.
+ """
+ )
+ msa_data["dummy"] = {
+ "msa": [input_sequence],
+ "deletion_matrix": [[0 for _ in input_sequence]],
+ }
+
+ msas, deletion_matrices = zip(*[
+ (v["msa"], v["deletion_matrix"]) for v in msa_data.values()
+ ])
+
+ msa_features = make_msa_features(
+ msas=msas,
+ deletion_matrices=deletion_matrices,
+ )
+
+ return msa_features
+
+ def process_fasta(
+ self,
+ fasta_path: str,
+ alignment_dir: str,
+ _alignment_index: Optional[str] = None,
+ ) -> FeatureDict:
+ """Assembles features for a single sequence in a FASTA file"""
+ with open(fasta_path) as f:
+ fasta_str = f.read()
+ input_seqs, input_descs = parsers.parse_fasta(fasta_str)
+ if len(input_seqs) != 1:
+ raise ValueError(
+ f"More than one input sequence found in {fasta_path}."
+ )
+ input_sequence = input_seqs[0]
+ input_description = input_descs[0]
+ num_res = len(input_sequence)
+
+ hits = self._parse_template_hits(alignment_dir, _alignment_index)
+ template_features = make_template_features(
+ input_sequence,
+ hits,
+ self.template_featurizer,
+ )
+
+ sequence_features = make_sequence_features(
+ sequence=input_sequence,
+ description=input_description,
+ num_res=num_res,
+ )
+
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
+
+ return {
+ **sequence_features,
+ **msa_features,
+ **template_features
+ }
+
+ def process_mmcif(
+ self,
+ mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
+ alignment_dir: str,
+ chain_id: Optional[str] = None,
+ _alignment_index: Optional[str] = None,
+ ) -> FeatureDict:
+ """
+ Assembles features for a specific chain in an mmCIF object.
+
+ If chain_id is None, it is assumed that there is only one chain
+ in the object. Otherwise, a ValueError is thrown.
+ """
+ if chain_id is None:
+ chains = mmcif.structure.get_chains()
+ chain = next(chains, None)
+ if chain is None:
+ raise ValueError("No chains in mmCIF file")
+ chain_id = chain.id
+
+ mmcif_feats = make_mmcif_features(mmcif, chain_id)
+
+ input_sequence = mmcif.chain_to_seqres[chain_id]
+ hits = self._parse_template_hits(alignment_dir, _alignment_index)
+ template_features = make_template_features(
+ input_sequence,
+ hits,
+ self.template_featurizer,
+ query_release_date=to_date(mmcif.header["release_date"])
+ )
+
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
+
+ return {**mmcif_feats, **template_features, **msa_features}
+
+ def process_pdb(
+ self,
+ pdb_path: str,
+ alignment_dir: str,
+ is_distillation: bool = True,
+ chain_id: Optional[str] = None,
+ _alignment_index: Optional[str] = None,
+ ) -> FeatureDict:
+ """
+ Assembles features for a protein in a PDB file.
+ """
+ with open(pdb_path, 'r') as f:
+ pdb_str = f.read()
+
+ protein_object = protein.from_pdb_string(pdb_str, chain_id)
+ input_sequence = _aatype_to_str_sequence(protein_object.aatype)
+ description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
+ pdb_feats = make_pdb_features(
+ protein_object,
+ description,
+ is_distillation
+ )
+
+ hits = self._parse_template_hits(alignment_dir, _alignment_index)
+ template_features = make_template_features(
+ input_sequence,
+ hits,
+ self.template_featurizer,
+ )
+
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
+
+ return {**pdb_feats, **template_features, **msa_features}
+
+ def process_core(
+ self,
+ core_path: str,
+ alignment_dir: str,
+ _alignment_index: Optional[str] = None,
+ ) -> FeatureDict:
+ """
+ Assembles features for a protein in a ProteinNet .core file.
+ """
+ with open(core_path, 'r') as f:
+ core_str = f.read()
+
+ protein_object = protein.from_proteinnet_string(core_str)
+ input_sequence = _aatype_to_str_sequence(protein_object.aatype)
+ description = os.path.splitext(os.path.basename(core_path))[0].upper()
+ core_feats = make_protein_features(protein_object, description)
+
+ hits = self._parse_template_hits(alignment_dir, _alignment_index)
+ template_features = make_template_features(
+ input_sequence,
+ hits,
+ self.template_featurizer,
+ )
+
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence)
+
+ return {**core_feats, **template_features, **msa_features}
+
diff --git a/openfold/data/data_transforms.py b/openfold/data/data_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f5b19f4c74b035563944894984de59214bb48dd
--- /dev/null
+++ b/openfold/data/data_transforms.py
@@ -0,0 +1,1191 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+from functools import reduce, wraps
+from operator import add
+
+import numpy as np
+import torch
+
+from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
+from openfold.np import residue_constants as rc
+from openfold.utils.rigid_utils import Rotation, Rigid
+from openfold.utils.tensor_utils import (
+ tree_map,
+ tensor_tree_map,
+ batched_gather,
+)
+
+
+MSA_FEATURE_NAMES = [
+ "msa",
+ "deletion_matrix",
+ "msa_mask",
+ "msa_row_mask",
+ "bert_mask",
+ "true_msa",
+]
+
+
+def cast_to_64bit_ints(protein):
+ # We keep all ints as int64
+ for k, v in protein.items():
+ if v.dtype == torch.int32:
+ protein[k] = v.type(torch.int64)
+
+ return protein
+
+
+def make_one_hot(x, num_classes):
+ x_one_hot = torch.zeros(*x.shape, num_classes)
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
+ return x_one_hot
+
+
+def make_seq_mask(protein):
+ protein["seq_mask"] = torch.ones(
+ protein["aatype"].shape, dtype=torch.float32
+ )
+ return protein
+
+
+def make_template_mask(protein):
+ protein["template_mask"] = torch.ones(
+ protein["template_aatype"].shape[0], dtype=torch.float32
+ )
+ return protein
+
+
+def curry1(f):
+ """Supply all arguments but the first."""
+ @wraps(f)
+ def fc(*args, **kwargs):
+ return lambda x: f(x, *args, **kwargs)
+
+ return fc
+
+
+def make_all_atom_aatype(protein):
+ protein["all_atom_aatype"] = protein["aatype"]
+ return protein
+
+
+def fix_templates_aatype(protein):
+ # Map one-hot to indices
+ num_templates = protein["template_aatype"].shape[0]
+ if(num_templates > 0):
+ protein["template_aatype"] = torch.argmax(
+ protein["template_aatype"], dim=-1
+ )
+ # Map hhsearch-aatype to our aatype.
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
+ num_templates, -1
+ )
+ protein["template_aatype"] = torch.gather(
+ new_order, 1, index=protein["template_aatype"]
+ )
+
+ return protein
+
+
+def correct_msa_restypes(protein):
+ """Correct MSA restype to have the same order as rc."""
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = torch.tensor(
+ [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
+ ).transpose(0, 1)
+ protein["msa"] = torch.gather(new_order, 0, protein["msa"])
+
+ perm_matrix = np.zeros((22, 22), dtype=np.float32)
+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
+
+ for k in protein:
+ if "profile" in k:
+ num_dim = protein[k].shape.as_list()[-1]
+ assert num_dim in [
+ 20,
+ 21,
+ 22,
+ ], "num_dim for %s out of expected range: %s" % (k, num_dim)
+ protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
+
+ return protein
+
+
+def squeeze_features(protein):
+ """Remove singleton and repeated dimensions in protein features."""
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
+ for k in [
+ "domain_name",
+ "msa",
+ "num_alignments",
+ "seq_length",
+ "sequence",
+ "superfamily",
+ "deletion_matrix",
+ "resolution",
+ "between_segment_residues",
+ "residue_index",
+ "template_all_atom_mask",
+ ]:
+ if k in protein:
+ final_dim = protein[k].shape[-1]
+ if isinstance(final_dim, int) and final_dim == 1:
+ if torch.is_tensor(protein[k]):
+ protein[k] = torch.squeeze(protein[k], dim=-1)
+ else:
+ protein[k] = np.squeeze(protein[k], axis=-1)
+
+ for k in ["seq_length", "num_alignments"]:
+ if k in protein:
+ protein[k] = protein[k][0]
+
+ return protein
+
+
+@curry1
+def randomly_replace_msa_with_unknown(protein, replace_proportion):
+ """Replace a portion of the MSA with 'X'."""
+ msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
+ x_idx = 20
+ gap_idx = 21
+ msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
+ protein["msa"] = torch.where(
+ msa_mask,
+ torch.ones_like(protein["msa"]) * x_idx,
+ protein["msa"]
+ )
+ aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
+
+ protein["aatype"] = torch.where(
+ aatype_mask,
+ torch.ones_like(protein["aatype"]) * x_idx,
+ protein["aatype"],
+ )
+ return protein
+
+
+@curry1
+def sample_msa(protein, max_seq, keep_extra, seed=None):
+ """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
+ num_seq = protein["msa"].shape[0]
+ g = torch.Generator(device=protein["msa"].device)
+ if seed is not None:
+ g.manual_seed(seed)
+ shuffled = torch.randperm(num_seq - 1, generator=g) + 1
+ index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
+ num_sel = min(max_seq, num_seq)
+ sel_seq, not_sel_seq = torch.split(
+ index_order, [num_sel, num_seq - num_sel]
+ )
+
+ for k in MSA_FEATURE_NAMES:
+ if k in protein:
+ if keep_extra:
+ protein["extra_" + k] = torch.index_select(
+ protein[k], 0, not_sel_seq
+ )
+ protein[k] = torch.index_select(protein[k], 0, sel_seq)
+
+ return protein
+
+
+@curry1
+def add_distillation_flag(protein, distillation):
+ protein['is_distillation'] = distillation
+ return protein
+
+@curry1
+def sample_msa_distillation(protein, max_seq):
+ if(protein["is_distillation"] == 1):
+ protein = sample_msa(max_seq, keep_extra=False)(protein)
+ return protein
+
+
+@curry1
+def crop_extra_msa(protein, max_extra_msa):
+ num_seq = protein["extra_msa"].shape[0]
+ num_sel = min(max_extra_msa, num_seq)
+ select_indices = torch.randperm(num_seq)[:num_sel]
+ for k in MSA_FEATURE_NAMES:
+ if "extra_" + k in protein:
+ protein["extra_" + k] = torch.index_select(
+ protein["extra_" + k], 0, select_indices
+ )
+
+ return protein
+
+
+def delete_extra_msa(protein):
+ for k in MSA_FEATURE_NAMES:
+ if "extra_" + k in protein:
+ del protein["extra_" + k]
+ return protein
+
+
+# Not used in inference
+@curry1
+def block_delete_msa(protein, config):
+ num_seq = protein["msa"].shape[0]
+ block_num_seq = torch.floor(
+ torch.tensor(num_seq, dtype=torch.float32)
+ * config.msa_fraction_per_block
+ ).to(torch.int32)
+
+ if config.randomize_num_blocks:
+ nb = torch.distributions.uniform.Uniform(
+ 0, config.num_blocks + 1
+ ).sample()
+ else:
+ nb = config.num_blocks
+
+ del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
+ del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
+ del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
+ del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
+
+ # Make sure we keep the original sequence
+ combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
+ uniques, counts = combined.unique(return_counts=True)
+ difference = uniques[counts == 1]
+ intersection = uniques[counts > 1]
+ keep_indices = torch.squeeze(difference, 0)
+
+ for k in MSA_FEATURE_NAMES:
+ if k in protein:
+ protein[k] = torch.gather(protein[k], keep_indices)
+
+ return protein
+
+
+@curry1
+def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
+ weights = torch.cat(
+ [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
+ 0,
+ )
+
+ # Make agreement score as weighted Hamming distance
+ msa_one_hot = make_one_hot(protein["msa"], 23)
+ sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
+ extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
+ extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
+
+ num_seq, num_res, _ = sample_one_hot.shape
+ extra_num_seq, _, _ = extra_one_hot.shape
+
+ # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
+ # in an optimized fashion to avoid possible memory or computation blowup.
+ agreement = torch.matmul(
+ torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
+ torch.reshape(
+ sample_one_hot * weights, [num_seq, num_res * 23]
+ ).transpose(0, 1),
+ )
+
+ # Assign each sequence in the extra sequences to the closest MSA sample
+ protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
+ torch.int64
+ )
+
+ return protein
+
+
+def unsorted_segment_sum(data, segment_ids, num_segments):
+ """
+ Computes the sum along segments of a tensor. Similar to
+ tf.unsorted_segment_sum, but only supports 1-D indices.
+
+ :param data: A tensor whose segments are to be summed.
+ :param segment_ids: The 1-D segment indices tensor.
+ :param num_segments: The number of segments.
+ :return: A tensor of same data type as the data argument.
+ """
+ assert (
+ len(segment_ids.shape) == 1 and
+ segment_ids.shape[0] == data.shape[0]
+ )
+ segment_ids = segment_ids.view(
+ segment_ids.shape[0], *((1,) * len(data.shape[1:]))
+ )
+ segment_ids = segment_ids.expand(data.shape)
+ shape = [num_segments] + list(data.shape[1:])
+ tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
+ tensor = tensor.type(data.dtype)
+ return tensor
+
+
+@curry1
+def summarize_clusters(protein):
+ """Produce profile and deletion_matrix_mean within each cluster."""
+ num_seq = protein["msa"].shape[0]
+
+ def csum(x):
+ return unsorted_segment_sum(
+ x, protein["extra_cluster_assignment"], num_seq
+ )
+
+ mask = protein["extra_msa_mask"]
+ mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
+
+ msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
+ msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
+ protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
+ del msa_sum
+
+ del_sum = csum(mask * protein["extra_deletion_matrix"])
+ del_sum += protein["deletion_matrix"] # Original sequence
+ protein["cluster_deletion_mean"] = del_sum / mask_counts
+ del del_sum
+
+ return protein
+
+
+def make_msa_mask(protein):
+ """Mask features are all ones, but will later be zero-padded."""
+ protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
+ protein["msa_row_mask"] = torch.ones(
+ (protein["msa"].shape[0]), dtype=torch.float32
+ )
+ return protein
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
+ """Create pseudo beta features."""
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
+ ca_idx = rc.atom_order["CA"]
+ cb_idx = rc.atom_order["CB"]
+ pseudo_beta = torch.where(
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :],
+ )
+
+ if all_atom_mask is not None:
+ pseudo_beta_mask = torch.where(
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
+ )
+ return pseudo_beta, pseudo_beta_mask
+ else:
+ return pseudo_beta
+
+
+@curry1
+def make_pseudo_beta(protein, prefix=""):
+ """Create pseudo-beta (alpha for glycine) position and mask."""
+ assert prefix in ["", "template_"]
+ (
+ protein[prefix + "pseudo_beta"],
+ protein[prefix + "pseudo_beta_mask"],
+ ) = pseudo_beta_fn(
+ protein["template_aatype" if prefix else "aatype"],
+ protein[prefix + "all_atom_positions"],
+ protein["template_all_atom_mask" if prefix else "all_atom_mask"],
+ )
+ return protein
+
+
+@curry1
+def add_constant_field(protein, key, value):
+ protein[key] = torch.tensor(value)
+ return protein
+
+
+def shaped_categorical(probs, epsilon=1e-10):
+ ds = probs.shape
+ num_classes = ds[-1]
+ distribution = torch.distributions.categorical.Categorical(
+ torch.reshape(probs + epsilon, [-1, num_classes])
+ )
+ counts = distribution.sample()
+ return torch.reshape(counts, ds[:-1])
+
+
+def make_hhblits_profile(protein):
+ """Compute the HHblits MSA profile if not already present."""
+ if "hhblits_profile" in protein:
+ return protein
+
+ # Compute the profile for every residue (over all MSA sequences).
+ msa_one_hot = make_one_hot(protein["msa"], 22)
+
+ protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
+ return protein
+
+
+@curry1
+def make_masked_msa(protein, config, replace_fraction):
+ """Create data for BERT on raw MSA."""
+ # Add a random amino acid uniformly.
+ random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
+
+ categorical_probs = (
+ config.uniform_prob * random_aa
+ + config.profile_prob * protein["hhblits_profile"]
+ + config.same_prob * make_one_hot(protein["msa"], 22)
+ )
+
+ # Put all remaining probability on [MASK] which is a new column
+ pad_shapes = list(
+ reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
+ )
+ pad_shapes[1] = 1
+ mask_prob = (
+ 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
+ )
+ assert mask_prob >= 0.0
+ categorical_probs = torch.nn.functional.pad(
+ categorical_probs, pad_shapes, value=mask_prob
+ )
+
+ sh = protein["msa"].shape
+ mask_position = torch.rand(sh) < replace_fraction
+
+ bert_msa = shaped_categorical(categorical_probs)
+ bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
+
+ # Mix real and masked MSA
+ protein["bert_mask"] = mask_position.to(torch.float32)
+ protein["true_msa"] = protein["msa"]
+ protein["msa"] = bert_msa
+
+ return protein
+
+
+@curry1
+def make_fixed_size(
+ protein,
+ shape_schema,
+ msa_cluster_size,
+ extra_msa_size,
+ num_res=0,
+ num_templates=0,
+):
+ """Guess at the MSA and sequence dimension to make fixed size."""
+ pad_size_map = {
+ NUM_RES: num_res,
+ NUM_MSA_SEQ: msa_cluster_size,
+ NUM_EXTRA_SEQ: extra_msa_size,
+ NUM_TEMPLATES: num_templates,
+ }
+
+ for k, v in protein.items():
+ # Don't transfer this to the accelerator.
+ if k == "extra_cluster_assignment":
+ continue
+ shape = list(v.shape)
+ schema = shape_schema[k]
+ msg = "Rank mismatch between shape and shape schema for"
+ assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
+ pad_size = [
+ pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
+ ]
+
+ padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
+ padding.reverse()
+ padding = list(itertools.chain(*padding))
+ if padding:
+ protein[k] = torch.nn.functional.pad(v, padding)
+ protein[k] = torch.reshape(protein[k], pad_size)
+
+ return protein
+
+
+@curry1
+def make_msa_feat(protein):
+ """Create and concatenate MSA features."""
+ # Whether there is a domain break. Always zero for chains, but keeping for
+ # compatibility with domain datasets.
+ has_break = torch.clip(
+ protein["between_segment_residues"].to(torch.float32), 0, 1
+ )
+ aatype_1hot = make_one_hot(protein["aatype"], 21)
+
+ target_feat = [
+ torch.unsqueeze(has_break, dim=-1),
+ aatype_1hot, # Everyone gets the original sequence.
+ ]
+
+ msa_1hot = make_one_hot(protein["msa"], 23)
+ has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
+ deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
+ 2.0 / np.pi
+ )
+
+ msa_feat = [
+ msa_1hot,
+ torch.unsqueeze(has_deletion, dim=-1),
+ torch.unsqueeze(deletion_value, dim=-1),
+ ]
+
+ if "cluster_profile" in protein:
+ deletion_mean_value = torch.atan(
+ protein["cluster_deletion_mean"] / 3.0
+ ) * (2.0 / np.pi)
+ msa_feat.extend(
+ [
+ protein["cluster_profile"],
+ torch.unsqueeze(deletion_mean_value, dim=-1),
+ ]
+ )
+
+ if "extra_deletion_matrix" in protein:
+ protein["extra_has_deletion"] = torch.clip(
+ protein["extra_deletion_matrix"], 0.0, 1.0
+ )
+ protein["extra_deletion_value"] = torch.atan(
+ protein["extra_deletion_matrix"] / 3.0
+ ) * (2.0 / np.pi)
+
+ protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
+ protein["target_feat"] = torch.cat(target_feat, dim=-1)
+ return protein
+
+
+@curry1
+def select_feat(protein, feature_list):
+ return {k: v for k, v in protein.items() if k in feature_list}
+
+
+@curry1
+def crop_templates(protein, max_templates):
+ for k, v in protein.items():
+ if k.startswith("template_"):
+ protein[k] = v[:max_templates]
+ return protein
+
+
+def make_atom14_masks(protein):
+ """Construct denser atom positions (14 dimensions instead of 37)."""
+ restype_atom14_to_atom37 = []
+ restype_atom37_to_atom14 = []
+ restype_atom14_mask = []
+
+ for rt in rc.restypes:
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
+ restype_atom14_to_atom37.append(
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
+ )
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
+ restype_atom37_to_atom14.append(
+ [
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
+ for name in rc.atom_types
+ ]
+ )
+
+ restype_atom14_mask.append(
+ [(1.0 if name else 0.0) for name in atom_names]
+ )
+
+ # Add dummy mapping for restype 'UNK'
+ restype_atom14_to_atom37.append([0] * 14)
+ restype_atom37_to_atom14.append([0] * 37)
+ restype_atom14_mask.append([0.0] * 14)
+
+ restype_atom14_to_atom37 = torch.tensor(
+ restype_atom14_to_atom37,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom37_to_atom14 = torch.tensor(
+ restype_atom37_to_atom14,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom14_mask = torch.tensor(
+ restype_atom14_mask,
+ dtype=torch.float32,
+ device=protein["aatype"].device,
+ )
+ protein_aatype = protein['aatype'].to(torch.long)
+
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
+ # with shape (num_res, 14) containing the atom37 indices for this protein
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
+
+ protein["atom14_atom_exists"] = residx_atom14_mask
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
+
+ # create the gather indices for mapping back
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
+
+ # create the corresponding mask
+ restype_atom37_mask = torch.zeros(
+ [21, 37], dtype=torch.float32, device=protein["aatype"].device
+ )
+ for restype, restype_letter in enumerate(rc.restypes):
+ restype_name = rc.restype_1to3[restype_letter]
+ atom_names = rc.residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = rc.atom_order[atom_name]
+ restype_atom37_mask[restype, atom_type] = 1
+
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
+ protein["atom37_atom_exists"] = residx_atom37_mask
+
+ return protein
+
+
+def make_atom14_masks_np(batch):
+ batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
+ out = make_atom14_masks(batch)
+ out = tensor_tree_map(lambda t: np.array(t), out)
+ return out
+
+
+def make_atom14_positions(protein):
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
+ residx_atom14_mask = protein["atom14_atom_exists"]
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
+
+ # Create a mask for known ground truth positions.
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
+ protein["all_atom_mask"],
+ residx_atom14_to_atom37,
+ dim=-1,
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
+ )
+
+ # Gather the ground truth positions.
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
+ batched_gather(
+ protein["all_atom_positions"],
+ residx_atom14_to_atom37,
+ dim=-2,
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
+ )
+ )
+
+ protein["atom14_atom_exists"] = residx_atom14_mask
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
+
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
+ # alternative ground truth coordinates where the naming is swapped
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
+ restype_3 += ["UNK"]
+
+ # Matrices for renaming ambiguous atoms.
+ all_matrices = {
+ res: torch.eye(
+ 14,
+ dtype=protein["all_atom_mask"].dtype,
+ device=protein["all_atom_mask"].device,
+ )
+ for res in restype_3
+ }
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
+ correspondences = torch.arange(
+ 14, device=protein["all_atom_mask"].device
+ )
+ for source_atom_swap, target_atom_swap in swap.items():
+ source_index = rc.restype_name_to_atom14_names[resname].index(
+ source_atom_swap
+ )
+ target_index = rc.restype_name_to_atom14_names[resname].index(
+ target_atom_swap
+ )
+ correspondences[source_index] = target_index
+ correspondences[target_index] = source_index
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
+ for index, correspondence in enumerate(correspondences):
+ renaming_matrix[index, correspondence] = 1.0
+ all_matrices[resname] = renaming_matrix
+ renaming_matrices = torch.stack(
+ [all_matrices[restype] for restype in restype_3]
+ )
+
+ # Pick the transformation matrices for the given residue sequence
+ # shape (num_res, 14, 14).
+ renaming_transform = renaming_matrices[protein["aatype"]]
+
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
+ alternative_gt_positions = torch.einsum(
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
+ )
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
+
+ # Create the mask for the alternative ground truth (differs from the
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
+ # ground truth position).
+ alternative_gt_mask = torch.einsum(
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
+ )
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
+
+ # Create an ambiguous atoms mask. shape: (21, 14).
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
+ for atom_name1, atom_name2 in swap.items():
+ restype = rc.restype_order[rc.restype_3to1[resname]]
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
+ atom_name1
+ )
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
+ atom_name2
+ )
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
+
+ # From this create an ambiguous_mask for the given sequence.
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
+ protein["aatype"]
+ ]
+
+ return protein
+
+
+def atom37_to_frames(protein, eps=1e-8):
+ aatype = protein["aatype"]
+ all_atom_positions = protein["all_atom_positions"]
+ all_atom_mask = protein["all_atom_mask"]
+
+ batch_dims = len(aatype.shape[:-1])
+
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
+
+ for restype, restype_letter in enumerate(rc.restypes):
+ resname = rc.restype_1to3[restype_letter]
+ for chi_idx in range(4):
+ if rc.chi_angles_mask[restype][chi_idx]:
+ names = rc.chi_angles_atoms[resname][chi_idx]
+ restype_rigidgroup_base_atom_names[
+ restype, chi_idx + 4, :
+ ] = names[1:]
+
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
+ (*aatype.shape[:-1], 21, 8),
+ )
+ restype_rigidgroup_mask[..., 0] = 1
+ restype_rigidgroup_mask[..., 3] = 1
+ restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
+ rc.chi_angles_mask
+ )
+
+ lookuptable = rc.atom_order.copy()
+ lookuptable[""] = 0
+ lookup = np.vectorize(lambda x: lookuptable[x])
+ restype_rigidgroup_base_atom37_idx = lookup(
+ restype_rigidgroup_base_atom_names,
+ )
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
+ restype_rigidgroup_base_atom37_idx,
+ )
+ restype_rigidgroup_base_atom37_idx = (
+ restype_rigidgroup_base_atom37_idx.view(
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
+ )
+ )
+
+ residx_rigidgroup_base_atom37_idx = batched_gather(
+ restype_rigidgroup_base_atom37_idx,
+ aatype,
+ dim=-3,
+ no_batch_dims=batch_dims,
+ )
+
+ base_atom_pos = batched_gather(
+ all_atom_positions,
+ residx_rigidgroup_base_atom37_idx,
+ dim=-2,
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
+ )
+
+ gt_frames = Rigid.from_3_points(
+ p_neg_x_axis=base_atom_pos[..., 0, :],
+ origin=base_atom_pos[..., 1, :],
+ p_xy_plane=base_atom_pos[..., 2, :],
+ eps=eps,
+ )
+
+ group_exists = batched_gather(
+ restype_rigidgroup_mask,
+ aatype,
+ dim=-2,
+ no_batch_dims=batch_dims,
+ )
+
+ gt_atoms_exist = batched_gather(
+ all_atom_mask,
+ residx_rigidgroup_base_atom37_idx,
+ dim=-1,
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
+ )
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
+
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
+ rots[..., 0, 0, 0] = -1
+ rots[..., 0, 2, 2] = -1
+ rots = Rotation(rot_mats=rots)
+
+ gt_frames = gt_frames.compose(Rigid(rots, None))
+
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
+ *((1,) * batch_dims), 21, 8
+ )
+ restype_rigidgroup_rots = torch.eye(
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
+ )
+ restype_rigidgroup_rots = torch.tile(
+ restype_rigidgroup_rots,
+ (*((1,) * batch_dims), 21, 8, 1, 1),
+ )
+
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
+ restype = rc.restype_order[rc.restype_3to1[resname]]
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
+
+ residx_rigidgroup_is_ambiguous = batched_gather(
+ restype_rigidgroup_is_ambiguous,
+ aatype,
+ dim=-2,
+ no_batch_dims=batch_dims,
+ )
+
+ residx_rigidgroup_ambiguity_rot = batched_gather(
+ restype_rigidgroup_rots,
+ aatype,
+ dim=-4,
+ no_batch_dims=batch_dims,
+ )
+
+ residx_rigidgroup_ambiguity_rot = Rotation(
+ rot_mats=residx_rigidgroup_ambiguity_rot
+ )
+ alt_gt_frames = gt_frames.compose(
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
+ )
+
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
+
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
+ protein["rigidgroups_gt_exists"] = gt_exists
+ protein["rigidgroups_group_exists"] = group_exists
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
+
+ return protein
+
+
+def get_chi_atom_indices():
+ """Returns atom indices needed to compute chi angles for all residue types.
+
+ Returns:
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
+ in the order specified in rc.restypes + unknown residue type
+ at the end. For chi angles which are not defined on the residue, the
+ positions indices are by default set to 0.
+ """
+ chi_atom_indices = []
+ for residue_name in rc.restypes:
+ residue_name = rc.restype_1to3[residue_name]
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
+ atom_indices = []
+ for chi_angle in residue_chi_angles:
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
+ for _ in range(4 - len(atom_indices)):
+ atom_indices.append(
+ [0, 0, 0, 0]
+ ) # For chi angles not defined on the AA.
+ chi_atom_indices.append(atom_indices)
+
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
+
+ return chi_atom_indices
+
+
+@curry1
+def atom37_to_torsion_angles(
+ protein,
+ prefix="",
+):
+ """
+ Convert coordinates to torsion angles.
+
+ This function is extremely sensitive to floating point imprecisions
+ and should be run with double precision whenever possible.
+
+ Args:
+ Dict containing:
+ * (prefix)aatype:
+ [*, N_res] residue indices
+ * (prefix)all_atom_positions:
+ [*, N_res, 37, 3] atom positions (in atom37
+ format)
+ * (prefix)all_atom_mask:
+ [*, N_res, 37] atom position mask
+ Returns:
+ The same dictionary updated with the following features:
+
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
+ Torsion angles
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
+ Alternate torsion angles (accounting for 180-degree symmetry)
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
+ Torsion angles mask
+ """
+ aatype = protein[prefix + "aatype"]
+ all_atom_positions = protein[prefix + "all_atom_positions"]
+ all_atom_mask = protein[prefix + "all_atom_mask"]
+
+ aatype = torch.clamp(aatype, max=20)
+
+ pad = all_atom_positions.new_zeros(
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
+ )
+ prev_all_atom_positions = torch.cat(
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
+ )
+
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
+
+ pre_omega_atom_pos = torch.cat(
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
+ dim=-2,
+ )
+ phi_atom_pos = torch.cat(
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
+ dim=-2,
+ )
+ psi_atom_pos = torch.cat(
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
+ dim=-2,
+ )
+
+ pre_omega_mask = torch.prod(
+ prev_all_atom_mask[..., 1:3], dim=-1
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
+ )
+ psi_mask = (
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
+ * all_atom_mask[..., 4]
+ )
+
+ chi_atom_indices = torch.as_tensor(
+ get_chi_atom_indices(), device=aatype.device
+ )
+
+ atom_indices = chi_atom_indices[..., aatype, :, :]
+ chis_atom_pos = batched_gather(
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
+ )
+
+ chi_angles_mask = list(rc.chi_angles_mask)
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
+
+ chis_mask = chi_angles_mask[aatype, :]
+
+ chi_angle_atoms_mask = batched_gather(
+ all_atom_mask,
+ atom_indices,
+ dim=-1,
+ no_batch_dims=len(atom_indices.shape[:-2]),
+ )
+ chi_angle_atoms_mask = torch.prod(
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
+ )
+ chis_mask = chis_mask * chi_angle_atoms_mask
+
+ torsions_atom_pos = torch.cat(
+ [
+ pre_omega_atom_pos[..., None, :, :],
+ phi_atom_pos[..., None, :, :],
+ psi_atom_pos[..., None, :, :],
+ chis_atom_pos,
+ ],
+ dim=-3,
+ )
+
+ torsion_angles_mask = torch.cat(
+ [
+ pre_omega_mask[..., None],
+ phi_mask[..., None],
+ psi_mask[..., None],
+ chis_mask,
+ ],
+ dim=-1,
+ )
+
+ torsion_frames = Rigid.from_3_points(
+ torsions_atom_pos[..., 1, :],
+ torsions_atom_pos[..., 2, :],
+ torsions_atom_pos[..., 0, :],
+ eps=1e-8,
+ )
+
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
+ torsions_atom_pos[..., 3, :]
+ )
+
+ torsion_angles_sin_cos = torch.stack(
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
+ )
+
+ denom = torch.sqrt(
+ torch.sum(
+ torch.square(torsion_angles_sin_cos),
+ dim=-1,
+ dtype=torsion_angles_sin_cos.dtype,
+ keepdims=True,
+ )
+ + 1e-8
+ )
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
+
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
+
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
+ rc.chi_pi_periodic,
+ )[aatype, ...]
+
+ mirror_torsion_angles = torch.cat(
+ [
+ all_atom_mask.new_ones(*aatype.shape, 3),
+ 1.0 - 2.0 * chi_is_ambiguous,
+ ],
+ dim=-1,
+ )
+
+ alt_torsion_angles_sin_cos = (
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
+ )
+
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
+
+ return protein
+
+
+def get_backbone_frames(protein):
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
+ ..., 0, :, :
+ ]
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
+
+ return protein
+
+
+def get_chi_angles(protein):
+ dtype = protein["all_atom_mask"].dtype
+ protein["chi_angles_sin_cos"] = (
+ protein["torsion_angles_sin_cos"][..., 3:, :]
+ ).to(dtype)
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
+
+ return protein
+
+
+@curry1
+def random_crop_to_size(
+ protein,
+ crop_size,
+ max_templates,
+ shape_schema,
+ subsample_templates=False,
+ seed=None,
+):
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
+ # We want each ensemble to be cropped the same way
+ g = torch.Generator(device=protein["seq_length"].device)
+ if seed is not None:
+ g.manual_seed(seed)
+
+ seq_length = protein["seq_length"]
+
+ if "template_mask" in protein:
+ num_templates = protein["template_mask"].shape[-1]
+ else:
+ num_templates = 0
+
+ # No need to subsample templates if there aren't any
+ subsample_templates = subsample_templates and num_templates
+
+ num_res_crop_size = min(int(seq_length), crop_size)
+
+ def _randint(lower, upper):
+ return int(torch.randint(
+ lower,
+ upper + 1,
+ (1,),
+ device=protein["seq_length"].device,
+ generator=g,
+ )[0])
+
+ if subsample_templates:
+ templates_crop_start = _randint(0, num_templates)
+ templates_select_indices = torch.randperm(
+ num_templates, device=protein["seq_length"].device, generator=g
+ )
+ else:
+ templates_crop_start = 0
+
+ num_templates_crop_size = min(
+ num_templates - templates_crop_start, max_templates
+ )
+
+ n = seq_length - num_res_crop_size
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
+ right_anchor = n
+ else:
+ x = _randint(0, n)
+ right_anchor = n - x
+
+ num_res_crop_start = _randint(0, right_anchor)
+
+ for k, v in protein.items():
+ if k not in shape_schema or (
+ "template" not in k and NUM_RES not in shape_schema[k]
+ ):
+ continue
+
+ # randomly permute the templates before cropping them.
+ if k.startswith("template") and subsample_templates:
+ v = v[templates_select_indices]
+
+ slices = []
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
+ is_num_res = dim_size == NUM_RES
+ if i == 0 and k.startswith("template"):
+ crop_size = num_templates_crop_size
+ crop_start = templates_crop_start
+ else:
+ crop_start = num_res_crop_start if is_num_res else 0
+ crop_size = num_res_crop_size if is_num_res else dim
+ slices.append(slice(crop_start, crop_start + crop_size))
+ protein[k] = v[slices]
+
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
+
+ return protein
diff --git a/openfold/data/errors.py b/openfold/data/errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..7809d5ae29a736ccc7688d555a06a5301a3ab46f
--- /dev/null
+++ b/openfold/data/errors.py
@@ -0,0 +1,22 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""General-purpose errors used throughout the data pipeline"""
+class Error(Exception):
+ """Base class for exceptions."""
+
+
+class MultipleChainsError(Error):
+ """An error indicating that multiple chains were found for a given ID."""
diff --git a/openfold/data/feature_pipeline.py b/openfold/data/feature_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b9a9f2570d01307b724ca73cf3b00016526383
--- /dev/null
+++ b/openfold/data/feature_pipeline.py
@@ -0,0 +1,115 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Mapping, Tuple, List, Optional, Dict, Sequence
+
+import ml_collections
+import numpy as np
+import torch
+
+from openfold.data import input_pipeline
+
+
+FeatureDict = Mapping[str, np.ndarray]
+TensorDict = Dict[str, torch.Tensor]
+
+
+def np_to_tensor_dict(
+ np_example: Mapping[str, np.ndarray],
+ features: Sequence[str],
+) -> TensorDict:
+ """Creates dict of tensors from a dict of NumPy arrays.
+
+ Args:
+ np_example: A dict of NumPy feature arrays.
+ features: A list of strings of feature names to be returned in the dataset.
+
+ Returns:
+ A dictionary of features mapping feature names to features. Only the given
+ features are returned, all other ones are filtered out.
+ """
+ tensor_dict = {
+ k: torch.tensor(v) for k, v in np_example.items() if k in features
+ }
+ return tensor_dict
+
+
+def make_data_config(
+ config: ml_collections.ConfigDict,
+ mode: str,
+ num_res: int,
+) -> Tuple[ml_collections.ConfigDict, List[str]]:
+ cfg = copy.deepcopy(config)
+ mode_cfg = cfg[mode]
+ with cfg.unlocked():
+ if mode_cfg.crop_size is None:
+ mode_cfg.crop_size = num_res
+
+ feature_names = cfg.common.unsupervised_features
+
+ if cfg.common.use_templates:
+ feature_names += cfg.common.template_features
+
+ if cfg[mode].supervised:
+ feature_names += cfg.supervised.supervised_features
+
+ return cfg, feature_names
+
+
+def np_example_to_features(
+ np_example: FeatureDict,
+ config: ml_collections.ConfigDict,
+ mode: str,
+):
+ np_example = dict(np_example)
+ num_res = int(np_example["seq_length"][0])
+ cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
+
+ if "deletion_matrix_int" in np_example:
+ np_example["deletion_matrix"] = np_example.pop(
+ "deletion_matrix_int"
+ ).astype(np.float32)
+
+ tensor_dict = np_to_tensor_dict(
+ np_example=np_example, features=feature_names
+ )
+ with torch.no_grad():
+ features = input_pipeline.process_tensors_from_config(
+ tensor_dict,
+ cfg.common,
+ cfg[mode],
+ )
+
+ return {k: v for k, v in features.items()}
+
+
+class FeaturePipeline:
+ def __init__(
+ self,
+ config: ml_collections.ConfigDict,
+ ):
+ self.config = config
+
+ def process_features(
+ self,
+ raw_features: FeatureDict,
+ mode: str = "train",
+ ) -> FeatureDict:
+ return np_example_to_features(
+ np_example=raw_features,
+ config=self.config,
+ mode=mode,
+ )
diff --git a/openfold/data/input_pipeline.py b/openfold/data/input_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b706ec3df523580d25d1603fc7f653c597be2902
--- /dev/null
+++ b/openfold/data/input_pipeline.py
@@ -0,0 +1,208 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import torch
+
+from openfold.data import data_transforms
+
+
+def nonensembled_transform_fns(common_cfg, mode_cfg):
+ """Input pipeline data transformers that are not ensembled."""
+ transforms = [
+ data_transforms.cast_to_64bit_ints,
+ data_transforms.correct_msa_restypes,
+ data_transforms.squeeze_features,
+ data_transforms.randomly_replace_msa_with_unknown(0.0),
+ data_transforms.make_seq_mask,
+ data_transforms.make_msa_mask,
+ data_transforms.make_hhblits_profile,
+ ]
+ if common_cfg.use_templates:
+ transforms.extend(
+ [
+ data_transforms.fix_templates_aatype,
+ data_transforms.make_template_mask,
+ data_transforms.make_pseudo_beta("template_"),
+ ]
+ )
+ if common_cfg.use_template_torsion_angles:
+ transforms.extend(
+ [
+ data_transforms.atom37_to_torsion_angles("template_"),
+ ]
+ )
+
+ transforms.extend(
+ [
+ data_transforms.make_atom14_masks,
+ ]
+ )
+
+ if mode_cfg.supervised:
+ transforms.extend(
+ [
+ data_transforms.make_atom14_positions,
+ data_transforms.atom37_to_frames,
+ data_transforms.atom37_to_torsion_angles(""),
+ data_transforms.make_pseudo_beta(""),
+ data_transforms.get_backbone_frames,
+ data_transforms.get_chi_angles,
+ ]
+ )
+
+ return transforms
+
+
+def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
+ """Input pipeline data transformers that can be ensembled and averaged."""
+ transforms = []
+
+ if "max_distillation_msa_clusters" in mode_cfg:
+ transforms.append(
+ data_transforms.sample_msa_distillation(
+ mode_cfg.max_distillation_msa_clusters
+ )
+ )
+
+ if common_cfg.reduce_msa_clusters_by_max_templates:
+ pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
+ else:
+ pad_msa_clusters = mode_cfg.max_msa_clusters
+
+ max_msa_clusters = pad_msa_clusters
+ max_extra_msa = common_cfg.max_extra_msa
+
+ msa_seed = None
+ if(not common_cfg.resample_msa_in_recycling):
+ msa_seed = ensemble_seed
+
+ transforms.append(
+ data_transforms.sample_msa(
+ max_msa_clusters,
+ keep_extra=True,
+ seed=msa_seed,
+ )
+ )
+
+ if "masked_msa" in common_cfg:
+ # Masked MSA should come *before* MSA clustering so that
+ # the clustering and full MSA profile do not leak information about
+ # the masked locations and secret corrupted locations.
+ transforms.append(
+ data_transforms.make_masked_msa(
+ common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
+ )
+ )
+
+ if common_cfg.msa_cluster_features:
+ transforms.append(data_transforms.nearest_neighbor_clusters())
+ transforms.append(data_transforms.summarize_clusters())
+
+ # Crop after creating the cluster profiles.
+ if max_extra_msa:
+ transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
+ else:
+ transforms.append(data_transforms.delete_extra_msa)
+
+ transforms.append(data_transforms.make_msa_feat())
+
+ crop_feats = dict(common_cfg.feat)
+
+ if mode_cfg.fixed_size:
+ transforms.append(data_transforms.select_feat(list(crop_feats)))
+ transforms.append(
+ data_transforms.random_crop_to_size(
+ mode_cfg.crop_size,
+ mode_cfg.max_templates,
+ crop_feats,
+ mode_cfg.subsample_templates,
+ seed=ensemble_seed + 1,
+ )
+ )
+ transforms.append(
+ data_transforms.make_fixed_size(
+ crop_feats,
+ pad_msa_clusters,
+ common_cfg.max_extra_msa,
+ mode_cfg.crop_size,
+ mode_cfg.max_templates,
+ )
+ )
+ else:
+ transforms.append(
+ data_transforms.crop_templates(mode_cfg.max_templates)
+ )
+
+ return transforms
+
+
+def process_tensors_from_config(tensors, common_cfg, mode_cfg):
+ """Based on the config, apply filters and transformations to the data."""
+
+ ensemble_seed = torch.Generator().seed()
+
+ def wrap_ensemble_fn(data, i):
+ """Function to be mapped over the ensemble dimension."""
+ d = data.copy()
+ fns = ensembled_transform_fns(
+ common_cfg,
+ mode_cfg,
+ ensemble_seed,
+ )
+ fn = compose(fns)
+ d["ensemble_index"] = i
+ return fn(d)
+
+ no_templates = True
+ if("template_aatype" in tensors):
+ no_templates = tensors["template_aatype"].shape[0] == 0
+
+ nonensembled = nonensembled_transform_fns(
+ common_cfg,
+ mode_cfg,
+ )
+
+ tensors = compose(nonensembled)(tensors)
+
+ if("no_recycling_iters" in tensors):
+ num_recycling = int(tensors["no_recycling_iters"])
+ else:
+ num_recycling = common_cfg.max_recycling_iters
+
+ tensors = map_fn(
+ lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
+ )
+
+ return tensors
+
+
+@data_transforms.curry1
+def compose(x, fs):
+ for f in fs:
+ x = f(x)
+ return x
+
+
+def map_fn(fun, x):
+ ensembles = [fun(elem) for elem in x]
+ features = ensembles[0].keys()
+ ensembled_dict = {}
+ for feat in features:
+ ensembled_dict[feat] = torch.stack(
+ [dict_i[feat] for dict_i in ensembles], dim=-1
+ )
+ return ensembled_dict
diff --git a/openfold/data/mmcif_parsing.py b/openfold/data/mmcif_parsing.py
new file mode 100644
index 0000000000000000000000000000000000000000..c95047e1f9ab0cf1549b51d9f01da1c5e26080f1
--- /dev/null
+++ b/openfold/data/mmcif_parsing.py
@@ -0,0 +1,485 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Parses the mmCIF file format."""
+import collections
+import dataclasses
+import io
+import json
+import logging
+import os
+from typing import Any, Mapping, Optional, Sequence, Tuple
+
+from Bio import PDB
+from Bio.Data import SCOPData
+import numpy as np
+
+from openfold.data.errors import MultipleChainsError
+import openfold.np.residue_constants as residue_constants
+
+
+# Type aliases:
+ChainId = str
+PdbHeader = Mapping[str, Any]
+PdbStructure = PDB.Structure.Structure
+SeqRes = str
+MmCIFDict = Mapping[str, Sequence[str]]
+
+
+@dataclasses.dataclass(frozen=True)
+class Monomer:
+ id: str
+ num: int
+
+
+# Note - mmCIF format provides no guarantees on the type of author-assigned
+# sequence numbers. They need not be integers.
+@dataclasses.dataclass(frozen=True)
+class AtomSite:
+ residue_name: str
+ author_chain_id: str
+ mmcif_chain_id: str
+ author_seq_num: str
+ mmcif_seq_num: int
+ insertion_code: str
+ hetatm_atom: str
+ model_num: int
+
+
+# Used to map SEQRES index to a residue in the structure.
+@dataclasses.dataclass(frozen=True)
+class ResiduePosition:
+ chain_id: str
+ residue_number: int
+ insertion_code: str
+
+
+@dataclasses.dataclass(frozen=True)
+class ResidueAtPosition:
+ position: Optional[ResiduePosition]
+ name: str
+ is_missing: bool
+ hetflag: str
+
+
+@dataclasses.dataclass(frozen=True)
+class MmcifObject:
+ """Representation of a parsed mmCIF file.
+
+ Contains:
+ file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
+ files being processed.
+ header: Biopython header.
+ structure: Biopython structure.
+ chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
+ {'A': 'ABCDEFG'}
+ seqres_to_structure: Dict; for each chain_id contains a mapping between
+ SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
+ 1: ResidueAtPosition,
+ ...}}
+ raw_string: The raw string used to construct the MmcifObject.
+ """
+
+ file_id: str
+ header: PdbHeader
+ structure: PdbStructure
+ chain_to_seqres: Mapping[ChainId, SeqRes]
+ seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
+ raw_string: Any
+
+
+@dataclasses.dataclass(frozen=True)
+class ParsingResult:
+ """Returned by the parse function.
+
+ Contains:
+ mmcif_object: A MmcifObject, may be None if no chain could be successfully
+ parsed.
+ errors: A dict mapping (file_id, chain_id) to any exception generated.
+ """
+
+ mmcif_object: Optional[MmcifObject]
+ errors: Mapping[Tuple[str, str], Any]
+
+
+class ParseError(Exception):
+ """An error indicating that an mmCIF file could not be parsed."""
+
+
+def mmcif_loop_to_list(
+ prefix: str, parsed_info: MmCIFDict
+) -> Sequence[Mapping[str, str]]:
+ """Extracts loop associated with a prefix from mmCIF data as a list.
+
+ Reference for loop_ in mmCIF:
+ http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
+
+ Args:
+ prefix: Prefix shared by each of the data items in the loop.
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
+ _entity_poly_seq.mon_id. Should include the trailing period.
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
+ parser.
+
+ Returns:
+ Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
+ """
+ cols = []
+ data = []
+ for key, value in parsed_info.items():
+ if key.startswith(prefix):
+ cols.append(key)
+ data.append(value)
+
+ assert all([len(xs) == len(data[0]) for xs in data]), (
+ "mmCIF error: Not all loops are the same length: %s" % cols
+ )
+
+ return [dict(zip(cols, xs)) for xs in zip(*data)]
+
+
+def mmcif_loop_to_dict(
+ prefix: str,
+ index: str,
+ parsed_info: MmCIFDict,
+) -> Mapping[str, Mapping[str, str]]:
+ """Extracts loop associated with a prefix from mmCIF data as a dictionary.
+
+ Args:
+ prefix: Prefix shared by each of the data items in the loop.
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
+ _entity_poly_seq.mon_id. Should include the trailing period.
+ index: Which item of loop data should serve as the key.
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
+ parser.
+
+ Returns:
+ Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
+ indexed by the index column.
+ """
+ entries = mmcif_loop_to_list(prefix, parsed_info)
+ return {entry[index]: entry for entry in entries}
+
+
+def parse(
+ *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
+) -> ParsingResult:
+ """Entry point, parses an mmcif_string.
+
+ Args:
+ file_id: A string identifier for this file. Should be unique within the
+ collection of files being processed.
+ mmcif_string: Contents of an mmCIF file.
+ catch_all_errors: If True, all exceptions are caught and error messages are
+ returned as part of the ParsingResult. If False exceptions will be allowed
+ to propagate.
+
+ Returns:
+ A ParsingResult.
+ """
+ errors = {}
+ try:
+ parser = PDB.MMCIFParser(QUIET=True)
+ handle = io.StringIO(mmcif_string)
+ full_structure = parser.get_structure("", handle)
+ first_model_structure = _get_first_model(full_structure)
+ # Extract the _mmcif_dict from the parser, which contains useful fields not
+ # reflected in the Biopython structure.
+ parsed_info = parser._mmcif_dict # pylint:disable=protected-access
+
+ # Ensure all values are lists, even if singletons.
+ for key, value in parsed_info.items():
+ if not isinstance(value, list):
+ parsed_info[key] = [value]
+
+ header = _get_header(parsed_info)
+
+ # Determine the protein chains, and their start numbers according to the
+ # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
+ valid_chains = _get_protein_chains(parsed_info=parsed_info)
+ if not valid_chains:
+ return ParsingResult(
+ None, {(file_id, ""): "No protein chains found in this file."}
+ )
+ seq_start_num = {
+ chain_id: min([monomer.num for monomer in seq])
+ for chain_id, seq in valid_chains.items()
+ }
+
+ # Loop over the atoms for which we have coordinates. Populate two mappings:
+ # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
+ # the authors / Biopython).
+ # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
+ mmcif_to_author_chain_id = {}
+ seq_to_structure_mappings = {}
+ for atom in _get_atom_site_list(parsed_info):
+ if atom.model_num != "1":
+ # We only process the first model at the moment.
+ continue
+
+ mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
+
+ if atom.mmcif_chain_id in valid_chains:
+ hetflag = " "
+ if atom.hetatm_atom == "HETATM":
+ # Water atoms are assigned a special hetflag of W in Biopython. We
+ # need to do the same, so that this hetflag can be used to fetch
+ # a residue from the Biopython structure by id.
+ if atom.residue_name in ("HOH", "WAT"):
+ hetflag = "W"
+ else:
+ hetflag = "H_" + atom.residue_name
+ insertion_code = atom.insertion_code
+ if not _is_set(atom.insertion_code):
+ insertion_code = " "
+ position = ResiduePosition(
+ chain_id=atom.author_chain_id,
+ residue_number=int(atom.author_seq_num),
+ insertion_code=insertion_code,
+ )
+ seq_idx = (
+ int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
+ )
+ current = seq_to_structure_mappings.get(
+ atom.author_chain_id, {}
+ )
+ current[seq_idx] = ResidueAtPosition(
+ position=position,
+ name=atom.residue_name,
+ is_missing=False,
+ hetflag=hetflag,
+ )
+ seq_to_structure_mappings[atom.author_chain_id] = current
+
+ # Add missing residue information to seq_to_structure_mappings.
+ for chain_id, seq_info in valid_chains.items():
+ author_chain = mmcif_to_author_chain_id[chain_id]
+ current_mapping = seq_to_structure_mappings[author_chain]
+ for idx, monomer in enumerate(seq_info):
+ if idx not in current_mapping:
+ current_mapping[idx] = ResidueAtPosition(
+ position=None,
+ name=monomer.id,
+ is_missing=True,
+ hetflag=" ",
+ )
+
+ author_chain_to_sequence = {}
+ for chain_id, seq_info in valid_chains.items():
+ author_chain = mmcif_to_author_chain_id[chain_id]
+ seq = []
+ for monomer in seq_info:
+ code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
+ seq.append(code if len(code) == 1 else "X")
+ seq = "".join(seq)
+ author_chain_to_sequence[author_chain] = seq
+
+ mmcif_object = MmcifObject(
+ file_id=file_id,
+ header=header,
+ structure=first_model_structure,
+ chain_to_seqres=author_chain_to_sequence,
+ seqres_to_structure=seq_to_structure_mappings,
+ raw_string=parsed_info,
+ )
+
+ return ParsingResult(mmcif_object=mmcif_object, errors=errors)
+ except Exception as e: # pylint:disable=broad-except
+ errors[(file_id, "")] = e
+ if not catch_all_errors:
+ raise
+ return ParsingResult(mmcif_object=None, errors=errors)
+
+
+def _get_first_model(structure: PdbStructure) -> PdbStructure:
+ """Returns the first model in a Biopython structure."""
+ return next(structure.get_models())
+
+
+_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
+
+
+def get_release_date(parsed_info: MmCIFDict) -> str:
+ """Returns the oldest revision date."""
+ revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
+ return min(revision_dates)
+
+
+def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
+ """Returns a basic header containing method, release date and resolution."""
+ header = {}
+
+ experiments = mmcif_loop_to_list("_exptl.", parsed_info)
+ header["structure_method"] = ",".join(
+ [experiment["_exptl.method"].lower() for experiment in experiments]
+ )
+
+ # Note: The release_date here corresponds to the oldest revision. We prefer to
+ # use this for dataset filtering over the deposition_date.
+ if "_pdbx_audit_revision_history.revision_date" in parsed_info:
+ header["release_date"] = get_release_date(parsed_info)
+ else:
+ logging.warning(
+ "Could not determine release_date: %s", parsed_info["_entry.id"]
+ )
+
+ header["resolution"] = 0.00
+ for res_key in (
+ "_refine.ls_d_res_high",
+ "_em_3d_reconstruction.resolution",
+ "_reflns.d_resolution_high",
+ ):
+ if res_key in parsed_info:
+ try:
+ raw_resolution = parsed_info[res_key][0]
+ header["resolution"] = float(raw_resolution)
+ except ValueError:
+ logging.info(
+ "Invalid resolution format: %s", parsed_info[res_key]
+ )
+
+ return header
+
+
+def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
+ """Returns list of atom sites; contains data not present in the structure."""
+ return [
+ AtomSite(*site)
+ for site in zip( # pylint:disable=g-complex-comprehension
+ parsed_info["_atom_site.label_comp_id"],
+ parsed_info["_atom_site.auth_asym_id"],
+ parsed_info["_atom_site.label_asym_id"],
+ parsed_info["_atom_site.auth_seq_id"],
+ parsed_info["_atom_site.label_seq_id"],
+ parsed_info["_atom_site.pdbx_PDB_ins_code"],
+ parsed_info["_atom_site.group_PDB"],
+ parsed_info["_atom_site.pdbx_PDB_model_num"],
+ )
+ ]
+
+
+def _get_protein_chains(
+ *, parsed_info: Mapping[str, Any]
+) -> Mapping[ChainId, Sequence[Monomer]]:
+ """Extracts polymer information for protein chains only.
+
+ Args:
+ parsed_info: _mmcif_dict produced by the Biopython parser.
+
+ Returns:
+ A dict mapping mmcif chain id to a list of Monomers.
+ """
+ # Get polymer information for each entity in the structure.
+ entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
+
+ polymers = collections.defaultdict(list)
+ for entity_poly_seq in entity_poly_seqs:
+ polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
+ Monomer(
+ id=entity_poly_seq["_entity_poly_seq.mon_id"],
+ num=int(entity_poly_seq["_entity_poly_seq.num"]),
+ )
+ )
+
+ # Get chemical compositions. Will allow us to identify which of these polymers
+ # are proteins.
+ chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
+
+ # Get chains information for each entity. Necessary so that we can return a
+ # dict keyed on chain id rather than entity.
+ struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
+
+ entity_to_mmcif_chains = collections.defaultdict(list)
+ for struct_asym in struct_asyms:
+ chain_id = struct_asym["_struct_asym.id"]
+ entity_id = struct_asym["_struct_asym.entity_id"]
+ entity_to_mmcif_chains[entity_id].append(chain_id)
+
+ # Identify and return the valid protein chains.
+ valid_chains = {}
+ for entity_id, seq_info in polymers.items():
+ chain_ids = entity_to_mmcif_chains[entity_id]
+
+ # Reject polymers without any peptide-like components, such as DNA/RNA.
+ if any(
+ [
+ "peptide" in chem_comps[monomer.id]["_chem_comp.type"]
+ for monomer in seq_info
+ ]
+ ):
+ for chain_id in chain_ids:
+ valid_chains[chain_id] = seq_info
+ return valid_chains
+
+
+def _is_set(data: str) -> bool:
+ """Returns False if data is a special mmCIF character indicating 'unset'."""
+ return data not in (".", "?")
+
+
+def get_atom_coords(
+ mmcif_object: MmcifObject,
+ chain_id: str,
+ _zero_center_positions: bool = True
+) -> Tuple[np.ndarray, np.ndarray]:
+ # Locate the right chain
+ chains = list(mmcif_object.structure.get_chains())
+ relevant_chains = [c for c in chains if c.id == chain_id]
+ if len(relevant_chains) != 1:
+ raise MultipleChainsError(
+ f"Expected exactly one chain in structure with id {chain_id}."
+ )
+ chain = relevant_chains[0]
+
+ # Extract the coordinates
+ num_res = len(mmcif_object.chain_to_seqres[chain_id])
+ all_atom_positions = np.zeros(
+ [num_res, residue_constants.atom_type_num, 3], dtype=np.float32
+ )
+ all_atom_mask = np.zeros(
+ [num_res, residue_constants.atom_type_num], dtype=np.float32
+ )
+ for res_index in range(num_res):
+ pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
+ mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
+ res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
+ if not res_at_position.is_missing:
+ res = chain[
+ (
+ res_at_position.hetflag,
+ res_at_position.position.residue_number,
+ res_at_position.position.insertion_code,
+ )
+ ]
+ for atom in res.get_atoms():
+ atom_name = atom.get_name()
+ x, y, z = atom.get_coord()
+ if atom_name in residue_constants.atom_order.keys():
+ pos[residue_constants.atom_order[atom_name]] = [x, y, z]
+ mask[residue_constants.atom_order[atom_name]] = 1.0
+ elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
+ # Put the coords of the selenium atom in the sulphur column
+ pos[residue_constants.atom_order["SD"]] = [x, y, z]
+ mask[residue_constants.atom_order["SD"]] = 1.0
+
+ all_atom_positions[res_index] = pos
+ all_atom_mask[res_index] = mask
+
+ if _zero_center_positions:
+ binary_mask = all_atom_mask.astype(bool)
+ translation_vec = all_atom_positions[binary_mask].mean(axis=0)
+ all_atom_positions[binary_mask] -= translation_vec
+
+ return all_atom_positions, all_atom_mask
diff --git a/openfold/data/parsers.py b/openfold/data/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0574c6e86f5c13d8235b16a61cefa365f0ab5b
--- /dev/null
+++ b/openfold/data/parsers.py
@@ -0,0 +1,388 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions for parsing various file formats."""
+import collections
+import dataclasses
+import re
+import string
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+
+
+DeletionMatrix = Sequence[Sequence[int]]
+
+
+@dataclasses.dataclass(frozen=True)
+class TemplateHit:
+ """Class representing a template hit."""
+
+ index: int
+ name: str
+ aligned_cols: int
+ sum_probs: float
+ query: str
+ hit_sequence: str
+ indices_query: List[int]
+ indices_hit: List[int]
+
+
+def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
+
+ Arguments:
+ fasta_string: The string contents of a FASTA file.
+
+ Returns:
+ A tuple of two lists:
+ * A list of sequences.
+ * A list of sequence descriptions taken from the comment lines. In the
+ same order as the sequences.
+ """
+ sequences = []
+ descriptions = []
+ index = -1
+ for line in fasta_string.splitlines():
+ line = line.strip()
+ if line.startswith(">"):
+ index += 1
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
+ sequences.append("")
+ continue
+ elif not line:
+ continue # Skip blank lines.
+ sequences[index] += line
+
+ return sequences, descriptions
+
+
+def parse_stockholm(
+ stockholm_string: str,
+) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
+ """Parses sequences and deletion matrix from stockholm format alignment.
+
+ Args:
+ stockholm_string: The string contents of a stockholm file. The first
+ sequence in the file should be the query sequence.
+
+ Returns:
+ A tuple of:
+ * A list of sequences that have been aligned to the query. These
+ might contain duplicates.
+ * The deletion matrix for the alignment as a list of lists. The element
+ at `deletion_matrix[i][j]` is the number of residues deleted from
+ the aligned sequence i at residue position j.
+ * The names of the targets matched, including the jackhmmer subsequence
+ suffix.
+ """
+ name_to_sequence = collections.OrderedDict()
+ for line in stockholm_string.splitlines():
+ line = line.strip()
+ if not line or line.startswith(("#", "//")):
+ continue
+ name, sequence = line.split()
+ if name not in name_to_sequence:
+ name_to_sequence[name] = ""
+ name_to_sequence[name] += sequence
+
+ msa = []
+ deletion_matrix = []
+
+ query = ""
+ keep_columns = []
+ for seq_index, sequence in enumerate(name_to_sequence.values()):
+ if seq_index == 0:
+ # Gather the columns with gaps from the query
+ query = sequence
+ keep_columns = [i for i, res in enumerate(query) if res != "-"]
+
+ # Remove the columns with gaps in the query from all sequences.
+ aligned_sequence = "".join([sequence[c] for c in keep_columns])
+
+ msa.append(aligned_sequence)
+
+ # Count the number of deletions w.r.t. query.
+ deletion_vec = []
+ deletion_count = 0
+ for seq_res, query_res in zip(sequence, query):
+ if seq_res != "-" or query_res != "-":
+ if query_res == "-":
+ deletion_count += 1
+ else:
+ deletion_vec.append(deletion_count)
+ deletion_count = 0
+ deletion_matrix.append(deletion_vec)
+
+ return msa, deletion_matrix, list(name_to_sequence.keys())
+
+
+def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
+ """Parses sequences and deletion matrix from a3m format alignment.
+
+ Args:
+ a3m_string: The string contents of a a3m file. The first sequence in the
+ file should be the query sequence.
+
+ Returns:
+ A tuple of:
+ * A list of sequences that have been aligned to the query. These
+ might contain duplicates.
+ * The deletion matrix for the alignment as a list of lists. The element
+ at `deletion_matrix[i][j]` is the number of residues deleted from
+ the aligned sequence i at residue position j.
+ """
+ sequences, _ = parse_fasta(a3m_string)
+ deletion_matrix = []
+ for msa_sequence in sequences:
+ deletion_vec = []
+ deletion_count = 0
+ for j in msa_sequence:
+ if j.islower():
+ deletion_count += 1
+ else:
+ deletion_vec.append(deletion_count)
+ deletion_count = 0
+ deletion_matrix.append(deletion_vec)
+
+ # Make the MSA matrix out of aligned (deletion-free) sequences.
+ deletion_table = str.maketrans("", "", string.ascii_lowercase)
+ aligned_sequences = [s.translate(deletion_table) for s in sequences]
+ return aligned_sequences, deletion_matrix
+
+
+def _convert_sto_seq_to_a3m(
+ query_non_gaps: Sequence[bool], sto_seq: str
+) -> Iterable[str]:
+ for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
+ if is_query_res_non_gap:
+ yield sequence_res
+ elif sequence_res != "-":
+ yield sequence_res.lower()
+
+
+def convert_stockholm_to_a3m(
+ stockholm_format: str, max_sequences: Optional[int] = None
+) -> str:
+ """Converts MSA in Stockholm format to the A3M format."""
+ descriptions = {}
+ sequences = {}
+ reached_max_sequences = False
+
+ for line in stockholm_format.splitlines():
+ reached_max_sequences = (
+ max_sequences and len(sequences) >= max_sequences
+ )
+ if line.strip() and not line.startswith(("#", "//")):
+ # Ignore blank lines, markup and end symbols - remainder are alignment
+ # sequence parts.
+ seqname, aligned_seq = line.split(maxsplit=1)
+ if seqname not in sequences:
+ if reached_max_sequences:
+ continue
+ sequences[seqname] = ""
+ sequences[seqname] += aligned_seq
+
+ for line in stockholm_format.splitlines():
+ if line[:4] == "#=GS":
+ # Description row - example format is:
+ # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
+ columns = line.split(maxsplit=3)
+ seqname, feature = columns[1:3]
+ value = columns[3] if len(columns) == 4 else ""
+ if feature != "DE":
+ continue
+ if reached_max_sequences and seqname not in sequences:
+ continue
+ descriptions[seqname] = value
+ if len(descriptions) == len(sequences):
+ break
+
+ # Convert sto format to a3m line by line
+ a3m_sequences = {}
+ # query_sequence is assumed to be the first sequence
+ query_sequence = next(iter(sequences.values()))
+ query_non_gaps = [res != "-" for res in query_sequence]
+ for seqname, sto_sequence in sequences.items():
+ a3m_sequences[seqname] = "".join(
+ _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
+ )
+
+ fasta_chunks = (
+ f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
+ for k in a3m_sequences
+ )
+ return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
+
+
+def _get_hhr_line_regex_groups(
+ regex_pattern: str, line: str
+) -> Sequence[Optional[str]]:
+ match = re.match(regex_pattern, line)
+ if match is None:
+ raise RuntimeError(f"Could not parse query line {line}")
+ return match.groups()
+
+
+def _update_hhr_residue_indices_list(
+ sequence: str, start_index: int, indices_list: List[int]
+):
+ """Computes the relative indices for each residue with respect to the original sequence."""
+ counter = start_index
+ for symbol in sequence:
+ if symbol == "-":
+ indices_list.append(-1)
+ else:
+ indices_list.append(counter)
+ counter += 1
+
+
+def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
+ """Parses the detailed HMM HMM comparison section for a single Hit.
+
+ This works on .hhr files generated from both HHBlits and HHSearch.
+
+ Args:
+ detailed_lines: A list of lines from a single comparison section between 2
+ sequences (which each have their own HMM's)
+
+ Returns:
+ A dictionary with the information from that detailed comparison section
+
+ Raises:
+ RuntimeError: If a certain line cannot be processed
+ """
+ # Parse first 2 lines.
+ number_of_hit = int(detailed_lines[0].split()[-1])
+ name_hit = detailed_lines[1][1:]
+
+ # Parse the summary line.
+ pattern = (
+ "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
+ " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
+ "]*Template_Neff=(.*)"
+ )
+ match = re.match(pattern, detailed_lines[2])
+ if match is None:
+ raise RuntimeError(
+ "Could not parse section: %s. Expected this: \n%s to contain summary."
+ % (detailed_lines, detailed_lines[2])
+ )
+ (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
+ float(x) for x in match.groups()
+ ]
+
+ # The next section reads the detailed comparisons. These are in a 'human
+ # readable' format which has a fixed length. The strategy employed is to
+ # assume that each block starts with the query sequence line, and to parse
+ # that with a regexp in order to deduce the fixed length used for that block.
+ query = ""
+ hit_sequence = ""
+ indices_query = []
+ indices_hit = []
+ length_block = None
+
+ for line in detailed_lines[3:]:
+ # Parse the query sequence line
+ if (
+ line.startswith("Q ")
+ and not line.startswith("Q ss_dssp")
+ and not line.startswith("Q ss_pred")
+ and not line.startswith("Q Consensus")
+ ):
+ # Thus the first 17 characters must be 'Q ', and we can parse
+ # everything after that.
+ # start sequence end total_sequence_length
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
+
+ # Get the length of the parsed block using the start and finish indices,
+ # and ensure it is the same as the actual block length.
+ start = int(groups[0]) - 1 # Make index zero based.
+ delta_query = groups[1]
+ end = int(groups[2])
+ num_insertions = len([x for x in delta_query if x == "-"])
+ length_block = end - start + num_insertions
+ assert length_block == len(delta_query)
+
+ # Update the query sequence and indices list.
+ query += delta_query
+ _update_hhr_residue_indices_list(delta_query, start, indices_query)
+
+ elif line.startswith("T "):
+ # Parse the hit sequence.
+ if (
+ not line.startswith("T ss_dssp")
+ and not line.startswith("T ss_pred")
+ and not line.startswith("T Consensus")
+ ):
+ # Thus the first 17 characters must be 'T ', and we can
+ # parse everything after that.
+ # start sequence end total_sequence_length
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
+ start = int(groups[0]) - 1 # Make index zero based.
+ delta_hit_sequence = groups[1]
+ assert length_block == len(delta_hit_sequence)
+
+ # Update the hit sequence and indices list.
+ hit_sequence += delta_hit_sequence
+ _update_hhr_residue_indices_list(
+ delta_hit_sequence, start, indices_hit
+ )
+
+ return TemplateHit(
+ index=number_of_hit,
+ name=name_hit,
+ aligned_cols=int(aligned_cols),
+ sum_probs=sum_probs,
+ query=query,
+ hit_sequence=hit_sequence,
+ indices_query=indices_query,
+ indices_hit=indices_hit,
+ )
+
+
+def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
+ """Parses the content of an entire HHR file."""
+ lines = hhr_string.splitlines()
+
+ # Each .hhr file starts with a results table, then has a sequence of hit
+ # "paragraphs", each paragraph starting with a line 'No '. We
+ # iterate through each paragraph to parse each hit.
+
+ block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
+
+ hits = []
+ if block_starts:
+ block_starts.append(len(lines)) # Add the end of the final block.
+ for i in range(len(block_starts) - 1):
+ hits.append(
+ _parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
+ )
+ return hits
+
+
+def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
+ """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
+ e_values = {"query": 0}
+ lines = [line for line in tblout.splitlines() if line[0] != "#"]
+ # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
+ # space-delimited. Relevant fields are (1) target name: and
+ # (5) E-value (full sequence) (numbering from 1).
+ for line in lines:
+ fields = line.split()
+ e_value = fields[4]
+ target_name = fields[0]
+ e_values[target_name] = float(e_value)
+ return e_values
diff --git a/openfold/data/templates.py b/openfold/data/templates.py
new file mode 100644
index 0000000000000000000000000000000000000000..3071433826e50296818ed5d11d58b92f0bac36b2
--- /dev/null
+++ b/openfold/data/templates.py
@@ -0,0 +1,1107 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions for getting templates and calculating template features."""
+import dataclasses
+import datetime
+import glob
+import json
+import logging
+import os
+import re
+from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
+
+import numpy as np
+
+from openfold.data import parsers, mmcif_parsing
+from openfold.data.errors import Error
+from openfold.data.tools import kalign
+from openfold.data.tools.utils import to_date
+from openfold.np import residue_constants
+
+
+class NoChainsError(Error):
+ """An error indicating that template mmCIF didn't have any chains."""
+
+
+class SequenceNotInTemplateError(Error):
+ """An error indicating that template mmCIF didn't contain the sequence."""
+
+
+class NoAtomDataInTemplateError(Error):
+ """An error indicating that template mmCIF didn't contain atom positions."""
+
+
+class TemplateAtomMaskAllZerosError(Error):
+ """An error indicating that template mmCIF had all atom positions masked."""
+
+
+class QueryToTemplateAlignError(Error):
+ """An error indicating that the query can't be aligned to the template."""
+
+
+class CaDistanceError(Error):
+ """An error indicating that a CA atom distance exceeds a threshold."""
+
+
+# Prefilter exceptions.
+class PrefilterError(Exception):
+ """A base class for template prefilter exceptions."""
+
+
+class DateError(PrefilterError):
+ """An error indicating that the hit date was after the max allowed date."""
+
+
+class PdbIdError(PrefilterError):
+ """An error indicating that the hit PDB ID was identical to the query."""
+
+
+class AlignRatioError(PrefilterError):
+ """An error indicating that the hit align ratio to the query was too small."""
+
+
+class DuplicateError(PrefilterError):
+ """An error indicating that the hit was an exact subsequence of the query."""
+
+
+class LengthError(PrefilterError):
+ """An error indicating that the hit was too short."""
+
+
+TEMPLATE_FEATURES = {
+ "template_aatype": np.int64,
+ "template_all_atom_mask": np.float32,
+ "template_all_atom_positions": np.float32,
+ "template_domain_names": np.object,
+ "template_sequence": np.object,
+ "template_sum_probs": np.float32,
+}
+
+
+def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
+ """Returns PDB id and chain id for an HHSearch Hit."""
+ # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
+ id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
+ if not id_match:
+ raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
+ pdb_id, chain_id = id_match.group(0).split("_")
+ return pdb_id.lower(), chain_id
+
+
+def _is_after_cutoff(
+ pdb_id: str,
+ release_dates: Mapping[str, datetime.datetime],
+ release_date_cutoff: Optional[datetime.datetime],
+) -> bool:
+ """Checks if the template date is after the release date cutoff.
+
+ Args:
+ pdb_id: 4 letter pdb code.
+ release_dates: Dictionary mapping PDB ids to their structure release dates.
+ release_date_cutoff: Max release date that is valid for this query.
+
+ Returns:
+ True if the template release date is after the cutoff, False otherwise.
+ """
+ pdb_id_upper = pdb_id.upper()
+ if release_date_cutoff is None:
+ raise ValueError("The release_date_cutoff must not be None.")
+ if pdb_id_upper in release_dates:
+ return release_dates[pdb_id_upper] > release_date_cutoff
+ else:
+ # Since this is just a quick prefilter to reduce the number of mmCIF files
+ # we need to parse, we don't have to worry about returning True here.
+ logging.info(
+ "Template structure not in release dates dict: %s", pdb_id
+ )
+ return False
+
+
+def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
+ """Parses the data file from PDB that lists which PDB ids are obsolete."""
+ with open(obsolete_file_path) as f:
+ result = {}
+ for line in f:
+ line = line.strip()
+ # We skip obsolete entries that don't contain a mapping to a new entry.
+ if line.startswith("OBSLTE") and len(line) > 30:
+ # Format: Date From To
+ # 'OBSLTE 31-JUL-94 116L 216L'
+ from_id = line[20:24].lower()
+ to_id = line[29:33].lower()
+ result[from_id] = to_id
+ return result
+
+
+def generate_release_dates_cache(mmcif_dir: str, out_path: str):
+ dates = {}
+ for f in os.listdir(mmcif_dir):
+ if f.endswith(".cif"):
+ path = os.path.join(mmcif_dir, f)
+ with open(path, "r") as fp:
+ mmcif_string = fp.read()
+
+ file_id = os.path.splitext(f)[0]
+ mmcif = mmcif_parsing.parse(
+ file_id=file_id, mmcif_string=mmcif_string
+ )
+ if mmcif.mmcif_object is None:
+ logging.info(f"Failed to parse {f}. Skipping...")
+ continue
+
+ mmcif = mmcif.mmcif_object
+ release_date = mmcif.header["release_date"]
+
+ dates[file_id] = release_date
+
+ with open(out_path, "r") as fp:
+ fp.write(json.dumps(dates))
+
+
+def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
+ """Parses release dates file, returns a mapping from PDBs to release dates."""
+ with open(path, "r") as fp:
+ data = json.load(fp)
+
+ return {
+ pdb.upper(): to_date(v)
+ for pdb, d in data.items()
+ for k, v in d.items()
+ if k == "release_date"
+ }
+
+
+def _assess_hhsearch_hit(
+ hit: parsers.TemplateHit,
+ hit_pdb_code: str,
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ release_dates: Mapping[str, datetime.datetime],
+ release_date_cutoff: datetime.datetime,
+ max_subsequence_ratio: float = 0.95,
+ min_align_ratio: float = 0.1,
+) -> bool:
+ """Determines if template is valid (without parsing the template mmcif file).
+
+ Args:
+ hit: HhrHit for the template.
+ hit_pdb_code: The 4 letter pdb code of the template hit. This might be
+ different from the value in the actual hit since the original pdb might
+ have become obsolete.
+ query_sequence: Amino acid sequence of the query.
+ query_pdb_code: 4 letter pdb code of the query.
+ release_dates: Dictionary mapping pdb codes to their structure release
+ dates.
+ release_date_cutoff: Max release date that is valid for this query.
+ max_subsequence_ratio: Exclude any exact matches with this much overlap.
+ min_align_ratio: Minimum overlap between the template and query.
+
+ Returns:
+ True if the hit passed the prefilter. Raises an exception otherwise.
+
+ Raises:
+ DateError: If the hit date was after the max allowed date.
+ PdbIdError: If the hit PDB ID was identical to the query.
+ AlignRatioError: If the hit align ratio to the query was too small.
+ DuplicateError: If the hit was an exact subsequence of the query.
+ LengthError: If the hit was too short.
+ """
+ aligned_cols = hit.aligned_cols
+ align_ratio = aligned_cols / len(query_sequence)
+
+ template_sequence = hit.hit_sequence.replace("-", "")
+ length_ratio = float(len(template_sequence)) / len(query_sequence)
+
+ # Check whether the template is a large subsequence or duplicate of original
+ # query. This can happen due to duplicate entries in the PDB database.
+ duplicate = (
+ template_sequence in query_sequence
+ and length_ratio > max_subsequence_ratio
+ )
+
+ if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
+ date = release_dates[hit_pdb_code.upper()]
+ raise DateError(
+ f"Date ({date}) > max template date "
+ f"({release_date_cutoff})."
+ )
+
+ if query_pdb_code is not None:
+ if query_pdb_code.lower() == hit_pdb_code.lower():
+ raise PdbIdError("PDB code identical to Query PDB code.")
+
+ if align_ratio <= min_align_ratio:
+ raise AlignRatioError(
+ "Proportion of residues aligned to query too small. "
+ f"Align ratio: {align_ratio}."
+ )
+
+ if duplicate:
+ raise DuplicateError(
+ "Template is an exact subsequence of query with large "
+ f"coverage. Length ratio: {length_ratio}."
+ )
+
+ if len(template_sequence) < 10:
+ raise LengthError(
+ f"Template too short. Length: {len(template_sequence)}."
+ )
+
+ return True
+
+
+def _find_template_in_pdb(
+ template_chain_id: str,
+ template_sequence: str,
+ mmcif_object: mmcif_parsing.MmcifObject,
+) -> Tuple[str, str, int]:
+ """Tries to find the template chain in the given pdb file.
+
+ This method tries the three following things in order:
+ 1. Tries if there is an exact match in both the chain ID and the sequence.
+ If yes, the chain sequence is returned. Otherwise:
+ 2. Tries if there is an exact match only in the sequence.
+ If yes, the chain sequence is returned. Otherwise:
+ 3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
+ If yes, the chain sequence is returned.
+ If none of these succeed, a SequenceNotInTemplateError is thrown.
+
+ Args:
+ template_chain_id: The template chain ID.
+ template_sequence: The template chain sequence.
+ mmcif_object: The PDB object to search for the template in.
+
+ Returns:
+ A tuple with:
+ * The chain sequence that was found to match the template in the PDB object.
+ * The ID of the chain that is being returned.
+ * The offset where the template sequence starts in the chain sequence.
+
+ Raises:
+ SequenceNotInTemplateError: If no match is found after the steps described
+ above.
+ """
+ # Try if there is an exact match in both the chain ID and the (sub)sequence.
+ pdb_id = mmcif_object.file_id
+ chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
+ if chain_sequence and (template_sequence in chain_sequence):
+ logging.info(
+ "Found an exact template match %s_%s.", pdb_id, template_chain_id
+ )
+ mapping_offset = chain_sequence.find(template_sequence)
+ return chain_sequence, template_chain_id, mapping_offset
+
+ # Try if there is an exact match in the (sub)sequence only.
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
+ if chain_sequence and (template_sequence in chain_sequence):
+ logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
+ mapping_offset = chain_sequence.find(template_sequence)
+ return chain_sequence, chain_id, mapping_offset
+
+ # Return a chain sequence that fuzzy matches (X = wildcard) the template.
+ # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
+ regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
+ regex = re.compile("".join(regex))
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
+ match = re.search(regex, chain_sequence)
+ if match:
+ logging.info(
+ "Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
+ )
+ mapping_offset = match.start()
+ return chain_sequence, chain_id, mapping_offset
+
+ # No hits, raise an error.
+ raise SequenceNotInTemplateError(
+ "Could not find the template sequence in %s_%s. Template sequence: %s, "
+ "chain_to_seqres: %s"
+ % (
+ pdb_id,
+ template_chain_id,
+ template_sequence,
+ mmcif_object.chain_to_seqres,
+ )
+ )
+
+
+def _realign_pdb_template_to_query(
+ old_template_sequence: str,
+ template_chain_id: str,
+ mmcif_object: mmcif_parsing.MmcifObject,
+ old_mapping: Mapping[int, int],
+ kalign_binary_path: str,
+) -> Tuple[str, Mapping[int, int]]:
+ """Aligns template from the mmcif_object to the query.
+
+ In case PDB70 contains a different version of the template sequence, we need
+ to perform a realignment to the actual sequence that is in the mmCIF file.
+ This method performs such realignment, but returns the new sequence and
+ mapping only if the sequence in the mmCIF file is 90% identical to the old
+ sequence.
+
+ Note that the old_template_sequence comes from the hit, and contains only that
+ part of the chain that matches with the query while the new_template_sequence
+ is the full chain.
+
+ Args:
+ old_template_sequence: The template sequence that was returned by the PDB
+ template search (typically done using HHSearch).
+ template_chain_id: The template chain id was returned by the PDB template
+ search (typically done using HHSearch). This is used to find the right
+ chain in the mmcif_object chain_to_seqres mapping.
+ mmcif_object: A mmcif_object which holds the actual template data.
+ old_mapping: A mapping from the query sequence to the template sequence.
+ This mapping will be used to compute the new mapping from the query
+ sequence to the actual mmcif_object template sequence by aligning the
+ old_template_sequence and the actual template sequence.
+ kalign_binary_path: The path to a kalign executable.
+
+ Returns:
+ A tuple (new_template_sequence, new_query_to_template_mapping) where:
+ * new_template_sequence is the actual template sequence that was found in
+ the mmcif_object.
+ * new_query_to_template_mapping is the new mapping from the query to the
+ actual template found in the mmcif_object.
+
+ Raises:
+ QueryToTemplateAlignError:
+ * If there was an error thrown by the alignment tool.
+ * Or if the actual template sequence differs by more than 10% from the
+ old_template_sequence.
+ """
+ aligner = kalign.Kalign(binary_path=kalign_binary_path)
+ new_template_sequence = mmcif_object.chain_to_seqres.get(
+ template_chain_id, ""
+ )
+
+ # Sometimes the template chain id is unknown. But if there is only a single
+ # sequence within the mmcif_object, it is safe to assume it is that one.
+ if not new_template_sequence:
+ if len(mmcif_object.chain_to_seqres) == 1:
+ logging.info(
+ "Could not find %s in %s, but there is only 1 sequence, so "
+ "using that one.",
+ template_chain_id,
+ mmcif_object.file_id,
+ )
+ new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
+ 0
+ ]
+ else:
+ raise QueryToTemplateAlignError(
+ f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
+ "If there are no mmCIF parsing errors, it is possible it was not a "
+ "protein chain."
+ )
+
+ try:
+ (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
+ aligner.align([old_template_sequence, new_template_sequence])
+ )
+ except Exception as e:
+ raise QueryToTemplateAlignError(
+ "Could not align old template %s to template %s (%s_%s). Error: %s"
+ % (
+ old_template_sequence,
+ new_template_sequence,
+ mmcif_object.file_id,
+ template_chain_id,
+ str(e),
+ )
+ )
+
+ logging.info(
+ "Old aligned template: %s\nNew aligned template: %s",
+ old_aligned_template,
+ new_aligned_template,
+ )
+
+ old_to_new_template_mapping = {}
+ old_template_index = -1
+ new_template_index = -1
+ num_same = 0
+ for old_template_aa, new_template_aa in zip(
+ old_aligned_template, new_aligned_template
+ ):
+ if old_template_aa != "-":
+ old_template_index += 1
+ if new_template_aa != "-":
+ new_template_index += 1
+ if old_template_aa != "-" and new_template_aa != "-":
+ old_to_new_template_mapping[old_template_index] = new_template_index
+ if old_template_aa == new_template_aa:
+ num_same += 1
+
+ # Require at least 90 % sequence identity wrt to the shorter of the sequences.
+ if (
+ float(num_same)
+ / min(len(old_template_sequence), len(new_template_sequence))
+ < 0.9
+ ):
+ raise QueryToTemplateAlignError(
+ "Insufficient similarity of the sequence in the database: %s to the "
+ "actual sequence in the mmCIF file %s_%s: %s. We require at least "
+ "90 %% similarity wrt to the shorter of the sequences. This is not a "
+ "problem unless you think this is a template that should be included."
+ % (
+ old_template_sequence,
+ mmcif_object.file_id,
+ template_chain_id,
+ new_template_sequence,
+ )
+ )
+
+ new_query_to_template_mapping = {}
+ for query_index, old_template_index in old_mapping.items():
+ new_query_to_template_mapping[
+ query_index
+ ] = old_to_new_template_mapping.get(old_template_index, -1)
+
+ new_template_sequence = new_template_sequence.replace("-", "")
+
+ return new_template_sequence, new_query_to_template_mapping
+
+
+def _check_residue_distances(
+ all_positions: np.ndarray,
+ all_positions_mask: np.ndarray,
+ max_ca_ca_distance: float,
+):
+ """Checks if the distance between unmasked neighbor residues is ok."""
+ ca_position = residue_constants.atom_order["CA"]
+ prev_is_unmasked = False
+ prev_calpha = None
+ for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
+ this_is_unmasked = bool(mask[ca_position])
+ if this_is_unmasked:
+ this_calpha = coords[ca_position]
+ if prev_is_unmasked:
+ distance = np.linalg.norm(this_calpha - prev_calpha)
+ if distance > max_ca_ca_distance:
+ raise CaDistanceError(
+ "The distance between residues %d and %d is %f > limit %f."
+ % (i, i + 1, distance, max_ca_ca_distance)
+ )
+ prev_calpha = this_calpha
+ prev_is_unmasked = this_is_unmasked
+
+
+def _get_atom_positions(
+ mmcif_object: mmcif_parsing.MmcifObject,
+ auth_chain_id: str,
+ max_ca_ca_distance: float,
+ _zero_center_positions: bool = True,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Gets atom positions and mask from a list of Biopython Residues."""
+ coords_with_mask = mmcif_parsing.get_atom_coords(
+ mmcif_object=mmcif_object,
+ chain_id=auth_chain_id,
+ _zero_center_positions=_zero_center_positions,
+ )
+ all_atom_positions, all_atom_mask = coords_with_mask
+ _check_residue_distances(
+ all_atom_positions, all_atom_mask, max_ca_ca_distance
+ )
+ return all_atom_positions, all_atom_mask
+
+
+def _extract_template_features(
+ mmcif_object: mmcif_parsing.MmcifObject,
+ pdb_id: str,
+ mapping: Mapping[int, int],
+ template_sequence: str,
+ query_sequence: str,
+ template_chain_id: str,
+ kalign_binary_path: str,
+ _zero_center_positions: bool = True,
+) -> Tuple[Dict[str, Any], Optional[str]]:
+ """Parses atom positions in the target structure and aligns with the query.
+
+ Atoms for each residue in the template structure are indexed to coincide
+ with their corresponding residue in the query sequence, according to the
+ alignment mapping provided.
+
+ Args:
+ mmcif_object: mmcif_parsing.MmcifObject representing the template.
+ pdb_id: PDB code for the template.
+ mapping: Dictionary mapping indices in the query sequence to indices in
+ the template sequence.
+ template_sequence: String describing the amino acid sequence for the
+ template protein.
+ query_sequence: String describing the amino acid sequence for the query
+ protein.
+ template_chain_id: String ID describing which chain in the structure proto
+ should be used.
+ kalign_binary_path: The path to a kalign executable used for template
+ realignment.
+
+ Returns:
+ A tuple with:
+ * A dictionary containing the extra features derived from the template
+ protein structure.
+ * A warning message if the hit was realigned to the actual mmCIF sequence.
+ Otherwise None.
+
+ Raises:
+ NoChainsError: If the mmcif object doesn't contain any chains.
+ SequenceNotInTemplateError: If the given chain id / sequence can't
+ be found in the mmcif object.
+ QueryToTemplateAlignError: If the actual template in the mmCIF file
+ can't be aligned to the query.
+ NoAtomDataInTemplateError: If the mmcif object doesn't contain
+ atom positions.
+ TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
+ unmasked residues.
+ """
+ if mmcif_object is None or not mmcif_object.chain_to_seqres:
+ raise NoChainsError(
+ "No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
+ )
+
+ warning = None
+ try:
+ seqres, chain_id, mapping_offset = _find_template_in_pdb(
+ template_chain_id=template_chain_id,
+ template_sequence=template_sequence,
+ mmcif_object=mmcif_object,
+ )
+ except SequenceNotInTemplateError:
+ # If PDB70 contains a different version of the template, we use the sequence
+ # from the mmcif_object.
+ chain_id = template_chain_id
+ warning = (
+ f"The exact sequence {template_sequence} was not found in "
+ f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
+ )
+ logging.warning(warning)
+ # This throws an exception if it fails to realign the hit.
+ seqres, mapping = _realign_pdb_template_to_query(
+ old_template_sequence=template_sequence,
+ template_chain_id=template_chain_id,
+ mmcif_object=mmcif_object,
+ old_mapping=mapping,
+ kalign_binary_path=kalign_binary_path,
+ )
+ logging.info(
+ "Sequence in %s_%s: %s successfully realigned to %s",
+ pdb_id,
+ chain_id,
+ template_sequence,
+ seqres,
+ )
+ # The template sequence changed.
+ template_sequence = seqres
+ # No mapping offset, the query is aligned to the actual sequence.
+ mapping_offset = 0
+
+ try:
+ # Essentially set to infinity - we don't want to reject templates unless
+ # they're really really bad.
+ all_atom_positions, all_atom_mask = _get_atom_positions(
+ mmcif_object,
+ chain_id,
+ max_ca_ca_distance=150.0,
+ _zero_center_positions=_zero_center_positions,
+ )
+ except (CaDistanceError, KeyError) as ex:
+ raise NoAtomDataInTemplateError(
+ "Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
+ ) from ex
+
+ all_atom_positions = np.split(
+ all_atom_positions, all_atom_positions.shape[0]
+ )
+ all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
+
+ output_templates_sequence = []
+ templates_all_atom_positions = []
+ templates_all_atom_masks = []
+
+ for _ in query_sequence:
+ # Residues in the query_sequence that are not in the template_sequence:
+ templates_all_atom_positions.append(
+ np.zeros((residue_constants.atom_type_num, 3))
+ )
+ templates_all_atom_masks.append(
+ np.zeros(residue_constants.atom_type_num)
+ )
+ output_templates_sequence.append("-")
+
+ for k, v in mapping.items():
+ template_index = v + mapping_offset
+ templates_all_atom_positions[k] = all_atom_positions[template_index][0]
+ templates_all_atom_masks[k] = all_atom_masks[template_index][0]
+ output_templates_sequence[k] = template_sequence[v]
+
+ # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
+ if np.sum(templates_all_atom_masks) < 5:
+ raise TemplateAtomMaskAllZerosError(
+ "Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
+ % (
+ pdb_id,
+ chain_id,
+ min(mapping.values()) + mapping_offset,
+ max(mapping.values()) + mapping_offset,
+ )
+ )
+
+ output_templates_sequence = "".join(output_templates_sequence)
+
+ templates_aatype = residue_constants.sequence_to_onehot(
+ output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
+ )
+
+ return (
+ {
+ "template_all_atom_positions": np.array(
+ templates_all_atom_positions
+ ),
+ "template_all_atom_mask": np.array(templates_all_atom_masks),
+ "template_sequence": output_templates_sequence.encode(),
+ "template_aatype": np.array(templates_aatype),
+ "template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
+ },
+ warning,
+ )
+
+
+def _build_query_to_hit_index_mapping(
+ hit_query_sequence: str,
+ hit_sequence: str,
+ indices_hit: Sequence[int],
+ indices_query: Sequence[int],
+ original_query_sequence: str,
+) -> Mapping[int, int]:
+ """Gets mapping from indices in original query sequence to indices in the hit.
+
+ hit_query_sequence and hit_sequence are two aligned sequences containing gap
+ characters. hit_query_sequence contains only the part of the original query
+ sequence that matched the hit. When interpreting the indices from the .hhr, we
+ need to correct for this to recover a mapping from original query sequence to
+ the hit sequence.
+
+ Args:
+ hit_query_sequence: The portion of the query sequence that is in the .hhr
+ hit
+ hit_sequence: The portion of the hit sequence that is in the .hhr
+ indices_hit: The indices for each aminoacid relative to the hit sequence
+ indices_query: The indices for each aminoacid relative to the original query
+ sequence
+ original_query_sequence: String describing the original query sequence.
+
+ Returns:
+ Dictionary with indices in the original query sequence as keys and indices
+ in the hit sequence as values.
+ """
+ # If the hit is empty (no aligned residues), return empty mapping
+ if not hit_query_sequence:
+ return {}
+
+ # Remove gaps and find the offset of hit.query relative to original query.
+ hhsearch_query_sequence = hit_query_sequence.replace("-", "")
+ hit_sequence = hit_sequence.replace("-", "")
+ hhsearch_query_offset = original_query_sequence.find(
+ hhsearch_query_sequence
+ )
+
+ # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
+ min_idx = min(x for x in indices_hit if x > -1)
+ fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
+
+ min_idx = min(x for x in indices_query if x > -1)
+ fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
+
+ # Zip the corrected indices, ignore case where both seqs have gap characters.
+ mapping = {}
+ for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
+ if q_t != -1 and q_i != -1:
+ if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
+ original_query_sequence
+ ):
+ continue
+ mapping[q_i + hhsearch_query_offset] = q_t
+
+ return mapping
+
+
+@dataclasses.dataclass(frozen=True)
+class PrefilterResult:
+ valid: bool
+ error: Optional[str]
+ warning: Optional[str]
+
+@dataclasses.dataclass(frozen=True)
+class SingleHitResult:
+ features: Optional[Mapping[str, Any]]
+ error: Optional[str]
+ warning: Optional[str]
+
+
+def _prefilter_hit(
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ hit: parsers.TemplateHit,
+ max_template_date: datetime.datetime,
+ release_dates: Mapping[str, datetime.datetime],
+ obsolete_pdbs: Mapping[str, str],
+ strict_error_check: bool = False,
+):
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
+
+ if hit_pdb_code not in release_dates:
+ if hit_pdb_code in obsolete_pdbs:
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
+
+ # Pass hit_pdb_code since it might have changed due to the pdb being
+ # obsolete.
+ try:
+ _assess_hhsearch_hit(
+ hit=hit,
+ hit_pdb_code=hit_pdb_code,
+ query_sequence=query_sequence,
+ query_pdb_code=query_pdb_code,
+ release_dates=release_dates,
+ release_date_cutoff=max_template_date,
+ )
+ except PrefilterError as e:
+ hit_name = f"{hit_pdb_code}_{hit_chain_id}"
+ msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
+ logging.info("%s: %s", query_pdb_code, msg)
+ if strict_error_check and isinstance(
+ e, (DateError, PdbIdError, DuplicateError)
+ ):
+ # In strict mode we treat some prefilter cases as errors.
+ return PrefilterResult(valid=False, error=msg, warning=None)
+
+ return PrefilterResult(valid=False, error=None, warning=None)
+
+ return PrefilterResult(valid=True, error=None, warning=None)
+
+
+def _process_single_hit(
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ hit: parsers.TemplateHit,
+ mmcif_dir: str,
+ max_template_date: datetime.datetime,
+ release_dates: Mapping[str, datetime.datetime],
+ obsolete_pdbs: Mapping[str, str],
+ kalign_binary_path: str,
+ strict_error_check: bool = False,
+ _zero_center_positions: bool = True,
+) -> SingleHitResult:
+ """Tries to extract template features from a single HHSearch hit."""
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
+
+ if hit_pdb_code not in release_dates:
+ if hit_pdb_code in obsolete_pdbs:
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
+
+ mapping = _build_query_to_hit_index_mapping(
+ hit.query,
+ hit.hit_sequence,
+ hit.indices_hit,
+ hit.indices_query,
+ query_sequence,
+ )
+
+ # The mapping is from the query to the actual hit sequence, so we need to
+ # remove gaps (which regardless have a missing confidence score).
+ template_sequence = hit.hit_sequence.replace("-", "")
+
+ cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
+ logging.info(
+ "Reading PDB entry from %s. Query: %s, template: %s",
+ cif_path,
+ query_sequence,
+ template_sequence,
+ )
+ # Fail if we can't find the mmCIF file.
+ with open(cif_path, "r") as cif_file:
+ cif_string = cif_file.read()
+
+ parsing_result = mmcif_parsing.parse(
+ file_id=hit_pdb_code, mmcif_string=cif_string
+ )
+
+ if parsing_result.mmcif_object is not None:
+ hit_release_date = datetime.datetime.strptime(
+ parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
+ )
+ if hit_release_date > max_template_date:
+ error = "Template %s date (%s) > max template date (%s)." % (
+ hit_pdb_code,
+ hit_release_date,
+ max_template_date,
+ )
+ if strict_error_check:
+ return SingleHitResult(features=None, error=error, warning=None)
+ else:
+ logging.info(error)
+ return SingleHitResult(features=None, error=None, warning=None)
+
+ try:
+ features, realign_warning = _extract_template_features(
+ mmcif_object=parsing_result.mmcif_object,
+ pdb_id=hit_pdb_code,
+ mapping=mapping,
+ template_sequence=template_sequence,
+ query_sequence=query_sequence,
+ template_chain_id=hit_chain_id,
+ kalign_binary_path=kalign_binary_path,
+ _zero_center_positions=_zero_center_positions,
+ )
+ features["template_sum_probs"] = [hit.sum_probs]
+
+ # It is possible there were some errors when parsing the other chains in the
+ # mmCIF file, but the template features for the chain we want were still
+ # computed. In such case the mmCIF parsing errors are not relevant.
+ return SingleHitResult(
+ features=features, error=None, warning=realign_warning
+ )
+ except (
+ NoChainsError,
+ NoAtomDataInTemplateError,
+ TemplateAtomMaskAllZerosError,
+ ) as e:
+ # These 3 errors indicate missing mmCIF experimental data rather than a
+ # problem with the template search, so turn them into warnings.
+ warning = (
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
+ "%s, mmCIF parsing errors: %s"
+ % (
+ hit_pdb_code,
+ hit_chain_id,
+ hit.sum_probs,
+ hit.index,
+ str(e),
+ parsing_result.errors,
+ )
+ )
+ if strict_error_check:
+ return SingleHitResult(features=None, error=warning, warning=None)
+ else:
+ return SingleHitResult(features=None, error=None, warning=warning)
+ except Error as e:
+ error = (
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
+ "%s, mmCIF parsing errors: %s"
+ % (
+ hit_pdb_code,
+ hit_chain_id,
+ hit.sum_probs,
+ hit.index,
+ str(e),
+ parsing_result.errors,
+ )
+ )
+ return SingleHitResult(features=None, error=error, warning=None)
+
+
+@dataclasses.dataclass(frozen=True)
+class TemplateSearchResult:
+ features: Mapping[str, Any]
+ errors: Sequence[str]
+ warnings: Sequence[str]
+
+
+class TemplateHitFeaturizer:
+ """A class for turning hhr hits to template features."""
+ def __init__(
+ self,
+ mmcif_dir: str,
+ max_template_date: str,
+ max_hits: int,
+ kalign_binary_path: str,
+ release_dates_path: Optional[str] = None,
+ obsolete_pdbs_path: Optional[str] = None,
+ strict_error_check: bool = False,
+ _shuffle_top_k_prefiltered: Optional[int] = None,
+ _zero_center_positions: bool = True,
+ ):
+ """Initializes the Template Search.
+
+ Args:
+ mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
+ is found by HHSearch, this directory is used to retrieve the template
+ data.
+ max_template_date: The maximum date permitted for template structures. No
+ template with date higher than this date will be returned. In ISO8601
+ date format, YYYY-MM-DD.
+ max_hits: The maximum number of templates that will be returned.
+ kalign_binary_path: The path to a kalign executable used for template
+ realignment.
+ release_dates_path: An optional path to a file with a mapping from PDB IDs
+ to their release dates. Thanks to this we don't have to redundantly
+ parse mmCIF files to get that information.
+ obsolete_pdbs_path: An optional path to a file containing a mapping from
+ obsolete PDB IDs to the PDB IDs of their replacements.
+ strict_error_check: If True, then the following will be treated as errors:
+ * If any template date is after the max_template_date.
+ * If any template has identical PDB ID to the query.
+ * If any template is a duplicate of the query.
+ * Any feature computation errors.
+ """
+ self._mmcif_dir = mmcif_dir
+ if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
+ logging.error("Could not find CIFs in %s", self._mmcif_dir)
+ raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
+
+ try:
+ self._max_template_date = datetime.datetime.strptime(
+ max_template_date, "%Y-%m-%d"
+ )
+ except ValueError:
+ raise ValueError(
+ "max_template_date must be set and have format YYYY-MM-DD."
+ )
+ self.max_hits = max_hits
+ self._kalign_binary_path = kalign_binary_path
+ self._strict_error_check = strict_error_check
+
+ if release_dates_path:
+ logging.info(
+ "Using precomputed release dates %s.", release_dates_path
+ )
+ self._release_dates = _parse_release_dates(release_dates_path)
+ else:
+ self._release_dates = {}
+
+ if obsolete_pdbs_path:
+ logging.info(
+ "Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
+ )
+ self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
+ else:
+ self._obsolete_pdbs = {}
+
+ self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
+ self._zero_center_positions = _zero_center_positions
+
+ def get_templates(
+ self,
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ query_release_date: Optional[datetime.datetime],
+ hits: Sequence[parsers.TemplateHit],
+ ) -> TemplateSearchResult:
+ """Computes the templates for given query sequence (more details above)."""
+ logging.info("Searching for template for: %s", query_pdb_code)
+
+ template_features = {}
+ for template_feature_name in TEMPLATE_FEATURES:
+ template_features[template_feature_name] = []
+
+ # Always use a max_template_date. Set to query_release_date minus 60 days
+ # if that's earlier.
+ template_cutoff_date = self._max_template_date
+ if query_release_date:
+ delta = datetime.timedelta(days=60)
+ if query_release_date - delta < template_cutoff_date:
+ template_cutoff_date = query_release_date - delta
+ assert template_cutoff_date < query_release_date
+ assert template_cutoff_date <= self._max_template_date
+
+ num_hits = 0
+ errors = []
+ warnings = []
+
+ filtered = []
+ for hit in hits:
+ prefilter_result = _prefilter_hit(
+ query_sequence=query_sequence,
+ query_pdb_code=query_pdb_code,
+ hit=hit,
+ max_template_date=template_cutoff_date,
+ release_dates=self._release_dates,
+ obsolete_pdbs=self._obsolete_pdbs,
+ strict_error_check=self._strict_error_check,
+ )
+
+ if prefilter_result.error:
+ errors.append(prefilter_result.error)
+
+ if prefilter_result.warning:
+ warnings.append(prefilter_result.warning)
+
+ if prefilter_result.valid:
+ filtered.append(hit)
+
+ filtered = list(
+ sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
+ )
+ idx = list(range(len(filtered)))
+ if(self._shuffle_top_k_prefiltered):
+ stk = self._shuffle_top_k_prefiltered
+ idx[:stk] = np.random.permutation(idx[:stk])
+
+ for i in idx:
+ # We got all the templates we wanted, stop processing hits.
+ if num_hits >= self.max_hits:
+ break
+
+ hit = filtered[i]
+
+ result = _process_single_hit(
+ query_sequence=query_sequence,
+ query_pdb_code=query_pdb_code,
+ hit=hit,
+ mmcif_dir=self._mmcif_dir,
+ max_template_date=template_cutoff_date,
+ release_dates=self._release_dates,
+ obsolete_pdbs=self._obsolete_pdbs,
+ strict_error_check=self._strict_error_check,
+ kalign_binary_path=self._kalign_binary_path,
+ _zero_center_positions=self._zero_center_positions,
+ )
+
+ if result.error:
+ errors.append(result.error)
+
+ # There could be an error even if there are some results, e.g. thrown by
+ # other unparsable chains in the same mmCIF file.
+ if result.warning:
+ warnings.append(result.warning)
+
+ if result.features is None:
+ logging.info(
+ "Skipped invalid hit %s, error: %s, warning: %s",
+ hit.name,
+ result.error,
+ result.warning,
+ )
+ else:
+ # Increment the hit counter, since we got features out of this hit.
+ num_hits += 1
+ for k in template_features:
+ template_features[k].append(result.features[k])
+
+ for name in template_features:
+ if num_hits > 0:
+ template_features[name] = np.stack(
+ template_features[name], axis=0
+ ).astype(TEMPLATE_FEATURES[name])
+ else:
+ # Make sure the feature has correct dtype even if empty.
+ template_features[name] = np.array(
+ [], dtype=TEMPLATE_FEATURES[name]
+ )
+
+ return TemplateSearchResult(
+ features=template_features, errors=errors, warnings=warnings
+ )
diff --git a/openfold/data/tools/__init__.py b/openfold/data/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openfold/data/tools/hhblits.py b/openfold/data/tools/hhblits.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e1c1e7bfdbf1cf7d6f373f79d8f7d2c354195fb
--- /dev/null
+++ b/openfold/data/tools/hhblits.py
@@ -0,0 +1,175 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library to run HHblits from Python."""
+import glob
+import logging
+import os
+import subprocess
+from typing import Any, Mapping, Optional, Sequence
+
+from openfold.data.tools import utils
+
+
+_HHBLITS_DEFAULT_P = 20
+_HHBLITS_DEFAULT_Z = 500
+
+
+class HHBlits:
+ """Python wrapper of the HHblits binary."""
+
+ def __init__(
+ self,
+ *,
+ binary_path: str,
+ databases: Sequence[str],
+ n_cpu: int = 4,
+ n_iter: int = 3,
+ e_value: float = 0.001,
+ maxseq: int = 1_000_000,
+ realign_max: int = 100_000,
+ maxfilt: int = 100_000,
+ min_prefilter_hits: int = 1000,
+ all_seqs: bool = False,
+ alt: Optional[int] = None,
+ p: int = _HHBLITS_DEFAULT_P,
+ z: int = _HHBLITS_DEFAULT_Z,
+ ):
+ """Initializes the Python HHblits wrapper.
+
+ Args:
+ binary_path: The path to the HHblits executable.
+ databases: A sequence of HHblits database paths. This should be the
+ common prefix for the database files (i.e. up to but not including
+ _hhm.ffindex etc.)
+ n_cpu: The number of CPUs to give HHblits.
+ n_iter: The number of HHblits iterations.
+ e_value: The E-value, see HHblits docs for more details.
+ maxseq: The maximum number of rows in an input alignment. Note that this
+ parameter is only supported in HHBlits version 3.1 and higher.
+ realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
+ maxfilt: Max number of hits allowed to pass the 2nd prefilter.
+ HHblits default: 20000.
+ min_prefilter_hits: Min number of hits to pass prefilter.
+ HHblits default: 100.
+ all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
+ HHblits default: False.
+ alt: Show up to this many alternative alignments.
+ p: Minimum Prob for a hit to be included in the output hhr file.
+ HHblits default: 20.
+ z: Hard cap on number of hits reported in the hhr file.
+ HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
+
+ Raises:
+ RuntimeError: If HHblits binary not found within the path.
+ """
+ self.binary_path = binary_path
+ self.databases = databases
+
+ for database_path in self.databases:
+ if not glob.glob(database_path + "_*"):
+ logging.error(
+ "Could not find HHBlits database %s", database_path
+ )
+ raise ValueError(
+ f"Could not find HHBlits database {database_path}"
+ )
+
+ self.n_cpu = n_cpu
+ self.n_iter = n_iter
+ self.e_value = e_value
+ self.maxseq = maxseq
+ self.realign_max = realign_max
+ self.maxfilt = maxfilt
+ self.min_prefilter_hits = min_prefilter_hits
+ self.all_seqs = all_seqs
+ self.alt = alt
+ self.p = p
+ self.z = z
+
+ def query(self, input_fasta_path: str) -> Mapping[str, Any]:
+ """Queries the database using HHblits."""
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
+ a3m_path = os.path.join(query_tmp_dir, "output.a3m")
+
+ db_cmd = []
+ for db_path in self.databases:
+ db_cmd.append("-d")
+ db_cmd.append(db_path)
+ cmd = [
+ self.binary_path,
+ "-i",
+ input_fasta_path,
+ "-cpu",
+ str(self.n_cpu),
+ "-oa3m",
+ a3m_path,
+ "-o",
+ "/dev/null",
+ "-n",
+ str(self.n_iter),
+ "-e",
+ str(self.e_value),
+ "-maxseq",
+ str(self.maxseq),
+ "-realign_max",
+ str(self.realign_max),
+ "-maxfilt",
+ str(self.maxfilt),
+ "-min_prefilter_hits",
+ str(self.min_prefilter_hits),
+ ]
+ if self.all_seqs:
+ cmd += ["-all"]
+ if self.alt:
+ cmd += ["-alt", str(self.alt)]
+ if self.p != _HHBLITS_DEFAULT_P:
+ cmd += ["-p", str(self.p)]
+ if self.z != _HHBLITS_DEFAULT_Z:
+ cmd += ["-Z", str(self.z)]
+ cmd += db_cmd
+
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+
+ with utils.timing("HHblits query"):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+
+ if retcode:
+ # Logs have a 15k character limit, so log HHblits error line by line.
+ logging.error("HHblits failed. HHblits stderr begin:")
+ for error_line in stderr.decode("utf-8").splitlines():
+ if error_line.strip():
+ logging.error(error_line.strip())
+ logging.error("HHblits stderr end")
+ raise RuntimeError(
+ "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
+ % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
+ )
+
+ with open(a3m_path) as f:
+ a3m = f.read()
+
+ raw_output = dict(
+ a3m=a3m,
+ output=stdout,
+ stderr=stderr,
+ n_iter=self.n_iter,
+ e_value=self.e_value,
+ )
+ return raw_output
diff --git a/openfold/data/tools/hhsearch.py b/openfold/data/tools/hhsearch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62e3b410c5bcf1e1089d961c8c44aad81d7f2a2
--- /dev/null
+++ b/openfold/data/tools/hhsearch.py
@@ -0,0 +1,106 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library to run HHsearch from Python."""
+import glob
+import logging
+import os
+import subprocess
+from typing import Sequence
+
+from openfold.data.tools import utils
+
+
+class HHSearch:
+ """Python wrapper of the HHsearch binary."""
+
+ def __init__(
+ self,
+ *,
+ binary_path: str,
+ databases: Sequence[str],
+ n_cpu: int = 2,
+ maxseq: int = 1_000_000,
+ ):
+ """Initializes the Python HHsearch wrapper.
+
+ Args:
+ binary_path: The path to the HHsearch executable.
+ databases: A sequence of HHsearch database paths. This should be the
+ common prefix for the database files (i.e. up to but not including
+ _hhm.ffindex etc.)
+ n_cpu: The number of CPUs to use
+ maxseq: The maximum number of rows in an input alignment. Note that this
+ parameter is only supported in HHBlits version 3.1 and higher.
+
+ Raises:
+ RuntimeError: If HHsearch binary not found within the path.
+ """
+ self.binary_path = binary_path
+ self.databases = databases
+ self.n_cpu = n_cpu
+ self.maxseq = maxseq
+
+ for database_path in self.databases:
+ if not glob.glob(database_path + "_*"):
+ logging.error(
+ "Could not find HHsearch database %s", database_path
+ )
+ raise ValueError(
+ f"Could not find HHsearch database {database_path}"
+ )
+
+ def query(self, a3m: str) -> str:
+ """Queries the database using HHsearch using a given a3m."""
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
+ input_path = os.path.join(query_tmp_dir, "query.a3m")
+ hhr_path = os.path.join(query_tmp_dir, "output.hhr")
+ with open(input_path, "w") as f:
+ f.write(a3m)
+
+ db_cmd = []
+ for db_path in self.databases:
+ db_cmd.append("-d")
+ db_cmd.append(db_path)
+ cmd = [
+ self.binary_path,
+ "-i",
+ input_path,
+ "-o",
+ hhr_path,
+ "-maxseq",
+ str(self.maxseq),
+ "-cpu",
+ str(self.n_cpu),
+ ] + db_cmd
+
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ with utils.timing("HHsearch query"):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+
+ if retcode:
+ # Stderr is truncated to prevent proto size errors in Beam.
+ raise RuntimeError(
+ "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
+ % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
+ )
+
+ with open(hhr_path) as f:
+ hhr = f.read()
+ return hhr
diff --git a/openfold/data/tools/jackhmmer.py b/openfold/data/tools/jackhmmer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca93edfe08fee383e40c451a647d8ff48230f3c
--- /dev/null
+++ b/openfold/data/tools/jackhmmer.py
@@ -0,0 +1,228 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library to run Jackhmmer from Python."""
+
+from concurrent import futures
+import glob
+import logging
+import os
+import subprocess
+from typing import Any, Callable, Mapping, Optional, Sequence
+from urllib import request
+
+from openfold.data.tools import utils
+
+
+class Jackhmmer:
+ """Python wrapper of the Jackhmmer binary."""
+
+ def __init__(
+ self,
+ *,
+ binary_path: str,
+ database_path: str,
+ n_cpu: int = 8,
+ n_iter: int = 1,
+ e_value: float = 0.0001,
+ z_value: Optional[int] = None,
+ get_tblout: bool = False,
+ filter_f1: float = 0.0005,
+ filter_f2: float = 0.00005,
+ filter_f3: float = 0.0000005,
+ incdom_e: Optional[float] = None,
+ dom_e: Optional[float] = None,
+ num_streamed_chunks: Optional[int] = None,
+ streaming_callback: Optional[Callable[[int], None]] = None,
+ ):
+ """Initializes the Python Jackhmmer wrapper.
+
+ Args:
+ binary_path: The path to the jackhmmer executable.
+ database_path: The path to the jackhmmer database (FASTA format).
+ n_cpu: The number of CPUs to give Jackhmmer.
+ n_iter: The number of Jackhmmer iterations.
+ e_value: The E-value, see Jackhmmer docs for more details.
+ z_value: The Z-value, see Jackhmmer docs for more details.
+ get_tblout: Whether to save tblout string.
+ filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
+ filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
+ filter_f3: Forward pre-filter, set to >1.0 to turn off.
+ incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
+ round.
+ dom_e: Domain e-value criteria for inclusion in tblout.
+ num_streamed_chunks: Number of database chunks to stream over.
+ streaming_callback: Callback function run after each chunk iteration with
+ the iteration number as argument.
+ """
+ self.binary_path = binary_path
+ self.database_path = database_path
+ self.num_streamed_chunks = num_streamed_chunks
+
+ if (
+ not os.path.exists(self.database_path)
+ and num_streamed_chunks is None
+ ):
+ logging.error("Could not find Jackhmmer database %s", database_path)
+ raise ValueError(
+ f"Could not find Jackhmmer database {database_path}"
+ )
+
+ self.n_cpu = n_cpu
+ self.n_iter = n_iter
+ self.e_value = e_value
+ self.z_value = z_value
+ self.filter_f1 = filter_f1
+ self.filter_f2 = filter_f2
+ self.filter_f3 = filter_f3
+ self.incdom_e = incdom_e
+ self.dom_e = dom_e
+ self.get_tblout = get_tblout
+ self.streaming_callback = streaming_callback
+
+ def _query_chunk(
+ self, input_fasta_path: str, database_path: str
+ ) -> Mapping[str, Any]:
+ """Queries the database chunk using Jackhmmer."""
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
+ sto_path = os.path.join(query_tmp_dir, "output.sto")
+
+ # The F1/F2/F3 are the expected proportion to pass each of the filtering
+ # stages (which get progressively more expensive), reducing these
+ # speeds up the pipeline at the expensive of sensitivity. They are
+ # currently set very low to make querying Mgnify run in a reasonable
+ # amount of time.
+ cmd_flags = [
+ # Don't pollute stdout with Jackhmmer output.
+ "-o",
+ "/dev/null",
+ "-A",
+ sto_path,
+ "--noali",
+ "--F1",
+ str(self.filter_f1),
+ "--F2",
+ str(self.filter_f2),
+ "--F3",
+ str(self.filter_f3),
+ "--incE",
+ str(self.e_value),
+ # Report only sequences with E-values <= x in per-sequence output.
+ "-E",
+ str(self.e_value),
+ "--cpu",
+ str(self.n_cpu),
+ "-N",
+ str(self.n_iter),
+ ]
+ if self.get_tblout:
+ tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
+ cmd_flags.extend(["--tblout", tblout_path])
+
+ if self.z_value:
+ cmd_flags.extend(["-Z", str(self.z_value)])
+
+ if self.dom_e is not None:
+ cmd_flags.extend(["--domE", str(self.dom_e)])
+
+ if self.incdom_e is not None:
+ cmd_flags.extend(["--incdomE", str(self.incdom_e)])
+
+ cmd = (
+ [self.binary_path]
+ + cmd_flags
+ + [input_fasta_path, database_path]
+ )
+
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ with utils.timing(
+ f"Jackhmmer ({os.path.basename(database_path)}) query"
+ ):
+ _, stderr = process.communicate()
+ retcode = process.wait()
+
+ if retcode:
+ raise RuntimeError(
+ "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
+ )
+
+ # Get e-values for each target name
+ tbl = ""
+ if self.get_tblout:
+ with open(tblout_path) as f:
+ tbl = f.read()
+
+ with open(sto_path) as f:
+ sto = f.read()
+
+ raw_output = dict(
+ sto=sto,
+ tbl=tbl,
+ stderr=stderr,
+ n_iter=self.n_iter,
+ e_value=self.e_value,
+ )
+
+ return raw_output
+
+ def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
+ """Queries the database using Jackhmmer."""
+ if self.num_streamed_chunks is None:
+ return [self._query_chunk(input_fasta_path, self.database_path)]
+
+ db_basename = os.path.basename(self.database_path)
+ db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
+ db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
+
+ # Remove existing files to prevent OOM
+ for f in glob.glob(db_local_chunk("[0-9]*")):
+ try:
+ os.remove(f)
+ except OSError:
+ print(f"OSError while deleting {f}")
+
+ # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
+ with futures.ThreadPoolExecutor(max_workers=2) as executor:
+ chunked_output = []
+ for i in range(1, self.num_streamed_chunks + 1):
+ # Copy the chunk locally
+ if i == 1:
+ future = executor.submit(
+ request.urlretrieve,
+ db_remote_chunk(i),
+ db_local_chunk(i),
+ )
+ if i < self.num_streamed_chunks:
+ next_future = executor.submit(
+ request.urlretrieve,
+ db_remote_chunk(i + 1),
+ db_local_chunk(i + 1),
+ )
+
+ # Run Jackhmmer with the chunk
+ future.result()
+ chunked_output.append(
+ self._query_chunk(input_fasta_path, db_local_chunk(i))
+ )
+
+ # Remove the local copy of the chunk
+ os.remove(db_local_chunk(i))
+ future = next_future
+ if self.streaming_callback:
+ self.streaming_callback(i)
+ return chunked_output
diff --git a/openfold/data/tools/kalign.py b/openfold/data/tools/kalign.py
new file mode 100644
index 0000000000000000000000000000000000000000..5306f864bc220dbd979fd1ae556df56c3e65ff71
--- /dev/null
+++ b/openfold/data/tools/kalign.py
@@ -0,0 +1,115 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A Python wrapper for Kalign."""
+import os
+import subprocess
+from typing import Sequence
+
+from absl import logging
+
+from openfold.data.tools import utils
+
+
+def _to_a3m(sequences: Sequence[str]) -> str:
+ """Converts sequences to an a3m file."""
+ names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
+ a3m = []
+ for sequence, name in zip(sequences, names):
+ a3m.append(u">" + name + u"\n")
+ a3m.append(sequence + u"\n")
+ return "".join(a3m)
+
+
+class Kalign:
+ """Python wrapper of the Kalign binary."""
+
+ def __init__(self, *, binary_path: str):
+ """Initializes the Python Kalign wrapper.
+
+ Args:
+ binary_path: The path to the Kalign binary.
+
+ Raises:
+ RuntimeError: If Kalign binary not found within the path.
+ """
+ self.binary_path = binary_path
+
+ def align(self, sequences: Sequence[str]) -> str:
+ """Aligns the sequences and returns the alignment in A3M string.
+
+ Args:
+ sequences: A list of query sequence strings. The sequences have to be at
+ least 6 residues long (Kalign requires this). Note that the order in
+ which you give the sequences might alter the output slightly as
+ different alignment tree might get constructed.
+
+ Returns:
+ A string with the alignment in a3m format.
+
+ Raises:
+ RuntimeError: If Kalign fails.
+ ValueError: If any of the sequences is less than 6 residues long.
+ """
+ logging.info("Aligning %d sequences", len(sequences))
+
+ for s in sequences:
+ if len(s) < 6:
+ raise ValueError(
+ "Kalign requires all sequences to be at least 6 "
+ "residues long. Got %s (%d residues)." % (s, len(s))
+ )
+
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
+ input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
+ output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
+
+ with open(input_fasta_path, "w") as f:
+ f.write(_to_a3m(sequences))
+
+ cmd = [
+ self.binary_path,
+ "-i",
+ input_fasta_path,
+ "-o",
+ output_a3m_path,
+ "-format",
+ "fasta",
+ ]
+
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+
+ with utils.timing("Kalign query"):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+ logging.info(
+ "Kalign stdout:\n%s\n\nstderr:\n%s\n",
+ stdout.decode("utf-8"),
+ stderr.decode("utf-8"),
+ )
+
+ if retcode:
+ raise RuntimeError(
+ "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
+ % (stdout.decode("utf-8"), stderr.decode("utf-8"))
+ )
+
+ with open(output_a3m_path) as f:
+ a3m = f.read()
+
+ return a3m
diff --git a/openfold/data/tools/utils.py b/openfold/data/tools/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9e8684f1e3e6f20f9098c16009836ac737106c2
--- /dev/null
+++ b/openfold/data/tools/utils.py
@@ -0,0 +1,48 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common utilities for data pipeline tools."""
+import contextlib
+import datetime
+import logging
+import shutil
+import tempfile
+import time
+from typing import Optional
+
+
+@contextlib.contextmanager
+def tmpdir_manager(base_dir: Optional[str] = None):
+ """Context manager that deletes a temporary directory on exit."""
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir, ignore_errors=True)
+
+
+@contextlib.contextmanager
+def timing(msg: str):
+ logging.info("Started %s", msg)
+ tic = time.perf_counter()
+ yield
+ toc = time.perf_counter()
+ logging.info("Finished %s in %.3f seconds", msg, toc - tic)
+
+
+def to_date(s: str):
+ return datetime.datetime(
+ year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
+ )
diff --git a/openfold/model/__init__.py b/openfold/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1
--- /dev/null
+++ b/openfold/model/__init__.py
@@ -0,0 +1,16 @@
+import os
+import glob
+import importlib as importlib
+
+_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
+__all__ = [
+ os.path.basename(f)[:-3]
+ for f in _files
+ if os.path.isfile(f) and not f.endswith("__init__.py")
+]
+_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
+for _m in _modules:
+ globals()[_m[0]] = _m[1]
+
+# Avoid needlessly cluttering the global namespace
+del _files, _m, _modules
diff --git a/openfold/model/__pycache__/__init__.cpython-310.pyc b/openfold/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dfd49ba79ae06d6b030200c99200cde4dbb34d5
Binary files /dev/null and b/openfold/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/dropout.cpython-310.pyc b/openfold/model/__pycache__/dropout.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6397ee233b8e3a9f5cb266c7a728682262b20c4
Binary files /dev/null and b/openfold/model/__pycache__/dropout.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/embedders.cpython-310.pyc b/openfold/model/__pycache__/embedders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4b71ba49c974cff62e3984e5929ac3bdcc2b453
Binary files /dev/null and b/openfold/model/__pycache__/embedders.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/evoformer.cpython-310.pyc b/openfold/model/__pycache__/evoformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..203fe5b4b8db1cd767c96f0eff244f84b90d6623
Binary files /dev/null and b/openfold/model/__pycache__/evoformer.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/heads.cpython-310.pyc b/openfold/model/__pycache__/heads.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd6b3bd70406b11a71d1bd3dea71682ca43dfffa
Binary files /dev/null and b/openfold/model/__pycache__/heads.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/model.cpython-310.pyc b/openfold/model/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfb8bb6e63c21cd5ee96cc84904409239864f4fd
Binary files /dev/null and b/openfold/model/__pycache__/model.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/msa.cpython-310.pyc b/openfold/model/__pycache__/msa.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7af39fa36da6d3535a25d69877a645baeb91c46
Binary files /dev/null and b/openfold/model/__pycache__/msa.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc b/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6af94ce1b8963ca6664cfda25f3fad3e8349a2d
Binary files /dev/null and b/openfold/model/__pycache__/outer_product_mean.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/pair_transition.cpython-310.pyc b/openfold/model/__pycache__/pair_transition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cf7918be7f8591ae0a45ef31e131cd72cac260e
Binary files /dev/null and b/openfold/model/__pycache__/pair_transition.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/primitives.cpython-310.pyc b/openfold/model/__pycache__/primitives.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8c357d467780e29c9dc8a85647c72a30337725d
Binary files /dev/null and b/openfold/model/__pycache__/primitives.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/structure_module.cpython-310.pyc b/openfold/model/__pycache__/structure_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e32c64bb867cf233488161e2a7d8a03740698911
Binary files /dev/null and b/openfold/model/__pycache__/structure_module.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/template.cpython-310.pyc b/openfold/model/__pycache__/template.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e2614459d6d5e5c9ccdb9fd13a1bcc845464d4e
Binary files /dev/null and b/openfold/model/__pycache__/template.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/torchscript.cpython-310.pyc b/openfold/model/__pycache__/torchscript.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cbfc71f0e8a3230e5063a74c0fee2b15368c6a9
Binary files /dev/null and b/openfold/model/__pycache__/torchscript.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/triangular_attention.cpython-310.pyc b/openfold/model/__pycache__/triangular_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab34565e5dcc515d04a6d1b8f8ba2d683110c49f
Binary files /dev/null and b/openfold/model/__pycache__/triangular_attention.cpython-310.pyc differ
diff --git a/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc b/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a49270d3fc7837b5ad9561a01a37184d5d13122
Binary files /dev/null and b/openfold/model/__pycache__/triangular_multiplicative_update.cpython-310.pyc differ
diff --git a/openfold/model/dropout.py b/openfold/model/dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..651b9775ef44fba20dec75c60703f00beac66e0c
--- /dev/null
+++ b/openfold/model/dropout.py
@@ -0,0 +1,78 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from functools import partialmethod
+from typing import Union, List
+
+
+class Dropout(nn.Module):
+ """
+ Implementation of dropout with the ability to share the dropout mask
+ along a particular dimension.
+
+ If not in training mode, this module computes the identity function.
+ """
+
+ def __init__(self, r: float, batch_dim: Union[int, List[int]]):
+ """
+ Args:
+ r:
+ Dropout rate
+ batch_dim:
+ Dimension(s) along which the dropout mask is shared
+ """
+ super(Dropout, self).__init__()
+
+ self.r = r
+ if type(batch_dim) == int:
+ batch_dim = [batch_dim]
+ self.batch_dim = batch_dim
+ self.dropout = nn.Dropout(self.r)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ Tensor to which dropout is applied. Can have any shape
+ compatible with self.batch_dim
+ """
+ shape = list(x.shape)
+ if self.batch_dim is not None:
+ for bd in self.batch_dim:
+ shape[bd] = 1
+ mask = x.new_ones(shape)
+ mask = self.dropout(mask)
+ x *= mask
+ return x
+
+
+class DropoutRowwise(Dropout):
+ """
+ Convenience class for rowwise dropout as described in subsection
+ 1.11.6.
+ """
+
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
+
+
+class DropoutColumnwise(Dropout):
+ """
+ Convenience class for columnwise dropout as described in subsection
+ 1.11.6.
+ """
+
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
diff --git a/openfold/model/embedders.py b/openfold/model/embedders.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f1f9981c9965be630fd1ccb98ebdcf0fe42dc5
--- /dev/null
+++ b/openfold/model/embedders.py
@@ -0,0 +1,352 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from typing import Tuple
+
+from openfold.model.primitives import Linear, LayerNorm
+from openfold.utils.tensor_utils import one_hot
+
+
+class InputEmbedder(nn.Module):
+ """
+ Embeds a subset of the input features.
+
+ Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
+ """
+
+ def __init__(
+ self,
+ tf_dim: int,
+ msa_dim: int,
+ c_z: int,
+ c_m: int,
+ relpos_k: int,
+ **kwargs,
+ ):
+ """
+ Args:
+ tf_dim:
+ Final dimension of the target features
+ msa_dim:
+ Final dimension of the MSA features
+ c_z:
+ Pair embedding dimension
+ c_m:
+ MSA embedding dimension
+ relpos_k:
+ Window size used in relative positional encoding
+ """
+ super(InputEmbedder, self).__init__()
+
+ self.tf_dim = tf_dim
+ self.msa_dim = msa_dim
+
+ self.c_z = c_z
+ self.c_m = c_m
+
+ self.linear_tf_z_i = Linear(tf_dim, c_z)
+ self.linear_tf_z_j = Linear(tf_dim, c_z)
+ self.linear_tf_m = Linear(tf_dim, c_m)
+ self.linear_msa_m = Linear(msa_dim, c_m)
+
+ # RPE stuff
+ self.relpos_k = relpos_k
+ self.no_bins = 2 * relpos_k + 1
+ self.linear_relpos = Linear(self.no_bins, c_z)
+
+ def relpos(self, ri: torch.Tensor):
+ """
+ Computes relative positional encodings
+
+ Implements Algorithm 4.
+
+ Args:
+ ri:
+ "residue_index" features of shape [*, N]
+ """
+ d = ri[..., None] - ri[..., None, :]
+ boundaries = torch.arange(
+ start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
+ )
+ oh = one_hot(d, boundaries).type(ri.dtype)
+ return self.linear_relpos(oh)
+
+ def forward(
+ self,
+ tf: torch.Tensor,
+ ri: torch.Tensor,
+ msa: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ tf:
+ "target_feat" features of shape [*, N_res, tf_dim]
+ ri:
+ "residue_index" features of shape [*, N_res]
+ msa:
+ "msa_feat" features of shape [*, N_clust, N_res, msa_dim]
+ Returns:
+ msa_emb:
+ [*, N_clust, N_res, C_m] MSA embedding
+ pair_emb:
+ [*, N_res, N_res, C_z] pair embedding
+
+ """
+ # [*, N_res, c_z]
+ tf_emb_i = self.linear_tf_z_i(tf)
+ tf_emb_j = self.linear_tf_z_j(tf)
+
+ # [*, N_res, N_res, c_z]
+ pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
+ pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
+
+ # [*, N_clust, N_res, c_m]
+ n_clust = msa.shape[-3]
+ tf_m = (
+ self.linear_tf_m(tf)
+ .unsqueeze(-3)
+ .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
+ )
+ msa_emb = self.linear_msa_m(msa) + tf_m
+
+ return msa_emb, pair_emb
+
+
+class RecyclingEmbedder(nn.Module):
+ """
+ Embeds the output of an iteration of the model for recycling.
+
+ Implements Algorithm 32.
+ """
+
+ def __init__(
+ self,
+ c_m: int,
+ c_z: int,
+ min_bin: float,
+ max_bin: float,
+ no_bins: int,
+ inf: float = 1e8,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_m:
+ MSA channel dimension
+ c_z:
+ Pair embedding channel dimension
+ min_bin:
+ Smallest distogram bin (Angstroms)
+ max_bin:
+ Largest distogram bin (Angstroms)
+ no_bins:
+ Number of distogram bins
+ """
+ super(RecyclingEmbedder, self).__init__()
+
+ self.c_m = c_m
+ self.c_z = c_z
+ self.min_bin = min_bin
+ self.max_bin = max_bin
+ self.no_bins = no_bins
+ self.inf = inf
+
+ self.bins = None
+
+ self.linear = Linear(self.no_bins, self.c_z)
+ self.layer_norm_m = LayerNorm(self.c_m)
+ self.layer_norm_z = LayerNorm(self.c_z)
+
+ def forward(
+ self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ x: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ m:
+ First row of the MSA embedding. [*, N_res, C_m]
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ x:
+ [*, N_res, 3] predicted C_beta coordinates
+ Returns:
+ m:
+ [*, N_res, C_m] MSA embedding update
+ z:
+ [*, N_res, N_res, C_z] pair embedding update
+ """
+ if self.bins is None:
+ self.bins = torch.linspace(
+ self.min_bin,
+ self.max_bin,
+ self.no_bins,
+ dtype=x.dtype,
+ device=x.device,
+ requires_grad=False,
+ )
+
+ # [*, N, C_m]
+ m_update = self.layer_norm_m(m)
+
+ # This squared method might become problematic in FP16 mode.
+ # I'm using it because my homegrown method had a stubborn discrepancy I
+ # couldn't find in time.
+ squared_bins = self.bins ** 2
+ upper = torch.cat(
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
+ )
+ d = torch.sum(
+ (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
+ )
+
+ # [*, N, N, no_bins]
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
+
+ # [*, N, N, C_z]
+ d = self.linear(d)
+ z_update = d + self.layer_norm_z(z)
+
+ return m_update, z_update
+
+
+class TemplateAngleEmbedder(nn.Module):
+ """
+ Embeds the "template_angle_feat" feature.
+
+ Implements Algorithm 2, line 7.
+ """
+
+ def __init__(
+ self,
+ c_in: int,
+ c_out: int,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_in:
+ Final dimension of "template_angle_feat"
+ c_out:
+ Output channel dimension
+ """
+ super(TemplateAngleEmbedder, self).__init__()
+
+ self.c_out = c_out
+ self.c_in = c_in
+
+ self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
+ self.relu = nn.ReLU()
+ self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: [*, N_templ, N_res, c_in] "template_angle_feat" features
+ Returns:
+ x: [*, N_templ, N_res, C_out] embedding
+ """
+ x = self.linear_1(x)
+ x = self.relu(x)
+ x = self.linear_2(x)
+
+ return x
+
+
+class TemplatePairEmbedder(nn.Module):
+ """
+ Embeds "template_pair_feat" features.
+
+ Implements Algorithm 2, line 9.
+ """
+
+ def __init__(
+ self,
+ c_in: int,
+ c_out: int,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_in:
+
+ c_out:
+ Output channel dimension
+ """
+ super(TemplatePairEmbedder, self).__init__()
+
+ self.c_in = c_in
+ self.c_out = c_out
+
+ # Despite there being no relu nearby, the source uses that initializer
+ self.linear = Linear(self.c_in, self.c_out, init="relu")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, C_in] input tensor
+ Returns:
+ [*, C_out] output tensor
+ """
+ x = self.linear(x)
+
+ return x
+
+
+class ExtraMSAEmbedder(nn.Module):
+ """
+ Embeds unclustered MSA sequences.
+
+ Implements Algorithm 2, line 15
+ """
+
+ def __init__(
+ self,
+ c_in: int,
+ c_out: int,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_out:
+ Output channel dimension
+ """
+ super(ExtraMSAEmbedder, self).__init__()
+
+ self.c_in = c_in
+ self.c_out = c_out
+
+ self.linear = Linear(self.c_in, self.c_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
+ Returns:
+ [*, N_extra_seq, N_res, C_out] embedding
+ """
+ x = self.linear(x)
+
+ return x
diff --git a/openfold/model/evoformer.py b/openfold/model/evoformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a862097a40dcb81514e3522494a66ea5078b01
--- /dev/null
+++ b/openfold/model/evoformer.py
@@ -0,0 +1,630 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import torch
+import torch.nn as nn
+from typing import Tuple, Optional
+from functools import partial
+
+from openfold.model.primitives import Linear, LayerNorm
+from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
+from openfold.model.msa import (
+ MSARowAttentionWithPairBias,
+ MSAColumnAttention,
+ MSAColumnGlobalAttention,
+)
+from openfold.model.outer_product_mean import OuterProductMean
+from openfold.model.pair_transition import PairTransition
+from openfold.model.triangular_attention import (
+ TriangleAttentionStartingNode,
+ TriangleAttentionEndingNode,
+)
+from openfold.model.triangular_multiplicative_update import (
+ TriangleMultiplicationOutgoing,
+ TriangleMultiplicationIncoming,
+)
+from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
+from openfold.utils.tensor_utils import chunk_layer
+
+
+class MSATransition(nn.Module):
+ """
+ Feed-forward network applied to MSA activations after attention.
+
+ Implements Algorithm 9
+ """
+ def __init__(self, c_m, n):
+ """
+ Args:
+ c_m:
+ MSA channel dimension
+ n:
+ Factor multiplied to c_m to obtain the hidden channel
+ dimension
+ """
+ super(MSATransition, self).__init__()
+
+ self.c_m = c_m
+ self.n = n
+
+ self.layer_norm = LayerNorm(self.c_m)
+ self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
+ self.relu = nn.ReLU()
+ self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
+
+ def _transition(self, m, mask):
+ m = self.linear_1(m)
+ m = self.relu(m)
+ m = self.linear_2(m) * mask
+ return m
+
+ @torch.jit.ignore
+ def _chunk(self,
+ m: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int,
+ ) -> torch.Tensor:
+ return chunk_layer(
+ self._transition,
+ {"m": m, "mask": mask},
+ chunk_size=chunk_size,
+ no_batch_dims=len(m.shape[:-2]),
+ )
+
+
+ def forward(
+ self,
+ m: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA activation
+ mask:
+ [*, N_seq, N_res, C_m] MSA mask
+ Returns:
+ m:
+ [*, N_seq, N_res, C_m] MSA activation update
+ """
+ # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
+ if mask is None:
+ mask = m.new_ones(m.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ m = self.layer_norm(m)
+
+ if chunk_size is not None:
+ m = self._chunk(m, mask, chunk_size)
+ else:
+ m = self._transition(m, mask)
+
+ return m
+
+
+class EvoformerBlockCore(nn.Module):
+ def __init__(
+ self,
+ c_m: int,
+ c_z: int,
+ c_hidden_opm: int,
+ c_hidden_mul: int,
+ c_hidden_pair_att: int,
+ no_heads_msa: int,
+ no_heads_pair: int,
+ transition_n: int,
+ pair_dropout: float,
+ inf: float,
+ eps: float,
+ _is_extra_msa_stack: bool = False,
+ ):
+ super(EvoformerBlockCore, self).__init__()
+
+ self.msa_transition = MSATransition(
+ c_m=c_m,
+ n=transition_n,
+ )
+
+ self.outer_product_mean = OuterProductMean(
+ c_m,
+ c_z,
+ c_hidden_opm,
+ )
+
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
+ c_z,
+ c_hidden_mul,
+ )
+ self.tri_mul_in = TriangleMultiplicationIncoming(
+ c_z,
+ c_hidden_mul,
+ )
+
+ self.tri_att_start = TriangleAttentionStartingNode(
+ c_z,
+ c_hidden_pair_att,
+ no_heads_pair,
+ inf=inf,
+ )
+ self.tri_att_end = TriangleAttentionEndingNode(
+ c_z,
+ c_hidden_pair_att,
+ no_heads_pair,
+ inf=inf,
+ )
+
+ self.pair_transition = PairTransition(
+ c_z,
+ transition_n,
+ )
+
+ self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
+ self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
+
+ def forward(
+ self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ msa_mask: torch.Tensor,
+ pair_mask: torch.Tensor,
+ chunk_size: Optional[int] = None,
+ _mask_trans: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # DeepMind doesn't mask these transitions in the source, so _mask_trans
+ # should be disabled to better approximate the exact activations of
+ # the original.
+ msa_trans_mask = msa_mask if _mask_trans else None
+ pair_trans_mask = pair_mask if _mask_trans else None
+
+ m = m + self.msa_transition(
+ m, mask=msa_trans_mask, chunk_size=chunk_size
+ )
+ z = z + self.outer_product_mean(
+ m, mask=msa_mask, chunk_size=chunk_size
+ )
+ z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
+ z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
+ z = z + self.ps_dropout_row_layer(
+ self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
+ )
+ z = z + self.ps_dropout_col_layer(
+ self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
+ )
+ z = z + self.pair_transition(
+ z, mask=pair_trans_mask, chunk_size=chunk_size
+ )
+
+ return m, z
+
+
+class EvoformerBlock(nn.Module):
+ def __init__(self,
+ c_m: int,
+ c_z: int,
+ c_hidden_msa_att: int,
+ c_hidden_opm: int,
+ c_hidden_mul: int,
+ c_hidden_pair_att: int,
+ no_heads_msa: int,
+ no_heads_pair: int,
+ transition_n: int,
+ msa_dropout: float,
+ pair_dropout: float,
+ inf: float,
+ eps: float,
+ ):
+ super(EvoformerBlock, self).__init__()
+
+ self.msa_att_row = MSARowAttentionWithPairBias(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden=c_hidden_msa_att,
+ no_heads=no_heads_msa,
+ inf=inf,
+ )
+
+ self.msa_att_col = MSAColumnAttention(
+ c_m,
+ c_hidden_msa_att,
+ no_heads_msa,
+ inf=inf,
+ )
+
+ self.msa_dropout_layer = DropoutRowwise(msa_dropout)
+
+ self.core = EvoformerBlockCore(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden_opm=c_hidden_opm,
+ c_hidden_mul=c_hidden_mul,
+ c_hidden_pair_att=c_hidden_pair_att,
+ no_heads_msa=no_heads_msa,
+ no_heads_pair=no_heads_pair,
+ transition_n=transition_n,
+ pair_dropout=pair_dropout,
+ inf=inf,
+ eps=eps,
+ )
+
+ def forward(self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ msa_mask: torch.Tensor,
+ pair_mask: torch.Tensor,
+ chunk_size: Optional[int] = None,
+ _mask_trans: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ m = m + self.msa_dropout_layer(
+ self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
+ )
+ m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
+ m, z = self.core(
+ m,
+ z,
+ msa_mask=msa_mask,
+ pair_mask=pair_mask,
+ chunk_size=chunk_size,
+ _mask_trans=_mask_trans,
+ )
+
+ return m, z
+
+
+class ExtraMSABlock(nn.Module):
+ """
+ Almost identical to the standard EvoformerBlock, except in that the
+ ExtraMSABlock uses GlobalAttention for MSA column attention and
+ requires more fine-grained control over checkpointing. Separated from
+ its twin to preserve the TorchScript-ability of the latter.
+ """
+ def __init__(self,
+ c_m: int,
+ c_z: int,
+ c_hidden_msa_att: int,
+ c_hidden_opm: int,
+ c_hidden_mul: int,
+ c_hidden_pair_att: int,
+ no_heads_msa: int,
+ no_heads_pair: int,
+ transition_n: int,
+ msa_dropout: float,
+ pair_dropout: float,
+ inf: float,
+ eps: float,
+ ckpt: bool,
+ ):
+ super(ExtraMSABlock, self).__init__()
+
+ self.ckpt = ckpt
+
+ self.msa_att_row = MSARowAttentionWithPairBias(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden=c_hidden_msa_att,
+ no_heads=no_heads_msa,
+ inf=inf,
+ )
+
+ self.msa_att_col = MSAColumnGlobalAttention(
+ c_in=c_m,
+ c_hidden=c_hidden_msa_att,
+ no_heads=no_heads_msa,
+ inf=inf,
+ eps=eps,
+ )
+
+ self.msa_dropout_layer = DropoutRowwise(msa_dropout)
+
+ self.core = EvoformerBlockCore(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden_opm=c_hidden_opm,
+ c_hidden_mul=c_hidden_mul,
+ c_hidden_pair_att=c_hidden_pair_att,
+ no_heads_msa=no_heads_msa,
+ no_heads_pair=no_heads_pair,
+ transition_n=transition_n,
+ pair_dropout=pair_dropout,
+ inf=inf,
+ eps=eps,
+ )
+
+ def forward(self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ msa_mask: torch.Tensor,
+ pair_mask: torch.Tensor,
+ chunk_size: Optional[int] = None,
+ _chunk_logits: Optional[int] = 1024,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def add(m1, m2):
+ # The first operation in a checkpoint can't be in-place, but it's
+ # nice to have in-place addition during inference. Thus...
+ if(torch.is_grad_enabled()):
+ m1 = m1 + m2
+ else:
+ m1 += m2
+
+ return m1
+
+ m = add(m, self.msa_dropout_layer(
+ self.msa_att_row(
+ m.clone() if torch.is_grad_enabled() else m,
+ z=z.clone() if torch.is_grad_enabled() else z,
+ mask=msa_mask,
+ chunk_size=chunk_size,
+ _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
+ _checkpoint_chunks=
+ self.ckpt if torch.is_grad_enabled() else False,
+ )
+ ))
+
+ def fn(m, z):
+ m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
+ m, z = self.core(
+ m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
+ )
+
+ return m, z
+
+ if(torch.is_grad_enabled() and self.ckpt):
+ checkpoint_fn = get_checkpoint_fn()
+ m, z = checkpoint_fn(fn, m, z)
+ else:
+ m, z = fn(m, z)
+
+ return m, z
+
+
+class EvoformerStack(nn.Module):
+ """
+ Main Evoformer trunk.
+
+ Implements Algorithm 6.
+ """
+
+ def __init__(
+ self,
+ c_m: int,
+ c_z: int,
+ c_hidden_msa_att: int,
+ c_hidden_opm: int,
+ c_hidden_mul: int,
+ c_hidden_pair_att: int,
+ c_s: int,
+ no_heads_msa: int,
+ no_heads_pair: int,
+ no_blocks: int,
+ transition_n: int,
+ msa_dropout: float,
+ pair_dropout: float,
+ blocks_per_ckpt: int,
+ inf: float,
+ eps: float,
+ clear_cache_between_blocks: bool = False,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_m:
+ MSA channel dimension
+ c_z:
+ Pair channel dimension
+ c_hidden_msa_att:
+ Hidden dimension in MSA attention
+ c_hidden_opm:
+ Hidden dimension in outer product mean module
+ c_hidden_mul:
+ Hidden dimension in multiplicative updates
+ c_hidden_pair_att:
+ Hidden dimension in triangular attention
+ c_s:
+ Channel dimension of the output "single" embedding
+ no_heads_msa:
+ Number of heads used for MSA attention
+ no_heads_pair:
+ Number of heads used for pair attention
+ no_blocks:
+ Number of Evoformer blocks in the stack
+ transition_n:
+ Factor by which to multiply c_m to obtain the MSATransition
+ hidden dimension
+ msa_dropout:
+ Dropout rate for MSA activations
+ pair_dropout:
+ Dropout used for pair activations
+ blocks_per_ckpt:
+ Number of Evoformer blocks in each activation checkpoint
+ clear_cache_between_blocks:
+ Whether to clear CUDA's GPU memory cache between blocks of the
+ stack. Slows down each block but can reduce fragmentation
+ """
+ super(EvoformerStack, self).__init__()
+
+ self.blocks_per_ckpt = blocks_per_ckpt
+ self.clear_cache_between_blocks = clear_cache_between_blocks
+
+ self.blocks = nn.ModuleList()
+
+ for _ in range(no_blocks):
+ block = EvoformerBlock(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden_msa_att=c_hidden_msa_att,
+ c_hidden_opm=c_hidden_opm,
+ c_hidden_mul=c_hidden_mul,
+ c_hidden_pair_att=c_hidden_pair_att,
+ no_heads_msa=no_heads_msa,
+ no_heads_pair=no_heads_pair,
+ transition_n=transition_n,
+ msa_dropout=msa_dropout,
+ pair_dropout=pair_dropout,
+ inf=inf,
+ eps=eps,
+ )
+ self.blocks.append(block)
+
+ self.linear = Linear(c_m, c_s)
+
+ def forward(self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ msa_mask: torch.Tensor,
+ pair_mask: torch.Tensor,
+ chunk_size: int,
+ _mask_trans: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ msa_mask:
+ [*, N_seq, N_res] MSA mask
+ pair_mask:
+ [*, N_res, N_res] pair mask
+ Returns:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ s:
+ [*, N_res, C_s] single embedding (or None if extra MSA stack)
+ """
+ blocks = [
+ partial(
+ b,
+ msa_mask=msa_mask,
+ pair_mask=pair_mask,
+ chunk_size=chunk_size,
+ _mask_trans=_mask_trans,
+ )
+ for b in self.blocks
+ ]
+
+ if(self.clear_cache_between_blocks):
+ def block_with_cache_clear(block, *args):
+ torch.cuda.empty_cache()
+ return block(*args)
+
+ blocks = [partial(block_with_cache_clear, b) for b in blocks]
+
+ m, z = checkpoint_blocks(
+ blocks,
+ args=(m, z),
+ blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
+ )
+
+ s = self.linear(m[..., 0, :, :])
+
+ return m, z, s
+
+
+class ExtraMSAStack(nn.Module):
+ """
+ Implements Algorithm 18.
+ """
+
+ def __init__(self,
+ c_m: int,
+ c_z: int,
+ c_hidden_msa_att: int,
+ c_hidden_opm: int,
+ c_hidden_mul: int,
+ c_hidden_pair_att: int,
+ no_heads_msa: int,
+ no_heads_pair: int,
+ no_blocks: int,
+ transition_n: int,
+ msa_dropout: float,
+ pair_dropout: float,
+ inf: float,
+ eps: float,
+ ckpt: bool,
+ clear_cache_between_blocks: bool = False,
+ **kwargs,
+ ):
+ super(ExtraMSAStack, self).__init__()
+
+ self.clear_cache_between_blocks = clear_cache_between_blocks
+ self.blocks = nn.ModuleList()
+ for _ in range(no_blocks):
+ block = ExtraMSABlock(
+ c_m=c_m,
+ c_z=c_z,
+ c_hidden_msa_att=c_hidden_msa_att,
+ c_hidden_opm=c_hidden_opm,
+ c_hidden_mul=c_hidden_mul,
+ c_hidden_pair_att=c_hidden_pair_att,
+ no_heads_msa=no_heads_msa,
+ no_heads_pair=no_heads_pair,
+ transition_n=transition_n,
+ msa_dropout=msa_dropout,
+ pair_dropout=pair_dropout,
+ inf=inf,
+ eps=eps,
+ ckpt=ckpt,
+ )
+ self.blocks.append(block)
+
+ def forward(self,
+ m: torch.Tensor,
+ z: torch.Tensor,
+ chunk_size: int,
+ msa_mask: Optional[torch.Tensor] = None,
+ pair_mask: Optional[torch.Tensor] = None,
+ _mask_trans: bool = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ m:
+ [*, N_extra, N_res, C_m] extra MSA embedding
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ msa_mask:
+ Optional [*, N_extra, N_res] MSA mask
+ pair_mask:
+ Optional [*, N_res, N_res] pair mask
+ Returns:
+ [*, N_res, N_res, C_z] pair update
+ """
+ #checkpoint_fn = get_checkpoint_fn()
+ #blocks = [
+ # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
+ #]
+
+ #def dodo(b, *args):
+ # torch.cuda.empty_cache()
+ # return b(*args)
+
+ #blocks = [partial(dodo, b) for b in blocks]
+
+ #for b in blocks:
+ # if(torch.is_grad_enabled()):
+ # m, z = checkpoint_fn(b, *(m, z))
+ # else:
+ # m, z = b(m, z)
+
+ for b in self.blocks:
+ m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
+
+ if(self.clear_cache_between_blocks):
+ torch.cuda.empty_cache()
+
+ return z
diff --git a/openfold/model/heads.py b/openfold/model/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..f46126d741e79d177a80220f89e1523a408e569b
--- /dev/null
+++ b/openfold/model/heads.py
@@ -0,0 +1,251 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear, LayerNorm
+from openfold.utils.loss import (
+ compute_plddt,
+ compute_tm,
+ compute_predicted_aligned_error,
+)
+
+
+class AuxiliaryHeads(nn.Module):
+ def __init__(self, config):
+ super(AuxiliaryHeads, self).__init__()
+
+ self.plddt = PerResidueLDDTCaPredictor(
+ **config["lddt"],
+ )
+
+ self.distogram = DistogramHead(
+ **config["distogram"],
+ )
+
+ self.masked_msa = MaskedMSAHead(
+ **config["masked_msa"],
+ )
+
+ self.experimentally_resolved = ExperimentallyResolvedHead(
+ **config["experimentally_resolved"],
+ )
+
+ if config.tm.enabled:
+ self.tm = TMScoreHead(
+ **config.tm,
+ )
+
+ self.config = config
+
+ def forward(self, outputs):
+ aux_out = {}
+ lddt_logits = self.plddt(outputs["sm"]["single"])
+ aux_out["lddt_logits"] = lddt_logits
+
+ # Required for relaxation later on
+ aux_out["plddt"] = compute_plddt(lddt_logits)
+
+ distogram_logits = self.distogram(outputs["pair"])
+ aux_out["distogram_logits"] = distogram_logits
+
+ masked_msa_logits = self.masked_msa(outputs["msa"])
+ aux_out["masked_msa_logits"] = masked_msa_logits
+
+ experimentally_resolved_logits = self.experimentally_resolved(
+ outputs["single"]
+ )
+ aux_out[
+ "experimentally_resolved_logits"
+ ] = experimentally_resolved_logits
+
+ if self.config.tm.enabled:
+ tm_logits = self.tm(outputs["pair"])
+ aux_out["tm_logits"] = tm_logits
+ aux_out["predicted_tm_score"] = compute_tm(
+ tm_logits, **self.config.tm
+ )
+ aux_out.update(
+ compute_predicted_aligned_error(
+ tm_logits,
+ **self.config.tm,
+ )
+ )
+
+ return aux_out
+
+
+class PerResidueLDDTCaPredictor(nn.Module):
+ def __init__(self, no_bins, c_in, c_hidden):
+ super(PerResidueLDDTCaPredictor, self).__init__()
+
+ self.no_bins = no_bins
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+
+ self.layer_norm = LayerNorm(self.c_in)
+
+ self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
+ self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s = self.layer_norm(s)
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+
+ return s
+
+
+class DistogramHead(nn.Module):
+ """
+ Computes a distogram probability distribution.
+
+ For use in computation of distogram loss, subsection 1.9.8
+ """
+
+ def __init__(self, c_z, no_bins, **kwargs):
+ """
+ Args:
+ c_z:
+ Input channel dimension
+ no_bins:
+ Number of distogram bins
+ """
+ super(DistogramHead, self).__init__()
+
+ self.c_z = c_z
+ self.no_bins = no_bins
+
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
+
+ def forward(self, z): # [*, N, N, C_z]
+ """
+ Args:
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ Returns:
+ [*, N, N, no_bins] distogram probability distribution
+ """
+ # [*, N, N, no_bins]
+ logits = self.linear(z)
+ logits = logits + logits.transpose(-2, -3)
+ return logits
+
+
+class TMScoreHead(nn.Module):
+ """
+ For use in computation of TM-score, subsection 1.9.7
+ """
+
+ def __init__(self, c_z, no_bins, **kwargs):
+ """
+ Args:
+ c_z:
+ Input channel dimension
+ no_bins:
+ Number of bins
+ """
+ super(TMScoreHead, self).__init__()
+
+ self.c_z = c_z
+ self.no_bins = no_bins
+
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
+
+ def forward(self, z):
+ """
+ Args:
+ z:
+ [*, N_res, N_res, C_z] pairwise embedding
+ Returns:
+ [*, N_res, N_res, no_bins] prediction
+ """
+ # [*, N, N, no_bins]
+ logits = self.linear(z)
+ return logits
+
+
+class MaskedMSAHead(nn.Module):
+ """
+ For use in computation of masked MSA loss, subsection 1.9.9
+ """
+
+ def __init__(self, c_m, c_out, **kwargs):
+ """
+ Args:
+ c_m:
+ MSA channel dimension
+ c_out:
+ Output channel dimension
+ """
+ super(MaskedMSAHead, self).__init__()
+
+ self.c_m = c_m
+ self.c_out = c_out
+
+ self.linear = Linear(self.c_m, self.c_out, init="final")
+
+ def forward(self, m):
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ Returns:
+ [*, N_seq, N_res, C_out] reconstruction
+ """
+ # [*, N_seq, N_res, C_out]
+ logits = self.linear(m)
+ return logits
+
+
+class ExperimentallyResolvedHead(nn.Module):
+ """
+ For use in computation of "experimentally resolved" loss, subsection
+ 1.9.10
+ """
+
+ def __init__(self, c_s, c_out, **kwargs):
+ """
+ Args:
+ c_s:
+ Input channel dimension
+ c_out:
+ Number of distogram bins
+ """
+ super(ExperimentallyResolvedHead, self).__init__()
+
+ self.c_s = c_s
+ self.c_out = c_out
+
+ self.linear = Linear(self.c_s, self.c_out, init="final")
+
+ def forward(self, s):
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single embedding
+ Returns:
+ [*, N, C_out] logits
+ """
+ # [*, N, C_out]
+ logits = self.linear(s)
+ return logits
diff --git a/openfold/model/model.py b/openfold/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..007b90f7af7b7f5b379f5c18d89f0f69f398c494
--- /dev/null
+++ b/openfold/model/model.py
@@ -0,0 +1,446 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+import torch
+import torch.nn as nn
+
+from openfold.utils.feats import (
+ pseudo_beta_fn,
+ build_extra_msa_feat,
+ build_template_angle_feat,
+ build_template_pair_feat,
+ atom14_to_atom37,
+)
+from openfold.model.embedders import (
+ InputEmbedder,
+ RecyclingEmbedder,
+ TemplateAngleEmbedder,
+ TemplatePairEmbedder,
+ ExtraMSAEmbedder,
+)
+from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
+from openfold.model.heads import AuxiliaryHeads
+import openfold.np.residue_constants as residue_constants
+from openfold.model.structure_module import StructureModule
+from openfold.model.template import (
+ TemplatePairStack,
+ TemplatePointwiseAttention,
+)
+from openfold.utils.loss import (
+ compute_plddt,
+)
+from openfold.utils.tensor_utils import (
+ dict_multimap,
+ tensor_tree_map,
+)
+
+
+class AlphaFold(nn.Module):
+ """
+ Alphafold 2.
+
+ Implements Algorithm 2 (but with training).
+ """
+
+ def __init__(self, config):
+ """
+ Args:
+ config:
+ A dict-like config object (like the one in config.py)
+ """
+ super(AlphaFold, self).__init__()
+
+ self.globals = config.globals
+ config = config.model
+ template_config = config.template
+ extra_msa_config = config.extra_msa
+
+ # Main trunk + structure module
+ self.input_embedder = InputEmbedder(
+ **config["input_embedder"],
+ )
+ self.recycling_embedder = RecyclingEmbedder(
+ **config["recycling_embedder"],
+ )
+ self.template_angle_embedder = TemplateAngleEmbedder(
+ **template_config["template_angle_embedder"],
+ )
+ self.template_pair_embedder = TemplatePairEmbedder(
+ **template_config["template_pair_embedder"],
+ )
+ self.template_pair_stack = TemplatePairStack(
+ **template_config["template_pair_stack"],
+ )
+ self.template_pointwise_att = TemplatePointwiseAttention(
+ **template_config["template_pointwise_attention"],
+ )
+ self.extra_msa_embedder = ExtraMSAEmbedder(
+ **extra_msa_config["extra_msa_embedder"],
+ )
+ self.extra_msa_stack = ExtraMSAStack(
+ **extra_msa_config["extra_msa_stack"],
+ )
+ self.evoformer = EvoformerStack(
+ **config["evoformer_stack"],
+ )
+ self.structure_module = StructureModule(
+ **config["structure_module"],
+ )
+
+ self.aux_heads = AuxiliaryHeads(
+ config["heads"],
+ )
+
+ self.config = config
+
+ def embed_templates(self, batch, z, pair_mask, templ_dim):
+ # Embed the templates one at a time (with a poor man's vmap)
+ template_embeds = []
+ n_templ = batch["template_aatype"].shape[templ_dim]
+ for i in range(n_templ):
+ idx = batch["template_aatype"].new_tensor(i)
+ single_template_feats = tensor_tree_map(
+ lambda t: torch.index_select(t, templ_dim, idx),
+ batch,
+ )
+
+ single_template_embeds = {}
+ if self.config.template.embed_angles:
+ template_angle_feat = build_template_angle_feat(
+ single_template_feats,
+ )
+
+ # [*, S_t, N, C_m]
+ a = self.template_angle_embedder(template_angle_feat)
+
+ single_template_embeds["angle"] = a
+
+ # [*, S_t, N, N, C_t]
+ t = build_template_pair_feat(
+ single_template_feats,
+ inf=self.config.template.inf,
+ eps=self.config.template.eps,
+ **self.config.template.distogram,
+ ).to(z.dtype)
+ t = self.template_pair_embedder(t)
+
+ single_template_embeds.update({"pair": t})
+
+ template_embeds.append(single_template_embeds)
+
+ template_embeds = dict_multimap(
+ partial(torch.cat, dim=templ_dim),
+ template_embeds,
+ )
+
+ # [*, S_t, N, N, C_z]
+ t = self.template_pair_stack(
+ template_embeds["pair"],
+ pair_mask.unsqueeze(-3).to(dtype=z.dtype),
+ chunk_size=self.globals.chunk_size,
+ _mask_trans=self.config._mask_trans,
+ )
+
+ # [*, N, N, C_z]
+ t = self.template_pointwise_att(
+ t,
+ z,
+ template_mask=batch["template_mask"].to(dtype=z.dtype),
+ chunk_size=self.globals.chunk_size,
+ )
+ t = t * (torch.sum(batch["template_mask"]) > 0)
+
+ ret = {}
+ if self.config.template.embed_angles:
+ ret["template_angle_embedding"] = template_embeds["angle"]
+
+ ret.update({"template_pair_embedding": t})
+
+ return ret
+
+ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
+ # Primary output dictionary
+ outputs = {}
+
+ # This needs to be done manually for DeepSpeed's sake
+ dtype = next(self.parameters()).dtype
+ for k in feats:
+ if(feats[k].dtype == torch.float32):
+ feats[k] = feats[k].to(dtype=dtype)
+
+ # Grab some data about the input
+ batch_dims = feats["target_feat"].shape[:-2]
+ no_batch_dims = len(batch_dims)
+ n = feats["target_feat"].shape[-2]
+ n_seq = feats["msa_feat"].shape[-3]
+ device = feats["target_feat"].device
+
+ # Prep some features
+ seq_mask = feats["seq_mask"]
+ pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
+ msa_mask = feats["msa_mask"]
+
+ # Initialize the MSA and pair representations
+
+ # m: [*, S_c, N, C_m]
+ # z: [*, N, N, C_z]
+ m, z = self.input_embedder(
+ feats["target_feat"],
+ feats["residue_index"],
+ feats["msa_feat"],
+ )
+
+ # Initialize the recycling embeddings, if needs be
+ if None in [m_1_prev, z_prev, x_prev]:
+ # [*, N, C_m]
+ m_1_prev = m.new_zeros(
+ (*batch_dims, n, self.config.input_embedder.c_m),
+ requires_grad=False,
+ )
+
+ # [*, N, N, C_z]
+ z_prev = z.new_zeros(
+ (*batch_dims, n, n, self.config.input_embedder.c_z),
+ requires_grad=False,
+ )
+
+ # [*, N, 3]
+ x_prev = z.new_zeros(
+ (*batch_dims, n, residue_constants.atom_type_num, 3),
+ requires_grad=False,
+ )
+
+ x_prev = pseudo_beta_fn(
+ feats["aatype"], x_prev, None
+ ).to(dtype=z.dtype)
+
+ # m_1_prev_emb: [*, N, C_m]
+ # z_prev_emb: [*, N, N, C_z]
+ m_1_prev_emb, z_prev_emb = self.recycling_embedder(
+ m_1_prev,
+ z_prev,
+ x_prev,
+ )
+
+ # If the number of recycling iterations is 0, skip recycling
+ # altogether. We zero them this way instead of computing them
+ # conditionally to avoid leaving parameters unused, which has annoying
+ # implications for DDP training.
+ if(not _recycle):
+ m_1_prev_emb *= 0
+ z_prev_emb *= 0
+
+ # [*, S_c, N, C_m]
+ m[..., 0, :, :] += m_1_prev_emb
+
+ # [*, N, N, C_z]
+ z += z_prev_emb
+
+ # Possibly prevents memory fragmentation
+ del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
+
+ # Embed the templates + merge with MSA/pair embeddings
+ if self.config.template.enabled:
+ template_feats = {
+ k: v for k, v in feats.items() if k.startswith("template_")
+ }
+ template_embeds = self.embed_templates(
+ template_feats,
+ z,
+ pair_mask.to(dtype=z.dtype),
+ no_batch_dims,
+ )
+
+ # [*, N, N, C_z]
+ z = z + template_embeds["template_pair_embedding"]
+
+ if self.config.template.embed_angles:
+ # [*, S = S_c + S_t, N, C_m]
+ m = torch.cat(
+ [m, template_embeds["template_angle_embedding"]],
+ dim=-3
+ )
+
+ # [*, S, N]
+ torsion_angles_mask = feats["template_torsion_angles_mask"]
+ msa_mask = torch.cat(
+ [feats["msa_mask"], torsion_angles_mask[..., 2]],
+ dim=-2
+ )
+
+ # Embed extra MSA features + merge with pairwise embeddings
+ if self.config.extra_msa.enabled:
+ # [*, S_e, N, C_e]
+ a = self.extra_msa_embedder(build_extra_msa_feat(feats))
+
+ # [*, N, N, C_z]
+ z = self.extra_msa_stack(
+ a,
+ z,
+ msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
+ chunk_size=self.globals.chunk_size,
+ pair_mask=pair_mask.to(dtype=z.dtype),
+ _mask_trans=self.config._mask_trans,
+ )
+
+ # Run MSA + pair embeddings through the trunk of the network
+ # m: [*, S, N, C_m]
+ # z: [*, N, N, C_z]
+ # s: [*, N, C_s]
+ m, z, s = self.evoformer(
+ m,
+ z,
+ msa_mask=msa_mask.to(dtype=m.dtype),
+ pair_mask=pair_mask.to(dtype=z.dtype),
+ chunk_size=self.globals.chunk_size,
+ _mask_trans=self.config._mask_trans,
+ )
+
+ outputs["msa"] = m[..., :n_seq, :, :]
+ outputs["pair"] = z
+ outputs["single"] = s
+
+ # Predict 3D structure
+ outputs["sm"] = self.structure_module(
+ s,
+ z,
+ feats["aatype"],
+ mask=feats["seq_mask"].to(dtype=s.dtype),
+ )
+ outputs["final_atom_positions"] = atom14_to_atom37(
+ outputs["sm"]["positions"][-1], feats
+ )
+ outputs["final_atom_mask"] = feats["atom37_atom_exists"]
+ outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
+
+ # Save embeddings for use during the next recycling iteration
+
+ # [*, N, C_m]
+ m_1_prev = m[..., 0, :, :]
+
+ # [*, N, N, C_z]
+ z_prev = z
+
+ # [*, N, 3]
+ x_prev = outputs["final_atom_positions"]
+
+ return outputs, m_1_prev, z_prev, x_prev
+
+ def _disable_activation_checkpointing(self):
+ self.template_pair_stack.blocks_per_ckpt = None
+ self.evoformer.blocks_per_ckpt = None
+
+ for b in self.extra_msa_stack.blocks:
+ b.ckpt = False
+
+ def _enable_activation_checkpointing(self):
+ self.template_pair_stack.blocks_per_ckpt = (
+ self.config.template.template_pair_stack.blocks_per_ckpt
+ )
+ self.evoformer.blocks_per_ckpt = (
+ self.config.evoformer_stack.blocks_per_ckpt
+ )
+
+ for b in self.extra_msa_stack.blocks:
+ b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
+
+ def forward(self, batch):
+ """
+ Args:
+ batch:
+ Dictionary of arguments outlined in Algorithm 2. Keys must
+ include the official names of the features in the
+ supplement subsection 1.2.9.
+
+ The final dimension of each input must have length equal to
+ the number of recycling iterations.
+
+ Features (without the recycling dimension):
+
+ "aatype" ([*, N_res]):
+ Contrary to the supplement, this tensor of residue
+ indices is not one-hot.
+ "target_feat" ([*, N_res, C_tf])
+ One-hot encoding of the target sequence. C_tf is
+ config.model.input_embedder.tf_dim.
+ "residue_index" ([*, N_res])
+ Tensor whose final dimension consists of
+ consecutive indices from 0 to N_res.
+ "msa_feat" ([*, N_seq, N_res, C_msa])
+ MSA features, constructed as in the supplement.
+ C_msa is config.model.input_embedder.msa_dim.
+ "seq_mask" ([*, N_res])
+ 1-D sequence mask
+ "msa_mask" ([*, N_seq, N_res])
+ MSA mask
+ "pair_mask" ([*, N_res, N_res])
+ 2-D pair mask
+ "extra_msa_mask" ([*, N_extra, N_res])
+ Extra MSA mask
+ "template_mask" ([*, N_templ])
+ Template mask (on the level of templates, not
+ residues)
+ "template_aatype" ([*, N_templ, N_res])
+ Tensor of template residue indices (indices greater
+ than 19 are clamped to 20 (Unknown))
+ "template_all_atom_positions"
+ ([*, N_templ, N_res, 37, 3])
+ Template atom coordinates in atom37 format
+ "template_all_atom_mask" ([*, N_templ, N_res, 37])
+ Template atom coordinate mask
+ "template_pseudo_beta" ([*, N_templ, N_res, 3])
+ Positions of template carbon "pseudo-beta" atoms
+ (i.e. C_beta for all residues but glycine, for
+ for which C_alpha is used instead)
+ "template_pseudo_beta_mask" ([*, N_templ, N_res])
+ Pseudo-beta mask
+ """
+ # Initialize recycling embeddings
+ m_1_prev, z_prev, x_prev = None, None, None
+
+ # Disable activation checkpointing for the first few recycling iters
+ is_grad_enabled = torch.is_grad_enabled()
+ self._disable_activation_checkpointing()
+
+ # Main recycling loop
+ num_iters = batch["aatype"].shape[-1]
+ for cycle_no in range(num_iters):
+ # Select the features for the current recycling cycle
+ fetch_cur_batch = lambda t: t[..., cycle_no]
+ feats = tensor_tree_map(fetch_cur_batch, batch)
+
+ # Enable grad iff we're training and it's the final recycling layer
+ is_final_iter = cycle_no == (num_iters - 1)
+ with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
+ if is_final_iter:
+ self._enable_activation_checkpointing()
+ # Sidestep AMP bug (PyTorch issue #65766)
+ if torch.is_autocast_enabled():
+ torch.clear_autocast_cache()
+
+ # Run the next iteration of the model
+ outputs, m_1_prev, z_prev, x_prev = self.iteration(
+ feats,
+ m_1_prev,
+ z_prev,
+ x_prev,
+ _recycle=(num_iters > 1)
+ )
+
+ # Run auxiliary heads
+ outputs.update(self.aux_heads(outputs))
+
+ return outputs
diff --git a/openfold/model/msa.py b/openfold/model/msa.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d3936f296a8cd5b7fa7f28a9d941668b2bd6c66
--- /dev/null
+++ b/openfold/model/msa.py
@@ -0,0 +1,392 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import torch
+import torch.nn as nn
+from typing import Optional, List, Tuple
+
+from openfold.model.primitives import (
+ Linear,
+ LayerNorm,
+ Attention,
+ GlobalAttention,
+ _attention_chunked_trainable,
+)
+from openfold.utils.checkpointing import get_checkpoint_fn
+from openfold.utils.tensor_utils import (
+ chunk_layer,
+ permute_final_dims,
+ flatten_final_dims,
+)
+
+
+class MSAAttention(nn.Module):
+ def __init__(
+ self,
+ c_in,
+ c_hidden,
+ no_heads,
+ pair_bias=False,
+ c_z=None,
+ inf=1e9,
+ ):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Per-head hidden channel dimension
+ no_heads:
+ Number of attention heads
+ pair_bias:
+ Whether to use pair embedding bias
+ c_z:
+ Pair embedding channel dimension. Ignored unless pair_bias
+ is true
+ inf:
+ A large number to be used in computing the attention mask
+ """
+ super(MSAAttention, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.pair_bias = pair_bias
+ self.c_z = c_z
+ self.inf = inf
+
+ self.layer_norm_m = LayerNorm(self.c_in)
+
+ self.layer_norm_z = None
+ self.linear_z = None
+ if self.pair_bias:
+ self.layer_norm_z = LayerNorm(self.c_z)
+ self.linear_z = Linear(
+ self.c_z, self.no_heads, bias=False, init="normal"
+ )
+
+ self.mha = Attention(
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
+ )
+
+ @torch.jit.ignore
+ def _chunk(self,
+ m: torch.Tensor,
+ biases: List[torch.Tensor],
+ chunk_size: int,
+ ) -> torch.Tensor:
+ return chunk_layer(
+ self.mha,
+ {"q_x": m, "kv_x": m, "biases": biases},
+ chunk_size=chunk_size,
+ no_batch_dims=len(m.shape[:-2]),
+ )
+
+ def _prep_inputs(self,
+ m: torch.Tensor,
+ z: Optional[torch.Tensor],
+ mask: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # [*, N_seq, N_res, C_m]
+ m = self.layer_norm_m(m)
+
+ n_seq, n_res = m.shape[-3:-1]
+ if mask is None:
+ # [*, N_seq, N_res]
+ mask = m.new_ones(
+ m.shape[:-3] + (n_seq, n_res),
+ )
+
+ # [*, N_seq, 1, 1, N_res]
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+ # This step simply returns a larger view of the bias, and does not
+ # consume additional memory.
+ # [*, N_seq, no_heads, N_res, N_res]
+ #bias = bias.expand(
+ # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
+ #)
+
+ if (self.pair_bias and
+ z is not None and # For the
+ self.layer_norm_z is not None and # benefit of
+ self.linear_z is not None # TorchScript
+ ):
+ # [*, N_res, N_res, C_z]
+ z = self.layer_norm_z(z)
+
+ # [*, N_res, N_res, no_heads]
+ z = self.linear_z(z)
+
+ # [*, 1, no_heads, N_res, N_res]
+ z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
+
+ return m, mask_bias, z
+
+ @torch.jit.ignore
+ def _chunked_msa_attn(self,
+ m: torch.Tensor,
+ z: Optional[torch.Tensor],
+ mask: Optional[torch.Tensor],
+ chunk_logits: int,
+ checkpoint: bool,
+ ) -> torch.Tensor:
+ MSA_DIM = -4
+
+ def _get_qkv(m, z):
+ m, mask_bias, z = self._prep_inputs(m, z, mask)
+ q, k, v = self.mha._prep_qkv(m, m)
+ return m, q, k, v, mask_bias, z
+
+ checkpoint_fn = get_checkpoint_fn()
+
+ if(torch.is_grad_enabled() and checkpoint):
+ m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
+ else:
+ m, q, k, v, mask_bias, z = _get_qkv(m, z)
+
+ o = _attention_chunked_trainable(
+ query=q,
+ key=k,
+ value=v,
+ biases=[mask_bias, z],
+ chunk_size=chunk_logits,
+ chunk_dim=MSA_DIM,
+ checkpoint=checkpoint,
+ )
+
+ if(torch.is_grad_enabled() and checkpoint):
+ # Storing an additional m here is far from ideal
+ m = checkpoint_fn(self.mha._wrap_up, o, m)
+ else:
+ m = self.mha._wrap_up(o, m)
+
+ return m
+
+ def forward(self,
+ m: torch.Tensor,
+ z: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ _chunk_logits: Optional[int] = None,
+ _checkpoint_chunks: Optional[bool] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ z:
+ [*, N_res, N_res, C_z] pair embedding. Required only if
+ pair_bias is True
+ mask:
+ [*, N_seq, N_res] MSA mask
+ chunk_size:
+ Size of chunks into which the inputs are split along their
+ batch dimensions. A low value decreases memory overhead at the
+ cost of slower execution. Chunking is not performed by default.
+
+ """
+ if(_chunk_logits is not None):
+ return self._chunked_msa_attn(
+ m=m, z=z, mask=mask,
+ chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
+ )
+
+ m, mask_bias, z = self._prep_inputs(m, z, mask)
+
+ biases = [mask_bias]
+ if(z is not None):
+ biases.append(z)
+
+ if chunk_size is not None:
+ m = self._chunk(m, biases, chunk_size)
+ else:
+ m = self.mha(
+ q_x=m,
+ kv_x=m,
+ biases=biases
+ )
+
+ return m
+
+
+class MSARowAttentionWithPairBias(MSAAttention):
+ """
+ Implements Algorithm 7.
+ """
+
+ def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
+ """
+ Args:
+ c_m:
+ Input channel dimension
+ c_z:
+ Pair embedding channel dimension
+ c_hidden:
+ Per-head hidden channel dimension
+ no_heads:
+ Number of attention heads
+ inf:
+ Large number used to construct attention masks
+ """
+ super(MSARowAttentionWithPairBias, self).__init__(
+ c_m,
+ c_hidden,
+ no_heads,
+ pair_bias=True,
+ c_z=c_z,
+ inf=inf,
+ )
+
+
+class MSAColumnAttention(nn.Module):
+ """
+ Implements Algorithm 8.
+
+ By rights, this should also be a subclass of MSAAttention. Alas,
+ most inheritance isn't supported by TorchScript.
+ """
+
+ def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
+ """
+ Args:
+ c_m:
+ MSA channel dimension
+ c_hidden:
+ Per-head hidden channel dimension
+ no_heads:
+ Number of attention heads
+ inf:
+ Large number used to construct attention masks
+ """
+ super(MSAColumnAttention, self).__init__()
+
+ self.c_m = c_m
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.inf = inf
+
+ self._msa_att = MSAAttention(
+ c_in=c_m,
+ c_hidden=c_hidden,
+ no_heads=no_heads,
+ pair_bias=False,
+ c_z=None,
+ inf=inf,
+ )
+
+ def forward(self,
+ m: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ mask:
+ [*, N_seq, N_res] MSA mask
+ chunk_size:
+ Size of chunks into which the inputs are split along their
+ batch dimensions. A low value decreases memory overhead at the
+ cost of slower execution. Chunking is not performed by default.
+ """
+ # [*, N_res, N_seq, C_in]
+ m = m.transpose(-2, -3)
+ if mask is not None:
+ mask = mask.transpose(-1, -2)
+
+ m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
+
+ # [*, N_seq, N_res, C_in]
+ m = m.transpose(-2, -3)
+ if mask is not None:
+ mask = mask.transpose(-1, -2)
+
+ return m
+
+
+class MSAColumnGlobalAttention(nn.Module):
+ def __init__(
+ self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
+ ):
+ super(MSAColumnGlobalAttention, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.inf = inf
+ self.eps = eps
+
+ self.layer_norm_m = nn.LayerNorm(c_in)
+
+ self.global_attention = GlobalAttention(
+ c_in=c_in,
+ c_hidden=c_hidden,
+ no_heads=no_heads,
+ inf=inf,
+ eps=eps,
+ )
+
+ @torch.jit.ignore
+ def _chunk(self,
+ m: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int,
+ ) -> torch.Tensor:
+ mha_input = {
+ "m": m,
+ "mask": mask,
+ }
+ return chunk_layer(
+ self.global_attention,
+ mha_input,
+ chunk_size=chunk_size,
+ no_batch_dims=len(m.shape[:-2]),
+ )
+
+ def forward(
+ self,
+ m: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ n_seq, n_res, c_in = m.shape[-3:]
+
+ if mask is None:
+ # [*, N_seq, N_res]
+ mask = torch.ones(
+ m.shape[:-1],
+ dtype=m.dtype,
+ device=m.device,
+ ).detach()
+
+ # [*, N_res, N_seq, C_in]
+ m = m.transpose(-2, -3)
+ mask = mask.transpose(-1, -2)
+
+ # [*, N_res, N_seq, C_in]
+ m = self.layer_norm_m(m)
+
+ if chunk_size is not None:
+ m = self._chunk(m, mask, chunk_size)
+ else:
+ m = self.global_attention(m=m, mask=mask)
+
+ # [*, N_seq, N_res, C_in]
+ m = m.transpose(-2, -3)
+
+ return m
diff --git a/openfold/model/outer_product_mean.py b/openfold/model/outer_product_mean.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edef22e13235b1d40e403176b5b41828e10832f
--- /dev/null
+++ b/openfold/model/outer_product_mean.py
@@ -0,0 +1,129 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear
+from openfold.utils.tensor_utils import chunk_layer
+
+
+class OuterProductMean(nn.Module):
+ """
+ Implements Algorithm 10.
+ """
+
+ def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
+ """
+ Args:
+ c_m:
+ MSA embedding channel dimension
+ c_z:
+ Pair embedding channel dimension
+ c_hidden:
+ Hidden channel dimension
+ """
+ super(OuterProductMean, self).__init__()
+
+ self.c_m = c_m
+ self.c_z = c_z
+ self.c_hidden = c_hidden
+ self.eps = eps
+
+ self.layer_norm = nn.LayerNorm(c_m)
+ self.linear_1 = Linear(c_m, c_hidden)
+ self.linear_2 = Linear(c_m, c_hidden)
+ self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
+
+ def _opm(self, a, b):
+ # [*, N_res, N_res, C, C]
+ outer = torch.einsum("...bac,...dae->...bdce", a, b)
+
+ # [*, N_res, N_res, C * C]
+ outer = outer.reshape(outer.shape[:-2] + (-1,))
+
+ # [*, N_res, N_res, C_z]
+ outer = self.linear_out(outer)
+
+ return outer
+
+ @torch.jit.ignore
+ def _chunk(self,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ chunk_size: int
+ ) -> torch.Tensor:
+ # Since the "batch dim" in this case is not a true batch dimension
+ # (in that the shape of the output depends on it), we need to
+ # iterate over it ourselves
+ a_reshape = a.reshape((-1,) + a.shape[-3:])
+ b_reshape = b.reshape((-1,) + b.shape[-3:])
+ out = []
+ for a_prime, b_prime in zip(a_reshape, b_reshape):
+ outer = chunk_layer(
+ partial(self._opm, b=b_prime),
+ {"a": a_prime},
+ chunk_size=chunk_size,
+ no_batch_dims=1,
+ )
+ out.append(outer)
+ outer = torch.stack(out, dim=0)
+ outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
+
+ return outer
+
+ def forward(self,
+ m: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ m:
+ [*, N_seq, N_res, C_m] MSA embedding
+ mask:
+ [*, N_seq, N_res] MSA mask
+ Returns:
+ [*, N_res, N_res, C_z] pair embedding update
+ """
+ if mask is None:
+ mask = m.new_ones(m.shape[:-1])
+
+ # [*, N_seq, N_res, C_m]
+ m = self.layer_norm(m)
+
+ # [*, N_seq, N_res, C]
+ mask = mask.unsqueeze(-1)
+ a = self.linear_1(m) * mask
+ b = self.linear_2(m) * mask
+
+ a = a.transpose(-2, -3)
+ b = b.transpose(-2, -3)
+
+ if chunk_size is not None:
+ outer = self._chunk(a, b, chunk_size)
+ else:
+ outer = self._opm(a, b)
+
+ # [*, N_res, N_res, 1]
+ norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
+
+ # [*, N_res, N_res, C_z]
+ outer = outer / (self.eps + norm)
+
+ return outer
diff --git a/openfold/model/pair_transition.py b/openfold/model/pair_transition.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81d2a4c0e087ffdb0c98d8cb0e91c5b24d6a0e8
--- /dev/null
+++ b/openfold/model/pair_transition.py
@@ -0,0 +1,99 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear, LayerNorm
+from openfold.utils.tensor_utils import chunk_layer
+
+
+class PairTransition(nn.Module):
+ """
+ Implements Algorithm 15.
+ """
+
+ def __init__(self, c_z, n):
+ """
+ Args:
+ c_z:
+ Pair transition channel dimension
+ n:
+ Factor by which c_z is multiplied to obtain hidden channel
+ dimension
+ """
+ super(PairTransition, self).__init__()
+
+ self.c_z = c_z
+ self.n = n
+
+ self.layer_norm = LayerNorm(self.c_z)
+ self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
+ self.relu = nn.ReLU()
+ self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
+
+ def _transition(self, z, mask):
+ # [*, N_res, N_res, C_hidden]
+ z = self.linear_1(z)
+ z = self.relu(z)
+
+ # [*, N_res, N_res, C_z]
+ z = self.linear_2(z) * mask
+
+ return z
+
+ @torch.jit.ignore
+ def _chunk(self,
+ z: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int,
+ ) -> torch.Tensor:
+ return chunk_layer(
+ self._transition,
+ {"z": z, "mask": mask},
+ chunk_size=chunk_size,
+ no_batch_dims=len(z.shape[:-2]),
+ )
+
+
+ def forward(self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ z:
+ [*, N_res, N_res, C_z] pair embedding
+ Returns:
+ [*, N_res, N_res, C_z] pair embedding update
+ """
+ # DISCREPANCY: DeepMind forgets to apply the mask in this module.
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ # [*, N_res, N_res, 1]
+ mask = mask.unsqueeze(-1)
+
+ # [*, N_res, N_res, C_z]
+ z = self.layer_norm(z)
+
+ if chunk_size is not None:
+ z = self._chunk(z, mask, chunk_size)
+ else:
+ z = self._transition(z=z, mask=mask)
+
+ return z
diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py
new file mode 100644
index 0000000000000000000000000000000000000000..0742038e49b327661c3b6a137918f500e883cd04
--- /dev/null
+++ b/openfold/model/primitives.py
@@ -0,0 +1,587 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+import math
+from typing import Optional, Callable, List, Tuple, Sequence
+import numpy as np
+
+import deepspeed
+import torch
+import torch.nn as nn
+from scipy.stats import truncnorm
+
+from openfold.utils.checkpointing import get_checkpoint_fn
+from openfold.utils.tensor_utils import (
+ permute_final_dims,
+ flatten_final_dims,
+ _chunk_slice,
+)
+
+
+def _prod(nums):
+ out = 1
+ for n in nums:
+ out = out * n
+ return out
+
+
+def _calculate_fan(linear_weight_shape, fan="fan_in"):
+ fan_out, fan_in = linear_weight_shape
+
+ if fan == "fan_in":
+ f = fan_in
+ elif fan == "fan_out":
+ f = fan_out
+ elif fan == "fan_avg":
+ f = (fan_in + fan_out) / 2
+ else:
+ raise ValueError("Invalid fan option")
+
+ return f
+
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ f = _calculate_fan(shape, fan)
+ scale = scale / max(1, f)
+ a = -2
+ b = 2
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
+ size = _prod(shape)
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
+ samples = np.reshape(samples, shape)
+ with torch.no_grad():
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def lecun_normal_init_(weights):
+ trunc_normal_init_(weights, scale=1.0)
+
+
+def he_normal_init_(weights):
+ trunc_normal_init_(weights, scale=2.0)
+
+
+def glorot_uniform_init_(weights):
+ nn.init.xavier_uniform_(weights, gain=1)
+
+
+def final_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def gating_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def normal_init_(weights):
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
+
+
+class Linear(nn.Linear):
+ """
+ A Linear layer with built-in nonstandard initializations. Called just
+ like torch.nn.Linear.
+
+ Implements the initializers in 1.11.4, plus some additional ones found
+ in the code.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization
+ "relu": He initialization w/ truncated normal distribution
+ "glorot": Fan-average Glorot uniform initialization
+ "gating": Weights=0, Bias=1
+ "normal": Normal initialization with std=1/sqrt(fan_in)
+ "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs.
+ Overrides init if not None.
+ """
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+
+ if init_fn is not None:
+ init_fn(self.weight, self.bias)
+ else:
+ if init == "default":
+ lecun_normal_init_(self.weight)
+ elif init == "relu":
+ he_normal_init_(self.weight)
+ elif init == "glorot":
+ glorot_uniform_init_(self.weight)
+ elif init == "gating":
+ gating_init_(self.weight)
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(1.0)
+ elif init == "normal":
+ normal_init_(self.weight)
+ elif init == "final":
+ final_init_(self.weight)
+ else:
+ raise ValueError("Invalid init string.")
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, c_in, eps=1e-5):
+ super(LayerNorm, self).__init__()
+
+ self.c_in = (c_in,)
+ self.eps = eps
+
+ self.weight = nn.Parameter(torch.ones(c_in))
+ self.bias = nn.Parameter(torch.zeros(c_in))
+
+ def forward(self, x):
+ d = x.dtype
+ if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
+ with torch.cuda.amp.autocast(enabled=False):
+ out = nn.functional.layer_norm(
+ x,
+ self.c_in,
+ self.weight.to(dtype=d),
+ self.bias.to(dtype=d),
+ self.eps
+ )
+ else:
+ out = nn.functional.layer_norm(
+ x,
+ self.c_in,
+ self.weight,
+ self.bias,
+ self.eps,
+ )
+
+ return out
+
+@torch.jit.ignore
+def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
+ """
+ Softmax, but without automatic casting to fp32 when the input is of
+ type bfloat16
+ """
+ d = t.dtype
+ if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
+ with torch.cuda.amp.autocast(enabled=False):
+ s = torch.nn.functional.softmax(t, dim=dim)
+ else:
+ s = torch.nn.functional.softmax(t, dim=dim)
+
+ return s
+
+
+#@torch.jit.script
+def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
+ # [*, H, Q, C_hidden]
+ query = permute_final_dims(query, (1, 0, 2))
+
+ # [*, H, C_hidden, K]
+ key = permute_final_dims(key, (1, 2, 0))
+
+ # [*, H, V, C_hidden]
+ value = permute_final_dims(value, (1, 0, 2))
+
+ # [*, H, Q, K]
+ a = torch.matmul(query, key)
+
+ for b in biases:
+ a += b
+
+ a = softmax(a, -1)
+
+ # [*, H, Q, C_hidden]
+ a = torch.matmul(a, value)
+
+ # [*, Q, H, C_hidden]
+ a = a.transpose(-2, -3)
+
+ return a
+
+
+@torch.jit.ignore
+def _attention_chunked_trainable(
+ query, key, value, biases, chunk_size, chunk_dim, checkpoint,
+):
+ if(checkpoint and len(biases) > 2):
+ raise ValueError(
+ "Checkpointed version permits only permits two bias terms"
+ )
+
+ def _checkpointable_attention(q, k, v, b1, b2):
+ bs = [b for b in [b1, b2] if b is not None]
+ return _attention(q, k, v, bs)
+
+ o_chunks = []
+ checkpoint_fn = get_checkpoint_fn()
+ count = query.shape[chunk_dim]
+ for start in range(0, count, chunk_size):
+ end = start + chunk_size
+ idx = [slice(None)] * len(query.shape)
+ idx[chunk_dim] = slice(start, end)
+ idx_tup = tuple(idx)
+ q_chunk = query[idx_tup]
+ k_chunk = key[idx_tup]
+ v_chunk = value[idx_tup]
+
+ def _slice_bias(b):
+ idx[chunk_dim] = (
+ slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
+ )
+ return b[tuple(idx)]
+
+ if(checkpoint):
+ bias_1_chunk, bias_2_chunk = [
+ _slice_bias(b) if b is not None else None
+ for b in (biases + [None, None])[:2]
+ ]
+
+ o_chunk = checkpoint_fn(_checkpointable_attention,
+ q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
+ )
+ else:
+ bias_chunks = [
+ _slice_bias(b) for b in biases
+ ]
+
+ o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
+
+ o_chunks.append(o_chunk)
+
+ o = torch.cat(o_chunks, dim=chunk_dim)
+ return o
+
+
+class Attention(nn.Module):
+ """
+ Standard multi-head attention using AlphaFold's default layer
+ initialization. Allows multiple bias vectors.
+ """
+ def __init__(
+ self,
+ c_q: int,
+ c_k: int,
+ c_v: int,
+ c_hidden: int,
+ no_heads: int,
+ gating: bool = True,
+ ):
+ """
+ Args:
+ c_q:
+ Input dimension of query data
+ c_k:
+ Input dimension of key data
+ c_v:
+ Input dimension of value data
+ c_hidden:
+ Per-head hidden dimension
+ no_heads:
+ Number of attention heads
+ gating:
+ Whether the output should be gated using query data
+ """
+ super(Attention, self).__init__()
+
+ self.c_q = c_q
+ self.c_k = c_k
+ self.c_v = c_v
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.gating = gating
+
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
+ # stated in the supplement, but the overall channel dimension.
+
+ self.linear_q = Linear(
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_k = Linear(
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_v = Linear(
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_o = Linear(
+ self.c_hidden * self.no_heads, self.c_q, init="final"
+ )
+
+ self.linear_g = None
+ if self.gating:
+ self.linear_g = Linear(
+ self.c_q, self.c_hidden * self.no_heads, init="gating"
+ )
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _prep_qkv(self,
+ q_x: torch.Tensor,
+ kv_x: torch.Tensor
+ ) -> Tuple[
+ torch.Tensor, torch.Tensor, torch.Tensor
+ ]:
+ # [*, Q/K/V, H * C_hidden]
+ q = self.linear_q(q_x)
+ k = self.linear_k(kv_x)
+ v = self.linear_v(kv_x)
+
+ # [*, Q/K, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
+
+ q /= math.sqrt(self.c_hidden)
+
+ return q, k, v
+
+ def _wrap_up(self,
+ o: torch.Tensor,
+ q_x: torch.Tensor
+ ) -> torch.Tensor:
+ if(self.linear_g is not None):
+ g = self.sigmoid(self.linear_g(q_x))
+
+ # [*, Q, H, C_hidden]
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
+ o = o * g
+
+ # [*, Q, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, Q, C_q]
+ o = self.linear_o(o)
+
+ return o
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ kv_x: torch.Tensor,
+ biases: Optional[List[torch.Tensor]] = None,
+ use_lma: bool = False,
+ q_chunk_size: Optional[int] = None,
+ kv_chunk_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ q_x:
+ [*, Q, C_q] query data
+ kv_x:
+ [*, K, C_k] key data
+ biases:
+ List of biases that broadcast to [*, H, Q, K]
+ use_lma:
+ Whether to use low-memory attention
+ q_chunk_size:
+ Query chunk size (for LMA)
+ kv_chunk_size:
+ Key/Value chunk size (for LMA)
+ Returns
+ [*, Q, C_q] attention update
+ """
+ if(biases is None):
+ biases = []
+ if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
+ raise ValueError(
+ "If use_lma is specified, q_chunk_size and kv_chunk_size must "
+ "be provided"
+ )
+
+ q, k, v = self._prep_qkv(q_x, kv_x)
+
+ if(use_lma):
+ biases = [
+ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
+ for b in biases
+ ]
+
+ o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
+ else:
+ o = _attention(q, k, v, biases)
+
+ o = self._wrap_up(o, q_x)
+
+ return o
+
+
+class GlobalAttention(nn.Module):
+ def __init__(self, c_in, c_hidden, no_heads, inf, eps):
+ super(GlobalAttention, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.inf = inf
+ self.eps = eps
+
+ self.linear_q = Linear(
+ c_in, c_hidden * no_heads, bias=False, init="glorot"
+ )
+
+ self.linear_k = Linear(
+ c_in, c_hidden, bias=False, init="glorot",
+ )
+ self.linear_v = Linear(
+ c_in, c_hidden, bias=False, init="glorot",
+ )
+ self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
+ self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ # [*, N_res, C_in]
+ q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
+ torch.sum(mask, dim=-1)[..., None] + self.eps
+ )
+
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(q)
+ q *= (self.c_hidden ** (-0.5))
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, N_seq, C_hidden]
+ k = self.linear_k(m)
+ v = self.linear_v(m)
+
+ # [*, N_res, H, N_seq]
+ a = torch.matmul(
+ q,
+ k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
+ )
+ bias = (self.inf * (mask - 1))[..., :, None, :]
+ a += bias
+ a = softmax(a)
+
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(
+ a,
+ v,
+ )
+
+ # [*, N_res, N_seq, C_hidden]
+ g = self.sigmoid(self.linear_g(m))
+
+ # [*, N_res, N_seq, H, C_hidden]
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, N_seq, H, C_hidden]
+ o = o.unsqueeze(-3) * g
+
+ # [*, N_res, N_seq, H * C_hidden]
+ o = o.reshape(o.shape[:-2] + (-1,))
+
+ # [*, N_res, N_seq, C_in]
+ m = self.linear_o(o)
+
+ return m
+
+
+def _lma(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ biases: List[torch.Tensor],
+ q_chunk_size: int,
+ kv_chunk_size: int,
+):
+ no_q, no_kv = q.shape[-3], k.shape[-3]
+
+ # [*, Q, H, C_hidden]
+ o = q.new_zeros(q.shape)
+ for q_s in range(0, no_q, q_chunk_size):
+ q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
+ large_bias_chunks = [
+ b[..., q_s: q_s + q_chunk_size, :] for b in biases
+ ]
+
+ maxes = []
+ weights = []
+ values = []
+ for kv_s in range(0, no_kv, kv_chunk_size):
+ k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
+ v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
+ small_bias_chunks = [
+ b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
+ ]
+
+ a = torch.einsum(
+ "...qhd,...khd->...hqk", q_chunk, k_chunk,
+ )
+
+ for b in small_bias_chunks:
+ a += b
+
+ a = a.transpose(-2, -3)
+
+ max_a = torch.max(a, dim=-1, keepdim=True)[0]
+ exp_a = torch.exp(a - max_a)
+ exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
+
+ maxes.append(max_a.detach().squeeze(-1))
+ weights.append(torch.sum(exp_a, dim=-1))
+ values.append(exp_v)
+
+ chunk_max = torch.stack(maxes, dim=-3)
+ chunk_weights = torch.stack(weights, dim=-3)
+ chunk_values = torch.stack(values, dim=-4)
+
+ global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
+ max_diffs = torch.exp(chunk_max - global_max)
+ chunk_values *= max_diffs.unsqueeze(-1)
+ chunk_weights *= max_diffs
+
+ all_values = torch.sum(chunk_values, dim=-4)
+ all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
+
+ q_chunk_out = all_values / all_weights
+
+ o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
+
+ return o
diff --git a/openfold/model/structure_module.py b/openfold/model/structure_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6ba1fe42b995b123a43095bc51a947999b102b
--- /dev/null
+++ b/openfold/model/structure_module.py
@@ -0,0 +1,820 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import reduce
+import importlib
+import math
+import sys
+from operator import mul
+
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple, Sequence
+
+from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
+from openfold.np.residue_constants import (
+ restype_rigid_group_default_frame,
+ restype_atom14_to_rigid_group,
+ restype_atom14_mask,
+ restype_atom14_rigid_group_positions,
+)
+from openfold.utils.feats import (
+ frames_and_literature_positions_to_atom14_pos,
+ torsion_angles_to_frames,
+)
+from openfold.utils.precision_utils import is_fp16_enabled
+from openfold.utils.rigid_utils import Rotation, Rigid
+from openfold.utils.tensor_utils import (
+ dict_multimap,
+ permute_final_dims,
+ flatten_final_dims,
+)
+
+# attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
+
+
+class AngleResnetBlock(nn.Module):
+ def __init__(self, c_hidden):
+ """
+ Args:
+ c_hidden:
+ Hidden channel dimension
+ """
+ super(AngleResnetBlock, self).__init__()
+
+ self.c_hidden = c_hidden
+
+ self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+
+ s_initial = a
+
+ a = self.relu(a)
+ a = self.linear_1(a)
+ a = self.relu(a)
+ a = self.linear_2(a)
+
+ return a + s_initial
+
+
+class AngleResnet(nn.Module):
+ """
+ Implements Algorithm 20, lines 11-14
+ """
+
+ def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Hidden channel dimension
+ no_blocks:
+ Number of resnet blocks
+ no_angles:
+ Number of torsion angles to generate
+ epsilon:
+ Small constant for normalization
+ """
+ super(AngleResnet, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_blocks = no_blocks
+ self.no_angles = no_angles
+ self.eps = epsilon
+
+ self.linear_in = Linear(self.c_in, self.c_hidden)
+ self.linear_initial = Linear(self.c_in, self.c_hidden)
+
+ self.layers = nn.ModuleList()
+ for _ in range(self.no_blocks):
+ layer = AngleResnetBlock(c_hidden=self.c_hidden)
+ self.layers.append(layer)
+
+ self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
+
+ self.relu = nn.ReLU()
+
+ def forward(
+ self, s: torch.Tensor, s_initial: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ s:
+ [*, C_hidden] single embedding
+ s_initial:
+ [*, C_hidden] single embedding as of the start of the
+ StructureModule
+ Returns:
+ [*, no_angles, 2] predicted angles
+ """
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
+ # pseudocode but present in the source. For maximal compatibility with
+ # the pretrained weights, I'm going with the source.
+
+ # [*, C_hidden]
+ s_initial = self.relu(s_initial)
+ s_initial = self.linear_initial(s_initial)
+ s = self.relu(s)
+ s = self.linear_in(s)
+ s = s + s_initial
+
+ for l in self.layers:
+ s = l(s)
+
+ s = self.relu(s)
+
+ # [*, no_angles * 2]
+ s = self.linear_out(s)
+
+ # [*, no_angles, 2]
+ s = s.view(s.shape[:-1] + (-1, 2))
+
+ unnormalized_s = s
+ norm_denom = torch.sqrt(
+ torch.clamp(
+ torch.sum(s ** 2, dim=-1, keepdim=True),
+ min=self.eps,
+ )
+ )
+ s = s / norm_denom
+
+ return unnormalized_s, s
+
+
+class InvariantPointAttention(nn.Module):
+ """
+ Implements Algorithm 22.
+ """
+ def __init__(
+ self,
+ c_s: int,
+ c_z: int,
+ c_hidden: int,
+ no_heads: int,
+ no_qk_points: int,
+ no_v_points: int,
+ inf: float = 1e5,
+ eps: float = 1e-8,
+ ):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ c_z:
+ Pair representation channel dimension
+ c_hidden:
+ Hidden channel dimension
+ no_heads:
+ Number of attention heads
+ no_qk_points:
+ Number of query/key points to generate
+ no_v_points:
+ Number of value points to generate
+ """
+ super(InvariantPointAttention, self).__init__()
+
+ self.c_s = c_s
+ self.c_z = c_z
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.no_qk_points = no_qk_points
+ self.no_v_points = no_v_points
+ self.inf = inf
+ self.eps = eps
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = self.c_hidden * self.no_heads
+ self.linear_q = Linear(self.c_s, hc)
+ self.linear_kv = Linear(self.c_s, 2 * hc)
+
+ hpq = self.no_heads * self.no_qk_points * 3
+ self.linear_q_points = Linear(self.c_s, hpq)
+
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
+ self.linear_kv_points = Linear(self.c_s, hpkv)
+
+ hpv = self.no_heads * self.no_v_points * 3
+
+ self.linear_b = Linear(self.c_z, self.no_heads)
+
+ self.head_weights = nn.Parameter(torch.zeros((no_heads)))
+ ipa_point_weights_init_(self.head_weights)
+
+ concat_out_dim = self.no_heads * (
+ self.c_z + self.c_hidden + self.no_v_points * 4
+ )
+ self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ r: Rigid,
+ mask: torch.Tensor,
+ inplace_safe: bool = False,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ if(_offload_inference and inplace_safe):
+ z = _z_reference_list
+ else:
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
+ )
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
+ )
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if(_offload_inference):
+ assert(sys.getrefcount(z[0]) == 2)
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ if(is_fp16_enabled()):
+ with torch.cuda.amp.autocast(enabled=False):
+ a = torch.matmul(
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ else:
+ a = torch.matmul(
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ if(inplace_safe):
+ pt_att *= pt_att
+ else:
+ pt_att = pt_att ** 2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ head_weights = self.softplus(self.head_weights).view(
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
+ )
+ head_weights = head_weights * math.sqrt(
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
+ )
+ if(inplace_safe):
+ pt_att *= head_weights
+ else:
+ pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ if(inplace_safe):
+ a += pt_att
+ del pt_att
+ a += square_mask.unsqueeze(-3)
+ # in-place softmax
+ attn_core_inplace_cuda.forward_(
+ a,
+ reduce(mul, a.shape[:-1]),
+ a.shape[-1],
+ )
+ else:
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(
+ a, v.transpose(-2, -3).to(dtype=a.dtype)
+ ).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ if(inplace_safe):
+ v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
+ o_pt = [
+ torch.matmul(a, v.to(a.dtype))
+ for v in torch.unbind(v_pts, dim=-3)
+ ]
+ o_pt = torch.stack(o_pt, dim=-3)
+ else:
+ o_pt = torch.sum(
+ (
+ a[..., None, :, :, None]
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
+ ),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_norm = flatten_final_dims(
+ torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
+ )
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if(_offload_inference):
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z]
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
+
+ # [*, N_res, H * C_z]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat(
+ (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
+ ).to(dtype=z[0].dtype)
+ )
+
+ return s
+
+
+class BackboneUpdate(nn.Module):
+ """
+ Implements part of Algorithm 23.
+ """
+
+ def __init__(self, c_s):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ """
+ super(BackboneUpdate, self).__init__()
+
+ self.c_s = c_s
+
+ self.linear = Linear(self.c_s, 6, init="final")
+
+ def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ [*, N_res, C_s] single representation
+ Returns:
+ [*, N_res, 6] update vector
+ """
+ # [*, 6]
+ update = self.linear(s)
+
+ return update
+
+
+class StructureModuleTransitionLayer(nn.Module):
+ def __init__(self, c):
+ super(StructureModuleTransitionLayer, self).__init__()
+
+ self.c = c
+
+ self.linear_1 = Linear(self.c, self.c, init="relu")
+ self.linear_2 = Linear(self.c, self.c, init="relu")
+ self.linear_3 = Linear(self.c, self.c, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+
+ s = s + s_initial
+
+ return s
+
+
+class StructureModuleTransition(nn.Module):
+ def __init__(self, c, num_layers, dropout_rate):
+ super(StructureModuleTransition, self).__init__()
+
+ self.c = c
+ self.num_layers = num_layers
+ self.dropout_rate = dropout_rate
+
+ self.layers = nn.ModuleList()
+ for _ in range(self.num_layers):
+ l = StructureModuleTransitionLayer(self.c)
+ self.layers.append(l)
+
+ self.dropout = nn.Dropout(self.dropout_rate)
+ self.layer_norm = LayerNorm(self.c)
+
+ def forward(self, s):
+ for l in self.layers:
+ s = l(s)
+
+ s = self.dropout(s)
+ s = self.layer_norm(s)
+
+ return s
+
+
+class StructureModule(nn.Module):
+ def __init__(
+ self,
+ c_s,
+ c_z,
+ c_ipa,
+ c_resnet,
+ no_heads_ipa,
+ no_qk_points,
+ no_v_points,
+ dropout_rate,
+ no_blocks,
+ no_transition_layers,
+ no_resnet_blocks,
+ no_angles,
+ trans_scale_factor,
+ epsilon,
+ inf,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_s:
+ Single representation channel dimension
+ c_z:
+ Pair representation channel dimension
+ c_ipa:
+ IPA hidden channel dimension
+ c_resnet:
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
+ no_heads_ipa:
+ Number of IPA heads
+ no_qk_points:
+ Number of query/key points to generate during IPA
+ no_v_points:
+ Number of value points to generate during IPA
+ dropout_rate:
+ Dropout rate used throughout the layer
+ no_blocks:
+ Number of structure module blocks
+ no_transition_layers:
+ Number of layers in the single representation transition
+ (Alg. 23 lines 8-9)
+ no_resnet_blocks:
+ Number of blocks in the angle resnet
+ no_angles:
+ Number of angles to generate in the angle resnet
+ trans_scale_factor:
+ Scale of single representation transition hidden dimension
+ epsilon:
+ Small number used in angle resnet normalization
+ inf:
+ Large number used for attention masking
+ """
+ super(StructureModule, self).__init__()
+
+ self.c_s = c_s
+ self.c_z = c_z
+ self.c_ipa = c_ipa
+ self.c_resnet = c_resnet
+ self.no_heads_ipa = no_heads_ipa
+ self.no_qk_points = no_qk_points
+ self.no_v_points = no_v_points
+ self.dropout_rate = dropout_rate
+ self.no_blocks = no_blocks
+ self.no_transition_layers = no_transition_layers
+ self.no_resnet_blocks = no_resnet_blocks
+ self.no_angles = no_angles
+ self.trans_scale_factor = trans_scale_factor
+ self.epsilon = epsilon
+ self.inf = inf
+
+ # Buffers to be lazily initialized later
+ # self.default_frames
+ # self.group_idx
+ # self.atom_mask
+ # self.lit_positions
+
+ self.layer_norm_s = LayerNorm(self.c_s)
+ self.layer_norm_z = LayerNorm(self.c_z)
+
+ self.linear_in = Linear(self.c_s, self.c_s)
+
+ self.ipa = InvariantPointAttention(
+ self.c_s,
+ self.c_z,
+ self.c_ipa,
+ self.no_heads_ipa,
+ self.no_qk_points,
+ self.no_v_points,
+ inf=self.inf,
+ eps=self.epsilon,
+ )
+
+ self.ipa_dropout = nn.Dropout(self.dropout_rate)
+ self.layer_norm_ipa = LayerNorm(self.c_s)
+
+ self.transition = StructureModuleTransition(
+ self.c_s,
+ self.no_transition_layers,
+ self.dropout_rate,
+ )
+
+ self.bb_update = BackboneUpdate(self.c_s)
+
+ self.angle_resnet = AngleResnet(
+ self.c_s,
+ self.c_resnet,
+ self.no_resnet_blocks,
+ self.no_angles,
+ self.epsilon,
+ )
+
+ def forward(
+ self,
+ evoformer_output_dict,
+ aatype,
+ mask=None,
+ inplace_safe=False,
+ _offload_inference=False,
+ ):
+ """
+ Args:
+ evoformer_output_dict:
+ Dictionary containing:
+ "single":
+ [*, N_res, C_s] single representation
+ "pair":
+ [*, N_res, N_res, C_z] pair representation
+ aatype:
+ [*, N_res] amino acid indices
+ mask:
+ Optional [*, N_res] sequence mask
+ Returns:
+ A dictionary of outputs
+ """
+ s = evoformer_output_dict["single"]
+
+ if mask is None:
+ # [*, N]
+ mask = s.new_ones(s.shape[:-1])
+
+ # [*, N, C_s]
+ s = self.layer_norm_s(s)
+
+ # [*, N, N, C_z]
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
+
+ z_reference_list = None
+ if(_offload_inference):
+ assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
+ z_reference_list = [z]
+ z = None
+
+ # [*, N, C_s]
+ s_initial = s
+ s = self.linear_in(s)
+
+ # [*, N]
+ rigids = Rigid.identity(
+ s.shape[:-1],
+ s.dtype,
+ s.device,
+ self.training,
+ fmt="quat",
+ )
+ outputs = []
+ for i in range(self.no_blocks):
+ # [*, N, C_s]
+ s = s + self.ipa(
+ s,
+ z,
+ rigids,
+ mask,
+ inplace_safe=inplace_safe,
+ _offload_inference=_offload_inference,
+ _z_reference_list=z_reference_list
+ )
+ s = self.ipa_dropout(s)
+ s = self.layer_norm_ipa(s)
+ s = self.transition(s)
+
+ # [*, N]
+ rigids = rigids.compose_q_update_vec(self.bb_update(s))
+
+ # To hew as closely as possible to AlphaFold, we convert our
+ # quaternion-based transformations to rotation-matrix ones
+ # here
+ backb_to_global = Rigid(
+ Rotation(
+ rot_mats=rigids.get_rots().get_rot_mats(),
+ quats=None
+ ),
+ rigids.get_trans(),
+ )
+
+ backb_to_global = backb_to_global.scale_translation(
+ self.trans_scale_factor
+ )
+
+ # [*, N, 7, 2]
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
+
+ all_frames_to_global = self.torsion_angles_to_frames(
+ backb_to_global,
+ angles,
+ aatype,
+ )
+
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
+ all_frames_to_global,
+ aatype,
+ )
+
+ scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
+
+ preds = {
+ "frames": scaled_rigids.to_tensor_7(),
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
+ "unnormalized_angles": unnormalized_angles,
+ "angles": angles,
+ "positions": pred_xyz,
+ "states": s,
+ }
+
+ outputs.append(preds)
+
+ rigids = rigids.stop_rot_gradient()
+
+ del z, z_reference_list
+
+ if(_offload_inference):
+ evoformer_output_dict["pair"] = (
+ evoformer_output_dict["pair"].to(s.device)
+ )
+
+ outputs = dict_multimap(torch.stack, outputs)
+ outputs["single"] = s
+
+ return outputs
+
+ def _init_residue_constants(self, float_dtype, device):
+ if not hasattr(self, "default_frames"):
+ self.register_buffer(
+ "default_frames",
+ torch.tensor(
+ restype_rigid_group_default_frame,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "group_idx"):
+ self.register_buffer(
+ "group_idx",
+ torch.tensor(
+ restype_atom14_to_rigid_group,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "atom_mask"):
+ self.register_buffer(
+ "atom_mask",
+ torch.tensor(
+ restype_atom14_mask,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "lit_positions"):
+ self.register_buffer(
+ "lit_positions",
+ torch.tensor(
+ restype_atom14_rigid_group_positions,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+
+ def torsion_angles_to_frames(self, r, alpha, f):
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(alpha.dtype, alpha.device)
+ # Separated purely to make testing less annoying
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
+
+ def frames_and_literature_positions_to_atom14_pos(
+ self, r, f # [*, N, 8] # [*, N]
+ ):
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
+ return frames_and_literature_positions_to_atom14_pos(
+ r,
+ f,
+ self.default_frames,
+ self.group_idx,
+ self.atom_mask,
+ self.lit_positions,
+ )
diff --git a/openfold/model/template.py b/openfold/model/template.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e1fd95bd0fa073b55be7714a794f14ae0234d0
--- /dev/null
+++ b/openfold/model/template.py
@@ -0,0 +1,333 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import math
+from typing import Optional, List
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear, LayerNorm, Attention
+from openfold.model.dropout import (
+ DropoutRowwise,
+ DropoutColumnwise,
+)
+from openfold.model.pair_transition import PairTransition
+from openfold.model.triangular_attention import (
+ TriangleAttentionStartingNode,
+ TriangleAttentionEndingNode,
+)
+from openfold.model.triangular_multiplicative_update import (
+ TriangleMultiplicationOutgoing,
+ TriangleMultiplicationIncoming,
+)
+from openfold.utils.checkpointing import checkpoint_blocks
+from openfold.utils.tensor_utils import (
+ chunk_layer,
+ permute_final_dims,
+ flatten_final_dims,
+)
+
+
+class TemplatePointwiseAttention(nn.Module):
+ """
+ Implements Algorithm 17.
+ """
+ def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
+ """
+ Args:
+ c_t:
+ Template embedding channel dimension
+ c_z:
+ Pair embedding channel dimension
+ c_hidden:
+ Hidden channel dimension
+ """
+ super(TemplatePointwiseAttention, self).__init__()
+
+ self.c_t = c_t
+ self.c_z = c_z
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.inf = inf
+
+ self.mha = Attention(
+ self.c_z,
+ self.c_t,
+ self.c_t,
+ self.c_hidden,
+ self.no_heads,
+ gating=False,
+ )
+
+ def _chunk(self,
+ z: torch.Tensor,
+ t: torch.Tensor,
+ biases: List[torch.Tensor],
+ chunk_size: int,
+ ) -> torch.Tensor:
+ mha_inputs = {
+ "q_x": z,
+ "kv_x": t,
+ "biases": biases,
+ }
+ return chunk_layer(
+ self.mha,
+ mha_inputs,
+ chunk_size=chunk_size,
+ no_batch_dims=len(z.shape[:-2]),
+ )
+
+
+ def forward(self,
+ t: torch.Tensor,
+ z: torch.Tensor,
+ template_mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ t:
+ [*, N_templ, N_res, N_res, C_t] template embedding
+ z:
+ [*, N_res, N_res, C_t] pair embedding
+ template_mask:
+ [*, N_templ] template mask
+ Returns:
+ [*, N_res, N_res, C_z] pair embedding update
+ """
+ if template_mask is None:
+ template_mask = t.new_ones(t.shape[:-3])
+
+ bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
+
+ # [*, N_res, N_res, 1, C_z]
+ z = z.unsqueeze(-2)
+
+ # [*, N_res, N_res, N_temp, C_t]
+ t = permute_final_dims(t, (1, 2, 0, 3))
+
+ # [*, N_res, N_res, 1, C_z]
+ biases = [bias]
+ if chunk_size is not None:
+ z = self._chunk(z, t, biases, chunk_size)
+ else:
+ z = self.mha(q_x=z, kv_x=t, biases=biases)
+
+ # [*, N_res, N_res, C_z]
+ z = z.squeeze(-2)
+
+ return z
+
+
+class TemplatePairStackBlock(nn.Module):
+ def __init__(
+ self,
+ c_t: int,
+ c_hidden_tri_att: int,
+ c_hidden_tri_mul: int,
+ no_heads: int,
+ pair_transition_n: int,
+ dropout_rate: float,
+ inf: float,
+ **kwargs,
+ ):
+ super(TemplatePairStackBlock, self).__init__()
+
+ self.c_t = c_t
+ self.c_hidden_tri_att = c_hidden_tri_att
+ self.c_hidden_tri_mul = c_hidden_tri_mul
+ self.no_heads = no_heads
+ self.pair_transition_n = pair_transition_n
+ self.dropout_rate = dropout_rate
+ self.inf = inf
+
+ self.dropout_row = DropoutRowwise(self.dropout_rate)
+ self.dropout_col = DropoutColumnwise(self.dropout_rate)
+
+ self.tri_att_start = TriangleAttentionStartingNode(
+ self.c_t,
+ self.c_hidden_tri_att,
+ self.no_heads,
+ inf=inf,
+ )
+ self.tri_att_end = TriangleAttentionEndingNode(
+ self.c_t,
+ self.c_hidden_tri_att,
+ self.no_heads,
+ inf=inf,
+ )
+
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
+ self.c_t,
+ self.c_hidden_tri_mul,
+ )
+ self.tri_mul_in = TriangleMultiplicationIncoming(
+ self.c_t,
+ self.c_hidden_tri_mul,
+ )
+
+ self.pair_transition = PairTransition(
+ self.c_t,
+ self.pair_transition_n,
+ )
+
+ def forward(self,
+ z: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: Optional[int] = None,
+ _mask_trans: bool = True
+ ):
+ single_templates = [
+ t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
+ ]
+ single_templates_masks = [
+ m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
+ ]
+ for i in range(len(single_templates)):
+ single = single_templates[i]
+ single_mask = single_templates_masks[i]
+
+ single = single + self.dropout_row(
+ self.tri_att_start(
+ single,
+ chunk_size=chunk_size,
+ mask=single_mask
+ )
+ )
+ single = single + self.dropout_col(
+ self.tri_att_end(
+ single,
+ chunk_size=chunk_size,
+ mask=single_mask
+ )
+ )
+ single = single + self.dropout_row(
+ self.tri_mul_out(
+ single,
+ mask=single_mask
+ )
+ )
+ single = single + self.dropout_row(
+ self.tri_mul_in(
+ single,
+ mask=single_mask
+ )
+ )
+ single = single + self.pair_transition(
+ single,
+ mask=single_mask if _mask_trans else None,
+ chunk_size=chunk_size,
+ )
+
+ single_templates[i] = single
+
+ z = torch.cat(single_templates, dim=-4)
+
+ return z
+
+
+class TemplatePairStack(nn.Module):
+ """
+ Implements Algorithm 16.
+ """
+ def __init__(
+ self,
+ c_t,
+ c_hidden_tri_att,
+ c_hidden_tri_mul,
+ no_blocks,
+ no_heads,
+ pair_transition_n,
+ dropout_rate,
+ blocks_per_ckpt,
+ inf=1e9,
+ **kwargs,
+ ):
+ """
+ Args:
+ c_t:
+ Template embedding channel dimension
+ c_hidden_tri_att:
+ Per-head hidden dimension for triangular attention
+ c_hidden_tri_att:
+ Hidden dimension for triangular multiplication
+ no_blocks:
+ Number of blocks in the stack
+ pair_transition_n:
+ Scale of pair transition (Alg. 15) hidden dimension
+ dropout_rate:
+ Dropout rate used throughout the stack
+ blocks_per_ckpt:
+ Number of blocks per activation checkpoint. None disables
+ activation checkpointing
+ """
+ super(TemplatePairStack, self).__init__()
+
+ self.blocks_per_ckpt = blocks_per_ckpt
+
+ self.blocks = nn.ModuleList()
+ for _ in range(no_blocks):
+ block = TemplatePairStackBlock(
+ c_t=c_t,
+ c_hidden_tri_att=c_hidden_tri_att,
+ c_hidden_tri_mul=c_hidden_tri_mul,
+ no_heads=no_heads,
+ pair_transition_n=pair_transition_n,
+ dropout_rate=dropout_rate,
+ inf=inf,
+ )
+ self.blocks.append(block)
+
+ self.layer_norm = LayerNorm(c_t)
+
+ def forward(
+ self,
+ t: torch.tensor,
+ mask: torch.tensor,
+ chunk_size: int,
+ _mask_trans: bool = True,
+ ):
+ """
+ Args:
+ t:
+ [*, N_templ, N_res, N_res, C_t] template embedding
+ mask:
+ [*, N_templ, N_res, N_res] mask
+ Returns:
+ [*, N_templ, N_res, N_res, C_t] template embedding update
+ """
+ if(mask.shape[-3] == 1):
+ expand_idx = list(mask.shape)
+ expand_idx[-3] = t.shape[-4]
+ mask = mask.expand(*expand_idx)
+
+ t, = checkpoint_blocks(
+ blocks=[
+ partial(
+ b,
+ mask=mask,
+ chunk_size=chunk_size,
+ _mask_trans=_mask_trans,
+ )
+ for b in self.blocks
+ ],
+ args=(t,),
+ blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
+ )
+
+ t = self.layer_norm(t)
+
+ return t
diff --git a/openfold/model/torchscript.py b/openfold/model/torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..cac7e6725b04f9383bd27cb5599b4e5172462de5
--- /dev/null
+++ b/openfold/model/torchscript.py
@@ -0,0 +1,215 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Sequence, Tuple
+
+import torch
+import torch.nn as nn
+
+from openfold.model.dropout import (
+ DropoutRowwise,
+ DropoutColumnwise,
+)
+from openfold.model.evoformer import (
+ EvoformerBlock,
+ EvoformerStack,
+)
+from openfold.model.outer_product_mean import OuterProductMean
+from openfold.model.msa import (
+ MSARowAttentionWithPairBias,
+ MSAColumnAttention,
+ MSAColumnGlobalAttention,
+)
+from openfold.model.pair_transition import PairTransition
+from openfold.model.primitives import Attention, GlobalAttention
+from openfold.model.structure_module import (
+ InvariantPointAttention,
+ BackboneUpdate,
+)
+from openfold.model.template import TemplatePairStackBlock
+from openfold.model.triangular_attention import (
+ TriangleAttentionStartingNode,
+ TriangleAttentionEndingNode,
+)
+from openfold.model.triangular_multiplicative_update import (
+ TriangleMultiplicationOutgoing,
+ TriangleMultiplicationIncoming,
+)
+
+
+def script_preset_(model: torch.nn.Module):
+ """
+ TorchScript a handful of low-level but frequently used submodule types
+ that are known to be scriptable.
+
+ Args:
+ model:
+ A torch.nn.Module. It should contain at least some modules from
+ this repository, or this function won't do anything.
+ """
+ script_submodules_(
+ model,
+ [
+ nn.Dropout,
+ Attention,
+ GlobalAttention,
+ EvoformerBlock,
+ #TemplatePairStackBlock,
+ ],
+ attempt_trace=False,
+ batch_dims=None,
+ )
+
+
+def _get_module_device(module: torch.nn.Module) -> torch.device:
+ """
+ Fetches the device of a module, assuming that all of the module's
+ parameters reside on a single device
+
+ Args:
+ module: A torch.nn.Module
+ Returns:
+ The module's device
+ """
+ return next(module.parameters()).device
+
+
+def _trace_module(module, batch_dims=None):
+ if(batch_dims is None):
+ batch_dims = ()
+
+ # Stand-in values
+ n_seq = 10
+ n_res = 10
+
+ device = _get_module_device(module)
+
+ def msa(channel_dim):
+ return torch.rand(
+ (*batch_dims, n_seq, n_res, channel_dim),
+ device=device,
+ )
+
+ def pair(channel_dim):
+ return torch.rand(
+ (*batch_dims, n_res, n_res, channel_dim),
+ device=device,
+ )
+
+ if(isinstance(module, MSARowAttentionWithPairBias)):
+ inputs = {
+ "forward": (
+ msa(module.c_in), # m
+ pair(module.c_z), # z
+ torch.randint(
+ 0, 2,
+ (*batch_dims, n_seq, n_res)
+ ), # mask
+ ),
+ }
+ elif(isinstance(module, MSAColumnAttention)):
+ inputs = {
+ "forward": (
+ msa(module.c_in), # m
+ torch.randint(
+ 0, 2,
+ (*batch_dims, n_seq, n_res)
+ ), # mask
+ ),
+ }
+ elif(isinstance(module, OuterProductMean)):
+ inputs = {
+ "forward": (
+ msa(module.c_m),
+ torch.randint(
+ 0, 2,
+ (*batch_dims, n_seq, n_res)
+ )
+ )
+ }
+ else:
+ raise TypeError(
+ f"tracing is not supported for modules of type {type(module)}"
+ )
+
+ return torch.jit.trace_module(module, inputs)
+
+
+def _script_submodules_helper_(
+ model,
+ types,
+ attempt_trace,
+ to_trace,
+):
+ for name, child in model.named_children():
+ if(types is None or any(isinstance(child, t) for t in types)):
+ try:
+ scripted = torch.jit.script(child)
+ setattr(model, name, scripted)
+ continue
+ except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
+ if(attempt_trace):
+ to_trace.add(type(child))
+ else:
+ raise e
+
+ _script_submodules_helper_(child, types, attempt_trace, to_trace)
+
+
+def _trace_submodules_(
+ model,
+ types,
+ batch_dims=None,
+):
+ for name, child in model.named_children():
+ if(any(isinstance(child, t) for t in types)):
+ traced = _trace_module(child, batch_dims=batch_dims)
+ setattr(model, name, traced)
+ else:
+ _trace_submodules_(child, types, batch_dims=batch_dims)
+
+
+def script_submodules_(
+ model: nn.Module,
+ types: Optional[Sequence[type]] = None,
+ attempt_trace: Optional[bool] = True,
+ batch_dims: Optional[Tuple[int]] = None,
+):
+ """
+ Convert all submodules whose types match one of those in the input
+ list to recursively scripted equivalents in place. To script the entire
+ model, just call torch.jit.script on it directly.
+
+ When types is None, all submodules are scripted.
+
+ Args:
+ model:
+ A torch.nn.Module
+ types:
+ A list of types of submodules to script
+ attempt_trace:
+ Whether to attempt to trace specified modules if scripting
+ fails. Recall that tracing eliminates all conditional
+ logic---with great tracing comes the mild responsibility of
+ having to remember to ensure that the modules in question
+ perform the same computations no matter what.
+ """
+ to_trace = set()
+
+ # Aggressively script as much as possible first...
+ _script_submodules_helper_(model, types, attempt_trace, to_trace)
+
+ # ... and then trace stragglers.
+ if(attempt_trace and len(to_trace) > 0):
+ _trace_submodules_(model, to_trace, batch_dims=batch_dims)
diff --git a/openfold/model/triangular_attention.py b/openfold/model/triangular_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..91fa2bb62a110825cc584eb2aff02e753ad3b307
--- /dev/null
+++ b/openfold/model/triangular_attention.py
@@ -0,0 +1,139 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partialmethod, partial
+import math
+from typing import Optional, List
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear, LayerNorm, Attention
+from openfold.utils.tensor_utils import (
+ chunk_layer,
+ permute_final_dims,
+ flatten_final_dims,
+)
+
+
+class TriangleAttention(nn.Module):
+ def __init__(
+ self, c_in, c_hidden, no_heads, starting, inf=1e9
+ ):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Overall hidden channel dimension (not per-head)
+ no_heads:
+ Number of attention heads
+ """
+ super(TriangleAttention, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.starting = starting
+ self.inf = inf
+
+ self.layer_norm = LayerNorm(self.c_in)
+
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
+
+ self.mha = Attention(
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
+ )
+
+ @torch.jit.ignore
+ def _chunk(self,
+ x: torch.Tensor,
+ biases: List[torch.Tensor],
+ chunk_size: int,
+ ) -> torch.Tensor:
+ mha_inputs = {
+ "q_x": x,
+ "kv_x": x,
+ "biases": biases,
+ }
+ return chunk_layer(
+ partial(self.mha),
+ mha_inputs,
+ chunk_size=chunk_size,
+ no_batch_dims=len(x.shape[:-2]),
+ )
+
+ def forward(self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
+ Returns:
+ [*, I, J, C_in] output tensor
+ """
+ if mask is None:
+ # [*, I, J]
+ mask = x.new_ones(
+ x.shape[:-1],
+ )
+
+ # Shape annotations assume self.starting. Else, I and J are flipped
+ if not self.starting:
+ x = x.transpose(-2, -3)
+ mask = mask.transpose(-1, -2)
+
+ # [*, I, J, C_in]
+ x = self.layer_norm(x)
+
+ # [*, I, 1, 1, J]
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+ # [*, H, I, J]
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
+
+ # [*, 1, H, I, J]
+ triangle_bias = triangle_bias.unsqueeze(-4)
+
+ biases = [mask_bias, triangle_bias]
+
+ if chunk_size is not None:
+ x = self._chunk(x, biases, chunk_size)
+ else:
+ x = self.mha(q_x=x, kv_x=x, biases=biases)
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+
+ return x
+
+
+class TriangleAttentionStartingNode(TriangleAttention):
+ """
+ Implements Algorithm 13.
+ """
+
+ __init__ = partialmethod(TriangleAttention.__init__, starting=True)
+
+
+class TriangleAttentionEndingNode(TriangleAttention):
+ """
+ Implements Algorithm 14.
+ """
+
+ __init__ = partialmethod(TriangleAttention.__init__, starting=False)
diff --git a/openfold/model/triangular_multiplicative_update.py b/openfold/model/triangular_multiplicative_update.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edd24f6f8ccab7e73e577ddc2aaa6d2443757ed
--- /dev/null
+++ b/openfold/model/triangular_multiplicative_update.py
@@ -0,0 +1,127 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partialmethod
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from openfold.model.primitives import Linear, LayerNorm
+from openfold.utils.tensor_utils import permute_final_dims
+
+
+class TriangleMultiplicativeUpdate(nn.Module):
+ """
+ Implements Algorithms 11 and 12.
+ """
+ def __init__(self, c_z, c_hidden, _outgoing=True):
+ """
+ Args:
+ c_z:
+ Input channel dimension
+ c:
+ Hidden channel dimension
+ """
+ super(TriangleMultiplicativeUpdate, self).__init__()
+ self.c_z = c_z
+ self.c_hidden = c_hidden
+ self._outgoing = _outgoing
+
+ self.linear_a_p = Linear(self.c_z, self.c_hidden)
+ self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
+ self.linear_b_p = Linear(self.c_z, self.c_hidden)
+ self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
+ self.linear_g = Linear(self.c_z, self.c_z, init="gating")
+ self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
+
+ self.layer_norm_in = LayerNorm(self.c_z)
+ self.layer_norm_out = LayerNorm(self.c_hidden)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _combine_projections(self,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ ) -> torch.Tensor:
+ raise NotImplementedError("This method needs to be overridden")
+
+ def forward(self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, N_res, N_res, C_z] input tensor
+ mask:
+ [*, N_res, N_res] input mask
+ Returns:
+ [*, N_res, N_res, C_z] output tensor
+ """
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ z = self.layer_norm_in(z)
+ a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
+ a = a * mask
+ b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
+ b = b * mask
+ x = self._combine_projections(a, b)
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ g = self.sigmoid(self.linear_g(z))
+ z = x * g
+
+ return z
+
+
+class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
+ """
+ Implements Algorithm 11.
+ """
+ def _combine_projections(self,
+ a: torch.Tensor, # [*, N_i, N_k, C]
+ b: torch.Tensor, # [*, N_j, N_k, C]
+ ):
+ # [*, C, N_i, N_j]
+ p = torch.matmul(
+ permute_final_dims(a, (2, 0, 1)),
+ permute_final_dims(b, (2, 1, 0)),
+ )
+
+ # [*, N_i, N_j, C]
+ return permute_final_dims(p, (1, 2, 0))
+
+
+class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
+ """
+ Implements Algorithm 12.
+ """
+ def _combine_projections(self,
+ a: torch.Tensor, # [*, N_k, N_i, C]
+ b: torch.Tensor, # [*, N_k, N_j, C]
+ ):
+ # [*, C, N_i, N_j]
+ p = torch.matmul(
+ permute_final_dims(a, (2, 1, 0)),
+ permute_final_dims(b, (2, 0, 1)),
+ )
+
+ # [*, N_i, N_j, C]
+ return permute_final_dims(p, (1, 2, 0))
+
diff --git a/openfold/np/__init__.py b/openfold/np/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1
--- /dev/null
+++ b/openfold/np/__init__.py
@@ -0,0 +1,16 @@
+import os
+import glob
+import importlib as importlib
+
+_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
+__all__ = [
+ os.path.basename(f)[:-3]
+ for f in _files
+ if os.path.isfile(f) and not f.endswith("__init__.py")
+]
+_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
+for _m in _modules:
+ globals()[_m[0]] = _m[1]
+
+# Avoid needlessly cluttering the global namespace
+del _files, _m, _modules
diff --git a/openfold/np/__pycache__/__init__.cpython-310.pyc b/openfold/np/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a92827dd828cfadf9345a7606211b74ca975339c
Binary files /dev/null and b/openfold/np/__pycache__/__init__.cpython-310.pyc differ
diff --git a/openfold/np/__pycache__/protein.cpython-310.pyc b/openfold/np/__pycache__/protein.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4502b1209825e70267a93d21b45f8ef9a4b56e7
Binary files /dev/null and b/openfold/np/__pycache__/protein.cpython-310.pyc differ
diff --git a/openfold/np/__pycache__/residue_constants.cpython-310.pyc b/openfold/np/__pycache__/residue_constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b27df6a398b16067a61abd35678619c474276f1
Binary files /dev/null and b/openfold/np/__pycache__/residue_constants.cpython-310.pyc differ
diff --git a/openfold/np/protein.py b/openfold/np/protein.py
new file mode 100644
index 0000000000000000000000000000000000000000..75748b27d768d7765f1823c3eada5f7bcd0e4099
--- /dev/null
+++ b/openfold/np/protein.py
@@ -0,0 +1,438 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Protein data type."""
+import dataclasses
+import io
+from typing import Any, Sequence, Mapping, Optional
+import re
+import string
+
+from openfold.np import residue_constants
+from Bio.PDB import PDBParser
+import numpy as np
+
+
+FeatureDict = Mapping[str, np.ndarray]
+ModelOutput = Mapping[str, Any] # Is a nested dict.
+PICO_TO_ANGSTROM = 0.01
+
+@dataclasses.dataclass(frozen=True)
+class Protein:
+ """Protein structure representation."""
+
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
+
+ # Amino-acid type for each residue represented as an integer between 0 and
+ # 20, where 20 is 'X'.
+ aatype: np.ndarray # [num_res]
+
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
+ # is present and 0.0 if not. This should be used for loss masking.
+ atom_mask: np.ndarray # [num_res, num_atom_type]
+
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
+ residue_index: np.ndarray # [num_res]
+
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
+ # representing the displacement of the residue from its ground truth mean
+ # value.
+ b_factors: np.ndarray # [num_res, num_atom_type]
+
+ # Chain indices for multi-chain predictions
+ chain_index: Optional[np.ndarray] = None
+
+ # Optional remark about the protein. Included as a comment in output PDB
+ # files
+ remark: Optional[str] = None
+
+ # Templates used to generate this protein (prediction-only)
+ parents: Optional[Sequence[str]] = None
+
+ # Chain corresponding to each parent
+ parents_chain_index: Optional[Sequence[int]] = None
+
+
+def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
+ """Takes a PDB string and constructs a Protein object.
+
+ WARNING: All non-standard residue types will be converted into UNK. All
+ non-standard atoms will be ignored.
+
+ Args:
+ pdb_str: The contents of the pdb file
+ chain_id: If None, then the pdb file must contain a single chain (which
+ will be parsed). If chain_id is specified (e.g. A), then only that chain
+ is parsed.
+
+ Returns:
+ A new `Protein` parsed from the pdb contents.
+ """
+ pdb_fh = io.StringIO(pdb_str)
+ parser = PDBParser(QUIET=True)
+ structure = parser.get_structure("none", pdb_fh)
+ models = list(structure.get_models())
+ if len(models) != 1:
+ raise ValueError(
+ f"Only single model PDBs are supported. Found {len(models)} models."
+ )
+ model = models[0]
+
+ atom_positions = []
+ aatype = []
+ atom_mask = []
+ residue_index = []
+ chain_ids = []
+ b_factors = []
+
+ for chain in model:
+ if(chain_id is not None and chain.id != chain_id):
+ continue
+ for res in chain:
+ if res.id[2] != " ":
+ raise ValueError(
+ f"PDB contains an insertion code at chain {chain.id} and residue "
+ f"index {res.id[1]}. These are not supported."
+ )
+ res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
+ restype_idx = residue_constants.restype_order.get(
+ res_shortname, residue_constants.restype_num
+ )
+ pos = np.zeros((residue_constants.atom_type_num, 3))
+ mask = np.zeros((residue_constants.atom_type_num,))
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
+ for atom in res:
+ if atom.name not in residue_constants.atom_types:
+ continue
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
+ mask[residue_constants.atom_order[atom.name]] = 1.0
+ res_b_factors[
+ residue_constants.atom_order[atom.name]
+ ] = atom.bfactor
+ if np.sum(mask) < 0.5:
+ # If no known atom positions are reported for the residue then skip it.
+ continue
+ aatype.append(restype_idx)
+ atom_positions.append(pos)
+ atom_mask.append(mask)
+ residue_index.append(res.id[1])
+ chain_ids.append(chain.id)
+ b_factors.append(res_b_factors)
+
+ parents = None
+ parents_chain_index = None
+ if("PARENT" in pdb_str):
+ parents = []
+ parents_chain_index = []
+ chain_id = 0
+ for l in pdb_str.split("\n"):
+ if("PARENT" in l):
+ if(not "N/A" in l):
+ parent_names = l.split()[1:]
+ parents.extend(parent_names)
+ parents_chain_index.extend([
+ chain_id for _ in parent_names
+ ])
+ chain_id += 1
+
+ unique_chain_ids = np.unique(chain_ids)
+ chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
+
+ return Protein(
+ atom_positions=np.array(atom_positions),
+ atom_mask=np.array(atom_mask),
+ aatype=np.array(aatype),
+ residue_index=np.array(residue_index),
+ chain_index=chain_index,
+ b_factors=np.array(b_factors),
+ parents=parents,
+ parents_chain_index=parents_chain_index,
+ )
+
+
+def from_proteinnet_string(proteinnet_str: str) -> Protein:
+ tag_re = r'(\[[A-Z]+\]\n)'
+ tags = [
+ tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
+ ]
+ groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
+
+ atoms = ['N', 'CA', 'C']
+ aatype = None
+ atom_positions = None
+ atom_mask = None
+ for g in groups:
+ if("[PRIMARY]" == g[0]):
+ seq = g[1][0].strip()
+ for i in range(len(seq)):
+ if(seq[i] not in residue_constants.restypes):
+ seq[i] = 'X'
+ aatype = np.array([
+ residue_constants.restype_order.get(
+ res_symbol, residue_constants.restype_num
+ ) for res_symbol in seq
+ ])
+ elif("[TERTIARY]" == g[0]):
+ tertiary = []
+ for axis in range(3):
+ tertiary.append(list(map(float, g[1][axis].split())))
+ tertiary_np = np.array(tertiary)
+ atom_positions = np.zeros(
+ (len(tertiary[0])//3, residue_constants.atom_type_num, 3)
+ ).astype(np.float32)
+ for i, atom in enumerate(atoms):
+ atom_positions[:, residue_constants.atom_order[atom], :] = (
+ np.transpose(tertiary_np[:, i::3])
+ )
+ atom_positions *= PICO_TO_ANGSTROM
+ elif("[MASK]" == g[0]):
+ mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
+ atom_mask = np.zeros(
+ (len(mask), residue_constants.atom_type_num,)
+ ).astype(np.float32)
+ for i, atom in enumerate(atoms):
+ atom_mask[:, residue_constants.atom_order[atom]] = 1
+ atom_mask *= mask[..., None]
+
+ return Protein(
+ atom_positions=atom_positions,
+ atom_mask=atom_mask,
+ aatype=aatype,
+ residue_index=np.arange(len(aatype)),
+ b_factors=None,
+ )
+
+
+def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
+ pdb_headers = []
+
+ remark = prot.remark
+ if(remark is not None):
+ pdb_headers.append(f"REMARK {remark}")
+
+ parents = prot.parents
+ parents_chain_index = prot.parents_chain_index
+ if(parents_chain_index is not None):
+ parents = [
+ p for i, p in zip(parents_chain_index, parents) if i == chain_id
+ ]
+
+ if(parents is None or len(parents) == 0):
+ parents = ["N/A"]
+
+ pdb_headers.append(f"PARENT {' '.join(parents)}")
+
+ return pdb_headers
+
+
+def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
+ """ Add pdb headers to an existing PDB string. Useful during multi-chain
+ recycling
+ """
+ out_pdb_lines = []
+ lines = pdb_str.split('\n')
+
+ remark = prot.remark
+ if(remark is not None):
+ out_pdb_lines.append(f"REMARK {remark}")
+
+ parents_per_chain = None
+ if(prot.parents is not None and len(prot.parents) > 0):
+ parents_per_chain = []
+ if(prot.parents_chain_index is not None):
+ cur_chain = prot.parents_chain_index[0]
+ parent_dict = {}
+ for p, i in zip(prot.parents, prot.parents_chain_index):
+ parent_dict.setdefault(str(i), [])
+ parent_dict[str(i)].append(p)
+
+ max_idx = max([int(chain_idx) for chain_idx in parent_dict])
+ for i in range(max_idx + 1):
+ chain_parents = parent_dict.get(str(i), ["N/A"])
+ parents_per_chain.append(chain_parents)
+ else:
+ parents_per_chain.append(prot.parents)
+ else:
+ parents_per_chain = [["N/A"]]
+
+ make_parent_line = lambda p: f"PARENT {' '.join(p)}"
+
+ out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
+
+ chain_counter = 0
+ for i, l in enumerate(lines):
+ if("PARENT" not in l and "REMARK" not in l):
+ out_pdb_lines.append(l)
+ if("TER" in l and not "END" in lines[i + 1]):
+ chain_counter += 1
+ if(not chain_counter >= len(parents_per_chain)):
+ chain_parents = parents_per_chain[chain_counter]
+ else:
+ chain_parents = ["N/A"]
+
+ out_pdb_lines.append(make_parent_line(chain_parents))
+
+ return '\n'.join(out_pdb_lines)
+
+
+def to_pdb(prot: Protein) -> str:
+ """Converts a `Protein` instance to a PDB string.
+
+ Args:
+ prot: The protein to convert to PDB.
+
+ Returns:
+ PDB string.
+ """
+ restypes = residue_constants.restypes + ["X"]
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
+ atom_types = residue_constants.atom_types
+
+ pdb_lines = []
+
+ atom_mask = prot.atom_mask
+ aatype = prot.aatype
+ atom_positions = prot.atom_positions
+ residue_index = prot.residue_index.astype(int)
+ b_factors = prot.b_factors
+ chain_index = prot.chain_index
+
+ if np.any(aatype > residue_constants.restype_num):
+ raise ValueError("Invalid aatypes.")
+
+ headers = get_pdb_headers(prot)
+ if(len(headers) > 0):
+ pdb_lines.extend(headers)
+
+ n = aatype.shape[0]
+ atom_index = 1
+ prev_chain_index = 0
+ chain_tags = string.ascii_uppercase
+ # Add all atom sites.
+ for i in range(n):
+ res_name_3 = res_1to3(aatype[i])
+ for atom_name, pos, mask, b_factor in zip(
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]
+ ):
+ if mask < 0.5:
+ continue
+
+ record_type = "ATOM"
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
+ alt_loc = ""
+ insertion_code = ""
+ occupancy = 1.00
+ element = atom_name[
+ 0
+ ] # Protein supports only C, N, O, S, this works.
+ charge = ""
+
+ chain_tag = "A"
+ if(chain_index is not None):
+ chain_tag = chain_tags[chain_index[i]]
+
+ # PDB is a columnar format, every space matters here!
+ atom_line = (
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
+ f"{res_name_3:>3} {chain_tag:>1}"
+ f"{residue_index[i]:>4}{insertion_code:>1} "
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
+ f"{element:>2}{charge:>2}"
+ )
+ pdb_lines.append(atom_line)
+ atom_index += 1
+
+ should_terminate = (i == n - 1)
+ if(chain_index is not None):
+ if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
+ should_terminate = True
+ prev_chain_index = chain_index[i + 1]
+
+ if(should_terminate):
+ # Close the chain.
+ chain_end = "TER"
+ chain_termination_line = (
+ f"{chain_end:<6}{atom_index:>5} "
+ f"{res_1to3(aatype[i]):>3} "
+ f"{chain_tag:>1}{residue_index[i]:>4}"
+ )
+ pdb_lines.append(chain_termination_line)
+ atom_index += 1
+
+ if(i != n - 1):
+ # "prev" is a misnomer here. This happens at the beginning of
+ # each new chain.
+ pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
+
+ pdb_lines.append("END")
+ pdb_lines.append("")
+ return "\n".join(pdb_lines)
+
+
+def ideal_atom_mask(prot: Protein) -> np.ndarray:
+ """Computes an ideal atom mask.
+
+ `Protein.atom_mask` typically is defined according to the atoms that are
+ reported in the PDB. This function computes a mask according to heavy atoms
+ that should be present in the given sequence of amino acids.
+
+ Args:
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
+
+ Returns:
+ An ideal atom mask.
+ """
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
+
+
+def from_prediction(
+ features: FeatureDict,
+ result: ModelOutput,
+ b_factors: Optional[np.ndarray] = None,
+ chain_index: Optional[np.ndarray] = None,
+ remark: Optional[str] = None,
+ parents: Optional[Sequence[str]] = None,
+ parents_chain_index: Optional[Sequence[int]] = None
+) -> Protein:
+ """Assembles a protein from a prediction.
+
+ Args:
+ features: Dictionary holding model inputs.
+ result: Dictionary holding model outputs.
+ b_factors: (Optional) B-factors to use for the protein.
+ chain_index: (Optional) Chain indices for multi-chain predictions
+ remark: (Optional) Remark about the prediction
+ parents: (Optional) List of template names
+ Returns:
+ A protein instance.
+ """
+ if b_factors is None:
+ b_factors = np.zeros_like(result["final_atom_mask"])
+
+ return Protein(
+ aatype=features["aatype"],
+ atom_positions=result["final_atom_positions"],
+ atom_mask=result["final_atom_mask"],
+ residue_index=features["residue_index"] + 1,
+ b_factors=b_factors,
+ chain_index=chain_index,
+ remark=remark,
+ parents=parents,
+ parents_chain_index=parents_chain_index,
+ )
diff --git a/openfold/np/relax/__init__.py b/openfold/np/relax/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d1a5ba4be7300076d6f44af1541a33ca9e4ab1
--- /dev/null
+++ b/openfold/np/relax/__init__.py
@@ -0,0 +1,16 @@
+import os
+import glob
+import importlib as importlib
+
+_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
+__all__ = [
+ os.path.basename(f)[:-3]
+ for f in _files
+ if os.path.isfile(f) and not f.endswith("__init__.py")
+]
+_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
+for _m in _modules:
+ globals()[_m[0]] = _m[1]
+
+# Avoid needlessly cluttering the global namespace
+del _files, _m, _modules
diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f35a2d9fdf56fca3e89d822b7c6a8f56a265ccd
--- /dev/null
+++ b/openfold/np/relax/amber_minimize.py
@@ -0,0 +1,612 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Restrained Amber Minimization of a structure."""
+
+import io
+import time
+from typing import Collection, Optional, Sequence
+
+from absl import logging
+from openfold.np import (
+ protein,
+ residue_constants,
+)
+import openfold.utils.loss as loss
+from openfold.np.relax import cleanup, utils
+import ml_collections
+import numpy as np
+import openmm
+from openmm import unit
+from openmm import app as openmm_app
+from openmm.app.internal.pdbstructure import PdbStructure
+
+ENERGY = unit.kilocalories_per_mole
+LENGTH = unit.angstroms
+
+
+def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
+ """Returns True if the atom will be restrained by the given restraint set."""
+
+ if rset == "non_hydrogen":
+ return atom.element.name != "hydrogen"
+ elif rset == "c_alpha":
+ return atom.name == "CA"
+
+
+def _add_restraints(
+ system: openmm.System,
+ reference_pdb: openmm_app.PDBFile,
+ stiffness: unit.Unit,
+ rset: str,
+ exclude_residues: Sequence[int],
+):
+ """Adds a harmonic potential that restrains the system to a structure."""
+ assert rset in ["non_hydrogen", "c_alpha"]
+
+ force = openmm.CustomExternalForce(
+ "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
+ )
+ force.addGlobalParameter("k", stiffness)
+ for p in ["x0", "y0", "z0"]:
+ force.addPerParticleParameter(p)
+
+ for i, atom in enumerate(reference_pdb.topology.atoms()):
+ if atom.residue.index in exclude_residues:
+ continue
+ if will_restrain(atom, rset):
+ force.addParticle(i, reference_pdb.positions[i])
+ logging.info(
+ "Restraining %d / %d particles.",
+ force.getNumParticles(),
+ system.getNumParticles(),
+ )
+ system.addForce(force)
+
+
+def _openmm_minimize(
+ pdb_str: str,
+ max_iterations: int,
+ tolerance: unit.Unit,
+ stiffness: unit.Unit,
+ restraint_set: str,
+ exclude_residues: Sequence[int],
+ use_gpu: bool,
+):
+ """Minimize energy via openmm."""
+
+ pdb_file = io.StringIO(pdb_str)
+ pdb = openmm_app.PDBFile(pdb_file)
+
+ force_field = openmm_app.ForceField("amber99sb.xml")
+ constraints = openmm_app.HBonds
+ system = force_field.createSystem(pdb.topology, constraints=constraints)
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
+ _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
+
+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
+ platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
+ simulation = openmm_app.Simulation(
+ pdb.topology, system, integrator, platform
+ )
+ simulation.context.setPositions(pdb.positions)
+
+ ret = {}
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
+ simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
+ ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
+ return ret
+
+
+def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
+ """Returns a pdb string provided OpenMM topology and positions."""
+ with io.StringIO() as f:
+ openmm_app.PDBFile.writeFile(topology, positions, f)
+ return f.getvalue()
+
+
+def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
+ """Checks that no atom positions have been altered by cleaning."""
+ cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
+ reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
+
+ cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
+ ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
+
+ for ref_res, cl_res in zip(
+ reference.topology.residues(), cleaned.topology.residues()
+ ):
+ assert ref_res.name == cl_res.name
+ for rat in ref_res.atoms():
+ for cat in cl_res.atoms():
+ if cat.name == rat.name:
+ if not np.array_equal(
+ cl_xyz[cat.index], ref_xyz[rat.index]
+ ):
+ raise ValueError(
+ f"Coordinates of cleaned atom {cat} do not match "
+ f"coordinates of reference atom {rat}."
+ )
+
+
+def _check_residues_are_well_defined(prot: protein.Protein):
+ """Checks that all residues contain non-empty atom sets."""
+ if (prot.atom_mask.sum(axis=-1) == 0).any():
+ raise ValueError(
+ "Amber minimization can only be performed on proteins with"
+ " well-defined residues. This protein contains at least"
+ " one residue with no atoms."
+ )
+
+
+def _check_atom_mask_is_ideal(prot):
+ """Sanity-check the atom mask is ideal, up to a possible OXT."""
+ atom_mask = prot.atom_mask
+ ideal_atom_mask = protein.ideal_atom_mask(prot)
+ utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
+
+
+def clean_protein(prot: protein.Protein, checks: bool = True):
+ """Adds missing atoms to Protein instance.
+
+ Args:
+ prot: A `protein.Protein` instance.
+ checks: A `bool` specifying whether to add additional checks to the cleaning
+ process.
+
+ Returns:
+ pdb_string: A string of the cleaned protein.
+ """
+ _check_atom_mask_is_ideal(prot)
+
+ # Clean pdb.
+ prot_pdb_string = protein.to_pdb(prot)
+ pdb_file = io.StringIO(prot_pdb_string)
+ alterations_info = {}
+ fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
+ fixed_pdb_file = io.StringIO(fixed_pdb)
+ pdb_structure = PdbStructure(fixed_pdb_file)
+ cleanup.clean_structure(pdb_structure, alterations_info)
+
+ logging.info("alterations info: %s", alterations_info)
+
+ # Write pdb file of cleaned structure.
+ as_file = openmm_app.PDBFile(pdb_structure)
+ pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
+ if checks:
+ _check_cleaned_atoms(pdb_string, prot_pdb_string)
+ return pdb_string
+
+
+def make_atom14_positions(prot):
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
+ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
+ restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
+ restype_atom14_mask = []
+
+ for rt in residue_constants.restypes:
+ atom_names = residue_constants.restype_name_to_atom14_names[
+ residue_constants.restype_1to3[rt]
+ ]
+
+ restype_atom14_to_atom37.append(
+ [
+ (residue_constants.atom_order[name] if name else 0)
+ for name in atom_names
+ ]
+ )
+
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
+ restype_atom37_to_atom14.append(
+ [
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
+ for name in residue_constants.atom_types
+ ]
+ )
+
+ restype_atom14_mask.append(
+ [(1.0 if name else 0.0) for name in atom_names]
+ )
+
+ # Add dummy mapping for restype 'UNK'.
+ restype_atom14_to_atom37.append([0] * 14)
+ restype_atom37_to_atom14.append([0] * 37)
+ restype_atom14_mask.append([0.0] * 14)
+
+ restype_atom14_to_atom37 = np.array(
+ restype_atom14_to_atom37, dtype=int
+ )
+ restype_atom37_to_atom14 = np.array(
+ restype_atom37_to_atom14, dtype=int
+ )
+ restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
+
+ # Create the mapping for (residx, atom14) --> atom37, i.e. an array
+ # with shape (num_res, 14) containing the atom37 indices for this protein.
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
+ residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
+
+ # Create a mask for known ground truth positions.
+ residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
+ prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
+ ).astype(np.float32)
+
+ # Gather the ground truth positions.
+ residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
+ np.take_along_axis(
+ prot["all_atom_positions"],
+ residx_atom14_to_atom37[..., None],
+ axis=1,
+ )
+ )
+
+ prot["atom14_atom_exists"] = residx_atom14_mask
+ prot["atom14_gt_exists"] = residx_atom14_gt_mask
+ prot["atom14_gt_positions"] = residx_atom14_gt_positions
+
+ prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
+
+ # Create the gather indices for mapping back.
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
+ prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
+
+ # Create the corresponding mask.
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+ for restype, restype_letter in enumerate(residue_constants.restypes):
+ restype_name = residue_constants.restype_1to3[restype_letter]
+ atom_names = residue_constants.residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = residue_constants.atom_order[atom_name]
+ restype_atom37_mask[restype, atom_type] = 1
+
+ residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
+ prot["atom37_atom_exists"] = residx_atom37_mask
+
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
+ # alternative ground truth coordinates where the naming is swapped
+ restype_3 = [
+ residue_constants.restype_1to3[res]
+ for res in residue_constants.restypes
+ ]
+ restype_3 += ["UNK"]
+
+ # Matrices for renaming ambiguous atoms.
+ all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
+ correspondences = np.arange(14)
+ for source_atom_swap, target_atom_swap in swap.items():
+ source_index = residue_constants.restype_name_to_atom14_names[
+ resname
+ ].index(source_atom_swap)
+ target_index = residue_constants.restype_name_to_atom14_names[
+ resname
+ ].index(target_atom_swap)
+ correspondences[source_index] = target_index
+ correspondences[target_index] = source_index
+ renaming_matrix = np.zeros((14, 14), dtype=np.float32)
+ for index, correspondence in enumerate(correspondences):
+ renaming_matrix[index, correspondence] = 1.0
+ all_matrices[resname] = renaming_matrix.astype(np.float32)
+ renaming_matrices = np.stack(
+ [all_matrices[restype] for restype in restype_3]
+ )
+
+ # Pick the transformation matrices for the given residue sequence
+ # shape (num_res, 14, 14).
+ renaming_transform = renaming_matrices[prot["aatype"]]
+
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
+ alternative_gt_positions = np.einsum(
+ "rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
+ )
+ prot["atom14_alt_gt_positions"] = alternative_gt_positions
+
+ # Create the mask for the alternative ground truth (differs from the
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
+ # ground truth position).
+ alternative_gt_mask = np.einsum(
+ "ra,rab->rb", residx_atom14_gt_mask, renaming_transform
+ )
+
+ prot["atom14_alt_gt_exists"] = alternative_gt_mask
+
+ # Create an ambiguous atoms mask. shape: (21, 14).
+ restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
+ for atom_name1, atom_name2 in swap.items():
+ restype = residue_constants.restype_order[
+ residue_constants.restype_3to1[resname]
+ ]
+ atom_idx1 = residue_constants.restype_name_to_atom14_names[
+ resname
+ ].index(atom_name1)
+ atom_idx2 = residue_constants.restype_name_to_atom14_names[
+ resname
+ ].index(atom_name2)
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
+
+ # From this create an ambiguous_mask for the given sequence.
+ prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
+ prot["aatype"]
+ ]
+
+ return prot
+
+
+def find_violations(prot_np: protein.Protein):
+ """Analyzes a protein and returns structural violation information.
+
+ Args:
+ prot_np: A protein.
+
+ Returns:
+ violations: A `dict` of structure components with structural violations.
+ violation_metrics: A `dict` of violation metrics.
+ """
+ batch = {
+ "aatype": prot_np.aatype,
+ "all_atom_positions": prot_np.atom_positions.astype(np.float32),
+ "all_atom_mask": prot_np.atom_mask.astype(np.float32),
+ "residue_index": prot_np.residue_index,
+ }
+
+ batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
+ batch = make_atom14_positions(batch)
+
+ violations = loss.find_structural_violations_np(
+ batch=batch,
+ atom14_pred_positions=batch["atom14_gt_positions"],
+ config=ml_collections.ConfigDict(
+ {
+ "violation_tolerance_factor": 12, # Taken from model config.
+ "clash_overlap_tolerance": 1.5, # Taken from model config.
+ }
+ ),
+ )
+ violation_metrics = loss.compute_violation_metrics_np(
+ batch=batch,
+ atom14_pred_positions=batch["atom14_gt_positions"],
+ violations=violations,
+ )
+
+ return violations, violation_metrics
+
+
+def get_violation_metrics(prot: protein.Protein):
+ """Computes violation and alignment metrics."""
+ structural_violations, struct_metrics = find_violations(prot)
+ violation_idx = np.flatnonzero(
+ structural_violations["total_per_residue_violations_mask"]
+ )
+
+ struct_metrics["residue_violations"] = violation_idx
+ struct_metrics["num_residue_violations"] = len(violation_idx)
+ struct_metrics["structural_violations"] = structural_violations
+ return struct_metrics
+
+
+def _run_one_iteration(
+ *,
+ pdb_string: str,
+ max_iterations: int,
+ tolerance: float,
+ stiffness: float,
+ restraint_set: str,
+ max_attempts: int,
+ exclude_residues: Optional[Collection[int]] = None,
+ use_gpu: bool,
+):
+ """Runs the minimization pipeline.
+
+ Args:
+ pdb_string: A pdb string.
+ max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
+ A value of 0 specifies no limit.
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
+ potential.
+ restraint_set: The set of atoms to restrain.
+ max_attempts: The maximum number of minimization attempts.
+ exclude_residues: An optional list of zero-indexed residues to exclude from
+ restraints.
+ use_gpu: Whether to run relaxation on GPU
+ Returns:
+ A `dict` of minimization info.
+ """
+ exclude_residues = exclude_residues or []
+
+ # Assign physical dimensions.
+ tolerance = tolerance * ENERGY
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
+
+ start = time.perf_counter()
+ minimized = False
+ attempts = 0
+ while not minimized and attempts < max_attempts:
+ attempts += 1
+ try:
+ logging.info(
+ "Minimizing protein, attempt %d of %d.", attempts, max_attempts
+ )
+ ret = _openmm_minimize(
+ pdb_string,
+ max_iterations=max_iterations,
+ tolerance=tolerance,
+ stiffness=stiffness,
+ restraint_set=restraint_set,
+ exclude_residues=exclude_residues,
+ use_gpu=use_gpu,
+ )
+ minimized = True
+ except Exception as e: # pylint: disable=broad-except
+ print(e)
+ logging.info(e)
+ if not minimized:
+ raise ValueError(f"Minimization failed after {max_attempts} attempts.")
+ ret["opt_time"] = time.perf_counter() - start
+ ret["min_attempts"] = attempts
+ return ret
+
+
+def run_pipeline(
+ prot: protein.Protein,
+ stiffness: float,
+ use_gpu: bool,
+ max_outer_iterations: int = 1,
+ place_hydrogens_every_iteration: bool = True,
+ max_iterations: int = 0,
+ tolerance: float = 2.39,
+ restraint_set: str = "non_hydrogen",
+ max_attempts: int = 100,
+ checks: bool = True,
+ exclude_residues: Optional[Sequence[int]] = None,
+):
+ """Run iterative amber relax.
+
+ Successive relax iterations are performed until all violations have been
+ resolved. Each iteration involves a restrained Amber minimization, with
+ restraint exclusions determined by violation-participating residues.
+
+ Args:
+ prot: A protein to be relaxed.
+ stiffness: kcal/mol A**2, the restraint stiffness.
+ use_gpu: Whether to run on GPU
+ max_outer_iterations: The maximum number of iterative minimization.
+ place_hydrogens_every_iteration: Whether hydrogens are re-initialized
+ prior to every minimization.
+ max_iterations: An `int` specifying the maximum number of L-BFGS steps
+ per relax iteration. A value of 0 specifies no limit.
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
+ The default value is the OpenMM default.
+ restraint_set: The set of atoms to restrain.
+ max_attempts: The maximum number of minimization attempts per iteration.
+ checks: Whether to perform cleaning checks.
+ exclude_residues: An optional list of zero-indexed residues to exclude from
+ restraints.
+
+ Returns:
+ out: A dictionary of output values.
+ """
+
+ # `protein.to_pdb` will strip any poorly-defined residues so we need to
+ # perform this check before `clean_protein`.
+ _check_residues_are_well_defined(prot)
+ pdb_string = clean_protein(prot, checks=checks)
+
+ exclude_residues = exclude_residues or []
+ exclude_residues = set(exclude_residues)
+ violations = np.inf
+ iteration = 0
+
+ while violations > 0 and iteration < max_outer_iterations:
+ ret = _run_one_iteration(
+ pdb_string=pdb_string,
+ exclude_residues=exclude_residues,
+ max_iterations=max_iterations,
+ tolerance=tolerance,
+ stiffness=stiffness,
+ restraint_set=restraint_set,
+ max_attempts=max_attempts,
+ use_gpu=use_gpu,
+ )
+ prot = protein.from_pdb_string(ret["min_pdb"])
+ if place_hydrogens_every_iteration:
+ pdb_string = clean_protein(prot, checks=True)
+ else:
+ pdb_string = ret["min_pdb"]
+ ret.update(get_violation_metrics(prot))
+ ret.update(
+ {
+ "num_exclusions": len(exclude_residues),
+ "iteration": iteration,
+ }
+ )
+ violations = ret["violations_per_residue"]
+ exclude_residues = exclude_residues.union(ret["residue_violations"])
+
+ logging.info(
+ "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
+ "num residue violations %d num residue exclusions %d ",
+ ret["einit"],
+ ret["efinal"],
+ ret["opt_time"],
+ ret["num_residue_violations"],
+ ret["num_exclusions"],
+ )
+ iteration += 1
+ return ret
+
+
+def get_initial_energies(
+ pdb_strs: Sequence[str],
+ stiffness: float = 0.0,
+ restraint_set: str = "non_hydrogen",
+ exclude_residues: Optional[Sequence[int]] = None,
+):
+ """Returns initial potential energies for a sequence of PDBs.
+
+ Assumes the input PDBs are ready for minimization, and all have the same
+ topology.
+ Allows time to be saved by not pdbfixing / rebuilding the system.
+
+ Args:
+ pdb_strs: List of PDB strings.
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
+ potential.
+ restraint_set: Which atom types to restrain.
+ exclude_residues: An optional list of zero-indexed residues to exclude from
+ restraints.
+
+ Returns:
+ A list of initial energies in the same order as pdb_strs.
+ """
+ exclude_residues = exclude_residues or []
+
+ openmm_pdbs = [
+ openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
+ ]
+ force_field = openmm_app.ForceField("amber99sb.xml")
+ system = force_field.createSystem(
+ openmm_pdbs[0].topology, constraints=openmm_app.HBonds
+ )
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
+ _add_restraints(
+ system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
+ )
+ simulation = openmm_app.Simulation(
+ openmm_pdbs[0].topology,
+ system,
+ openmm.LangevinIntegrator(0, 0.01, 0.0),
+ openmm.Platform.getPlatformByName("CPU"),
+ )
+ energies = []
+ for pdb in openmm_pdbs:
+ try:
+ simulation.context.setPositions(pdb.positions)
+ state = simulation.context.getState(getEnergy=True)
+ energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(
+ "Error getting initial energy, returning large value %s", e
+ )
+ energies.append(unit.Quantity(1e20, ENERGY))
+ return energies
diff --git a/openfold/np/relax/cleanup.py b/openfold/np/relax/cleanup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1f59a48b3b2184c68ccbbce9048b6a137086ca4
--- /dev/null
+++ b/openfold/np/relax/cleanup.py
@@ -0,0 +1,131 @@
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
+
+fix_pdb uses a third-party tool. We also support fixing some additional edge
+cases like removing chains of length one (see clean_structure).
+"""
+import io
+
+import pdbfixer
+from simtk.openmm import app
+from simtk.openmm.app import element
+
+
+def fix_pdb(pdbfile, alterations_info):
+ """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
+
+ 1) Replaces nonstandard residues.
+ 2) Removes heterogens (non protein residues) including water.
+ 3) Adds missing residues and missing atoms within existing residues.
+ 4) Adds hydrogens assuming pH=7.0.
+ 5) KeepIds is currently true, so the fixer must keep the existing chain and
+ residue identifiers. This will fail for some files in wider PDB that have
+ invalid IDs.
+
+ Args:
+ pdbfile: Input PDB file handle.
+ alterations_info: A dict that will store details of changes made.
+
+ Returns:
+ A PDB string representing the fixed structure.
+ """
+ fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
+ fixer.findNonstandardResidues()
+ alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
+ fixer.replaceNonstandardResidues()
+ _remove_heterogens(fixer, alterations_info, keep_water=False)
+ fixer.findMissingResidues()
+ alterations_info["missing_residues"] = fixer.missingResidues
+ fixer.findMissingAtoms()
+ alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
+ alterations_info["missing_terminals"] = fixer.missingTerminals
+ fixer.addMissingAtoms(seed=0)
+ fixer.addMissingHydrogens()
+ out_handle = io.StringIO()
+ app.PDBFile.writeFile(
+ fixer.topology, fixer.positions, out_handle, keepIds=True
+ )
+ return out_handle.getvalue()
+
+
+def clean_structure(pdb_structure, alterations_info):
+ """Applies additional fixes to an OpenMM structure, to handle edge cases.
+
+ Args:
+ pdb_structure: An OpenMM structure to modify and fix.
+ alterations_info: A dict that will store details of changes made.
+ """
+ _replace_met_se(pdb_structure, alterations_info)
+ _remove_chains_of_length_one(pdb_structure, alterations_info)
+
+
+def _remove_heterogens(fixer, alterations_info, keep_water):
+ """Removes the residues that Pdbfixer considers to be heterogens.
+
+ Args:
+ fixer: A Pdbfixer instance.
+ alterations_info: A dict that will store details of changes made.
+ keep_water: If True, water (HOH) is not considered to be a heterogen.
+ """
+ initial_resnames = set()
+ for chain in fixer.topology.chains():
+ for residue in chain.residues():
+ initial_resnames.add(residue.name)
+ fixer.removeHeterogens(keepWater=keep_water)
+ final_resnames = set()
+ for chain in fixer.topology.chains():
+ for residue in chain.residues():
+ final_resnames.add(residue.name)
+ alterations_info["removed_heterogens"] = initial_resnames.difference(
+ final_resnames
+ )
+
+
+def _replace_met_se(pdb_structure, alterations_info):
+ """Replace the Se in any MET residues that were not marked as modified."""
+ modified_met_residues = []
+ for res in pdb_structure.iter_residues():
+ name = res.get_name_with_spaces().strip()
+ if name == "MET":
+ s_atom = res.get_atom("SD")
+ if s_atom.element_symbol == "Se":
+ s_atom.element_symbol = "S"
+ s_atom.element = element.get_by_symbol("S")
+ modified_met_residues.append(s_atom.residue_number)
+ alterations_info["Se_in_MET"] = modified_met_residues
+
+
+def _remove_chains_of_length_one(pdb_structure, alterations_info):
+ """Removes chains that correspond to a single amino acid.
+
+ A single amino acid in a chain is both N and C terminus. There is no force
+ template for this case.
+
+ Args:
+ pdb_structure: An OpenMM pdb_structure to modify and fix.
+ alterations_info: A dict that will store details of changes made.
+ """
+ removed_chains = {}
+ for model in pdb_structure.iter_models():
+ valid_chains = [c for c in model.iter_chains() if len(c) > 1]
+ invalid_chain_ids = [
+ c.chain_id for c in model.iter_chains() if len(c) <= 1
+ ]
+ model.chains = valid_chains
+ for chain_id in invalid_chain_ids:
+ model.chains_by_id.pop(chain_id)
+ removed_chains[model.number] = invalid_chain_ids
+ alterations_info["removed_chains"] = removed_chains
diff --git a/openfold/np/relax/relax.py b/openfold/np/relax/relax.py
new file mode 100644
index 0000000000000000000000000000000000000000..8feee742d033ee3165f5975cb14e13c599bfb364
--- /dev/null
+++ b/openfold/np/relax/relax.py
@@ -0,0 +1,90 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Amber relaxation."""
+from typing import Any, Dict, Sequence, Tuple
+from openfold.np import protein
+from openfold.np.relax import amber_minimize, utils
+import numpy as np
+
+
+class AmberRelaxation(object):
+ """Amber relaxation."""
+ def __init__(
+ self,
+ *,
+ max_iterations: int,
+ tolerance: float,
+ stiffness: float,
+ exclude_residues: Sequence[int],
+ max_outer_iterations: int,
+ use_gpu: bool,
+ ):
+ """Initialize Amber Relaxer.
+
+ Args:
+ max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
+ potential.
+ exclude_residues: Residues to exclude from per-atom restraining.
+ Zero-indexed.
+ max_outer_iterations: Maximum number of violation-informed relax
+ iterations. A value of 1 will run the non-iterative procedure used in
+ CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
+ as soon as there are no violations, hence in most cases this causes no
+ slowdown. In the worst case we do 20 outer iterations.
+ use_gpu: Whether to run on GPU
+ """
+
+ self._max_iterations = max_iterations
+ self._tolerance = tolerance
+ self._stiffness = stiffness
+ self._exclude_residues = exclude_residues
+ self._max_outer_iterations = max_outer_iterations
+ self._use_gpu = use_gpu
+
+ def process(
+ self, *, prot: protein.Protein
+ ) -> Tuple[str, Dict[str, Any], np.ndarray]:
+ """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
+ out = amber_minimize.run_pipeline(
+ prot=prot,
+ max_iterations=self._max_iterations,
+ tolerance=self._tolerance,
+ stiffness=self._stiffness,
+ exclude_residues=self._exclude_residues,
+ max_outer_iterations=self._max_outer_iterations,
+ use_gpu=self._use_gpu,
+ )
+ min_pos = out["pos"]
+ start_pos = out["posinit"]
+ rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
+ debug_data = {
+ "initial_energy": out["einit"],
+ "final_energy": out["efinal"],
+ "attempts": out["min_attempts"],
+ "rmsd": rmsd,
+ }
+ pdb_str = amber_minimize.clean_protein(prot)
+ min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
+ min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
+ utils.assert_equal_nonterminal_atom_types(
+ protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
+ )
+ violations = out["structural_violations"][
+ "total_per_residue_violations_mask"
+ ]
+ return min_pdb, debug_data, violations
diff --git a/openfold/np/relax/utils.py b/openfold/np/relax/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e41aea524a97ddd563dc32963b4671f438c7ae4
--- /dev/null
+++ b/openfold/np/relax/utils.py
@@ -0,0 +1,88 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for minimization."""
+import io
+from openfold.np import residue_constants
+from Bio import PDB
+import numpy as np
+# simtk.openmm is not supported anymore. Remove simtk.
+# https://github.com/openmm/openmm/releases
+from openmm import app as openmm_app
+from openmm.app.internal.pdbstructure import PdbStructure
+
+
+def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
+ pdb_file = io.StringIO(pdb_str)
+ structure = PdbStructure(pdb_file)
+ topology = openmm_app.PDBFile(structure).getTopology()
+ with io.StringIO() as f:
+ openmm_app.PDBFile.writeFile(topology, pos, f)
+ return f.getvalue()
+
+
+def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
+ """Overwrites the B-factors in pdb_str with contents of bfactors array.
+
+ Args:
+ pdb_str: An input PDB string.
+ bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
+ B-factors are per residue; i.e. that the nonzero entries are identical in
+ [0, i, :].
+
+ Returns:
+ A new PDB string with the B-factors replaced.
+ """
+ if bfactors.shape[-1] != residue_constants.atom_type_num:
+ raise ValueError(
+ f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
+ )
+
+ parser = PDB.PDBParser(QUIET=True)
+ handle = io.StringIO(pdb_str)
+ structure = parser.get_structure("", handle)
+
+ curr_resid = ("", "", "")
+ idx = -1
+ for atom in structure.get_atoms():
+ atom_resid = atom.parent.get_id()
+ if atom_resid != curr_resid:
+ idx += 1
+ if idx >= bfactors.shape[0]:
+ raise ValueError(
+ "Index into bfactors exceeds number of residues. "
+ "B-factors shape: {shape}, idx: {idx}."
+ )
+ curr_resid = atom_resid
+ atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
+
+ new_pdb = io.StringIO()
+ pdb_io = PDB.PDBIO()
+ pdb_io.set_structure(structure)
+ pdb_io.save(new_pdb)
+ return new_pdb.getvalue()
+
+
+def assert_equal_nonterminal_atom_types(
+ atom_mask: np.ndarray, ref_atom_mask: np.ndarray
+):
+ """Checks that pre- and post-minimized proteins have same atom set."""
+ # Ignore any terminal OXT atoms which may have been added by minimization.
+ oxt = residue_constants.atom_order["OXT"]
+ no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
+ no_oxt_mask[..., oxt] = False
+ np.testing.assert_almost_equal(
+ ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
+ )
diff --git a/openfold/np/residue_constants.py b/openfold/np/residue_constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3eae6fe48634e2dee88375c4a38c293305c5b0
--- /dev/null
+++ b/openfold/np/residue_constants.py
@@ -0,0 +1,1310 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Constants used in AlphaFold."""
+
+import collections
+import functools
+from typing import Mapping, List, Tuple
+from importlib import resources
+
+import numpy as np
+import tree
+
+# Internal import (35fd).
+
+
+# Distance from one CA to next CA [trans configuration: omega = 180].
+ca_ca = 3.80209737096
+
+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
+# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
+# chi angles so their chi angle lists are empty.
+chi_angles_atoms = {
+ "ALA": [],
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
+ "ARG": [
+ ["N", "CA", "CB", "CG"],
+ ["CA", "CB", "CG", "CD"],
+ ["CB", "CG", "CD", "NE"],
+ ["CG", "CD", "NE", "CZ"],
+ ],
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
+ "CYS": [["N", "CA", "CB", "SG"]],
+ "GLN": [
+ ["N", "CA", "CB", "CG"],
+ ["CA", "CB", "CG", "CD"],
+ ["CB", "CG", "CD", "OE1"],
+ ],
+ "GLU": [
+ ["N", "CA", "CB", "CG"],
+ ["CA", "CB", "CG", "CD"],
+ ["CB", "CG", "CD", "OE1"],
+ ],
+ "GLY": [],
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "LYS": [
+ ["N", "CA", "CB", "CG"],
+ ["CA", "CB", "CG", "CD"],
+ ["CB", "CG", "CD", "CE"],
+ ["CG", "CD", "CE", "NZ"],
+ ],
+ "MET": [
+ ["N", "CA", "CB", "CG"],
+ ["CA", "CB", "CG", "SD"],
+ ["CB", "CG", "SD", "CE"],
+ ],
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
+ "SER": [["N", "CA", "CB", "OG"]],
+ "THR": [["N", "CA", "CB", "OG1"]],
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "VAL": [["N", "CA", "CB", "CG1"]],
+}
+
+# If chi angles given in fixed-length array, this matrix determines how to mask
+# them for each AA type. The order is as per restype_order (see below).
+chi_angles_mask = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [1.0, 1.0, 1.0, 1.0], # ARG
+ [1.0, 1.0, 0.0, 0.0], # ASN
+ [1.0, 1.0, 0.0, 0.0], # ASP
+ [1.0, 0.0, 0.0, 0.0], # CYS
+ [1.0, 1.0, 1.0, 0.0], # GLN
+ [1.0, 1.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [1.0, 1.0, 0.0, 0.0], # HIS
+ [1.0, 1.0, 0.0, 0.0], # ILE
+ [1.0, 1.0, 0.0, 0.0], # LEU
+ [1.0, 1.0, 1.0, 1.0], # LYS
+ [1.0, 1.0, 1.0, 0.0], # MET
+ [1.0, 1.0, 0.0, 0.0], # PHE
+ [1.0, 1.0, 0.0, 0.0], # PRO
+ [1.0, 0.0, 0.0, 0.0], # SER
+ [1.0, 0.0, 0.0, 0.0], # THR
+ [1.0, 1.0, 0.0, 0.0], # TRP
+ [1.0, 1.0, 0.0, 0.0], # TYR
+ [1.0, 0.0, 0.0, 0.0], # VAL
+]
+
+# The following chi angles are pi periodic: they can be rotated by a multiple
+# of pi without affecting the structure.
+chi_pi_periodic = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [0.0, 0.0, 0.0, 0.0], # ARG
+ [0.0, 0.0, 0.0, 0.0], # ASN
+ [0.0, 1.0, 0.0, 0.0], # ASP
+ [0.0, 0.0, 0.0, 0.0], # CYS
+ [0.0, 0.0, 0.0, 0.0], # GLN
+ [0.0, 0.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [0.0, 0.0, 0.0, 0.0], # HIS
+ [0.0, 0.0, 0.0, 0.0], # ILE
+ [0.0, 0.0, 0.0, 0.0], # LEU
+ [0.0, 0.0, 0.0, 0.0], # LYS
+ [0.0, 0.0, 0.0, 0.0], # MET
+ [0.0, 1.0, 0.0, 0.0], # PHE
+ [0.0, 0.0, 0.0, 0.0], # PRO
+ [0.0, 0.0, 0.0, 0.0], # SER
+ [0.0, 0.0, 0.0, 0.0], # THR
+ [0.0, 0.0, 0.0, 0.0], # TRP
+ [0.0, 1.0, 0.0, 0.0], # TYR
+ [0.0, 0.0, 0.0, 0.0], # VAL
+ [0.0, 0.0, 0.0, 0.0], # UNK
+]
+
+# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
+# psi and chi angles:
+# 0: 'backbone group',
+# 1: 'pre-omega-group', (empty)
+# 2: 'phi-group', (currently empty, because it defines only hydrogens)
+# 3: 'psi-group',
+# 4,5,6,7: 'chi1,2,3,4-group'
+# The atom positions are relative to the axis-end-atom of the corresponding
+# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
+# is defined such that the dihedral-angle-definiting atom (the last entry in
+# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
+# format: [atomname, group_idx, rel_position]
+rigid_group_atom_positions = {
+ "ALA": [
+ ["N", 0, (-0.525, 1.363, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, -0.000, -0.000)],
+ ["CB", 0, (-0.529, -0.774, -1.205)],
+ ["O", 3, (0.627, 1.062, 0.000)],
+ ],
+ "ARG": [
+ ["N", 0, (-0.524, 1.362, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, -0.000, -0.000)],
+ ["CB", 0, (-0.524, -0.778, -1.209)],
+ ["O", 3, (0.626, 1.062, 0.000)],
+ ["CG", 4, (0.616, 1.390, -0.000)],
+ ["CD", 5, (0.564, 1.414, 0.000)],
+ ["NE", 6, (0.539, 1.357, -0.000)],
+ ["NH1", 7, (0.206, 2.301, 0.000)],
+ ["NH2", 7, (2.078, 0.978, -0.000)],
+ ["CZ", 7, (0.758, 1.093, -0.000)],
+ ],
+ "ASN": [
+ ["N", 0, (-0.536, 1.357, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, -0.000, -0.000)],
+ ["CB", 0, (-0.531, -0.787, -1.200)],
+ ["O", 3, (0.625, 1.062, 0.000)],
+ ["CG", 4, (0.584, 1.399, 0.000)],
+ ["ND2", 5, (0.593, -1.188, 0.001)],
+ ["OD1", 5, (0.633, 1.059, 0.000)],
+ ],
+ "ASP": [
+ ["N", 0, (-0.525, 1.362, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.527, 0.000, -0.000)],
+ ["CB", 0, (-0.526, -0.778, -1.208)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ["CG", 4, (0.593, 1.398, -0.000)],
+ ["OD1", 5, (0.610, 1.091, 0.000)],
+ ["OD2", 5, (0.592, -1.101, -0.003)],
+ ],
+ "CYS": [
+ ["N", 0, (-0.522, 1.362, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.524, 0.000, 0.000)],
+ ["CB", 0, (-0.519, -0.773, -1.212)],
+ ["O", 3, (0.625, 1.062, -0.000)],
+ ["SG", 4, (0.728, 1.653, 0.000)],
+ ],
+ "GLN": [
+ ["N", 0, (-0.526, 1.361, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, 0.000, 0.000)],
+ ["CB", 0, (-0.525, -0.779, -1.207)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ["CG", 4, (0.615, 1.393, 0.000)],
+ ["CD", 5, (0.587, 1.399, -0.000)],
+ ["NE2", 6, (0.593, -1.189, -0.001)],
+ ["OE1", 6, (0.634, 1.060, 0.000)],
+ ],
+ "GLU": [
+ ["N", 0, (-0.528, 1.361, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, -0.000, -0.000)],
+ ["CB", 0, (-0.526, -0.781, -1.207)],
+ ["O", 3, (0.626, 1.062, 0.000)],
+ ["CG", 4, (0.615, 1.392, 0.000)],
+ ["CD", 5, (0.600, 1.397, 0.000)],
+ ["OE1", 6, (0.607, 1.095, -0.000)],
+ ["OE2", 6, (0.589, -1.104, -0.001)],
+ ],
+ "GLY": [
+ ["N", 0, (-0.572, 1.337, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.517, -0.000, -0.000)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ],
+ "HIS": [
+ ["N", 0, (-0.527, 1.360, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, 0.000, 0.000)],
+ ["CB", 0, (-0.525, -0.778, -1.208)],
+ ["O", 3, (0.625, 1.063, 0.000)],
+ ["CG", 4, (0.600, 1.370, -0.000)],
+ ["CD2", 5, (0.889, -1.021, 0.003)],
+ ["ND1", 5, (0.744, 1.160, -0.000)],
+ ["CE1", 5, (2.030, 0.851, 0.002)],
+ ["NE2", 5, (2.145, -0.466, 0.004)],
+ ],
+ "ILE": [
+ ["N", 0, (-0.493, 1.373, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.527, -0.000, -0.000)],
+ ["CB", 0, (-0.536, -0.793, -1.213)],
+ ["O", 3, (0.627, 1.062, -0.000)],
+ ["CG1", 4, (0.534, 1.437, -0.000)],
+ ["CG2", 4, (0.540, -0.785, -1.199)],
+ ["CD1", 5, (0.619, 1.391, 0.000)],
+ ],
+ "LEU": [
+ ["N", 0, (-0.520, 1.363, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, -0.000, -0.000)],
+ ["CB", 0, (-0.522, -0.773, -1.214)],
+ ["O", 3, (0.625, 1.063, -0.000)],
+ ["CG", 4, (0.678, 1.371, 0.000)],
+ ["CD1", 5, (0.530, 1.430, -0.000)],
+ ["CD2", 5, (0.535, -0.774, 1.200)],
+ ],
+ "LYS": [
+ ["N", 0, (-0.526, 1.362, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, 0.000, 0.000)],
+ ["CB", 0, (-0.524, -0.778, -1.208)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ["CG", 4, (0.619, 1.390, 0.000)],
+ ["CD", 5, (0.559, 1.417, 0.000)],
+ ["CE", 6, (0.560, 1.416, 0.000)],
+ ["NZ", 7, (0.554, 1.387, 0.000)],
+ ],
+ "MET": [
+ ["N", 0, (-0.521, 1.364, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, 0.000, 0.000)],
+ ["CB", 0, (-0.523, -0.776, -1.210)],
+ ["O", 3, (0.625, 1.062, -0.000)],
+ ["CG", 4, (0.613, 1.391, -0.000)],
+ ["SD", 5, (0.703, 1.695, 0.000)],
+ ["CE", 6, (0.320, 1.786, -0.000)],
+ ],
+ "PHE": [
+ ["N", 0, (-0.518, 1.363, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.524, 0.000, -0.000)],
+ ["CB", 0, (-0.525, -0.776, -1.212)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ["CG", 4, (0.607, 1.377, 0.000)],
+ ["CD1", 5, (0.709, 1.195, -0.000)],
+ ["CD2", 5, (0.706, -1.196, 0.000)],
+ ["CE1", 5, (2.102, 1.198, -0.000)],
+ ["CE2", 5, (2.098, -1.201, -0.000)],
+ ["CZ", 5, (2.794, -0.003, -0.001)],
+ ],
+ "PRO": [
+ ["N", 0, (-0.566, 1.351, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.527, -0.000, 0.000)],
+ ["CB", 0, (-0.546, -0.611, -1.293)],
+ ["O", 3, (0.621, 1.066, 0.000)],
+ ["CG", 4, (0.382, 1.445, 0.0)],
+ # ['CD', 5, (0.427, 1.440, 0.0)],
+ ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
+ ],
+ "SER": [
+ ["N", 0, (-0.529, 1.360, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, -0.000, -0.000)],
+ ["CB", 0, (-0.518, -0.777, -1.211)],
+ ["O", 3, (0.626, 1.062, -0.000)],
+ ["OG", 4, (0.503, 1.325, 0.000)],
+ ],
+ "THR": [
+ ["N", 0, (-0.517, 1.364, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.526, 0.000, -0.000)],
+ ["CB", 0, (-0.516, -0.793, -1.215)],
+ ["O", 3, (0.626, 1.062, 0.000)],
+ ["CG2", 4, (0.550, -0.718, -1.228)],
+ ["OG1", 4, (0.472, 1.353, 0.000)],
+ ],
+ "TRP": [
+ ["N", 0, (-0.521, 1.363, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.525, -0.000, 0.000)],
+ ["CB", 0, (-0.523, -0.776, -1.212)],
+ ["O", 3, (0.627, 1.062, 0.000)],
+ ["CG", 4, (0.609, 1.370, -0.000)],
+ ["CD1", 5, (0.824, 1.091, 0.000)],
+ ["CD2", 5, (0.854, -1.148, -0.005)],
+ ["CE2", 5, (2.186, -0.678, -0.007)],
+ ["CE3", 5, (0.622, -2.530, -0.007)],
+ ["NE1", 5, (2.140, 0.690, -0.004)],
+ ["CH2", 5, (3.028, -2.890, -0.013)],
+ ["CZ2", 5, (3.283, -1.543, -0.011)],
+ ["CZ3", 5, (1.715, -3.389, -0.011)],
+ ],
+ "TYR": [
+ ["N", 0, (-0.522, 1.362, 0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.524, -0.000, -0.000)],
+ ["CB", 0, (-0.522, -0.776, -1.213)],
+ ["O", 3, (0.627, 1.062, -0.000)],
+ ["CG", 4, (0.607, 1.382, -0.000)],
+ ["CD1", 5, (0.716, 1.195, -0.000)],
+ ["CD2", 5, (0.713, -1.194, -0.001)],
+ ["CE1", 5, (2.107, 1.200, -0.002)],
+ ["CE2", 5, (2.104, -1.201, -0.003)],
+ ["OH", 5, (4.168, -0.002, -0.005)],
+ ["CZ", 5, (2.791, -0.001, -0.003)],
+ ],
+ "VAL": [
+ ["N", 0, (-0.494, 1.373, -0.000)],
+ ["CA", 0, (0.000, 0.000, 0.000)],
+ ["C", 0, (1.527, -0.000, -0.000)],
+ ["CB", 0, (-0.533, -0.795, -1.213)],
+ ["O", 3, (0.627, 1.062, -0.000)],
+ ["CG1", 4, (0.540, 1.429, -0.000)],
+ ["CG2", 4, (0.533, -0.776, 1.203)],
+ ],
+}
+
+# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
+residue_atoms = {
+ "ALA": ["C", "CA", "CB", "N", "O"],
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
+ "GLY": ["C", "CA", "N", "O"],
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
+ "TRP": [
+ "C",
+ "CA",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "CE2",
+ "CE3",
+ "CZ2",
+ "CZ3",
+ "CH2",
+ "N",
+ "NE1",
+ "O",
+ ],
+ "TYR": [
+ "C",
+ "CA",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "CE1",
+ "CE2",
+ "CZ",
+ "N",
+ "O",
+ "OH",
+ ],
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
+}
+
+# Naming swaps for ambiguous atom names.
+# Due to symmetries in the amino acids the naming of atoms is ambiguous in
+# 4 of the 20 amino acids.
+# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
+# in LEU, VAL and ARG can be resolved by using the 3d constellations of
+# the 'ambiguous' atoms and their neighbours)
+# TODO: ^ interpret this
+residue_atom_renaming_swaps = {
+ "ASP": {"OD1": "OD2"},
+ "GLU": {"OE1": "OE2"},
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
+}
+
+# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
+van_der_waals_radius = {
+ "C": 1.7,
+ "N": 1.55,
+ "O": 1.52,
+ "S": 1.8,
+}
+
+Bond = collections.namedtuple(
+ "Bond", ["atom1_name", "atom2_name", "length", "stddev"]
+)
+BondAngle = collections.namedtuple(
+ "BondAngle",
+ ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
+)
+
+
+@functools.lru_cache(maxsize=None)
+def load_stereo_chemical_props() -> Tuple[
+ Mapping[str, List[Bond]],
+ Mapping[str, List[Bond]],
+ Mapping[str, List[BondAngle]],
+]:
+ """Load stereo_chemical_props.txt into a nice structure.
+
+ Load literature values for bond lengths and bond angles and translate
+ bond angles into the length of the opposite edge of the triangle
+ ("residue_virtual_bonds").
+
+ Returns:
+ residue_bonds: dict that maps resname --> list of Bond tuples
+ residue_virtual_bonds: dict that maps resname --> list of Bond tuples
+ residue_bond_angles: dict that maps resname --> list of BondAngle tuples
+ """
+ # TODO: this file should be downloaded in a setup script
+ stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
+
+ lines_iter = iter(stereo_chemical_props.splitlines())
+ # Load bond lengths.
+ residue_bonds = {}
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == "-":
+ break
+ bond, resname, length, stddev = line.split()
+ atom1, atom2 = bond.split("-")
+ if resname not in residue_bonds:
+ residue_bonds[resname] = []
+ residue_bonds[resname].append(
+ Bond(atom1, atom2, float(length), float(stddev))
+ )
+ residue_bonds["UNK"] = []
+
+ # Load bond angles.
+ residue_bond_angles = {}
+ next(lines_iter) # Skip empty line.
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == "-":
+ break
+ bond, resname, angle_degree, stddev_degree = line.split()
+ atom1, atom2, atom3 = bond.split("-")
+ if resname not in residue_bond_angles:
+ residue_bond_angles[resname] = []
+ residue_bond_angles[resname].append(
+ BondAngle(
+ atom1,
+ atom2,
+ atom3,
+ float(angle_degree) / 180.0 * np.pi,
+ float(stddev_degree) / 180.0 * np.pi,
+ )
+ )
+ residue_bond_angles["UNK"] = []
+
+ def make_bond_key(atom1_name, atom2_name):
+ """Unique key to lookup bonds."""
+ return "-".join(sorted([atom1_name, atom2_name]))
+
+ # Translate bond angles into distances ("virtual bonds").
+ residue_virtual_bonds = {}
+ for resname, bond_angles in residue_bond_angles.items():
+ # Create a fast lookup dict for bond lengths.
+ bond_cache = {}
+ for b in residue_bonds[resname]:
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
+ residue_virtual_bonds[resname] = []
+ for ba in bond_angles:
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
+
+ # Compute distance between atom1 and atom3 using the law of cosines
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
+ gamma = ba.angle_rad
+ length = np.sqrt(
+ bond1.length ** 2
+ + bond2.length ** 2
+ - 2 * bond1.length * bond2.length * np.cos(gamma)
+ )
+
+ # Propagation of uncertainty assuming uncorrelated errors.
+ dl_outer = 0.5 / length
+ dl_dgamma = (
+ 2 * bond1.length * bond2.length * np.sin(gamma)
+ ) * dl_outer
+ dl_db1 = (
+ 2 * bond1.length - 2 * bond2.length * np.cos(gamma)
+ ) * dl_outer
+ dl_db2 = (
+ 2 * bond2.length - 2 * bond1.length * np.cos(gamma)
+ ) * dl_outer
+ stddev = np.sqrt(
+ (dl_dgamma * ba.stddev) ** 2
+ + (dl_db1 * bond1.stddev) ** 2
+ + (dl_db2 * bond2.stddev) ** 2
+ )
+ residue_virtual_bonds[resname].append(
+ Bond(ba.atom1_name, ba.atom3name, length, stddev)
+ )
+
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
+
+
+# Between-residue bond lengths for general bonds (first element) and for Proline
+# (second element).
+between_res_bond_length_c_n = [1.329, 1.341]
+between_res_bond_length_stddev_c_n = [0.014, 0.016]
+
+# Between-residue cos_angles.
+between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
+between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
+
+# This mapping is used when we need to store atom data in a format that requires
+# fixed atom data size for every residue (e.g. a numpy array).
+atom_types = [
+ "N",
+ "CA",
+ "C",
+ "CB",
+ "O",
+ "CG",
+ "CG1",
+ "CG2",
+ "OG",
+ "OG1",
+ "SG",
+ "CD",
+ "CD1",
+ "CD2",
+ "ND1",
+ "ND2",
+ "OD1",
+ "OD2",
+ "SD",
+ "CE",
+ "CE1",
+ "CE2",
+ "CE3",
+ "NE",
+ "NE1",
+ "NE2",
+ "OE1",
+ "OE2",
+ "CH2",
+ "NH1",
+ "NH2",
+ "OH",
+ "CZ",
+ "CZ2",
+ "CZ3",
+ "NZ",
+ "OXT",
+]
+atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
+atom_type_num = len(atom_types) # := 37.
+
+# A compact atom encoding with 14 columns
+# pylint: disable=line-too-long
+# pylint: disable=bad-whitespace
+restype_name_to_atom14_names = {
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
+ "ARG": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD",
+ "NE",
+ "CZ",
+ "NH1",
+ "NH2",
+ "",
+ "",
+ "",
+ ],
+ "ASN": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "OD1",
+ "ND2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "ASP": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "OD1",
+ "OD2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
+ "GLN": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD",
+ "OE1",
+ "NE2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "GLU": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD",
+ "OE1",
+ "OE2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
+ "HIS": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "ND1",
+ "CD2",
+ "CE1",
+ "NE2",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "ILE": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG1",
+ "CG2",
+ "CD1",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "LEU": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "LYS": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD",
+ "CE",
+ "NZ",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "MET": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "SD",
+ "CE",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "PHE": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "CE1",
+ "CE2",
+ "CZ",
+ "",
+ "",
+ "",
+ ],
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
+ "THR": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "OG1",
+ "CG2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "TRP": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "NE1",
+ "CE2",
+ "CE3",
+ "CZ2",
+ "CZ3",
+ "CH2",
+ ],
+ "TYR": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG",
+ "CD1",
+ "CD2",
+ "CE1",
+ "CE2",
+ "CZ",
+ "OH",
+ "",
+ "",
+ ],
+ "VAL": [
+ "N",
+ "CA",
+ "C",
+ "O",
+ "CB",
+ "CG1",
+ "CG2",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
+}
+# pylint: enable=line-too-long
+# pylint: enable=bad-whitespace
+
+
+# This is the standard residue order when coding AA type as a number.
+# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
+restypes = [
+ "A",
+ "R",
+ "N",
+ "D",
+ "C",
+ "Q",
+ "E",
+ "G",
+ "H",
+ "I",
+ "L",
+ "K",
+ "M",
+ "F",
+ "P",
+ "S",
+ "T",
+ "W",
+ "Y",
+ "V",
+]
+restype_order = {restype: i for i, restype in enumerate(restypes)}
+restype_num = len(restypes) # := 20.
+unk_restype_index = restype_num # Catch-all index for unknown restypes.
+
+restypes_with_x = restypes + ["X"]
+restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
+
+
+def sequence_to_onehot(
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
+) -> np.ndarray:
+ """Maps the given sequence into a one-hot encoded matrix.
+
+ Args:
+ sequence: An amino acid sequence.
+ mapping: A dictionary mapping amino acids to integers.
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
+ the mapping will throw an error.
+
+ Returns:
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
+ the sequence.
+
+ Raises:
+ ValueError: If the mapping doesn't contain values from 0 to
+ num_unique_aas - 1 without any gaps.
+ """
+ num_entries = max(mapping.values()) + 1
+
+ if sorted(set(mapping.values())) != list(range(num_entries)):
+ raise ValueError(
+ "The mapping must have values from 0 to num_unique_aas-1 "
+ "without any gaps. Got: %s" % sorted(mapping.values())
+ )
+
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
+
+ for aa_index, aa_type in enumerate(sequence):
+ if map_unknown_to_x:
+ if aa_type.isalpha() and aa_type.isupper():
+ aa_id = mapping.get(aa_type, mapping["X"])
+ else:
+ raise ValueError(
+ f"Invalid character in the sequence: {aa_type}"
+ )
+ else:
+ aa_id = mapping[aa_type]
+ one_hot_arr[aa_index, aa_id] = 1
+
+ return one_hot_arr
+
+
+restype_1to3 = {
+ "A": "ALA",
+ "R": "ARG",
+ "N": "ASN",
+ "D": "ASP",
+ "C": "CYS",
+ "Q": "GLN",
+ "E": "GLU",
+ "G": "GLY",
+ "H": "HIS",
+ "I": "ILE",
+ "L": "LEU",
+ "K": "LYS",
+ "M": "MET",
+ "F": "PHE",
+ "P": "PRO",
+ "S": "SER",
+ "T": "THR",
+ "W": "TRP",
+ "Y": "TYR",
+ "V": "VAL",
+}
+
+
+# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
+# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
+# many more, and less common, three letter names as keys and maps many of these
+# to the same one letter name (including 'X' and 'U' which we don't use here).
+restype_3to1 = {v: k for k, v in restype_1to3.items()}
+
+# Define a restype name for all unknown residues.
+unk_restype = "UNK"
+
+resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
+resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
+
+
+# The mapping here uses hhblits convention, so that B is mapped to D, J and O
+# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
+# remaining 20 amino acids are kept in alphabetical order.
+# There are 2 non-amino acid codes, X (representing any amino acid) and
+# "-" representing a missing amino acid in an alignment. The id for these
+# codes is put at the end (20 and 21) so that they can easily be ignored if
+# desired.
+HHBLITS_AA_TO_ID = {
+ "A": 0,
+ "B": 2,
+ "C": 1,
+ "D": 2,
+ "E": 3,
+ "F": 4,
+ "G": 5,
+ "H": 6,
+ "I": 7,
+ "J": 20,
+ "K": 8,
+ "L": 9,
+ "M": 10,
+ "N": 11,
+ "O": 20,
+ "P": 12,
+ "Q": 13,
+ "R": 14,
+ "S": 15,
+ "T": 16,
+ "U": 1,
+ "V": 17,
+ "W": 18,
+ "X": 20,
+ "Y": 19,
+ "Z": 3,
+ "-": 21,
+}
+
+# Partial inversion of HHBLITS_AA_TO_ID.
+ID_TO_HHBLITS_AA = {
+ 0: "A",
+ 1: "C", # Also U.
+ 2: "D", # Also B.
+ 3: "E", # Also Z.
+ 4: "F",
+ 5: "G",
+ 6: "H",
+ 7: "I",
+ 8: "K",
+ 9: "L",
+ 10: "M",
+ 11: "N",
+ 12: "P",
+ 13: "Q",
+ 14: "R",
+ 15: "S",
+ 16: "T",
+ 17: "V",
+ 18: "W",
+ 19: "Y",
+ 20: "X", # Includes J and O.
+ 21: "-",
+}
+
+restypes_with_x_and_gap = restypes + ["X", "-"]
+MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
+ for i in range(len(restypes_with_x_and_gap))
+)
+
+
+def _make_standard_atom_mask() -> np.ndarray:
+ """Returns [num_res_types, num_atom_types] mask array."""
+ # +1 to account for unknown (all 0s).
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
+ for restype, restype_letter in enumerate(restypes):
+ restype_name = restype_1to3[restype_letter]
+ atom_names = residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = atom_order[atom_name]
+ mask[restype, atom_type] = 1
+ return mask
+
+
+STANDARD_ATOM_MASK = _make_standard_atom_mask()
+
+
+# A one hot representation for the first and second atoms defining the axis
+# of rotation for each chi-angle in each residue.
+def chi_angle_atom(atom_index: int) -> np.ndarray:
+ """Define chi-angle rigid groups via one-hot representations."""
+ chi_angles_index = {}
+ one_hots = []
+
+ for k, v in chi_angles_atoms.items():
+ indices = [atom_types.index(s[atom_index]) for s in v]
+ indices.extend([-1] * (4 - len(indices)))
+ chi_angles_index[k] = indices
+
+ for r in restypes:
+ res3 = restype_1to3[r]
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
+ one_hots.append(one_hot)
+
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
+ one_hot = np.stack(one_hots, axis=0)
+ one_hot = np.transpose(one_hot, [0, 2, 1])
+
+ return one_hot
+
+
+chi_atom_1_one_hot = chi_angle_atom(1)
+chi_atom_2_one_hot = chi_angle_atom(2)
+
+# An array like chi_angles_atoms but using indices rather than names.
+chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
+chi_angles_atom_indices = tree.map_structure(
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
+)
+chi_angles_atom_indices = np.array(
+ [
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
+ for chi_atoms in chi_angles_atom_indices
+ ]
+)
+
+# Mapping from (res_name, atom_name) pairs to the atom's chi group index
+# and atom index within that group.
+chi_groups_for_atom = collections.defaultdict(list)
+for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
+ for atom_i, atom in enumerate(chi_group):
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
+chi_groups_for_atom = dict(chi_groups_for_atom)
+
+
+def _make_rigid_transformation_4x4(ex, ey, translation):
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
+ # Normalize ex.
+ ex_normalized = ex / np.linalg.norm(ex)
+
+ # make ey perpendicular to ex
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
+ ey_normalized /= np.linalg.norm(ey_normalized)
+
+ # compute ez as cross product
+ eznorm = np.cross(ex_normalized, ey_normalized)
+ m = np.stack(
+ [ex_normalized, ey_normalized, eznorm, translation]
+ ).transpose()
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
+ return m
+
+
+# create an array with (restype, atomtype) --> rigid_group_idx
+# and an array with (restype, atomtype, coord) for the atom positions
+# and compute affine transformation matrices (4,4) from one rigid group to the
+# previous group
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
+restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
+restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
+restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
+restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
+
+
+def _make_rigid_group_constants():
+ """Fill the arrays above."""
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
+ resname
+ ]:
+ atomtype = atom_order[atomname]
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
+ restype_atom37_mask[restype, atomtype] = 1
+ restype_atom37_rigid_group_positions[
+ restype, atomtype, :
+ ] = atom_position
+
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
+ restype_atom14_mask[restype, atom14idx] = 1
+ restype_atom14_rigid_group_positions[
+ restype, atom14idx, :
+ ] = atom_position
+
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_positions = {
+ name: np.array(pos)
+ for name, _, pos in rigid_group_atom_positions[resname]
+ }
+
+ # backbone to backbone is the identity transform
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
+
+ # pre-omega-frame to backbone (currently dummy identity matrix)
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
+
+ # phi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions["N"] - atom_positions["CA"],
+ ey=np.array([1.0, 0.0, 0.0]),
+ translation=atom_positions["N"],
+ )
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
+
+ # psi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions["C"] - atom_positions["CA"],
+ ey=atom_positions["CA"] - atom_positions["N"],
+ translation=atom_positions["C"],
+ )
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
+
+ # chi1-frame to backbone
+ if chi_angles_mask[restype][0]:
+ base_atom_names = chi_angles_atoms[resname][0]
+ base_atom_positions = [
+ atom_positions[name] for name in base_atom_names
+ ]
+ mat = _make_rigid_transformation_4x4(
+ ex=base_atom_positions[2] - base_atom_positions[1],
+ ey=base_atom_positions[0] - base_atom_positions[1],
+ translation=base_atom_positions[2],
+ )
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
+
+ # chi2-frame to chi1-frame
+ # chi3-frame to chi2-frame
+ # chi4-frame to chi3-frame
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
+ # previous frame
+ for chi_idx in range(1, 4):
+ if chi_angles_mask[restype][chi_idx]:
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
+ mat = _make_rigid_transformation_4x4(
+ ex=axis_end_atom_position,
+ ey=np.array([-1.0, 0.0, 0.0]),
+ translation=axis_end_atom_position,
+ )
+ restype_rigid_group_default_frame[
+ restype, 4 + chi_idx, :, :
+ ] = mat
+
+
+_make_rigid_group_constants()
+
+
+def make_atom14_dists_bounds(
+ overlap_tolerance=1.5, bond_length_tolerance_factor=15
+):
+ """compute upper and lower bounds for bonds to assess violations."""
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_list = restype_name_to_atom14_names[resname]
+
+ # create lower and upper bounds for clashes
+ for atom1_idx, atom1_name in enumerate(atom_list):
+ if not atom1_name:
+ continue
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
+ for atom2_idx, atom2_name in enumerate(atom_list):
+ if (not atom2_name) or atom1_idx == atom2_idx:
+ continue
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
+ lower = atom1_radius + atom2_radius - overlap_tolerance
+ upper = 1e10
+ restype_atom14_bond_lower_bound[
+ restype, atom1_idx, atom2_idx
+ ] = lower
+ restype_atom14_bond_lower_bound[
+ restype, atom2_idx, atom1_idx
+ ] = lower
+ restype_atom14_bond_upper_bound[
+ restype, atom1_idx, atom2_idx
+ ] = upper
+ restype_atom14_bond_upper_bound[
+ restype, atom2_idx, atom1_idx
+ ] = upper
+
+ # overwrite lower and upper bounds for bonds and angles
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
+ atom1_idx = atom_list.index(b.atom1_name)
+ atom2_idx = atom_list.index(b.atom2_name)
+ lower = b.length - bond_length_tolerance_factor * b.stddev
+ upper = b.length + bond_length_tolerance_factor * b.stddev
+ restype_atom14_bond_lower_bound[
+ restype, atom1_idx, atom2_idx
+ ] = lower
+ restype_atom14_bond_lower_bound[
+ restype, atom2_idx, atom1_idx
+ ] = lower
+ restype_atom14_bond_upper_bound[
+ restype, atom1_idx, atom2_idx
+ ] = upper
+ restype_atom14_bond_upper_bound[
+ restype, atom2_idx, atom1_idx
+ ] = upper
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
+ return {
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
+ }
+
+
+restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
+restype_atom14_ambiguous_atoms_swap_idx = np.tile(
+ np.arange(14, dtype=int), (21, 1)
+)
+
+
+def _make_atom14_ambiguity_feats():
+ for res, pairs in residue_atom_renaming_swaps.items():
+ res_idx = restype_order[restype_3to1[res]]
+ for atom1, atom2 in pairs.items():
+ atom1_idx = restype_name_to_atom14_names[res].index(atom1)
+ atom2_idx = restype_name_to_atom14_names[res].index(atom2)
+ restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
+ restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
+ restype_atom14_ambiguous_atoms_swap_idx[
+ res_idx, atom1_idx
+ ] = atom2_idx
+ restype_atom14_ambiguous_atoms_swap_idx[
+ res_idx, atom2_idx
+ ] = atom1_idx
+
+
+_make_atom14_ambiguity_feats()
+
+
+def aatype_to_str_sequence(aatype):
+ return ''.join([
+ restypes_with_x[aatype[i]]
+ for i in range(len(aatype))
+ ])
diff --git a/openfold/utils/__pycache__/checkpointing.cpython-310.pyc b/openfold/utils/__pycache__/checkpointing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1849c7538e102c16b4d5faff4c0a109a946f852
Binary files /dev/null and b/openfold/utils/__pycache__/checkpointing.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc b/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbc8d273987289da03ea78e161e8709db1e9c64c
Binary files /dev/null and b/openfold/utils/__pycache__/exponential_moving_average.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/feats.cpython-310.pyc b/openfold/utils/__pycache__/feats.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bae498213af441eba021d71240ee3450b7a20aef
Binary files /dev/null and b/openfold/utils/__pycache__/feats.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/loss.cpython-310.pyc b/openfold/utils/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc3735c9e452c66675182977db67028923c192c
Binary files /dev/null and b/openfold/utils/__pycache__/loss.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/precision_utils.cpython-310.pyc b/openfold/utils/__pycache__/precision_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e169848727ad015efaa0fdbbbc47122d6c96eefa
Binary files /dev/null and b/openfold/utils/__pycache__/precision_utils.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc b/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..119fe21dcd80406a3b5d234c28a37f56d6932cdb
Binary files /dev/null and b/openfold/utils/__pycache__/rigid_utils.cpython-310.pyc differ
diff --git a/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc b/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b043ca55f494fc84d30c696ac552219247160ed
Binary files /dev/null and b/openfold/utils/__pycache__/rigid_utils.cpython-38.pyc differ
diff --git a/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc b/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee386bc0b20fe83c201b59d11cb6da9618a33bd2
Binary files /dev/null and b/openfold/utils/__pycache__/tensor_utils.cpython-310.pyc differ
diff --git a/openfold/utils/argparse.py b/openfold/utils/argparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..d13f895ec418ef74af1c1ffd7ec227cdd691236f
--- /dev/null
+++ b/openfold/utils/argparse.py
@@ -0,0 +1,30 @@
+from argparse import HelpFormatter
+from operator import attrgetter
+
+class ArgparseAlphabetizer(HelpFormatter):
+ """
+ Sorts the optional arguments of an argparse parser alphabetically
+ """
+
+ @staticmethod
+ def sort_actions(actions):
+ return sorted(actions, key=attrgetter("option_strings"))
+
+ # Formats the help message
+ def add_arguments(self, actions):
+ actions = ArgparseAlphabetizer.sort_actions(actions)
+ super(ArgparseAlphabetizer, self).add_arguments(actions)
+
+ # Formats the usage message
+ def add_usage(self, usage, actions, groups, prefix=None):
+ actions = ArgparseAlphabetizer.sort_actions(actions)
+ args = usage, actions, groups, prefix
+ super(ArgparseAlphabetizer, self).add_usage(*args)
+
+
+def remove_arguments(parser, args):
+ for arg in args:
+ for action in parser._actions:
+ opts = vars(action)["option_strings"]
+ if(arg in opts):
+ parser._handle_conflict_resolve(None, [(arg, action)])
diff --git a/openfold/utils/callbacks.py b/openfold/utils/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..872dbdce98bcca08e58368481c05024986b4a71f
--- /dev/null
+++ b/openfold/utils/callbacks.py
@@ -0,0 +1,14 @@
+from pytorch_lightning.utilities import rank_zero_info
+from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+
+class EarlyStoppingVerbose(EarlyStopping):
+ """
+ The default EarlyStopping callback's verbose mode is too verbose.
+ This class outputs a message only when it's getting ready to stop.
+ """
+ def _evalute_stopping_criteria(self, *args):
+ should_stop, reason = super()._evalute_stopping_criteria(*args)
+ if(should_stop):
+ rank_zero_info(f"{reason}\n")
+
+ return should_stop, reason
diff --git a/openfold/utils/checkpointing.py b/openfold/utils/checkpointing.py
new file mode 100644
index 0000000000000000000000000000000000000000..75a4455ae6b1581d76dab7bcf3adf590e9be15df
--- /dev/null
+++ b/openfold/utils/checkpointing.py
@@ -0,0 +1,88 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import deepspeed
+import torch
+import torch.utils.checkpoint
+from typing import Any, Tuple, List, Callable, Optional
+
+
+BLOCK_ARG = Any
+BLOCK_ARGS = List[BLOCK_ARG]
+
+
+def get_checkpoint_fn():
+ if(deepspeed.checkpointing.is_configured()):
+ checkpoint = deepspeed.checkpointing.checkpoint
+ else:
+ checkpoint = torch.utils.checkpoint.checkpoint
+
+ return checkpoint
+
+
+@torch.jit.ignore
+def checkpoint_blocks(
+ blocks: List[Callable],
+ args: BLOCK_ARGS,
+ blocks_per_ckpt: Optional[int],
+) -> BLOCK_ARGS:
+ """
+ Chunk a list of blocks and run each chunk with activation
+ checkpointing. We define a "block" as a callable whose only inputs are
+ the outputs of the previous block.
+
+ Implements Subsection 1.11.8
+
+ Args:
+ blocks:
+ List of blocks
+ args:
+ Tuple of arguments for the first block.
+ blocks_per_ckpt:
+ Size of each chunk. A higher value corresponds to fewer
+ checkpoints, and trades memory for speed. If None, no checkpointing
+ is performed.
+ Returns:
+ The output of the final block
+ """
+ def wrap(a):
+ return (a,) if type(a) is not tuple else a
+
+ def exec(b, a):
+ for block in b:
+ a = wrap(block(*a))
+ return a
+
+ def chunker(s, e):
+ def exec_sliced(*a):
+ return exec(blocks[s:e], a)
+
+ return exec_sliced
+
+ # Avoids mishaps when the blocks take just one argument
+ args = wrap(args)
+
+ if blocks_per_ckpt is None:
+ return exec(blocks, args)
+ elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
+ raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
+
+ checkpoint = get_checkpoint_fn()
+
+ for s in range(0, len(blocks), blocks_per_ckpt):
+ e = s + blocks_per_ckpt
+ args = checkpoint(chunker(s, e), *args)
+ args = wrap(args)
+
+ return args
\ No newline at end of file
diff --git a/openfold/utils/exponential_moving_average.py b/openfold/utils/exponential_moving_average.py
new file mode 100644
index 0000000000000000000000000000000000000000..53649506a7af46c7530687c14c79c2cb8d43c089
--- /dev/null
+++ b/openfold/utils/exponential_moving_average.py
@@ -0,0 +1,70 @@
+from collections import OrderedDict
+import copy
+import torch
+import torch.nn as nn
+
+from openfold.utils.tensor_utils import tensor_tree_map
+
+
+class ExponentialMovingAverage:
+ """
+ Maintains moving averages of parameters with exponential decay
+
+ At each step, the stored copy `copy` of each parameter `param` is
+ updated as follows:
+
+ `copy = decay * copy + (1 - decay) * param`
+
+ where `decay` is an attribute of the ExponentialMovingAverage object.
+ """
+
+ def __init__(self, model: nn.Module, decay: float):
+ """
+ Args:
+ model:
+ A torch.nn.Module whose parameters are to be tracked
+ decay:
+ A value (usually close to 1.) by which updates are
+ weighted as part of the above formula
+ """
+ super(ExponentialMovingAverage, self).__init__()
+
+ clone_param = lambda t: t.clone().detach()
+ self.params = tensor_tree_map(clone_param, model.state_dict())
+ self.decay = decay
+ self.device = next(model.parameters()).device
+
+ def to(self, device):
+ self.params = tensor_tree_map(lambda t: t.to(device), self.params)
+ self.device = device
+
+ def _update_state_dict_(self, update, state_dict):
+ with torch.no_grad():
+ for k, v in update.items():
+ stored = state_dict[k]
+ if not isinstance(v, torch.Tensor):
+ self._update_state_dict_(v, stored)
+ else:
+ diff = stored - v
+ diff *= 1 - self.decay
+ stored -= diff
+
+ def update(self, model: torch.nn.Module) -> None:
+ """
+ Updates the stored parameters using the state dict of the provided
+ module. The module should have the same structure as that used to
+ initialize the ExponentialMovingAverage object.
+ """
+ self._update_state_dict_(model.state_dict(), self.params)
+
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
+ self.params = state_dict["params"]
+ self.decay = state_dict["decay"]
+
+ def state_dict(self) -> OrderedDict:
+ return OrderedDict(
+ {
+ "params": self.params,
+ "decay": self.decay,
+ }
+ )
diff --git a/openfold/utils/feats.py b/openfold/utils/feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..3be298ff4cfd578d446142826937b5643551db29
--- /dev/null
+++ b/openfold/utils/feats.py
@@ -0,0 +1,267 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+from typing import Dict
+
+from openfold.np import protein
+import openfold.np.residue_constants as rc
+from openfold.utils.rigid_utils import Rotation, Rigid
+from openfold.utils.tensor_utils import (
+ batched_gather,
+ one_hot,
+ tree_map,
+ tensor_tree_map,
+)
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
+ is_gly = aatype == rc.restype_order["G"]
+ ca_idx = rc.atom_order["CA"]
+ cb_idx = rc.atom_order["CB"]
+ pseudo_beta = torch.where(
+ is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :],
+ )
+
+ if all_atom_masks is not None:
+ pseudo_beta_mask = torch.where(
+ is_gly,
+ all_atom_masks[..., ca_idx],
+ all_atom_masks[..., cb_idx],
+ )
+ return pseudo_beta, pseudo_beta_mask
+ else:
+ return pseudo_beta
+
+
+def atom14_to_atom37(atom14, batch):
+ atom37_data = batched_gather(
+ atom14,
+ batch["residx_atom37_to_atom14"],
+ dim=-2,
+ no_batch_dims=len(atom14.shape[:-2]),
+ )
+
+ atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
+
+ return atom37_data
+
+
+def build_template_angle_feat(template_feats):
+ template_aatype = template_feats["template_aatype"]
+ torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
+ alt_torsion_angles_sin_cos = template_feats[
+ "template_alt_torsion_angles_sin_cos"
+ ]
+ torsion_angles_mask = template_feats["template_torsion_angles_mask"]
+ template_angle_feat = torch.cat(
+ [
+ nn.functional.one_hot(template_aatype, 22),
+ torsion_angles_sin_cos.reshape(
+ *torsion_angles_sin_cos.shape[:-2], 14
+ ),
+ alt_torsion_angles_sin_cos.reshape(
+ *alt_torsion_angles_sin_cos.shape[:-2], 14
+ ),
+ torsion_angles_mask,
+ ],
+ dim=-1,
+ )
+
+ return template_angle_feat
+
+
+def build_template_pair_feat(
+ batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8
+):
+ template_mask = batch["template_pseudo_beta_mask"]
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
+
+ # Compute distogram (this seems to differ slightly from Alg. 5)
+ tpb = batch["template_pseudo_beta"]
+ dgram = torch.sum(
+ (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
+ )
+ lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
+ upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
+ dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
+
+ to_concat = [dgram, template_mask_2d[..., None]]
+
+ aatype_one_hot = nn.functional.one_hot(
+ batch["template_aatype"],
+ rc.restype_num + 2,
+ )
+
+ n_res = batch["template_aatype"].shape[-1]
+ to_concat.append(
+ aatype_one_hot[..., None, :, :].expand(
+ *aatype_one_hot.shape[:-2], n_res, -1, -1
+ )
+ )
+ to_concat.append(
+ aatype_one_hot[..., None, :].expand(
+ *aatype_one_hot.shape[:-2], -1, n_res, -1
+ )
+ )
+
+ n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
+ rigids = Rigid.make_transform_from_reference(
+ n_xyz=batch["template_all_atom_positions"][..., n, :],
+ ca_xyz=batch["template_all_atom_positions"][..., ca, :],
+ c_xyz=batch["template_all_atom_positions"][..., c, :],
+ eps=eps,
+ )
+ points = rigids.get_trans()[..., None, :, :]
+ rigid_vec = rigids[..., None].invert_apply(points)
+
+ inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
+
+ t_aa_masks = batch["template_all_atom_mask"]
+ template_mask = (
+ t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
+ )
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
+
+ inv_distance_scalar = inv_distance_scalar * template_mask_2d
+ unit_vector = rigid_vec * inv_distance_scalar[..., None]
+ to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
+ to_concat.append(template_mask_2d[..., None])
+
+ act = torch.cat(to_concat, dim=-1)
+ act = act * template_mask_2d[..., None]
+
+ return act
+
+
+def build_extra_msa_feat(batch):
+ msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
+ msa_feat = [
+ msa_1hot,
+ batch["extra_has_deletion"].unsqueeze(-1),
+ batch["extra_deletion_value"].unsqueeze(-1),
+ ]
+ return torch.cat(msa_feat, dim=-1)
+
+
+def torsion_angles_to_frames(
+ r: Rigid,
+ alpha: torch.Tensor,
+ aatype: torch.Tensor,
+ rrgdf: torch.Tensor,
+):
+ # [*, N, 8, 4, 4]
+ default_4x4 = rrgdf[aatype, ...]
+
+ # [*, N, 8] transformations, i.e.
+ # One [*, N, 8, 3, 3] rotation matrix and
+ # One [*, N, 8, 3] translation matrix
+ default_r = r.from_tensor_4x4(default_4x4)
+
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
+ bb_rot[..., 1] = 1
+
+ # [*, N, 8, 2]
+ alpha = torch.cat(
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
+ )
+
+ # [*, N, 8, 3, 3]
+ # Produces rotation matrices of the form:
+ # [
+ # [1, 0 , 0 ],
+ # [0, a_2,-a_1],
+ # [0, a_1, a_2]
+ # ]
+ # This follows the original code rather than the supplement, which uses
+ # different indices.
+
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
+ all_rots[..., 0, 0] = 1
+ all_rots[..., 1, 1] = alpha[..., 1]
+ all_rots[..., 1, 2] = -alpha[..., 0]
+ all_rots[..., 2, 1:] = alpha
+
+ all_rots = Rigid(Rotation(rot_mats=all_rots), None)
+
+ all_frames = default_r.compose(all_rots)
+
+ chi2_frame_to_frame = all_frames[..., 5]
+ chi3_frame_to_frame = all_frames[..., 6]
+ chi4_frame_to_frame = all_frames[..., 7]
+
+ chi1_frame_to_bb = all_frames[..., 4]
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
+
+ all_frames_to_bb = Rigid.cat(
+ [
+ all_frames[..., :5],
+ chi2_frame_to_bb.unsqueeze(-1),
+ chi3_frame_to_bb.unsqueeze(-1),
+ chi4_frame_to_bb.unsqueeze(-1),
+ ],
+ dim=-1,
+ )
+
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
+
+ return all_frames_to_global
+
+
+def frames_and_literature_positions_to_atom14_pos(
+ r: Rigid,
+ aatype: torch.Tensor,
+ default_frames,
+ group_idx,
+ atom_mask,
+ lit_positions,
+):
+ # [*, N, 14, 4, 4]
+ default_4x4 = default_frames[aatype, ...]
+
+ # [*, N, 14]
+ group_mask = group_idx[aatype, ...]
+
+ # [*, N, 14, 8]
+ group_mask = nn.functional.one_hot(
+ group_mask,
+ num_classes=default_frames.shape[-3],
+ )
+
+ # [*, N, 14, 8]
+ t_atoms_to_global = r[..., None, :] * group_mask
+
+ # [*, N, 14]
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
+ lambda x: torch.sum(x, dim=-1)
+ )
+
+ # [*, N, 14, 1]
+ atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
+
+ # [*, N, 14, 3]
+ lit_positions = lit_positions[aatype, ...]
+ pred_positions = t_atoms_to_global.apply(lit_positions)
+ pred_positions = pred_positions * atom_mask
+
+ return pred_positions
diff --git a/openfold/utils/import_weights.py b/openfold/utils/import_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1ebd6d6adc6d60e9dadbb04c677db013cb7f3a
--- /dev/null
+++ b/openfold/utils/import_weights.py
@@ -0,0 +1,449 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from dataclasses import dataclass
+from functools import partial
+import numpy as np
+import torch
+from typing import Union, List
+
+
+_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
+
+
+# With Param, a poor man's enum with attributes (Rust-style)
+class ParamType(Enum):
+ LinearWeight = partial( # hack: partial prevents fns from becoming methods
+ lambda w: w.transpose(-1, -2)
+ )
+ LinearWeightMHA = partial(
+ lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
+ )
+ LinearMHAOutputWeight = partial(
+ lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
+ )
+ LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
+ LinearWeightOPM = partial(
+ lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
+ )
+ Other = partial(lambda w: w)
+
+ def __init__(self, fn):
+ self.transformation = fn
+
+
+@dataclass
+class Param:
+ param: Union[torch.Tensor, List[torch.Tensor]]
+ param_type: ParamType = ParamType.Other
+ stacked: bool = False
+
+
+def _process_translations_dict(d, top_layer=True):
+ flat = {}
+ for k, v in d.items():
+ if type(v) == dict:
+ prefix = _NPZ_KEY_PREFIX if top_layer else ""
+ sub_flat = {
+ (prefix + "/".join([k, k_prime])): v_prime
+ for k_prime, v_prime in _process_translations_dict(
+ v, top_layer=False
+ ).items()
+ }
+ flat.update(sub_flat)
+ else:
+ k = "/" + k if not top_layer else k
+ flat[k] = v
+
+ return flat
+
+
+def stacked(param_dict_list, out=None):
+ """
+ Args:
+ param_dict_list:
+ A list of (nested) Param dicts to stack. The structure of
+ each dict must be the identical (down to the ParamTypes of
+ "parallel" Params). There must be at least one dict
+ in the list.
+ """
+ if out is None:
+ out = {}
+ template = param_dict_list[0]
+ for k, _ in template.items():
+ v = [d[k] for d in param_dict_list]
+ if type(v[0]) is dict:
+ out[k] = {}
+ stacked(v, out=out[k])
+ elif type(v[0]) is Param:
+ stacked_param = Param(
+ param=[param.param for param in v],
+ param_type=v[0].param_type,
+ stacked=True,
+ )
+
+ out[k] = stacked_param
+
+ return out
+
+
+def assign(translation_dict, orig_weights):
+ for k, param in translation_dict.items():
+ with torch.no_grad():
+ weights = torch.as_tensor(orig_weights[k])
+ ref, param_type = param.param, param.param_type
+ if param.stacked:
+ weights = torch.unbind(weights, 0)
+ else:
+ weights = [weights]
+ ref = [ref]
+
+ try:
+ weights = list(map(param_type.transformation, weights))
+ for p, w in zip(ref, weights):
+ p.copy_(w)
+ except:
+ print(k)
+ print(ref[0].shape)
+ print(weights[0].shape)
+ raise
+
+
+def import_jax_weights_(model, npz_path, version="model_1"):
+ data = np.load(npz_path)
+
+ #######################
+ # Some templates
+ #######################
+
+ LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
+
+ LinearBias = lambda l: (Param(l))
+
+ LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
+
+ LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
+
+ LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
+
+ LinearParams = lambda l: {
+ "weights": LinearWeight(l.weight),
+ "bias": LinearBias(l.bias),
+ }
+
+ LayerNormParams = lambda l: {
+ "scale": Param(l.weight),
+ "offset": Param(l.bias),
+ }
+
+ AttentionParams = lambda att: {
+ "query_w": LinearWeightMHA(att.linear_q.weight),
+ "key_w": LinearWeightMHA(att.linear_k.weight),
+ "value_w": LinearWeightMHA(att.linear_v.weight),
+ "output_w": Param(
+ att.linear_o.weight,
+ param_type=ParamType.LinearMHAOutputWeight,
+ ),
+ "output_b": LinearBias(att.linear_o.bias),
+ }
+
+ AttentionGatedParams = lambda att: dict(
+ **AttentionParams(att),
+ **{
+ "gating_w": LinearWeightMHA(att.linear_g.weight),
+ "gating_b": LinearBiasMHA(att.linear_g.bias),
+ },
+ )
+
+ GlobalAttentionParams = lambda att: dict(
+ AttentionGatedParams(att),
+ key_w=LinearWeight(att.linear_k.weight),
+ value_w=LinearWeight(att.linear_v.weight),
+ )
+
+ TriAttParams = lambda tri_att: {
+ "query_norm": LayerNormParams(tri_att.layer_norm),
+ "feat_2d_weights": LinearWeight(tri_att.linear.weight),
+ "attention": AttentionGatedParams(tri_att.mha),
+ }
+
+ TriMulOutParams = lambda tri_mul: {
+ "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
+ "left_projection": LinearParams(tri_mul.linear_a_p),
+ "right_projection": LinearParams(tri_mul.linear_b_p),
+ "left_gate": LinearParams(tri_mul.linear_a_g),
+ "right_gate": LinearParams(tri_mul.linear_b_g),
+ "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
+ "output_projection": LinearParams(tri_mul.linear_z),
+ "gating_linear": LinearParams(tri_mul.linear_g),
+ }
+
+ # see commit b88f8da on the Alphafold repo
+ # Alphafold swaps the pseudocode's a and b between the incoming/outcoming
+ # iterations of triangle multiplication, which is confusing and not
+ # reproduced in our implementation.
+ TriMulInParams = lambda tri_mul: {
+ "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
+ "left_projection": LinearParams(tri_mul.linear_b_p),
+ "right_projection": LinearParams(tri_mul.linear_a_p),
+ "left_gate": LinearParams(tri_mul.linear_b_g),
+ "right_gate": LinearParams(tri_mul.linear_a_g),
+ "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
+ "output_projection": LinearParams(tri_mul.linear_z),
+ "gating_linear": LinearParams(tri_mul.linear_g),
+ }
+
+ PairTransitionParams = lambda pt: {
+ "input_layer_norm": LayerNormParams(pt.layer_norm),
+ "transition1": LinearParams(pt.linear_1),
+ "transition2": LinearParams(pt.linear_2),
+ }
+
+ MSAAttParams = lambda matt: {
+ "query_norm": LayerNormParams(matt.layer_norm_m),
+ "attention": AttentionGatedParams(matt.mha),
+ }
+
+ MSAColAttParams = lambda matt: {
+ "query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
+ "attention": AttentionGatedParams(matt._msa_att.mha),
+ }
+
+ MSAGlobalAttParams = lambda matt: {
+ "query_norm": LayerNormParams(matt.layer_norm_m),
+ "attention": GlobalAttentionParams(matt.global_attention),
+ }
+
+ MSAAttPairBiasParams = lambda matt: dict(
+ **MSAAttParams(matt),
+ **{
+ "feat_2d_norm": LayerNormParams(matt.layer_norm_z),
+ "feat_2d_weights": LinearWeight(matt.linear_z.weight),
+ },
+ )
+
+ IPAParams = lambda ipa: {
+ "q_scalar": LinearParams(ipa.linear_q),
+ "kv_scalar": LinearParams(ipa.linear_kv),
+ "q_point_local": LinearParams(ipa.linear_q_points),
+ "kv_point_local": LinearParams(ipa.linear_kv_points),
+ "trainable_point_weights": Param(
+ param=ipa.head_weights, param_type=ParamType.Other
+ ),
+ "attention_2d": LinearParams(ipa.linear_b),
+ "output_projection": LinearParams(ipa.linear_out),
+ }
+
+ TemplatePairBlockParams = lambda b: {
+ "triangle_attention_starting_node": TriAttParams(b.tri_att_start),
+ "triangle_attention_ending_node": TriAttParams(b.tri_att_end),
+ "triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
+ "triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
+ "pair_transition": PairTransitionParams(b.pair_transition),
+ }
+
+ MSATransitionParams = lambda m: {
+ "input_layer_norm": LayerNormParams(m.layer_norm),
+ "transition1": LinearParams(m.linear_1),
+ "transition2": LinearParams(m.linear_2),
+ }
+
+ OuterProductMeanParams = lambda o: {
+ "layer_norm_input": LayerNormParams(o.layer_norm),
+ "left_projection": LinearParams(o.linear_1),
+ "right_projection": LinearParams(o.linear_2),
+ "output_w": LinearWeightOPM(o.linear_out.weight),
+ "output_b": LinearBias(o.linear_out.bias),
+ }
+
+ def EvoformerBlockParams(b, is_extra_msa=False):
+ if is_extra_msa:
+ col_att_name = "msa_column_global_attention"
+ msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
+ else:
+ col_att_name = "msa_column_attention"
+ msa_col_att_params = MSAColAttParams(b.msa_att_col)
+
+ d = {
+ "msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
+ b.msa_att_row
+ ),
+ col_att_name: msa_col_att_params,
+ "msa_transition": MSATransitionParams(b.core.msa_transition),
+ "outer_product_mean":
+ OuterProductMeanParams(b.core.outer_product_mean),
+ "triangle_multiplication_outgoing":
+ TriMulOutParams(b.core.tri_mul_out),
+ "triangle_multiplication_incoming":
+ TriMulInParams(b.core.tri_mul_in),
+ "triangle_attention_starting_node":
+ TriAttParams(b.core.tri_att_start),
+ "triangle_attention_ending_node":
+ TriAttParams(b.core.tri_att_end),
+ "pair_transition":
+ PairTransitionParams(b.core.pair_transition),
+ }
+
+ return d
+
+ ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
+
+ FoldIterationParams = lambda sm: {
+ "invariant_point_attention": IPAParams(sm.ipa),
+ "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
+ "transition": LinearParams(sm.transition.layers[0].linear_1),
+ "transition_1": LinearParams(sm.transition.layers[0].linear_2),
+ "transition_2": LinearParams(sm.transition.layers[0].linear_3),
+ "transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
+ "affine_update": LinearParams(sm.bb_update.linear),
+ "rigid_sidechain": {
+ "input_projection": LinearParams(sm.angle_resnet.linear_in),
+ "input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
+ "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
+ "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
+ "resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
+ "resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
+ "unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
+ },
+ }
+
+ ############################
+ # translations dict overflow
+ ############################
+
+ tps_blocks = model.template_pair_stack.blocks
+ tps_blocks_params = stacked(
+ [TemplatePairBlockParams(b) for b in tps_blocks]
+ )
+
+ ems_blocks = model.extra_msa_stack.blocks
+ ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
+
+ evo_blocks = model.evoformer.blocks
+ evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
+
+ translations = {
+ "evoformer": {
+ "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
+ "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
+ "left_single": LinearParams(model.input_embedder.linear_tf_z_i),
+ "right_single": LinearParams(model.input_embedder.linear_tf_z_j),
+ "prev_pos_linear": LinearParams(model.recycling_embedder.linear),
+ "prev_msa_first_row_norm": LayerNormParams(
+ model.recycling_embedder.layer_norm_m
+ ),
+ "prev_pair_norm": LayerNormParams(
+ model.recycling_embedder.layer_norm_z
+ ),
+ "pair_activiations": LinearParams(
+ model.input_embedder.linear_relpos
+ ),
+ "template_embedding": {
+ "single_template_embedding": {
+ "embedding2d": LinearParams(
+ model.template_pair_embedder.linear
+ ),
+ "template_pair_stack": {
+ "__layer_stack_no_state": tps_blocks_params,
+ },
+ "output_layer_norm": LayerNormParams(
+ model.template_pair_stack.layer_norm
+ ),
+ },
+ "attention": AttentionParams(model.template_pointwise_att.mha),
+ },
+ "extra_msa_activations": LinearParams(
+ model.extra_msa_embedder.linear
+ ),
+ "extra_msa_stack": ems_blocks_params,
+ "template_single_embedding": LinearParams(
+ model.template_angle_embedder.linear_1
+ ),
+ "template_projection": LinearParams(
+ model.template_angle_embedder.linear_2
+ ),
+ "evoformer_iteration": evo_blocks_params,
+ "single_activations": LinearParams(model.evoformer.linear),
+ },
+ "structure_module": {
+ "single_layer_norm": LayerNormParams(
+ model.structure_module.layer_norm_s
+ ),
+ "initial_projection": LinearParams(
+ model.structure_module.linear_in
+ ),
+ "pair_layer_norm": LayerNormParams(
+ model.structure_module.layer_norm_z
+ ),
+ "fold_iteration": FoldIterationParams(model.structure_module),
+ },
+ "predicted_lddt_head": {
+ "input_layer_norm": LayerNormParams(
+ model.aux_heads.plddt.layer_norm
+ ),
+ "act_0": LinearParams(model.aux_heads.plddt.linear_1),
+ "act_1": LinearParams(model.aux_heads.plddt.linear_2),
+ "logits": LinearParams(model.aux_heads.plddt.linear_3),
+ },
+ "distogram_head": {
+ "half_logits": LinearParams(model.aux_heads.distogram.linear),
+ },
+ "experimentally_resolved_head": {
+ "logits": LinearParams(
+ model.aux_heads.experimentally_resolved.linear
+ ),
+ },
+ "masked_msa_head": {
+ "logits": LinearParams(model.aux_heads.masked_msa.linear),
+ },
+ }
+
+ no_templ = [
+ "model_3",
+ "model_4",
+ "model_5",
+ "model_3_ptm",
+ "model_4_ptm",
+ "model_5_ptm",
+ ]
+ if version in no_templ:
+ evo_dict = translations["evoformer"]
+ keys = list(evo_dict.keys())
+ for k in keys:
+ if "template_" in k:
+ evo_dict.pop(k)
+
+ if "_ptm" in version:
+ translations["predicted_aligned_error_head"] = {
+ "logits": LinearParams(model.aux_heads.tm.linear)
+ }
+
+ # Flatten keys and insert missing key prefixes
+ flat = _process_translations_dict(translations)
+
+ # Sanity check
+ keys = list(data.keys())
+ flat_keys = list(flat.keys())
+ incorrect = [k for k in flat_keys if k not in keys]
+ missing = [k for k in keys if k not in flat_keys]
+ # print(f"Incorrect: {incorrect}")
+ # print(f"Missing: {missing}")
+
+ assert len(incorrect) == 0
+ # assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
+
+ # Set weights
+ assign(flat, data)
diff --git a/openfold/utils/logger.py b/openfold/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2e2223e3020765b8d847c269285a5f5de248b94
--- /dev/null
+++ b/openfold/utils/logger.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import operator
+import time
+
+import dllogger as logger
+import numpy as np
+import torch.cuda.profiler as profiler
+from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
+from pytorch_lightning import Callback
+
+
+def is_main_process():
+ return int(os.getenv("LOCAL_RANK", "0")) == 0
+
+
+class PerformanceLoggingCallback(Callback):
+ def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False):
+ logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)])
+ self.warmup_steps = warmup_steps
+ self.global_batch_size = global_batch_size
+ self.step = 0
+ self.profile = profile
+ self.timestamps = []
+
+ def do_step(self):
+ self.step += 1
+ if self.profile and self.step == self.warmup_steps:
+ profiler.start()
+ if self.step > self.warmup_steps:
+ self.timestamps.append(time.time())
+
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
+ self.do_step()
+
+ def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
+ self.do_step()
+
+ def process_performance_stats(self, deltas):
+ def _round3(val):
+ return round(val, 3)
+
+ throughput_imgps = _round3(self.global_batch_size / np.mean(deltas))
+ timestamps_ms = 1000 * deltas
+ stats = {
+ f"throughput": throughput_imgps,
+ f"latency_mean": _round3(timestamps_ms.mean()),
+ }
+ for level in [90, 95, 99]:
+ stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))})
+
+ return stats
+
+ def _log(self):
+ if is_main_process():
+ diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1]))
+ deltas = np.array(diffs)
+ stats = self.process_performance_stats(deltas)
+ logger.log(step=(), data=stats)
+ logger.flush()
+
+ def on_train_end(self, trainer, pl_module):
+ if self.profile:
+ profiler.stop()
+ self._log()
+
+ def on_epoch_end(self, trainer, pl_module):
+ self._log()
diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..50809af9e7c3564e7104bfb298f50fc64f8dd597
--- /dev/null
+++ b/openfold/utils/loss.py
@@ -0,0 +1,1637 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+import logging
+import ml_collections
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.distributions.bernoulli import Bernoulli
+from typing import Dict, Optional, Tuple
+
+from openfold.np import residue_constants
+from openfold.utils import feats
+from openfold.utils.rigid_utils import Rotation, Rigid
+from openfold.utils.tensor_utils import (
+ tree_map,
+ tensor_tree_map,
+ masked_mean,
+ permute_final_dims,
+ batched_gather,
+)
+
+
+def softmax_cross_entropy(logits, labels):
+ loss = -1 * torch.sum(
+ labels * torch.nn.functional.log_softmax(logits, dim=-1),
+ dim=-1,
+ )
+ return loss
+
+
+def sigmoid_cross_entropy(logits, labels):
+ log_p = torch.log(torch.sigmoid(logits))
+ log_not_p = torch.log(torch.sigmoid(-logits))
+ loss = -labels * log_p - (1 - labels) * log_not_p
+ return loss
+
+
+def torsion_angle_loss(
+ a, # [*, N, 7, 2]
+ a_gt, # [*, N, 7, 2]
+ a_alt_gt, # [*, N, 7, 2]
+):
+ # [*, N, 7]
+ norm = torch.norm(a, dim=-1)
+
+ # [*, N, 7, 2]
+ a = a / norm.unsqueeze(-1)
+
+ # [*, N, 7]
+ diff_norm_gt = torch.norm(a - a_gt, dim=-1)
+ diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
+ min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
+
+ # [*]
+ l_torsion = torch.mean(min_diff, dim=(-1, -2))
+ l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
+
+ an_weight = 0.02
+ return l_torsion + an_weight * l_angle_norm
+
+
+def compute_fape(
+ pred_frames: Rigid,
+ target_frames: Rigid,
+ frames_mask: torch.Tensor,
+ pred_positions: torch.Tensor,
+ target_positions: torch.Tensor,
+ positions_mask: torch.Tensor,
+ length_scale: float,
+ l1_clamp_distance: Optional[float] = None,
+ eps=1e-8,
+ ignore_nan=True,
+) -> torch.Tensor:
+ """
+ Computes FAPE loss.
+
+ Args:
+ pred_frames:
+ [*, N_frames] Rigid object of predicted frames
+ target_frames:
+ [*, N_frames] Rigid object of ground truth frames
+ frames_mask:
+ [*, N_frames] binary mask for the frames
+ pred_positions:
+ [*, N_pts, 3] predicted atom positions
+ target_positions:
+ [*, N_pts, 3] ground truth positions
+ positions_mask:
+ [*, N_pts] positions mask
+ length_scale:
+ Length scale by which the loss is divided
+ l1_clamp_distance:
+ Cutoff above which distance errors are disregarded
+ eps:
+ Small value used to regularize denominators
+ Returns:
+ [*] loss tensor
+ """
+ # [*, N_frames, N_pts, 3]
+ local_pred_pos = pred_frames.invert()[..., None].apply(
+ pred_positions[..., None, :, :],
+ )
+ local_target_pos = target_frames.invert()[..., None].apply(
+ target_positions[..., None, :, :],
+ )
+
+ error_dist = torch.sqrt(
+ torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
+ )
+
+ if l1_clamp_distance is not None:
+ error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
+
+ normed_error = error_dist / length_scale
+ normed_error = normed_error * frames_mask[..., None]
+ normed_error = normed_error * positions_mask[..., None, :]
+ if ignore_nan:
+ normed_error = torch.nan_to_num(normed_error)
+
+ # FP16-friendly averaging. Roughly equivalent to:
+ #
+ # norm_factor = (
+ # torch.sum(frames_mask, dim=-1) *
+ # torch.sum(positions_mask, dim=-1)
+ # )
+ # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
+ #
+ # ("roughly" because eps is necessarily duplicated in the latter)
+ normed_error = torch.sum(normed_error, dim=-1)
+ normed_error = (
+ normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
+ )
+ normed_error = torch.sum(normed_error, dim=-1)
+ normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
+ return normed_error
+
+
+def backbone_loss(
+ backbone_rigid_tensor: torch.Tensor,
+ backbone_rigid_mask: torch.Tensor,
+ traj: torch.Tensor,
+ use_clamped_fape: Optional[torch.Tensor] = None,
+ clamp_distance: float = 10.0,
+ loss_unit_distance: float = 10.0,
+ eps: float = 1e-4,
+ **kwargs,
+) -> torch.Tensor:
+ pred_aff = Rigid.from_tensor_7(traj)
+ pred_aff = Rigid(
+ Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
+ pred_aff.get_trans(),
+ )
+
+ # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
+ # backbone tensor, normalizes it, and then turns it back to a rotation
+ # matrix. To avoid a potentially numerically unstable rotation matrix
+ # to quaternion conversion, we just use the original rotation matrix
+ # outright. This one hasn't been composed a bunch of times, though, so
+ # it might be fine.
+ gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
+
+ fape_loss = compute_fape(
+ pred_aff,
+ gt_aff[None],
+ backbone_rigid_mask[None],
+ pred_aff.get_trans(),
+ gt_aff[None].get_trans(),
+ backbone_rigid_mask[None],
+ l1_clamp_distance=clamp_distance,
+ length_scale=loss_unit_distance,
+ eps=eps,
+ )
+ if use_clamped_fape is not None:
+ unclamped_fape_loss = compute_fape(
+ pred_aff,
+ gt_aff[None],
+ backbone_rigid_mask[None],
+ pred_aff.get_trans(),
+ gt_aff[None].get_trans(),
+ backbone_rigid_mask[None],
+ l1_clamp_distance=None,
+ length_scale=loss_unit_distance,
+ eps=eps,
+ )
+
+ fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
+ 1 - use_clamped_fape
+ )
+
+ # Average over the batch dimension
+ fape_loss = torch.mean(fape_loss)
+
+ return fape_loss
+
+
+def sidechain_loss(
+ sidechain_frames: torch.Tensor,
+ sidechain_atom_pos: torch.Tensor,
+ rigidgroups_gt_frames: torch.Tensor,
+ rigidgroups_alt_gt_frames: torch.Tensor,
+ rigidgroups_gt_exists: torch.Tensor,
+ renamed_atom14_gt_positions: torch.Tensor,
+ renamed_atom14_gt_exists: torch.Tensor,
+ alt_naming_is_better: torch.Tensor,
+ clamp_distance: float = 10.0,
+ length_scale: float = 10.0,
+ eps: float = 1e-4,
+ **kwargs,
+) -> torch.Tensor:
+ renamed_gt_frames = (
+ 1.0 - alt_naming_is_better[..., None, None, None]
+ ) * rigidgroups_gt_frames + alt_naming_is_better[
+ ..., None, None, None
+ ] * rigidgroups_alt_gt_frames
+
+ # Steamroll the inputs
+ sidechain_frames = sidechain_frames[-1]
+ batch_dims = sidechain_frames.shape[:-4]
+ sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
+ sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
+ renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
+ renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
+ rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
+ sidechain_atom_pos = sidechain_atom_pos[-1]
+ sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
+ renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
+ *batch_dims, -1, 3
+ )
+ renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
+
+ fape = compute_fape(
+ sidechain_frames,
+ renamed_gt_frames,
+ rigidgroups_gt_exists,
+ sidechain_atom_pos,
+ renamed_atom14_gt_positions,
+ renamed_atom14_gt_exists,
+ l1_clamp_distance=clamp_distance,
+ length_scale=length_scale,
+ eps=eps,
+ )
+
+ return fape
+
+
+def fape_loss(
+ out: Dict[str, torch.Tensor],
+ batch: Dict[str, torch.Tensor],
+ config: ml_collections.ConfigDict,
+) -> torch.Tensor:
+ bb_loss = backbone_loss(
+ traj=out["sm"]["frames"],
+ **{**batch, **config.backbone},
+ )
+
+ sc_loss = sidechain_loss(
+ out["sm"]["sidechain_frames"],
+ out["sm"]["positions"],
+ **{**batch, **config.sidechain},
+ )
+
+ loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def supervised_chi_loss(
+ angles_sin_cos: torch.Tensor,
+ unnormalized_angles_sin_cos: torch.Tensor,
+ aatype: torch.Tensor,
+ seq_mask: torch.Tensor,
+ chi_mask: torch.Tensor,
+ chi_angles_sin_cos: torch.Tensor,
+ chi_weight: float,
+ angle_norm_weight: float,
+ eps=1e-6,
+ **kwargs,
+) -> torch.Tensor:
+ """
+ Implements Algorithm 27 (torsionAngleLoss)
+
+ Args:
+ angles_sin_cos:
+ [*, N, 7, 2] predicted angles
+ unnormalized_angles_sin_cos:
+ The same angles, but unnormalized
+ aatype:
+ [*, N] residue indices
+ seq_mask:
+ [*, N] sequence mask
+ chi_mask:
+ [*, N, 7] angle mask
+ chi_angles_sin_cos:
+ [*, N, 7, 2] ground truth angles
+ chi_weight:
+ Weight for the angle component of the loss
+ angle_norm_weight:
+ Weight for the normalization component of the loss
+ Returns:
+ [*] loss tensor
+ """
+ pred_angles = angles_sin_cos[..., 3:, :]
+ residue_type_one_hot = torch.nn.functional.one_hot(
+ aatype,
+ residue_constants.restype_num + 1,
+ )
+ chi_pi_periodic = torch.einsum(
+ "...ij,jk->ik",
+ residue_type_one_hot.type(angles_sin_cos.dtype),
+ angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
+ )
+
+ true_chi = chi_angles_sin_cos[None]
+
+ shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
+ true_chi_shifted = shifted_mask * true_chi
+ sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
+ sq_chi_error_shifted = torch.sum(
+ (true_chi_shifted - pred_angles) ** 2, dim=-1
+ )
+ sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
+ # The ol' switcheroo
+ sq_chi_error = sq_chi_error.permute(
+ *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
+ )
+ sq_chi_loss = masked_mean(
+ chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
+ )
+
+ loss = chi_weight * sq_chi_loss
+
+ angle_norm = torch.sqrt(
+ torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
+ )
+ norm_error = torch.abs(angle_norm - 1.0)
+ norm_error = norm_error.permute(
+ *range(len(norm_error.shape))[1:-2], 0, -2, -1
+ )
+ angle_norm_loss = masked_mean(
+ seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
+ )
+
+ loss = loss + angle_norm_weight * angle_norm_loss
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
+ num_bins = logits.shape[-1]
+ bin_width = 1.0 / num_bins
+ bounds = torch.arange(
+ start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
+ )
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ pred_lddt_ca = torch.sum(
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
+ dim=-1,
+ )
+ return pred_lddt_ca * 100
+
+
+def lddt(
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ cutoff: float = 15.0,
+ eps: float = 1e-10,
+ per_residue: bool = True,
+) -> torch.Tensor:
+ n = all_atom_mask.shape[-2]
+ dmat_true = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ all_atom_positions[..., None, :]
+ - all_atom_positions[..., None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ dmat_pred = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ all_atom_pred_pos[..., None, :]
+ - all_atom_pred_pos[..., None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+ dists_to_score = (
+ (dmat_true < cutoff)
+ * all_atom_mask
+ * permute_final_dims(all_atom_mask, (1, 0))
+ * (1.0 - torch.eye(n, device=all_atom_mask.device))
+ )
+
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
+
+ score = (
+ (dist_l1 < 0.5).type(dist_l1.dtype)
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
+ )
+ score = score * 0.25
+
+ dims = (-1,) if per_residue else (-2, -1)
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
+
+ return score
+
+
+def lddt_ca(
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ cutoff: float = 15.0,
+ eps: float = 1e-10,
+ per_residue: bool = True,
+) -> torch.Tensor:
+ ca_pos = residue_constants.atom_order["CA"]
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
+
+ return lddt(
+ all_atom_pred_pos,
+ all_atom_positions,
+ all_atom_mask,
+ cutoff=cutoff,
+ eps=eps,
+ per_residue=per_residue,
+ )
+
+
+def lddt_loss(
+ logits: torch.Tensor,
+ all_atom_pred_pos: torch.Tensor,
+ all_atom_positions: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ resolution: torch.Tensor,
+ cutoff: float = 15.0,
+ no_bins: int = 50,
+ min_resolution: float = 0.1,
+ max_resolution: float = 3.0,
+ eps: float = 1e-10,
+ **kwargs,
+) -> torch.Tensor:
+ n = all_atom_mask.shape[-2]
+
+ ca_pos = residue_constants.atom_order["CA"]
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
+
+ score = lddt(
+ all_atom_pred_pos,
+ all_atom_positions,
+ all_atom_mask,
+ cutoff=cutoff,
+ eps=eps
+ )
+
+ score = score.detach()
+
+ bin_index = torch.floor(score * no_bins).long()
+ bin_index = torch.clamp(bin_index, max=(no_bins - 1))
+ lddt_ca_one_hot = torch.nn.functional.one_hot(
+ bin_index, num_classes=no_bins
+ )
+
+ errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
+ all_atom_mask = all_atom_mask.squeeze(-1)
+ loss = torch.sum(errors * all_atom_mask, dim=-1) / (
+ eps + torch.sum(all_atom_mask, dim=-1)
+ )
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def distogram_loss(
+ logits,
+ pseudo_beta,
+ pseudo_beta_mask,
+ min_bin=2.3125,
+ max_bin=21.6875,
+ no_bins=64,
+ eps=1e-6,
+ **kwargs,
+):
+ boundaries = torch.linspace(
+ min_bin,
+ max_bin,
+ no_bins - 1,
+ device=logits.device,
+ )
+ boundaries = boundaries ** 2
+
+ dists = torch.sum(
+ (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
+ dim=-1,
+ keepdims=True,
+ )
+
+ true_bins = torch.sum(dists > boundaries, dim=-1)
+
+ errors = softmax_cross_entropy(
+ logits,
+ torch.nn.functional.one_hot(true_bins, no_bins),
+ )
+
+ square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
+
+ # FP16-friendly sum. Equivalent to:
+ # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
+ # (eps + torch.sum(square_mask, dim=(-1, -2))))
+ denom = eps + torch.sum(square_mask, dim=(-1, -2))
+ mean = errors * square_mask
+ mean = torch.sum(mean, dim=-1)
+ mean = mean / denom[..., None]
+ mean = torch.sum(mean, dim=-1)
+
+ # Average over the batch dimensions
+ mean = torch.mean(mean)
+
+ return mean
+
+
+def _calculate_bin_centers(boundaries: torch.Tensor):
+ step = boundaries[1] - boundaries[0]
+ bin_centers = boundaries + step / 2
+ bin_centers = torch.cat(
+ [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
+ )
+ return bin_centers
+
+
+def _calculate_expected_aligned_error(
+ alignment_confidence_breaks: torch.Tensor,
+ aligned_distance_error_probs: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
+ return (
+ torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
+ bin_centers[-1],
+ )
+
+
+def compute_predicted_aligned_error(
+ logits: torch.Tensor,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ **kwargs,
+) -> Dict[str, torch.Tensor]:
+ """Computes aligned confidence metrics from logits.
+
+ Args:
+ logits: [*, num_res, num_res, num_bins] the logits output from
+ PredictedAlignedErrorHead.
+ max_bin: Maximum bin value
+ no_bins: Number of bins
+ Returns:
+ aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
+ aligned error probabilities over bins for each residue pair.
+ predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
+ error for each pair of residues.
+ max_predicted_aligned_error: [*] the maximum predicted error possible.
+ """
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+
+ aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
+ (
+ predicted_aligned_error,
+ max_predicted_aligned_error,
+ ) = _calculate_expected_aligned_error(
+ alignment_confidence_breaks=boundaries,
+ aligned_distance_error_probs=aligned_confidence_probs,
+ )
+
+ return {
+ "aligned_confidence_probs": aligned_confidence_probs,
+ "predicted_aligned_error": predicted_aligned_error,
+ "max_predicted_aligned_error": max_predicted_aligned_error,
+ }
+
+
+def compute_tm(
+ logits: torch.Tensor,
+ residue_weights: Optional[torch.Tensor] = None,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ eps: float = 1e-8,
+ **kwargs,
+) -> torch.Tensor:
+ if residue_weights is None:
+ residue_weights = logits.new_ones(logits.shape[-2])
+
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+
+ bin_centers = _calculate_bin_centers(boundaries)
+ torch.sum(residue_weights)
+ n = logits.shape[-2]
+ clipped_n = max(n, 19)
+
+ d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
+ predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
+
+ normed_residue_mask = residue_weights / (eps + residue_weights.sum())
+ per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
+ weighted = per_alignment * residue_weights
+ argmax = (weighted == torch.max(weighted)).nonzero()[0]
+ return per_alignment[tuple(argmax)]
+
+
+def tm_loss(
+ logits,
+ final_affine_tensor,
+ backbone_rigid_tensor,
+ backbone_rigid_mask,
+ resolution,
+ max_bin=31,
+ no_bins=64,
+ min_resolution: float = 0.1,
+ max_resolution: float = 3.0,
+ eps=1e-8,
+ **kwargs,
+):
+ pred_affine = Rigid.from_tensor_7(final_affine_tensor)
+ backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
+
+ def _points(affine):
+ pts = affine.get_trans()[..., None, :, :]
+ return affine.invert()[..., None].apply(pts)
+
+ sq_diff = torch.sum(
+ (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
+ )
+
+ sq_diff = sq_diff.detach()
+
+ boundaries = torch.linspace(
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
+ )
+ boundaries = boundaries ** 2
+ true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
+
+ errors = softmax_cross_entropy(
+ logits, torch.nn.functional.one_hot(true_bins, no_bins)
+ )
+
+ square_mask = (
+ backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
+ )
+
+ loss = torch.sum(errors * square_mask, dim=-1)
+ scale = 0.5 # hack to help FP16 training along
+ denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
+ loss = loss / denom[..., None]
+ loss = torch.sum(loss, dim=-1)
+ loss = loss * scale
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ # Average over the loss dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def between_residue_bond_loss(
+ pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
+ pred_atom_mask: torch.Tensor, # (*, N, 37/14)
+ residue_index: torch.Tensor, # (*, N)
+ aatype: torch.Tensor, # (*, N)
+ tolerance_factor_soft=12.0,
+ tolerance_factor_hard=12.0,
+ eps=1e-6,
+) -> Dict[str, torch.Tensor]:
+ """Flat-bottom loss to penalize structural violations between residues.
+
+ This is a loss penalizing any violation of the geometry around the peptide
+ bond between consecutive amino acids. This loss corresponds to
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
+
+ Args:
+ pred_atom_positions: Atom positions in atom37/14 representation
+ pred_atom_mask: Atom mask in atom37/14 representation
+ residue_index: Residue index for given amino acid, this is assumed to be
+ monotonically increasing.
+ aatype: Amino acid type of given residue
+ tolerance_factor_soft: soft tolerance factor measured in standard deviations
+ of pdb distributions
+ tolerance_factor_hard: hard tolerance factor measured in standard deviations
+ of pdb distributions
+
+ Returns:
+ Dict containing:
+ * 'c_n_loss_mean': Loss for peptide bond length violations
+ * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
+ by CA, C, N
+ * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
+ by C, N, CA
+ * 'per_residue_loss_sum': sum of all losses for each residue
+ * 'per_residue_violation_mask': mask denoting all residues with violation
+ present.
+ """
+ # Get the positions of the relevant backbone atoms.
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
+ this_c_pos = pred_atom_positions[..., :-1, 2, :]
+ this_c_mask = pred_atom_mask[..., :-1, 2]
+ next_n_pos = pred_atom_positions[..., 1:, 0, :]
+ next_n_mask = pred_atom_mask[..., 1:, 0]
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
+
+ # Compute loss for the C--N bond.
+ c_n_bond_length = torch.sqrt(
+ eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
+ )
+
+ # The C-N bond to proline has slightly different length because of the ring.
+ next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
+ gt_length = (
+ ~next_is_proline
+ ) * residue_constants.between_res_bond_length_c_n[
+ 0
+ ] + next_is_proline * residue_constants.between_res_bond_length_c_n[
+ 1
+ ]
+ gt_stddev = (
+ ~next_is_proline
+ ) * residue_constants.between_res_bond_length_stddev_c_n[
+ 0
+ ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
+ 1
+ ]
+ c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
+ c_n_loss_per_residue = torch.nn.functional.relu(
+ c_n_bond_length_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_c_mask * next_n_mask * has_no_gap_mask
+ c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ c_n_violation_mask = mask * (
+ c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ # Compute loss for the angles.
+ ca_c_bond_length = torch.sqrt(
+ eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
+ )
+ n_ca_bond_length = torch.sqrt(
+ eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
+ )
+
+ c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
+ c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
+ n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
+
+ ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
+ gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
+ gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
+ ca_c_n_cos_angle_error = torch.sqrt(
+ eps + (ca_c_n_cos_angle - gt_angle) ** 2
+ )
+ ca_c_n_loss_per_residue = torch.nn.functional.relu(
+ ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
+ ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ ca_c_n_violation_mask = mask * (
+ ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
+ gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
+ gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
+ c_n_ca_cos_angle_error = torch.sqrt(
+ eps + torch.square(c_n_ca_cos_angle - gt_angle)
+ )
+ c_n_ca_loss_per_residue = torch.nn.functional.relu(
+ c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
+ )
+ mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
+ c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
+ torch.sum(mask, dim=-1) + eps
+ )
+ c_n_ca_violation_mask = mask * (
+ c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
+ )
+
+ # Compute a per residue loss (equally distribute the loss to both
+ # neighbouring residues).
+ per_residue_loss_sum = (
+ c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
+ )
+ per_residue_loss_sum = 0.5 * (
+ torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ + torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
+ )
+
+ # Compute hard violations.
+ violation_mask = torch.max(
+ torch.stack(
+ [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
+ dim=-2,
+ ),
+ dim=-2,
+ )[0]
+ violation_mask = torch.maximum(
+ torch.nn.functional.pad(violation_mask, (0, 1)),
+ torch.nn.functional.pad(violation_mask, (1, 0)),
+ )
+
+ return {
+ "c_n_loss_mean": c_n_loss,
+ "ca_c_n_loss_mean": ca_c_n_loss,
+ "c_n_ca_loss_mean": c_n_ca_loss,
+ "per_residue_loss_sum": per_residue_loss_sum,
+ "per_residue_violation_mask": violation_mask,
+ }
+
+
+def between_residue_clash_loss(
+ atom14_pred_positions: torch.Tensor,
+ atom14_atom_exists: torch.Tensor,
+ atom14_atom_radius: torch.Tensor,
+ residue_index: torch.Tensor,
+ overlap_tolerance_soft=1.5,
+ overlap_tolerance_hard=1.5,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """Loss to penalize steric clashes between residues.
+
+ This is a loss penalizing any steric clashes due to non bonded atoms in
+ different peptides coming too close. This loss corresponds to the part with
+ different residues of
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
+
+ Args:
+ atom14_pred_positions: Predicted positions of atoms in
+ global prediction frame
+ atom14_atom_exists: Mask denoting whether atom at positions exists for given
+ amino acid type
+ atom14_atom_radius: Van der Waals radius for each atom.
+ residue_index: Residue index for given amino acid.
+ overlap_tolerance_soft: Soft tolerance factor.
+ overlap_tolerance_hard: Hard tolerance factor.
+
+ Returns:
+ Dict containing:
+ * 'mean_loss': average clash loss
+ * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
+ * 'per_atom_clash_mask': mask whether atom clashes with any other atom
+ shape (N, 14)
+ """
+ fp_type = atom14_pred_positions.dtype
+
+ # Create the distance matrix.
+ # (N, N, 14, 14)
+ dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., :, None, :, None, :]
+ - atom14_pred_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ # Create the mask for valid distances.
+ # shape (N, N, 14, 14)
+ dists_mask = (
+ atom14_atom_exists[..., :, None, :, None]
+ * atom14_atom_exists[..., None, :, None, :]
+ ).type(fp_type)
+
+ # Mask out all the duplicate entries in the lower triangular matrix.
+ # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
+ # are handled separately.
+ dists_mask = dists_mask * (
+ residue_index[..., :, None, None, None]
+ < residue_index[..., None, :, None, None]
+ )
+
+ # Backbone C--N bond between subsequent residues is no clash.
+ c_one_hot = torch.nn.functional.one_hot(
+ residue_index.new_tensor(2), num_classes=14
+ )
+ c_one_hot = c_one_hot.reshape(
+ *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
+ )
+ c_one_hot = c_one_hot.type(fp_type)
+ n_one_hot = torch.nn.functional.one_hot(
+ residue_index.new_tensor(0), num_classes=14
+ )
+ n_one_hot = n_one_hot.reshape(
+ *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
+ )
+ n_one_hot = n_one_hot.type(fp_type)
+
+ neighbour_mask = (
+ residue_index[..., :, None, None, None] + 1
+ ) == residue_index[..., None, :, None, None]
+ c_n_bonds = (
+ neighbour_mask
+ * c_one_hot[..., None, None, :, None]
+ * n_one_hot[..., None, None, None, :]
+ )
+ dists_mask = dists_mask * (1.0 - c_n_bonds)
+
+ # Disulfide bridge between two cysteines is no clash.
+ cys = residue_constants.restype_name_to_atom14_names["CYS"]
+ cys_sg_idx = cys.index("SG")
+ cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
+ cys_sg_idx = cys_sg_idx.reshape(
+ *((1,) * len(residue_index.shape[:-1])), 1
+ ).squeeze(-1)
+ cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
+ disulfide_bonds = (
+ cys_sg_one_hot[..., None, None, :, None]
+ * cys_sg_one_hot[..., None, None, None, :]
+ )
+ dists_mask = dists_mask * (1.0 - disulfide_bonds)
+
+ # Compute the lower bound for the allowed distances.
+ # shape (N, N, 14, 14)
+ dists_lower_bound = dists_mask * (
+ atom14_atom_radius[..., :, None, :, None]
+ + atom14_atom_radius[..., None, :, None, :]
+ )
+
+ # Compute the error.
+ # shape (N, N, 14, 14)
+ dists_to_low_error = dists_mask * torch.nn.functional.relu(
+ dists_lower_bound - overlap_tolerance_soft - dists
+ )
+
+ # Compute the mean loss.
+ # shape ()
+ mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
+
+ # Compute the per atom loss sum.
+ # shape (N, 14)
+ per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
+ dists_to_low_error, axis=(-3, -1)
+ )
+
+ # Compute the hard clash mask.
+ # shape (N, N, 14, 14)
+ clash_mask = dists_mask * (
+ dists < (dists_lower_bound - overlap_tolerance_hard)
+ )
+
+ # Compute the per atom clash.
+ # shape (N, 14)
+ per_atom_clash_mask = torch.maximum(
+ torch.amax(clash_mask, axis=(-4, -2)),
+ torch.amax(clash_mask, axis=(-3, -1)),
+ )
+
+ return {
+ "mean_loss": mean_loss, # shape ()
+ "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
+ "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
+ }
+
+
+def within_residue_violations(
+ atom14_pred_positions: torch.Tensor,
+ atom14_atom_exists: torch.Tensor,
+ atom14_dists_lower_bound: torch.Tensor,
+ atom14_dists_upper_bound: torch.Tensor,
+ tighten_bounds_for_loss=0.0,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """Loss to penalize steric clashes within residues.
+
+ This is a loss penalizing any steric violations or clashes of non-bonded atoms
+ in a given peptide. This loss corresponds to the part with
+ the same residues of
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
+
+ Args:
+ atom14_pred_positions ([*, N, 14, 3]):
+ Predicted positions of atoms in global prediction frame.
+ atom14_atom_exists ([*, N, 14]):
+ Mask denoting whether atom at positions exists for given
+ amino acid type
+ atom14_dists_lower_bound ([*, N, 14]):
+ Lower bound on allowed distances.
+ atom14_dists_upper_bound ([*, N, 14]):
+ Upper bound on allowed distances
+ tighten_bounds_for_loss ([*, N]):
+ Extra factor to tighten loss
+
+ Returns:
+ Dict containing:
+ * 'per_atom_loss_sum' ([*, N, 14]):
+ sum of all clash losses per atom, shape
+ * 'per_atom_clash_mask' ([*, N, 14]):
+ mask whether atom clashes with any other atom shape
+ """
+ # Compute the mask for each residue.
+ dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
+ dists_masks = dists_masks.reshape(
+ *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
+ )
+ dists_masks = (
+ atom14_atom_exists[..., :, :, None]
+ * atom14_atom_exists[..., :, None, :]
+ * dists_masks
+ )
+
+ # Distance matrix
+ dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., :, :, None, :]
+ - atom14_pred_positions[..., :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ # Compute the loss.
+ dists_to_low_error = torch.nn.functional.relu(
+ atom14_dists_lower_bound + tighten_bounds_for_loss - dists
+ )
+ dists_to_high_error = torch.nn.functional.relu(
+ dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
+ )
+ loss = dists_masks * (dists_to_low_error + dists_to_high_error)
+
+ # Compute the per atom loss sum.
+ per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
+
+ # Compute the violations mask.
+ violations = dists_masks * (
+ (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
+ )
+
+ # Compute the per atom violations.
+ per_atom_violations = torch.maximum(
+ torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
+ )
+
+ return {
+ "per_atom_loss_sum": per_atom_loss_sum,
+ "per_atom_violations": per_atom_violations,
+ }
+
+
+def find_structural_violations(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor,
+ violation_tolerance_factor: float,
+ clash_overlap_tolerance: float,
+ **kwargs,
+) -> Dict[str, torch.Tensor]:
+ """Computes several checks for structural violations."""
+
+ # Compute between residue backbone violations of bonds and angles.
+ connection_violations = between_residue_bond_loss(
+ pred_atom_positions=atom14_pred_positions,
+ pred_atom_mask=batch["atom14_atom_exists"],
+ residue_index=batch["residue_index"],
+ aatype=batch["aatype"],
+ tolerance_factor_soft=violation_tolerance_factor,
+ tolerance_factor_hard=violation_tolerance_factor,
+ )
+
+ # Compute the Van der Waals radius for every atom
+ # (the first letter of the atom name is the element type).
+ # Shape: (N, 14).
+ atomtype_radius = [
+ residue_constants.van_der_waals_radius[name[0]]
+ for name in residue_constants.atom_types
+ ]
+ atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
+ atom14_atom_radius = (
+ batch["atom14_atom_exists"]
+ * atomtype_radius[batch["residx_atom14_to_atom37"]]
+ )
+
+ # Compute the between residue clash loss.
+ between_residue_clashes = between_residue_clash_loss(
+ atom14_pred_positions=atom14_pred_positions,
+ atom14_atom_exists=batch["atom14_atom_exists"],
+ atom14_atom_radius=atom14_atom_radius,
+ residue_index=batch["residue_index"],
+ overlap_tolerance_soft=clash_overlap_tolerance,
+ overlap_tolerance_hard=clash_overlap_tolerance,
+ )
+
+ # Compute all within-residue violations (clashes,
+ # bond length and angle violations).
+ restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
+ overlap_tolerance=clash_overlap_tolerance,
+ bond_length_tolerance_factor=violation_tolerance_factor,
+ )
+ atom14_atom_exists = batch["atom14_atom_exists"]
+ atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
+ restype_atom14_bounds["lower_bound"]
+ )[batch["aatype"]]
+ atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
+ restype_atom14_bounds["upper_bound"]
+ )[batch["aatype"]]
+ residue_violations = within_residue_violations(
+ atom14_pred_positions=atom14_pred_positions,
+ atom14_atom_exists=batch["atom14_atom_exists"],
+ atom14_dists_lower_bound=atom14_dists_lower_bound,
+ atom14_dists_upper_bound=atom14_dists_upper_bound,
+ tighten_bounds_for_loss=0.0,
+ )
+
+ # Combine them to a single per-residue violation mask (used later for LDDT).
+ per_residue_violations_mask = torch.max(
+ torch.stack(
+ [
+ connection_violations["per_residue_violation_mask"],
+ torch.max(
+ between_residue_clashes["per_atom_clash_mask"], dim=-1
+ )[0],
+ torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
+ ],
+ dim=-1,
+ ),
+ dim=-1,
+ )[0]
+
+ return {
+ "between_residues": {
+ "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
+ "angles_ca_c_n_loss_mean": connection_violations[
+ "ca_c_n_loss_mean"
+ ], # ()
+ "angles_c_n_ca_loss_mean": connection_violations[
+ "c_n_ca_loss_mean"
+ ], # ()
+ "connections_per_residue_loss_sum": connection_violations[
+ "per_residue_loss_sum"
+ ], # (N)
+ "connections_per_residue_violation_mask": connection_violations[
+ "per_residue_violation_mask"
+ ], # (N)
+ "clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
+ "clashes_per_atom_loss_sum": between_residue_clashes[
+ "per_atom_loss_sum"
+ ], # (N, 14)
+ "clashes_per_atom_clash_mask": between_residue_clashes[
+ "per_atom_clash_mask"
+ ], # (N, 14)
+ },
+ "within_residues": {
+ "per_atom_loss_sum": residue_violations[
+ "per_atom_loss_sum"
+ ], # (N, 14)
+ "per_atom_violations": residue_violations[
+ "per_atom_violations"
+ ], # (N, 14),
+ },
+ "total_per_residue_violations_mask": per_residue_violations_mask, # (N)
+ }
+
+
+def find_structural_violations_np(
+ batch: Dict[str, np.ndarray],
+ atom14_pred_positions: np.ndarray,
+ config: ml_collections.ConfigDict,
+) -> Dict[str, np.ndarray]:
+ to_tensor = lambda x: torch.tensor(x)
+ batch = tree_map(to_tensor, batch, np.ndarray)
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
+
+ out = find_structural_violations(batch, atom14_pred_positions, **config)
+
+ to_np = lambda x: np.array(x)
+ np_out = tensor_tree_map(to_np, out)
+
+ return np_out
+
+
+def extreme_ca_ca_distance_violations(
+ pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
+ pred_atom_mask: torch.Tensor, # (N, 37(14))
+ residue_index: torch.Tensor, # (N)
+ max_angstrom_tolerance=1.5,
+ eps=1e-6,
+) -> torch.Tensor:
+ """Counts residues whose Ca is a large distance from its neighbour.
+
+ Measures the fraction of CA-CA pairs between consecutive amino acids that are
+ more than 'max_angstrom_tolerance' apart.
+
+ Args:
+ pred_atom_positions: Atom positions in atom37/14 representation
+ pred_atom_mask: Atom mask in atom37/14 representation
+ residue_index: Residue index for given amino acid, this is assumed to be
+ monotonically increasing.
+ max_angstrom_tolerance: Maximum distance allowed to not count as violation.
+ Returns:
+ Fraction of consecutive CA-CA pairs with violation.
+ """
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
+ ca_ca_distance = torch.sqrt(
+ eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
+ )
+ violations = (
+ ca_ca_distance - residue_constants.ca_ca
+ ) > max_angstrom_tolerance
+ mask = this_ca_mask * next_ca_mask * has_no_gap_mask
+ mean = masked_mean(mask, violations, -1)
+ return mean
+
+
+def compute_violation_metrics(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor, # (N, 14, 3)
+ violations: Dict[str, torch.Tensor],
+) -> Dict[str, torch.Tensor]:
+ """Compute several metrics to assess the structural violations."""
+ ret = {}
+ extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
+ pred_atom_positions=atom14_pred_positions,
+ pred_atom_mask=batch["atom14_atom_exists"],
+ residue_index=batch["residue_index"],
+ )
+ ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
+ ret["violations_between_residue_bond"] = masked_mean(
+ batch["seq_mask"],
+ violations["between_residues"][
+ "connections_per_residue_violation_mask"
+ ],
+ dim=-1,
+ )
+ ret["violations_between_residue_clash"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=torch.max(
+ violations["between_residues"]["clashes_per_atom_clash_mask"],
+ dim=-1,
+ )[0],
+ dim=-1,
+ )
+ ret["violations_within_residue"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=torch.max(
+ violations["within_residues"]["per_atom_violations"], dim=-1
+ )[0],
+ dim=-1,
+ )
+ ret["violations_per_residue"] = masked_mean(
+ mask=batch["seq_mask"],
+ value=violations["total_per_residue_violations_mask"],
+ dim=-1,
+ )
+ return ret
+
+
+def compute_violation_metrics_np(
+ batch: Dict[str, np.ndarray],
+ atom14_pred_positions: np.ndarray,
+ violations: Dict[str, np.ndarray],
+) -> Dict[str, np.ndarray]:
+ to_tensor = lambda x: torch.tensor(x)
+ batch = tree_map(to_tensor, batch, np.ndarray)
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
+ violations = tree_map(to_tensor, violations, np.ndarray)
+
+ out = compute_violation_metrics(batch, atom14_pred_positions, violations)
+
+ to_np = lambda x: np.array(x)
+ return tree_map(to_np, out, torch.Tensor)
+
+
+def violation_loss(
+ violations: Dict[str, torch.Tensor],
+ atom14_atom_exists: torch.Tensor,
+ eps=1e-6,
+ **kwargs,
+) -> torch.Tensor:
+ num_atoms = torch.sum(atom14_atom_exists)
+ l_clash = torch.sum(
+ violations["between_residues"]["clashes_per_atom_loss_sum"]
+ + violations["within_residues"]["per_atom_loss_sum"]
+ )
+ l_clash = l_clash / (eps + num_atoms)
+ loss = (
+ violations["between_residues"]["bonds_c_n_loss_mean"]
+ + violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ + violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ + l_clash
+ )
+
+ return loss
+
+
+def compute_renamed_ground_truth(
+ batch: Dict[str, torch.Tensor],
+ atom14_pred_positions: torch.Tensor,
+ eps=1e-10,
+) -> Dict[str, torch.Tensor]:
+ """
+ Find optimal renaming of ground truth based on the predicted positions.
+
+ Alg. 26 "renameSymmetricGroundTruthAtoms"
+
+ This renamed ground truth is then used for all losses,
+ such that each loss moves the atoms in the same direction.
+
+ Args:
+ batch: Dictionary containing:
+ * atom14_gt_positions: Ground truth positions.
+ * atom14_alt_gt_positions: Ground truth positions with renaming swaps.
+ * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
+ renaming swaps.
+ * atom14_gt_exists: Mask for which atoms exist in ground truth.
+ * atom14_alt_gt_exists: Mask for which atoms exist in ground truth
+ after renaming.
+ * atom14_atom_exists: Mask for whether each atom is part of the given
+ amino acid type.
+ atom14_pred_positions: Array of atom positions in global frame with shape
+ Returns:
+ Dictionary containing:
+ alt_naming_is_better: Array with 1.0 where alternative swap is better.
+ renamed_atom14_gt_positions: Array of optimal ground truth positions
+ after renaming swaps are performed.
+ renamed_atom14_gt_exists: Mask after renaming swap is performed.
+ """
+
+ pred_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_pred_positions[..., None, :, None, :]
+ - atom14_pred_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ atom14_gt_positions = batch["atom14_gt_positions"]
+ gt_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_gt_positions[..., None, :, None, :]
+ - atom14_gt_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
+ alt_gt_dists = torch.sqrt(
+ eps
+ + torch.sum(
+ (
+ atom14_alt_gt_positions[..., None, :, None, :]
+ - atom14_alt_gt_positions[..., None, :, None, :, :]
+ )
+ ** 2,
+ dim=-1,
+ )
+ )
+
+ lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
+ alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
+
+ atom14_gt_exists = batch["atom14_gt_exists"]
+ atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
+ mask = (
+ atom14_gt_exists[..., None, :, None]
+ * atom14_atom_is_ambiguous[..., None, :, None]
+ * atom14_gt_exists[..., None, :, None, :]
+ * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
+ )
+
+ per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
+ alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
+
+ fp_type = atom14_pred_positions.dtype
+ alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
+
+ renamed_atom14_gt_positions = (
+ 1.0 - alt_naming_is_better[..., None, None]
+ ) * atom14_gt_positions + alt_naming_is_better[
+ ..., None, None
+ ] * atom14_alt_gt_positions
+
+ renamed_atom14_gt_mask = (
+ 1.0 - alt_naming_is_better[..., None]
+ ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
+ "atom14_alt_gt_exists"
+ ]
+
+ return {
+ "alt_naming_is_better": alt_naming_is_better,
+ "renamed_atom14_gt_positions": renamed_atom14_gt_positions,
+ "renamed_atom14_gt_exists": renamed_atom14_gt_mask,
+ }
+
+
+def experimentally_resolved_loss(
+ logits: torch.Tensor,
+ atom37_atom_exists: torch.Tensor,
+ all_atom_mask: torch.Tensor,
+ resolution: torch.Tensor,
+ min_resolution: float,
+ max_resolution: float,
+ eps: float = 1e-8,
+ **kwargs,
+) -> torch.Tensor:
+ errors = sigmoid_cross_entropy(logits, all_atom_mask)
+ loss = torch.sum(errors * atom37_atom_exists, dim=-1)
+ loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
+ loss = torch.sum(loss, dim=-1)
+
+ loss = loss * (
+ (resolution >= min_resolution) & (resolution <= max_resolution)
+ )
+
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
+ """
+ Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
+
+ Args:
+ logits: [*, N_seq, N_res, 23] predicted residue distribution
+ true_msa: [*, N_seq, N_res] true MSA
+ bert_mask: [*, N_seq, N_res] MSA mask
+ Returns:
+ Masked MSA loss
+ """
+ errors = softmax_cross_entropy(
+ logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
+ )
+
+ # FP16-friendly averaging. Equivalent to:
+ # loss = (
+ # torch.sum(errors * bert_mask, dim=(-1, -2)) /
+ # (eps + torch.sum(bert_mask, dim=(-1, -2)))
+ # )
+ loss = errors * bert_mask
+ loss = torch.sum(loss, dim=-1)
+ scale = 0.5
+ denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
+ loss = loss / denom[..., None]
+ loss = torch.sum(loss, dim=-1)
+ loss = loss * scale
+
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def compute_drmsd(structure_1, structure_2, mask=None):
+ if(mask is not None):
+ structure_1 = structure_1 * mask[..., None]
+ structure_2 = structure_2 * mask[..., None]
+
+ d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
+ d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
+
+ d1 = d1 ** 2
+ d2 = d2 ** 2
+
+ d1 = torch.sqrt(torch.sum(d1, dim=-1))
+ d2 = torch.sqrt(torch.sum(d2, dim=-1))
+
+ drmsd = d1 - d2
+ drmsd = drmsd ** 2
+ drmsd = torch.sum(drmsd, dim=(-1, -2))
+ n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
+ drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
+ drmsd = torch.sqrt(drmsd)
+
+ return drmsd
+
+
+def compute_drmsd_np(structure_1, structure_2, mask=None):
+ structure_1 = torch.tensor(structure_1)
+ structure_2 = torch.tensor(structure_2)
+ if(mask is not None):
+ mask = torch.tensor(mask)
+
+ return compute_drmsd(structure_1, structure_2, mask)
+
+
+class AlphaFoldLoss(nn.Module):
+ """Aggregation of the various losses described in the supplement"""
+ def __init__(self, config):
+ super(AlphaFoldLoss, self).__init__()
+ self.config = config
+
+ def forward(self, out, batch, _return_breakdown=False):
+ if "violation" not in out.keys():
+ out["violation"] = find_structural_violations(
+ batch,
+ out["sm"]["positions"][-1],
+ **self.config.violation,
+ )
+
+ if "renamed_atom14_gt_positions" not in out.keys():
+ batch.update(
+ compute_renamed_ground_truth(
+ batch,
+ out["sm"]["positions"][-1],
+ )
+ )
+
+ loss_fns = {
+ "distogram": lambda: distogram_loss(
+ logits=out["distogram_logits"],
+ **{**batch, **self.config.distogram},
+ ),
+ "experimentally_resolved": lambda: experimentally_resolved_loss(
+ logits=out["experimentally_resolved_logits"],
+ **{**batch, **self.config.experimentally_resolved},
+ ),
+ "fape": lambda: fape_loss(
+ out,
+ batch,
+ self.config.fape,
+ ),
+ "lddt": lambda: lddt_loss(
+ logits=out["lddt_logits"],
+ all_atom_pred_pos=out["final_atom_positions"],
+ **{**batch, **self.config.lddt},
+ ),
+ "masked_msa": lambda: masked_msa_loss(
+ logits=out["masked_msa_logits"],
+ **{**batch, **self.config.masked_msa},
+ ),
+ "supervised_chi": lambda: supervised_chi_loss(
+ out["sm"]["angles"],
+ out["sm"]["unnormalized_angles"],
+ **{**batch, **self.config.supervised_chi},
+ ),
+ "violation": lambda: violation_loss(
+ out["violation"],
+ **batch,
+ ),
+ }
+
+ if(self.config.tm.enabled):
+ loss_fns["tm"] = lambda: tm_loss(
+ logits=out["tm_logits"],
+ **{**batch, **out, **self.config.tm},
+ )
+
+ cum_loss = 0.
+ losses = {}
+ for loss_name, loss_fn in loss_fns.items():
+ weight = self.config[loss_name].weight
+ loss = loss_fn()
+ if(torch.isnan(loss) or torch.isinf(loss)):
+ logging.warning(f"{loss_name} loss is NaN. Skipping...")
+ loss = loss.new_tensor(0., requires_grad=True)
+ cum_loss = cum_loss + weight * loss
+ losses[loss_name] = loss.detach().clone()
+
+ losses["unscaled_loss"] = cum_loss.detach().clone()
+
+ # Scale the loss by the square root of the minimum of the crop size and
+ # the (average) sequence length. See subsection 1.9.
+ seq_len = torch.mean(batch["seq_length"].float())
+ crop_len = batch["aatype"].shape[-1]
+ cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
+
+ losses["loss"] = cum_loss.detach().clone()
+
+ if(not _return_breakdown):
+ return cum_loss
+
+ return cum_loss, losses
diff --git a/openfold/utils/lr_schedulers.py b/openfold/utils/lr_schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d068be5ec826e7d1f8d49b3459b7cc66547f9f8
--- /dev/null
+++ b/openfold/utils/lr_schedulers.py
@@ -0,0 +1,107 @@
+import torch
+
+
+class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
+ """ Implements the learning rate schedule defined in the AlphaFold 2
+ supplement. A linear warmup is followed by a plateau at the maximum
+ learning rate and then exponential decay.
+
+ Note that the initial learning rate of the optimizer in question is
+ ignored; use this class' base_lr parameter to specify the starting
+ point of the warmup.
+ """
+ def __init__(self,
+ optimizer,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ base_lr: float = 0.,
+ max_lr: float = 0.001,
+ warmup_no_steps: int = 1000,
+ start_decay_after_n_steps: int = 50000,
+ decay_every_n_steps: int = 50000,
+ decay_factor: float = 0.95,
+ ):
+ step_counts = {
+ "warmup_no_steps": warmup_no_steps,
+ "start_decay_after_n_steps": start_decay_after_n_steps,
+ }
+
+ for k,v in step_counts.items():
+ if(v < 0):
+ raise ValueError(f"{k} must be nonnegative")
+
+ if(warmup_no_steps > start_decay_after_n_steps):
+ raise ValueError(
+ "warmup_no_steps must not exceed start_decay_after_n_steps"
+ )
+
+ self.optimizer = optimizer
+ self.last_epoch = last_epoch
+ self.verbose = verbose
+ self.base_lr = base_lr
+ self.max_lr = max_lr
+ self.warmup_no_steps = warmup_no_steps
+ self.start_decay_after_n_steps = start_decay_after_n_steps
+ self.decay_every_n_steps = decay_every_n_steps
+ self.decay_factor = decay_factor
+
+ super(AlphaFoldLRScheduler, self).__init__(
+ optimizer,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+ def state_dict(self):
+ state_dict = {
+ k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
+ }
+
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ self.__dict__.update(state_dict)
+
+ def get_lr(self):
+ if(not self._get_lr_called_within_step):
+ raise RuntimeError(
+ "To get the last learning rate computed by the scheduler, use "
+ "get_last_lr()"
+ )
+
+ step_no = self.last_epoch
+
+ if(step_no <= self.warmup_no_steps):
+ lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
+ elif(step_no > self.start_decay_after_n_steps):
+ steps_since_decay = step_no - self.start_decay_after_n_steps
+ exp = (steps_since_decay // self.decay_every_n_steps) + 1
+ lr = self.max_lr * (self.decay_factor ** exp)
+ else: # plateau
+ lr = self.max_lr
+
+ return [lr for group in self.optimizer.param_groups]
+
+
+class TestAF2LRScheduler(AlphaFoldLRScheduler):
+ def __init__(self,
+ optimizer,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ base_lr: float = 0.,
+ max_lr: float = 0.0001,
+ warmup_no_steps: int = 10,
+ start_decay_after_n_steps: int = 100,
+ decay_every_n_steps: int = 10,
+ decay_factor: float = 0.95,
+ ):
+ super().__init__(
+ optimizer,
+ last_epoch,
+ verbose,
+ base_lr,
+ max_lr,
+ warmup_no_steps,
+ start_decay_after_n_steps,
+ decay_every_n_steps,
+ decay_factor,
+ )
\ No newline at end of file
diff --git a/openfold/utils/precision_utils.py b/openfold/utils/precision_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81ef568637a1f6ee30a8fc295edbb60a70f4c602
--- /dev/null
+++ b/openfold/utils/precision_utils.py
@@ -0,0 +1,23 @@
+# Copyright 2022 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+
+import torch
+
+def is_fp16_enabled():
+ # Autocast world
+ fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
+
+ return fp16_enabled
diff --git a/openfold/utils/rigid_utils.py b/openfold/utils/rigid_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49d48d76764c42516f6e33da9f738fb45d0c7b5
--- /dev/null
+++ b/openfold/utils/rigid_utils.py
@@ -0,0 +1,1468 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Any, Sequence, Callable, Optional
+
+import numpy as np
+import torch
+
+
+def rot_matmul(
+ a: torch.Tensor,
+ b: torch.Tensor
+) -> torch.Tensor:
+ """
+ Performs matrix multiplication of two rotation matrix tensors. Written
+ out by hand to avoid AMP downcasting.
+
+ Args:
+ a: [*, 3, 3] left multiplicand
+ b: [*, 3, 3] right multiplicand
+ Returns:
+ The product ab
+ """
+ row_1 = torch.stack(
+ [
+ a[..., 0, 0] * b[..., 0, 0]
+ + a[..., 0, 1] * b[..., 1, 0]
+ + a[..., 0, 2] * b[..., 2, 0],
+ a[..., 0, 0] * b[..., 0, 1]
+ + a[..., 0, 1] * b[..., 1, 1]
+ + a[..., 0, 2] * b[..., 2, 1],
+ a[..., 0, 0] * b[..., 0, 2]
+ + a[..., 0, 1] * b[..., 1, 2]
+ + a[..., 0, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+ row_2 = torch.stack(
+ [
+ a[..., 1, 0] * b[..., 0, 0]
+ + a[..., 1, 1] * b[..., 1, 0]
+ + a[..., 1, 2] * b[..., 2, 0],
+ a[..., 1, 0] * b[..., 0, 1]
+ + a[..., 1, 1] * b[..., 1, 1]
+ + a[..., 1, 2] * b[..., 2, 1],
+ a[..., 1, 0] * b[..., 0, 2]
+ + a[..., 1, 1] * b[..., 1, 2]
+ + a[..., 1, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+ row_3 = torch.stack(
+ [
+ a[..., 2, 0] * b[..., 0, 0]
+ + a[..., 2, 1] * b[..., 1, 0]
+ + a[..., 2, 2] * b[..., 2, 0],
+ a[..., 2, 0] * b[..., 0, 1]
+ + a[..., 2, 1] * b[..., 1, 1]
+ + a[..., 2, 2] * b[..., 2, 1],
+ a[..., 2, 0] * b[..., 0, 2]
+ + a[..., 2, 1] * b[..., 1, 2]
+ + a[..., 2, 2] * b[..., 2, 2],
+ ],
+ dim=-1,
+ )
+
+ return torch.stack([row_1, row_2, row_3], dim=-2)
+
+
+def rot_vec_mul(
+ r: torch.Tensor,
+ t: torch.Tensor
+) -> torch.Tensor:
+ """
+ Applies a rotation to a vector. Written out by hand to avoid transfer
+ to avoid AMP downcasting.
+
+ Args:
+ r: [*, 3, 3] rotation matrices
+ t: [*, 3] coordinate tensors
+ Returns:
+ [*, 3] rotated coordinates
+ """
+ x = t[..., 0]
+ y = t[..., 1]
+ z = t[..., 2]
+ return torch.stack(
+ [
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
+ ],
+ dim=-1,
+ )
+
+
+def identity_rot_mats(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ rots = torch.eye(
+ 3, dtype=dtype, device=device, requires_grad=requires_grad
+ )
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
+ rots = rots.expand(*batch_dims, -1, -1)
+
+ return rots
+
+
+def identity_trans(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ trans = torch.zeros(
+ (*batch_dims, 3),
+ dtype=dtype,
+ device=device,
+ requires_grad=requires_grad
+ )
+ return trans
+
+
+def identity_quats(
+ batch_dims: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ quat = torch.zeros(
+ (*batch_dims, 4),
+ dtype=dtype,
+ device=device,
+ requires_grad=requires_grad
+ )
+
+ with torch.no_grad():
+ quat[..., 0] = 1
+
+ return quat
+
+
+_quat_elements = ["a", "b", "c", "d"]
+_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
+_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
+
+
+def _to_mat(pairs):
+ mat = np.zeros((4, 4))
+ for pair in pairs:
+ key, value = pair
+ ind = _qtr_ind_dict[key]
+ mat[ind // 4][ind % 4] = value
+
+ return mat
+
+
+_QTR_MAT = np.zeros((4, 4, 3, 3))
+_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
+_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
+_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
+_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
+_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
+_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
+_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
+_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
+_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
+
+
+def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Converts a quaternion to a rotation matrix.
+
+ Args:
+ quat: [*, 4] quaternions
+ Returns:
+ [*, 3, 3] rotation matrices
+ """
+ # [*, 4, 4]
+ quat = quat[..., None] * quat[..., None, :]
+
+ # [4, 4, 3, 3]
+ mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
+
+ # [*, 4, 4, 3, 3]
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
+ quat = quat[..., None, None] * shaped_qtr_mat
+
+ # [*, 3, 3]
+ return torch.sum(quat, dim=(-3, -4))
+
+
+def rot_to_quat(
+ rot: torch.Tensor,
+):
+ if(rot.shape[-2:] != (3, 3)):
+ raise ValueError("Input rotation is incorrectly shaped")
+
+ rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
+
+ k = [
+ [ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
+ [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
+ [ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
+ [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
+ ]
+
+ k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
+
+ _, vectors = torch.linalg.eigh(k)
+ return vectors[..., -1]
+
+
+_QUAT_MULTIPLY = np.zeros((4, 4, 4))
+_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
+ [ 0,-1, 0, 0],
+ [ 0, 0,-1, 0],
+ [ 0, 0, 0,-1]]
+
+_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
+ [ 1, 0, 0, 0],
+ [ 0, 0, 0, 1],
+ [ 0, 0,-1, 0]]
+
+_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
+ [ 0, 0, 0,-1],
+ [ 1, 0, 0, 0],
+ [ 0, 1, 0, 0]]
+
+_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
+ [ 0, 0, 1, 0],
+ [ 0,-1, 0, 0],
+ [ 1, 0, 0, 0]]
+
+_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
+
+
+def quat_multiply(quat1, quat2):
+ """Multiply a quaternion by another quaternion."""
+ mat = quat1.new_tensor(_QUAT_MULTIPLY)
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
+ return torch.sum(
+ reshaped_mat *
+ quat1[..., :, None, None] *
+ quat2[..., None, :, None],
+ dim=(-3, -2)
+ )
+
+
+def quat_multiply_by_vec(quat, vec):
+ """Multiply a quaternion by a pure-vector quaternion."""
+ mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
+ return torch.sum(
+ reshaped_mat *
+ quat[..., :, None, None] *
+ vec[..., None, :, None],
+ dim=(-3, -2)
+ )
+
+
+def invert_rot_mat(rot_mat: torch.Tensor):
+ return rot_mat.transpose(-1, -2)
+
+
+def invert_quat(quat: torch.Tensor):
+ quat_prime = quat.clone()
+ quat_prime[..., 1:] *= -1
+ inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
+ return inv
+
+
+class Rotation:
+ """
+ A 3D rotation. Depending on how the object is initialized, the
+ rotation is represented by either a rotation matrix or a
+ quaternion, though both formats are made available by helper functions.
+ To simplify gradient computation, the underlying format of the
+ rotation cannot be changed in-place. Like Rigid, the class is designed
+ to mimic the behavior of a torch Tensor, almost as if each Rotation
+ object were a tensor of rotations, in one format or another.
+ """
+ def __init__(self,
+ rot_mats: Optional[torch.Tensor] = None,
+ quats: Optional[torch.Tensor] = None,
+ normalize_quats: bool = True,
+ ):
+ """
+ Args:
+ rot_mats:
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
+ quats
+ quats:
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If
+ normalize_quats is not True, must be a unit quaternion
+ normalize_quats:
+ If quats is specified, whether to normalize quats
+ """
+ if((rot_mats is None and quats is None) or
+ (rot_mats is not None and quats is not None)):
+ raise ValueError("Exactly one input argument must be specified")
+
+ if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
+ (quats is not None and quats.shape[-1] != 4)):
+ raise ValueError(
+ "Incorrectly shaped rotation matrix or quaternion"
+ )
+
+ # Force full-precision
+ if(quats is not None):
+ quats = quats.type(torch.float32)
+ if(rot_mats is not None):
+ rot_mats = rot_mats.type(torch.float32)
+
+ if(quats is not None and normalize_quats):
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
+
+ self._rot_mats = rot_mats
+ self._quats = quats
+
+ @staticmethod
+ def identity(
+ shape,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+ fmt: str = "quat",
+ ):
+ """
+ Returns an identity Rotation.
+
+ Args:
+ shape:
+ The "shape" of the resulting Rotation object. See documentation
+ for the shape property
+ dtype:
+ The torch dtype for the rotation
+ device:
+ The torch device for the new rotation
+ requires_grad:
+ Whether the underlying tensors in the new rotation object
+ should require gradient computation
+ fmt:
+ One of "quat" or "rot_mat". Determines the underlying format
+ of the new object's rotation
+ Returns:
+ A new identity rotation
+ """
+ if(fmt == "rot_mat"):
+ rot_mats = identity_rot_mats(
+ shape, dtype, device, requires_grad,
+ )
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(fmt == "quat"):
+ quats = identity_quats(shape, dtype, device, requires_grad)
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError(f"Invalid format: f{fmt}")
+
+ # Magic methods
+
+ def __getitem__(self, index: Any):
+ """
+ Allows torch-style indexing over the virtual shape of the rotation
+ object. See documentation for the shape property.
+
+ Args:
+ index:
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
+ Returns:
+ The indexed rotation
+ """
+ if type(index) != tuple:
+ index = (index,)
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
+ return Rotation(rot_mats=rot_mats)
+ elif(self._quats is not None):
+ quats = self._quats[index + (slice(None),)]
+ return Rotation(quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def __mul__(self,
+ right: torch.Tensor,
+ ):
+ """
+ Pointwise left multiplication of the rotation with a tensor. Can be
+ used to e.g. mask the Rotation.
+
+ Args:
+ right:
+ The tensor multiplicand
+ Returns:
+ The product
+ """
+ if not(isinstance(right, torch.Tensor)):
+ raise TypeError("The other multiplicand must be a Tensor")
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats * right[..., None, None]
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = self._quats * right[..., None]
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def __rmul__(self,
+ left: torch.Tensor,
+ ):
+ """
+ Reverse pointwise multiplication of the rotation with a tensor.
+
+ Args:
+ left:
+ The left multiplicand
+ Returns:
+ The product
+ """
+ return self.__mul__(left)
+
+ # Properties
+
+ @property
+ def shape(self) -> torch.Size:
+ """
+ Returns the virtual shape of the rotation object. This shape is
+ defined as the batch dimensions of the underlying rotation matrix
+ or quaternion. If the Rotation was initialized with a [10, 3, 3]
+ rotation matrix tensor, for example, the resulting shape would be
+ [10].
+
+ Returns:
+ The virtual shape of the rotation object
+ """
+ s = None
+ if(self._quats is not None):
+ s = self._quats.shape[:-1]
+ else:
+ s = self._rot_mats.shape[:-2]
+
+ return s
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ Returns the dtype of the underlying rotation.
+
+ Returns:
+ The dtype of the underlying rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.dtype
+ elif(self._quats is not None):
+ return self._quats.dtype
+ else:
+ raise ValueError("Both rotations are None")
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the underlying rotation
+
+ Returns:
+ The device of the underlying rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.device
+ elif(self._quats is not None):
+ return self._quats.device
+ else:
+ raise ValueError("Both rotations are None")
+
+ @property
+ def requires_grad(self) -> bool:
+ """
+ Returns the requires_grad property of the underlying rotation
+
+ Returns:
+ The requires_grad property of the underlying tensor
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats.requires_grad
+ elif(self._quats is not None):
+ return self._quats.requires_grad
+ else:
+ raise ValueError("Both rotations are None")
+
+ def get_rot_mats(self) -> torch.Tensor:
+ """
+ Returns the underlying rotation as a rotation matrix tensor.
+
+ Returns:
+ The rotation as a rotation matrix tensor
+ """
+ rot_mats = self._rot_mats
+ if(rot_mats is None):
+ if(self._quats is None):
+ raise ValueError("Both rotations are None")
+ else:
+ rot_mats = quat_to_rot(self._quats)
+
+ return rot_mats
+
+ def get_quats(self) -> torch.Tensor:
+ """
+ Returns the underlying rotation as a quaternion tensor.
+
+ Depending on whether the Rotation was initialized with a
+ quaternion, this function may call torch.linalg.eigh.
+
+ Returns:
+ The rotation as a quaternion tensor.
+ """
+ quats = self._quats
+ if(quats is None):
+ if(self._rot_mats is None):
+ raise ValueError("Both rotations are None")
+ else:
+ quats = rot_to_quat(self._rot_mats)
+
+ return quats
+
+ def get_cur_rot(self) -> torch.Tensor:
+ """
+ Return the underlying rotation in its current form
+
+ Returns:
+ The stored rotation
+ """
+ if(self._rot_mats is not None):
+ return self._rot_mats
+ elif(self._quats is not None):
+ return self._quats
+ else:
+ raise ValueError("Both rotations are None")
+
+ def get_rotvec(self, eps=1e-6) -> torch.Tensor:
+ """
+ Return the underlying axis-angle rotation vector.
+
+ Follow's scipy's implementation:
+ https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
+
+ Returns:
+ The stored rotation as a axis-angle vector.
+ """
+ quat = self.get_quats()
+ # w > 0 to ensure 0 <= angle <= pi
+ flip = (quat[..., :1] < 0).float()
+ quat = (-1 * quat) * flip + (1 - flip) * quat
+
+ angle = 2 * torch.atan2(
+ torch.linalg.norm(quat[..., 1:], dim=-1),
+ quat[..., 0]
+ )
+
+ angle2 = angle * angle
+ small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
+ large_angle_scales = angle / torch.sin(angle / 2 + eps)
+
+ small_angles = (angle <= 1e-3).float()
+ rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
+ rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
+ return rot_vec
+
+ # Rotation functions
+
+ def compose_q_update_vec(self,
+ q_update_vec: torch.Tensor,
+ normalize_quats: bool = True,
+ update_mask: torch.Tensor = None,
+ ):
+ """
+ Returns a new quaternion Rotation after updating the current
+ object's underlying rotation with a quaternion update, formatted
+ as a [*, 3] tensor whose final three columns represent x, y, z such
+ that (1, x, y, z) is the desired (not necessarily unit) quaternion
+ update.
+
+ Args:
+ q_update_vec:
+ A [*, 3] quaternion update tensor
+ normalize_quats:
+ Whether to normalize the output quaternion
+ Returns:
+ An updated Rotation
+ """
+ quats = self.get_quats()
+ quat_update = quat_multiply_by_vec(quats, q_update_vec)
+ if update_mask is not None:
+ quat_update = quat_update * update_mask
+ new_quats = quats + quat_update
+ return Rotation(
+ rot_mats=None,
+ quats=new_quats,
+ normalize_quats=normalize_quats,
+ )
+
+ def compose_r(self, r):
+ """
+ Compose the rotation matrices of the current Rotation object with
+ those of another.
+
+ Args:
+ r:
+ An update rotation object
+ Returns:
+ An updated rotation object
+ """
+ r1 = self.get_rot_mats()
+ r2 = r.get_rot_mats()
+ new_rot_mats = rot_matmul(r1, r2)
+ return Rotation(rot_mats=new_rot_mats, quats=None)
+
+ def compose_q(self, r, normalize_quats: bool = True):
+ """
+ Compose the quaternions of the current Rotation object with those
+ of another.
+
+ Depending on whether either Rotation was initialized with
+ quaternions, this function may call torch.linalg.eigh.
+
+ Args:
+ r:
+ An update rotation object
+ Returns:
+ An updated rotation object
+ """
+ q1 = self.get_quats()
+ q2 = r.get_quats()
+ new_quats = quat_multiply(q1, q2)
+ return Rotation(
+ rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
+ )
+
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the current Rotation as a rotation matrix to a set of 3D
+ coordinates.
+
+ Args:
+ pts:
+ A [*, 3] set of points
+ Returns:
+ [*, 3] rotated points
+ """
+ rot_mats = self.get_rot_mats()
+ return rot_vec_mul(rot_mats, pts)
+
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
+ """
+ The inverse of the apply() method.
+
+ Args:
+ pts:
+ A [*, 3] set of points
+ Returns:
+ [*, 3] inverse-rotated points
+ """
+ rot_mats = self.get_rot_mats()
+ inv_rot_mats = invert_rot_mat(rot_mats)
+ return rot_vec_mul(inv_rot_mats, pts)
+
+ def invert(self) :
+ """
+ Returns the inverse of the current Rotation.
+
+ Returns:
+ The inverse of the current Rotation
+ """
+ if(self._rot_mats is not None):
+ return Rotation(
+ rot_mats=invert_rot_mat(self._rot_mats),
+ quats=None
+ )
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=invert_quat(self._quats),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ # "Tensor" stuff
+
+ def unsqueeze(self,
+ dim: int,
+ ):
+ """
+ Analogous to torch.unsqueeze. The dimension is relative to the
+ shape of the Rotation object.
+
+ Args:
+ dim: A positive or negative dimension index.
+ Returns:
+ The unsqueezed Rotation.
+ """
+ if dim >= len(self.shape):
+ raise ValueError("Invalid dimension")
+
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ @staticmethod
+ def cat(
+ rs,
+ dim: int,
+ ):
+ """
+ Concatenates rotations along one of the batch dimensions. Analogous
+ to torch.cat().
+
+ Note that the output of this operation is always a rotation matrix,
+ regardless of the format of input rotations.
+
+ Args:
+ rs:
+ A list of rotation objects
+ dim:
+ The dimension along which the rotations should be
+ concatenated
+ Returns:
+ A concatenated Rotation object in rotation matrix format
+ """
+ rot_mats = [r.get_rot_mats() for r in rs]
+ rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
+
+ return Rotation(rot_mats=rot_mats, quats=None)
+
+ def map_tensor_fn(self,
+ fn
+ ):
+ """
+ Apply a Tensor -> Tensor function to underlying rotation tensors,
+ mapping over the rotation dimension(s). Can be used e.g. to sum out
+ a one-hot batch dimension.
+
+ Args:
+ fn:
+ A Tensor -> Tensor function to be mapped over the Rotation
+ Returns:
+ The transformed Rotation object
+ """
+ if(self._rot_mats is not None):
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
+ rot_mats = torch.stack(
+ list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
+ )
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
+ return Rotation(rot_mats=rot_mats, quats=None)
+ elif(self._quats is not None):
+ quats = torch.stack(
+ list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
+ )
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
+ else:
+ raise ValueError("Both rotations are None")
+
+ def cuda(self):
+ """
+ Analogous to the cuda() method of torch Tensors
+
+ Returns:
+ A copy of the Rotation in CUDA memory
+ """
+ if(self._rot_mats is not None):
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.cuda(),
+ normalize_quats=False
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ def to(self,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype]
+ ):
+ """
+ Analogous to the to() method of torch Tensors
+
+ Args:
+ device:
+ A torch device
+ dtype:
+ A torch dtype
+ Returns:
+ A copy of the Rotation using the new device and dtype
+ """
+ if(self._rot_mats is not None):
+ return Rotation(
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
+ quats=None,
+ )
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.to(device=device, dtype=dtype),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+ def detach(self):
+ """
+ Returns a copy of the Rotation whose underlying Tensor has been
+ detached from its torch graph.
+
+ Returns:
+ A copy of the Rotation whose underlying Tensor has been detached
+ from its torch graph
+ """
+ if(self._rot_mats is not None):
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
+ elif(self._quats is not None):
+ return Rotation(
+ rot_mats=None,
+ quats=self._quats.detach(),
+ normalize_quats=False,
+ )
+ else:
+ raise ValueError("Both rotations are None")
+
+
+class Rigid:
+ """
+ A class representing a rigid transformation. Little more than a wrapper
+ around two objects: a Rotation object and a [*, 3] translation
+ Designed to behave approximately like a single torch tensor with the
+ shape of the shared batch dimensions of its component parts.
+ """
+ def __init__(self,
+ rots: Optional[Rotation],
+ trans: Optional[torch.Tensor],
+ ):
+ """
+ Args:
+ rots: A [*, 3, 3] rotation tensor
+ trans: A corresponding [*, 3] translation tensor
+ """
+ # (we need device, dtype, etc. from at least one input)
+
+ batch_dims, dtype, device, requires_grad = None, None, None, None
+ if(trans is not None):
+ batch_dims = trans.shape[:-1]
+ dtype = trans.dtype
+ device = trans.device
+ requires_grad = trans.requires_grad
+ elif(rots is not None):
+ batch_dims = rots.shape
+ dtype = rots.dtype
+ device = rots.device
+ requires_grad = rots.requires_grad
+ else:
+ raise ValueError("At least one input argument must be specified")
+
+ if(rots is None):
+ rots = Rotation.identity(
+ batch_dims, dtype, device, requires_grad,
+ )
+ elif(trans is None):
+ trans = identity_trans(
+ batch_dims, dtype, device, requires_grad,
+ )
+
+ if((rots.shape != trans.shape[:-1]) or
+ (rots.device != trans.device)):
+ raise ValueError("Rots and trans incompatible")
+
+ # Force full precision. Happens to the rotations automatically.
+ trans = trans.type(torch.float32)
+
+ self._rots = rots
+ self._trans = trans
+
+ @staticmethod
+ def identity(
+ shape: Tuple[int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ requires_grad: bool = True,
+ fmt: str = "quat",
+ ):
+ """
+ Constructs an identity transformation.
+
+ Args:
+ shape:
+ The desired shape
+ dtype:
+ The dtype of both internal tensors
+ device:
+ The device of both internal tensors
+ requires_grad:
+ Whether grad should be enabled for the internal tensors
+ Returns:
+ The identity transformation
+ """
+ return Rigid(
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
+ identity_trans(shape, dtype, device, requires_grad),
+ )
+
+ def __getitem__(self,
+ index: Any,
+ ):
+ """
+ Indexes the affine transformation with PyTorch-style indices.
+ The index is applied to the shared dimensions of both the rotation
+ and the translation.
+
+ E.g.::
+
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
+ t = Rigid(r, torch.rand(10, 10, 3))
+ indexed = t[3, 4:6]
+ assert(indexed.shape == (2,))
+ assert(indexed.get_rots().shape == (2,))
+ assert(indexed.get_trans().shape == (2, 3))
+
+ Args:
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
+ or (3, slice(0, 1, None))
+ Returns:
+ The indexed tensor
+ """
+ if type(index) != tuple:
+ index = (index,)
+
+ return Rigid(
+ self._rots[index],
+ self._trans[index + (slice(None),)],
+ )
+
+ def __mul__(self,
+ right: torch.Tensor,
+ ):
+ """
+ Pointwise left multiplication of the transformation with a tensor.
+ Can be used to e.g. mask the Rigid.
+
+ Args:
+ right:
+ The tensor multiplicand
+ Returns:
+ The product
+ """
+ if not(isinstance(right, torch.Tensor)):
+ raise TypeError("The other multiplicand must be a Tensor")
+
+ new_rots = self._rots * right
+ new_trans = self._trans * right[..., None]
+
+ return Rigid(new_rots, new_trans)
+
+ def __rmul__(self,
+ left: torch.Tensor,
+ ):
+ """
+ Reverse pointwise multiplication of the transformation with a
+ tensor.
+
+ Args:
+ left:
+ The left multiplicand
+ Returns:
+ The product
+ """
+ return self.__mul__(left)
+
+ @property
+ def shape(self) -> torch.Size:
+ """
+ Returns the shape of the shared dimensions of the rotation and
+ the translation.
+
+ Returns:
+ The shape of the transformation
+ """
+ s = self._trans.shape[:-1]
+ return s
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Returns the device on which the Rigid's tensors are located.
+
+ Returns:
+ The device on which the Rigid's tensors are located
+ """
+ return self._trans.device
+
+ def get_rots(self) -> Rotation:
+ """
+ Getter for the rotation.
+
+ Returns:
+ The rotation object
+ """
+ return self._rots
+
+ def get_trans(self) -> torch.Tensor:
+ """
+ Getter for the translation.
+
+ Returns:
+ The stored translation
+ """
+ return self._trans
+
+ def compose_q_update_vec(self,
+ q_update_vec: torch.Tensor,
+ update_mask: torch.Tensor=None,
+ ):
+ """
+ Composes the transformation with a quaternion update vector of
+ shape [*, 6], where the final 6 columns represent the x, y, and
+ z values of a quaternion of form (1, x, y, z) followed by a 3D
+ translation.
+
+ Args:
+ q_vec: The quaternion update vector.
+ Returns:
+ The composed transformation.
+ """
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
+ new_rots = self._rots.compose_q_update_vec(
+ q_vec, update_mask=update_mask)
+
+ trans_update = self._rots.apply(t_vec)
+ if update_mask is not None:
+ trans_update = trans_update * update_mask
+ new_translation = self._trans + trans_update
+
+ return Rigid(new_rots, new_translation)
+
+ def compose_tran_update_vec(self,
+ t_vec: torch.Tensor,
+ update_mask: torch.Tensor=None,
+ ):
+ """
+ Composes the transformation with a quaternion update vector of
+ shape [*, 3], where columns represent a 3D translation.
+
+ Args:
+ q_vec: The quaternion update vector.
+ Returns:
+ The composed transformation.
+ """
+ trans_update = self._rots.apply(t_vec)
+ if update_mask is not None:
+ trans_update = trans_update * update_mask
+ new_translation = self._trans + trans_update
+
+ return Rigid(self._rots, new_translation)
+
+ def compose(self,
+ r,
+ ):
+ """
+ Composes the current rigid object with another.
+
+ Args:
+ r:
+ Another Rigid object
+ Returns:
+ The composition of the two transformations
+ """
+ new_rot = self._rots.compose_r(r._rots)
+ new_trans = self._rots.apply(r._trans) + self._trans
+ return Rigid(new_rot, new_trans)
+
+ def compose_r(self,
+ rot,
+ order='right'
+ ):
+ """
+ Composes the current rigid object with another.
+
+ Args:
+ r:
+ Another Rigid object
+ order:
+ Order in which to perform rotation multiplication.
+ Returns:
+ The composition of the two transformations
+ """
+ if order == 'right':
+ new_rot = self._rots.compose_r(rot)
+ elif order == 'left':
+ new_rot = rot.compose_r(self._rots)
+ else:
+ raise ValueError(f'Unrecognized multiplication order: {order}')
+ return Rigid(new_rot, self._trans)
+
+ def apply(self,
+ pts: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Applies the transformation to a coordinate tensor.
+
+ Args:
+ pts: A [*, 3] coordinate tensor.
+ Returns:
+ The transformed points.
+ """
+ rotated = self._rots.apply(pts)
+ return rotated + self._trans
+
+ def invert_apply(self,
+ pts: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Applies the inverse of the transformation to a coordinate tensor.
+
+ Args:
+ pts: A [*, 3] coordinate tensor
+ Returns:
+ The transformed points.
+ """
+ pts = pts - self._trans
+ return self._rots.invert_apply(pts)
+
+ def invert(self):
+ """
+ Inverts the transformation.
+
+ Returns:
+ The inverse transformation.
+ """
+ rot_inv = self._rots.invert()
+ trn_inv = rot_inv.apply(self._trans)
+
+ return Rigid(rot_inv, -1 * trn_inv)
+
+ def map_tensor_fn(self,
+ fn
+ ):
+ """
+ Apply a Tensor -> Tensor function to underlying translation and
+ rotation tensors, mapping over the translation/rotation dimensions
+ respectively.
+
+ Args:
+ fn:
+ A Tensor -> Tensor function to be mapped over the Rigid
+ Returns:
+ The transformed Rigid object
+ """
+ new_rots = self._rots.map_tensor_fn(fn)
+ new_trans = torch.stack(
+ list(map(fn, torch.unbind(self._trans, dim=-1))),
+ dim=-1
+ )
+
+ return Rigid(new_rots, new_trans)
+
+ def to_tensor_4x4(self) -> torch.Tensor:
+ """
+ Converts a transformation to a homogenous transformation tensor.
+
+ Returns:
+ A [*, 4, 4] homogenous transformation tensor
+ """
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
+ tensor[..., :3, 3] = self._trans
+ tensor[..., 3, 3] = 1
+ return tensor
+
+ @staticmethod
+ def from_tensor_4x4(
+ t: torch.Tensor
+ ):
+ """
+ Constructs a transformation from a homogenous transformation
+ tensor.
+
+ Args:
+ t: [*, 4, 4] homogenous transformation tensor
+ Returns:
+ T object with shape [*]
+ """
+ if(t.shape[-2:] != (4, 4)):
+ raise ValueError("Incorrectly shaped input tensor")
+
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
+ trans = t[..., :3, 3]
+
+ return Rigid(rots, trans)
+
+ def to_tensor_7(self) -> torch.Tensor:
+ """
+ Converts a transformation to a tensor with 7 final columns, four
+ for the quaternion followed by three for the translation.
+
+ Returns:
+ A [*, 7] tensor representation of the transformation
+ """
+ tensor = self._trans.new_zeros((*self.shape, 7))
+ tensor[..., :4] = self._rots.get_quats()
+ tensor[..., 4:] = self._trans
+
+ return tensor
+
+ @staticmethod
+ def from_tensor_7(
+ t: torch.Tensor,
+ normalize_quats: bool = False,
+ ):
+ if(t.shape[-1] != 7):
+ raise ValueError("Incorrectly shaped input tensor")
+
+ quats, trans = t[..., :4], t[..., 4:]
+
+ rots = Rotation(
+ rot_mats=None,
+ quats=quats,
+ normalize_quats=normalize_quats
+ )
+
+ return Rigid(rots, trans)
+
+ @staticmethod
+ def from_3_points(
+ p_neg_x_axis: torch.Tensor,
+ origin: torch.Tensor,
+ p_xy_plane: torch.Tensor,
+ eps: float = 1e-8
+ ):
+ """
+ Implements algorithm 21. Constructs transformations from sets of 3
+ points using the Gram-Schmidt algorithm.
+
+ Args:
+ p_neg_x_axis: [*, 3] coordinates
+ origin: [*, 3] coordinates used as frame origins
+ p_xy_plane: [*, 3] coordinates
+ eps: Small epsilon value
+ Returns:
+ A transformation object of shape [*]
+ """
+ p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
+ origin = torch.unbind(origin, dim=-1)
+ p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
+
+ e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
+
+ denom = torch.sqrt(sum((c * c for c in e0)) + eps)
+ e0 = [c / denom for c in e0]
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps)
+ e1 = [c / denom for c in e1]
+ e2 = [
+ e0[1] * e1[2] - e0[2] * e1[1],
+ e0[2] * e1[0] - e0[0] * e1[2],
+ e0[0] * e1[1] - e0[1] * e1[0],
+ ]
+
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
+
+ rot_obj = Rotation(rot_mats=rots, quats=None)
+
+ return Rigid(rot_obj, torch.stack(origin, dim=-1))
+
+ def unsqueeze(self,
+ dim: int,
+ ):
+ """
+ Analogous to torch.unsqueeze. The dimension is relative to the
+ shared dimensions of the rotation/translation.
+
+ Args:
+ dim: A positive or negative dimension index.
+ Returns:
+ The unsqueezed transformation.
+ """
+ if dim >= len(self.shape):
+ raise ValueError("Invalid dimension")
+ rots = self._rots.unsqueeze(dim)
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
+
+ return Rigid(rots, trans)
+
+ @staticmethod
+ def cat(
+ ts,
+ dim: int,
+ ):
+ """
+ Concatenates transformations along a new dimension.
+
+ Args:
+ ts:
+ A list of T objects
+ dim:
+ The dimension along which the transformations should be
+ concatenated
+ Returns:
+ A concatenated transformation object
+ """
+ rots = Rotation.cat([t._rots for t in ts], dim)
+ trans = torch.cat(
+ [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
+ )
+
+ return Rigid(rots, trans)
+
+ def apply_rot_fn(self, fn):
+ """
+ Applies a Rotation -> Rotation function to the stored rotation
+ object.
+
+ Args:
+ fn: A function of type Rotation -> Rotation
+ Returns:
+ A transformation object with a transformed rotation.
+ """
+ return Rigid(fn(self._rots), self._trans)
+
+ def apply_trans_fn(self, fn):
+ """
+ Applies a Tensor -> Tensor function to the stored translation.
+
+ Args:
+ fn:
+ A function of type Tensor -> Tensor to be applied to the
+ translation
+ Returns:
+ A transformation object with a transformed translation.
+ """
+ return Rigid(self._rots, fn(self._trans))
+
+ def scale_translation(self, trans_scale_factor: float):
+ """
+ Scales the translation by a constant factor.
+
+ Args:
+ trans_scale_factor:
+ The constant factor
+ Returns:
+ A transformation object with a scaled translation.
+ """
+ fn = lambda t: t * trans_scale_factor
+ return self.apply_trans_fn(fn)
+
+ def stop_rot_gradient(self):
+ """
+ Detaches the underlying rotation object
+
+ Returns:
+ A transformation object with detached rotations
+ """
+ fn = lambda r: r.detach()
+ return self.apply_rot_fn(fn)
+
+ @staticmethod
+ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
+ """
+ Returns a transformation object from reference coordinates.
+
+ Note that this method does not take care of symmetries. If you
+ provide the atom positions in the non-standard way, the N atom will
+ end up not at [-0.527250, 1.359329, 0.0] but instead at
+ [-0.527250, -1.359329, 0.0]. You need to take care of such cases in
+ your code.
+
+ Args:
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
+ Returns:
+ A transformation object. After applying the translation and
+ rotation to the reference backbone, the coordinates will
+ approximately equal to the input coordinates.
+ """
+ translation = -1 * ca_xyz
+ n_xyz = n_xyz + translation
+ c_xyz = c_xyz + translation
+
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
+ sin_c1 = -c_y / norm
+ cos_c1 = c_x / norm
+ zeros = sin_c1.new_zeros(sin_c1.shape)
+ ones = sin_c1.new_ones(sin_c1.shape)
+
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
+ c1_rots[..., 0, 0] = cos_c1
+ c1_rots[..., 0, 1] = -1 * sin_c1
+ c1_rots[..., 1, 0] = sin_c1
+ c1_rots[..., 1, 1] = cos_c1
+ c1_rots[..., 2, 2] = 1
+
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
+ sin_c2 = c_z / norm
+ cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
+
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
+ c2_rots[..., 0, 0] = cos_c2
+ c2_rots[..., 0, 2] = sin_c2
+ c2_rots[..., 1, 1] = 1
+ c1_rots[..., 2, 0] = -1 * sin_c2
+ c1_rots[..., 2, 2] = cos_c2
+
+ c_rots = rot_matmul(c2_rots, c1_rots)
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
+
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
+ norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
+ sin_n = -n_z / norm
+ cos_n = n_y / norm
+
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
+ n_rots[..., 0, 0] = 1
+ n_rots[..., 1, 1] = cos_n
+ n_rots[..., 1, 2] = -1 * sin_n
+ n_rots[..., 2, 1] = sin_n
+ n_rots[..., 2, 2] = cos_n
+
+ rots = rot_matmul(n_rots, c_rots)
+
+ rots = rots.transpose(-1, -2)
+ translation = -1 * translation
+
+ rot_obj = Rotation(rot_mats=rots, quats=None)
+
+ return Rigid(rot_obj, translation)
+
+ def cuda(self):
+ """
+ Moves the transformation object to GPU memory
+
+ Returns:
+ A version of the transformation on GPU
+ """
+ return Rigid(self._rots.cuda(), self._trans.cuda())
diff --git a/openfold/utils/seed.py b/openfold/utils/seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45b81379257909e06b00839bbb90848cf3d0b3f
--- /dev/null
+++ b/openfold/utils/seed.py
@@ -0,0 +1,19 @@
+import os
+import logging
+import random
+import numpy as np
+from pytorch_lightning.utilities.seed import seed_everything
+
+from openfold.utils.suppress_output import SuppressLogging
+
+
+def seed_globally(seed=None):
+ if("PL_GLOBAL_SEED" not in os.environ):
+ if(seed is None):
+ seed = random.randint(0, np.iinfo(np.uint32).max)
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
+ logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
+
+ # seed_everything is a bit log-happy
+ with SuppressLogging(logging.INFO):
+ seed_everything(seed=None)
diff --git a/openfold/utils/superimposition.py b/openfold/utils/superimposition.py
new file mode 100644
index 0000000000000000000000000000000000000000..835498f014b8083b3f715bfb3b8dd426ce296f3e
--- /dev/null
+++ b/openfold/utils/superimposition.py
@@ -0,0 +1,107 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from Bio.SVDSuperimposer import SVDSuperimposer
+import numpy as np
+import torch
+
+
+def _superimpose_np(reference, coords):
+ """
+ Superimposes coordinates onto a reference by minimizing RMSD using SVD.
+
+ Args:
+ reference:
+ [N, 3] reference array
+ coords:
+ [N, 3] array
+ Returns:
+ A tuple of [N, 3] superimposed coords and the final RMSD.
+ """
+ sup = SVDSuperimposer()
+ sup.set(reference, coords)
+ sup.run()
+ return sup
+
+def _superimpose_single(reference, coords):
+ reference_np = reference.detach().cpu().numpy()
+ coords_np = coords.detach().cpu().numpy()
+ sup = _superimpose_np(reference_np, coords_np)
+ rot, tran = sup.get_rotran()
+ superimposed, rmsd = sup.get_transformed(), sup.get_rms()
+ return coords.new_tensor(superimposed), coords.new_tensor(rmsd), rot, tran
+
+
+def superimpose(reference, coords, mask, return_transform=False):
+ """
+ Superimposes coordinates onto a reference by minimizing RMSD using SVD.
+
+ Args:
+ reference:
+ [*, N, 3] reference tensor
+ coords:
+ [*, N, 3] tensor
+ mask:
+ [*, N] tensor
+ Returns:
+ A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
+ """
+ def select_unmasked_coords(coords, mask):
+ return torch.masked_select(
+ coords,
+ (mask > 0.)[..., None],
+ ).reshape(-1, 3)
+
+ batch_dims = reference.shape[:-2]
+ flat_reference = reference.reshape((-1,) + reference.shape[-2:])
+ flat_coords = coords.reshape((-1,) + reference.shape[-2:])
+ flat_mask = mask.reshape((-1,) + mask.shape[-1:])
+ superimposed_list = []
+ rmsds = []
+ rots = []
+ trans = []
+ for r, c, m in zip(flat_reference, flat_coords, flat_mask):
+ r_unmasked_coords = select_unmasked_coords(r, m)
+ c_unmasked_coords = select_unmasked_coords(c, m)
+ superimposed, rmsd, rot, tran = _superimpose_single(
+ r_unmasked_coords,
+ c_unmasked_coords
+ )
+ rots.append(rot)
+ trans.append(tran)
+ # This is very inelegant, but idk how else to invert the masking
+ # procedure.
+ count = 0
+ superimposed_full_size = torch.zeros_like(r)
+ for i, unmasked in enumerate(m):
+ if(unmasked):
+ superimposed_full_size[i] = superimposed[count]
+ count += 1
+
+ superimposed_list.append(superimposed_full_size)
+ rmsds.append(rmsd)
+
+ superimposed_stacked = torch.stack(superimposed_list, dim=0)
+ rmsds_stacked = torch.stack(rmsds, dim=0)
+ rots_stacked = torch.tensor(np.stack(rots, axis=0), device=coords.device)
+ trans_stacked = torch.tensor(np.stack(trans, axis=0), device=coords.device)
+
+ superimposed_reshaped = superimposed_stacked.reshape(
+ batch_dims + coords.shape[-2:]
+ )
+ rmsds_reshaped = rmsds_stacked.reshape(
+ batch_dims
+ )
+ if return_transform:
+ return superimposed_reshaped, rmsds_reshaped, rots_stacked, trans_stacked
+ return superimposed_reshaped, rmsds_reshaped
diff --git a/openfold/utils/suppress_output.py b/openfold/utils/suppress_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c821b36a6b8d047e541a51faceef037f48ff90
--- /dev/null
+++ b/openfold/utils/suppress_output.py
@@ -0,0 +1,26 @@
+import logging
+import sys
+
+
+class SuppressStdout:
+ def __enter__(self):
+ self.stdout = sys.stdout
+ dev_null = open("/dev/null", "w")
+ sys.stdout = dev_null
+
+ def __exit__(self, typ, value, traceback):
+ fp = sys.stdout
+ sys.stdout = self.stdout
+ fp.close()
+
+
+class SuppressLogging:
+ def __init__(self, level):
+ self.level = level
+
+ def __enter__(self):
+ logging.disable(self.level)
+
+ def __exit__(self, typ, value, traceback):
+ logging.disable(logging.NOTSET)
+
diff --git a/openfold/utils/tensor_utils.py b/openfold/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5e8e4b6b5e9a26ccf1c5cedf1ff10102e75b2f
--- /dev/null
+++ b/openfold/utils/tensor_utils.py
@@ -0,0 +1,408 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+import torch
+import torch.nn as nn
+from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def masked_mean(mask, value, dim, eps=1e-4):
+ mask = mask.expand(*value.shape)
+ return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
+
+
+def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
+ boundaries = torch.linspace(
+ min_bin, max_bin, no_bins - 1, device=pts.device
+ )
+ dists = torch.sqrt(
+ torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
+ )
+ return torch.bucketize(dists, boundaries)
+
+
+def dict_multimap(fn, dicts):
+ first = dicts[0]
+ new_dict = {}
+ for k, v in first.items():
+ all_v = [d[k] for d in dicts]
+ if type(v) is dict:
+ new_dict[k] = dict_multimap(fn, all_v)
+ else:
+ new_dict[k] = fn(all_v)
+
+ return new_dict
+
+
+def one_hot(x, v_bins):
+ reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
+ diffs = x[..., None] - reshaped_bins
+ am = torch.argmin(torch.abs(diffs), dim=-1)
+ return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
+
+
+def batched_gather(data, inds, dim=0, no_batch_dims=0):
+ ranges = []
+ for i, s in enumerate(data.shape[:no_batch_dims]):
+ r = torch.arange(s)
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
+ ranges.append(r)
+
+ remaining_dims = [
+ slice(None) for _ in range(len(data.shape) - no_batch_dims)
+ ]
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
+ ranges.extend(remaining_dims)
+ return data[ranges]
+
+
+# With tree_map, a poor man's JAX tree_map
+def dict_map(fn, dic, leaf_type):
+ new_dict = {}
+ for k, v in dic.items():
+ if type(v) is dict:
+ new_dict[k] = dict_map(fn, v, leaf_type)
+ else:
+ new_dict[k] = tree_map(fn, v, leaf_type)
+
+ return new_dict
+
+
+def tree_map(fn, tree, leaf_type):
+ if isinstance(tree, dict):
+ return dict_map(fn, tree, leaf_type)
+ elif isinstance(tree, list):
+ return [tree_map(fn, x, leaf_type) for x in tree]
+ elif isinstance(tree, tuple):
+ return tuple([tree_map(fn, x, leaf_type) for x in tree])
+ elif isinstance(tree, leaf_type):
+ return fn(tree)
+ else:
+ print(type(tree))
+ raise ValueError("Not supported")
+
+
+tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
+
+def _fetch_dims(tree):
+ shapes = []
+ tree_type = type(tree)
+ if tree_type is dict:
+ for v in tree.values():
+ shapes.extend(_fetch_dims(v))
+ elif tree_type is list or tree_type is tuple:
+ for t in tree:
+ shapes.extend(_fetch_dims(t))
+ elif tree_type is torch.Tensor:
+ shapes.append(tree.shape)
+ else:
+ raise ValueError("Not supported")
+
+ return shapes
+
+
+@torch.jit.ignore
+def _flat_idx_to_idx(
+ flat_idx: int,
+ dims: Tuple[int],
+) -> Tuple[int]:
+ idx = []
+ for d in reversed(dims):
+ idx.append(flat_idx % d)
+ flat_idx = flat_idx // d
+
+ return tuple(reversed(idx))
+
+
+@torch.jit.ignore
+def _get_minimal_slice_set(
+ start: Sequence[int],
+ end: Sequence[int],
+ dims: int,
+ start_edges: Optional[Sequence[bool]] = None,
+ end_edges: Optional[Sequence[bool]] = None,
+) -> Sequence[Tuple[int]]:
+ """
+ Produces an ordered sequence of tensor slices that, when used in
+ sequence on a tensor with shape dims, yields tensors that contain every
+ leaf in the contiguous range [start, end]. Care is taken to yield a
+ short sequence of slices, and perhaps even the shortest possible (I'm
+ pretty sure it's the latter).
+
+ end is INCLUSIVE.
+ """
+ # start_edges and end_edges both indicate whether, starting from any given
+ # dimension, the start/end index is at the top/bottom edge of the
+ # corresponding tensor, modeled as a tree
+ def reduce_edge_list(l):
+ tally = 1
+ for i in range(len(l)):
+ reversed_idx = -1 * (i + 1)
+ l[reversed_idx] *= tally
+ tally = l[reversed_idx]
+
+ if(start_edges is None):
+ start_edges = [s == 0 for s in start]
+ reduce_edge_list(start_edges)
+ if(end_edges is None):
+ end_edges = [e == (d - 1) for e,d in zip(end, dims)]
+ reduce_edge_list(end_edges)
+
+ # Base cases. Either start/end are empty and we're done, or the final,
+ # one-dimensional tensor can be simply sliced
+ if(len(start) == 0):
+ return [tuple()]
+ elif(len(start) == 1):
+ return [(slice(start[0], end[0] + 1),)]
+
+ slices = []
+ path = []
+
+ # Dimensions common to start and end can be selected directly
+ for s,e in zip(start, end):
+ if(s == e):
+ path.append(slice(s, s + 1))
+ else:
+ break
+
+ path = tuple(path)
+ divergence_idx = len(path)
+
+ # start == end, and we're done
+ if(divergence_idx == len(dims)):
+ return [tuple(path)]
+
+ def upper():
+ sdi = start[divergence_idx]
+ return [
+ path + (slice(sdi, sdi + 1),) + s for s in
+ _get_minimal_slice_set(
+ start[divergence_idx + 1:],
+ [d - 1 for d in dims[divergence_idx + 1:]],
+ dims[divergence_idx + 1:],
+ start_edges=start_edges[divergence_idx + 1:],
+ end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
+ )
+ ]
+
+ def lower():
+ edi = end[divergence_idx]
+ return [
+ path + (slice(edi, edi + 1),) + s for s in
+ _get_minimal_slice_set(
+ [0 for _ in start[divergence_idx + 1:]],
+ end[divergence_idx + 1:],
+ dims[divergence_idx + 1:],
+ start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
+ end_edges=end_edges[divergence_idx + 1:],
+ )
+ ]
+
+ # If both start and end are at the edges of the subtree rooted at
+ # divergence_idx, we can just select the whole subtree at once
+ if(start_edges[divergence_idx] and end_edges[divergence_idx]):
+ slices.append(
+ path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
+ )
+ # If just start is at the edge, we can grab almost all of the subtree,
+ # treating only the ragged bottom edge as an edge case
+ elif(start_edges[divergence_idx]):
+ slices.append(
+ path + (slice(start[divergence_idx], end[divergence_idx]),)
+ )
+ slices.extend(lower())
+ # Analogous to the previous case, but the top is ragged this time
+ elif(end_edges[divergence_idx]):
+ slices.extend(upper())
+ slices.append(
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
+ )
+ # If both sides of the range are ragged, we need to handle both sides
+ # separately. If there's contiguous meat in between them, we can index it
+ # in one big chunk
+ else:
+ slices.extend(upper())
+ middle_ground = end[divergence_idx] - start[divergence_idx]
+ if(middle_ground > 1):
+ slices.append(
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
+ )
+ slices.extend(lower())
+
+ return [tuple(s) for s in slices]
+
+
+@torch.jit.ignore
+def _chunk_slice(
+ t: torch.Tensor,
+ flat_start: int,
+ flat_end: int,
+ no_batch_dims: int,
+) -> torch.Tensor:
+ """
+ Equivalent to
+
+ t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
+
+ but without the need for the initial reshape call, which can be
+ memory-intensive in certain situations. The only reshape operations
+ in this function are performed on sub-tensors that scale with
+ (flat_end - flat_start), the chunk size.
+ """
+
+ batch_dims = t.shape[:no_batch_dims]
+ start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
+ # _get_minimal_slice_set is inclusive
+ end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
+
+ # Get an ordered list of slices to perform
+ slices = _get_minimal_slice_set(
+ start_idx,
+ end_idx,
+ batch_dims,
+ )
+
+ sliced_tensors = [t[s] for s in slices]
+
+ return torch.cat(
+ [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
+ )
+
+
+def chunk_layer(
+ layer: Callable,
+ inputs: Dict[str, Any],
+ chunk_size: int,
+ no_batch_dims: int,
+ low_mem: bool = False,
+) -> Any:
+ """
+ Implements the "chunking" procedure described in section 1.11.8.
+
+ Layer outputs and inputs are assumed to be simple "pytrees,"
+ consisting only of (arbitrarily nested) lists, tuples, and dicts with
+ torch.Tensor leaves.
+
+ Args:
+ layer:
+ The layer to be applied chunk-wise
+ inputs:
+ A (non-nested) dictionary of keyworded inputs. All leaves must
+ be tensors and must share the same batch dimensions.
+ chunk_size:
+ The number of sub-batches per chunk. If multiple batch
+ dimensions are specified, a "sub-batch" is defined as a single
+ indexing of all batch dimensions simultaneously (s.t. the
+ number of sub-batches is the product of the batch dimensions).
+ no_batch_dims:
+ How many of the initial dimensions of each input tensor can
+ be considered batch dimensions.
+ low_mem:
+ Avoids flattening potentially large input tensors. Unnecessary
+ in most cases, and is ever so slightly slower than the default
+ setting.
+ Returns:
+ The reassembled output of the layer on the inputs.
+ """
+ if not (len(inputs) > 0):
+ raise ValueError("Must provide at least one input")
+
+ initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
+ orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
+
+ def _prep_inputs(t):
+ # TODO: make this more memory efficient. This sucks
+ if(not low_mem):
+ if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ t = t.reshape(-1, *t.shape[no_batch_dims:])
+ else:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ return t
+
+ prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
+
+ flat_batch_dim = 1
+ for d in orig_batch_dims:
+ flat_batch_dim *= d
+
+ no_chunks = flat_batch_dim // chunk_size + (
+ flat_batch_dim % chunk_size != 0
+ )
+
+ i = 0
+ out = None
+ for _ in range(no_chunks):
+ # Chunk the input
+ if(not low_mem):
+ select_chunk = (
+ lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
+ )
+ else:
+ select_chunk = (
+ partial(
+ _chunk_slice,
+ flat_start=i,
+ flat_end=min(flat_batch_dim, i + chunk_size),
+ no_batch_dims=len(orig_batch_dims)
+ )
+ )
+
+ chunks = tensor_tree_map(select_chunk, prepped_inputs)
+
+ # Run the layer on the chunk
+ output_chunk = layer(**chunks)
+
+ # Allocate space for the output
+ if out is None:
+ allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
+ out = tensor_tree_map(allocate, output_chunk)
+
+ # Put the chunk in its pre-allocated space
+ out_type = type(output_chunk)
+ if out_type is dict:
+ def assign(d1, d2):
+ for k, v in d1.items():
+ if type(v) is dict:
+ assign(v, d2[k])
+ else:
+ v[i : i + chunk_size] = d2[k]
+
+ assign(out, output_chunk)
+ elif out_type is tuple:
+ for x1, x2 in zip(out, output_chunk):
+ x1[i : i + chunk_size] = x2
+ elif out_type is torch.Tensor:
+ out[i : i + chunk_size] = output_chunk
+ else:
+ raise ValueError("Not supported")
+
+ i += chunk_size
+
+ reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
+ out = tensor_tree_map(reshape, out)
+
+ return out
diff --git a/openfold/utils/validation_metrics.py b/openfold/utils/validation_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b8da9860452030b6e9e1c2f46b8e49fdc17b75
--- /dev/null
+++ b/openfold/utils/validation_metrics.py
@@ -0,0 +1,38 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+def gdt(p1, p2, mask, cutoffs):
+ n = torch.sum(mask, dim=-1)
+
+ p1 = p1.float()
+ p2 = p2.float()
+ distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
+
+ scores = []
+ for c in cutoffs:
+ score = torch.sum((distances <= c) * mask, dim=-1) / n
+ scores.append(score)
+
+ return sum(scores) / len(scores)
+
+
+def gdt_ts(p1, p2, mask):
+ return gdt(p1, p2, mask, [1., 2., 4., 8.])
+
+
+def gdt_ha(p1, p2, mask):
+ return gdt(p1, p2, mask, [0.5, 1., 2., 4.])
+