提交LigUnity初始代码
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -35
- .gitignore +165 -0
- HGNN/Attention.py +36 -0
- HGNN/PL_Aggregator.py +75 -0
- HGNN/PL_Encoder.py +51 -0
- HGNN/PP_Aggregator.py +43 -0
- HGNN/PP_Encoder.py +51 -0
- HGNN/align.py +198 -0
- HGNN/data/CoreSet.dat +286 -0
- HGNN/data/PDBbind_v2020/index/INDEX_general_PL_data.2020 +0 -0
- HGNN/data/PDBbind_v2020/index/INDEX_general_PL_name.2020 +0 -0
- HGNN/data/PDBbind_v2020/index/INDEX_refined_data.2020 +0 -0
- HGNN/data/PDBbind_v2020/index/INDEX_refined_name.2020 +0 -0
- HGNN/main.py +318 -0
- HGNN/read_fasta.py +112 -0
- HGNN/screen_dataset.py +420 -0
- HGNN/screening.py +165 -0
- HGNN/test_pocket.fasta +2 -0
- HGNN/util.py +96 -0
- License +159 -0
- README.md +206 -3
- active_learning_scripts/run_al.sh +22 -0
- active_learning_scripts/run_cycle_ensemble.py +334 -0
- active_learning_scripts/run_cycle_one_model.py +246 -0
- active_learning_scripts/run_model.sh +53 -0
- ensemble_result.py +173 -0
- py_scripts/__init__.py +0 -0
- py_scripts/write_case_study.py +227 -0
- test.sh +18 -0
- test_fewshot.sh +38 -0
- test_fewshot_demo.sh +43 -0
- test_zeroshot_demo.sh +20 -0
- train.sh +145 -0
- unimol/__init__.py +6 -0
- unimol/data/__init__.py +50 -0
- unimol/data/add_2d_conformer_dataset.py +46 -0
- unimol/data/affinity_dataset.py +527 -0
- unimol/data/atom_type_dataset.py +34 -0
- unimol/data/conformer_sample_dataset.py +315 -0
- unimol/data/coord_pad_dataset.py +82 -0
- unimol/data/cropping_dataset.py +269 -0
- unimol/data/data_utils.py +23 -0
- unimol/data/dictionary.py +157 -0
- unimol/data/distance_dataset.py +64 -0
- unimol/data/from_str_dataset.py +19 -0
- unimol/data/key_dataset.py +29 -0
- unimol/data/lmdb_dataset.py +49 -0
- unimol/data/mask_points_dataset.py +267 -0
- unimol/data/normalize_dataset.py +68 -0
- unimol/data/pair_dataset.py +144 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
*
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
tmp/
|
| 163 |
+
**/*.ipynb
|
| 164 |
+
*.ipynb
|
| 165 |
+
results/
|
HGNN/Attention.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Attention(nn.Module):
|
| 10 |
+
def __init__(self, embedding_dims):
|
| 11 |
+
super(Attention, self).__init__()
|
| 12 |
+
self.embed_dim = embedding_dims
|
| 13 |
+
self.bilinear = nn.Bilinear(self.embed_dim, self.embed_dim, 1)
|
| 14 |
+
self.att1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
|
| 15 |
+
self.att2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 16 |
+
self.att3 = nn.Linear(self.embed_dim, 1)
|
| 17 |
+
|
| 18 |
+
# self.linear_q = nn.Linear(self.embed_dim, self.embed_dim)
|
| 19 |
+
# self.linear_k = nn.Linear(self.embed_dim, self.embed_dim)
|
| 20 |
+
self.softmax = nn.Softmax(0)
|
| 21 |
+
|
| 22 |
+
def forward(self, node1, u_rep, num_neighs):
|
| 23 |
+
uv_reps = u_rep.repeat(num_neighs, 1)
|
| 24 |
+
x = torch.cat((node1, uv_reps), 1)
|
| 25 |
+
x = F.relu(self.att1(x))
|
| 26 |
+
x = F.dropout(x, training=self.training)
|
| 27 |
+
x = F.relu(self.att2(x))
|
| 28 |
+
x = F.dropout(x, training=self.training)
|
| 29 |
+
x = self.att3(x)
|
| 30 |
+
att = F.softmax(x, dim=0)
|
| 31 |
+
|
| 32 |
+
# u_rep = self.linear_q(u_rep).repeat(num_neighs, 1)
|
| 33 |
+
# node1 = self.linear_k(node1)
|
| 34 |
+
# att = torch.sum(u_rep * node1, dim=1)
|
| 35 |
+
# att = F.softmax(att, dim=0).unsqueeze(1)
|
| 36 |
+
return att
|
HGNN/PL_Aggregator.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
from Attention import Attention
|
| 8 |
+
|
| 9 |
+
class PLAggregator(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
item and user aggregator: for aggregating embeddings of neighbors (item/user aggreagator).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, v2e=None, r2e=None, u2e=None, embed_dim=128, cuda="cpu", uv=True):
|
| 15 |
+
super(PLAggregator, self).__init__()
|
| 16 |
+
self.uv = uv
|
| 17 |
+
self.v2e = v2e
|
| 18 |
+
self.r2e = r2e
|
| 19 |
+
self.u2e = u2e
|
| 20 |
+
self.device = cuda
|
| 21 |
+
self.embed_dim = embed_dim
|
| 22 |
+
self.w_r1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
|
| 23 |
+
self.w_r2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 24 |
+
self.att = Attention(self.embed_dim)
|
| 25 |
+
if self.v2e is not None:
|
| 26 |
+
self.v2e.requires_grad = False
|
| 27 |
+
if self.u2e is not None:
|
| 28 |
+
self.u2e.requires_grad = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def forward(self, nodes_u, input_hist):
|
| 32 |
+
embed_matrix = torch.zeros(len(input_hist), self.embed_dim, dtype=torch.float).to(self.device)
|
| 33 |
+
|
| 34 |
+
for i in range(len(input_hist)):
|
| 35 |
+
history = []
|
| 36 |
+
label = []
|
| 37 |
+
for idx in range(len(input_hist[i])):
|
| 38 |
+
vid_hist = input_hist[i][idx][0]
|
| 39 |
+
vlabel_hist = input_hist[i][idx][1]
|
| 40 |
+
history.append(vid_hist)
|
| 41 |
+
label.append(vlabel_hist)
|
| 42 |
+
|
| 43 |
+
num_histroy_item = len(history)
|
| 44 |
+
|
| 45 |
+
if num_histroy_item > 0:
|
| 46 |
+
e_uv = self.v2e.weight[history]
|
| 47 |
+
uv_rep = self.u2e.weight[nodes_u[i]]
|
| 48 |
+
|
| 49 |
+
e_r = self.r2e.weight[label]
|
| 50 |
+
x = torch.cat((e_uv, e_r), 1)
|
| 51 |
+
x = F.relu(self.w_r1(x))
|
| 52 |
+
o_history = F.relu(self.w_r2(x))
|
| 53 |
+
|
| 54 |
+
att_w = self.att(o_history, uv_rep, num_histroy_item)
|
| 55 |
+
# print([(a,b) for a,b in zip(label, att_w)])
|
| 56 |
+
att_history = torch.mm(o_history.t(), att_w)
|
| 57 |
+
att_history = att_history.t()
|
| 58 |
+
|
| 59 |
+
embed_matrix[i] = (att_history + uv_rep) / 2
|
| 60 |
+
else:
|
| 61 |
+
embed_matrix[i] = self.u2e.weight[nodes_u[i]]
|
| 62 |
+
|
| 63 |
+
return embed_matrix
|
| 64 |
+
|
| 65 |
+
def forward_inference(self, pocket_embed, neighbor_list):
|
| 66 |
+
neighbor_embed = torch.stack([x[1] for x in neighbor_list])
|
| 67 |
+
rel_embed = self.r2e.weight[torch.stack([x[2] for x in neighbor_list])]
|
| 68 |
+
x = torch.cat((neighbor_embed, rel_embed), 1)
|
| 69 |
+
x = F.relu(self.w_r1(x))
|
| 70 |
+
o_neighbor = F.relu(self.w_r2(x))
|
| 71 |
+
|
| 72 |
+
att_w = self.att(o_neighbor, pocket_embed, len(neighbor_list))
|
| 73 |
+
# print([(a,b) for a,b in zip(label, att_w)])
|
| 74 |
+
att_res = torch.mm(o_neighbor.t(), att_w).t()
|
| 75 |
+
return (att_res + pocket_embed) / 2
|
HGNN/PL_Encoder.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
class PLEncoder(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, embed_dim, pocket_graph=None, aggregator=None, idx2assayid={}, assayid_lst_train=[], mol_smi={}, train_label_lst=[], cuda="cpu", uv=True):
|
| 10 |
+
super(PLEncoder, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.uv = uv
|
| 13 |
+
self.pocket_graph = pocket_graph
|
| 14 |
+
self.aggregator = aggregator
|
| 15 |
+
self.embed_dim = embed_dim
|
| 16 |
+
self.device = cuda
|
| 17 |
+
smi2idx = {smi:idx for idx, smi in enumerate(mol_smi)}
|
| 18 |
+
self.idx2assayid, self.assayid_lst_train, self.smi2idx, self.mol_smi, self.train_label_lst = idx2assayid, assayid_lst_train, smi2idx, mol_smi, train_label_lst
|
| 19 |
+
self.assayid_set_train = set(assayid_lst_train)
|
| 20 |
+
self.label_dicts = {x["assay_id"]: x for x in self.train_label_lst}
|
| 21 |
+
self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim) #
|
| 22 |
+
|
| 23 |
+
def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
|
| 24 |
+
to_neighs = []
|
| 25 |
+
if nodes_lig is None:
|
| 26 |
+
lig_smi_lst = ["----"] * len(nodes_pocket)
|
| 27 |
+
else:
|
| 28 |
+
lig_smi_lst = [self.mol_smi[lig_id] for lig_id in nodes_lig]
|
| 29 |
+
|
| 30 |
+
for node, smi in zip(nodes_pocket, lig_smi_lst):
|
| 31 |
+
assayid = self.idx2assayid[node]
|
| 32 |
+
neighbors = []
|
| 33 |
+
nbr_pockets = self.pocket_graph.get(assayid, [])
|
| 34 |
+
# random.shuffle(nbr_pockets)
|
| 35 |
+
# breakpoint()
|
| 36 |
+
for n_assayid, score in nbr_pockets:
|
| 37 |
+
nbr_smi = self.label_dicts[n_assayid]["ligands"][0]["smi"]
|
| 38 |
+
if assayid == n_assayid:
|
| 39 |
+
continue
|
| 40 |
+
if smi == nbr_smi:
|
| 41 |
+
continue
|
| 42 |
+
if n_assayid not in self.assayid_set_train:
|
| 43 |
+
continue
|
| 44 |
+
neighbors.append((self.smi2idx[nbr_smi], int((score - 0.5) * 10)))
|
| 45 |
+
to_neighs.append(neighbors)
|
| 46 |
+
|
| 47 |
+
neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-item network
|
| 48 |
+
return neigh_feats
|
| 49 |
+
|
| 50 |
+
def refine_pocket(self, pocket_embed, neighbor_list=None):
|
| 51 |
+
return self.aggregator.forward_inference(pocket_embed, neighbor_list)
|
HGNN/PP_Aggregator.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from Attention import Attention
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PPAggregator(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Social Aggregator: for aggregating embeddings of social neighbors.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, u2e=None, embed_dim=128, cuda="cpu"):
|
| 15 |
+
super(PPAggregator, self).__init__()
|
| 16 |
+
self.device = cuda
|
| 17 |
+
self.u2e = u2e
|
| 18 |
+
self.embed_dim = embed_dim
|
| 19 |
+
self.att = Attention(self.embed_dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, nodes, to_neighs):
|
| 22 |
+
embed_matrix = torch.zeros(len(nodes), self.embed_dim, dtype=torch.float).to(self.device)
|
| 23 |
+
self_feats = self.u2e.weight[nodes]
|
| 24 |
+
for i in range(len(nodes)):
|
| 25 |
+
tmp_adj = to_neighs[i]
|
| 26 |
+
|
| 27 |
+
num_neighs = len(tmp_adj)
|
| 28 |
+
|
| 29 |
+
if num_neighs > 0:
|
| 30 |
+
e_u = self.u2e.weight[[x[0] for x in tmp_adj]] # fast: user embedding
|
| 31 |
+
u_rep = self.u2e.weight[nodes[i]]
|
| 32 |
+
att_w = self.att(e_u, u_rep, num_neighs)
|
| 33 |
+
att_history = torch.mm(e_u.t(), att_w).t()
|
| 34 |
+
embed_matrix[i] = (att_history + self_feats[i]) / 2
|
| 35 |
+
else:
|
| 36 |
+
embed_matrix[i] = self_feats[i]
|
| 37 |
+
return embed_matrix
|
| 38 |
+
|
| 39 |
+
def forward_inference(self, pocket_embed, neighbor_list):
|
| 40 |
+
neighbor_embed = torch.stack([x[0] for x in neighbor_list])
|
| 41 |
+
att_w = self.att(neighbor_embed, pocket_embed, len(neighbor_list))
|
| 42 |
+
att_res = torch.mm(neighbor_embed.t(), att_w).t()
|
| 43 |
+
return (att_res + pocket_embed) / 2
|
HGNN/PP_Encoder.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import random
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
class PPEncoder(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, pocket_encoder, embed_dim, pocket_graph=None, aggregator=None, assayid_lst_all=[], assayid_lst_train=[], base_model=None, cuda="cpu"):
|
| 11 |
+
super(PPEncoder, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.pocket_encoder = pocket_encoder
|
| 14 |
+
self.pocket_graph = pocket_graph
|
| 15 |
+
self.aggregator = aggregator
|
| 16 |
+
if base_model != None:
|
| 17 |
+
self.base_model = base_model
|
| 18 |
+
self.embed_dim = embed_dim
|
| 19 |
+
self.device = cuda
|
| 20 |
+
self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)
|
| 21 |
+
self.assayid_lst_all, self.assayid_set_train = assayid_lst_all, set(assayid_lst_train)
|
| 22 |
+
self.assayid2idxes = {}
|
| 23 |
+
for idx, assayid in enumerate(assayid_lst_all):
|
| 24 |
+
if assayid not in self.assayid2idxes:
|
| 25 |
+
self.assayid2idxes[assayid] = []
|
| 26 |
+
self.assayid2idxes[assayid].append(idx)
|
| 27 |
+
|
| 28 |
+
def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
|
| 29 |
+
to_neighs = []
|
| 30 |
+
|
| 31 |
+
for node in nodes_pocket:
|
| 32 |
+
assayid = self.assayid_lst_all[node]
|
| 33 |
+
neighbors = []
|
| 34 |
+
nbr_pockets = self.pocket_graph.get(assayid, [])
|
| 35 |
+
for n_assayid, score in nbr_pockets:
|
| 36 |
+
if n_assayid == assayid:
|
| 37 |
+
continue
|
| 38 |
+
if n_assayid not in self.assayid_set_train:
|
| 39 |
+
continue
|
| 40 |
+
neighbors.append((random.choices(self.assayid2idxes[n_assayid])[0], score))
|
| 41 |
+
to_neighs.append(neighbors)
|
| 42 |
+
|
| 43 |
+
neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-user network
|
| 44 |
+
self_feats = self.pocket_encoder(nodes_pocket, nodes_lig, max_sample)
|
| 45 |
+
|
| 46 |
+
return (self_feats + neigh_feats) / 2
|
| 47 |
+
|
| 48 |
+
def refine_pocket(self, pocket_embed, neighbor_list=None):
|
| 49 |
+
neigh_feats = self.aggregator.forward_inference(pocket_embed, neighbor_list)
|
| 50 |
+
self_feats = self.pocket_encoder.refine_pocket(pocket_embed, neighbor_list)
|
| 51 |
+
return (self_feats + neigh_feats) / 2
|
HGNN/align.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import skbio
|
| 3 |
+
import json, pickle, os
|
| 4 |
+
from skbio import alignment
|
| 5 |
+
from skbio import Protein
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from multiprocessing import Pool
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
cutoff = 5.0
|
| 13 |
+
blosum50 = \
|
| 14 |
+
{
|
| 15 |
+
'*': {'*': 1, 'A': -5, 'C': -5, 'B': -5, 'E': -5, 'D': -5, 'G': -5,
|
| 16 |
+
'F': -5, 'I': -5, 'H': -5, 'K': -5, 'M': -5, 'L': -5,
|
| 17 |
+
'N': -5, 'Q': -5, 'P': -5, 'S': -5, 'R': -5, 'T': -5,
|
| 18 |
+
'W': -5, 'V': -5, 'Y': -5, 'X': -5, 'Z': -5},
|
| 19 |
+
'A': {'*': -5, 'A': 5, 'C': -1, 'B': -2, 'E': -1, 'D': -2, 'G': 0,
|
| 20 |
+
'F': -3, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -2,
|
| 21 |
+
'N': -1, 'Q': -1, 'P': -1, 'S': 1, 'R': -2, 'T': 0, 'W': -3,
|
| 22 |
+
'V': 0, 'Y': -2, 'X': -1, 'Z': -1},
|
| 23 |
+
'C': {'*': -5, 'A': -1, 'C': 13, 'B': -3, 'E': -3, 'D': -4,
|
| 24 |
+
'G': -3, 'F': -2, 'I': -2, 'H': -3, 'K': -3, 'M': -2,
|
| 25 |
+
'L': -2, 'N': -2, 'Q': -3, 'P': -4, 'S': -1, 'R': -4,
|
| 26 |
+
'T': -1, 'W': -5, 'V': -1, 'Y': -3, 'X': -1, 'Z': -3},
|
| 27 |
+
'B': {'*': -5, 'A': -2, 'C': -3, 'B': 6, 'E': 1, 'D': 6, 'G': -1,
|
| 28 |
+
'F': -4, 'I': -4, 'H': 0, 'K': 0, 'M': -3, 'L': -4, 'N': 5,
|
| 29 |
+
'Q': 0, 'P': -2, 'S': 0, 'R': -1, 'T': 0, 'W': -5, 'V': -3,
|
| 30 |
+
'Y': -3, 'X': -1, 'Z': 1},
|
| 31 |
+
'E': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 6, 'D': 2, 'G': -3,
|
| 32 |
+
'F': -3, 'I': -4, 'H': 0, 'K': 1, 'M': -2, 'L': -3, 'N': 0,
|
| 33 |
+
'Q': 2, 'P': -1, 'S': -1, 'R': 0, 'T': -1, 'W': -3, 'V': -3,
|
| 34 |
+
'Y': -2, 'X': -1, 'Z': 5},
|
| 35 |
+
'D': {'*': -5, 'A': -2, 'C': -4, 'B': 6, 'E': 2, 'D': 8, 'G': -1,
|
| 36 |
+
'F': -5, 'I': -4, 'H': -1, 'K': -1, 'M': -4, 'L': -4, 'N': 2,
|
| 37 |
+
'Q': 0, 'P': -1, 'S': 0, 'R': -2, 'T': -1, 'W': -5, 'V': -4,
|
| 38 |
+
'Y': -3, 'X': -1, 'Z': 1},
|
| 39 |
+
'G': {'*': -5, 'A': 0, 'C': -3, 'B': -1, 'E': -3, 'D': -1, 'G': 8,
|
| 40 |
+
'F': -4, 'I': -4, 'H': -2, 'K': -2, 'M': -3, 'L': -4, 'N': 0,
|
| 41 |
+
'Q': -2, 'P': -2, 'S': 0, 'R': -3, 'T': -2, 'W': -3, 'V': -4,
|
| 42 |
+
'Y': -3, 'X': -1, 'Z': -2},
|
| 43 |
+
'F': {'*': -5, 'A': -3, 'C': -2, 'B': -4, 'E': -3, 'D': -5,
|
| 44 |
+
'G': -4, 'F': 8, 'I': 0, 'H': -1, 'K': -4, 'M': 0, 'L': 1,
|
| 45 |
+
'N': -4, 'Q': -4, 'P': -4, 'S': -3, 'R': -3, 'T': -2, 'W': 1,
|
| 46 |
+
'V': -1, 'Y': 4, 'X': -1, 'Z': -4},
|
| 47 |
+
'I': {'*': -5, 'A': -1, 'C': -2, 'B': -4, 'E': -4, 'D': -4,
|
| 48 |
+
'G': -4, 'F': 0, 'I': 5, 'H': -4, 'K': -3, 'M': 2, 'L': 2,
|
| 49 |
+
'N': -3, 'Q': -3, 'P': -3, 'S': -3, 'R': -4, 'T': -1,
|
| 50 |
+
'W': -3, 'V': 4, 'Y': -1, 'X': -1, 'Z': -3},
|
| 51 |
+
'H': {'*': -5, 'A': -2, 'C': -3, 'B': 0, 'E': 0, 'D': -1, 'G': -2,
|
| 52 |
+
'F': -1, 'I': -4, 'H': 10, 'K': 0, 'M': -1, 'L': -3, 'N': 1,
|
| 53 |
+
'Q': 1, 'P': -2, 'S': -1, 'R': 0, 'T': -2, 'W': -3, 'V': -4,
|
| 54 |
+
'Y': 2, 'X': -1, 'Z': 0},
|
| 55 |
+
'K': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 1, 'D': -1, 'G': -2,
|
| 56 |
+
'F': -4, 'I': -3, 'H': 0, 'K': 6, 'M': -2, 'L': -3, 'N': 0,
|
| 57 |
+
'Q': 2, 'P': -1, 'S': 0, 'R': 3, 'T': -1, 'W': -3, 'V': -3,
|
| 58 |
+
'Y': -2, 'X': -1, 'Z': 1},
|
| 59 |
+
'M': {'*': -5, 'A': -1, 'C': -2, 'B': -3, 'E': -2, 'D': -4,
|
| 60 |
+
'G': -3, 'F': 0, 'I': 2, 'H': -1, 'K': -2, 'M': 7, 'L': 3,
|
| 61 |
+
'N': -2, 'Q': 0, 'P': -3, 'S': -2, 'R': -2, 'T': -1, 'W': -1,
|
| 62 |
+
'V': 1, 'Y': 0, 'X': -1, 'Z': -1},
|
| 63 |
+
'L': {'*': -5, 'A': -2, 'C': -2, 'B': -4, 'E': -3, 'D': -4,
|
| 64 |
+
'G': -4, 'F': 1, 'I': 2, 'H': -3, 'K': -3, 'M': 3, 'L': 5,
|
| 65 |
+
'N': -4, 'Q': -2, 'P': -4, 'S': -3, 'R': -3, 'T': -1,
|
| 66 |
+
'W': -2, 'V': 1, 'Y': -1, 'X': -1, 'Z': -3},
|
| 67 |
+
'N': {'*': -5, 'A': -1, 'C': -2, 'B': 5, 'E': 0, 'D': 2, 'G': 0,
|
| 68 |
+
'F': -4, 'I': -3, 'H': 1, 'K': 0, 'M': -2, 'L': -4, 'N': 7,
|
| 69 |
+
'Q': 0, 'P': -2, 'S': 1, 'R': -1, 'T': 0, 'W': -4, 'V': -3,
|
| 70 |
+
'Y': -2, 'X': -1, 'Z': 0},
|
| 71 |
+
'Q': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 2, 'D': 0, 'G': -2,
|
| 72 |
+
'F': -4, 'I': -3, 'H': 1, 'K': 2, 'M': 0, 'L': -2, 'N': 0,
|
| 73 |
+
'Q': 7, 'P': -1, 'S': 0, 'R': 1, 'T': -1, 'W': -1, 'V': -3,
|
| 74 |
+
'Y': -1, 'X': -1, 'Z': 4},
|
| 75 |
+
'P': {'*': -5, 'A': -1, 'C': -4, 'B': -2, 'E': -1, 'D': -1,
|
| 76 |
+
'G': -2, 'F': -4, 'I': -3, 'H': -2, 'K': -1, 'M': -3,
|
| 77 |
+
'L': -4, 'N': -2, 'Q': -1, 'P': 10, 'S': -1, 'R': -3,
|
| 78 |
+
'T': -1, 'W': -4, 'V': -3, 'Y': -3, 'X': -1, 'Z': -1},
|
| 79 |
+
'S': {'*': -5, 'A': 1, 'C': -1, 'B': 0, 'E': -1, 'D': 0, 'G': 0,
|
| 80 |
+
'F': -3, 'I': -3, 'H': -1, 'K': 0, 'M': -2, 'L': -3, 'N': 1,
|
| 81 |
+
'Q': 0, 'P': -1, 'S': 5, 'R': -1, 'T': 2, 'W': -4, 'V': -2,
|
| 82 |
+
'Y': -2, 'X': -1, 'Z': 0},
|
| 83 |
+
'R': {'*': -5, 'A': -2, 'C': -4, 'B': -1, 'E': 0, 'D': -2, 'G': -3,
|
| 84 |
+
'F': -3, 'I': -4, 'H': 0, 'K': 3, 'M': -2, 'L': -3, 'N': -1,
|
| 85 |
+
'Q': 1, 'P': -3, 'S': -1, 'R': 7, 'T': -1, 'W': -3, 'V': -3,
|
| 86 |
+
'Y': -1, 'X': -1, 'Z': 0},
|
| 87 |
+
'T': {'*': -5, 'A': 0, 'C': -1, 'B': 0, 'E': -1, 'D': -1, 'G': -2,
|
| 88 |
+
'F': -2, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -1, 'N': 0,
|
| 89 |
+
'Q': -1, 'P': -1, 'S': 2, 'R': -1, 'T': 5, 'W': -3, 'V': 0,
|
| 90 |
+
'Y': -2, 'X': -1, 'Z': -1},
|
| 91 |
+
'W': {'*': -5, 'A': -3, 'C': -5, 'B': -5, 'E': -3, 'D': -5,
|
| 92 |
+
'G': -3, 'F': 1, 'I': -3, 'H': -3, 'K': -3, 'M': -1, 'L': -2,
|
| 93 |
+
'N': -4, 'Q': -1, 'P': -4, 'S': -4, 'R': -3, 'T': -3,
|
| 94 |
+
'W': 15, 'V': -3, 'Y': 2, 'X': -1, 'Z': -2},
|
| 95 |
+
'V': {'*': -5, 'A': 0, 'C': -1, 'B': -3, 'E': -3, 'D': -4, 'G': -4,
|
| 96 |
+
'F': -1, 'I': 4, 'H': -4, 'K': -3, 'M': 1, 'L': 1, 'N': -3,
|
| 97 |
+
'Q': -3, 'P': -3, 'S': -2, 'R': -3, 'T': 0, 'W': -3, 'V': 5,
|
| 98 |
+
'Y': -1, 'X': -1, 'Z': -3},
|
| 99 |
+
'Y': {'*': -5, 'A': -2, 'C': -3, 'B': -3, 'E': -2, 'D': -3,
|
| 100 |
+
'G': -3, 'F': 4, 'I': -1, 'H': 2, 'K': -2, 'M': 0, 'L': -1,
|
| 101 |
+
'N': -2, 'Q': -1, 'P': -3, 'S': -2, 'R': -1, 'T': -2, 'W': 2,
|
| 102 |
+
'V': -1, 'Y': 8, 'X': -1, 'Z': -2},
|
| 103 |
+
'X': {'*': -5, 'A': -1, 'C': -1, 'B': -1, 'E': -1, 'D': -1,
|
| 104 |
+
'G': -1, 'F': -1, 'I': -1, 'H': -1, 'K': -1, 'M': -1,
|
| 105 |
+
'L': -1, 'N': -1, 'Q': -1, 'P': -1, 'S': -1, 'R': -1,
|
| 106 |
+
'T': -1, 'W': -1, 'V': -1, 'Y': -1, 'X': -1, 'Z': -1},
|
| 107 |
+
'Z': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 5, 'D': 1, 'G': -2,
|
| 108 |
+
'F': -4, 'I': -3, 'H': 0, 'K': 1, 'M': -1, 'L': -3, 'N': 0,
|
| 109 |
+
'Q': 4, 'P': -1, 'S': 0, 'R': 0, 'T': -1, 'W': -2, 'V': -3,
|
| 110 |
+
'Y': -2, 'X': -1, 'Z': 5}}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
import math
|
| 114 |
+
def get_align_score(fasta_1, fasta_2):
|
| 115 |
+
kwargs = {}
|
| 116 |
+
kwargs['suppress_sequences'] = False
|
| 117 |
+
kwargs['zero_index'] = True
|
| 118 |
+
kwargs['protein'] = True
|
| 119 |
+
kwargs['substitution_matrix'] = blosum50
|
| 120 |
+
query = alignment.StripedSmithWaterman(fasta_1, **kwargs)
|
| 121 |
+
align = query(fasta_2)
|
| 122 |
+
score = align.optimal_alignment_score
|
| 123 |
+
return float(score)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def read_data(data_root, result_root):
|
| 127 |
+
training_data_fastas = json.load(open(f"{data_root}/align_fastas_dict_10.23.json"))
|
| 128 |
+
bdb_fastas_dict = training_data_fastas['bdb_fastas']
|
| 129 |
+
pdbbind_fastas_dict = training_data_fastas['pdb_fastas']
|
| 130 |
+
|
| 131 |
+
save_dir_bdb = f"{result_root}/BDB"
|
| 132 |
+
save_dir_pdbbind = f"{result_root}/PDBBind"
|
| 133 |
+
mol_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_mol_reps.npy')
|
| 134 |
+
pocket_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_pocket_reps.npy')
|
| 135 |
+
pocket_names_bdb = json.load(open(f"{save_dir_bdb}/bdb_pocket_names.json"))
|
| 136 |
+
mol_smis_bdb = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json"))
|
| 137 |
+
bdb_pocket_feat_dict = {pocket_names_bdb[i]: pocket_feat_train_bdb[i] for i in range(len(pocket_names_bdb))}
|
| 138 |
+
bdb_mol_feat_dict = {mol_smis_bdb[i]: mol_feat_train_bdb[i] for i in range(len(mol_smis_bdb))}
|
| 139 |
+
|
| 140 |
+
mol_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_mol_reps.npy')
|
| 141 |
+
pocket_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_pocket_reps.npy')
|
| 142 |
+
pocket_names_pdbbind = json.load(open(f"{save_dir_pdbbind}/train_pdbbind_ids.json"))
|
| 143 |
+
mol_smis_pdbbind = json.load(open(f"{save_dir_pdbbind}/train_mol_smis.json"))
|
| 144 |
+
pdbbind_pocket_feat_dict = {pocket_names_pdbbind[i]: pocket_feat_train_pdbbind[i] for i in range(len(pocket_names_pdbbind))}
|
| 145 |
+
pdbbind_mol_feat_dict = {mol_smis_pdbbind[i]: mol_feat_train_pdbbind[i] for i in range(len(mol_smis_pdbbind))}
|
| 146 |
+
|
| 147 |
+
return bdb_fastas_dict, bdb_pocket_feat_dict, bdb_mol_feat_dict, pdbbind_fastas_dict, pdbbind_pocket_feat_dict, pdbbind_mol_feat_dict
|
| 148 |
+
|
| 149 |
+
def get_neighbor_pocket(test_fasta, data_root, result_root, device):
|
| 150 |
+
# 1. Read data file
|
| 151 |
+
print("reading datas")
|
| 152 |
+
bdb_fastas_dict, bdb_pocket_feat_dict, bdb_mol_feat_dict, pdbbind_fastas_dict, \
|
| 153 |
+
pdbbind_pocket_feat_dict, pdbbind_mol_feat_dict = read_data(data_root, result_root)
|
| 154 |
+
|
| 155 |
+
training_assay = json.load(open(f"{data_root}/train_label_blend_seq_full.json"))
|
| 156 |
+
training_assay += json.load(open(f"{data_root}/train_label_pdbbind_seq.json"))
|
| 157 |
+
assay_dict = {}
|
| 158 |
+
for assay in training_assay:
|
| 159 |
+
assay["ligands"] = sorted(assay["ligands"], key=lambda x: x["act"], reverse=True)
|
| 160 |
+
if "assay_id" in assay:
|
| 161 |
+
assay_dict[assay["assay_id"]] = assay
|
| 162 |
+
else:
|
| 163 |
+
assay_dict[assay["pockets"][0][:4]] = assay
|
| 164 |
+
|
| 165 |
+
skip = 0
|
| 166 |
+
# 2. run alignment
|
| 167 |
+
print("running alignment pdbbind")
|
| 168 |
+
align_res_list = []
|
| 169 |
+
for a_name, fasta in tqdm(pdbbind_fastas_dict.items()):
|
| 170 |
+
if a_name not in pdbbind_pocket_feat_dict:
|
| 171 |
+
skip += 1
|
| 172 |
+
continue
|
| 173 |
+
p_name = a_name
|
| 174 |
+
l_smi = assay_dict[a_name]["ligands"][0]["smi"]
|
| 175 |
+
align_score = get_align_score(test_fasta, fasta) / get_align_score(test_fasta, test_fasta)
|
| 176 |
+
if align_score >= 0.5:
|
| 177 |
+
align_res_list.append((pdbbind_pocket_feat_dict[p_name], pdbbind_mol_feat_dict[l_smi], align_score, a_name))
|
| 178 |
+
|
| 179 |
+
print("running alignment bindingdb")
|
| 180 |
+
for a_name, fasta in tqdm(bdb_fastas_dict.items()):
|
| 181 |
+
if a_name not in assay_dict:
|
| 182 |
+
skip += 1
|
| 183 |
+
continue
|
| 184 |
+
p_name = assay_dict[a_name]["pockets"][0]
|
| 185 |
+
l_smi = assay_dict[a_name]["ligands"][0]["smi"]
|
| 186 |
+
if l_smi not in bdb_mol_feat_dict:
|
| 187 |
+
continue
|
| 188 |
+
align_score = get_align_score(test_fasta, fasta) / get_align_score(test_fasta, test_fasta)
|
| 189 |
+
if align_score >= 0.5:
|
| 190 |
+
align_res_list.append((bdb_pocket_feat_dict[p_name], bdb_mol_feat_dict[l_smi], align_score, a_name))
|
| 191 |
+
|
| 192 |
+
for i, res in enumerate(align_res_list):
|
| 193 |
+
align_res_list[i] = (torch.tensor(res[0]).float().to(device),
|
| 194 |
+
torch.tensor(res[1]).float().to(device),
|
| 195 |
+
torch.tensor(int((res[2] - 0.5) * 10)).to(device))
|
| 196 |
+
|
| 197 |
+
return align_res_list
|
| 198 |
+
|
HGNN/data/CoreSet.dat
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#code resl year logKa Ka target
|
| 2 |
+
4llx 1.75 2014 2.89 Ki=1300uM 1
|
| 3 |
+
5c28 1.56 2015 5.66 Ki=2.2uM 1
|
| 4 |
+
3uuo 2.11 2012 7.96 Ki=11nM 1
|
| 5 |
+
3ui7 2.28 2011 9.00 Ki=1nM 1
|
| 6 |
+
5c2h 2.09 2015 11.09 Ki=8.2pM 1
|
| 7 |
+
2v00 1.55 2007 3.66 Kd=0.22mM 2
|
| 8 |
+
3wz8 1.45 2015 5.82 Ki=1.5uM 2
|
| 9 |
+
3pww 1.22 2011 7.32 Ki=48nM 2
|
| 10 |
+
3prs 1.38 2011 7.82 Ki=15nM 2
|
| 11 |
+
3uri 2.10 2012 9.00 Ki=1nM 2
|
| 12 |
+
4m0z 2.00 2014 5.19 Kd=6.4uM 3
|
| 13 |
+
4m0y 1.70 2014 6.46 Kd=0.35uM 3
|
| 14 |
+
3qgy 2.10 2011 7.80 Ki=16nM 3
|
| 15 |
+
4qd6 2.45 2015 8.64 Ki=2.3nM 3
|
| 16 |
+
4rfm 2.10 2015 10.05 Ki=90pM 3
|
| 17 |
+
4cr9 1.70 2015 4.10 Ki=80uM 4
|
| 18 |
+
4cra 1.80 2015 7.22 Ki=0.06uM 4
|
| 19 |
+
4x6p 1.93 2015 8.30 Ki=5nM 4
|
| 20 |
+
4crc 1.60 2015 8.72 Ki=0.0019uM 4
|
| 21 |
+
4ty7 2.09 2014 9.52 Ki=0.3nM 4
|
| 22 |
+
5aba 1.62 2015 2.98 Kd=1040uM 5
|
| 23 |
+
5a7b 1.40 2015 3.57 Kd=271uM 5
|
| 24 |
+
4agn 1.60 2012 3.97 Kd=107uM 5
|
| 25 |
+
4agp 1.50 2012 4.69 Kd=20.6uM 5
|
| 26 |
+
4agq 1.42 2012 5.01 Kd=9.7uM 5
|
| 27 |
+
3bgz 2.40 2007 6.26 Ki=0.55uM 6
|
| 28 |
+
3jya 2.10 2009 6.89 Ki=0.1301uM 6
|
| 29 |
+
2c3i 1.90 2005 7.60 Kd=25nM 6
|
| 30 |
+
4k18 2.05 2013 8.96 Ki=1.1nM 6
|
| 31 |
+
5dwr 2.00 2015 11.22 Ki=6pM 6
|
| 32 |
+
3mss 1.95 2010 4.66 Kd=22uM 7
|
| 33 |
+
3k5v 1.74 2010 6.30 Kd=0.5uM 7
|
| 34 |
+
3pyy 1.85 2011 6.86 Kd=137nM 7
|
| 35 |
+
2v7a 2.50 2007 8.30 Kd=0.005uM 7
|
| 36 |
+
4twp 2.40 2015 10.00 Ki=100pM 7
|
| 37 |
+
3wtj 2.24 2015 6.53 Kd=0.297uM 8
|
| 38 |
+
3zdg 2.48 2013 7.10 Ki=79nM 8
|
| 39 |
+
3u8k 2.47 2011 8.66 Ki=2.2nM 8
|
| 40 |
+
4qac 2.10 2014 9.40 Kd=0.4nM 8
|
| 41 |
+
3u8n 2.35 2011 10.17 Ki=0.067nM 8
|
| 42 |
+
1a30 2.00 1998 4.30 Ki=50uM 9
|
| 43 |
+
2qnq 2.30 2008 6.11 Ki=0.77uM 9
|
| 44 |
+
1g2k 1.95 2001 7.96 Ki=11nM 9
|
| 45 |
+
1eby 2.29 2002 9.70 Ki=0.20nM 9
|
| 46 |
+
3o9i 1.45 2011 11.82 Ki=1.5pM 9
|
| 47 |
+
4lzs 2.20 2014 4.80 Kd=16uM 10
|
| 48 |
+
3u5j 1.60 2011 5.61 Kd=2.46uM 10
|
| 49 |
+
4wiv 1.56 2014 6.26 Kd=550nM 10
|
| 50 |
+
4ogj 1.65 2014 6.79 Kd=164nM 10
|
| 51 |
+
3p5o 1.60 2010 7.30 Kd=50.5nM 10
|
| 52 |
+
1ps3 1.80 2003 2.28 Ki=5.2mM 11
|
| 53 |
+
3dx1 1.21 2009 3.58 Ki=265uM 11
|
| 54 |
+
3d4z 1.39 2008 4.89 Ki=13uM 11
|
| 55 |
+
3dx2 1.40 2009 6.82 Ki=150nM 11
|
| 56 |
+
3ejr 1.27 2009 8.57 Ki=2.7nM 11
|
| 57 |
+
3l7b 2.00 2010 2.40 Ki=4.01mM 12
|
| 58 |
+
4eky 2.45 2012 3.52 Ki=303.0uM 12
|
| 59 |
+
3g2n 2.10 2010 4.09 Ki=81uM 12
|
| 60 |
+
3syr 2.40 2012 5.10 Ki=7.9uM 12
|
| 61 |
+
3ebp 2.00 2009 5.91 Ki=1.24uM 12
|
| 62 |
+
2w66 2.27 2009 4.05 Ki=89uM 13
|
| 63 |
+
2w4x 2.42 2009 4.85 Kd=14uM 13
|
| 64 |
+
2wca 2.30 2009 5.60 Ki=2.5uM 13
|
| 65 |
+
2xj7 2.00 2010 6.66 Ki=220nM 13
|
| 66 |
+
2vvn 1.85 2008 7.30 Kd=50nM 13
|
| 67 |
+
3aru 1.90 2011 3.22 Kd=600uM 14
|
| 68 |
+
3arv 1.50 2011 5.64 Kd=2.3uM 14
|
| 69 |
+
3ary 1.35 2011 6.00 Kd=1.0uM 14
|
| 70 |
+
3arq 1.50 2011 6.40 Kd=0.4uM 14
|
| 71 |
+
3arp 1.55 2011 7.15 Kd=0.07uM 14
|
| 72 |
+
4ih5 1.90 2013 4.11 Kd=78uM 15
|
| 73 |
+
4ih7 2.30 2013 5.24 Kd=5.8uM 15
|
| 74 |
+
3cj4 2.07 2008 6.51 Kd=0.31uM 15
|
| 75 |
+
4eo8 1.80 2012 8.15 Kd=7nM 15
|
| 76 |
+
3gnw 2.39 2009 9.10 Kd=0.79nM 15
|
| 77 |
+
1gpk 2.10 2002 5.37 Ki=4.3uM 16
|
| 78 |
+
1gpn 2.35 2002 6.48 Ki=0.334uM 16
|
| 79 |
+
1h23 2.15 2002 8.35 Ki=4.5nM 16
|
| 80 |
+
1h22 2.15 2002 9.10 Ki=0.8nM 16
|
| 81 |
+
1e66 2.10 2001 9.89 Ki=0.13nM 16
|
| 82 |
+
3f3a 2.00 2008 4.19 Ki=64.8uM 17
|
| 83 |
+
3f3c 2.10 2008 6.02 Ki=950nM 17
|
| 84 |
+
4mme 2.50 2013 6.50 Kd=318nM 17
|
| 85 |
+
3f3d 2.30 2008 7.16 Kd=69nM 17
|
| 86 |
+
3f3e 1.80 2008 7.70 Kd=20nM 17
|
| 87 |
+
2wbg 1.85 2009 4.45 Ki=35.2uM 18
|
| 88 |
+
2cbv 1.95 2006 5.48 Kd=3.3uM 18
|
| 89 |
+
2j78 1.65 2006 6.42 Kd=384nM 18
|
| 90 |
+
2j7h 1.95 2006 7.19 Kd=65nM 18
|
| 91 |
+
2cet 1.97 2006 8.02 Kd=9.6nM 18
|
| 92 |
+
3udh 1.70 2012 2.85 Kd=1.4mM 19
|
| 93 |
+
3rsx 2.48 2011 4.41 Kd=38.8uM 19
|
| 94 |
+
4djv 1.73 2012 6.72 Ki=0.19uM 19
|
| 95 |
+
2vkm 2.05 2008 8.74 Ki=1.8nM 19
|
| 96 |
+
4gid 2.00 2012 10.77 Ki=0.017nM 19
|
| 97 |
+
4jfs 2.00 2013 5.27 Ki=5.4uM 20
|
| 98 |
+
4j28 1.73 2013 5.70 Ki=2.0uM 20
|
| 99 |
+
2wvt 1.80 2010 6.12 Kd=755nM 20
|
| 100 |
+
2xii 1.80 2010 7.20 Kd=63.3nM 20
|
| 101 |
+
4pcs 1.77 2014 7.85 Ki=14nM 20
|
| 102 |
+
3rr4 1.68 2012 4.55 Ki=28.05uM 21
|
| 103 |
+
1s38 1.81 2004 5.15 Ki=7.0uM 21
|
| 104 |
+
1r5y 1.20 2004 6.46 Ki=0.35uM 21
|
| 105 |
+
3gc5 1.40 2009 7.26 Ki=55nM 21
|
| 106 |
+
3ge7 1.50 2009 8.70 Ki=2nM 21
|
| 107 |
+
4dli 1.91 2013 5.62 Kd=2.40uM 22
|
| 108 |
+
2zb1 2.50 2008 6.32 Kd=0.48uM 22
|
| 109 |
+
4f9w 2.00 2013 6.94 Ki=114nM 22
|
| 110 |
+
3e92 2.00 2008 8.00 Ki=10nM 22
|
| 111 |
+
3e93 2.00 2008 8.85 Ki=1.4nM 22
|
| 112 |
+
4owm 1.99 2014 2.96 Ki=1090uM 23
|
| 113 |
+
3twp 1.83 2012 3.92 Ki=119uM 23
|
| 114 |
+
3r88 1.73 2012 4.82 Ki=15uM 23
|
| 115 |
+
4gkm 1.67 2013 5.17 Ki=6.8uM 23
|
| 116 |
+
3qqs 1.97 2012 5.82 Ki=1.5uM 23
|
| 117 |
+
3gv9 1.80 2009 2.12 Ki=7.5mM 24
|
| 118 |
+
3gr2 1.80 2009 2.52 Ki=3mM 24
|
| 119 |
+
4kz6 1.68 2014 3.10 Ki=0.8mM 24
|
| 120 |
+
4jxs 1.90 2014 4.74 Ki=18uM 24
|
| 121 |
+
2r9w 1.80 2008 5.10 Ki=8uM 24
|
| 122 |
+
2hb1 2.00 2006 3.80 Ki=160uM 25
|
| 123 |
+
1bzc 2.35 1999 4.92 Ki=12uM 25
|
| 124 |
+
2qbr 2.30 2008 6.33 Ki=0.47uM 25
|
| 125 |
+
2qbq 2.10 2008 7.44 Ki=0.036uM 25
|
| 126 |
+
2qbp 2.50 2008 8.40 Ki=0.004uM 25
|
| 127 |
+
1q8t 2.00 2003 4.76 Kd=17.5uM 26
|
| 128 |
+
1ydr 2.20 1997 5.52 Ki=3.0uM 26
|
| 129 |
+
1q8u 1.90 2003 5.96 Kd=1.1uM 26
|
| 130 |
+
1ydt 2.30 1997 7.32 Ki=48nM 26
|
| 131 |
+
3ag9 2.00 2010 8.05 Ki=9nM 26
|
| 132 |
+
3fcq 1.75 2009 2.77 Ki=1.7mM 27
|
| 133 |
+
1z9g 1.70 2005 5.64 Ki=2.3uM 27
|
| 134 |
+
1qf1 2.00 1999 7.32 Ki=48nM 27
|
| 135 |
+
5tmn 1.60 1989 8.04 Ki=9.1nM 27
|
| 136 |
+
4tmn 1.70 1989 10.17 Ki=0.068nM 27
|
| 137 |
+
4ddk 1.75 2013 2.29 Kd=5.13mM 28
|
| 138 |
+
4ddh 2.07 2013 3.32 Kd=0.48mM 28
|
| 139 |
+
3ivg 1.95 2009 4.30 Kd=50uM 28
|
| 140 |
+
3coz 2.00 2008 5.57 Kd=2.7uM 28
|
| 141 |
+
3coy 2.03 2008 6.02 Kd=0.96uM 28
|
| 142 |
+
3pxf 1.80 2011 4.43 Kd=37uM 29
|
| 143 |
+
4eor 2.20 2013 6.30 Ki=500nM 29
|
| 144 |
+
2xnb 1.85 2010 6.83 Ki=149nM 29
|
| 145 |
+
1pxn 2.50 2004 7.15 Ki=0.07uM 29
|
| 146 |
+
2fvd 1.85 2006 8.52 Ki=3nM 29
|
| 147 |
+
4k77 2.40 2013 6.63 Ki=235nM 30
|
| 148 |
+
4e5w 1.86 2012 7.66 Ki=22nM 30
|
| 149 |
+
4ivb 1.90 2013 8.72 Ki=1.9nM 30
|
| 150 |
+
4ivd 1.93 2013 9.52 Ki=0.3nM 30
|
| 151 |
+
4ivc 2.35 2013 10.00 Ki=0.1nM 30
|
| 152 |
+
4f09 2.40 2012 6.70 Ki=200nM 31
|
| 153 |
+
4gfm 2.30 2013 7.22 Ki=0.06uM 31
|
| 154 |
+
4hge 2.30 2012 7.92 Ki=11.9nM 31
|
| 155 |
+
4e6q 1.95 2012 8.36 Ki=4.4nM 31
|
| 156 |
+
4jia 1.85 2013 9.22 Ki=0.6nM 31
|
| 157 |
+
2brb 2.10 2005 4.86 Ki=13.7uM 32
|
| 158 |
+
2br1 2.00 2005 5.14 Ki=7.2uM 32
|
| 159 |
+
3jvr 1.76 2009 5.72 Ki=1.89uM 32
|
| 160 |
+
3jvs 1.90 2009 6.54 Kd=0.29uM 32
|
| 161 |
+
1nvq 2.00 2003 8.25 Ki=5.6nM 32
|
| 162 |
+
3acw 1.63 2010 4.76 Ki=17.5uM 33
|
| 163 |
+
4ea2 2.05 2012 6.44 Ki=0.36uM 33
|
| 164 |
+
2zcr 1.92 2008 6.87 Ki=135nM 33
|
| 165 |
+
2zy1 1.78 2009 7.40 Ki=0.04uM 33
|
| 166 |
+
2zcq 2.38 2008 8.82 Ki=1.5nM 33
|
| 167 |
+
1bcu 2.00 1998 3.28 Kd=0.53mM 34
|
| 168 |
+
3bv9 1.80 2008 5.36 Ki=4.4uM 34
|
| 169 |
+
1oyt 1.67 2003 7.24 Ki=0.057uM 34
|
| 170 |
+
2zda 1.73 2008 8.40 Ki=4nM 34
|
| 171 |
+
3utu 1.55 2012 10.92 Ki=0.012nM 34
|
| 172 |
+
3u9q 1.52 2011 4.38 Ki=41.7uM 35
|
| 173 |
+
2yfe 2.00 2012 6.63 Ki=0.236uM 35
|
| 174 |
+
3fur 2.30 2009 8.00 Ki=10nM 35
|
| 175 |
+
3b1m 1.60 2011 8.48 Ki=3.3nM 35
|
| 176 |
+
2p4y 2.25 2008 9.00 Ki=1nM 35
|
| 177 |
+
3uo4 2.45 2012 6.52 Kd=299nM 36
|
| 178 |
+
3up2 2.30 2012 7.40 Kd=40nM 36
|
| 179 |
+
3e5a 2.30 2008 8.23 Ki=5.9nM 36
|
| 180 |
+
2wtv 2.40 2010 8.74 Ki=1.8nM 36
|
| 181 |
+
3myg 2.40 2010 10.70 Kd=0.02nM 36
|
| 182 |
+
3kgp 2.35 2009 2.57 Ki=2.68mM 37
|
| 183 |
+
1c5z 1.85 2000 4.01 Ki=97uM 37
|
| 184 |
+
1o5b 1.85 2004 5.77 Ki=1.7uM 37
|
| 185 |
+
1owh 1.61 2003 7.40 Ki=40nM 37
|
| 186 |
+
1sqa 2.00 2004 9.21 Ki=0.62nM 37
|
| 187 |
+
4jsz 1.90 2013 2.30 Ki=5000uM 38
|
| 188 |
+
3kwa 2.00 2010 4.08 Ki=84uM 38
|
| 189 |
+
2weg 1.10 2009 6.50 Kd=314nM 38
|
| 190 |
+
3ryj 1.39 2011 7.80 Kd=16nM 38
|
| 191 |
+
3dd0 1.48 2009 9.00 Ki=1nM 38
|
| 192 |
+
2xdl 1.98 2010 3.10 Kd=790uM 39
|
| 193 |
+
3b27 1.50 2011 5.16 Kd=6.9uM 39
|
| 194 |
+
1yc1 1.70 2005 6.17 Kd=680nM 39
|
| 195 |
+
3rlr 1.70 2011 7.52 Ki=30nM 39
|
| 196 |
+
2yki 1.67 2011 9.46 Kd=0.35nM 39
|
| 197 |
+
1z95 1.80 2005 7.12 Ki=76nM 40
|
| 198 |
+
3b68 1.90 2008 8.40 Ki=4nM 40
|
| 199 |
+
3b5r 1.80 2008 8.77 Ki=1.7nM 40
|
| 200 |
+
3b65 1.80 2008 9.27 Ki=0.54nM 40
|
| 201 |
+
3g0w 1.95 2009 9.52 Ki=0.3nM 40
|
| 202 |
+
4u4s 1.90 2014 2.92 Kd=1200uM 41
|
| 203 |
+
1p1q 2.00 2003 4.89 Kd=12.8uM 41
|
| 204 |
+
1syi 2.10 2005 5.44 Ki=3590nM 41
|
| 205 |
+
1p1n 1.60 2003 6.80 Kd=0.16uM 41
|
| 206 |
+
2al5 1.65 2005 8.40 Ki=4nM 41
|
| 207 |
+
3g2z 1.50 2009 2.36 Ki=4.4mM 42
|
| 208 |
+
3g31 1.70 2009 2.89 Ki=1.3mM 42
|
| 209 |
+
4de2 1.40 2012 4.12 Ki=76.0uM 42
|
| 210 |
+
4de3 1.44 2012 5.52 Ki=3.0uM 42
|
| 211 |
+
4de1 1.26 2012 5.96 Ki=1.1uM 42
|
| 212 |
+
1vso 1.85 2007 4.72 Ki=18.98uM 43
|
| 213 |
+
4dld 2.00 2012 5.82 Ki=1.5uM 43
|
| 214 |
+
3gbb 2.10 2009 6.90 Ki=126nM 43
|
| 215 |
+
3fv2 1.50 2010 8.11 Ki=7.7nM 43
|
| 216 |
+
3fv1 1.50 2010 9.30 Ki=0.5nM 43
|
| 217 |
+
4mgd 1.90 2014 4.69 Kd=20.19uM 44
|
| 218 |
+
2qe4 2.40 2007 7.96 Ki=11.0nM 44
|
| 219 |
+
1qkt 2.20 2000 9.04 Kd=0.92nM 44
|
| 220 |
+
2pog 1.84 2007 9.54 Ki=0.29nM 44
|
| 221 |
+
2p15 1.94 2007 10.30 Kd=50pM 44
|
| 222 |
+
2y5h 1.33 2011 5.79 Ki=1620nM 45
|
| 223 |
+
1lpg 2.00 2003 7.09 Ki=82nM 45
|
| 224 |
+
2xbv 1.66 2010 8.43 Kd=3.7nM 45
|
| 225 |
+
1z6e 1.80 2006 9.72 Ki=0.19nM 45
|
| 226 |
+
1mq6 2.10 2003 11.15 Ki=7pM 45
|
| 227 |
+
1nc3 2.20 2003 5.00 Ki=10uM 46
|
| 228 |
+
1nc1 2.00 2003 6.12 Ki=0.75uM 46
|
| 229 |
+
1y6r 2.20 2005 10.11 Ki=77pM 46
|
| 230 |
+
4f2w 2.00 2013 11.30 Ki=5.0pM 46
|
| 231 |
+
4f3c 1.93 2013 11.82 Ki=1.5pM 46
|
| 232 |
+
1uto 1.15 2004 2.27 Kd=5.32mM 47
|
| 233 |
+
4abg 1.52 2012 3.57 Kd=271uM 47
|
| 234 |
+
3gy4 1.55 2010 5.10 Kd=8uM 47
|
| 235 |
+
1k1i 2.20 2001 6.58 Kd=264nM 47
|
| 236 |
+
1o3f 1.55 2003 7.96 Ki=0.011uM 47
|
| 237 |
+
2yge 1.96 2011 5.06 Kd=8.62uM 48
|
| 238 |
+
2fxs 2.00 2007 6.06 Kd=0.87uM 48
|
| 239 |
+
2iwx 1.50 2006 6.68 Kd=0.21uM 48
|
| 240 |
+
2wer 1.60 2009 7.05 Kd=90nM 48
|
| 241 |
+
2vw5 1.90 2008 8.52 Kd=3nM 48
|
| 242 |
+
4kzq 2.25 2013 6.10 Kd=788nM 49
|
| 243 |
+
4kzu 2.10 2013 6.50 Ki=313nM 49
|
| 244 |
+
4j21 1.93 2013 7.41 Kd=39nM 49
|
| 245 |
+
4j3l 2.09 2013 7.80 Kd=16nM 49
|
| 246 |
+
3kr8 2.10 2009 8.10 Kd=8nM 49
|
| 247 |
+
2ymd 1.96 2012 3.16 Kd=693uM 50
|
| 248 |
+
2wnc 2.20 2009 6.32 Kd=479nM 50
|
| 249 |
+
2xys 1.91 2011 7.42 Ki=38nM 50
|
| 250 |
+
2wn9 1.75 2009 8.52 Kd=3.0nM 50
|
| 251 |
+
2x00 2.40 2010 11.33 Kd=4.7pM 50
|
| 252 |
+
3ozt 1.48 2011 4.13 Ki=74.9uM 51
|
| 253 |
+
3ozs 1.44 2011 5.33 Ki=4645nM 51
|
| 254 |
+
3oe5 1.52 2011 6.88 Ki=132nM 51
|
| 255 |
+
3oe4 1.49 2011 7.47 Ki=34nM 51
|
| 256 |
+
3nw9 1.65 2011 9.00 Ki=1nM 51
|
| 257 |
+
3ao4 1.95 2011 2.07 Kd=8.5mM 52
|
| 258 |
+
3zt2 1.70 2012 2.84 Kd=1435uM 52
|
| 259 |
+
3zsx 1.95 2012 3.28 Kd=519uM 52
|
| 260 |
+
4cig 1.70 2014 3.67 Kd=214uM 52
|
| 261 |
+
3zso 1.75 2012 5.12 Kd=7.6uM 52
|
| 262 |
+
3n7a 2.00 2011 3.70 Ki=200uM 53
|
| 263 |
+
4ciw 2.20 2014 4.82 Ki=15.0uM 53
|
| 264 |
+
3n86 1.90 2011 5.64 Ki=2.3uM 53
|
| 265 |
+
3n76 1.90 2011 6.85 Ki=0.14uM 53
|
| 266 |
+
2xb8 2.40 2010 7.59 Ki=26nM 53
|
| 267 |
+
4bkt 2.35 2013 3.62 Kd=240uM 54
|
| 268 |
+
4w9c 2.20 2014 4.65 Kd=22.2uM 54
|
| 269 |
+
4w9l 2.20 2014 5.02 Kd=9.52uM 54
|
| 270 |
+
4w9i 2.40 2014 5.96 Kd=1.10uM 54
|
| 271 |
+
4w9h 2.10 2014 6.73 Kd=0.185uM 54
|
| 272 |
+
3nq9 1.90 2010 4.03 Kd=92.6uM 55
|
| 273 |
+
3ueu 2.10 2011 5.24 Kd=5.81uM 55
|
| 274 |
+
3uev 1.90 2011 5.89 Kd=1.29uM 55
|
| 275 |
+
3uew 2.00 2011 6.31 Kd=0.49uM 55
|
| 276 |
+
3uex 2.10 2011 6.92 Kd=0.12uM 55
|
| 277 |
+
3lka 1.80 2010 2.82 Kd=1.5mM 56
|
| 278 |
+
3ehy 1.90 2009 5.85 Ki=1.4uM 56
|
| 279 |
+
3tsk 2.00 2012 7.17 Kd=67nM 56
|
| 280 |
+
3nx7 1.80 2010 8.10 Kd=7.88nM 56
|
| 281 |
+
4gr0 1.50 2013 9.55 Ki=0.28nM 56
|
| 282 |
+
3dxg 1.39 2009 2.40 Ki=4.0mM 57
|
| 283 |
+
3d6q 1.60 2009 3.76 Ki=172uM 57
|
| 284 |
+
1w4o 1.60 2005 5.22 Ki=6uM 57
|
| 285 |
+
1o0h 1.20 2003 5.92 Ki=1.2uM 57
|
| 286 |
+
1u1b 2.00 2005 7.80 Kd=16nM 57
|
HGNN/data/PDBbind_v2020/index/INDEX_general_PL_data.2020
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
HGNN/data/PDBbind_v2020/index/INDEX_general_PL_name.2020
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
HGNN/data/PDBbind_v2020/index/INDEX_refined_data.2020
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
HGNN/data/PDBbind_v2020/index/INDEX_refined_name.2020
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
HGNN/main.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from PL_Encoder import PLEncoder
|
| 8 |
+
from PL_Aggregator import PLAggregator
|
| 9 |
+
from PP_Encoder import PPEncoder
|
| 10 |
+
from PP_Aggregator import PPAggregator
|
| 11 |
+
from screen_dataset import *
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.data
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
from util import cal_metrics
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HGNN(nn.Module):
|
| 20 |
+
|
| 21 |
+
def __init__(self, enc_u, enc_v, r2e):
|
| 22 |
+
super(HGNN, self).__init__()
|
| 23 |
+
self.enc_u = enc_u
|
| 24 |
+
self.enc_v = enc_v
|
| 25 |
+
self.embed_dim = enc_u.embed_dim
|
| 26 |
+
|
| 27 |
+
self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 28 |
+
self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 29 |
+
self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 30 |
+
self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 31 |
+
|
| 32 |
+
self.r2e = r2e
|
| 33 |
+
self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
|
| 34 |
+
self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
|
| 35 |
+
|
| 36 |
+
self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14))
|
| 37 |
+
|
| 38 |
+
def trainable_parameters(self):
|
| 39 |
+
for name, param in self.named_parameters(recurse=True):
|
| 40 |
+
if param.requires_grad:
|
| 41 |
+
yield param
|
| 42 |
+
|
| 43 |
+
def forward(self, nodes_u, nodes_v):
|
| 44 |
+
embeds_u = self.enc_u(nodes_u, nodes_v)
|
| 45 |
+
embeds_v = self.enc_v(nodes_v)
|
| 46 |
+
return embeds_u, embeds_v
|
| 47 |
+
|
| 48 |
+
def criterion(self, x_u, x_v, labels):
|
| 49 |
+
|
| 50 |
+
netout = torch.matmul(x_u, torch.transpose(x_v, 0, 1))
|
| 51 |
+
score = netout * self.logit_scale.exp().detach()
|
| 52 |
+
score = (labels - torch.eye(len(labels)).to(labels.device)) * -1e6 + score
|
| 53 |
+
|
| 54 |
+
lprobs_pocket = F.log_softmax(score.float(), dim=-1)
|
| 55 |
+
lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
|
| 56 |
+
sample_size = lprobs_pocket.size(0)
|
| 57 |
+
targets = torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
|
| 58 |
+
|
| 59 |
+
# pocket retrieve mol
|
| 60 |
+
loss_pocket = F.nll_loss(
|
| 61 |
+
lprobs_pocket,
|
| 62 |
+
targets,
|
| 63 |
+
reduction="mean"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
lprobs_mol = F.log_softmax(torch.transpose(score.float(), 0, 1), dim=-1)
|
| 67 |
+
lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
|
| 68 |
+
lprobs_mol = lprobs_mol[:sample_size]
|
| 69 |
+
|
| 70 |
+
# mol retrieve pocket
|
| 71 |
+
loss_mol = F.nll_loss(
|
| 72 |
+
lprobs_mol,
|
| 73 |
+
targets,
|
| 74 |
+
reduction="mean"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
loss = 0.5 * loss_pocket + 0.5 * loss_mol
|
| 78 |
+
|
| 79 |
+
ef_all = []
|
| 80 |
+
for i in range(len(netout)):
|
| 81 |
+
act_pocket = labels[i]
|
| 82 |
+
affi_pocket = netout[i]
|
| 83 |
+
top1_index = torch.argmax(affi_pocket)
|
| 84 |
+
top1_act = act_pocket[top1_index]
|
| 85 |
+
ef_all.append(cal_metrics(affi_pocket.detach().cpu().numpy(), act_pocket.detach().cpu().numpy()))
|
| 86 |
+
ef_mean = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
|
| 87 |
+
|
| 88 |
+
return loss, ef_mean, netout
|
| 89 |
+
|
| 90 |
+
def loss(self, nodes_u, nodes_v, labels):
|
| 91 |
+
x_u, x_v = self.forward(nodes_u, nodes_v)
|
| 92 |
+
loss, ef_mean, netout = self.criterion(x_u, x_v, labels)
|
| 93 |
+
return loss, ef_mean
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def train(model, device, train_loader, optimizer, epoch, valid_idxes, valid_molidxes, valid_labels):
|
| 97 |
+
model.train()
|
| 98 |
+
running_loss = 0.0
|
| 99 |
+
for i, data in enumerate(train_loader, 0):
|
| 100 |
+
batch_nodes_u, batch_nodes_v, labels = data
|
| 101 |
+
optimizer.zero_grad()
|
| 102 |
+
loss, _ = model.loss(batch_nodes_u[0].to(device), batch_nodes_v[0].to(device), labels[0].to(device))
|
| 103 |
+
loss.backward(retain_graph=True)
|
| 104 |
+
optimizer.step()
|
| 105 |
+
running_loss += loss.item()
|
| 106 |
+
if i % 200 == 0:
|
| 107 |
+
print('[%d, %5d] loss: %.3f '%(epoch, i, running_loss / 200))
|
| 108 |
+
running_loss = 0.0
|
| 109 |
+
avg_loss, avg_acc = valid(model,
|
| 110 |
+
device,
|
| 111 |
+
torch.tensor(valid_idxes).to(device),
|
| 112 |
+
torch.tensor(valid_molidxes).to(device),
|
| 113 |
+
torch.tensor(valid_labels).to(device))
|
| 114 |
+
print('Valid set results:', avg_loss.item(), avg_acc)
|
| 115 |
+
return 0
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def valid(model, device, valid_idxes, valid_molidxes, valid_labels):
|
| 119 |
+
model.eval()
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
loss, ef = model.loss(valid_idxes.to(device), valid_molidxes.to(device), valid_labels.to(device))
|
| 122 |
+
model.train()
|
| 123 |
+
return loss, ef
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def test_dekois(model, device, epoch, result_root, dekois_pocket_name, dekois_idxes):
|
| 127 |
+
model.eval()
|
| 128 |
+
loss_all, ef_all = [], []
|
| 129 |
+
loss_raw_all, ef_raw_all = [], []
|
| 130 |
+
dekois_dir = f"{result_root}/DEKOIS"
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
for dekois_id, pocket_node_id in zip(dekois_pocket_name, dekois_idxes):
|
| 133 |
+
embeds_pocket = model.enc_u([pocket_node_id], None, max_sample=-1)
|
| 134 |
+
embeds_lig = torch.tensor(np.load(f"{dekois_dir}/{dekois_id}/saved_mols_embed.npy")).to(device).float()
|
| 135 |
+
labels = np.load(f"{dekois_dir}/{dekois_id}/saved_labels.npy")
|
| 136 |
+
embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pocket_node_id]).to(device))
|
| 137 |
+
|
| 138 |
+
score = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
|
| 139 |
+
score_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
|
| 140 |
+
np.save(f"{dekois_dir}/{dekois_id}/GNN_res_epoch{epoch}.npy", score)
|
| 141 |
+
np.save(f"{dekois_dir}/{dekois_id}/noGNN_res.npy", score_raw)
|
| 142 |
+
metric = cal_metrics(score, labels)
|
| 143 |
+
metric_raw = cal_metrics(score_raw, labels)
|
| 144 |
+
# print(dekois_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
|
| 145 |
+
ef_all.append(metric)
|
| 146 |
+
ef_raw_all.append(metric_raw)
|
| 147 |
+
|
| 148 |
+
model.train()
|
| 149 |
+
ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
|
| 150 |
+
ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
|
| 151 |
+
print('Test on dekois:', ef_all)
|
| 152 |
+
print('No HGNN on dekois:', ef_raw_all)
|
| 153 |
+
|
| 154 |
+
def test_dude(model, device, epoch, result_root, dude_pocket_name, dude_idxes):
|
| 155 |
+
model.eval()
|
| 156 |
+
loss_all, ef_all = [], []
|
| 157 |
+
loss_raw_all, ef_raw_all = [], []
|
| 158 |
+
dude_dir = f"{result_root}/DUDE"
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
for dude_id, pocket_node_id in zip(dude_pocket_name, dude_idxes):
|
| 161 |
+
embeds_pocket = model.enc_u([pocket_node_id], None, max_sample=-1)
|
| 162 |
+
embeds_lig = torch.tensor(np.load(f"{dude_dir}/{dude_id}/saved_mols_embed.npy")).to(device).float()
|
| 163 |
+
labels = np.load(f"{dude_dir}/{dude_id}/saved_labels.npy")
|
| 164 |
+
embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pocket_node_id]).to(device))
|
| 165 |
+
|
| 166 |
+
score = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
|
| 167 |
+
score_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
|
| 168 |
+
np.save(f"{dude_dir}/{dude_id}/GNN_res_epoch{epoch}.npy", score)
|
| 169 |
+
np.save(f"{dude_dir}/{dude_id}/noGNN_res.npy", score_raw)
|
| 170 |
+
metric = cal_metrics(score, labels)
|
| 171 |
+
metric_raw = cal_metrics(score_raw, labels)
|
| 172 |
+
# print(dude_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
|
| 173 |
+
ef_all.append(metric)
|
| 174 |
+
ef_raw_all.append(metric_raw)
|
| 175 |
+
|
| 176 |
+
model.train()
|
| 177 |
+
ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
|
| 178 |
+
ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
|
| 179 |
+
print('Test on dude:', ef_all)
|
| 180 |
+
print('No HGNN on dude:', ef_raw_all)
|
| 181 |
+
|
| 182 |
+
def test_pcba(model, device, epoch, result_root, pcba_idxes):
|
| 183 |
+
model.eval()
|
| 184 |
+
loss_all, ef_all = [], []
|
| 185 |
+
loss_raw_all, ef_raw_all = [], []
|
| 186 |
+
pcba_dir = f"{result_root}/PCBA"
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
pocket_idx = 0
|
| 189 |
+
for pcba_id in sorted(list(os.listdir(pcba_dir))):
|
| 190 |
+
pocket_names = []
|
| 191 |
+
for names in json.load(open(f"{pcba_dir}/{pcba_id}/saved_pocket_names.json")):
|
| 192 |
+
pocket_names += names
|
| 193 |
+
embeds_lig = torch.tensor(np.load(f"{pcba_dir}/{pcba_id}/saved_mols_embed.npy")).to(device).float()
|
| 194 |
+
labels = np.load(f"{pcba_dir}/{pcba_id}/saved_labels.npy")
|
| 195 |
+
score_all_pocket = []
|
| 196 |
+
score_raw_pocket = []
|
| 197 |
+
|
| 198 |
+
for i, pocket_name in enumerate(pocket_names):
|
| 199 |
+
pcba_test_idx = pcba_idxes[pocket_idx]
|
| 200 |
+
embeds_pocket = model.enc_u([pcba_test_idx], None, max_sample=-1)
|
| 201 |
+
netout = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1))
|
| 202 |
+
embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pcba_test_idx]).to(device))
|
| 203 |
+
netout_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1))
|
| 204 |
+
score_all_pocket.append(netout.squeeze().detach().cpu().numpy())
|
| 205 |
+
score_raw_pocket.append(netout_raw.squeeze().detach().cpu().numpy())
|
| 206 |
+
pocket_idx += 1
|
| 207 |
+
|
| 208 |
+
score_max = np.stack(score_all_pocket, axis=0).mean(axis=0)
|
| 209 |
+
score_raw_max = np.stack(score_raw_pocket, axis=0).max(axis=0)
|
| 210 |
+
metric = cal_metrics(score_max, labels)
|
| 211 |
+
print(pcba_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
|
| 212 |
+
np.save(f"{pcba_dir}/{pcba_id}/GNN_res_epoch{epoch}.npy", score_max)
|
| 213 |
+
np.save(f"{pcba_dir}/{pcba_id}/noGNN_res.npy", score_raw_max)
|
| 214 |
+
ef_all.append(cal_metrics(score_max, labels))
|
| 215 |
+
ef_raw_all.append(cal_metrics(score_raw_max, labels))
|
| 216 |
+
|
| 217 |
+
model.train()
|
| 218 |
+
print(f"saving to {pcba_dir}")
|
| 219 |
+
ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
|
| 220 |
+
ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
|
| 221 |
+
print('Test on pcba:', ef_all)
|
| 222 |
+
print('No HGNN on pcba:', ef_raw_all)
|
| 223 |
+
return ef_all["EF1"]
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def main():
|
| 227 |
+
# Training settings
|
| 228 |
+
parser = argparse.ArgumentParser(description='HGNN model training')
|
| 229 |
+
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
|
| 230 |
+
parser.add_argument('--embed_dim', type=int, default=128, metavar='N', help='embedding size')
|
| 231 |
+
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
|
| 232 |
+
parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing')
|
| 233 |
+
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train')
|
| 234 |
+
parser.add_argument("--test_ckpt", type=str, default=None)
|
| 235 |
+
parser.add_argument("--data_root", type=str, default="../data")
|
| 236 |
+
parser.add_argument("--result_root", type=str, default="../result/pocket_ranking")
|
| 237 |
+
args = parser.parse_args()
|
| 238 |
+
data_root = args.data_root
|
| 239 |
+
|
| 240 |
+
seed = 42
|
| 241 |
+
torch.manual_seed(seed)
|
| 242 |
+
torch.cuda.manual_seed(seed)
|
| 243 |
+
np.random.seed(seed)
|
| 244 |
+
random.seed(seed)
|
| 245 |
+
|
| 246 |
+
use_cuda = False
|
| 247 |
+
if torch.cuda.is_available():
|
| 248 |
+
use_cuda = True
|
| 249 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 250 |
+
|
| 251 |
+
print("begin load dataset")
|
| 252 |
+
assayinfo_lst, pocket_feat, mol_feat, assayid_lst_all, mol_smi_lst, \
|
| 253 |
+
assayid_lst_train, assayid_lst_test, dude_pocket_name, pcba_pocket_name, dekois_pocket_name, valid_molidxes = load_datas(data_root, result_root)
|
| 254 |
+
print("begin load pocket-pocket graph")
|
| 255 |
+
pocket_graph = load_pocket_pocket_graph(data_root, assayid_lst_all, assayid_lst_train)
|
| 256 |
+
|
| 257 |
+
screen_dataset = ScreenDataset(args.batch_size, pocket_graph, assayinfo_lst, assayid_lst_all, mol_smi_lst, assayid_lst_train)
|
| 258 |
+
num_pockets = len(assayid_lst_all)
|
| 259 |
+
num_ligs = mol_feat.shape[0]
|
| 260 |
+
|
| 261 |
+
embed_dim = args.embed_dim
|
| 262 |
+
pocket2e = nn.Embedding(num_pockets, embed_dim).to(device)
|
| 263 |
+
pocket2e.weight.data.copy_(torch.tensor(pocket_feat).to(device))
|
| 264 |
+
for param in pocket2e.parameters():
|
| 265 |
+
param.requires_grad = False
|
| 266 |
+
|
| 267 |
+
lig2e = nn.Embedding(num_ligs, embed_dim).to(device)
|
| 268 |
+
for param in lig2e.parameters():
|
| 269 |
+
param.requires_grad = False
|
| 270 |
+
type2e = nn.Embedding(10, embed_dim).to(device)
|
| 271 |
+
|
| 272 |
+
agg_pocket = PLAggregator(lig2e, type2e, pocket2e, embed_dim, cuda=device, uv=True)
|
| 273 |
+
enc_pocket = PLEncoder(embed_dim, pocket_graph, agg_pocket, assayid_lst_all, assayid_lst_train, mol_smi_lst, assayinfo_lst, cuda=device, uv=True)
|
| 274 |
+
# neighobrs
|
| 275 |
+
agg_pocket_sim = PPAggregator(pocket2e, embed_dim, cuda=device)
|
| 276 |
+
enc_pocket = PPEncoder(enc_pocket, embed_dim, pocket_graph, agg_pocket_sim, assayid_lst_all, assayid_lst_train,
|
| 277 |
+
base_model=enc_pocket, cuda=device)
|
| 278 |
+
enc_lig = lig2e
|
| 279 |
+
# model
|
| 280 |
+
graphrec = HGNN(enc_pocket, enc_lig, type2e).to(device)
|
| 281 |
+
print("trainable parameters")
|
| 282 |
+
for name, param in graphrec.named_parameters(recurse=True):
|
| 283 |
+
if param.requires_grad:
|
| 284 |
+
print(name, param.shape)
|
| 285 |
+
optimizer = torch.optim.RMSprop(graphrec.trainable_parameters(), lr=args.lr, alpha=0.9)
|
| 286 |
+
|
| 287 |
+
begin = len(assayid_lst_train+assayid_lst_test)
|
| 288 |
+
end = begin + len(dude_pocket_name)
|
| 289 |
+
dude_idxes = range(begin, end)
|
| 290 |
+
begin = end
|
| 291 |
+
end += len(pcba_pocket_name)
|
| 292 |
+
pcba_idxes = range(begin, end)
|
| 293 |
+
begin = end
|
| 294 |
+
end += len(dekois_pocket_name)
|
| 295 |
+
dekois_idxes = range(begin, end)
|
| 296 |
+
|
| 297 |
+
if args.test_ckpt is not None:
|
| 298 |
+
graphrec.load_state_dict(torch.load(args.test_ckpt, weights_only=True))
|
| 299 |
+
test_dude(graphrec, device, 0, result_root, dude_pocket_name, dude_idxes)
|
| 300 |
+
test_dekois(graphrec, device, 0, result_root, dekois_pocket_name, dekois_idxes)
|
| 301 |
+
test_pcba(graphrec, device, 0, result_root, pcba_idxes)
|
| 302 |
+
else:
|
| 303 |
+
for epoch in range(args.epochs):
|
| 304 |
+
screen_dataset.set_epoch(epoch)
|
| 305 |
+
train_loader = torch.utils.data.DataLoader(screen_dataset, batch_size=1, shuffle=True, num_workers=8)
|
| 306 |
+
lig2e.weight.data.copy_(torch.tensor(mol_feat).to(device))
|
| 307 |
+
valid_labels = load_valid_label(assayid_lst_test)
|
| 308 |
+
valid_idxes = range(len(assayid_lst_train), len(assayid_lst_train+assayid_lst_test))
|
| 309 |
+
train(graphrec, device, train_loader, optimizer, epoch, valid_idxes, valid_molidxes, valid_labels)
|
| 310 |
+
test_dude(graphrec, device, epoch+1, result_root, dude_pocket_name, dude_idxes)
|
| 311 |
+
test_dekois(graphrec, device, epoch+1, result_root, dekois_pocket_name, dekois_idxes)
|
| 312 |
+
test_pcba(graphrec, device, epoch+1, result_root, pcba_idxes)
|
| 313 |
+
|
| 314 |
+
os.system(f"mkdir -p {result_root}/HGNN_save")
|
| 315 |
+
torch.save(graphrec.state_dict(),f"{result_root}/HGNN_save/model_{epoch}.pt")
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
HGNN/read_fasta.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re
|
| 2 |
+
import prody as pr
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# import subprocess
|
| 8 |
+
# os.environ["BABEL_LIBDIR"] = "/home/shenchao/.conda/envs/my2/lib/openbabel/3.1.0"
|
| 9 |
+
|
| 10 |
+
def write_file(output_file, outline):
|
| 11 |
+
buffer = open(output_file, 'w')
|
| 12 |
+
buffer.write(outline)
|
| 13 |
+
buffer.close()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def lig_rename(infile, outfile):
|
| 17 |
+
##some peptides may impede the generation of pocket, so rename the ligname first.
|
| 18 |
+
lines = open(infile, 'r').readlines()
|
| 19 |
+
newlines = []
|
| 20 |
+
for line in lines:
|
| 21 |
+
if re.search(r'^HETATM|^ATOM', line):
|
| 22 |
+
newlines.append(line[:17] + "LIG" + line[20:])
|
| 23 |
+
else:
|
| 24 |
+
newlines.append(line)
|
| 25 |
+
write_file(outfile, ''.join(newlines))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check_mol(infile, outfile):
|
| 29 |
+
# Some metals may have the same ID as ligand, thus making ligand included in the pocket.
|
| 30 |
+
os.system("cat %s | sed '/LIG/d' > %s" % (infile, outfile))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def extract_pocket(protpath,
|
| 34 |
+
ligpath,
|
| 35 |
+
cutoff=5.0,
|
| 36 |
+
protname=None,
|
| 37 |
+
ligname=None,
|
| 38 |
+
pdb_pocket_file=None,
|
| 39 |
+
workdir='.'):
|
| 40 |
+
"""
|
| 41 |
+
protpath: the path of protein file (.pdb).
|
| 42 |
+
ligpath: the path of ligand file (.sdf|.mol2|.pdb).
|
| 43 |
+
cutoff: the distance range within the ligand to determine the pocket.
|
| 44 |
+
protname: the name of the protein.
|
| 45 |
+
ligname: the name of the ligand.
|
| 46 |
+
workdir: working directory.
|
| 47 |
+
"""
|
| 48 |
+
if protname is None:
|
| 49 |
+
protname = os.path.basename(protpath).split('.')[0]
|
| 50 |
+
if ligname is None:
|
| 51 |
+
ligname = os.path.basename(ligpath).split('.')[0]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if not re.search(r'.pdb$', ligpath):
|
| 55 |
+
os.system(f"obabel {ligpath} -O {workdir}/{ligname}.pdb")
|
| 56 |
+
else:
|
| 57 |
+
os.system(f"cp {ligpath} {workdir}/{ligname}.pdb")
|
| 58 |
+
|
| 59 |
+
xprot = pr.parsePDB(protpath)
|
| 60 |
+
# xlig = pr.parsePDB("%s/%s.pdb"%(workdir, ligname))
|
| 61 |
+
|
| 62 |
+
# if (xlig.getResnames() == xlig.getResnames()[0]).all():
|
| 63 |
+
# lresname = xlig.getResnames()[0]
|
| 64 |
+
# else:
|
| 65 |
+
lig_rename("%s/%s.pdb" % (workdir, ligname), "%s/%s2.pdb" % (workdir, ligname))
|
| 66 |
+
os.remove("%s/%s.pdb" % (workdir, ligname))
|
| 67 |
+
os.rename("%s/%s2.pdb" % (workdir, ligname), "%s/%s.pdb" % (workdir, ligname))
|
| 68 |
+
xlig = pr.parsePDB("%s/%s.pdb" % (workdir, ligname))
|
| 69 |
+
lresname = xlig.getResnames()[0]
|
| 70 |
+
xcom = xlig + xprot
|
| 71 |
+
|
| 72 |
+
# select ONLY atoms that belong to the protein
|
| 73 |
+
ret = xcom.select(f'same residue as exwithin %s of resname %s' % (cutoff, lresname))
|
| 74 |
+
|
| 75 |
+
pr.writePDB("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff), ret)
|
| 76 |
+
# ret = pr.parsePDB("%s/%s_pocket_%s.pdb"%(workdir, protname, cutoff))
|
| 77 |
+
|
| 78 |
+
check_mol("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff), pdb_pocket_file)
|
| 79 |
+
os.remove("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_fasta_seq(fasta_file):
|
| 83 |
+
with open(fasta_file) as f:
|
| 84 |
+
lines = []
|
| 85 |
+
for line in f.readlines():
|
| 86 |
+
lines.append(line.strip())
|
| 87 |
+
fasta = "".join(lines[1:])
|
| 88 |
+
return fasta
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def read_fasta_from_protein(pdb_file, lig_file, target_id="test", cutoff=5.0, pdb_pocket_file="test_pocket.pdb", fasta_pocket_file="test_pocket.fasta"):
|
| 92 |
+
if not os.path.exists(pdb_file):
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
extract_pocket(pdb_file,
|
| 97 |
+
lig_file,
|
| 98 |
+
cutoff=cutoff,
|
| 99 |
+
protname=target_id,
|
| 100 |
+
pdb_pocket_file=pdb_pocket_file,
|
| 101 |
+
ligname=f"{target_id}_ligand")
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(e)
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
os.system(f"./pdb2fasta {pdb_pocket_file} > {fasta_pocket_file}")
|
| 107 |
+
return get_fasta_seq(fasta_pocket_file)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def read_fasta_from_pocket(pocket_pdb_file, fasta_pocket_file="test_pocket.fasta"):
|
| 111 |
+
os.system(f"./pdb2fasta {pocket_pdb_file} > {fasta_pocket_file}")
|
| 112 |
+
return get_fasta_seq(fasta_pocket_file)
|
HGNN/screen_dataset.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import contextlib
|
| 6 |
+
import copy
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, sampler, DataLoader
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_ligname():
|
| 12 |
+
pdbbind_lig_dict = {}
|
| 13 |
+
with open("./data/PDBbind_v2020/index/INDEX_general_PL_data.2020") as f:
|
| 14 |
+
for line in f.readlines():
|
| 15 |
+
if line.startswith('#'):
|
| 16 |
+
continue
|
| 17 |
+
line = line.strip().split()
|
| 18 |
+
lig = line[-1][1:-1]
|
| 19 |
+
if lig != "":
|
| 20 |
+
pdbid = line[0]
|
| 21 |
+
pdbbind_lig_dict[pdbid] = lig
|
| 22 |
+
else:
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
with open("./data/PDBbind_v2020/index/INDEX_refined_data.2020") as f:
|
| 26 |
+
for line in f.readlines():
|
| 27 |
+
if line.startswith('#'):
|
| 28 |
+
continue
|
| 29 |
+
line = line.strip().split()
|
| 30 |
+
lig = line[-1][1:-1]
|
| 31 |
+
if lig != "":
|
| 32 |
+
pdbid = line[0]
|
| 33 |
+
pdbbind_lig_dict[pdbid] = lig
|
| 34 |
+
else:
|
| 35 |
+
continue
|
| 36 |
+
return pdbbind_lig_dict
|
| 37 |
+
|
| 38 |
+
def load_uniprotid():
|
| 39 |
+
uniprot_id_dict = {}
|
| 40 |
+
with open("./data/PDBbind_v2020/index/INDEX_refined_name.2020") as f:
|
| 41 |
+
for line in f.readlines():
|
| 42 |
+
if line.startswith('#'):
|
| 43 |
+
continue
|
| 44 |
+
line = line.strip().split()
|
| 45 |
+
uniprot_id = line[2]
|
| 46 |
+
if uniprot_id != "" and uniprot_id != "------":
|
| 47 |
+
pdbid = line[0]
|
| 48 |
+
uniprot_id_dict[pdbid] = uniprot_id
|
| 49 |
+
|
| 50 |
+
with open("./data/PDBbind_v2020/index/INDEX_general_PL_name.2020") as f:
|
| 51 |
+
for line in f.readlines():
|
| 52 |
+
if line.startswith('#'):
|
| 53 |
+
continue
|
| 54 |
+
line = line.strip().split()
|
| 55 |
+
uniprot_id = line[2]
|
| 56 |
+
if uniprot_id != "" and uniprot_id != "------":
|
| 57 |
+
pdbid = line[0]
|
| 58 |
+
uniprot_id_dict[pdbid] = uniprot_id
|
| 59 |
+
|
| 60 |
+
return uniprot_id_dict
|
| 61 |
+
|
| 62 |
+
def load_pocket_dude(result_root):
|
| 63 |
+
data_root = f"{result_root}/DUDE"
|
| 64 |
+
dude_pocket_feat = []
|
| 65 |
+
dude_pocket_name = []
|
| 66 |
+
for target in sorted(list(os.listdir(data_root))):
|
| 67 |
+
pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
|
| 68 |
+
dude_pocket_feat.append(pocket_arr)
|
| 69 |
+
dude_pocket_name.append(target)
|
| 70 |
+
|
| 71 |
+
dude_pocket_feat = np.concatenate(dude_pocket_feat, axis=0)
|
| 72 |
+
return dude_pocket_feat, dude_pocket_name
|
| 73 |
+
|
| 74 |
+
def load_pocket_dekois(result_root):
|
| 75 |
+
data_root = f"{result_root}/DEKOIS"
|
| 76 |
+
dekois_pocket_feat = []
|
| 77 |
+
dekois_pocket_name = []
|
| 78 |
+
for target in sorted(list(os.listdir(data_root))):
|
| 79 |
+
pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
|
| 80 |
+
dekois_pocket_feat.append(pocket_arr)
|
| 81 |
+
dekois_pocket_name.append(target)
|
| 82 |
+
|
| 83 |
+
dekois_pocket_feat = np.concatenate(dekois_pocket_feat, axis=0)
|
| 84 |
+
return dekois_pocket_feat, dekois_pocket_name
|
| 85 |
+
|
| 86 |
+
def load_pocket_pcba(result_root):
|
| 87 |
+
data_root = f"{result_root}/PCBA"
|
| 88 |
+
pcba_pocket_feat = []
|
| 89 |
+
pcba_pocket_name = []
|
| 90 |
+
for target in sorted(list(os.listdir(data_root))):
|
| 91 |
+
pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
|
| 92 |
+
names_target = []
|
| 93 |
+
for names in json.load(open(f"{data_root}/{target}/saved_pocket_names.json")):
|
| 94 |
+
names_target += [ f"{target}_{x}" for x in names]
|
| 95 |
+
|
| 96 |
+
if pocket_arr.shape[0] == 1:
|
| 97 |
+
pocket_arr = np.concatenate([pocket_arr]*len(names_target), axis=0)
|
| 98 |
+
pcba_pocket_feat.append(pocket_arr)
|
| 99 |
+
pcba_pocket_name += names_target
|
| 100 |
+
|
| 101 |
+
pcba_pocket_feat = np.concatenate(pcba_pocket_feat, axis=0)
|
| 102 |
+
return pcba_pocket_feat, pcba_pocket_name
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def read_cluster_file(cluster_file):
|
| 106 |
+
protein_clstr_dict = {}
|
| 107 |
+
with open(cluster_file) as f:
|
| 108 |
+
line_in_clstr = []
|
| 109 |
+
for line in f.readlines():
|
| 110 |
+
if line.startswith(">"):
|
| 111 |
+
for a in line_in_clstr:
|
| 112 |
+
for b in line_in_clstr:
|
| 113 |
+
if a not in protein_clstr_dict.keys():
|
| 114 |
+
protein_clstr_dict[a] = []
|
| 115 |
+
protein_clstr_dict[a].append(b)
|
| 116 |
+
|
| 117 |
+
line_in_clstr = []
|
| 118 |
+
else:
|
| 119 |
+
line_in_clstr.append(line.split('|')[1])
|
| 120 |
+
return protein_clstr_dict
|
| 121 |
+
|
| 122 |
+
def load_assayinfo(data_root, result_root):
|
| 123 |
+
labels = json.load(open(f"{data_root}/train_label_pdbbind_seq.json")) + \
|
| 124 |
+
json.load(open("../test_datasets/casf_label_seq.json"))
|
| 125 |
+
save_dir_bdb = f"{result_root}/BDB"
|
| 126 |
+
bdb_mol_smi = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json"))
|
| 127 |
+
bdb_mol_smi = set(bdb_mol_smi)
|
| 128 |
+
for label in labels:
|
| 129 |
+
label["assay_id"] = label["pockets"][0].split("_")[0]
|
| 130 |
+
label["domain"] = "pdbbind"
|
| 131 |
+
|
| 132 |
+
# breakpoint()
|
| 133 |
+
labels_bdb = json.load(open(f"{data_root}/train_label_blend_seq_full.json"))
|
| 134 |
+
non_repeat_uniprot = []
|
| 135 |
+
testset_uniport_root = "../test_datasets"
|
| 136 |
+
non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/dude.json"))]
|
| 137 |
+
non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/PCBA.json"))]
|
| 138 |
+
non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/dekois.json"))]
|
| 139 |
+
non_repeat_uniprot_strict = []
|
| 140 |
+
protein_clstr_dict_40 = read_cluster_file(f"{data_root}/uniport40.clstr")
|
| 141 |
+
protein_clstr_dict_80 = read_cluster_file(f"{data_root}/uniport80.clstr")
|
| 142 |
+
for uniprot in non_repeat_uniprot:
|
| 143 |
+
non_repeat_uniprot_strict += protein_clstr_dict_80.get(uniprot, [])
|
| 144 |
+
non_repeat_uniprot_strict.append(uniprot)
|
| 145 |
+
old_len = len(labels_bdb)
|
| 146 |
+
non_repeat_assayids = json.load(open(os.path.join(data_root, "fep_assays.json")))
|
| 147 |
+
labels_bdb = [x for x in labels_bdb if (x["assay_id"] not in non_repeat_assayids)]
|
| 148 |
+
labels_bdb = [x for x in labels_bdb if (x["uniprot"] not in non_repeat_uniprot)]
|
| 149 |
+
|
| 150 |
+
labels_bdb_new = []
|
| 151 |
+
for label in labels_bdb:
|
| 152 |
+
ligands = label["ligands"]
|
| 153 |
+
ligands_new = []
|
| 154 |
+
for lig in ligands:
|
| 155 |
+
if lig["smi"] in bdb_mol_smi and lig["act"] >= 5:
|
| 156 |
+
ligands_new.append(lig)
|
| 157 |
+
label["ligands"] = ligands_new
|
| 158 |
+
if len(ligands_new) > 0:
|
| 159 |
+
labels_bdb_new.append(label)
|
| 160 |
+
|
| 161 |
+
labels += labels_bdb_new
|
| 162 |
+
for label in labels:
|
| 163 |
+
label["ligands"] = sorted(label["ligands"], key=lambda x: x["act"], reverse=True)
|
| 164 |
+
|
| 165 |
+
# labels = [x for x in labels if (x["uniprot"] not in non_repeat_uniprot_strict)]
|
| 166 |
+
return labels
|
| 167 |
+
|
| 168 |
+
def load_id_dict(result_root, assayinfo_lst):
|
| 169 |
+
import random
|
| 170 |
+
random.seed(42)
|
| 171 |
+
bdb_dir = f"{result_root}/BDB"
|
| 172 |
+
pdbbind_dir = f"{result_root}/PDBBind"
|
| 173 |
+
|
| 174 |
+
pocket_names = json.load(open(f"{bdb_dir}/bdb_pocket_names.json"))
|
| 175 |
+
pocket_embed = np.load(f"{bdb_dir}/bdb_pocket_reps.npy")
|
| 176 |
+
name2idx = {name:i for i, name in enumerate(pocket_names)}
|
| 177 |
+
|
| 178 |
+
assay_feat_lst = []
|
| 179 |
+
bdb_assayid_lst = []
|
| 180 |
+
for assay in assayinfo_lst:
|
| 181 |
+
assay_id = assay["assay_id"]
|
| 182 |
+
if assay.get("domain", None) == "pdbbind":
|
| 183 |
+
continue
|
| 184 |
+
pockets = assay["pockets"]
|
| 185 |
+
repeat_num = len(assay["ligands"])
|
| 186 |
+
repeat_num = int(np.sqrt(repeat_num))
|
| 187 |
+
for i in range(repeat_num):
|
| 188 |
+
pocket = random.choice(pockets)
|
| 189 |
+
assay_feat_lst.append(pocket_embed[name2idx[pocket]])
|
| 190 |
+
bdb_assayid_lst.append(assay_id)
|
| 191 |
+
|
| 192 |
+
bdb_assay_feat = np.stack(assay_feat_lst)
|
| 193 |
+
|
| 194 |
+
train_pdbbind_ids = json.load(open(f'{pdbbind_dir}/train_pdbbind_ids.json'))
|
| 195 |
+
train_pdbbind_pocket_embed = np.load(f"{pdbbind_dir}/train_pocket_reps.npy")
|
| 196 |
+
train_pdbbind_ids_new = []
|
| 197 |
+
train_pdbbind_pocket_embed_new = []
|
| 198 |
+
pdbbind_aidlist = [assay["assay_id"] for assay in assayinfo_lst if assay.get("domain", None) == "pdbbind"]
|
| 199 |
+
pdbbind_aidset = set(pdbbind_aidlist)
|
| 200 |
+
for id, embed in zip(train_pdbbind_ids, train_pdbbind_pocket_embed):
|
| 201 |
+
if id in pdbbind_aidset:
|
| 202 |
+
train_pdbbind_ids_new.append(id)
|
| 203 |
+
train_pdbbind_pocket_embed_new.append(embed)
|
| 204 |
+
|
| 205 |
+
train_pdbbind_ids = train_pdbbind_ids_new
|
| 206 |
+
train_pdbbind_pocket_embed = np.stack(train_pdbbind_pocket_embed_new)
|
| 207 |
+
|
| 208 |
+
train_pocket = bdb_assayid_lst + train_pdbbind_ids
|
| 209 |
+
pocket_feat_train = np.concatenate([bdb_assay_feat, train_pdbbind_pocket_embed])
|
| 210 |
+
test_pocket = json.load(open(f'{pdbbind_dir}/test_pdbbind_ids.json'))
|
| 211 |
+
pocket_feat_test = np.load(f'{pdbbind_dir}/test_pocket_reps.npy')
|
| 212 |
+
|
| 213 |
+
return train_pocket, test_pocket, pocket_feat_train, pocket_feat_test
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def load_datas(data_root, result_root):
|
| 217 |
+
assayinfo_lst = load_assayinfo(data_root, result_root)
|
| 218 |
+
assayid_lst_train, assayid_lst_test, pocket_feat_train, pocket_feat_test = load_id_dict(result_root, assayinfo_lst)
|
| 219 |
+
|
| 220 |
+
dude_pocket_feat, dude_pocket_name = load_pocket_dude(result_root)
|
| 221 |
+
|
| 222 |
+
pcba_pocket_feat, pcba_pocket_name = load_pocket_pcba(result_root)
|
| 223 |
+
|
| 224 |
+
dekois_pocket_feat, dekois_pocket_name = load_pocket_dekois(result_root)
|
| 225 |
+
|
| 226 |
+
pocket_feat = np.concatenate((pocket_feat_train, pocket_feat_test, dude_pocket_feat, pcba_pocket_feat, dekois_pocket_feat), axis=0)
|
| 227 |
+
assayid_lst_all = assayid_lst_train + assayid_lst_test + dude_pocket_name + pcba_pocket_name + dekois_pocket_name
|
| 228 |
+
|
| 229 |
+
save_dir_bdb = f"{result_root}/BDB"
|
| 230 |
+
save_dir_pdbbind = f"{result_root}/PDBBind"
|
| 231 |
+
mol_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_mol_reps.npy')
|
| 232 |
+
mol_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_mol_reps.npy')
|
| 233 |
+
mol_feat_test = np.load(f'{save_dir_pdbbind}/test_mol_reps.npy')
|
| 234 |
+
mol_feat = np.concatenate((mol_feat_train_bdb, mol_feat_train_pdbbind, mol_feat_test), axis=0)
|
| 235 |
+
mol_smi_lst = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json")) + json.load(open(f"{save_dir_pdbbind}/train_mol_smis.json")) + json.load(open(f"{save_dir_pdbbind}/test_mol_smis.json"))
|
| 236 |
+
test_len = len(json.load(open(f"{save_dir_pdbbind}/test_mol_smis.json")))
|
| 237 |
+
test_molidxes = range(len(mol_smi_lst)-test_len, len(mol_smi_lst))
|
| 238 |
+
return assayinfo_lst, pocket_feat, mol_feat, assayid_lst_all, mol_smi_lst, \
|
| 239 |
+
assayid_lst_train, assayid_lst_test, dude_pocket_name, pcba_pocket_name, dekois_pocket_name, test_molidxes
|
| 240 |
+
|
| 241 |
+
def load_valid_label(assayid_lst_test):
|
| 242 |
+
coreset = list(open("./data/CoreSet.dat").readlines())[1:]
|
| 243 |
+
pdbid2cluster = {}
|
| 244 |
+
for line in coreset:
|
| 245 |
+
line = line.strip().split()
|
| 246 |
+
pdbid = line[0]
|
| 247 |
+
cluster = line[-1]
|
| 248 |
+
pdbid2cluster[pdbid] = cluster
|
| 249 |
+
|
| 250 |
+
labels = np.zeros((len(assayid_lst_test), len(assayid_lst_test)))
|
| 251 |
+
for i, pdbid_1 in enumerate(assayid_lst_test):
|
| 252 |
+
for j, pdbid_2 in enumerate(assayid_lst_test):
|
| 253 |
+
if pdbid2cluster[pdbid_1] != pdbid2cluster[pdbid_2]:
|
| 254 |
+
labels[i, j] = 0
|
| 255 |
+
else:
|
| 256 |
+
labels[i, j] = 1
|
| 257 |
+
return labels
|
| 258 |
+
|
| 259 |
+
def load_pocket_pocket_graph(data_root, assayid_lst_all, assayid_lst_train):
|
| 260 |
+
neighbor_dict_train = json.load(
|
| 261 |
+
open(f"{data_root}/align_pair_res_train_10.23.json"))
|
| 262 |
+
train_keys = json.load(
|
| 263 |
+
open(f"{data_root}/align_train_keys_10.23.json"))
|
| 264 |
+
neighbor_dict_train_new = {}
|
| 265 |
+
for idx, neighbors in neighbor_dict_train.items():
|
| 266 |
+
neighbor_dict_train_new[train_keys[int(idx)]] = neighbors
|
| 267 |
+
neighbor_dict_train = neighbor_dict_train_new
|
| 268 |
+
assayid_set = set(assayid_lst_all)
|
| 269 |
+
assayid_set_train = set(assayid_lst_train)
|
| 270 |
+
PPGraph = {}
|
| 271 |
+
|
| 272 |
+
for assayid_1 in neighbor_dict_train.keys():
|
| 273 |
+
if assayid_1 not in assayid_set:
|
| 274 |
+
continue
|
| 275 |
+
neighbor_dict_train[assayid_1] = sorted(neighbor_dict_train[assayid_1], key=lambda x: x[1], reverse=True)
|
| 276 |
+
|
| 277 |
+
score_new = []
|
| 278 |
+
for assayid_2, score in neighbor_dict_train[assayid_1]:
|
| 279 |
+
if assayid_2 not in assayid_set_train:
|
| 280 |
+
continue
|
| 281 |
+
if score < 0.5:
|
| 282 |
+
continue
|
| 283 |
+
score_new.append((assayid_2, score))
|
| 284 |
+
PPGraph[assayid_1] = score_new
|
| 285 |
+
|
| 286 |
+
import pickle
|
| 287 |
+
align_res_test = json.load(open(f"{data_root}/align_pair_res_test_10.23.json"))
|
| 288 |
+
align_score_test = {}
|
| 289 |
+
|
| 290 |
+
for test_id in align_res_test.keys():
|
| 291 |
+
if test_id not in assayid_set:
|
| 292 |
+
continue
|
| 293 |
+
pocket_sim_infos = align_res_test[test_id]
|
| 294 |
+
pocket_sim_infos = sorted(pocket_sim_infos, key=lambda x: x[1], reverse=True)
|
| 295 |
+
score_new = []
|
| 296 |
+
for test_target, score in pocket_sim_infos:
|
| 297 |
+
test_target = test_target.split('.')[0]
|
| 298 |
+
if test_target not in assayid_set_train:
|
| 299 |
+
continue
|
| 300 |
+
if score < 0.5:
|
| 301 |
+
continue
|
| 302 |
+
score_new.append((test_target, score))
|
| 303 |
+
align_score_test[test_id] = score_new
|
| 304 |
+
|
| 305 |
+
# breakpoint()
|
| 306 |
+
PPGraph = {**PPGraph, **align_score_test}
|
| 307 |
+
return PPGraph
|
| 308 |
+
|
| 309 |
+
@contextlib.contextmanager
|
| 310 |
+
def numpy_seed(seed, *addl_seeds):
|
| 311 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
| 312 |
+
restores the state afterward"""
|
| 313 |
+
if seed is None:
|
| 314 |
+
yield
|
| 315 |
+
return
|
| 316 |
+
if len(addl_seeds) > 0:
|
| 317 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
| 318 |
+
state = np.random.get_state()
|
| 319 |
+
np.random.seed(seed)
|
| 320 |
+
try:
|
| 321 |
+
yield
|
| 322 |
+
finally:
|
| 323 |
+
np.random.set_state(state)
|
| 324 |
+
|
| 325 |
+
class ScreenDataset(Dataset):
|
| 326 |
+
def __init__(self, batch_size, assay_graph, assayinfo_lst, assayid_lst_all, mol_smi_lst, assayid_lst_train):
|
| 327 |
+
self.batch_size = batch_size
|
| 328 |
+
self.train_idxes = list(range(len(assayid_lst_train)))
|
| 329 |
+
self.assayid_set_train = set(assayid_lst_train)
|
| 330 |
+
self.train_idxes_epoch = copy.deepcopy(self.train_idxes)
|
| 331 |
+
self.assay_graph = assay_graph
|
| 332 |
+
self.assayinfo_dicts = {x["assay_id"]: x for x in assayinfo_lst}
|
| 333 |
+
self.smi2idx = {smi:idx for idx, smi in enumerate(mol_smi_lst)}
|
| 334 |
+
self.uniprotid_dict = load_uniprotid()
|
| 335 |
+
self.pocket_lig_graph = self.load_graph()
|
| 336 |
+
self.seed = 66
|
| 337 |
+
self.assayid2idxes = {}
|
| 338 |
+
for idx, assayid in enumerate(assayid_lst_all):
|
| 339 |
+
if assayid not in self.assayid2idxes:
|
| 340 |
+
self.assayid2idxes[assayid] = []
|
| 341 |
+
self.assayid2idxes[assayid].append(idx)
|
| 342 |
+
self.idx2assayid = assayid_lst_all
|
| 343 |
+
self.epoch = 0
|
| 344 |
+
|
| 345 |
+
def set_epoch(self, epoch):
|
| 346 |
+
self.epoch = epoch
|
| 347 |
+
with numpy_seed(self.seed, epoch):
|
| 348 |
+
self.train_idxes_epoch = copy.deepcopy(self.train_idxes)
|
| 349 |
+
np.random.shuffle(self.train_idxes_epoch)
|
| 350 |
+
|
| 351 |
+
def load_graph(self):
|
| 352 |
+
pocket_lig_graph = {}
|
| 353 |
+
if os.path.exists("./data/pocket_lig_graph.json"):
|
| 354 |
+
pocket_lig_graph = json.load(open("./data/pocket_lig_graph.json"))
|
| 355 |
+
else:
|
| 356 |
+
from tqdm import tqdm
|
| 357 |
+
for assayid in tqdm(self.assayid2idxes.keys()):
|
| 358 |
+
if assayid not in self.assayid_set_train:
|
| 359 |
+
continue
|
| 360 |
+
ligands = self.assayinfo_dicts[assayid]["ligands"]
|
| 361 |
+
lig_candidate = []
|
| 362 |
+
if len(ligands) > 1:
|
| 363 |
+
lig_assay = [x["smi"] for x in ligands if x["act"] >= 5]
|
| 364 |
+
else:
|
| 365 |
+
lig_assay = [x["smi"] for x in ligands]
|
| 366 |
+
lig_candidate += lig_assay
|
| 367 |
+
lig_assay = set(lig_assay)
|
| 368 |
+
uniprot = self.assayinfo_dicts[assayid]["uniprot"]
|
| 369 |
+
|
| 370 |
+
for assayid_nbr, score in self.assay_graph.get(assayid, []):
|
| 371 |
+
if assayid_nbr not in self.assayinfo_dicts:
|
| 372 |
+
continue
|
| 373 |
+
assay_nbr = self.assayinfo_dicts[assayid_nbr]
|
| 374 |
+
uniprot_nbr = assay_nbr["uniprot"]
|
| 375 |
+
ligands_nbr = assay_nbr["ligands"]
|
| 376 |
+
if len(ligands) > 1:
|
| 377 |
+
lig_candidate_nbr = [x["smi"] for x in ligands_nbr if x["act"] >= 5]
|
| 378 |
+
else:
|
| 379 |
+
lig_candidate_nbr = [x["smi"] for x in ligands_nbr]
|
| 380 |
+
if assayid_nbr not in self.assayid_set_train:
|
| 381 |
+
continue
|
| 382 |
+
if len(lig_assay & set(lig_candidate_nbr)) > 0:
|
| 383 |
+
lig_candidate += lig_candidate_nbr
|
| 384 |
+
elif uniprot == uniprot_nbr:
|
| 385 |
+
lig_candidate += lig_candidate_nbr
|
| 386 |
+
|
| 387 |
+
pocket_lig_graph[assayid] = [x for x in set(lig_candidate) if x in self.smi2idx]
|
| 388 |
+
|
| 389 |
+
json.dump(pocket_lig_graph, open("./data/pocket_lig_graph.json", "w"))
|
| 390 |
+
return pocket_lig_graph
|
| 391 |
+
|
| 392 |
+
def __getitem__(self, item):
|
| 393 |
+
pocket_idx_batch = self.train_idxes_epoch[item*self.batch_size:(item+1)*self.batch_size]
|
| 394 |
+
pocket_batch = [self.idx2assayid[idx] for idx in pocket_idx_batch]
|
| 395 |
+
lig_batch = []
|
| 396 |
+
lig_idx_batch = []
|
| 397 |
+
epoch = self.epoch
|
| 398 |
+
for pocket in pocket_batch:
|
| 399 |
+
lig_candidate = self.pocket_lig_graph[pocket]
|
| 400 |
+
with numpy_seed(self.seed, epoch, item):
|
| 401 |
+
lig = np.random.choice(lig_candidate)
|
| 402 |
+
lig_batch.append(lig)
|
| 403 |
+
|
| 404 |
+
lig_idx_batch.append(self.smi2idx[lig])
|
| 405 |
+
|
| 406 |
+
labels = np.zeros((self.batch_size, self.batch_size))
|
| 407 |
+
for i, pocket in enumerate(pocket_batch):
|
| 408 |
+
for j, lig in enumerate(lig_batch):
|
| 409 |
+
if lig in self.pocket_lig_graph[pocket]:
|
| 410 |
+
labels[i, j] = 1
|
| 411 |
+
else:
|
| 412 |
+
labels[i, j] = 0
|
| 413 |
+
|
| 414 |
+
return torch.tensor(pocket_idx_batch), torch.tensor(lig_idx_batch), torch.tensor(labels)
|
| 415 |
+
|
| 416 |
+
def __len__(self):
|
| 417 |
+
return len(self.train_idxes_epoch) // self.batch_size
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
HGNN/screening.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from PL_Encoder import PLEncoder
|
| 8 |
+
from PL_Aggregator import PLAggregator
|
| 9 |
+
from PP_Encoder import PPEncoder
|
| 10 |
+
from PP_Aggregator import PPAggregator
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.utils.data
|
| 13 |
+
import argparse
|
| 14 |
+
import os
|
| 15 |
+
from util import cal_metrics
|
| 16 |
+
from read_fasta import read_fasta_from_pocket, read_fasta_from_protein
|
| 17 |
+
from align import get_neighbor_pocket
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HGNN(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, enc_u, enc_v=None, r2e=None):
|
| 23 |
+
super(HGNN, self).__init__()
|
| 24 |
+
self.enc_u = enc_u
|
| 25 |
+
self.enc_v = enc_v
|
| 26 |
+
self.embed_dim = enc_u.embed_dim
|
| 27 |
+
|
| 28 |
+
self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 29 |
+
self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 30 |
+
self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 31 |
+
self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim)
|
| 32 |
+
|
| 33 |
+
self.r2e = r2e
|
| 34 |
+
self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
|
| 35 |
+
self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
|
| 36 |
+
|
| 37 |
+
self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14))
|
| 38 |
+
|
| 39 |
+
def trainable_parameters(self):
|
| 40 |
+
for name, param in self.named_parameters(recurse=True):
|
| 41 |
+
if param.requires_grad:
|
| 42 |
+
yield param
|
| 43 |
+
|
| 44 |
+
def forward(self, nodes_u, nodes_v):
|
| 45 |
+
embeds_u = self.enc_u(nodes_u, nodes_v)
|
| 46 |
+
embeds_v = self.enc_v(nodes_v)
|
| 47 |
+
return embeds_u, embeds_v
|
| 48 |
+
|
| 49 |
+
def criterion(self, x_u, x_v, labels):
|
| 50 |
+
|
| 51 |
+
netout = torch.matmul(x_u, torch.transpose(x_v, 0, 1))
|
| 52 |
+
score = netout * self.logit_scale.exp().detach()
|
| 53 |
+
score = (labels - torch.eye(len(labels)).to(labels.device)) * -1e6 + score
|
| 54 |
+
|
| 55 |
+
lprobs_pocket = F.log_softmax(score.float(), dim=-1)
|
| 56 |
+
lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
|
| 57 |
+
sample_size = lprobs_pocket.size(0)
|
| 58 |
+
targets = torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
|
| 59 |
+
|
| 60 |
+
# pocket retrieve mol
|
| 61 |
+
loss_pocket = F.nll_loss(
|
| 62 |
+
lprobs_pocket,
|
| 63 |
+
targets,
|
| 64 |
+
reduction="mean"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
lprobs_mol = F.log_softmax(torch.transpose(score.float(), 0, 1), dim=-1)
|
| 68 |
+
lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
|
| 69 |
+
lprobs_mol = lprobs_mol[:sample_size]
|
| 70 |
+
|
| 71 |
+
# mol retrieve pocket
|
| 72 |
+
loss_mol = F.nll_loss(
|
| 73 |
+
lprobs_mol,
|
| 74 |
+
targets,
|
| 75 |
+
reduction="mean"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
loss = 0.5 * loss_pocket + 0.5 * loss_mol
|
| 79 |
+
|
| 80 |
+
ef_all = []
|
| 81 |
+
for i in range(len(netout)):
|
| 82 |
+
act_pocket = labels[i]
|
| 83 |
+
affi_pocket = netout[i]
|
| 84 |
+
top1_index = torch.argmax(affi_pocket)
|
| 85 |
+
top1_act = act_pocket[top1_index]
|
| 86 |
+
ef_all.append(cal_metrics(affi_pocket.detach().cpu().numpy(), act_pocket.detach().cpu().numpy()))
|
| 87 |
+
ef_mean = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
|
| 88 |
+
|
| 89 |
+
return loss, ef_mean, netout
|
| 90 |
+
|
| 91 |
+
def loss(self, nodes_u, nodes_v, labels):
|
| 92 |
+
x_u, x_v = self.forward(nodes_u, nodes_v)
|
| 93 |
+
loss, ef_mean, netout = self.criterion(x_u, x_v, labels)
|
| 94 |
+
return loss, ef_mean
|
| 95 |
+
|
| 96 |
+
def refine_pocket(self, pocket_embed, neighbor_pocket_list):
|
| 97 |
+
embeds_u = self.enc_u.refine_pocket(pocket_embed, neighbor_pocket_list)
|
| 98 |
+
return embeds_u
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
# Training settings
|
| 104 |
+
parser = argparse.ArgumentParser(description='HGNN model inference')
|
| 105 |
+
parser.add_argument('--embed_dim', type=int, default=128, metavar='N', help='embedding size')
|
| 106 |
+
parser.add_argument("--test_ckpt", type=str, default=None)
|
| 107 |
+
parser.add_argument("--data_root", type=str, default="../data")
|
| 108 |
+
parser.add_argument("--result_root", type=str, default="../result/pocket_ranking")
|
| 109 |
+
parser.add_argument("--pocket_embed", type=str, default="../example/pocket_embed.npy")
|
| 110 |
+
parser.add_argument("--save_file", type=str, default="../example/refined_pocket.npy")
|
| 111 |
+
parser.add_argument("--pocket_pdb", type=str, default=None)
|
| 112 |
+
parser.add_argument("--protein_pdb", type=str, default="../example/protein.pdb")
|
| 113 |
+
parser.add_argument("--ligand_pdb", type=str, default="../example/ligand.pdb")
|
| 114 |
+
|
| 115 |
+
args = parser.parse_args()
|
| 116 |
+
|
| 117 |
+
seed = 42
|
| 118 |
+
torch.manual_seed(seed)
|
| 119 |
+
torch.cuda.manual_seed(seed)
|
| 120 |
+
np.random.seed(seed)
|
| 121 |
+
random.seed(seed)
|
| 122 |
+
|
| 123 |
+
use_cuda = False
|
| 124 |
+
if torch.cuda.is_available():
|
| 125 |
+
use_cuda = True
|
| 126 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 127 |
+
|
| 128 |
+
embed_dim = args.embed_dim
|
| 129 |
+
type2e = nn.Embedding(10, embed_dim).to(device)
|
| 130 |
+
|
| 131 |
+
# load model
|
| 132 |
+
agg_pocket = PLAggregator(r2e=type2e, embed_dim=embed_dim, cuda=device, uv=True)
|
| 133 |
+
enc_pocket = PLEncoder(embed_dim=embed_dim, aggregator=agg_pocket, cuda=device, uv=True)
|
| 134 |
+
agg_pocket_sim = PPAggregator(embed_dim=embed_dim, cuda=device)
|
| 135 |
+
enc_pocket = PPEncoder(enc_pocket, embed_dim=embed_dim, aggregator=agg_pocket_sim, cuda=device)
|
| 136 |
+
|
| 137 |
+
model = HGNN(enc_pocket).to(device)
|
| 138 |
+
model.load_state_dict(torch.load(args.test_ckpt, weights_only=True), strict=False)
|
| 139 |
+
model.eval()
|
| 140 |
+
|
| 141 |
+
# load pocket embedding and fasta
|
| 142 |
+
pocket_embed = torch.tensor(np.load(args.pocket_embed)).to(device)
|
| 143 |
+
|
| 144 |
+
if args.pocket_pdb is not None:
|
| 145 |
+
pocket_fasta = read_fasta_from_pocket(args.pocket_pdb)
|
| 146 |
+
else:
|
| 147 |
+
pocket_fasta = read_fasta_from_protein(args.protein_pdb, args.ligand_pdb)
|
| 148 |
+
|
| 149 |
+
# get neighbor pocket
|
| 150 |
+
neighbor_pocket_list = get_neighbor_pocket(pocket_fasta, args.data_root, args.result_root, device) # [(pocket_embed, ligand_embed, similarity)]
|
| 151 |
+
|
| 152 |
+
# get refined pocket
|
| 153 |
+
if len(neighbor_pocket_list) > 0:
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
refined_pocket = model.refine_pocket(pocket_embed, neighbor_pocket_list)
|
| 156 |
+
refined_pocket = refined_pocket.cpu().numpy()
|
| 157 |
+
else:
|
| 158 |
+
refined_pocket = pocket_embed.cpu().numpy()
|
| 159 |
+
|
| 160 |
+
print("finished, saving refined pocket embedding into:", args.save_file)
|
| 161 |
+
np.save(args.save_file, refined_pocket)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|
HGNN/test_pocket.fasta
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
>aa2ar_pocket_5:_ 19
|
| 2 |
+
VLTLFEMMNWLXNHALMYI
|
HGNN/util.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
|
| 3 |
+
from sklearn.metrics import roc_curve
|
| 4 |
+
|
| 5 |
+
def re_new(y_true, y_score, ratio):
|
| 6 |
+
fp = 0
|
| 7 |
+
tp = 0
|
| 8 |
+
p = sum(y_true)
|
| 9 |
+
n = len(y_true) - p
|
| 10 |
+
num = ratio * n
|
| 11 |
+
sort_index = np.argsort(y_score)[::-1]
|
| 12 |
+
for i in range(len(sort_index)):
|
| 13 |
+
index = sort_index[i]
|
| 14 |
+
if y_true[index] == 1:
|
| 15 |
+
tp += 1
|
| 16 |
+
else:
|
| 17 |
+
fp += 1
|
| 18 |
+
if fp >= num:
|
| 19 |
+
break
|
| 20 |
+
return (tp * n) / (p * fp)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def calc_re(y_true, y_score, ratio_list):
|
| 24 |
+
fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
|
| 25 |
+
# print(fpr, tpr)
|
| 26 |
+
res = {}
|
| 27 |
+
res2 = {}
|
| 28 |
+
total_active_compounds = sum(y_true)
|
| 29 |
+
total_compounds = len(y_true)
|
| 30 |
+
|
| 31 |
+
# for ratio in ratio_list:
|
| 32 |
+
# for i, t in enumerate(fpr):
|
| 33 |
+
# if t > ratio:
|
| 34 |
+
# #print(fpr[i], tpr[i])
|
| 35 |
+
# if fpr[i-1]==0:
|
| 36 |
+
# res[str(ratio)]=tpr[i]/fpr[i]
|
| 37 |
+
# else:
|
| 38 |
+
# res[str(ratio)]=tpr[i-1]/fpr[i-1]
|
| 39 |
+
# break
|
| 40 |
+
|
| 41 |
+
for ratio in ratio_list:
|
| 42 |
+
res2[str(ratio)] = re_new(y_true, y_score, ratio)
|
| 43 |
+
|
| 44 |
+
# print(res)
|
| 45 |
+
# print(res2)
|
| 46 |
+
return res2
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def cal_metrics(y_score, y_true, alpha=80.5):
|
| 50 |
+
"""
|
| 51 |
+
Calculate BEDROC score.
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
- y_true: true binary labels (0 or 1)
|
| 55 |
+
- y_score: predicted scores or probabilities
|
| 56 |
+
- alpha: parameter controlling the degree of early retrieval emphasis
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
- BEDROC score
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# concate res_single and labels
|
| 63 |
+
scores = np.expand_dims(y_score, axis=1)
|
| 64 |
+
y_true = np.expand_dims(y_true, axis=1)
|
| 65 |
+
scores = np.concatenate((scores, y_true), axis=1)
|
| 66 |
+
# inverse sort scores based on first column
|
| 67 |
+
scores = scores[scores[:, 0].argsort()[::-1]]
|
| 68 |
+
bedroc = CalcBEDROC(scores, 1, 80.5)
|
| 69 |
+
count = 0
|
| 70 |
+
# sort y_score, return index
|
| 71 |
+
index = np.argsort(y_score)[::-1]
|
| 72 |
+
for i in range(int(len(index) * 0.005)):
|
| 73 |
+
if y_true[index[i]] == 1:
|
| 74 |
+
count += 1
|
| 75 |
+
auc = CalcAUC(scores, 1)
|
| 76 |
+
ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.05])
|
| 77 |
+
return {
|
| 78 |
+
"BEDROC": bedroc,
|
| 79 |
+
"AUC": auc,
|
| 80 |
+
"EF0.5": ef_list[0],
|
| 81 |
+
"EF1": ef_list[1],
|
| 82 |
+
"EF5": ef_list[2]
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# import torch
|
| 87 |
+
# torch.multiprocessing.set_start_method('spawn', force=True)
|
| 88 |
+
# def mycollator(input_batch):
|
| 89 |
+
# for data in input_batch:
|
| 90 |
+
# node, neighbors = data
|
| 91 |
+
# node["pocket_data"] = torch.tensor(node["pocket_data"]).cuda()
|
| 92 |
+
# node["lig_data"] = torch.tensor(node["lig_data"]).cuda()
|
| 93 |
+
# for neighbor in neighbors:
|
| 94 |
+
# neighbor["pocket_data"] = torch.tensor(node["pocket_data"]).cuda()
|
| 95 |
+
# neighbor["lig_data"] = torch.tensor(node["lig_data"]).cuda()
|
| 96 |
+
# return input_batch
|
License
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
|
| 4 |
+
>
|
| 5 |
+
> ### Using Creative Commons Public Licenses
|
| 6 |
+
>
|
| 7 |
+
> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
| 8 |
+
>
|
| 9 |
+
> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
| 10 |
+
>
|
| 11 |
+
> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
| 12 |
+
|
| 13 |
+
## Creative Commons Attribution-NonCommercial 4.0 International Public License
|
| 14 |
+
|
| 15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
| 16 |
+
|
| 17 |
+
### Section 1 – Definitions.
|
| 18 |
+
|
| 19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
| 20 |
+
|
| 21 |
+
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
| 22 |
+
|
| 23 |
+
c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
| 24 |
+
|
| 25 |
+
d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
| 26 |
+
|
| 27 |
+
e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
| 28 |
+
|
| 29 |
+
f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
| 30 |
+
|
| 31 |
+
g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
| 32 |
+
|
| 33 |
+
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
| 34 |
+
|
| 35 |
+
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
| 36 |
+
|
| 37 |
+
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
| 38 |
+
|
| 39 |
+
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
| 40 |
+
|
| 41 |
+
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
| 42 |
+
|
| 43 |
+
### Section 2 – Scope.
|
| 44 |
+
|
| 45 |
+
a. ___License grant.___
|
| 46 |
+
|
| 47 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
| 48 |
+
|
| 49 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
| 50 |
+
|
| 51 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
| 52 |
+
|
| 53 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
| 54 |
+
|
| 55 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
| 56 |
+
|
| 57 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
| 58 |
+
|
| 59 |
+
5. __Downstream recipients.__
|
| 60 |
+
|
| 61 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
| 62 |
+
|
| 63 |
+
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
| 64 |
+
|
| 65 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
| 66 |
+
|
| 67 |
+
b. ___Other rights.___
|
| 68 |
+
|
| 69 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
| 70 |
+
|
| 71 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
| 72 |
+
|
| 73 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
| 74 |
+
|
| 75 |
+
### Section 3 – License Conditions.
|
| 76 |
+
|
| 77 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
| 78 |
+
|
| 79 |
+
a. ___Attribution.___
|
| 80 |
+
|
| 81 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
| 82 |
+
|
| 83 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
| 84 |
+
|
| 85 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
| 86 |
+
|
| 87 |
+
ii. a copyright notice;
|
| 88 |
+
|
| 89 |
+
iii. a notice that refers to this Public License;
|
| 90 |
+
|
| 91 |
+
iv. a notice that refers to the disclaimer of warranties;
|
| 92 |
+
|
| 93 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
| 94 |
+
|
| 95 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
| 96 |
+
|
| 97 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
| 98 |
+
|
| 99 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
| 100 |
+
|
| 101 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
| 102 |
+
|
| 103 |
+
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
|
| 104 |
+
|
| 105 |
+
### Section 4 – Sui Generis Database Rights.
|
| 106 |
+
|
| 107 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
| 108 |
+
|
| 109 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
| 110 |
+
|
| 111 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
| 112 |
+
|
| 113 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
| 114 |
+
|
| 115 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
| 116 |
+
|
| 117 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
| 118 |
+
|
| 119 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
| 120 |
+
|
| 121 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
| 122 |
+
|
| 123 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
| 124 |
+
|
| 125 |
+
### Section 6 – Term and Termination.
|
| 126 |
+
|
| 127 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
| 128 |
+
|
| 129 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
| 130 |
+
|
| 131 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
| 132 |
+
|
| 133 |
+
2. upon express reinstatement by the Licensor.
|
| 134 |
+
|
| 135 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
| 136 |
+
|
| 137 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
| 138 |
+
|
| 139 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
| 140 |
+
|
| 141 |
+
### Section 7 – Other Terms and Conditions.
|
| 142 |
+
|
| 143 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
| 144 |
+
|
| 145 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
| 146 |
+
|
| 147 |
+
### Section 8 – Interpretation.
|
| 148 |
+
|
| 149 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
| 150 |
+
|
| 151 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
| 152 |
+
|
| 153 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
| 154 |
+
|
| 155 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
| 156 |
+
|
| 157 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
| 158 |
+
>
|
| 159 |
+
> Creative Commons may be contacted at creativecommons.org
|
README.md
CHANGED
|
@@ -1,3 +1,206 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## General
|
| 2 |
+
This repository contains the code for **LigUnity**: **Hierarchical affinity landscape navigation through learning a shared pocket-ligand space.**
|
| 3 |
+
|
| 4 |
+
**We are excited to announce that our paper has been accepted by Patterns and is featured as the cover article for the October 2025 issue!**
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
|
| 8 |
+
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
|
| 9 |
+
[](https://doi.org/10.1016/j.patter.2025.101371)
|
| 10 |
+
[](https://github.com/IDEA-XL/LigUnity)
|
| 11 |
+
|
| 12 |
+
<table>
|
| 13 |
+
<tr>
|
| 14 |
+
<td width="250px" valign="top">
|
| 15 |
+
<a href="https://www.cell.com/patterns/fulltext/S2666-3899(25)00219-3">
|
| 16 |
+
<img src="https://github.com/user-attachments/assets/5ab7f659-0b56-4cf1-8db0-7129d71ea9d5" alt="LigUnity Patterns Cover Image" width="230px" />
|
| 17 |
+
</a>
|
| 18 |
+
</td>
|
| 19 |
+
<td valign="top">
|
| 20 |
+
<p>
|
| 21 |
+
<strong>On the cover:</strong> This ocean symbolizes the human proteome—the complete set of proteins that carry out essential functions in our bodies. For medicine to work, it often needs to interact with a specific protein. For an estimated 90% of these proteins, however, they lack known small-molecule ligands with high activity. In the image, these proteins are represented as sailboats drifting in the dark.
|
| 22 |
+
</p>
|
| 23 |
+
<p>
|
| 24 |
+
At the center, stands a lighthouse symbolizing the AI method <strong>LigUnity</strong>. Its beam illuminates several sailboats, guiding them toward glowing buoys, which symbolize ligands with high activity found by LigUnity. The work by Feng et al. highlights the power of AI-driven computational methods to efficiently find active ligands and optimize their activity, opening up new therapeutic avenues for various diseases.
|
| 25 |
+
</p>
|
| 26 |
+
</td>
|
| 27 |
+
</tr>
|
| 28 |
+
</table>
|
| 29 |
+
|
| 30 |
+
## Instruction on running our model
|
| 31 |
+
|
| 32 |
+
### Virtual Screening
|
| 33 |
+
Colab demo for virtual screening with given protein pocket and candidate ligands.
|
| 34 |
+
|
| 35 |
+
https://colab.research.google.com/drive/1F0QSPjkKKLAfBexmIQotcs-jm87ohHeG?usp=sharing
|
| 36 |
+
|
| 37 |
+
### Hit-to-lead optimization
|
| 38 |
+
**Direct inference**
|
| 39 |
+
Colab demo for code inference with given protein and unmeasured ligands.
|
| 40 |
+
|
| 41 |
+
https://colab.research.google.com/drive/11Fx6mO51rRkPvq71qupuUmscfBw8Dw5R?usp=sharing
|
| 42 |
+
|
| 43 |
+
**Few-shot fine-tuning**
|
| 44 |
+
Colab demo for few-shot fine-tuning with given protein, few measure ligands for fine-tuning and unmeasured ligands for testing.
|
| 45 |
+
|
| 46 |
+
https://colab.research.google.com/drive/1gf0HhgyqI4qBjUAUICCvDa-FnTaARmR_?usp=sharing
|
| 47 |
+
|
| 48 |
+
Please feel free to contact me by email if there is any problem with the code or paper: fengbin@idea.edu.cn.
|
| 49 |
+
|
| 50 |
+
### Resource availability
|
| 51 |
+
|
| 52 |
+
The datasets for LigUnity were collected from ChEMBL version 34 and BindingDB version 2024m5. Our training dataset is available on figshare (https://doi.org/10.6084/m9.figshare.27966819). Our PocketAffDB with protein and pocket PDB structures is available on figshare (https://doi.org/10.6084/m9.figshare.29379161).
|
| 53 |
+
|
| 54 |
+
## Abstract
|
| 55 |
+
|
| 56 |
+
Protein-ligand binding affinity plays an important role in drug discovery, especially during virtual screening and hit-to-lead optimization. Computational chemistry and machine learning methods have been developed to investigate these tasks. Despite the encouraging performance, virtual screening and hit-to-lead optimization are often studied separately by existing methods, partially because they are performed sequentially in the existing drug discovery pipeline, thereby overlooking their interdependency and complementarity. To address this problem, we propose LigUnity, a foundation model for protein-ligand binding prediction by jointly optimizing virtual screening and hit-to-lead optimization.
|
| 57 |
+
In particular, LigUnity learns coarse-grained active/inactive distinction for virtual screening, and fine-grained pocket-specific ligand preference for hit-to-lead optimization.
|
| 58 |
+
We demonstrate the effectiveness and versatility of LigUnity on eight benchmarks across virtual screening and hit-to-lead optimization. In virtual screening, LigUnity outperforms 24 competing methods with more than 50% improvement on the DUD-E and Dekois 2.0 benchmarks, and shows robust generalization to novel proteins. In hit-to-lead optimization, LigUnity achieves the best performance on split-by-time, split-by-scaffold, and split-by-unit settings, further demonstrating its potential as a cost-effective alternative to free energy perturbation (FEP) calculations. We further showcase how LigUnity can be employed in an active learning framework to efficiently identify active ligands for TYK2, a therapeutic target for autoimmune diseases, yielding over 40% improved prediction performance. Collectively, these comprehensive results establish LigUnity as a versatile foundation model for both virtual screening and hit-to-lead optimization, offering broad applicability across the drug discovery pipeline through accurate protein-ligand affinity predictions.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
## Reproduce results in our paper
|
| 63 |
+
|
| 64 |
+
### Reproduce results on virtual screening benchmarks
|
| 65 |
+
|
| 66 |
+
Please first download checkpoints and processed dataset before running
|
| 67 |
+
- Download our procesed Dekois 2.0 dataset from https://doi.org/10.6084/m9.figshare.27967422
|
| 68 |
+
- Download LIT-PCBA and DUD-E datasets from https://drive.google.com/drive/folders/1zW1MGpgunynFxTKXC2Q4RgWxZmg6CInV?usp=sharing
|
| 69 |
+
- Clone model checkpoint from https://huggingface.co/fengb/LigUnity_VS (test proteins in DUD-E, Dekois, and LIT-PCBA are removed from the training set)
|
| 70 |
+
- Clone dataset from https://figshare.com/articles/dataset/LigUnity_project_data/27966819 and unzip them all (you can ignore .lmdb file if you only want to reproduce test result).
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
# run pocket/protein and ligand encoder model
|
| 74 |
+
path2weight="absolute path to the checkpoint of pocket_ranking"
|
| 75 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh ALL pocket_ranking ${path2weight} "./result/pocket_ranking"
|
| 76 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh BDB pocket_ranking ${path2weight} "./result/pocket_ranking"
|
| 77 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh PDB pocket_ranking ${path2weight} "./result/pocket_ranking"
|
| 78 |
+
|
| 79 |
+
path2weight="absolute path to the checkpoint of protein_ranking"
|
| 80 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh ALL protein_ranking ${path2weight} "./result/protein_ranking"
|
| 81 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh BDB protein_ranking ${path2weight} "./result/protein_ranking"
|
| 82 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh PDB protein_ranking ${path2weight} "./result/protein_ranking"
|
| 83 |
+
|
| 84 |
+
# train H-GNN model
|
| 85 |
+
cd ./HGNN
|
| 86 |
+
path2weight_HGNN="absolute path to the checkpoint of HGNN pocket"
|
| 87 |
+
python main.py --data_root ${path2data} --result_root "../result/pocket_ranking" --test_ckpt ${path2weight_HGNN}
|
| 88 |
+
path2weight_HGNN="absolute path to the checkpoint of HGNN protein"
|
| 89 |
+
python main.py --data_root ${path2data} --result_root "../result/protein_ranking" --test_ckpt ${path2weight_HGNN}
|
| 90 |
+
|
| 91 |
+
# get final prediction of our model
|
| 92 |
+
python ensemble_result.py DUDE PCBA DEKOIS
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
### Reproduce results on FEP benchmarks (zero-shot)
|
| 97 |
+
|
| 98 |
+
Please first download checkpoints before running
|
| 99 |
+
- Clone model checkpoint from https://huggingface.co/fengb/LigUnity_pocket_ranking and https://huggingface.co/fengb/LigUnity_protein_ranking (test ligands and assays in FEP benchmarks are removed from the training set)
|
| 100 |
+
|
| 101 |
+
```
|
| 102 |
+
# run pocket/protein and ligand encoder model
|
| 103 |
+
for r in {1..6} do
|
| 104 |
+
path2weight="path to checkpoint of pocket_ranking"
|
| 105 |
+
path2result="./result/pocket_ranking/FEP/repeat_{r}"
|
| 106 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh FEP pocket_ranking ${path2weight} ${path2result}
|
| 107 |
+
|
| 108 |
+
path2weight="path to checkpoint of protein_ranking"
|
| 109 |
+
path2result="./result/protein_ranking/FEP/repeat_{r}"
|
| 110 |
+
CUDA_VISIBLE_DEVICES=0 bash test.sh FEP protein_ranking ${path2weight} ${path2result}
|
| 111 |
+
done
|
| 112 |
+
|
| 113 |
+
# get final prediction of our model
|
| 114 |
+
python ensemble_result.py FEP
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Reproduce results on FEP benchmarks (few-shot)
|
| 118 |
+
```
|
| 119 |
+
# use the same checkpoints as in zero-shot
|
| 120 |
+
# run few-shot fine-tuning
|
| 121 |
+
for r in {1..6} do
|
| 122 |
+
path2weight="path to checkpoint of pocket_ranking"
|
| 123 |
+
path2result="./result/pocket_ranking/FEP_fewshot/repeat_{r}"
|
| 124 |
+
support_num=0.6
|
| 125 |
+
CUDA_VISIBLE_DEVICES=0 bash test_fewshot.sh FEP pocket_ranking support_num ${path2weight} ${path2result}
|
| 126 |
+
|
| 127 |
+
path2weight="path to checkpoint of protein_ranking"
|
| 128 |
+
path2result="./result/protein_ranking/FEP_fewshot/repeat_{r}"
|
| 129 |
+
CUDA_VISIBLE_DEVICES=0 bash test_fewshot.sh FEP protein_ranking support_num ${path2weight} ${path2result}
|
| 130 |
+
done
|
| 131 |
+
|
| 132 |
+
# get final prediction of our model
|
| 133 |
+
python ensemble_result_fewshot.py FEP_fewshot support_num
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Reproduce results on active learning
|
| 137 |
+
to speed up the active learning process, you should modify the unicore code
|
| 138 |
+
1. find the installed dir of unicore (root-to-unicore)
|
| 139 |
+
```
|
| 140 |
+
python -c "import unicore; print('/'.join(unicore.__file__.split('/')[:-2]))"
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
2. goto root-to-unicore/unicore/options.py line 250, add following line
|
| 144 |
+
```
|
| 145 |
+
group.add_argument('--validate-begin-epoch', type=int, default=0, metavar='N',
|
| 146 |
+
help='validate begin epoch')
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
3. goto root-to-unicore/unicore_cli/train.py line 303, add one line
|
| 150 |
+
```
|
| 151 |
+
do_validate = (
|
| 152 |
+
(not end_of_epoch and do_save)
|
| 153 |
+
or (
|
| 154 |
+
end_of_epoch
|
| 155 |
+
and epoch_itr.epoch >= args.validate_begin_epoch # !!!! add this line
|
| 156 |
+
and epoch_itr.epoch % args.validate_interval == 0
|
| 157 |
+
and not args.no_epoch_checkpoints
|
| 158 |
+
)
|
| 159 |
+
or should_stop
|
| 160 |
+
or (
|
| 161 |
+
args.validate_interval_updates > 0
|
| 162 |
+
and num_updates > 0
|
| 163 |
+
and num_updates % args.validate_interval_updates == 0
|
| 164 |
+
)
|
| 165 |
+
) and not args.disable_validation
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
4. run the active learning procedure
|
| 169 |
+
```
|
| 170 |
+
# use the same checkpoints as in FEP experiments
|
| 171 |
+
path1="path to checkpoint of pocket_ranking"
|
| 172 |
+
path2="path to checkpoint of protein_ranking"
|
| 173 |
+
result1="./result/pocket_ranking/TYK2"
|
| 174 |
+
result2="./result/protein_ranking/TYK2"
|
| 175 |
+
|
| 176 |
+
# run active learning cycle for 5 iters with pure greedy strategy
|
| 177 |
+
bash ./active_learning_scripts/run_al.sh 5 0 path1 path2 result1 result2
|
| 178 |
+
```
|
| 179 |
+
## Citation
|
| 180 |
+
|
| 181 |
+
```
|
| 182 |
+
@article{feng2025hierarchical,
|
| 183 |
+
title={Hierarchical affinity landscape navigation through learning a shared pocket-ligand space},
|
| 184 |
+
author={Feng, Bin and Liu, Zijing and Li, Hao and Yang, Mingjun and Zou, Junjie and Cao, He and Li, Yu and Zhang, Lei and Wang, Sheng},
|
| 185 |
+
journal={Patterns},
|
| 186 |
+
year={2025},
|
| 187 |
+
publisher={Elsevier}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
@article{feng2024bioactivity,
|
| 191 |
+
title={A bioactivity foundation model using pairwise meta-learning},
|
| 192 |
+
author={Feng, Bin and Liu, Zequn and Huang, Nanlan and Xiao, Zhiping and Zhang, Haomiao and Mirzoyan, Srbuhi and Xu, Hanwen and Hao, Jiaran and Xu, Yinghui and Zhang, Ming and others},
|
| 193 |
+
journal={Nature Machine Intelligence},
|
| 194 |
+
volume={6},
|
| 195 |
+
number={8},
|
| 196 |
+
pages={962--974},
|
| 197 |
+
year={2024},
|
| 198 |
+
publisher={Nature Publishing Group UK London}
|
| 199 |
+
}
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## Acknowledgments
|
| 203 |
+
|
| 204 |
+
This project was built based on Uni-Mol (https://github.com/deepmodeling/Uni-Mol)
|
| 205 |
+
|
| 206 |
+
Parts of our code reference the implementation from DrugCLIP (https://github.com/bowen-gao/DrugCLIP) by bowen-gao
|
active_learning_scripts/run_al.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_cycles=${1}
|
| 2 |
+
begin_greedy=${2}
|
| 3 |
+
weight_path1=${3}
|
| 4 |
+
weight_path2=${4}
|
| 5 |
+
result_path1=${5}
|
| 6 |
+
result_path2=${6}
|
| 7 |
+
|
| 8 |
+
python ./active_learning_scripts/run_cycle_ours.py \
|
| 9 |
+
--input_file ../PARank_data_curation/case_study/tyk2_fep_label.csv \
|
| 10 |
+
--results_dir_1 ${result_path1} \
|
| 11 |
+
--results_dir_2 ${result_path2} \
|
| 12 |
+
--al_batch_size 100 \
|
| 13 |
+
--num_cycles ${num_cycles} \
|
| 14 |
+
--arch_1 pocket_ranking \
|
| 15 |
+
--arch_2 protein_ranking \
|
| 16 |
+
--weight_path_1 ${weight_path1} \
|
| 17 |
+
--weight_path_2 ${weight_path2} \
|
| 18 |
+
--lr 0.0001 \
|
| 19 |
+
--device 0 \
|
| 20 |
+
--master_port 10071 \
|
| 21 |
+
--base_seed 42 \
|
| 22 |
+
--begin_greedy ${begin_greedy}
|
active_learning_scripts/run_cycle_ensemble.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import subprocess
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import argparse
|
| 8 |
+
import json
|
| 9 |
+
import subprocess
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor, wait
|
| 11 |
+
|
| 12 |
+
def parse_arguments():
|
| 13 |
+
parser = argparse.ArgumentParser(description='Active Learning Cycle for Ligand Prediction')
|
| 14 |
+
|
| 15 |
+
# Input/Output arguments
|
| 16 |
+
parser.add_argument('--input_file', type=str, required=True,
|
| 17 |
+
help='Input CSV file containing ligand data (e.g., tyk2_fep.csv)')
|
| 18 |
+
parser.add_argument('--results_dir_1', type=str, required=True,
|
| 19 |
+
help='Results directory for first model')
|
| 20 |
+
parser.add_argument('--results_dir_2', type=str, required=True,
|
| 21 |
+
help='Results directory for second model')
|
| 22 |
+
parser.add_argument('--al_batch_size', type=int, required=True,
|
| 23 |
+
help='Number of samples for each active learning batch')
|
| 24 |
+
|
| 25 |
+
# Experiment configuration
|
| 26 |
+
parser.add_argument('--num_repeats', type=int, default=5,
|
| 27 |
+
help='Number of repeated experiments (default: 5)')
|
| 28 |
+
parser.add_argument('--num_cycles', type=int, required=True,
|
| 29 |
+
help='Number of active learning cycles')
|
| 30 |
+
|
| 31 |
+
# Model configuration
|
| 32 |
+
parser.add_argument('--arch_1', type=str, required=True,
|
| 33 |
+
help='First model architecture')
|
| 34 |
+
parser.add_argument('--arch_2', type=str, required=True,
|
| 35 |
+
help='Second model architecture')
|
| 36 |
+
parser.add_argument('--weight_path_1', type=str, required=True,
|
| 37 |
+
help='Path to first model pretrained weights')
|
| 38 |
+
parser.add_argument('--weight_path_2', type=str, required=True,
|
| 39 |
+
help='Path to second model pretrained weights')
|
| 40 |
+
parser.add_argument('--lr', type=float, default=0.001,
|
| 41 |
+
help='Learning rate (default: 0.001)')
|
| 42 |
+
parser.add_argument('--master_port', type=int, default=29500,
|
| 43 |
+
help='Master port for distributed training (default: 29500)')
|
| 44 |
+
parser.add_argument('--device', type=int, default=0,
|
| 45 |
+
help='Base device to run the models on (default: 0)')
|
| 46 |
+
parser.add_argument('--begin_greedy', type=int, default=0,
|
| 47 |
+
help='iter of begin to be pure greedy, using half greedy before')
|
| 48 |
+
|
| 49 |
+
# Random seed
|
| 50 |
+
parser.add_argument('--base_seed', type=int, default=42,
|
| 51 |
+
help='Base random seed (default: 42)')
|
| 52 |
+
|
| 53 |
+
return parser.parse_args()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _run(cmd):
|
| 57 |
+
import os
|
| 58 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 59 |
+
subprocess.run(cmd, check=True, cwd=project_root)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def run_model(arch_1, arch_2, weight_path_1, weight_path_2, results_path_1, results_path_2, result_file, lr,
|
| 63 |
+
master_port, train_ligf, test_ligf, device):
|
| 64 |
+
cmd1 = [
|
| 65 |
+
"bash", "./active_learning_scripts/run_model.sh",
|
| 66 |
+
arch_1,
|
| 67 |
+
weight_path_1,
|
| 68 |
+
results_path_1,
|
| 69 |
+
result_file,
|
| 70 |
+
str(lr),
|
| 71 |
+
str(master_port),
|
| 72 |
+
train_ligf,
|
| 73 |
+
test_ligf,
|
| 74 |
+
str(device)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
cmd2 = [
|
| 78 |
+
"bash", "./active_learning_scripts/run_model.sh",
|
| 79 |
+
arch_2,
|
| 80 |
+
weight_path_2,
|
| 81 |
+
results_path_2,
|
| 82 |
+
result_file,
|
| 83 |
+
str(lr),
|
| 84 |
+
str(master_port + 1),
|
| 85 |
+
train_ligf,
|
| 86 |
+
test_ligf,
|
| 87 |
+
str(device + 1)
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 91 |
+
task1 = executor.submit(_run, cmd=cmd1)
|
| 92 |
+
task2 = executor.submit(_run, cmd=cmd2)
|
| 93 |
+
wait([task1, task2])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def read_predictions(results_path, result_file):
|
| 97 |
+
"""
|
| 98 |
+
Read predictions from a single model
|
| 99 |
+
"""
|
| 100 |
+
predictions = {}
|
| 101 |
+
|
| 102 |
+
jsonl_path = os.path.join(results_path, result_file)
|
| 103 |
+
with open(jsonl_path, 'r') as f:
|
| 104 |
+
first_line = json.loads(f.readline().strip())
|
| 105 |
+
smiles_list = first_line["tyk2"]["smiles"]
|
| 106 |
+
all_predictions = []
|
| 107 |
+
for line in f:
|
| 108 |
+
pred_line = json.loads(line.strip())
|
| 109 |
+
all_predictions.append(pred_line["tyk2"]["pred"])
|
| 110 |
+
|
| 111 |
+
# Convert to numpy array and calculate mean predictions
|
| 112 |
+
pred_array = np.array(all_predictions)
|
| 113 |
+
mean_predictions = np.mean(pred_array, axis=0)
|
| 114 |
+
|
| 115 |
+
# Create dictionary mapping SMILES to predictions
|
| 116 |
+
for smile, pred in zip(smiles_list, mean_predictions):
|
| 117 |
+
predictions[smile] = float(pred)
|
| 118 |
+
|
| 119 |
+
return predictions
|
| 120 |
+
|
| 121 |
+
def prepare_initial_split(input_file, results_dir_1, results_dir_2, al_batch_size, repeat_idx, cycle_idx, base_seed):
|
| 122 |
+
# Read all ligands
|
| 123 |
+
df = pd.read_csv(input_file)
|
| 124 |
+
|
| 125 |
+
# Set random seed for reproducibility
|
| 126 |
+
random.seed(base_seed + repeat_idx)
|
| 127 |
+
|
| 128 |
+
# Randomly select ligands for training and testing
|
| 129 |
+
all_indices = list(range(len(df)))
|
| 130 |
+
train_indices = random.sample(all_indices, al_batch_size)
|
| 131 |
+
test_indices = [i for i in all_indices if i not in train_indices]
|
| 132 |
+
|
| 133 |
+
# Create train and test files
|
| 134 |
+
train_df = df.iloc[train_indices]
|
| 135 |
+
test_df = df.iloc[test_indices]
|
| 136 |
+
|
| 137 |
+
# Create file names for both directories
|
| 138 |
+
train_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 139 |
+
test_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 140 |
+
|
| 141 |
+
train_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 142 |
+
test_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 143 |
+
|
| 144 |
+
# Create directories if they don't exist
|
| 145 |
+
os.makedirs(os.path.dirname(train_file_1), exist_ok=True)
|
| 146 |
+
os.makedirs(os.path.dirname(train_file_2), exist_ok=True)
|
| 147 |
+
|
| 148 |
+
# Save files to both directories
|
| 149 |
+
train_df.to_csv(train_file_1, index=False)
|
| 150 |
+
test_df.to_csv(test_file_1, index=False)
|
| 151 |
+
train_df.to_csv(train_file_2, index=False)
|
| 152 |
+
test_df.to_csv(test_file_2, index=False)
|
| 153 |
+
|
| 154 |
+
return train_file_1, test_file_1, train_file_2, test_file_2
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def read_and_combine_predictions(results_path_1, results_path_2, result_file):
|
| 158 |
+
"""
|
| 159 |
+
Read predictions from both models and calculate average predictions
|
| 160 |
+
"""
|
| 161 |
+
predictions = {}
|
| 162 |
+
|
| 163 |
+
# Read predictions from model 1
|
| 164 |
+
jsonl_path_1 = os.path.join(results_path_1, result_file)
|
| 165 |
+
with open(jsonl_path_1, 'r') as f:
|
| 166 |
+
first_line = json.loads(f.readline().strip())
|
| 167 |
+
smiles_list = first_line["tyk2"]["smiles"]
|
| 168 |
+
all_predictions_1 = []
|
| 169 |
+
for line in f:
|
| 170 |
+
pred_line = json.loads(line.strip())
|
| 171 |
+
all_predictions_1.append(pred_line["tyk2"]["pred"])
|
| 172 |
+
|
| 173 |
+
# Read predictions from model 2
|
| 174 |
+
jsonl_path_2 = os.path.join(results_path_2, result_file)
|
| 175 |
+
with open(jsonl_path_2, 'r') as f:
|
| 176 |
+
f.readline() # skip first line as we already have smiles_list
|
| 177 |
+
all_predictions_2 = []
|
| 178 |
+
for line in f:
|
| 179 |
+
pred_line = json.loads(line.strip())
|
| 180 |
+
all_predictions_2.append(pred_line["tyk2"]["pred"])
|
| 181 |
+
|
| 182 |
+
# Convert to numpy arrays
|
| 183 |
+
pred_array_1 = np.array(all_predictions_1)
|
| 184 |
+
pred_array_2 = np.array(all_predictions_2)
|
| 185 |
+
|
| 186 |
+
# Calculate mean predictions across both models
|
| 187 |
+
mean_predictions = (np.mean(pred_array_1, axis=0) + np.mean(pred_array_2, axis=0)) / 2
|
| 188 |
+
|
| 189 |
+
# Create dictionary mapping SMILES to average predictions
|
| 190 |
+
for smile, pred in zip(smiles_list, mean_predictions):
|
| 191 |
+
predictions[smile] = float(pred)
|
| 192 |
+
|
| 193 |
+
return predictions
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def update_splits(results_dir_1, results_dir_2, predictions_1, predictions_2,
|
| 197 |
+
prev_train_file_1, prev_test_file_1,
|
| 198 |
+
prev_train_file_2, prev_test_file_2,
|
| 199 |
+
repeat_idx, cycle_idx, al_batch_size, begin_greedy):
|
| 200 |
+
# Read previous test files
|
| 201 |
+
test_df_1 = pd.read_csv(prev_test_file_1)
|
| 202 |
+
test_df_2 = pd.read_csv(prev_test_file_2)
|
| 203 |
+
|
| 204 |
+
# Add predictions to test_df
|
| 205 |
+
test_df_1['prediction_1'] = test_df_1['Smiles'].map(predictions_1)
|
| 206 |
+
test_df_1['prediction_2'] = test_df_1['Smiles'].map(predictions_2)
|
| 207 |
+
test_df_1['prediction'] = (test_df_1['prediction_1'] + test_df_1['prediction_2']) / 2
|
| 208 |
+
|
| 209 |
+
# Sort by average predictions (high to low)
|
| 210 |
+
test_df_sorted = test_df_1.sort_values('prediction', ascending=False)
|
| 211 |
+
|
| 212 |
+
# Read previous train files
|
| 213 |
+
train_df_1 = pd.read_csv(prev_train_file_1)
|
| 214 |
+
train_df_2 = pd.read_csv(prev_train_file_2)
|
| 215 |
+
|
| 216 |
+
# Create new file names for both directories
|
| 217 |
+
new_train_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 218 |
+
new_test_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 219 |
+
new_train_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 220 |
+
new_test_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 221 |
+
|
| 222 |
+
# Create directories if they don't exist
|
| 223 |
+
os.makedirs(os.path.dirname(new_train_file_1), exist_ok=True)
|
| 224 |
+
os.makedirs(os.path.dirname(new_train_file_2), exist_ok=True)
|
| 225 |
+
|
| 226 |
+
if cycle_idx >= begin_greedy:
|
| 227 |
+
# Take top al_batch_size compounds for training
|
| 228 |
+
new_train_compounds = test_df_sorted.head(al_batch_size)
|
| 229 |
+
remaining_test_compounds = test_df_sorted.iloc[al_batch_size:]
|
| 230 |
+
else:
|
| 231 |
+
# use half greedy approach
|
| 232 |
+
new_train_compounds_tmp_1 = test_df_sorted.head(al_batch_size//2)
|
| 233 |
+
remaining_test_compounds_tmp = test_df_sorted.iloc[al_batch_size//2:]
|
| 234 |
+
all_indices = list(range(len(remaining_test_compounds_tmp)))
|
| 235 |
+
|
| 236 |
+
train_indices = random.sample(all_indices, al_batch_size - al_batch_size//2)
|
| 237 |
+
test_indices = [i for i in all_indices if i not in train_indices]
|
| 238 |
+
remaining_test_compounds = remaining_test_compounds_tmp.iloc[test_indices]
|
| 239 |
+
new_train_compounds_tmp_2 = remaining_test_compounds_tmp.iloc[train_indices]
|
| 240 |
+
new_train_compounds = pd.concat([new_train_compounds_tmp_1, new_train_compounds_tmp_2])
|
| 241 |
+
|
| 242 |
+
# Combine with previous training data
|
| 243 |
+
combined_train_df = pd.concat([train_df_1, new_train_compounds])
|
| 244 |
+
|
| 245 |
+
for _ in range(3):
|
| 246 |
+
print("########################################")
|
| 247 |
+
print("Cycling: ", cycle_idx)
|
| 248 |
+
print("top_1p: {}/100".format(combined_train_df['top_1p'].sum()))
|
| 249 |
+
print("top_2p: {}/200".format(combined_train_df['top_2p'].sum()))
|
| 250 |
+
print("top_5p: {}/500".format(combined_train_df['top_5p'].sum()))
|
| 251 |
+
|
| 252 |
+
# Save files for both models (same content, different directories)
|
| 253 |
+
combined_train_df.to_csv(new_train_file_1, index=False)
|
| 254 |
+
remaining_test_compounds.to_csv(new_test_file_1, index=False)
|
| 255 |
+
combined_train_df.to_csv(new_train_file_2, index=False)
|
| 256 |
+
remaining_test_compounds.to_csv(new_test_file_2, index=False)
|
| 257 |
+
|
| 258 |
+
return (new_train_file_1, new_test_file_1,
|
| 259 |
+
new_train_file_2, new_test_file_2)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def run_active_learning(args):
|
| 263 |
+
# Create base results directories
|
| 264 |
+
os.system(f"rm -rf {args.results_dir_1}")
|
| 265 |
+
os.system(f"rm -rf {args.results_dir_2}")
|
| 266 |
+
os.makedirs(args.results_dir_1, exist_ok=True)
|
| 267 |
+
os.makedirs(args.results_dir_2, exist_ok=True)
|
| 268 |
+
|
| 269 |
+
for repeat_idx in range(args.num_repeats):
|
| 270 |
+
print(f"Starting repeat {repeat_idx}")
|
| 271 |
+
|
| 272 |
+
# Initial split for this repeat
|
| 273 |
+
train_file_1, test_file_1, train_file_2, test_file_2 = prepare_initial_split(
|
| 274 |
+
args.input_file,
|
| 275 |
+
args.results_dir_1,
|
| 276 |
+
args.results_dir_2,
|
| 277 |
+
args.al_batch_size,
|
| 278 |
+
repeat_idx,
|
| 279 |
+
0, # First cycle
|
| 280 |
+
args.base_seed
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
for cycle_idx in range(args.num_cycles):
|
| 284 |
+
print(f"Running cycle {cycle_idx} for repeat {repeat_idx}")
|
| 285 |
+
|
| 286 |
+
# Result file name
|
| 287 |
+
result_file = f"repeat_{repeat_idx}_cycle_{cycle_idx}_results.jsonl"
|
| 288 |
+
if os.path.exists(f"{args.results_dir_1}/{result_file}"):
|
| 289 |
+
os.remove(f"{args.results_dir_1}/{result_file}")
|
| 290 |
+
if os.path.exists(f"{args.results_dir_2}/{result_file}"):
|
| 291 |
+
os.remove(f"{args.results_dir_2}/{result_file}")
|
| 292 |
+
|
| 293 |
+
# Run both models
|
| 294 |
+
run_model(
|
| 295 |
+
arch_1=args.arch_1,
|
| 296 |
+
arch_2=args.arch_2,
|
| 297 |
+
weight_path_1=args.weight_path_1,
|
| 298 |
+
weight_path_2=args.weight_path_2,
|
| 299 |
+
results_path_1=args.results_dir_1,
|
| 300 |
+
results_path_2=args.results_dir_2,
|
| 301 |
+
result_file=result_file,
|
| 302 |
+
lr=args.lr,
|
| 303 |
+
master_port=args.master_port,
|
| 304 |
+
train_ligf=train_file_1,
|
| 305 |
+
test_ligf=test_file_1,
|
| 306 |
+
device=args.device
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Update splits for next cycle
|
| 310 |
+
if cycle_idx < args.num_cycles - 1:
|
| 311 |
+
# Read predictions from both models separately
|
| 312 |
+
predictions_1 = read_predictions(args.results_dir_1, result_file)
|
| 313 |
+
predictions_2 = read_predictions(args.results_dir_2, result_file)
|
| 314 |
+
|
| 315 |
+
# Update splits for both models
|
| 316 |
+
train_file_1, test_file_1, train_file_2, test_file_2 = update_splits(
|
| 317 |
+
args.results_dir_1,
|
| 318 |
+
args.results_dir_2,
|
| 319 |
+
predictions_1,
|
| 320 |
+
predictions_2,
|
| 321 |
+
train_file_1,
|
| 322 |
+
test_file_1,
|
| 323 |
+
train_file_2,
|
| 324 |
+
test_file_2,
|
| 325 |
+
repeat_idx,
|
| 326 |
+
cycle_idx + 1,
|
| 327 |
+
args.al_batch_size,
|
| 328 |
+
args.begin_greedy
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
args = parse_arguments()
|
| 334 |
+
run_active_learning(args)
|
active_learning_scripts/run_cycle_one_model.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import subprocess
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import argparse
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_arguments():
|
| 12 |
+
parser = argparse.ArgumentParser(description='Active Learning Cycle for Ligand Prediction')
|
| 13 |
+
|
| 14 |
+
# Input/Output arguments
|
| 15 |
+
parser.add_argument('--input_file', type=str, required=True,
|
| 16 |
+
help='Input CSV file containing ligand data (e.g., tyk2_fep.csv)')
|
| 17 |
+
parser.add_argument('--results_dir', type=str, required=True,
|
| 18 |
+
help='Base directory for storing all results')
|
| 19 |
+
parser.add_argument('--al_batch_size', type=int, required=True,
|
| 20 |
+
help='Number of samples for each active learning batch')
|
| 21 |
+
|
| 22 |
+
# Experiment configuration
|
| 23 |
+
parser.add_argument('--num_repeats', type=int, default=5,
|
| 24 |
+
help='Number of repeated experiments (default: 5)')
|
| 25 |
+
parser.add_argument('--num_cycles', type=int, required=True,
|
| 26 |
+
help='Number of active learning cycles')
|
| 27 |
+
|
| 28 |
+
# Model configuration
|
| 29 |
+
parser.add_argument('--arch', type=str, required=True,
|
| 30 |
+
help='Model architecture')
|
| 31 |
+
parser.add_argument('--weight_path', type=str, required=True,
|
| 32 |
+
help='Path to pretrained model weights')
|
| 33 |
+
parser.add_argument('--lr', type=float, default=0.001,
|
| 34 |
+
help='Learning rate (default: 0.001)')
|
| 35 |
+
parser.add_argument('--master_port', type=int, default=29500,
|
| 36 |
+
help='Master port for distributed training (default: 29500)')
|
| 37 |
+
parser.add_argument('--device', type=int, default=0,
|
| 38 |
+
help='Device to run the model on (default: cuda:0)')
|
| 39 |
+
parser.add_argument('--begin_greedy', type=int, default=0,
|
| 40 |
+
help='iter of begin to be pure greedy, using half greedy before')
|
| 41 |
+
|
| 42 |
+
# Random seed
|
| 43 |
+
parser.add_argument('--base_seed', type=int, default=42,
|
| 44 |
+
help='Base random seed (default: 42)')
|
| 45 |
+
|
| 46 |
+
return parser.parse_args()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def run_model(arch, weight_path, results_path, result_file, lr, master_port, train_ligf, test_ligf, device):
|
| 50 |
+
import os
|
| 51 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 52 |
+
cmd = [
|
| 53 |
+
"bash", "./active_learning_scripts/run_model.sh",
|
| 54 |
+
arch,
|
| 55 |
+
weight_path,
|
| 56 |
+
results_path,
|
| 57 |
+
result_file,
|
| 58 |
+
str(lr),
|
| 59 |
+
str(master_port),
|
| 60 |
+
train_ligf,
|
| 61 |
+
test_ligf,
|
| 62 |
+
str(device)
|
| 63 |
+
]
|
| 64 |
+
subprocess.run(cmd, check=True, cwd=project_root)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def prepare_initial_split(input_file, results_dir, al_batch_size, repeat_idx, cycle_idx, base_seed):
|
| 68 |
+
# Read all ligands
|
| 69 |
+
df = pd.read_csv(input_file)
|
| 70 |
+
|
| 71 |
+
# Set random seed for reproducibility
|
| 72 |
+
random.seed(base_seed + repeat_idx) # Different seed for each repeat
|
| 73 |
+
|
| 74 |
+
# Randomly select ligands for training and testing
|
| 75 |
+
all_indices = list(range(len(df)))
|
| 76 |
+
train_indices = random.sample(all_indices, al_batch_size)
|
| 77 |
+
test_indices = [i for i in all_indices if i not in train_indices]
|
| 78 |
+
|
| 79 |
+
# Create train and test files
|
| 80 |
+
train_df = df.iloc[train_indices]
|
| 81 |
+
test_df = df.iloc[test_indices]
|
| 82 |
+
|
| 83 |
+
# Create file names with repeat and cycle information
|
| 84 |
+
train_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 85 |
+
test_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 86 |
+
|
| 87 |
+
# Create directory if it doesn't exist
|
| 88 |
+
os.makedirs(os.path.dirname(train_file), exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# Save files
|
| 91 |
+
train_df.to_csv(train_file, index=False)
|
| 92 |
+
test_df.to_csv(test_file, index=False)
|
| 93 |
+
|
| 94 |
+
return train_file, test_file
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def read_jsonl_predictions(results_path, result_file):
|
| 98 |
+
"""
|
| 99 |
+
Read predictions from jsonl file and calculate average predictions
|
| 100 |
+
Returns a dictionary mapping SMILES to average predictions
|
| 101 |
+
"""
|
| 102 |
+
predictions = {}
|
| 103 |
+
all_predictions = []
|
| 104 |
+
smiles_list = None
|
| 105 |
+
|
| 106 |
+
jsonl_path = os.path.join(results_path, result_file)
|
| 107 |
+
with open(jsonl_path, 'r') as f:
|
| 108 |
+
# Read first line to get SMILES list
|
| 109 |
+
first_line = f.readline()
|
| 110 |
+
smiles_list = json.loads(first_line.strip())["tyk2"]["smiles"]
|
| 111 |
+
|
| 112 |
+
# Read rest of lines containing predictions
|
| 113 |
+
for line in f:
|
| 114 |
+
pred_line = json.loads(line.strip())
|
| 115 |
+
all_predictions.append(pred_line["tyk2"]["pred"])
|
| 116 |
+
|
| 117 |
+
# Convert to numpy array for easier computation
|
| 118 |
+
pred_array = np.array(all_predictions)
|
| 119 |
+
# Calculate mean predictions
|
| 120 |
+
mean_predictions = np.mean(pred_array, axis=0)
|
| 121 |
+
|
| 122 |
+
# Create dictionary mapping SMILES to average predictions
|
| 123 |
+
for smile, pred in zip(smiles_list, mean_predictions):
|
| 124 |
+
predictions[smile] = float(pred)
|
| 125 |
+
|
| 126 |
+
return predictions
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def update_splits(results_dir, results_path, result_file, prev_train_file, prev_test_file, repeat_idx, cycle_idx,
|
| 130 |
+
al_batch_size, begin_greedy):
|
| 131 |
+
# Read predictions from jsonl file
|
| 132 |
+
predictions = read_jsonl_predictions(results_path, result_file)
|
| 133 |
+
|
| 134 |
+
# Read previous test file
|
| 135 |
+
test_df = pd.read_csv(prev_test_file)
|
| 136 |
+
|
| 137 |
+
# Add predictions to test_df
|
| 138 |
+
test_df['prediction'] = test_df['Smiles'].map(predictions)
|
| 139 |
+
|
| 140 |
+
# Sort by predictions (high to low)
|
| 141 |
+
test_df_sorted = test_df.sort_values('prediction', ascending=False)
|
| 142 |
+
|
| 143 |
+
# Read previous train file
|
| 144 |
+
train_df = pd.read_csv(prev_train_file)
|
| 145 |
+
|
| 146 |
+
# Create new file names
|
| 147 |
+
new_train_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
|
| 148 |
+
new_test_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
|
| 149 |
+
|
| 150 |
+
# Create directory if it doesn't exist
|
| 151 |
+
os.makedirs(os.path.dirname(new_train_file), exist_ok=True)
|
| 152 |
+
|
| 153 |
+
if cycle_idx >= begin_greedy:
|
| 154 |
+
# Take top al_batch_size compounds for training
|
| 155 |
+
new_train_compounds = test_df_sorted.head(al_batch_size)
|
| 156 |
+
remaining_test_compounds = test_df_sorted.iloc[al_batch_size:]
|
| 157 |
+
else:
|
| 158 |
+
# use half greedy approach
|
| 159 |
+
new_train_compounds_tmp_1 = test_df_sorted.head(al_batch_size//2)
|
| 160 |
+
remaining_test_compounds_tmp = test_df_sorted.iloc[al_batch_size//2:]
|
| 161 |
+
all_indices = list(range(len(remaining_test_compounds_tmp)))
|
| 162 |
+
|
| 163 |
+
train_indices = random.sample(all_indices, al_batch_size - al_batch_size//2)
|
| 164 |
+
test_indices = [i for i in all_indices if i not in train_indices]
|
| 165 |
+
remaining_test_compounds = remaining_test_compounds_tmp.iloc[test_indices]
|
| 166 |
+
new_train_compounds_tmp_2 = remaining_test_compounds_tmp.iloc[train_indices]
|
| 167 |
+
new_train_compounds = pd.concat([new_train_compounds_tmp_1, new_train_compounds_tmp_2])
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Combine with previous training data
|
| 171 |
+
combined_train_df = pd.concat([train_df, new_train_compounds])
|
| 172 |
+
|
| 173 |
+
for _ in range(3):
|
| 174 |
+
print("########################################")
|
| 175 |
+
print("Cycling: ", cycle_idx)
|
| 176 |
+
print("top_1p: {}/100".format(combined_train_df['top_1p'].sum()))
|
| 177 |
+
print("top_2p: {}/200".format(combined_train_df['top_2p'].sum()))
|
| 178 |
+
print("top_5p: {}/500".format(combined_train_df['top_5p'].sum()))
|
| 179 |
+
|
| 180 |
+
# Save files
|
| 181 |
+
combined_train_df.to_csv(new_train_file, index=False)
|
| 182 |
+
remaining_test_compounds.to_csv(new_test_file, index=False)
|
| 183 |
+
|
| 184 |
+
return new_train_file, new_test_file
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def run_active_learning(args):
|
| 188 |
+
# Create base results directory
|
| 189 |
+
os.system(f"rm -rf {args.results_dir}")
|
| 190 |
+
os.makedirs(args.results_dir, exist_ok=True)
|
| 191 |
+
|
| 192 |
+
for repeat_idx in range(args.num_repeats):
|
| 193 |
+
print(f"Starting repeat {repeat_idx}")
|
| 194 |
+
|
| 195 |
+
# Initial split for this repeat
|
| 196 |
+
train_file, test_file = prepare_initial_split(
|
| 197 |
+
args.input_file,
|
| 198 |
+
args.results_dir,
|
| 199 |
+
args.al_batch_size,
|
| 200 |
+
repeat_idx,
|
| 201 |
+
0, # First cycle
|
| 202 |
+
args.base_seed
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
for cycle_idx in range(args.num_cycles):
|
| 206 |
+
print(f"Running cycle {cycle_idx} for repeat {repeat_idx}")
|
| 207 |
+
|
| 208 |
+
# Create results directory for this cycle
|
| 209 |
+
results_path = args.results_dir
|
| 210 |
+
|
| 211 |
+
# Result file name
|
| 212 |
+
result_file = f"repeat_{repeat_idx}_cycle_{cycle_idx}_results.jsonl"
|
| 213 |
+
if os.path.exists(f"{args.results_dir}/{result_file}"):
|
| 214 |
+
os.remove(f"{args.results_dir}/{result_file}")
|
| 215 |
+
|
| 216 |
+
# Run the model
|
| 217 |
+
run_model(
|
| 218 |
+
arch=args.arch,
|
| 219 |
+
weight_path=args.weight_path,
|
| 220 |
+
results_path=results_path,
|
| 221 |
+
result_file=result_file,
|
| 222 |
+
lr=args.lr,
|
| 223 |
+
master_port=args.master_port,
|
| 224 |
+
train_ligf=train_file,
|
| 225 |
+
test_ligf=test_file,
|
| 226 |
+
device=args.device
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Update splits for next cycle
|
| 230 |
+
if cycle_idx < args.num_cycles - 1: # Don't update after last cycle
|
| 231 |
+
train_file, test_file = update_splits(
|
| 232 |
+
args.results_dir,
|
| 233 |
+
results_path,
|
| 234 |
+
result_file,
|
| 235 |
+
train_file,
|
| 236 |
+
test_file,
|
| 237 |
+
repeat_idx,
|
| 238 |
+
cycle_idx + 1,
|
| 239 |
+
args.al_batch_size,
|
| 240 |
+
args.begin_greedy
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
args = parse_arguments()
|
| 246 |
+
run_active_learning(args)
|
active_learning_scripts/run_model.sh
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_path="./test_datasets"
|
| 2 |
+
|
| 3 |
+
n_gpu=1
|
| 4 |
+
|
| 5 |
+
batch_size=1
|
| 6 |
+
batch_size_valid=1
|
| 7 |
+
epoch=20
|
| 8 |
+
update_freq=1
|
| 9 |
+
#lr=1e-3
|
| 10 |
+
#MASTER_PORT=10075
|
| 11 |
+
#arch=pocket_ranking
|
| 12 |
+
|
| 13 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 14 |
+
export OMP_NUM_THREADS=1
|
| 15 |
+
|
| 16 |
+
arch=${1} # model architecture
|
| 17 |
+
weight_path=${2} # path for pretrained model
|
| 18 |
+
results_path=${3} #
|
| 19 |
+
result_file=${4} #
|
| 20 |
+
lr=${5} # learning rate
|
| 21 |
+
MASTER_PORT=${6}
|
| 22 |
+
train_ligf=${7} # !! input path for training ligands file (.csv format)
|
| 23 |
+
test_ligf=${8} # !! input path for test ligands file (.csv format)
|
| 24 |
+
device=${9} # cuda device
|
| 25 |
+
|
| 26 |
+
if [[ "$arch" == "pocketregression" ]] || [[ "$arch" == "DTA" ]]; then
|
| 27 |
+
loss="mseloss"
|
| 28 |
+
else
|
| 29 |
+
loss="rank_softmax"
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
CUDA_VISIBLE_DEVICES=${device} python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 34 |
+
--results-path $results_path \
|
| 35 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 36 |
+
--task train_task --loss ${loss} --arch $arch \
|
| 37 |
+
--max-pocket-atoms 256 \
|
| 38 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 39 |
+
--lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 40 |
+
--update-freq $update_freq --seed 1 \
|
| 41 |
+
--log-interval 1 --log-format simple \
|
| 42 |
+
--validate-interval 1 --validate-begin-epoch 15 \
|
| 43 |
+
--best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
|
| 44 |
+
--no-save --save-dir $results_path --tmp-save-dir $results_path \
|
| 45 |
+
--find-unused-parameters \
|
| 46 |
+
--maximize-best-checkpoint-metric \
|
| 47 |
+
--valid-set TYK2 \
|
| 48 |
+
--max-lignum 512 --test-max-lignum 10000 \
|
| 49 |
+
--restore-model $weight_path --few-shot true \
|
| 50 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
|
| 51 |
+
--active-learning-resfile ${result_file} \
|
| 52 |
+
--case-train-ligfile ${train_ligf} --case-test-ligfile ${test_ligf}
|
| 53 |
+
|
ensemble_result.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.stats as stats
|
| 7 |
+
import math
|
| 8 |
+
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
|
| 9 |
+
|
| 10 |
+
def cal_metrics(y_score, y_true):
|
| 11 |
+
# concate res_single and labels
|
| 12 |
+
scores = np.expand_dims(y_score, axis=1)
|
| 13 |
+
y_true = np.expand_dims(y_true, axis=1)
|
| 14 |
+
scores = np.concatenate((scores, y_true), axis=1)
|
| 15 |
+
# inverse sort scores based on first column
|
| 16 |
+
scores = scores[scores[:, 0].argsort()[::-1]]
|
| 17 |
+
bedroc = CalcBEDROC(scores, 1, 80.5)
|
| 18 |
+
count = 0
|
| 19 |
+
# sort y_score, return index
|
| 20 |
+
index = np.argsort(y_score)[::-1]
|
| 21 |
+
for i in range(int(len(index) * 0.005)):
|
| 22 |
+
if y_true[index[i]] == 1:
|
| 23 |
+
count += 1
|
| 24 |
+
auc = CalcAUC(scores, 1)
|
| 25 |
+
ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05])
|
| 26 |
+
|
| 27 |
+
return {
|
| 28 |
+
"BEDROC": bedroc,
|
| 29 |
+
"AUROC": auc,
|
| 30 |
+
"EF0.5": ef_list[0],
|
| 31 |
+
"EF1": ef_list[1],
|
| 32 |
+
"EF5": ef_list[3]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def print_avg_metric(metric_dict, name):
|
| 36 |
+
metric_lst = list(metric_dict.values())
|
| 37 |
+
ret_metric = copy.deepcopy(metric_lst[0])
|
| 38 |
+
for m in metric_lst[1:]:
|
| 39 |
+
for k in m:
|
| 40 |
+
ret_metric[k] += m[k]
|
| 41 |
+
|
| 42 |
+
for k in ret_metric:
|
| 43 |
+
ret_metric[k] = ret_metric[k] / len(metric_lst)
|
| 44 |
+
print(name, ret_metric)
|
| 45 |
+
|
| 46 |
+
def read_zeroshot_res(res_dir):
|
| 47 |
+
targets = sorted(list(os.listdir(res_dir)))
|
| 48 |
+
res_dict = {}
|
| 49 |
+
for target in targets:
|
| 50 |
+
real_dg = np.load(f"{res_dir}/{target}/saved_labels.npy")
|
| 51 |
+
if os.path.exists(f"{res_dir}/{target}/saved_preds.npy"):
|
| 52 |
+
pred_dg = np.load(f"{res_dir}/{target}/saved_preds.npy")
|
| 53 |
+
else:
|
| 54 |
+
mol_reps = np.load(f"{res_dir}/{target}/saved_mols_embed.npy")
|
| 55 |
+
pocket_reps = np.load(f"{res_dir}/{target}/saved_target_embed.npy")
|
| 56 |
+
res = pocket_reps @ mol_reps.T
|
| 57 |
+
pred_dg = res.max(axis=0)
|
| 58 |
+
res_dict[target] = {
|
| 59 |
+
"pred": pred_dg,
|
| 60 |
+
"exp": real_dg
|
| 61 |
+
}
|
| 62 |
+
return res_dict
|
| 63 |
+
|
| 64 |
+
def get_ensemble_res(res_list, begin=0, end=-1):
|
| 65 |
+
if end == -1:
|
| 66 |
+
end = len(res_list)
|
| 67 |
+
ret = copy.deepcopy(res_list[begin])
|
| 68 |
+
for res in res_list[begin+1:end]:
|
| 69 |
+
for k in ret.keys():
|
| 70 |
+
ret[k]["pred"] = np.array(ret[k]["pred"]) + np.array(res[k]["pred"])
|
| 71 |
+
|
| 72 |
+
for k in ret.keys():
|
| 73 |
+
ret[k]["pred"] = np.array(ret[k]["pred"]) / (end-begin)
|
| 74 |
+
|
| 75 |
+
return ret
|
| 76 |
+
|
| 77 |
+
def avg_metric(metric_lst_all):
|
| 78 |
+
ret_metric_dict = {}
|
| 79 |
+
for metric_lst in metric_lst_all:
|
| 80 |
+
ret_metric = copy.deepcopy(metric_lst[0])
|
| 81 |
+
for m in metric_lst[1:]:
|
| 82 |
+
for k in ["pearsonr", "spearmanr", "r2"]:
|
| 83 |
+
ret_metric[k] += m[k]
|
| 84 |
+
for k in ["spearmanr", "pearsonr", "r2"]:
|
| 85 |
+
ret_metric[k] = ret_metric[k] / len(metric_lst)
|
| 86 |
+
ret_metric_dict[ret_metric["target"]] = ret_metric
|
| 87 |
+
return ret_metric_dict
|
| 88 |
+
|
| 89 |
+
def get_metric(res):
|
| 90 |
+
metric_dict = {}
|
| 91 |
+
for k in sorted(list(res.keys())):
|
| 92 |
+
pred = res[k]["pred"]
|
| 93 |
+
exp = res[k]["exp"]
|
| 94 |
+
spearmanr = stats.spearmanr(exp, pred).statistic
|
| 95 |
+
pearsonr = stats.pearsonr(exp, pred).statistic
|
| 96 |
+
if math.isnan(pearsonr):
|
| 97 |
+
pearsonr = 0
|
| 98 |
+
if math.isnan(spearmanr):
|
| 99 |
+
spearmanr = 0
|
| 100 |
+
metric_dict[k] = {
|
| 101 |
+
"pearsonr":pearsonr,
|
| 102 |
+
"spearmanr":spearmanr,
|
| 103 |
+
"r2":max(pearsonr, 0)**2,
|
| 104 |
+
"target":k
|
| 105 |
+
}
|
| 106 |
+
return metric_dict
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
mode = sys.argv[1]
|
| 111 |
+
if mode == "zeroshot":
|
| 112 |
+
test_sets = sys.argv[2:]
|
| 113 |
+
for test_set in test_sets:
|
| 114 |
+
if test_set in ["DUDE", "PCBA", "DEKOIS"]:
|
| 115 |
+
metrics = {}
|
| 116 |
+
target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}")))
|
| 117 |
+
for target_id in target_id_list:
|
| 118 |
+
lig_act = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/saved_labels.npy")
|
| 119 |
+
score_1 = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy")
|
| 120 |
+
score_2 = np.load(f"./result/protein_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy")
|
| 121 |
+
|
| 122 |
+
score = (score_1 + score_2)/2
|
| 123 |
+
metrics[target_id] = cal_metrics(score, lig_act)
|
| 124 |
+
|
| 125 |
+
json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
|
| 126 |
+
print_avg_metric(metrics, "Ours")
|
| 127 |
+
elif test_set in ["FEP"]:
|
| 128 |
+
target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}")))
|
| 129 |
+
res_all_pocket, res_all_protein = [], []
|
| 130 |
+
for repeat in range(1, 6):
|
| 131 |
+
res_pocket = read_zeroshot_res(f"./result/pocket_ranking/{test_set}/repeat_{repeat}")
|
| 132 |
+
res_protein = read_zeroshot_res(f"./result/protein_ranking/{test_set}/repeat_{repeat}")
|
| 133 |
+
res_all_pocket.append(res_pocket)
|
| 134 |
+
res_all_protein.append(res_protein)
|
| 135 |
+
res_all_fusion = get_ensemble_res(res_all_pocket + res_all_protein)
|
| 136 |
+
metrics = get_metric(res_all_fusion)
|
| 137 |
+
json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
|
| 138 |
+
print_avg_metric(metrics, "Ours")
|
| 139 |
+
elif mode == "fewshot":
|
| 140 |
+
test_set = sys.argv[2]
|
| 141 |
+
support_num = sys.argv[3]
|
| 142 |
+
begin = 15
|
| 143 |
+
end = 20
|
| 144 |
+
metric_fusion_all = []
|
| 145 |
+
for seed in range(1, 11):
|
| 146 |
+
res_repeat_pocket = []
|
| 147 |
+
res_repeat_seq = []
|
| 148 |
+
|
| 149 |
+
if test_set in ["TIME", "OOD"]:
|
| 150 |
+
res_file_pocket = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl"
|
| 151 |
+
res_file_seq = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl"
|
| 152 |
+
if not os.path.exists(res_file_pocket):
|
| 153 |
+
continue
|
| 154 |
+
res_repeat_pocket = [json.loads(line) for line in open(res_file_pocket)][1:]
|
| 155 |
+
res_repeat_seq = [json.loads(line) for line in open(res_file_seq)][1:]
|
| 156 |
+
elif test_set in ["FEP_fewshot"]:
|
| 157 |
+
for repeat in range(1, 6):
|
| 158 |
+
res_file_pocket = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl"
|
| 159 |
+
res_file_seq = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl"
|
| 160 |
+
if not os.path.exists(res_file_pocket):
|
| 161 |
+
continue
|
| 162 |
+
res_pocket = [json.loads(line) for line in open(res_file_pocket)][1:]
|
| 163 |
+
res_seq = [json.loads(line) for line in open(res_file_seq)][1:]
|
| 164 |
+
res_pocket = get_ensemble_res(res_pocket, begin, end)
|
| 165 |
+
res_seq = get_ensemble_res(res_seq, begin, end)
|
| 166 |
+
res_repeat_pocket.append(res_pocket)
|
| 167 |
+
res_repeat_seq.append(res_seq)
|
| 168 |
+
|
| 169 |
+
res_repeat_fusion = get_ensemble_res(res_repeat_pocket + res_repeat_seq)
|
| 170 |
+
metric_fusion_all.append(get_metric(res_repeat_fusion))
|
| 171 |
+
metric_fusion_all = avg_metric(list(map(list, zip(*metric_fusion_all))))
|
| 172 |
+
json.dump(metric_fusion_all, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
|
| 173 |
+
print_avg_metric(metric_fusion_all, "Ours")
|
py_scripts/__init__.py
ADDED
|
File without changes
|
py_scripts/write_case_study.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import gzip
|
| 3 |
+
import json
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import lmdb
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import rdkit
|
| 13 |
+
import rdkit.Chem.AllChem as AllChem
|
| 14 |
+
import torch
|
| 15 |
+
import tqdm
|
| 16 |
+
from biopandas.mol2 import PandasMol2
|
| 17 |
+
from biopandas.pdb import PandasPdb
|
| 18 |
+
from rdkit import Chem, RDLogger
|
| 19 |
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
| 20 |
+
|
| 21 |
+
RDLogger.DisableLog('rdApp.*')
|
| 22 |
+
|
| 23 |
+
def gen_conformation(mol, num_conf=20, num_worker=8):
|
| 24 |
+
try:
|
| 25 |
+
mol = Chem.AddHs(mol)
|
| 26 |
+
AllChem.EmbedMultipleConfs(mol, numConfs=num_conf, numThreads=num_worker, pruneRmsThresh=1, maxAttempts=10000, useRandomCoords=False)
|
| 27 |
+
try:
|
| 28 |
+
AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=num_worker)
|
| 29 |
+
except:
|
| 30 |
+
pass
|
| 31 |
+
mol = Chem.RemoveHs(mol)
|
| 32 |
+
except:
|
| 33 |
+
print("cannot gen conf", Chem.MolToSmiles(mol))
|
| 34 |
+
return None
|
| 35 |
+
if mol.GetNumConformers() == 0:
|
| 36 |
+
print("cannot gen conf", Chem.MolToSmiles(mol))
|
| 37 |
+
return None
|
| 38 |
+
return mol
|
| 39 |
+
|
| 40 |
+
def convert_2Dmol_to_data(smi, num_conf=1, num_worker=5):
|
| 41 |
+
#to 3D
|
| 42 |
+
mol = Chem.MolFromSmiles(smi)
|
| 43 |
+
if mol is None:
|
| 44 |
+
return None
|
| 45 |
+
mol = gen_conformation(mol, num_conf, num_worker)
|
| 46 |
+
if mol is None:
|
| 47 |
+
return None
|
| 48 |
+
coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
|
| 49 |
+
atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
|
| 50 |
+
return {'coords': coords, 'atom_types': atom_types, 'smi': smi, 'mol': mol}
|
| 51 |
+
|
| 52 |
+
def convert_3Dmol_to_data(mol):
|
| 53 |
+
|
| 54 |
+
if mol is None:
|
| 55 |
+
return None
|
| 56 |
+
coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
|
| 57 |
+
atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
|
| 58 |
+
return {'coords': coords, 'atom_types': atom_types, 'smi': Chem.MolToSmiles(mol), 'mol': mol}
|
| 59 |
+
|
| 60 |
+
def read_pdb(path):
|
| 61 |
+
pdb_df = PandasPdb().read_pdb(path)
|
| 62 |
+
|
| 63 |
+
coord = pdb_df.df['ATOM'][['x_coord', 'y_coord', 'z_coord']]
|
| 64 |
+
atom_type = pdb_df.df['ATOM']['atom_name']
|
| 65 |
+
residue_name = pdb_df.df['ATOM']['chain_id'] + pdb_df.df['ATOM']['residue_number'].astype(str)
|
| 66 |
+
residue_type = pdb_df.df['ATOM']['residue_name']
|
| 67 |
+
protein = {'coord': np.array(coord),
|
| 68 |
+
'atom_type': list(atom_type),
|
| 69 |
+
'residue_name': list(residue_name),
|
| 70 |
+
'residue_type': list(residue_type)}
|
| 71 |
+
return protein
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def read_sdf_gz_3d(path):
|
| 75 |
+
inf = gzip.open(path)
|
| 76 |
+
with Chem.ForwardSDMolSupplier(inf, removeHs=False, sanitize=False) as gzsuppl:
|
| 77 |
+
ms = [add_charges(x) for x in gzsuppl if x is not None]
|
| 78 |
+
ms = [rdMolStandardize.Uncharger().uncharge(Chem.RemoveHs(m)) for m in ms if m is not None]
|
| 79 |
+
return ms
|
| 80 |
+
|
| 81 |
+
def add_charges(m):
|
| 82 |
+
m.UpdatePropertyCache(strict=False)
|
| 83 |
+
ps = Chem.DetectChemistryProblems(m)
|
| 84 |
+
if not ps:
|
| 85 |
+
Chem.SanitizeMol(m)
|
| 86 |
+
return m
|
| 87 |
+
for p in ps:
|
| 88 |
+
if p.GetType()=='AtomValenceException':
|
| 89 |
+
at = m.GetAtomWithIdx(p.GetAtomIdx())
|
| 90 |
+
if at.GetAtomicNum()==7 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
|
| 91 |
+
at.SetFormalCharge(1)
|
| 92 |
+
if at.GetAtomicNum()==6 and at.GetExplicitValence()==5:
|
| 93 |
+
#remove a bond
|
| 94 |
+
for b in at.GetBonds():
|
| 95 |
+
if b.GetBondType()==Chem.rdchem.BondType.DOUBLE:
|
| 96 |
+
b.SetBondType(Chem.rdchem.BondType.SINGLE)
|
| 97 |
+
break
|
| 98 |
+
if at.GetAtomicNum()==8 and at.GetFormalCharge()==0 and at.GetExplicitValence()==3:
|
| 99 |
+
at.SetFormalCharge(1)
|
| 100 |
+
if at.GetAtomicNum()==5 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
|
| 101 |
+
at.SetFormalCharge(-1)
|
| 102 |
+
try:
|
| 103 |
+
Chem.SanitizeMol(m)
|
| 104 |
+
except:
|
| 105 |
+
return None
|
| 106 |
+
return m
|
| 107 |
+
|
| 108 |
+
def get_different_raid(protein, ligand, raid=6):
|
| 109 |
+
protein_coord = protein['coord']
|
| 110 |
+
ligand_coord = ligand['coord']
|
| 111 |
+
protein_residue_name = protein['residue_name']
|
| 112 |
+
pocket_residue = set()
|
| 113 |
+
for i in range(len(protein_coord)):
|
| 114 |
+
for j in range(len(ligand_coord)):
|
| 115 |
+
if np.linalg.norm(protein_coord[i] - ligand_coord[j]) < raid:
|
| 116 |
+
pocket_residue.add(protein_residue_name[i])
|
| 117 |
+
return pocket_residue
|
| 118 |
+
|
| 119 |
+
def read_mol2_ligand(path):
|
| 120 |
+
mol2_df = PandasMol2().read_mol2(path)
|
| 121 |
+
coord = mol2_df.df[['x', 'y', 'z']]
|
| 122 |
+
atom_type = mol2_df.df['atom_name']
|
| 123 |
+
ligand = {'coord': np.array(coord), 'atom_type': list(atom_type), 'mol': Chem.MolFromMol2File(path)}
|
| 124 |
+
return ligand
|
| 125 |
+
|
| 126 |
+
def read_smi_mol(path):
|
| 127 |
+
with open(path, 'r') as f:
|
| 128 |
+
mols_lines = list(f.readlines())
|
| 129 |
+
smis = [l.split(' ')[0] for l in mols_lines]
|
| 130 |
+
mols = [Chem.MolFromSmiles(m) for m in smis]
|
| 131 |
+
return mols
|
| 132 |
+
|
| 133 |
+
def parser(protein_path, mol_path, ligand_path, activity, pocket_index, raid=6):
|
| 134 |
+
protein = read_pdb(protein_path)
|
| 135 |
+
data_mols = read_smi_mol(mol_path)
|
| 136 |
+
|
| 137 |
+
ligand = read_mol2_ligand(ligand_path)
|
| 138 |
+
pocket_residue = get_different_raid(protein, ligand, raid=raid)
|
| 139 |
+
pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
|
| 140 |
+
pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
|
| 141 |
+
pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
|
| 142 |
+
pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
|
| 143 |
+
pocket_name = protein_path.split('/')[-2]
|
| 144 |
+
pool = mp.Pool(32)
|
| 145 |
+
#mols = [convert_2Dmol_to_data(m) for m in data_mols if m is not None]
|
| 146 |
+
data_mols = [m for m in data_mols if m is not None]
|
| 147 |
+
mols = [m for m in pool.imap_unordered(convert_2Dmol_to_data, data_mols)]
|
| 148 |
+
mols = [m for m in mols if m is not None]
|
| 149 |
+
|
| 150 |
+
return [{'atoms': m['atom_types'],
|
| 151 |
+
'coordinates': m['coords'],
|
| 152 |
+
'smi': m['smi'],
|
| 153 |
+
'mol': ligand,
|
| 154 |
+
'pocket_name': pocket_name,
|
| 155 |
+
'pocket_index': pocket_index,
|
| 156 |
+
'activity': activity,
|
| 157 |
+
"pocket_atom_type": pocket_atom_type,
|
| 158 |
+
"pocket_coord": pocket_coord} for m in mols]
|
| 159 |
+
|
| 160 |
+
def mol_parser(ligand_smis):
|
| 161 |
+
pool = mp.Pool(16)
|
| 162 |
+
mols = [m for m in pool.imap_unordered(convert_2Dmol_to_data, tqdm.tqdm(ligand_smis))]
|
| 163 |
+
mols = [m for m in mols if m is not None]
|
| 164 |
+
return [{'atoms': m['atom_types'],
|
| 165 |
+
'coordinates': m['coords'],
|
| 166 |
+
'smi': m['smi'],
|
| 167 |
+
'mol': m['mol'],
|
| 168 |
+
'label': 1,
|
| 169 |
+
} for m in mols]
|
| 170 |
+
|
| 171 |
+
def pocket_parser(protein_path, ligand_path, pocket_index, pocket_name, raid=6):
|
| 172 |
+
protein = read_pdb(protein_path)
|
| 173 |
+
ligand = read_mol2_ligand(ligand_path)
|
| 174 |
+
pocket_residue = get_different_raid(protein, ligand, raid=raid)
|
| 175 |
+
pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
|
| 176 |
+
pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
|
| 177 |
+
pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
|
| 178 |
+
pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
|
| 179 |
+
pocket_residue_name = [protein['residue_name'][i] for i in pocket_atom_idx]
|
| 180 |
+
return {'pocket': pocket_name,
|
| 181 |
+
'pocket_index': pocket_index,
|
| 182 |
+
"pocket_atoms": pocket_atom_type,
|
| 183 |
+
"pocket_coordinates": pocket_coord,
|
| 184 |
+
"pocket_residue_type": pocket_residue_type,
|
| 185 |
+
"pocket_residue_name": pocket_residue_name}
|
| 186 |
+
|
| 187 |
+
def write_lmdb(data, lmdb_path):
|
| 188 |
+
#resume
|
| 189 |
+
if os.path.exists(lmdb_path):
|
| 190 |
+
os.system(f"rm {lmdb_path}")
|
| 191 |
+
env = lmdb.open(lmdb_path, subdir=False, readonly=False, lock=False, readahead=False, meminit=False, map_size=1099511627776)
|
| 192 |
+
num = 0
|
| 193 |
+
with env.begin(write=True) as txn:
|
| 194 |
+
for d in data:
|
| 195 |
+
txn.put(str(num).encode('ascii'), pickle.dumps(d))
|
| 196 |
+
num += 1
|
| 197 |
+
|
| 198 |
+
import sys
|
| 199 |
+
if __name__ == '__main__':
|
| 200 |
+
mode = sys.argv[1]
|
| 201 |
+
|
| 202 |
+
if mode == 'mol':
|
| 203 |
+
lig_file = sys.argv[2]
|
| 204 |
+
lig_write_file = sys.argv[3]
|
| 205 |
+
|
| 206 |
+
# read the ligands smiles into a list
|
| 207 |
+
smis = json.load(open(lig_file))
|
| 208 |
+
data = []
|
| 209 |
+
print("number of ligands", len(set(smis)))
|
| 210 |
+
d_active = (mol_parser(list(set(smis))))
|
| 211 |
+
data.extend(d_active)
|
| 212 |
+
|
| 213 |
+
# write ligands lmdb
|
| 214 |
+
write_lmdb(data, lig_write_file)
|
| 215 |
+
elif mode == 'pocket':
|
| 216 |
+
prot_file = sys.argv[2]
|
| 217 |
+
crystal_lig_file = sys.argv[3] # must be .mol2 file
|
| 218 |
+
prot_write_file = sys.argv[4]
|
| 219 |
+
|
| 220 |
+
# write pocket
|
| 221 |
+
data = []
|
| 222 |
+
d = pocket_parser(prot_file, crystal_lig_file, 1, "demo")
|
| 223 |
+
data.append(d)
|
| 224 |
+
write_lmdb(data, prot_write_file)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
test.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size=256
|
| 2 |
+
|
| 3 |
+
TASK=${1}
|
| 4 |
+
arch=${2}
|
| 5 |
+
weight_path=${3}
|
| 6 |
+
results_path=${4}
|
| 7 |
+
echo "writing to ${results_path}"
|
| 8 |
+
|
| 9 |
+
mkdir -p $results_path
|
| 10 |
+
python ./unimol/test.py "./test_datasets" --user-dir ./unimol --valid-subset test \
|
| 11 |
+
--results-path $results_path \
|
| 12 |
+
--num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
|
| 13 |
+
--task test_task --loss rank_softmax --arch $arch \
|
| 14 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --seed 1 \
|
| 15 |
+
--path $weight_path \
|
| 16 |
+
--log-interval 100 --log-format simple \
|
| 17 |
+
--max-pocket-atoms 511 \
|
| 18 |
+
--test-task $TASK
|
test_fewshot.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_path="./test_datasets"
|
| 2 |
+
|
| 3 |
+
TASK=${1}
|
| 4 |
+
arch=${2}
|
| 5 |
+
sup_num=${3}
|
| 6 |
+
weight_path=${4}
|
| 7 |
+
results_path=${5}
|
| 8 |
+
|
| 9 |
+
n_gpu=1
|
| 10 |
+
batch_size=8
|
| 11 |
+
batch_size_valid=16
|
| 12 |
+
epoch=10
|
| 13 |
+
update_freq=1
|
| 14 |
+
lr=1e-4
|
| 15 |
+
MASTER_PORT=10092
|
| 16 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
seed=1
|
| 19 |
+
|
| 20 |
+
torchrun --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 21 |
+
--results-path $results_path \
|
| 22 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 23 |
+
--task train_task --loss rank_softmax --arch $arch \
|
| 24 |
+
--max-pocket-atoms 256 \
|
| 25 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 26 |
+
--lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 27 |
+
--update-freq $update_freq --seed $seed \
|
| 28 |
+
--log-interval 1 --log-format simple \
|
| 29 |
+
--validate-interval 1 \
|
| 30 |
+
--best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
|
| 31 |
+
--no-save --save-dir $results_path --tmp-save-dir $results_path \
|
| 32 |
+
--find-unused-parameters \
|
| 33 |
+
--maximize-best-checkpoint-metric \
|
| 34 |
+
--split-method random --valid-set $TASK \
|
| 35 |
+
--max-lignum 512 \
|
| 36 |
+
--sup-num $sup_num \
|
| 37 |
+
--restore-model $weight_path --few-shot true \
|
| 38 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256
|
test_fewshot_demo.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_path="./vocab"
|
| 2 |
+
|
| 3 |
+
n_gpu=1
|
| 4 |
+
batch_size=1
|
| 5 |
+
batch_size_valid=1
|
| 6 |
+
epoch=20
|
| 7 |
+
update_freq=1
|
| 8 |
+
lr=1e-4
|
| 9 |
+
MASTER_PORT=10092
|
| 10 |
+
|
| 11 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 12 |
+
export OMP_NUM_THREADS=1
|
| 13 |
+
|
| 14 |
+
arch=${1}
|
| 15 |
+
weight_path=${2}
|
| 16 |
+
results_path=${3}
|
| 17 |
+
lig_file=${4}
|
| 18 |
+
prot_file=${5}
|
| 19 |
+
split_file=${6}
|
| 20 |
+
|
| 21 |
+
sup_num=16
|
| 22 |
+
seed=1
|
| 23 |
+
|
| 24 |
+
torchrun --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 25 |
+
--results-path $results_path \
|
| 26 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 27 |
+
--task train_task --loss rank_softmax --arch $arch \
|
| 28 |
+
--max-pocket-atoms 256 \
|
| 29 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 30 |
+
--lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 31 |
+
--update-freq $update_freq --seed $seed \
|
| 32 |
+
--log-interval 1 --log-format simple \
|
| 33 |
+
--validate-interval 1 \
|
| 34 |
+
--best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
|
| 35 |
+
--no-save --save-dir ./tmp --tmp-save-dir ./tmp \
|
| 36 |
+
--find-unused-parameters \
|
| 37 |
+
--maximize-best-checkpoint-metric \
|
| 38 |
+
--split-method random --valid-set DEMO \
|
| 39 |
+
--max-lignum 512 \
|
| 40 |
+
--sup-num $sup_num \
|
| 41 |
+
--restore-model $weight_path --few-shot true \
|
| 42 |
+
--demo-lig-file $lig_file --demo-prot-file $prot_file --demo-split-file $split_file \
|
| 43 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256
|
test_zeroshot_demo.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size=128
|
| 2 |
+
|
| 3 |
+
lig_file=${1}
|
| 4 |
+
prot_file=${2}
|
| 5 |
+
uniprot=${3}
|
| 6 |
+
arch=${4}
|
| 7 |
+
weight_path=${5}
|
| 8 |
+
results_path=${6}
|
| 9 |
+
echo "writing to ${results_path}"
|
| 10 |
+
|
| 11 |
+
mkdir -p $results_path
|
| 12 |
+
python ./unimol/test.py "./vocab" --user-dir ./unimol --valid-subset test \
|
| 13 |
+
--results-path $results_path \
|
| 14 |
+
--num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
|
| 15 |
+
--task test_task --loss rank_softmax --arch $arch \
|
| 16 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --seed 1 \
|
| 17 |
+
--path $weight_path \
|
| 18 |
+
--log-interval 100 --log-format simple \
|
| 19 |
+
--max-pocket-atoms 511 --demo-lig-file $lig_file --demo-prot-file $prot_file --demo-uniprot $uniprot \
|
| 20 |
+
--test-task DEMO
|
train.sh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_path="./data"
|
| 2 |
+
|
| 3 |
+
save_root="./save"
|
| 4 |
+
save_name="screen_pocket"
|
| 5 |
+
save_dir="${save_root}/${save_name}/savedir_screen"
|
| 6 |
+
tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
|
| 7 |
+
tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
|
| 8 |
+
mkdir -p ${save_dir}
|
| 9 |
+
n_gpu=2
|
| 10 |
+
MASTER_PORT=10062
|
| 11 |
+
finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
|
| 12 |
+
finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
batch_size=24
|
| 16 |
+
batch_size_valid=32
|
| 17 |
+
epoch=50
|
| 18 |
+
dropout=0.0
|
| 19 |
+
warmup=0.06
|
| 20 |
+
update_freq=1
|
| 21 |
+
dist_threshold=8.0
|
| 22 |
+
recycling=3
|
| 23 |
+
lr=1e-4
|
| 24 |
+
|
| 25 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 26 |
+
export OMP_NUM_THREADS=1
|
| 27 |
+
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 28 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 29 |
+
--task train_task --loss rank_softmax --arch pocketscreen \
|
| 30 |
+
--max-pocket-atoms 256 \
|
| 31 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 32 |
+
--lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 33 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
|
| 34 |
+
--tensorboard-logdir $tsb_dir \
|
| 35 |
+
--log-interval 100 --log-format simple \
|
| 36 |
+
--validate-interval 1 \
|
| 37 |
+
--best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
|
| 38 |
+
--save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
|
| 39 |
+
--find-unused-parameters \
|
| 40 |
+
--maximize-best-checkpoint-metric \
|
| 41 |
+
--finetune-pocket-model $finetune_pocket_model \
|
| 42 |
+
--finetune-mol-model $finetune_mol_model \
|
| 43 |
+
--valid-set CASF \
|
| 44 |
+
--max-lignum 16 \
|
| 45 |
+
--protein-similarity-thres 1.0 > ${save_root}/train_log/train_log_${save_name}.txt
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
save_name="screen_pocket_norank"
|
| 49 |
+
save_dir="${save_root}/${save_name}/savedir_screen"
|
| 50 |
+
tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
|
| 51 |
+
tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
|
| 52 |
+
mkdir -p ${save_dir}
|
| 53 |
+
n_gpu=2
|
| 54 |
+
MASTER_PORT=10062
|
| 55 |
+
finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
|
| 56 |
+
finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
|
| 57 |
+
|
| 58 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 59 |
+
export OMP_NUM_THREADS=1
|
| 60 |
+
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 61 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 62 |
+
--task train_task --loss rank_softmax --arch pocketscreen \
|
| 63 |
+
--max-pocket-atoms 256 \
|
| 64 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 65 |
+
--lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 66 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
|
| 67 |
+
--tensorboard-logdir $tsb_dir \
|
| 68 |
+
--log-interval 100 --log-format simple \
|
| 69 |
+
--validate-interval 1 \
|
| 70 |
+
--best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
|
| 71 |
+
--save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
|
| 72 |
+
--find-unused-parameters \
|
| 73 |
+
--maximize-best-checkpoint-metric \
|
| 74 |
+
--finetune-pocket-model $finetune_pocket_model \
|
| 75 |
+
--finetune-mol-model $finetune_mol_model \
|
| 76 |
+
--valid-set CASF \
|
| 77 |
+
--max-lignum 16 \
|
| 78 |
+
--protein-similarity-thres 1.0 \
|
| 79 |
+
--rank-weight 0.0 > ${save_root}/train_log/train_log_${save_name}.txt
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
save_name="screen_pocket_no_similar_protein0.8"
|
| 83 |
+
save_dir="${save_root}/${save_name}/savedir_screen"
|
| 84 |
+
tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
|
| 85 |
+
tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
|
| 86 |
+
mkdir -p ${save_dir}
|
| 87 |
+
n_gpu=2
|
| 88 |
+
MASTER_PORT=10062
|
| 89 |
+
finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
|
| 90 |
+
finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
|
| 91 |
+
|
| 92 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 93 |
+
export OMP_NUM_THREADS=1
|
| 94 |
+
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 95 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 96 |
+
--task train_task --loss rank_softmax --arch pocketscreen \
|
| 97 |
+
--max-pocket-atoms 256 \
|
| 98 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 99 |
+
--lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 100 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
|
| 101 |
+
--tensorboard-logdir $tsb_dir \
|
| 102 |
+
--log-interval 100 --log-format simple \
|
| 103 |
+
--validate-interval 1 \
|
| 104 |
+
--best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
|
| 105 |
+
--save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
|
| 106 |
+
--find-unused-parameters \
|
| 107 |
+
--maximize-best-checkpoint-metric \
|
| 108 |
+
--finetune-pocket-model $finetune_pocket_model \
|
| 109 |
+
--finetune-mol-model $finetune_mol_model \
|
| 110 |
+
--valid-set CASF \
|
| 111 |
+
--max-lignum 16 \
|
| 112 |
+
--protein-similarity-thres 0.8 > ${save_root}/train_log/train_log_${save_name}.txt
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
save_name="screen_pocket_no_similar_protein"
|
| 116 |
+
save_dir="${save_root}/${save_name}/savedir_screen"
|
| 117 |
+
tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
|
| 118 |
+
tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
|
| 119 |
+
mkdir -p ${save_dir}
|
| 120 |
+
n_gpu=2
|
| 121 |
+
MASTER_PORT=10062
|
| 122 |
+
finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
|
| 123 |
+
finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
|
| 124 |
+
|
| 125 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
| 126 |
+
export OMP_NUM_THREADS=1
|
| 127 |
+
CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
|
| 128 |
+
--num-workers 8 --ddp-backend=c10d \
|
| 129 |
+
--task train_task --loss rank_softmax --arch pocketscreen \
|
| 130 |
+
--max-pocket-atoms 256 \
|
| 131 |
+
--optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
|
| 132 |
+
--lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
|
| 133 |
+
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
|
| 134 |
+
--tensorboard-logdir $tsb_dir \
|
| 135 |
+
--log-interval 100 --log-format simple \
|
| 136 |
+
--validate-interval 1 \
|
| 137 |
+
--best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
|
| 138 |
+
--save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
|
| 139 |
+
--find-unused-parameters \
|
| 140 |
+
--maximize-best-checkpoint-metric \
|
| 141 |
+
--finetune-pocket-model $finetune_pocket_model \
|
| 142 |
+
--finetune-mol-model $finetune_mol_model \
|
| 143 |
+
--valid-set CASF \
|
| 144 |
+
--max-lignum 16 \
|
| 145 |
+
--protein-similarity-thres 0.4 > ${save_root}/train_log/train_log_${save_name}.txt
|
unimol/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import unimol.tasks
|
| 3 |
+
import unimol.data
|
| 4 |
+
import unimol.models
|
| 5 |
+
import unimol.losses
|
| 6 |
+
import unimol.utils
|
unimol/data/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .key_dataset import KeyDataset, LengthDataset
|
| 2 |
+
from .normalize_dataset import (
|
| 3 |
+
NormalizeDataset,
|
| 4 |
+
NormalizeDockingPoseDataset,
|
| 5 |
+
)
|
| 6 |
+
from .remove_hydrogen_dataset import (
|
| 7 |
+
RemoveHydrogenDataset,
|
| 8 |
+
RemoveHydrogenResiduePocketDataset,
|
| 9 |
+
RemoveHydrogenPocketDataset,
|
| 10 |
+
)
|
| 11 |
+
from .tta_dataset import (
|
| 12 |
+
TTADataset,
|
| 13 |
+
TTADecoderDataset,
|
| 14 |
+
TTADockingPoseDataset,
|
| 15 |
+
)
|
| 16 |
+
from .cropping_dataset import (
|
| 17 |
+
CroppingDataset,
|
| 18 |
+
CroppingPocketDataset,
|
| 19 |
+
CroppingResiduePocketDataset,
|
| 20 |
+
CroppingPocketDockingPoseDataset,
|
| 21 |
+
CroppingPocketDockingPoseTestDataset,
|
| 22 |
+
)
|
| 23 |
+
from .atom_type_dataset import AtomTypeDataset
|
| 24 |
+
from .add_2d_conformer_dataset import Add2DConformerDataset
|
| 25 |
+
from .distance_dataset import (
|
| 26 |
+
DistanceDataset,
|
| 27 |
+
EdgeTypeDataset,
|
| 28 |
+
CrossDistanceDataset,
|
| 29 |
+
CrossEdgeTypeDataset
|
| 30 |
+
)
|
| 31 |
+
from .conformer_sample_dataset import (
|
| 32 |
+
ConformerSampleDataset,
|
| 33 |
+
ConformerSampleDecoderDataset,
|
| 34 |
+
ConformerSamplePocketDataset,
|
| 35 |
+
ConformerSamplePocketFinetuneDataset,
|
| 36 |
+
ConformerSampleConfGDataset,
|
| 37 |
+
ConformerSampleConfGV2Dataset,
|
| 38 |
+
ConformerSampleDockingPoseDataset,
|
| 39 |
+
)
|
| 40 |
+
from .mask_points_dataset import MaskPointsDataset, MaskPointsPocketDataset
|
| 41 |
+
from .coord_pad_dataset import RightPadDatasetCoord, RightPadDatasetCross2D
|
| 42 |
+
from .from_str_dataset import FromStrLabelDataset
|
| 43 |
+
from .lmdb_dataset import LMDBDataset
|
| 44 |
+
from .prepend_and_append_2d_dataset import PrependAndAppend2DDataset
|
| 45 |
+
from .affinity_dataset import AffinityDataset, AffinityTestDataset, AffinityValidDataset, AffinityMolDataset, AffinityPocketDataset, AffinityHNSDataset, AffinityAugDataset
|
| 46 |
+
from .pocket2mol_dataset import FragmentConformationDataset
|
| 47 |
+
from .vae_binding_dataset import VAEBindingDataset, VAEBindingTestDataset, VAEGenerationTestDataset
|
| 48 |
+
from .resampling_dataset import ResamplingDataset
|
| 49 |
+
from .pair_dataset import PairDataset
|
| 50 |
+
__all__ = []
|
unimol/data/add_2d_conformer_dataset.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from unicore.data import BaseWrapperDataset
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from rdkit.Chem import AllChem
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Add2DConformerDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset, smi, atoms, coordinates):
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
self.smi = smi
|
| 16 |
+
self.atoms = atoms
|
| 17 |
+
self.coordinates = coordinates
|
| 18 |
+
self.set_epoch(None)
|
| 19 |
+
|
| 20 |
+
def set_epoch(self, epoch, **unused):
|
| 21 |
+
super().set_epoch(epoch)
|
| 22 |
+
self.epoch = epoch
|
| 23 |
+
|
| 24 |
+
@lru_cache(maxsize=16)
|
| 25 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 26 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 27 |
+
assert len(atoms) > 0
|
| 28 |
+
smi = self.dataset[index][self.smi]
|
| 29 |
+
coordinates_2d = smi2_2Dcoords(smi)
|
| 30 |
+
coordinates = self.dataset[index][self.coordinates]
|
| 31 |
+
coordinates.append(coordinates_2d)
|
| 32 |
+
return {"smi": smi, "atoms": atoms, "coordinates": coordinates}
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, index: int):
|
| 35 |
+
return self.__cached_item__(index, self.epoch)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def smi2_2Dcoords(smi):
|
| 39 |
+
mol = Chem.MolFromSmiles(smi)
|
| 40 |
+
mol = AllChem.AddHs(mol)
|
| 41 |
+
AllChem.Compute2DCoords(mol)
|
| 42 |
+
coordinates = mol.GetConformer().GetPositions().astype(np.float32)
|
| 43 |
+
len(mol.GetAtoms()) == len(
|
| 44 |
+
coordinates
|
| 45 |
+
), "2D coordinates shape is not align with {}".format(smi)
|
| 46 |
+
return coordinates
|
unimol/data/affinity_dataset.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from unicore.data import BaseWrapperDataset
|
| 10 |
+
import pickle
|
| 11 |
+
from . import data_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AffinityDataset(BaseWrapperDataset):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dataset,
|
| 18 |
+
seed,
|
| 19 |
+
atoms,
|
| 20 |
+
coordinates,
|
| 21 |
+
pocket_atoms,
|
| 22 |
+
pocket_coordinates,
|
| 23 |
+
affinity,
|
| 24 |
+
is_train=False,
|
| 25 |
+
pocket="pocket"
|
| 26 |
+
):
|
| 27 |
+
self.dataset = dataset
|
| 28 |
+
self.seed = seed
|
| 29 |
+
self.atoms = atoms
|
| 30 |
+
self.coordinates = coordinates
|
| 31 |
+
self.pocket_atoms = pocket_atoms
|
| 32 |
+
self.pocket_coordinates = pocket_coordinates
|
| 33 |
+
self.affinity = affinity
|
| 34 |
+
self.is_train = is_train
|
| 35 |
+
self.pocket=pocket
|
| 36 |
+
self.set_epoch(None)
|
| 37 |
+
|
| 38 |
+
def set_epoch(self, epoch, **unused):
|
| 39 |
+
super().set_epoch(epoch)
|
| 40 |
+
self.epoch = epoch
|
| 41 |
+
|
| 42 |
+
def pocket_atom(self, atom):
|
| 43 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 44 |
+
return atom[1]
|
| 45 |
+
else:
|
| 46 |
+
return atom[0]
|
| 47 |
+
|
| 48 |
+
@lru_cache(maxsize=16)
|
| 49 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 50 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 51 |
+
ori_mol_length = len(atoms)
|
| 52 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 53 |
+
size = len(self.dataset[index][self.coordinates])
|
| 54 |
+
if self.is_train:
|
| 55 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 56 |
+
sample_idx = np.random.randint(size)
|
| 57 |
+
else:
|
| 58 |
+
with data_utils.numpy_seed(self.seed, 1, index):
|
| 59 |
+
sample_idx = np.random.randint(size)
|
| 60 |
+
#print(len(self.dataset[index][self.coordinates][sample_idx]))
|
| 61 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 62 |
+
#print(coordinates.shape)
|
| 63 |
+
pocket_atoms = np.array(
|
| 64 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
|
| 65 |
+
)
|
| 66 |
+
ori_pocket_length = len(pocket_atoms)
|
| 67 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
|
| 68 |
+
|
| 69 |
+
smi = self.dataset[index]["smi"]
|
| 70 |
+
pocket = self.dataset[index][self.pocket]
|
| 71 |
+
if self.affinity in self.dataset[index]:
|
| 72 |
+
affinity = float(self.dataset[index][self.affinity])
|
| 73 |
+
else:
|
| 74 |
+
affinity = 1
|
| 75 |
+
return {
|
| 76 |
+
"atoms": atoms,
|
| 77 |
+
"coordinates": coordinates.astype(np.float32),
|
| 78 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 79 |
+
"pocket_atoms": pocket_atoms,
|
| 80 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 81 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 82 |
+
"smi": smi,
|
| 83 |
+
"pocket": pocket,
|
| 84 |
+
"affinity": affinity,
|
| 85 |
+
"ori_mol_length": ori_mol_length,
|
| 86 |
+
"ori_pocket_length": ori_pocket_length
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, index: int):
|
| 90 |
+
return self.__cached_item__(index, self.epoch)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AffinityAugDataset(BaseWrapperDataset):
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
dataset,
|
| 97 |
+
seed,
|
| 98 |
+
atoms,
|
| 99 |
+
coordinates,
|
| 100 |
+
pocket_atoms,
|
| 101 |
+
pocket_coordinates,
|
| 102 |
+
affinity,
|
| 103 |
+
is_train=False,
|
| 104 |
+
pocket="pocket_id"
|
| 105 |
+
):
|
| 106 |
+
self.dataset = dataset
|
| 107 |
+
self.seed = seed
|
| 108 |
+
self.atoms = atoms
|
| 109 |
+
self.coordinates = coordinates
|
| 110 |
+
self.pocket_atoms = pocket_atoms
|
| 111 |
+
self.pocket_coordinates = pocket_coordinates
|
| 112 |
+
self.affinity = affinity
|
| 113 |
+
self.is_train = is_train
|
| 114 |
+
self.pocket=pocket
|
| 115 |
+
self.set_epoch(None)
|
| 116 |
+
|
| 117 |
+
def set_epoch(self, epoch, **unused):
|
| 118 |
+
super().set_epoch(epoch)
|
| 119 |
+
self.epoch = epoch
|
| 120 |
+
|
| 121 |
+
def pocket_atom(self, atom):
|
| 122 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 123 |
+
return atom[1]
|
| 124 |
+
else:
|
| 125 |
+
return atom[0]
|
| 126 |
+
|
| 127 |
+
@lru_cache(maxsize=16)
|
| 128 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 129 |
+
#mol_atoms_list = self.dataset[index][self.atoms]
|
| 130 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 131 |
+
mol_idx = np.random.randint(len(self.dataset[index][self.atoms]))
|
| 132 |
+
atoms = np.array(self.dataset[index][self.atoms][mol_idx])
|
| 133 |
+
ori_mol_length = len(atoms)
|
| 134 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 135 |
+
size = len(self.dataset[index][self.coordinates][mol_idx])
|
| 136 |
+
if self.is_train:
|
| 137 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 138 |
+
sample_idx = np.random.randint(size)
|
| 139 |
+
else:
|
| 140 |
+
with data_utils.numpy_seed(self.seed, 1, index):
|
| 141 |
+
sample_idx = np.random.randint(size)
|
| 142 |
+
#print(len(self.dataset[index][self.coordinates][sample_idx]))
|
| 143 |
+
coordinates = self.dataset[index][self.coordinates][mol_idx][sample_idx]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
#pocket_list = self.dataset[index][self.pocket_atoms]
|
| 147 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 148 |
+
pocket_idx = np.random.randint(len(self.dataset[index][self.pocket_atoms]))
|
| 149 |
+
pocket_atoms = np.array(
|
| 150 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms][pocket_idx]]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
ori_pocket_length = len(pocket_atoms)
|
| 154 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates][pocket_idx])
|
| 155 |
+
|
| 156 |
+
smi = self.dataset[index]["smiles"][mol_idx]
|
| 157 |
+
pocket = self.dataset[index][self.pocket][0]
|
| 158 |
+
if self.affinity in self.dataset[index]:
|
| 159 |
+
affinity = float(self.dataset[index][self.affinity])
|
| 160 |
+
else:
|
| 161 |
+
affinity = 1
|
| 162 |
+
return {
|
| 163 |
+
"atoms": atoms,
|
| 164 |
+
"coordinates": coordinates.astype(np.float32),
|
| 165 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 166 |
+
"pocket_atoms": pocket_atoms,
|
| 167 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 168 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 169 |
+
"smi": smi,
|
| 170 |
+
"pocket": pocket,
|
| 171 |
+
"affinity": affinity,
|
| 172 |
+
"ori_mol_length": ori_mol_length,
|
| 173 |
+
"ori_pocket_length": ori_pocket_length
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
def __getitem__(self, index: int):
|
| 177 |
+
return self.__cached_item__(index, self.epoch)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class AffinityHNSDataset(BaseWrapperDataset):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
dataset,
|
| 184 |
+
seed,
|
| 185 |
+
atoms,
|
| 186 |
+
coordinates,
|
| 187 |
+
atoms_hns,
|
| 188 |
+
coordinates_hns,
|
| 189 |
+
pocket_atoms,
|
| 190 |
+
pocket_coordinates,
|
| 191 |
+
affinity,
|
| 192 |
+
is_train=False,
|
| 193 |
+
pocket="pocket"
|
| 194 |
+
):
|
| 195 |
+
self.dataset = dataset
|
| 196 |
+
self.seed = seed
|
| 197 |
+
self.atoms = atoms
|
| 198 |
+
self.coordinates = coordinates
|
| 199 |
+
self.atoms_hns = atoms_hns
|
| 200 |
+
self.coordinates_hns = coordinates_hns
|
| 201 |
+
self.pocket_atoms = pocket_atoms
|
| 202 |
+
self.pocket_coordinates = pocket_coordinates
|
| 203 |
+
self.affinity = affinity
|
| 204 |
+
self.is_train = is_train
|
| 205 |
+
self.pocket=pocket
|
| 206 |
+
self.set_epoch(None)
|
| 207 |
+
|
| 208 |
+
def set_epoch(self, epoch, **unused):
|
| 209 |
+
super().set_epoch(epoch)
|
| 210 |
+
self.epoch = epoch
|
| 211 |
+
|
| 212 |
+
def pocket_atom(self, atom):
|
| 213 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 214 |
+
return atom[1]
|
| 215 |
+
else:
|
| 216 |
+
return atom[0]
|
| 217 |
+
|
| 218 |
+
@lru_cache(maxsize=16)
|
| 219 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 220 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 221 |
+
ori_mol_length = len(atoms)
|
| 222 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 223 |
+
size = len(self.dataset[index][self.coordinates])
|
| 224 |
+
if self.is_train:
|
| 225 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 226 |
+
sample_idx = np.random.randint(size)
|
| 227 |
+
else:
|
| 228 |
+
with data_utils.numpy_seed(self.seed, 1, index):
|
| 229 |
+
sample_idx = np.random.randint(size)
|
| 230 |
+
#print(len(self.dataset[index][self.coordinates][sample_idx]))
|
| 231 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 232 |
+
atoms_hns = np.array(self.dataset[index][self.atoms_hns])
|
| 233 |
+
coordinates_hns = self.dataset[index][self.coordinates_hns][0]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
pocket_atoms = np.array(
|
| 238 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
|
| 239 |
+
)
|
| 240 |
+
ori_pocket_length = len(pocket_atoms)
|
| 241 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
|
| 242 |
+
|
| 243 |
+
smi = self.dataset[index]["smi"]
|
| 244 |
+
pocket = self.dataset[index][self.pocket]
|
| 245 |
+
if self.affinity in self.dataset[index]:
|
| 246 |
+
affinity = float(self.dataset[index][self.affinity])
|
| 247 |
+
else:
|
| 248 |
+
affinity = 1
|
| 249 |
+
return {
|
| 250 |
+
"atoms": atoms,
|
| 251 |
+
"coordinates": coordinates.astype(np.float32),
|
| 252 |
+
"atoms_hns": atoms_hns,
|
| 253 |
+
"coordinates_hns": coordinates_hns.astype(np.float32),
|
| 254 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 255 |
+
"pocket_atoms": pocket_atoms,
|
| 256 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 257 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 258 |
+
"smi": smi,
|
| 259 |
+
"pocket": pocket,
|
| 260 |
+
"affinity": affinity,
|
| 261 |
+
"ori_mol_length": ori_mol_length,
|
| 262 |
+
"ori_pocket_length": ori_pocket_length
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
def __getitem__(self, index: int):
|
| 266 |
+
return self.__cached_item__(index, self.epoch)
|
| 267 |
+
|
| 268 |
+
class AffinityTestDataset(BaseWrapperDataset):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
dataset,
|
| 272 |
+
seed,
|
| 273 |
+
atoms,
|
| 274 |
+
coordinates,
|
| 275 |
+
pocket_atoms,
|
| 276 |
+
pocket_coordinates,
|
| 277 |
+
affinity=None,
|
| 278 |
+
is_train=False,
|
| 279 |
+
pocket="pocket"
|
| 280 |
+
):
|
| 281 |
+
self.dataset = dataset
|
| 282 |
+
self.seed = seed
|
| 283 |
+
self.atoms = atoms
|
| 284 |
+
self.coordinates = coordinates
|
| 285 |
+
self.pocket_atoms = pocket_atoms
|
| 286 |
+
self.pocket_coordinates = pocket_coordinates
|
| 287 |
+
self.affinity = affinity
|
| 288 |
+
self.is_train = is_train
|
| 289 |
+
self.pocket=pocket
|
| 290 |
+
self.set_epoch(None)
|
| 291 |
+
|
| 292 |
+
def set_epoch(self, epoch, **unused):
|
| 293 |
+
super().set_epoch(epoch)
|
| 294 |
+
self.epoch = epoch
|
| 295 |
+
|
| 296 |
+
def pocket_atom(self, atom):
|
| 297 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 298 |
+
return atom[1]
|
| 299 |
+
else:
|
| 300 |
+
return atom[0]
|
| 301 |
+
|
| 302 |
+
@lru_cache(maxsize=16)
|
| 303 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 304 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 305 |
+
ori_length = len(atoms)
|
| 306 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 307 |
+
size = len(self.dataset[index][self.coordinates])
|
| 308 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 309 |
+
sample_idx = np.random.randint(size)
|
| 310 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 311 |
+
pocket_atoms = np.array(
|
| 312 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
|
| 313 |
+
)
|
| 314 |
+
#print(len(self.dataset[index][self.pocket_coordinates]))
|
| 315 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
|
| 316 |
+
|
| 317 |
+
smi = self.dataset[index]["smi"]
|
| 318 |
+
pocket = self.dataset[index][self.pocket]
|
| 319 |
+
affinity = self.dataset[index][self.affinity]
|
| 320 |
+
return {
|
| 321 |
+
"atoms": atoms,
|
| 322 |
+
"coordinates": coordinates.astype(np.float32),
|
| 323 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 324 |
+
"pocket_atoms": pocket_atoms,
|
| 325 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 326 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 327 |
+
"smi": smi,
|
| 328 |
+
"pocket": pocket,
|
| 329 |
+
"affinity": affinity.astype(np.float32),
|
| 330 |
+
"ori_length": ori_length
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
def __getitem__(self, index: int):
|
| 334 |
+
return self.__cached_item__(index, self.epoch)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class AffinityMolDataset(BaseWrapperDataset):
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
dataset,
|
| 341 |
+
seed,
|
| 342 |
+
atoms,
|
| 343 |
+
coordinates,
|
| 344 |
+
is_train=False,
|
| 345 |
+
):
|
| 346 |
+
self.dataset = dataset
|
| 347 |
+
self.seed = seed
|
| 348 |
+
self.atoms = atoms
|
| 349 |
+
self.coordinates = coordinates
|
| 350 |
+
self.is_train = is_train
|
| 351 |
+
self.set_epoch(None)
|
| 352 |
+
|
| 353 |
+
def set_epoch(self, epoch, **unused):
|
| 354 |
+
super().set_epoch(epoch)
|
| 355 |
+
self.epoch = epoch
|
| 356 |
+
|
| 357 |
+
def pocket_atom(self, atom):
|
| 358 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 359 |
+
return atom[1]
|
| 360 |
+
else:
|
| 361 |
+
return atom[0]
|
| 362 |
+
|
| 363 |
+
@lru_cache(maxsize=16)
|
| 364 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 365 |
+
#print(self.dataset[index])
|
| 366 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 367 |
+
ori_length = len(atoms)
|
| 368 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 369 |
+
size = len(self.dataset[index][self.coordinates])
|
| 370 |
+
#print(size)
|
| 371 |
+
|
| 372 |
+
# TODO: FB: introduce enough random when training using pairwise data
|
| 373 |
+
# with data_utils.numpy_seed(self.seed, epoch, index):
|
| 374 |
+
if self.is_train:
|
| 375 |
+
sample_idx = np.random.randint(size)
|
| 376 |
+
else:
|
| 377 |
+
with data_utils.numpy_seed(self.seed, index):
|
| 378 |
+
sample_idx = np.random.randint(size)
|
| 379 |
+
# check coordinates is 2 dimension or not
|
| 380 |
+
if len(self.dataset[index][self.coordinates][sample_idx].shape) == 2:
|
| 381 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 382 |
+
else:
|
| 383 |
+
coordinates = self.dataset[index][self.coordinates]
|
| 384 |
+
#coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 385 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 386 |
+
|
| 387 |
+
smi = self.dataset[index]["smi"]
|
| 388 |
+
name = self.dataset[index].get("name", None)
|
| 389 |
+
mol = pickle.dumps(self.dataset[index].get("mol", None))
|
| 390 |
+
return {
|
| 391 |
+
"atoms": atoms,
|
| 392 |
+
"coordinates": coordinates.astype(np.float32),
|
| 393 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 394 |
+
"smi": smi,
|
| 395 |
+
"ori_length": ori_length,
|
| 396 |
+
"name": name,
|
| 397 |
+
"mol": mol
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def __getitem__(self, index: int):
|
| 401 |
+
return self.__cached_item__(index, self.epoch)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class AffinityPocketDataset(BaseWrapperDataset):
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
dataset,
|
| 408 |
+
seed,
|
| 409 |
+
pocket_atoms,
|
| 410 |
+
pocket_coordinates,
|
| 411 |
+
is_train=False,
|
| 412 |
+
pocket="pocket"
|
| 413 |
+
):
|
| 414 |
+
self.dataset = dataset
|
| 415 |
+
self.seed = seed
|
| 416 |
+
self.pocket_atoms = pocket_atoms
|
| 417 |
+
self.pocket_coordinates = pocket_coordinates
|
| 418 |
+
self.is_train = is_train
|
| 419 |
+
self.pocket=pocket
|
| 420 |
+
self.set_epoch(None)
|
| 421 |
+
|
| 422 |
+
def set_epoch(self, epoch, **unused):
|
| 423 |
+
super().set_epoch(epoch)
|
| 424 |
+
self.epoch = epoch
|
| 425 |
+
|
| 426 |
+
def pocket_atom(self, atom):
|
| 427 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 428 |
+
return atom[1]
|
| 429 |
+
else:
|
| 430 |
+
return atom[0]
|
| 431 |
+
|
| 432 |
+
@lru_cache(maxsize=16)
|
| 433 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 434 |
+
# print(self.dataset[index].keys())
|
| 435 |
+
pocket_atoms = np.array(
|
| 436 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
|
| 437 |
+
)
|
| 438 |
+
ori_length = len(pocket_atoms)
|
| 439 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
|
| 440 |
+
if self.pocket in self.dataset[index]:
|
| 441 |
+
pocket = self.dataset[index][self.pocket]
|
| 442 |
+
else:
|
| 443 |
+
pocket = ""
|
| 444 |
+
if "pocket_residue_name" in self.dataset[index]:
|
| 445 |
+
pocket_residue_name = self.dataset[index]["pocket_residue_name"]
|
| 446 |
+
pocket_residue_name_noH = []
|
| 447 |
+
for res, atom in zip(pocket_residue_name, pocket_atoms):
|
| 448 |
+
if atom == "H":
|
| 449 |
+
continue
|
| 450 |
+
pocket_residue_name_noH.append(res)
|
| 451 |
+
else:
|
| 452 |
+
pocket_residue_name_noH = [""]
|
| 453 |
+
return {
|
| 454 |
+
"pocket_atoms": pocket_atoms,
|
| 455 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 456 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 457 |
+
"pocket": pocket,
|
| 458 |
+
"pocket_residue_name": pocket_residue_name_noH,
|
| 459 |
+
"ori_length": ori_length
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
def __getitem__(self, index: int):
|
| 463 |
+
return self.__cached_item__(index, self.epoch)
|
| 464 |
+
|
| 465 |
+
class AffinityValidDataset(BaseWrapperDataset):
|
| 466 |
+
def __init__(
|
| 467 |
+
self,
|
| 468 |
+
dataset,
|
| 469 |
+
seed,
|
| 470 |
+
atoms,
|
| 471 |
+
coordinates,
|
| 472 |
+
pocket_atoms,
|
| 473 |
+
pocket_coordinates,
|
| 474 |
+
pocket="pocket"
|
| 475 |
+
):
|
| 476 |
+
self.dataset = dataset
|
| 477 |
+
self.seed = seed
|
| 478 |
+
self.atoms = atoms
|
| 479 |
+
self.coordinates = coordinates
|
| 480 |
+
self.pocket_atoms = pocket_atoms
|
| 481 |
+
self.pocket_coordinates = pocket_coordinates
|
| 482 |
+
self.pocket=pocket
|
| 483 |
+
self.set_epoch(None)
|
| 484 |
+
|
| 485 |
+
def set_epoch(self, epoch, **unused):
|
| 486 |
+
super().set_epoch(epoch)
|
| 487 |
+
self.epoch = epoch
|
| 488 |
+
|
| 489 |
+
def pocket_atom(self, atom):
|
| 490 |
+
if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
| 491 |
+
return atom[1]
|
| 492 |
+
else:
|
| 493 |
+
return atom[0]
|
| 494 |
+
|
| 495 |
+
@lru_cache(maxsize=16)
|
| 496 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 497 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 498 |
+
ori_mol_length = len(atoms)
|
| 499 |
+
#coordinates = self.dataset[index][self.coordinates]
|
| 500 |
+
|
| 501 |
+
size = len(self.dataset[index][self.coordinates])
|
| 502 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 503 |
+
sample_idx = np.random.randint(size)
|
| 504 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 505 |
+
pocket_atoms = np.array(
|
| 506 |
+
[self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
|
| 507 |
+
)
|
| 508 |
+
ori_pocket_length = len(pocket_atoms)
|
| 509 |
+
pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
|
| 510 |
+
|
| 511 |
+
smi = self.dataset[index]["smi"]
|
| 512 |
+
pocket = self.dataset[index][self.pocket]
|
| 513 |
+
return {
|
| 514 |
+
"atoms": atoms,
|
| 515 |
+
"coordinates": coordinates.astype(np.float32),
|
| 516 |
+
"holo_coordinates": coordinates.astype(np.float32),#placeholder
|
| 517 |
+
"pocket_atoms": pocket_atoms,
|
| 518 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 519 |
+
"holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
|
| 520 |
+
"smi": smi,
|
| 521 |
+
"pocket": pocket,
|
| 522 |
+
"ori_mol_length": ori_mol_length,
|
| 523 |
+
"ori_pocket_length": ori_pocket_length
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
def __getitem__(self, index: int):
|
| 527 |
+
return self.__cached_item__(index, self.epoch)
|
unimol/data/atom_type_dataset.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from unicore.data import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AtomTypeDataset(BaseWrapperDataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
raw_dataset,
|
| 13 |
+
dataset,
|
| 14 |
+
smi="smi",
|
| 15 |
+
atoms="atoms",
|
| 16 |
+
):
|
| 17 |
+
self.raw_dataset = raw_dataset
|
| 18 |
+
self.dataset = dataset
|
| 19 |
+
self.smi = smi
|
| 20 |
+
self.atoms = atoms
|
| 21 |
+
|
| 22 |
+
@lru_cache(maxsize=16)
|
| 23 |
+
def __getitem__(self, index: int):
|
| 24 |
+
# for low rdkit version
|
| 25 |
+
if len(self.dataset[index]["atoms"]) != len(self.dataset[index]["coordinates"]):
|
| 26 |
+
min_len = min(
|
| 27 |
+
len(self.dataset[index]["atoms"]),
|
| 28 |
+
len(self.dataset[index]["coordinates"]),
|
| 29 |
+
)
|
| 30 |
+
self.dataset[index]["atoms"] = self.dataset[index]["atoms"][:min_len]
|
| 31 |
+
self.dataset[index]["coordinates"] = self.dataset[index]["coordinates"][
|
| 32 |
+
:min_len
|
| 33 |
+
]
|
| 34 |
+
return self.dataset[index]
|
unimol/data/conformer_sample_dataset.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from unicore.data import BaseWrapperDataset
|
| 8 |
+
from . import data_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConformerSampleDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(self, dataset, seed, atoms, coordinates):
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
self.seed = seed
|
| 15 |
+
self.atoms = atoms
|
| 16 |
+
self.coordinates = coordinates
|
| 17 |
+
self.set_epoch(None)
|
| 18 |
+
|
| 19 |
+
def set_epoch(self, epoch, **unused):
|
| 20 |
+
super().set_epoch(epoch)
|
| 21 |
+
self.epoch = epoch
|
| 22 |
+
|
| 23 |
+
@lru_cache(maxsize=16)
|
| 24 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 25 |
+
#print(index,self.dataset[index])
|
| 26 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 27 |
+
assert len(atoms) > 0
|
| 28 |
+
size = len(self.dataset[index][self.coordinates])
|
| 29 |
+
#print(size)
|
| 30 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 31 |
+
sample_idx = np.random.randint(size)
|
| 32 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 33 |
+
return {"atoms": atoms, "coordinates": coordinates.astype(np.float32)}
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, index: int):
|
| 36 |
+
return self.__cached_item__(index, self.epoch)
|
| 37 |
+
|
| 38 |
+
class ConformerSampleDecoderDataset(BaseWrapperDataset):
|
| 39 |
+
def __init__(self, dataset, seed, atoms, coordinates, selfies):
|
| 40 |
+
self.dataset = dataset
|
| 41 |
+
self.seed = seed
|
| 42 |
+
self.atoms = atoms
|
| 43 |
+
self.coordinates = coordinates
|
| 44 |
+
self.selfies = selfies
|
| 45 |
+
self.set_epoch(None)
|
| 46 |
+
|
| 47 |
+
def set_epoch(self, epoch, **unused):
|
| 48 |
+
super().set_epoch(epoch)
|
| 49 |
+
self.epoch = epoch
|
| 50 |
+
|
| 51 |
+
@lru_cache(maxsize=16)
|
| 52 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 53 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 54 |
+
assert len(atoms) > 0
|
| 55 |
+
# print("self.dataset[index][self.atoms]")
|
| 56 |
+
# print(self.dataset[index][self.atoms])
|
| 57 |
+
# print("self.dataset[index][self.selfies]")
|
| 58 |
+
# print(self.dataset[index][self.selfies])
|
| 59 |
+
selfies = np.array(self.dataset[index][self.selfies])
|
| 60 |
+
assert len(selfies) > 0
|
| 61 |
+
size = len(self.dataset[index][self.coordinates])
|
| 62 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 63 |
+
sample_idx = np.random.randint(size)
|
| 64 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 65 |
+
return {"atoms": atoms, "selfies": selfies, "coordinates": coordinates.astype(np.float32)}
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index: int):
|
| 68 |
+
return self.__cached_item__(index, self.epoch)
|
| 69 |
+
|
| 70 |
+
class ConformerSamplePocketDataset(BaseWrapperDataset):
|
| 71 |
+
def __init__(self, dataset, seed, atoms, coordinates, dict_name):
|
| 72 |
+
self.dataset = dataset
|
| 73 |
+
self.seed = seed
|
| 74 |
+
self.atoms = atoms
|
| 75 |
+
self.dict_name = dict_name
|
| 76 |
+
self.coordinates = coordinates
|
| 77 |
+
self.set_epoch(None)
|
| 78 |
+
|
| 79 |
+
def set_epoch(self, epoch, **unused):
|
| 80 |
+
super().set_epoch(epoch)
|
| 81 |
+
self.epoch = epoch
|
| 82 |
+
|
| 83 |
+
@lru_cache(maxsize=16)
|
| 84 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 85 |
+
if self.dict_name == "dict_coarse.txt":
|
| 86 |
+
atoms = np.array([a[0] for a in self.dataset[index][self.atoms]])
|
| 87 |
+
elif self.dict_name == "dict_fine.txt":
|
| 88 |
+
atoms = np.array(
|
| 89 |
+
[
|
| 90 |
+
a[0] if len(a) == 1 or a[0] == "H" else a[:2]
|
| 91 |
+
for a in self.dataset[index][self.atoms]
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
assert len(atoms) > 0
|
| 95 |
+
size = len(self.dataset[index][self.coordinates])
|
| 96 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 97 |
+
sample_idx = np.random.randint(size)
|
| 98 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 99 |
+
residue = np.array(self.dataset[index]["residue"])
|
| 100 |
+
score = np.float(self.dataset[index]["meta_info"]["fpocket"]["Score"])
|
| 101 |
+
return {
|
| 102 |
+
"atoms": atoms,
|
| 103 |
+
"coordinates": coordinates.astype(np.float32),
|
| 104 |
+
"residue": residue,
|
| 105 |
+
"score": score,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, index: int):
|
| 109 |
+
return self.__cached_item__(index, self.epoch)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class ConformerSamplePocketFinetuneDataset(BaseWrapperDataset):
|
| 113 |
+
def __init__(self, dataset, seed, atoms, residues, coordinates):
|
| 114 |
+
self.dataset = dataset
|
| 115 |
+
self.seed = seed
|
| 116 |
+
self.atoms = atoms
|
| 117 |
+
self.residues = residues
|
| 118 |
+
self.coordinates = coordinates
|
| 119 |
+
self.set_epoch(None)
|
| 120 |
+
|
| 121 |
+
def set_epoch(self, epoch, **unused):
|
| 122 |
+
super().set_epoch(epoch)
|
| 123 |
+
self.epoch = epoch
|
| 124 |
+
|
| 125 |
+
@lru_cache(maxsize=16)
|
| 126 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 127 |
+
atoms = np.array(
|
| 128 |
+
[a[0] for a in self.dataset[index][self.atoms]]
|
| 129 |
+
) # only 'C H O N S'
|
| 130 |
+
assert len(atoms) > 0
|
| 131 |
+
# This judgment is reserved for possible future expansion.
|
| 132 |
+
# The number of pocket conformations is 1, and the 'sample' does not work.
|
| 133 |
+
if isinstance(self.dataset[index][self.coordinates], list):
|
| 134 |
+
size = len(self.dataset[index][self.coordinates])
|
| 135 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 136 |
+
sample_idx = np.random.randint(size)
|
| 137 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 138 |
+
else:
|
| 139 |
+
coordinates = self.dataset[index][self.coordinates]
|
| 140 |
+
|
| 141 |
+
if self.residues in self.dataset[index]:
|
| 142 |
+
residues = np.array(self.dataset[index][self.residues])
|
| 143 |
+
else:
|
| 144 |
+
residues = None
|
| 145 |
+
assert len(atoms) == len(coordinates)
|
| 146 |
+
return {
|
| 147 |
+
self.atoms: atoms,
|
| 148 |
+
self.coordinates: coordinates.astype(np.float32),
|
| 149 |
+
self.residues: residues,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, index: int):
|
| 153 |
+
return self.__cached_item__(index, self.epoch)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ConformerSampleConfGDataset(BaseWrapperDataset):
|
| 157 |
+
def __init__(self, dataset, seed, atoms, coordinates, tgt_coordinates):
|
| 158 |
+
self.dataset = dataset
|
| 159 |
+
self.seed = seed
|
| 160 |
+
self.atoms = atoms
|
| 161 |
+
self.coordinates = coordinates
|
| 162 |
+
self.tgt_coordinates = tgt_coordinates
|
| 163 |
+
self.set_epoch(None)
|
| 164 |
+
|
| 165 |
+
def set_epoch(self, epoch, **unused):
|
| 166 |
+
super().set_epoch(epoch)
|
| 167 |
+
self.epoch = epoch
|
| 168 |
+
|
| 169 |
+
@lru_cache(maxsize=16)
|
| 170 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 171 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 172 |
+
assert len(atoms) > 0
|
| 173 |
+
size = len(self.dataset[index][self.coordinates])
|
| 174 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 175 |
+
sample_idx = np.random.randint(size)
|
| 176 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 177 |
+
tgt_coordinates = self.dataset[index][self.tgt_coordinates]
|
| 178 |
+
return {
|
| 179 |
+
self.atoms: atoms,
|
| 180 |
+
self.coordinates: coordinates.astype(np.float32),
|
| 181 |
+
self.tgt_coordinates: tgt_coordinates.astype(np.float32),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def __getitem__(self, index: int):
|
| 185 |
+
return self.__cached_item__(index, self.epoch)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ConformerSampleConfGV2Dataset(BaseWrapperDataset):
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
dataset,
|
| 192 |
+
seed,
|
| 193 |
+
atoms,
|
| 194 |
+
coordinates,
|
| 195 |
+
tgt_coordinates,
|
| 196 |
+
beta=1.0,
|
| 197 |
+
smooth=0.1,
|
| 198 |
+
topN=10,
|
| 199 |
+
):
|
| 200 |
+
self.dataset = dataset
|
| 201 |
+
self.seed = seed
|
| 202 |
+
self.atoms = atoms
|
| 203 |
+
self.coordinates = coordinates
|
| 204 |
+
self.tgt_coordinates = tgt_coordinates
|
| 205 |
+
self.beta = beta
|
| 206 |
+
self.smooth = smooth
|
| 207 |
+
self.topN = topN
|
| 208 |
+
self.set_epoch(None)
|
| 209 |
+
|
| 210 |
+
def set_epoch(self, epoch, **unused):
|
| 211 |
+
super().set_epoch(epoch)
|
| 212 |
+
self.epoch = epoch
|
| 213 |
+
|
| 214 |
+
@lru_cache(maxsize=16)
|
| 215 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 216 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 217 |
+
assert len(atoms) > 0
|
| 218 |
+
meta_df = self.dataset[index]["meta"]
|
| 219 |
+
tgt_conf_ids = meta_df["gid"].unique()
|
| 220 |
+
# randomly choose one conf
|
| 221 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 222 |
+
conf_id = np.random.choice(tgt_conf_ids)
|
| 223 |
+
conf_df = meta_df[meta_df["gid"] == conf_id]
|
| 224 |
+
conf_df = conf_df.sort_values("score").reset_index(drop=False)[
|
| 225 |
+
: self.topN
|
| 226 |
+
] # only use top 5 confs for sampling...
|
| 227 |
+
# importance sampling with rmsd inverse score
|
| 228 |
+
|
| 229 |
+
def normalize(x, beta=1.0, smooth=0.1):
|
| 230 |
+
x = 1.0 / (x**beta + smooth)
|
| 231 |
+
return x / x.sum()
|
| 232 |
+
|
| 233 |
+
rmsd_score = conf_df["score"].values
|
| 234 |
+
weight = normalize(
|
| 235 |
+
rmsd_score, beta=self.beta, smooth=self.smooth
|
| 236 |
+
) # for smoothing purpose
|
| 237 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 238 |
+
idx = np.random.choice(len(conf_df), 1, replace=False, p=weight)
|
| 239 |
+
# idx = [np.argmax(weight)]
|
| 240 |
+
coordinates = conf_df.iloc[idx]["rdkit_coords"].values[0]
|
| 241 |
+
tgt_coordinates = conf_df.iloc[idx]["tgt_coords"].values[0]
|
| 242 |
+
return {
|
| 243 |
+
self.atoms: atoms,
|
| 244 |
+
self.coordinates: coordinates.astype(np.float32),
|
| 245 |
+
self.tgt_coordinates: tgt_coordinates.astype(np.float32),
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
def __getitem__(self, index: int):
|
| 249 |
+
return self.__cached_item__(index, self.epoch)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class ConformerSampleDockingPoseDataset(BaseWrapperDataset):
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
dataset,
|
| 256 |
+
seed,
|
| 257 |
+
atoms,
|
| 258 |
+
coordinates,
|
| 259 |
+
pocket_atoms,
|
| 260 |
+
pocket_coordinates,
|
| 261 |
+
holo_coordinates,
|
| 262 |
+
holo_pocket_coordinates,
|
| 263 |
+
is_train=True,
|
| 264 |
+
):
|
| 265 |
+
self.dataset = dataset
|
| 266 |
+
self.seed = seed
|
| 267 |
+
self.atoms = atoms
|
| 268 |
+
self.coordinates = coordinates
|
| 269 |
+
self.pocket_atoms = pocket_atoms
|
| 270 |
+
self.pocket_coordinates = pocket_coordinates
|
| 271 |
+
self.holo_coordinates = holo_coordinates
|
| 272 |
+
self.holo_pocket_coordinates = holo_pocket_coordinates
|
| 273 |
+
self.is_train = is_train
|
| 274 |
+
self.set_epoch(None)
|
| 275 |
+
|
| 276 |
+
def set_epoch(self, epoch, **unused):
|
| 277 |
+
super().set_epoch(epoch)
|
| 278 |
+
self.epoch = epoch
|
| 279 |
+
|
| 280 |
+
@lru_cache(maxsize=16)
|
| 281 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 282 |
+
atoms = np.array(self.dataset[index][self.atoms])
|
| 283 |
+
size = len(self.dataset[index][self.coordinates])
|
| 284 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 285 |
+
sample_idx = np.random.randint(size)
|
| 286 |
+
coordinates = self.dataset[index][self.coordinates][sample_idx]
|
| 287 |
+
pocket_atoms = np.array(
|
| 288 |
+
[item[0] for item in self.dataset[index][self.pocket_atoms]]
|
| 289 |
+
)
|
| 290 |
+
pocket_coordinates = self.dataset[index][self.pocket_coordinates][0]
|
| 291 |
+
if self.is_train:
|
| 292 |
+
holo_coordinates = self.dataset[index][self.holo_coordinates][0]
|
| 293 |
+
holo_pocket_coordinates = self.dataset[index][self.holo_pocket_coordinates][
|
| 294 |
+
0
|
| 295 |
+
]
|
| 296 |
+
else:
|
| 297 |
+
holo_coordinates = coordinates
|
| 298 |
+
holo_pocket_coordinates = pocket_coordinates
|
| 299 |
+
|
| 300 |
+
smi = self.dataset[index]["smi"]
|
| 301 |
+
pocket = self.dataset[index]["pocket"]
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"atoms": atoms,
|
| 305 |
+
"coordinates": coordinates.astype(np.float32),
|
| 306 |
+
"pocket_atoms": pocket_atoms,
|
| 307 |
+
"pocket_coordinates": pocket_coordinates.astype(np.float32),
|
| 308 |
+
"holo_coordinates": holo_coordinates.astype(np.float32),
|
| 309 |
+
"holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32),
|
| 310 |
+
"smi": smi,
|
| 311 |
+
"pocket": pocket,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
def __getitem__(self, index: int):
|
| 315 |
+
return self.__cached_item__(index, self.epoch)
|
unimol/data/coord_pad_dataset.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
from unicore.data import BaseWrapperDataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def collate_tokens_coords(
|
| 9 |
+
values,
|
| 10 |
+
pad_idx,
|
| 11 |
+
left_pad=False,
|
| 12 |
+
pad_to_length=None,
|
| 13 |
+
pad_to_multiple=1,
|
| 14 |
+
):
|
| 15 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 16 |
+
size = max(v.size(0) for v in values)
|
| 17 |
+
size = size if pad_to_length is None else max(size, pad_to_length)
|
| 18 |
+
#if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
| 19 |
+
# size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
| 20 |
+
res = values[0].new(len(values), size, 3).fill_(pad_idx)
|
| 21 |
+
|
| 22 |
+
def copy_tensor(src, dst):
|
| 23 |
+
assert dst.numel() == src.numel()
|
| 24 |
+
dst.copy_(src)
|
| 25 |
+
|
| 26 |
+
for i, v in enumerate(values):
|
| 27 |
+
copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
|
| 28 |
+
return res
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RightPadDatasetCoord(BaseWrapperDataset):
|
| 32 |
+
def __init__(self, dataset, pad_idx, left_pad=False):
|
| 33 |
+
super().__init__(dataset)
|
| 34 |
+
self.pad_idx = pad_idx
|
| 35 |
+
self.left_pad = left_pad
|
| 36 |
+
|
| 37 |
+
def collater(self, samples):
|
| 38 |
+
return collate_tokens_coords(
|
| 39 |
+
samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def collate_cross_2d(
|
| 44 |
+
values,
|
| 45 |
+
pad_idx,
|
| 46 |
+
left_pad=False,
|
| 47 |
+
pad_to_length=None,
|
| 48 |
+
pad_to_multiple=1,
|
| 49 |
+
):
|
| 50 |
+
"""Convert a list of 2d tensors into a padded 2d tensor."""
|
| 51 |
+
size_h = max(v.size(0) for v in values)
|
| 52 |
+
size_w = max(v.size(1) for v in values)
|
| 53 |
+
if pad_to_multiple != 1 and size_h % pad_to_multiple != 0:
|
| 54 |
+
size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
| 55 |
+
if pad_to_multiple != 1 and size_w % pad_to_multiple != 0:
|
| 56 |
+
size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
| 57 |
+
res = values[0].new(len(values), size_h, size_w).fill_(pad_idx)
|
| 58 |
+
|
| 59 |
+
def copy_tensor(src, dst):
|
| 60 |
+
assert dst.numel() == src.numel()
|
| 61 |
+
dst.copy_(src)
|
| 62 |
+
|
| 63 |
+
for i, v in enumerate(values):
|
| 64 |
+
copy_tensor(
|
| 65 |
+
v,
|
| 66 |
+
res[i][size_h - v.size(0) :, size_w - v.size(1) :]
|
| 67 |
+
if left_pad
|
| 68 |
+
else res[i][: v.size(0), : v.size(1)],
|
| 69 |
+
)
|
| 70 |
+
return res
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RightPadDatasetCross2D(BaseWrapperDataset):
|
| 74 |
+
def __init__(self, dataset, pad_idx, left_pad=False):
|
| 75 |
+
super().__init__(dataset)
|
| 76 |
+
self.pad_idx = pad_idx
|
| 77 |
+
self.left_pad = left_pad
|
| 78 |
+
|
| 79 |
+
def collater(self, samples):
|
| 80 |
+
return collate_cross_2d(
|
| 81 |
+
samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
|
| 82 |
+
)
|
unimol/data/cropping_dataset.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
import logging
|
| 8 |
+
from unicore.data import BaseWrapperDataset
|
| 9 |
+
from . import data_utils
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CroppingDataset(BaseWrapperDataset):
|
| 15 |
+
def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
|
| 16 |
+
self.dataset = dataset
|
| 17 |
+
self.seed = seed
|
| 18 |
+
self.atoms = atoms
|
| 19 |
+
self.coordinates = coordinates
|
| 20 |
+
self.max_atoms = max_atoms
|
| 21 |
+
self.set_epoch(None)
|
| 22 |
+
|
| 23 |
+
def set_epoch(self, epoch, **unused):
|
| 24 |
+
super().set_epoch(epoch)
|
| 25 |
+
self.epoch = epoch
|
| 26 |
+
|
| 27 |
+
@lru_cache(maxsize=16)
|
| 28 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 29 |
+
dd = self.dataset[index].copy()
|
| 30 |
+
atoms = dd[self.atoms]
|
| 31 |
+
coordinates = dd[self.coordinates]
|
| 32 |
+
if self.max_atoms and len(atoms) > self.max_atoms:
|
| 33 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 34 |
+
index = np.random.choice(len(atoms), self.max_atoms, replace=False)
|
| 35 |
+
atoms = np.array(atoms)[index]
|
| 36 |
+
coordinates = coordinates[index]
|
| 37 |
+
dd[self.atoms] = atoms
|
| 38 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 39 |
+
return dd
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, index: int):
|
| 42 |
+
return self.__cached_item__(index, self.epoch)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CroppingPocketDataset(BaseWrapperDataset):
|
| 46 |
+
def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
|
| 47 |
+
self.dataset = dataset
|
| 48 |
+
self.seed = seed
|
| 49 |
+
self.atoms = atoms
|
| 50 |
+
self.coordinates = coordinates
|
| 51 |
+
self.max_atoms = (
|
| 52 |
+
max_atoms # max number of atoms in a molecule, None indicates no limit.
|
| 53 |
+
)
|
| 54 |
+
self.set_epoch(None)
|
| 55 |
+
|
| 56 |
+
def set_epoch(self, epoch, **unused):
|
| 57 |
+
super().set_epoch(epoch)
|
| 58 |
+
self.epoch = epoch
|
| 59 |
+
|
| 60 |
+
@lru_cache(maxsize=16)
|
| 61 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 62 |
+
dd = self.dataset[index].copy()
|
| 63 |
+
atoms = dd[self.atoms]
|
| 64 |
+
coordinates = dd[self.coordinates]
|
| 65 |
+
#residue = dd["residue"]
|
| 66 |
+
|
| 67 |
+
# crop atoms according to their distance to the center of pockets
|
| 68 |
+
if self.max_atoms and len(atoms) > self.max_atoms:
|
| 69 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 70 |
+
distance = np.linalg.norm(
|
| 71 |
+
coordinates - coordinates.mean(axis=0), axis=1
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def softmax(x):
|
| 75 |
+
x -= np.max(x)
|
| 76 |
+
x = np.exp(x) / np.sum(np.exp(x))
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
distance += 1 # prevent inf
|
| 80 |
+
weight = softmax(np.reciprocal(distance))
|
| 81 |
+
index = np.random.choice(
|
| 82 |
+
len(atoms), self.max_atoms, replace=False, p=weight
|
| 83 |
+
)
|
| 84 |
+
atoms = atoms[index]
|
| 85 |
+
coordinates = coordinates[index]
|
| 86 |
+
#residue = residue[index]
|
| 87 |
+
|
| 88 |
+
dd[self.atoms] = atoms
|
| 89 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 90 |
+
#dd["residue"] = residue
|
| 91 |
+
return dd
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, index: int):
|
| 94 |
+
return self.__cached_item__(index, self.epoch)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CroppingResiduePocketDataset(BaseWrapperDataset):
|
| 98 |
+
def __init__(self, dataset, seed, atoms, residues, coordinates, max_atoms=256):
|
| 99 |
+
self.dataset = dataset
|
| 100 |
+
self.seed = seed
|
| 101 |
+
self.atoms = atoms
|
| 102 |
+
self.residues = residues
|
| 103 |
+
self.coordinates = coordinates
|
| 104 |
+
self.max_atoms = (
|
| 105 |
+
max_atoms # max number of atoms in a molecule, None indicates no limit.
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.set_epoch(None)
|
| 109 |
+
|
| 110 |
+
def set_epoch(self, epoch, **unused):
|
| 111 |
+
super().set_epoch(epoch)
|
| 112 |
+
self.epoch = epoch
|
| 113 |
+
|
| 114 |
+
@lru_cache(maxsize=16)
|
| 115 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 116 |
+
dd = self.dataset[index].copy()
|
| 117 |
+
atoms = dd[self.atoms]
|
| 118 |
+
residues = dd[self.residues]
|
| 119 |
+
coordinates = dd[self.coordinates]
|
| 120 |
+
|
| 121 |
+
residues_distance_map = {}
|
| 122 |
+
|
| 123 |
+
# crop atoms according to their distance to the center of pockets
|
| 124 |
+
if self.max_atoms and len(atoms) > self.max_atoms:
|
| 125 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 126 |
+
distance = np.linalg.norm(
|
| 127 |
+
coordinates - coordinates.mean(axis=0), axis=1
|
| 128 |
+
)
|
| 129 |
+
residues_ids, residues_distance = [], []
|
| 130 |
+
for res in residues:
|
| 131 |
+
if res not in residues_ids:
|
| 132 |
+
residues_ids.append(res)
|
| 133 |
+
residues_distance.append(distance[residues == res].mean())
|
| 134 |
+
residues_ids = np.array(residues_ids)
|
| 135 |
+
residues_distance = np.array(residues_distance)
|
| 136 |
+
|
| 137 |
+
def softmax(x):
|
| 138 |
+
x -= np.max(x)
|
| 139 |
+
x = np.exp(x) / np.sum(np.exp(x))
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
residues_distance += 1 # prevent inf and smoothing out the distance
|
| 143 |
+
weight = softmax(np.reciprocal(residues_distance))
|
| 144 |
+
max_residues = self.max_atoms // (len(atoms) // (len(residues_ids) + 1))
|
| 145 |
+
if max_residues < 1:
|
| 146 |
+
max_residues += 1
|
| 147 |
+
max_residues = min(max_residues, len(residues_ids))
|
| 148 |
+
residue_index = np.random.choice(
|
| 149 |
+
len(residues_ids), max_residues, replace=False, p=weight
|
| 150 |
+
)
|
| 151 |
+
index = [
|
| 152 |
+
i
|
| 153 |
+
for i in range(len(atoms))
|
| 154 |
+
if residues[i] in residues_ids[residue_index]
|
| 155 |
+
]
|
| 156 |
+
atoms = atoms[index]
|
| 157 |
+
coordinates = coordinates[index]
|
| 158 |
+
residues = residues[index]
|
| 159 |
+
|
| 160 |
+
dd[self.atoms] = atoms
|
| 161 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 162 |
+
dd[self.residues] = residues
|
| 163 |
+
return dd
|
| 164 |
+
|
| 165 |
+
def __getitem__(self, index: int):
|
| 166 |
+
return self.__cached_item__(index, self.epoch)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class CroppingPocketDockingPoseDataset(BaseWrapperDataset):
|
| 170 |
+
def __init__(
|
| 171 |
+
self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256
|
| 172 |
+
):
|
| 173 |
+
self.dataset = dataset
|
| 174 |
+
self.seed = seed
|
| 175 |
+
self.atoms = atoms
|
| 176 |
+
self.coordinates = coordinates
|
| 177 |
+
self.max_atoms = max_atoms
|
| 178 |
+
|
| 179 |
+
self.set_epoch(None)
|
| 180 |
+
|
| 181 |
+
def set_epoch(self, epoch, **unused):
|
| 182 |
+
super().set_epoch(epoch)
|
| 183 |
+
self.epoch = epoch
|
| 184 |
+
|
| 185 |
+
@lru_cache(maxsize=16)
|
| 186 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 187 |
+
dd = self.dataset[index].copy()
|
| 188 |
+
atoms = dd[self.atoms]
|
| 189 |
+
coordinates = dd[self.coordinates]
|
| 190 |
+
holo_coordinates = dd[self.holo_coordinates]
|
| 191 |
+
|
| 192 |
+
# crop atoms according to their distance to the center of pockets
|
| 193 |
+
#print(len(atoms))
|
| 194 |
+
if self.max_atoms and len(atoms) > self.max_atoms:
|
| 195 |
+
with data_utils.numpy_seed(self.seed, 1):
|
| 196 |
+
distance = np.linalg.norm(
|
| 197 |
+
coordinates - coordinates.mean(axis=0), axis=1
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def softmax(x):
|
| 201 |
+
x -= np.max(x)
|
| 202 |
+
x = np.exp(x) / np.sum(np.exp(x))
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
distance += 1 # prevent inf
|
| 206 |
+
weight = softmax(np.reciprocal(distance))
|
| 207 |
+
index = np.random.choice(
|
| 208 |
+
len(atoms), self.max_atoms, replace=False, p=weight
|
| 209 |
+
)
|
| 210 |
+
atoms = atoms[index]
|
| 211 |
+
coordinates = coordinates[index]
|
| 212 |
+
holo_coordinates = holo_coordinates[index]
|
| 213 |
+
|
| 214 |
+
dd[self.atoms] = atoms
|
| 215 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 216 |
+
dd[self.holo_coordinates] = holo_coordinates.astype(np.float32)
|
| 217 |
+
return dd
|
| 218 |
+
|
| 219 |
+
def __getitem__(self, index: int):
|
| 220 |
+
return self.__cached_item__(index, self.epoch)
|
| 221 |
+
|
| 222 |
+
class CroppingPocketDockingPoseTestDataset(BaseWrapperDataset):
|
| 223 |
+
def __init__(
|
| 224 |
+
self, dataset, seed, atoms, coordinates, max_atoms=256
|
| 225 |
+
):
|
| 226 |
+
self.dataset = dataset
|
| 227 |
+
self.seed = seed
|
| 228 |
+
self.atoms = atoms
|
| 229 |
+
self.coordinates = coordinates
|
| 230 |
+
self.max_atoms = max_atoms
|
| 231 |
+
|
| 232 |
+
self.set_epoch(None)
|
| 233 |
+
|
| 234 |
+
def set_epoch(self, epoch, **unused):
|
| 235 |
+
super().set_epoch(epoch)
|
| 236 |
+
self.epoch = epoch
|
| 237 |
+
|
| 238 |
+
@lru_cache(maxsize=16)
|
| 239 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 240 |
+
dd = self.dataset[index].copy()
|
| 241 |
+
atoms = dd[self.atoms]
|
| 242 |
+
coordinates = dd[self.coordinates]
|
| 243 |
+
|
| 244 |
+
# crop atoms according to their distance to the center of pockets
|
| 245 |
+
if self.max_atoms and len(atoms) > self.max_atoms:
|
| 246 |
+
with data_utils.numpy_seed(1, 1):
|
| 247 |
+
distance = np.linalg.norm(
|
| 248 |
+
coordinates - coordinates.mean(axis=0), axis=1
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def softmax(x):
|
| 252 |
+
x -= np.max(x)
|
| 253 |
+
x = np.exp(x) / np.sum(np.exp(x))
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
distance += 1 # prevent inf
|
| 257 |
+
weight = softmax(np.reciprocal(distance))
|
| 258 |
+
index = np.random.choice(
|
| 259 |
+
len(atoms), self.max_atoms, replace=False, p=weight
|
| 260 |
+
)
|
| 261 |
+
atoms = atoms[index]
|
| 262 |
+
coordinates = coordinates[index]
|
| 263 |
+
|
| 264 |
+
dd[self.atoms] = atoms
|
| 265 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 266 |
+
return dd
|
| 267 |
+
|
| 268 |
+
def __getitem__(self, index: int):
|
| 269 |
+
return self.__cached_item__(index, self.epoch)
|
unimol/data/data_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import contextlib
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@contextlib.contextmanager
|
| 10 |
+
def numpy_seed(seed, *addl_seeds):
|
| 11 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
| 12 |
+
restores the state afterward"""
|
| 13 |
+
if seed is None:
|
| 14 |
+
yield
|
| 15 |
+
return
|
| 16 |
+
if len(addl_seeds) > 0:
|
| 17 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
| 18 |
+
state = np.random.get_state()
|
| 19 |
+
np.random.seed(seed)
|
| 20 |
+
try:
|
| 21 |
+
yield
|
| 22 |
+
finally:
|
| 23 |
+
np.random.set_state(state)
|
unimol/data/dictionary.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 11 |
+
|
| 12 |
+
class DecoderDictionary:
|
| 13 |
+
"""A mapping from symbols to consecutive integers"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
*, # begin keyword-only arguments
|
| 18 |
+
bos="[CLS]",
|
| 19 |
+
pad="[PAD]",
|
| 20 |
+
eos="[SEP]",
|
| 21 |
+
unk="[UNK]",
|
| 22 |
+
extra_special_symbols=None,
|
| 23 |
+
):
|
| 24 |
+
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
| 25 |
+
self.symbols = []
|
| 26 |
+
self.count = []
|
| 27 |
+
self.indices = {}
|
| 28 |
+
self.idx2sym = {}
|
| 29 |
+
self.specials = set()
|
| 30 |
+
self.specials.add(bos)
|
| 31 |
+
self.specials.add(unk)
|
| 32 |
+
self.specials.add(pad)
|
| 33 |
+
self.specials.add(eos)
|
| 34 |
+
|
| 35 |
+
def __eq__(self, other):
|
| 36 |
+
return self.indices == other.indices
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
if idx < len(self.symbols):
|
| 40 |
+
return self.symbols[idx]
|
| 41 |
+
return self.unk_word
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
"""Returns the number of symbols in the dictionary"""
|
| 45 |
+
return len(self.symbols)
|
| 46 |
+
|
| 47 |
+
def __contains__(self, sym):
|
| 48 |
+
return sym in self.indices
|
| 49 |
+
|
| 50 |
+
def vec_index(self, a):
|
| 51 |
+
return np.vectorize(self.index)(a)
|
| 52 |
+
|
| 53 |
+
def index(self, sym):
|
| 54 |
+
"""Returns the index of the specified symbol"""
|
| 55 |
+
assert isinstance(sym, str)
|
| 56 |
+
if sym in self.indices:
|
| 57 |
+
return self.indices[sym]
|
| 58 |
+
return self.indices[self.unk_word]
|
| 59 |
+
|
| 60 |
+
def index2symbol(self, idx):
|
| 61 |
+
"""Returns the corresponding symbol of the specified index"""
|
| 62 |
+
assert isinstance(idx, int)
|
| 63 |
+
if idx in self.idx2sym:
|
| 64 |
+
return self.idx2sym[idx]
|
| 65 |
+
return self.unk_word
|
| 66 |
+
|
| 67 |
+
def special_index(self):
|
| 68 |
+
return [self.index(x) for x in self.specials]
|
| 69 |
+
|
| 70 |
+
def add_symbol(self, word, n=1, overwrite=False, is_special=False):
|
| 71 |
+
"""Adds a word to the dictionary"""
|
| 72 |
+
if is_special:
|
| 73 |
+
self.specials.add(word)
|
| 74 |
+
if word in self.indices and not overwrite:
|
| 75 |
+
idx = self.indices[word]
|
| 76 |
+
self.count[idx] = self.count[idx] + n
|
| 77 |
+
return idx
|
| 78 |
+
else:
|
| 79 |
+
idx = len(self.symbols)
|
| 80 |
+
self.indices[word] = idx
|
| 81 |
+
self.idx2sym[idx] = word
|
| 82 |
+
self.symbols.append(word)
|
| 83 |
+
self.count.append(n)
|
| 84 |
+
return idx
|
| 85 |
+
|
| 86 |
+
def bos(self):
|
| 87 |
+
"""Helper to get index of beginning-of-sentence symbol"""
|
| 88 |
+
return self.index(self.bos_word)
|
| 89 |
+
|
| 90 |
+
def pad(self):
|
| 91 |
+
"""Helper to get index of pad symbol"""
|
| 92 |
+
return self.index(self.pad_word)
|
| 93 |
+
|
| 94 |
+
def eos(self):
|
| 95 |
+
"""Helper to get index of end-of-sentence symbol"""
|
| 96 |
+
return self.index(self.eos_word)
|
| 97 |
+
|
| 98 |
+
def unk(self):
|
| 99 |
+
"""Helper to get index of unk symbol"""
|
| 100 |
+
return self.index(self.unk_word)
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def load(cls, f):
|
| 104 |
+
"""Loads the dictionary from a text file with the format:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
<symbol0> <count0>
|
| 108 |
+
<symbol1> <count1>
|
| 109 |
+
...
|
| 110 |
+
```
|
| 111 |
+
"""
|
| 112 |
+
d = cls()
|
| 113 |
+
d.add_from_file(f)
|
| 114 |
+
return d
|
| 115 |
+
|
| 116 |
+
def add_from_file(self, f):
|
| 117 |
+
"""
|
| 118 |
+
Loads a pre-existing dictionary from a text file and adds its symbols
|
| 119 |
+
to this instance.
|
| 120 |
+
"""
|
| 121 |
+
if isinstance(f, str):
|
| 122 |
+
try:
|
| 123 |
+
with open(f, "r", encoding="utf-8") as fd:
|
| 124 |
+
self.add_from_file(fd)
|
| 125 |
+
except FileNotFoundError as fnfe:
|
| 126 |
+
raise fnfe
|
| 127 |
+
except UnicodeError:
|
| 128 |
+
raise Exception(
|
| 129 |
+
"Incorrect encoding detected in {}, please "
|
| 130 |
+
"rebuild the dataset".format(f)
|
| 131 |
+
)
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
lines = f.readlines()
|
| 135 |
+
|
| 136 |
+
for line_idx, line in enumerate(lines):
|
| 137 |
+
try:
|
| 138 |
+
splits = line.rstrip().rsplit(" ", 1)
|
| 139 |
+
line = splits[0]
|
| 140 |
+
field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
|
| 141 |
+
if field == "#overwrite":
|
| 142 |
+
overwrite = True
|
| 143 |
+
line, field = line.rsplit(" ", 1)
|
| 144 |
+
else:
|
| 145 |
+
overwrite = False
|
| 146 |
+
count = int(field)
|
| 147 |
+
word = line
|
| 148 |
+
if word in self and not overwrite:
|
| 149 |
+
logger.info(
|
| 150 |
+
"Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
self.add_symbol(word, n=count, overwrite=overwrite)
|
| 154 |
+
except ValueError:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
|
| 157 |
+
)
|
unimol/data/distance_dataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from scipy.spatial import distance_matrix
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
from unicore.data import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DistanceDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset):
|
| 14 |
+
super().__init__(dataset)
|
| 15 |
+
self.dataset = dataset
|
| 16 |
+
|
| 17 |
+
@lru_cache(maxsize=16)
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
pos = self.dataset[idx].view(-1, 3).numpy()
|
| 20 |
+
dist = distance_matrix(pos, pos).astype(np.float32)
|
| 21 |
+
return torch.from_numpy(dist)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EdgeTypeDataset(BaseWrapperDataset):
|
| 25 |
+
def __init__(self, dataset: torch.utils.data.Dataset, num_types: int):
|
| 26 |
+
self.dataset = dataset
|
| 27 |
+
self.num_types = num_types
|
| 28 |
+
|
| 29 |
+
@lru_cache(maxsize=16)
|
| 30 |
+
def __getitem__(self, index: int):
|
| 31 |
+
node_input = self.dataset[index].clone()
|
| 32 |
+
offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1)
|
| 33 |
+
return offset
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CrossDistanceDataset(BaseWrapperDataset):
|
| 37 |
+
def __init__(self, mol_dataset, pocket_dataset):
|
| 38 |
+
super().__init__(mol_dataset)
|
| 39 |
+
self.dataset = mol_dataset
|
| 40 |
+
self.mol_dataset = mol_dataset
|
| 41 |
+
self.pocket_dataset = pocket_dataset
|
| 42 |
+
|
| 43 |
+
@lru_cache(maxsize=16)
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
mol_pos = self.mol_dataset[idx].view(-1, 3).numpy()
|
| 46 |
+
pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy()
|
| 47 |
+
dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32)
|
| 48 |
+
assert dist.shape[0] == self.mol_dataset[idx].shape[0]
|
| 49 |
+
assert dist.shape[1] == self.pocket_dataset[idx].shape[0]
|
| 50 |
+
return torch.from_numpy(dist)
|
| 51 |
+
|
| 52 |
+
class CrossEdgeTypeDataset(BaseWrapperDataset):
|
| 53 |
+
def __init__(self, mol_dataset, pocket_dataset, num_types: int):
|
| 54 |
+
self.dataset = mol_dataset
|
| 55 |
+
self.mol_dataset = mol_dataset
|
| 56 |
+
self.pocket_dataset = pocket_dataset
|
| 57 |
+
self.num_types = num_types
|
| 58 |
+
|
| 59 |
+
@lru_cache(maxsize=16)
|
| 60 |
+
def __getitem__(self, index: int):
|
| 61 |
+
mol_node_input = self.mol_dataset[index].clone()
|
| 62 |
+
pocket_node_input = self.pocket_dataset[index].clone()
|
| 63 |
+
offset = mol_node_input.view(-1, 1) * self.num_types + pocket_node_input.view(1, -1)
|
| 64 |
+
return offset
|
unimol/data/from_str_dataset.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from unicore.data import UnicoreDataset
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FromStrLabelDataset(UnicoreDataset):
|
| 7 |
+
def __init__(self, labels):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.labels = labels
|
| 10 |
+
|
| 11 |
+
@lru_cache(maxsize=16)
|
| 12 |
+
def __getitem__(self, index):
|
| 13 |
+
return self.labels[index]
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.labels)
|
| 17 |
+
|
| 18 |
+
def collater(self, samples):
|
| 19 |
+
return torch.tensor(list(map(float, samples)))
|
unimol/data/key_dataset.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from unicore.data import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class KeyDataset(BaseWrapperDataset):
|
| 10 |
+
def __init__(self, dataset, key):
|
| 11 |
+
self.dataset = dataset
|
| 12 |
+
self.key = key
|
| 13 |
+
|
| 14 |
+
def __len__(self):
|
| 15 |
+
return len(self.dataset)
|
| 16 |
+
|
| 17 |
+
@lru_cache(maxsize=16)
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
return self.dataset[idx][self.key]
|
| 20 |
+
|
| 21 |
+
class LengthDataset(BaseWrapperDataset):
|
| 22 |
+
|
| 23 |
+
def __init__(self, dataset):
|
| 24 |
+
super().__init__(dataset)
|
| 25 |
+
|
| 26 |
+
@lru_cache(maxsize=16)
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
item = self.dataset[idx]
|
| 29 |
+
return len(item)
|
unimol/data/lmdb_dataset.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import lmdb
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LMDBDataset:
|
| 16 |
+
def __init__(self, db_path):
|
| 17 |
+
self.db_path = db_path
|
| 18 |
+
assert os.path.isfile(self.db_path), "{} not found".format(self.db_path)
|
| 19 |
+
env = self.connect_db(self.db_path)
|
| 20 |
+
with env.begin() as txn:
|
| 21 |
+
self._keys = list(txn.cursor().iternext(values=False))
|
| 22 |
+
|
| 23 |
+
def connect_db(self, lmdb_path, save_to_self=False):
|
| 24 |
+
env = lmdb.open(
|
| 25 |
+
lmdb_path,
|
| 26 |
+
subdir=False,
|
| 27 |
+
readonly=True,
|
| 28 |
+
lock=False,
|
| 29 |
+
readahead=False,
|
| 30 |
+
meminit=False,
|
| 31 |
+
max_readers=256,
|
| 32 |
+
)
|
| 33 |
+
if not save_to_self:
|
| 34 |
+
return env
|
| 35 |
+
else:
|
| 36 |
+
self.env = env
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self._keys)
|
| 40 |
+
|
| 41 |
+
@lru_cache(maxsize=16)
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
if not hasattr(self, "env"):
|
| 44 |
+
self.connect_db(self.db_path, save_to_self=True)
|
| 45 |
+
#datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii"))
|
| 46 |
+
#print(idx)
|
| 47 |
+
datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii"))
|
| 48 |
+
data = pickle.loads(datapoint_pickled)
|
| 49 |
+
return data
|
unimol/data/mask_points_dataset.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from unicore.data import Dictionary
|
| 10 |
+
from unicore.data import BaseWrapperDataset
|
| 11 |
+
from . import data_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MaskPointsDataset(BaseWrapperDataset):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dataset: torch.utils.data.Dataset,
|
| 18 |
+
coord_dataset: torch.utils.data.Dataset,
|
| 19 |
+
vocab: Dictionary,
|
| 20 |
+
pad_idx: int,
|
| 21 |
+
mask_idx: int,
|
| 22 |
+
noise_type: str,
|
| 23 |
+
noise: float = 1.0,
|
| 24 |
+
seed: int = 1,
|
| 25 |
+
mask_prob: float = 0.15,
|
| 26 |
+
leave_unmasked_prob: float = 0.1,
|
| 27 |
+
random_token_prob: float = 0.1,
|
| 28 |
+
):
|
| 29 |
+
assert 0.0 < mask_prob < 1.0
|
| 30 |
+
assert 0.0 <= random_token_prob <= 1.0
|
| 31 |
+
assert 0.0 <= leave_unmasked_prob <= 1.0
|
| 32 |
+
assert random_token_prob + leave_unmasked_prob <= 1.0
|
| 33 |
+
|
| 34 |
+
self.dataset = dataset
|
| 35 |
+
self.coord_dataset = coord_dataset
|
| 36 |
+
self.vocab = vocab
|
| 37 |
+
self.pad_idx = pad_idx
|
| 38 |
+
self.mask_idx = mask_idx
|
| 39 |
+
self.noise_type = noise_type
|
| 40 |
+
self.noise = noise
|
| 41 |
+
self.seed = seed
|
| 42 |
+
self.mask_prob = mask_prob
|
| 43 |
+
self.leave_unmasked_prob = leave_unmasked_prob
|
| 44 |
+
self.random_token_prob = random_token_prob
|
| 45 |
+
|
| 46 |
+
if random_token_prob > 0.0:
|
| 47 |
+
weights = np.ones(len(self.vocab))
|
| 48 |
+
weights[vocab.special_index()] = 0
|
| 49 |
+
self.weights = weights / weights.sum()
|
| 50 |
+
|
| 51 |
+
self.epoch = None
|
| 52 |
+
if self.noise_type == "trunc_normal":
|
| 53 |
+
self.noise_f = lambda num_mask: np.clip(
|
| 54 |
+
np.random.randn(num_mask, 3) * self.noise,
|
| 55 |
+
a_min=-self.noise * 2.0,
|
| 56 |
+
a_max=self.noise * 2.0,
|
| 57 |
+
)
|
| 58 |
+
elif self.noise_type == "normal":
|
| 59 |
+
self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
|
| 60 |
+
elif self.noise_type == "uniform":
|
| 61 |
+
self.noise_f = lambda num_mask: np.random.uniform(
|
| 62 |
+
low=-self.noise, high=self.noise, size=(num_mask, 3)
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.noise_f = lambda num_mask: 0.0
|
| 66 |
+
|
| 67 |
+
def set_epoch(self, epoch, **unused):
|
| 68 |
+
super().set_epoch(epoch)
|
| 69 |
+
self.coord_dataset.set_epoch(epoch)
|
| 70 |
+
self.dataset.set_epoch(epoch)
|
| 71 |
+
self.epoch = epoch
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, index: int):
|
| 74 |
+
return self.__getitem_cached__(self.epoch, index)
|
| 75 |
+
|
| 76 |
+
@lru_cache(maxsize=16)
|
| 77 |
+
def __getitem_cached__(self, epoch: int, index: int):
|
| 78 |
+
ret = {}
|
| 79 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 80 |
+
item = self.dataset[index]
|
| 81 |
+
coord = self.coord_dataset[index]
|
| 82 |
+
sz = len(item)
|
| 83 |
+
# don't allow empty sequence
|
| 84 |
+
assert sz > 0
|
| 85 |
+
# decide elements to mask
|
| 86 |
+
num_mask = int(
|
| 87 |
+
# add a random number for probabilistic rounding
|
| 88 |
+
self.mask_prob * sz
|
| 89 |
+
+ np.random.rand()
|
| 90 |
+
)
|
| 91 |
+
mask_idc = np.random.choice(sz, num_mask, replace=False)
|
| 92 |
+
mask = np.full(sz, False)
|
| 93 |
+
mask[mask_idc] = True
|
| 94 |
+
ret["targets"] = np.full(len(mask), self.pad_idx)
|
| 95 |
+
ret["targets"][mask] = item[mask]
|
| 96 |
+
ret["targets"] = torch.from_numpy(ret["targets"]).long()
|
| 97 |
+
# decide unmasking and random replacement
|
| 98 |
+
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
|
| 99 |
+
if rand_or_unmask_prob > 0.0:
|
| 100 |
+
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
|
| 101 |
+
if self.random_token_prob == 0.0:
|
| 102 |
+
unmask = rand_or_unmask
|
| 103 |
+
rand_mask = None
|
| 104 |
+
elif self.leave_unmasked_prob == 0.0:
|
| 105 |
+
unmask = None
|
| 106 |
+
rand_mask = rand_or_unmask
|
| 107 |
+
else:
|
| 108 |
+
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
|
| 109 |
+
decision = np.random.rand(sz) < unmask_prob
|
| 110 |
+
unmask = rand_or_unmask & decision
|
| 111 |
+
rand_mask = rand_or_unmask & (~decision)
|
| 112 |
+
else:
|
| 113 |
+
unmask = rand_mask = None
|
| 114 |
+
|
| 115 |
+
if unmask is not None:
|
| 116 |
+
mask = mask ^ unmask
|
| 117 |
+
|
| 118 |
+
new_item = np.copy(item)
|
| 119 |
+
new_item[mask] = self.mask_idx
|
| 120 |
+
|
| 121 |
+
num_mask = mask.astype(np.int32).sum()
|
| 122 |
+
new_coord = np.copy(coord)
|
| 123 |
+
new_coord[mask, :] += self.noise_f(num_mask)
|
| 124 |
+
|
| 125 |
+
if rand_mask is not None:
|
| 126 |
+
num_rand = rand_mask.sum()
|
| 127 |
+
if num_rand > 0:
|
| 128 |
+
new_item[rand_mask] = np.random.choice(
|
| 129 |
+
len(self.vocab),
|
| 130 |
+
num_rand,
|
| 131 |
+
p=self.weights,
|
| 132 |
+
)
|
| 133 |
+
ret["atoms"] = torch.from_numpy(new_item).long()
|
| 134 |
+
ret["coordinates"] = torch.from_numpy(new_coord).float()
|
| 135 |
+
return ret
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MaskPointsPocketDataset(BaseWrapperDataset):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
dataset: torch.utils.data.Dataset,
|
| 142 |
+
coord_dataset: torch.utils.data.Dataset,
|
| 143 |
+
residue_dataset: torch.utils.data.Dataset,
|
| 144 |
+
vocab: Dictionary,
|
| 145 |
+
pad_idx: int,
|
| 146 |
+
mask_idx: int,
|
| 147 |
+
noise_type: str,
|
| 148 |
+
noise: float = 1.0,
|
| 149 |
+
seed: int = 1,
|
| 150 |
+
mask_prob: float = 0.15,
|
| 151 |
+
leave_unmasked_prob: float = 0.1,
|
| 152 |
+
random_token_prob: float = 0.1,
|
| 153 |
+
):
|
| 154 |
+
assert 0.0 < mask_prob < 1.0
|
| 155 |
+
assert 0.0 <= random_token_prob <= 1.0
|
| 156 |
+
assert 0.0 <= leave_unmasked_prob <= 1.0
|
| 157 |
+
assert random_token_prob + leave_unmasked_prob <= 1.0
|
| 158 |
+
|
| 159 |
+
self.dataset = dataset
|
| 160 |
+
self.coord_dataset = coord_dataset
|
| 161 |
+
self.residue_dataset = residue_dataset
|
| 162 |
+
self.vocab = vocab
|
| 163 |
+
self.pad_idx = pad_idx
|
| 164 |
+
self.mask_idx = mask_idx
|
| 165 |
+
self.noise_type = noise_type
|
| 166 |
+
self.noise = noise
|
| 167 |
+
self.seed = seed
|
| 168 |
+
self.mask_prob = mask_prob
|
| 169 |
+
self.leave_unmasked_prob = leave_unmasked_prob
|
| 170 |
+
self.random_token_prob = random_token_prob
|
| 171 |
+
|
| 172 |
+
if random_token_prob > 0.0:
|
| 173 |
+
weights = np.ones(len(self.vocab))
|
| 174 |
+
weights[vocab.special_index()] = 0
|
| 175 |
+
self.weights = weights / weights.sum()
|
| 176 |
+
|
| 177 |
+
self.epoch = None
|
| 178 |
+
if self.noise_type == "trunc_normal":
|
| 179 |
+
self.noise_f = lambda num_mask: np.clip(
|
| 180 |
+
np.random.randn(num_mask, 3) * self.noise,
|
| 181 |
+
a_min=-self.noise * 2.0,
|
| 182 |
+
a_max=self.noise * 2.0,
|
| 183 |
+
)
|
| 184 |
+
elif self.noise_type == "normal":
|
| 185 |
+
self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
|
| 186 |
+
elif self.noise_type == "uniform":
|
| 187 |
+
self.noise_f = lambda num_mask: np.random.uniform(
|
| 188 |
+
low=-self.noise, high=self.noise, size=(num_mask, 3)
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
self.noise_f = lambda num_mask: 0.0
|
| 192 |
+
|
| 193 |
+
def set_epoch(self, epoch, **unused):
|
| 194 |
+
super().set_epoch(epoch)
|
| 195 |
+
self.coord_dataset.set_epoch(epoch)
|
| 196 |
+
self.dataset.set_epoch(epoch)
|
| 197 |
+
self.epoch = epoch
|
| 198 |
+
|
| 199 |
+
def __getitem__(self, index: int):
|
| 200 |
+
return self.__getitem_cached__(self.epoch, index)
|
| 201 |
+
|
| 202 |
+
@lru_cache(maxsize=16)
|
| 203 |
+
def __getitem_cached__(self, epoch: int, index: int):
|
| 204 |
+
ret = {}
|
| 205 |
+
with data_utils.numpy_seed(self.seed, epoch, index):
|
| 206 |
+
item = self.dataset[index]
|
| 207 |
+
coord = self.coord_dataset[index]
|
| 208 |
+
sz = len(item)
|
| 209 |
+
# don't allow empty sequence
|
| 210 |
+
assert sz > 0
|
| 211 |
+
|
| 212 |
+
# mask on the level of residues
|
| 213 |
+
residue = self.residue_dataset[index]
|
| 214 |
+
res_list = list(set(residue))
|
| 215 |
+
res_sz = len(res_list)
|
| 216 |
+
|
| 217 |
+
# decide elements to mask
|
| 218 |
+
num_mask = int(
|
| 219 |
+
# add a random number for probabilistic rounding
|
| 220 |
+
self.mask_prob * res_sz
|
| 221 |
+
+ np.random.rand()
|
| 222 |
+
)
|
| 223 |
+
mask_res = np.random.choice(res_list, num_mask, replace=False).tolist()
|
| 224 |
+
mask = np.isin(residue, mask_res)
|
| 225 |
+
|
| 226 |
+
ret["targets"] = np.full(len(mask), self.pad_idx)
|
| 227 |
+
ret["targets"][mask] = item[mask]
|
| 228 |
+
ret["targets"] = torch.from_numpy(ret["targets"]).long()
|
| 229 |
+
# decide unmasking and random replacement
|
| 230 |
+
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
|
| 231 |
+
if rand_or_unmask_prob > 0.0:
|
| 232 |
+
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
|
| 233 |
+
if self.random_token_prob == 0.0:
|
| 234 |
+
unmask = rand_or_unmask
|
| 235 |
+
rand_mask = None
|
| 236 |
+
elif self.leave_unmasked_prob == 0.0:
|
| 237 |
+
unmask = None
|
| 238 |
+
rand_mask = rand_or_unmask
|
| 239 |
+
else:
|
| 240 |
+
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
|
| 241 |
+
decision = np.random.rand(sz) < unmask_prob
|
| 242 |
+
unmask = rand_or_unmask & decision
|
| 243 |
+
rand_mask = rand_or_unmask & (~decision)
|
| 244 |
+
else:
|
| 245 |
+
unmask = rand_mask = None
|
| 246 |
+
|
| 247 |
+
if unmask is not None:
|
| 248 |
+
mask = mask ^ unmask
|
| 249 |
+
|
| 250 |
+
new_item = np.copy(item)
|
| 251 |
+
new_item[mask] = self.mask_idx
|
| 252 |
+
|
| 253 |
+
num_mask = mask.astype(np.int32).sum()
|
| 254 |
+
new_coord = np.copy(coord)
|
| 255 |
+
new_coord[mask, :] += self.noise_f(num_mask)
|
| 256 |
+
|
| 257 |
+
if rand_mask is not None:
|
| 258 |
+
num_rand = rand_mask.sum()
|
| 259 |
+
if num_rand > 0:
|
| 260 |
+
new_item[rand_mask] = np.random.choice(
|
| 261 |
+
len(self.vocab),
|
| 262 |
+
num_rand,
|
| 263 |
+
p=self.weights,
|
| 264 |
+
)
|
| 265 |
+
ret["atoms"] = torch.from_numpy(new_item).long()
|
| 266 |
+
ret["coordinates"] = torch.from_numpy(new_coord).float()
|
| 267 |
+
return ret
|
unimol/data/normalize_dataset.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) DP Technology.
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from unicore.data import BaseWrapperDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormalizeDataset(BaseWrapperDataset):
|
| 11 |
+
def __init__(self, dataset, coordinates, normalize_coord=True):
|
| 12 |
+
self.dataset = dataset
|
| 13 |
+
self.coordinates = coordinates
|
| 14 |
+
self.normalize_coord = normalize_coord # normalize the coordinates.
|
| 15 |
+
self.set_epoch(None)
|
| 16 |
+
|
| 17 |
+
def set_epoch(self, epoch, **unused):
|
| 18 |
+
super().set_epoch(epoch)
|
| 19 |
+
self.epoch = epoch
|
| 20 |
+
|
| 21 |
+
@lru_cache(maxsize=16)
|
| 22 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 23 |
+
dd = self.dataset[index].copy()
|
| 24 |
+
coordinates = dd[self.coordinates]
|
| 25 |
+
# normalize
|
| 26 |
+
if self.normalize_coord:
|
| 27 |
+
coordinates = coordinates - coordinates.mean(axis=0)
|
| 28 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 29 |
+
return dd
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, index: int):
|
| 32 |
+
return self.__cached_item__(index, self.epoch)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class NormalizeDockingPoseDataset(BaseWrapperDataset):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dataset,
|
| 39 |
+
coordinates,
|
| 40 |
+
pocket_coordinates,
|
| 41 |
+
center_coordinates="center_coordinates",
|
| 42 |
+
):
|
| 43 |
+
self.dataset = dataset
|
| 44 |
+
self.coordinates = coordinates
|
| 45 |
+
self.pocket_coordinates = pocket_coordinates
|
| 46 |
+
self.center_coordinates = center_coordinates
|
| 47 |
+
self.set_epoch(None)
|
| 48 |
+
|
| 49 |
+
def set_epoch(self, epoch, **unused):
|
| 50 |
+
super().set_epoch(epoch)
|
| 51 |
+
self.epoch = epoch
|
| 52 |
+
|
| 53 |
+
@lru_cache(maxsize=16)
|
| 54 |
+
def __cached_item__(self, index: int, epoch: int):
|
| 55 |
+
dd = self.dataset[index].copy()
|
| 56 |
+
coordinates = dd[self.coordinates]
|
| 57 |
+
pocket_coordinates = dd[self.pocket_coordinates]
|
| 58 |
+
# normalize coordinates and pocket coordinates ,align with pocket center coordinates
|
| 59 |
+
center_coordinates = pocket_coordinates.mean(axis=0)
|
| 60 |
+
coordinates = coordinates - center_coordinates
|
| 61 |
+
pocket_coordinates = pocket_coordinates - center_coordinates
|
| 62 |
+
dd[self.coordinates] = coordinates.astype(np.float32)
|
| 63 |
+
dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32)
|
| 64 |
+
dd[self.center_coordinates] = center_coordinates.astype(np.float32)
|
| 65 |
+
return dd
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index: int):
|
| 68 |
+
return self.__cached_item__(index, self.epoch)
|
unimol/data/pair_dataset.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from unicore.data import UnicoreDataset
|
| 9 |
+
import numpy as np
|
| 10 |
+
from . import data_utils
|
| 11 |
+
import rdkit
|
| 12 |
+
from rdkit import Chem
|
| 13 |
+
from rdkit import DataStructs
|
| 14 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 15 |
+
from multiprocessing import Pool
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
def get_fp(smiles):
|
| 19 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 20 |
+
fp_numpy = np.zeros((0,), np.int8) # Generate target pointer to fill
|
| 21 |
+
if mol is None:
|
| 22 |
+
return None
|
| 23 |
+
fingerprints_vect = rdFingerprintGenerator.GetCountFPs(
|
| 24 |
+
[mol], fpType=rdFingerprintGenerator.MorganFP
|
| 25 |
+
)[0]
|
| 26 |
+
DataStructs.ConvertToNumpyArray(fingerprints_vect, fp_numpy)
|
| 27 |
+
return fp_numpy
|
| 28 |
+
|
| 29 |
+
class PairDataset(UnicoreDataset):
|
| 30 |
+
def __init__(self, args, pocket_dataset, mol_dataset, labels, split, use_cache=True, cache_dir=None):
|
| 31 |
+
self.args = args
|
| 32 |
+
self.pocket_dataset = pocket_dataset
|
| 33 |
+
self.mol_dataset = mol_dataset
|
| 34 |
+
self.labels = labels
|
| 35 |
+
|
| 36 |
+
# use the cached file, or it will take loooooong time to load
|
| 37 |
+
if use_cache:
|
| 38 |
+
pocket_name2idx_file = f"{cache_dir}/cache/pocket_name2idx_train_blend.json"
|
| 39 |
+
if os.path.exists(pocket_name2idx_file):
|
| 40 |
+
self.pocket_name2idx = json.load(open(pocket_name2idx_file))
|
| 41 |
+
else:
|
| 42 |
+
self.pocket_name2idx = {x["pocket_name"]:i for i,x in enumerate(self.pocket_dataset)}
|
| 43 |
+
json.dump(self.pocket_name2idx, open(pocket_name2idx_file, "w"))
|
| 44 |
+
else:
|
| 45 |
+
self.pocket_name2idx = {x["pocket_name"]: i for i, x in enumerate(self.pocket_dataset)}
|
| 46 |
+
|
| 47 |
+
if use_cache:
|
| 48 |
+
mol_smi2idx_file = f"{cache_dir}/cache/mol_smi2idx_train_blend.json"
|
| 49 |
+
if os.path.exists(mol_smi2idx_file):
|
| 50 |
+
self.mol_smi2idx = json.load(open(mol_smi2idx_file))
|
| 51 |
+
else:
|
| 52 |
+
self.mol_smi2idx = {x["smi_name"]: i for i, x in enumerate(self.mol_dataset)}
|
| 53 |
+
json.dump(self.mol_smi2idx, open(mol_smi2idx_file, "w"))
|
| 54 |
+
else:
|
| 55 |
+
self.mol_smi2idx = {x["smi_name"]: i for i, x in enumerate(self.mol_dataset)}
|
| 56 |
+
|
| 57 |
+
uniprot_ids = [x["uniprot"] for x in labels]
|
| 58 |
+
self.uniprot_id_dict = {x:i for i,x in enumerate(set(uniprot_ids))}
|
| 59 |
+
self.split = split
|
| 60 |
+
if self.split == "train":
|
| 61 |
+
self.max_lignum = args.max_lignum # default=16
|
| 62 |
+
else:
|
| 63 |
+
self.max_lignum = args.test_max_lignum # default 512
|
| 64 |
+
|
| 65 |
+
if self.split == "train":
|
| 66 |
+
trainidxmap = []
|
| 67 |
+
for idx, assay_item in enumerate(self.labels):
|
| 68 |
+
lig_info = assay_item["ligands"]
|
| 69 |
+
trainidxmap += [idx]*math.ceil(len(lig_info)/max(self.max_lignum, 32))
|
| 70 |
+
self.trainidxmap = trainidxmap
|
| 71 |
+
|
| 72 |
+
self.epoch = 0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
if self.split == "train":
|
| 77 |
+
import os
|
| 78 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 79 |
+
div = self.args.batch_size * world_size
|
| 80 |
+
return (len(self.trainidxmap) // div) * div
|
| 81 |
+
else:
|
| 82 |
+
return len(self.labels)
|
| 83 |
+
|
| 84 |
+
def set_epoch(self, epoch):
|
| 85 |
+
self.epoch = epoch
|
| 86 |
+
self.pocket_dataset.set_epoch(epoch)
|
| 87 |
+
self.mol_dataset.set_epoch(epoch)
|
| 88 |
+
super().set_epoch(epoch)
|
| 89 |
+
|
| 90 |
+
def collater(self, samples):
|
| 91 |
+
ret_pocket = []
|
| 92 |
+
ret_lig = []
|
| 93 |
+
batch_list = []
|
| 94 |
+
act_list = []
|
| 95 |
+
uniprot_list = []
|
| 96 |
+
ret_protein = []
|
| 97 |
+
assay_id_list = []
|
| 98 |
+
|
| 99 |
+
if len(samples) == 0:
|
| 100 |
+
return {}
|
| 101 |
+
for pocket, ligs, acts, uniprot, assay_id, prot_seq in samples:
|
| 102 |
+
ret_pocket.append(pocket)
|
| 103 |
+
lignum_old = len(ret_lig)
|
| 104 |
+
ret_lig += ligs
|
| 105 |
+
batch_list.append([lignum_old, len(ret_lig)])
|
| 106 |
+
uniprot_list.append(self.uniprot_id_dict[uniprot])
|
| 107 |
+
assay_id_list.append(assay_id)
|
| 108 |
+
act_list.append(acts)
|
| 109 |
+
ret_protein.append(prot_seq)
|
| 110 |
+
|
| 111 |
+
ret_pocket = self.pocket_dataset.collater(ret_pocket)
|
| 112 |
+
ret_lig = self.mol_dataset.collater(ret_lig)
|
| 113 |
+
return {"pocket": ret_pocket, "lig": ret_lig, "protein": ret_protein,
|
| 114 |
+
"batch_list": batch_list, "act_list": act_list,
|
| 115 |
+
"uniprot_list": uniprot_list, "assay_id_list": assay_id_list}
|
| 116 |
+
|
| 117 |
+
# @lru_cache(maxsize=16)
|
| 118 |
+
def __getitem__(self, idx):
|
| 119 |
+
if self.split == "train":
|
| 120 |
+
t_idx = self.trainidxmap[idx]
|
| 121 |
+
else:
|
| 122 |
+
t_idx = idx
|
| 123 |
+
|
| 124 |
+
with data_utils.numpy_seed(1111, idx, self.epoch):
|
| 125 |
+
pocket_name = np.random.choice(self.labels[t_idx]["pockets"], 1, replace=False)[0]
|
| 126 |
+
|
| 127 |
+
lig_info = self.labels[t_idx]["ligands"]
|
| 128 |
+
lig_info = [x for x in lig_info if x["smi"] in self.mol_smi2idx]
|
| 129 |
+
uniprot = self.labels[t_idx]["uniprot"]
|
| 130 |
+
assay_id = self.labels[t_idx].get("assay_id", "none")
|
| 131 |
+
prot_seq = self.labels[t_idx]["sequence"]
|
| 132 |
+
if len(lig_info) > self.max_lignum:
|
| 133 |
+
with data_utils.numpy_seed(1111, idx, self.epoch):
|
| 134 |
+
lig_idxes = np.random.choice(list(range(len(lig_info))), self.max_lignum, replace=False)
|
| 135 |
+
lig_idxes = sorted(lig_idxes)
|
| 136 |
+
lig_info = [lig_info[idx] for idx in lig_idxes]
|
| 137 |
+
|
| 138 |
+
lig_idxes = [self.mol_smi2idx[info["smi"]] for info in lig_info]
|
| 139 |
+
pocket_idx = self.pocket_name2idx[pocket_name]
|
| 140 |
+
lig_act = [info["act"] for info in lig_info]
|
| 141 |
+
pocket_data = self.pocket_dataset[pocket_idx]
|
| 142 |
+
lig_data = [self.mol_dataset[x] for x in lig_idxes]
|
| 143 |
+
|
| 144 |
+
return pocket_data, lig_data, lig_act, uniprot, assay_id, prot_seq
|