Irwiny123 commited on
Commit
94391f2
·
1 Parent(s): a7b498a

提交LigUnity初始代码

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -35
  2. .gitignore +165 -0
  3. HGNN/Attention.py +36 -0
  4. HGNN/PL_Aggregator.py +75 -0
  5. HGNN/PL_Encoder.py +51 -0
  6. HGNN/PP_Aggregator.py +43 -0
  7. HGNN/PP_Encoder.py +51 -0
  8. HGNN/align.py +198 -0
  9. HGNN/data/CoreSet.dat +286 -0
  10. HGNN/data/PDBbind_v2020/index/INDEX_general_PL_data.2020 +0 -0
  11. HGNN/data/PDBbind_v2020/index/INDEX_general_PL_name.2020 +0 -0
  12. HGNN/data/PDBbind_v2020/index/INDEX_refined_data.2020 +0 -0
  13. HGNN/data/PDBbind_v2020/index/INDEX_refined_name.2020 +0 -0
  14. HGNN/main.py +318 -0
  15. HGNN/read_fasta.py +112 -0
  16. HGNN/screen_dataset.py +420 -0
  17. HGNN/screening.py +165 -0
  18. HGNN/test_pocket.fasta +2 -0
  19. HGNN/util.py +96 -0
  20. License +159 -0
  21. README.md +206 -3
  22. active_learning_scripts/run_al.sh +22 -0
  23. active_learning_scripts/run_cycle_ensemble.py +334 -0
  24. active_learning_scripts/run_cycle_one_model.py +246 -0
  25. active_learning_scripts/run_model.sh +53 -0
  26. ensemble_result.py +173 -0
  27. py_scripts/__init__.py +0 -0
  28. py_scripts/write_case_study.py +227 -0
  29. test.sh +18 -0
  30. test_fewshot.sh +38 -0
  31. test_fewshot_demo.sh +43 -0
  32. test_zeroshot_demo.sh +20 -0
  33. train.sh +145 -0
  34. unimol/__init__.py +6 -0
  35. unimol/data/__init__.py +50 -0
  36. unimol/data/add_2d_conformer_dataset.py +46 -0
  37. unimol/data/affinity_dataset.py +527 -0
  38. unimol/data/atom_type_dataset.py +34 -0
  39. unimol/data/conformer_sample_dataset.py +315 -0
  40. unimol/data/coord_pad_dataset.py +82 -0
  41. unimol/data/cropping_dataset.py +269 -0
  42. unimol/data/data_utils.py +23 -0
  43. unimol/data/dictionary.py +157 -0
  44. unimol/data/distance_dataset.py +64 -0
  45. unimol/data/from_str_dataset.py +19 -0
  46. unimol/data/key_dataset.py +29 -0
  47. unimol/data/lmdb_dataset.py +49 -0
  48. unimol/data/mask_points_dataset.py +267 -0
  49. unimol/data/normalize_dataset.py +68 -0
  50. unimol/data/pair_dataset.py +144 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ tmp/
163
+ **/*.ipynb
164
+ *.ipynb
165
+ results/
HGNN/Attention.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import numpy as np
5
+ import random
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Attention(nn.Module):
10
+ def __init__(self, embedding_dims):
11
+ super(Attention, self).__init__()
12
+ self.embed_dim = embedding_dims
13
+ self.bilinear = nn.Bilinear(self.embed_dim, self.embed_dim, 1)
14
+ self.att1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
15
+ self.att2 = nn.Linear(self.embed_dim, self.embed_dim)
16
+ self.att3 = nn.Linear(self.embed_dim, 1)
17
+
18
+ # self.linear_q = nn.Linear(self.embed_dim, self.embed_dim)
19
+ # self.linear_k = nn.Linear(self.embed_dim, self.embed_dim)
20
+ self.softmax = nn.Softmax(0)
21
+
22
+ def forward(self, node1, u_rep, num_neighs):
23
+ uv_reps = u_rep.repeat(num_neighs, 1)
24
+ x = torch.cat((node1, uv_reps), 1)
25
+ x = F.relu(self.att1(x))
26
+ x = F.dropout(x, training=self.training)
27
+ x = F.relu(self.att2(x))
28
+ x = F.dropout(x, training=self.training)
29
+ x = self.att3(x)
30
+ att = F.softmax(x, dim=0)
31
+
32
+ # u_rep = self.linear_q(u_rep).repeat(num_neighs, 1)
33
+ # node1 = self.linear_k(node1)
34
+ # att = torch.sum(u_rep * node1, dim=1)
35
+ # att = F.softmax(att, dim=0).unsqueeze(1)
36
+ return att
HGNN/PL_Aggregator.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import random
7
+ from Attention import Attention
8
+
9
+ class PLAggregator(nn.Module):
10
+ """
11
+ item and user aggregator: for aggregating embeddings of neighbors (item/user aggreagator).
12
+ """
13
+
14
+ def __init__(self, v2e=None, r2e=None, u2e=None, embed_dim=128, cuda="cpu", uv=True):
15
+ super(PLAggregator, self).__init__()
16
+ self.uv = uv
17
+ self.v2e = v2e
18
+ self.r2e = r2e
19
+ self.u2e = u2e
20
+ self.device = cuda
21
+ self.embed_dim = embed_dim
22
+ self.w_r1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
23
+ self.w_r2 = nn.Linear(self.embed_dim, self.embed_dim)
24
+ self.att = Attention(self.embed_dim)
25
+ if self.v2e is not None:
26
+ self.v2e.requires_grad = False
27
+ if self.u2e is not None:
28
+ self.u2e.requires_grad = False
29
+
30
+
31
+ def forward(self, nodes_u, input_hist):
32
+ embed_matrix = torch.zeros(len(input_hist), self.embed_dim, dtype=torch.float).to(self.device)
33
+
34
+ for i in range(len(input_hist)):
35
+ history = []
36
+ label = []
37
+ for idx in range(len(input_hist[i])):
38
+ vid_hist = input_hist[i][idx][0]
39
+ vlabel_hist = input_hist[i][idx][1]
40
+ history.append(vid_hist)
41
+ label.append(vlabel_hist)
42
+
43
+ num_histroy_item = len(history)
44
+
45
+ if num_histroy_item > 0:
46
+ e_uv = self.v2e.weight[history]
47
+ uv_rep = self.u2e.weight[nodes_u[i]]
48
+
49
+ e_r = self.r2e.weight[label]
50
+ x = torch.cat((e_uv, e_r), 1)
51
+ x = F.relu(self.w_r1(x))
52
+ o_history = F.relu(self.w_r2(x))
53
+
54
+ att_w = self.att(o_history, uv_rep, num_histroy_item)
55
+ # print([(a,b) for a,b in zip(label, att_w)])
56
+ att_history = torch.mm(o_history.t(), att_w)
57
+ att_history = att_history.t()
58
+
59
+ embed_matrix[i] = (att_history + uv_rep) / 2
60
+ else:
61
+ embed_matrix[i] = self.u2e.weight[nodes_u[i]]
62
+
63
+ return embed_matrix
64
+
65
+ def forward_inference(self, pocket_embed, neighbor_list):
66
+ neighbor_embed = torch.stack([x[1] for x in neighbor_list])
67
+ rel_embed = self.r2e.weight[torch.stack([x[2] for x in neighbor_list])]
68
+ x = torch.cat((neighbor_embed, rel_embed), 1)
69
+ x = F.relu(self.w_r1(x))
70
+ o_neighbor = F.relu(self.w_r2(x))
71
+
72
+ att_w = self.att(o_neighbor, pocket_embed, len(neighbor_list))
73
+ # print([(a,b) for a,b in zip(label, att_w)])
74
+ att_res = torch.mm(o_neighbor.t(), att_w).t()
75
+ return (att_res + pocket_embed) / 2
HGNN/PL_Encoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import torch.nn.functional as F
5
+ import random
6
+
7
+ class PLEncoder(nn.Module):
8
+
9
+ def __init__(self, embed_dim, pocket_graph=None, aggregator=None, idx2assayid={}, assayid_lst_train=[], mol_smi={}, train_label_lst=[], cuda="cpu", uv=True):
10
+ super(PLEncoder, self).__init__()
11
+
12
+ self.uv = uv
13
+ self.pocket_graph = pocket_graph
14
+ self.aggregator = aggregator
15
+ self.embed_dim = embed_dim
16
+ self.device = cuda
17
+ smi2idx = {smi:idx for idx, smi in enumerate(mol_smi)}
18
+ self.idx2assayid, self.assayid_lst_train, self.smi2idx, self.mol_smi, self.train_label_lst = idx2assayid, assayid_lst_train, smi2idx, mol_smi, train_label_lst
19
+ self.assayid_set_train = set(assayid_lst_train)
20
+ self.label_dicts = {x["assay_id"]: x for x in self.train_label_lst}
21
+ self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim) #
22
+
23
+ def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
24
+ to_neighs = []
25
+ if nodes_lig is None:
26
+ lig_smi_lst = ["----"] * len(nodes_pocket)
27
+ else:
28
+ lig_smi_lst = [self.mol_smi[lig_id] for lig_id in nodes_lig]
29
+
30
+ for node, smi in zip(nodes_pocket, lig_smi_lst):
31
+ assayid = self.idx2assayid[node]
32
+ neighbors = []
33
+ nbr_pockets = self.pocket_graph.get(assayid, [])
34
+ # random.shuffle(nbr_pockets)
35
+ # breakpoint()
36
+ for n_assayid, score in nbr_pockets:
37
+ nbr_smi = self.label_dicts[n_assayid]["ligands"][0]["smi"]
38
+ if assayid == n_assayid:
39
+ continue
40
+ if smi == nbr_smi:
41
+ continue
42
+ if n_assayid not in self.assayid_set_train:
43
+ continue
44
+ neighbors.append((self.smi2idx[nbr_smi], int((score - 0.5) * 10)))
45
+ to_neighs.append(neighbors)
46
+
47
+ neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-item network
48
+ return neigh_feats
49
+
50
+ def refine_pocket(self, pocket_embed, neighbor_list=None):
51
+ return self.aggregator.forward_inference(pocket_embed, neighbor_list)
HGNN/PP_Aggregator.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import numpy as np
5
+ import random
6
+ from Attention import Attention
7
+
8
+
9
+ class PPAggregator(nn.Module):
10
+ """
11
+ Social Aggregator: for aggregating embeddings of social neighbors.
12
+ """
13
+
14
+ def __init__(self, u2e=None, embed_dim=128, cuda="cpu"):
15
+ super(PPAggregator, self).__init__()
16
+ self.device = cuda
17
+ self.u2e = u2e
18
+ self.embed_dim = embed_dim
19
+ self.att = Attention(self.embed_dim)
20
+
21
+ def forward(self, nodes, to_neighs):
22
+ embed_matrix = torch.zeros(len(nodes), self.embed_dim, dtype=torch.float).to(self.device)
23
+ self_feats = self.u2e.weight[nodes]
24
+ for i in range(len(nodes)):
25
+ tmp_adj = to_neighs[i]
26
+
27
+ num_neighs = len(tmp_adj)
28
+
29
+ if num_neighs > 0:
30
+ e_u = self.u2e.weight[[x[0] for x in tmp_adj]] # fast: user embedding
31
+ u_rep = self.u2e.weight[nodes[i]]
32
+ att_w = self.att(e_u, u_rep, num_neighs)
33
+ att_history = torch.mm(e_u.t(), att_w).t()
34
+ embed_matrix[i] = (att_history + self_feats[i]) / 2
35
+ else:
36
+ embed_matrix[i] = self_feats[i]
37
+ return embed_matrix
38
+
39
+ def forward_inference(self, pocket_embed, neighbor_list):
40
+ neighbor_embed = torch.stack([x[0] for x in neighbor_list])
41
+ att_w = self.att(neighbor_embed, pocket_embed, len(neighbor_list))
42
+ att_res = torch.mm(neighbor_embed.t(), att_w).t()
43
+ return (att_res + pocket_embed) / 2
HGNN/PP_Encoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import torch.nn.functional as F
5
+ import random
6
+ import copy
7
+
8
+ class PPEncoder(nn.Module):
9
+
10
+ def __init__(self, pocket_encoder, embed_dim, pocket_graph=None, aggregator=None, assayid_lst_all=[], assayid_lst_train=[], base_model=None, cuda="cpu"):
11
+ super(PPEncoder, self).__init__()
12
+
13
+ self.pocket_encoder = pocket_encoder
14
+ self.pocket_graph = pocket_graph
15
+ self.aggregator = aggregator
16
+ if base_model != None:
17
+ self.base_model = base_model
18
+ self.embed_dim = embed_dim
19
+ self.device = cuda
20
+ self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)
21
+ self.assayid_lst_all, self.assayid_set_train = assayid_lst_all, set(assayid_lst_train)
22
+ self.assayid2idxes = {}
23
+ for idx, assayid in enumerate(assayid_lst_all):
24
+ if assayid not in self.assayid2idxes:
25
+ self.assayid2idxes[assayid] = []
26
+ self.assayid2idxes[assayid].append(idx)
27
+
28
+ def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
29
+ to_neighs = []
30
+
31
+ for node in nodes_pocket:
32
+ assayid = self.assayid_lst_all[node]
33
+ neighbors = []
34
+ nbr_pockets = self.pocket_graph.get(assayid, [])
35
+ for n_assayid, score in nbr_pockets:
36
+ if n_assayid == assayid:
37
+ continue
38
+ if n_assayid not in self.assayid_set_train:
39
+ continue
40
+ neighbors.append((random.choices(self.assayid2idxes[n_assayid])[0], score))
41
+ to_neighs.append(neighbors)
42
+
43
+ neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-user network
44
+ self_feats = self.pocket_encoder(nodes_pocket, nodes_lig, max_sample)
45
+
46
+ return (self_feats + neigh_feats) / 2
47
+
48
+ def refine_pocket(self, pocket_embed, neighbor_list=None):
49
+ neigh_feats = self.aggregator.forward_inference(pocket_embed, neighbor_list)
50
+ self_feats = self.pocket_encoder.refine_pocket(pocket_embed, neighbor_list)
51
+ return (self_feats + neigh_feats) / 2
HGNN/align.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import skbio
3
+ import json, pickle, os
4
+ from skbio import alignment
5
+ from skbio import Protein
6
+ from tqdm import tqdm
7
+ from multiprocessing import Pool
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ cutoff = 5.0
13
+ blosum50 = \
14
+ {
15
+ '*': {'*': 1, 'A': -5, 'C': -5, 'B': -5, 'E': -5, 'D': -5, 'G': -5,
16
+ 'F': -5, 'I': -5, 'H': -5, 'K': -5, 'M': -5, 'L': -5,
17
+ 'N': -5, 'Q': -5, 'P': -5, 'S': -5, 'R': -5, 'T': -5,
18
+ 'W': -5, 'V': -5, 'Y': -5, 'X': -5, 'Z': -5},
19
+ 'A': {'*': -5, 'A': 5, 'C': -1, 'B': -2, 'E': -1, 'D': -2, 'G': 0,
20
+ 'F': -3, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -2,
21
+ 'N': -1, 'Q': -1, 'P': -1, 'S': 1, 'R': -2, 'T': 0, 'W': -3,
22
+ 'V': 0, 'Y': -2, 'X': -1, 'Z': -1},
23
+ 'C': {'*': -5, 'A': -1, 'C': 13, 'B': -3, 'E': -3, 'D': -4,
24
+ 'G': -3, 'F': -2, 'I': -2, 'H': -3, 'K': -3, 'M': -2,
25
+ 'L': -2, 'N': -2, 'Q': -3, 'P': -4, 'S': -1, 'R': -4,
26
+ 'T': -1, 'W': -5, 'V': -1, 'Y': -3, 'X': -1, 'Z': -3},
27
+ 'B': {'*': -5, 'A': -2, 'C': -3, 'B': 6, 'E': 1, 'D': 6, 'G': -1,
28
+ 'F': -4, 'I': -4, 'H': 0, 'K': 0, 'M': -3, 'L': -4, 'N': 5,
29
+ 'Q': 0, 'P': -2, 'S': 0, 'R': -1, 'T': 0, 'W': -5, 'V': -3,
30
+ 'Y': -3, 'X': -1, 'Z': 1},
31
+ 'E': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 6, 'D': 2, 'G': -3,
32
+ 'F': -3, 'I': -4, 'H': 0, 'K': 1, 'M': -2, 'L': -3, 'N': 0,
33
+ 'Q': 2, 'P': -1, 'S': -1, 'R': 0, 'T': -1, 'W': -3, 'V': -3,
34
+ 'Y': -2, 'X': -1, 'Z': 5},
35
+ 'D': {'*': -5, 'A': -2, 'C': -4, 'B': 6, 'E': 2, 'D': 8, 'G': -1,
36
+ 'F': -5, 'I': -4, 'H': -1, 'K': -1, 'M': -4, 'L': -4, 'N': 2,
37
+ 'Q': 0, 'P': -1, 'S': 0, 'R': -2, 'T': -1, 'W': -5, 'V': -4,
38
+ 'Y': -3, 'X': -1, 'Z': 1},
39
+ 'G': {'*': -5, 'A': 0, 'C': -3, 'B': -1, 'E': -3, 'D': -1, 'G': 8,
40
+ 'F': -4, 'I': -4, 'H': -2, 'K': -2, 'M': -3, 'L': -4, 'N': 0,
41
+ 'Q': -2, 'P': -2, 'S': 0, 'R': -3, 'T': -2, 'W': -3, 'V': -4,
42
+ 'Y': -3, 'X': -1, 'Z': -2},
43
+ 'F': {'*': -5, 'A': -3, 'C': -2, 'B': -4, 'E': -3, 'D': -5,
44
+ 'G': -4, 'F': 8, 'I': 0, 'H': -1, 'K': -4, 'M': 0, 'L': 1,
45
+ 'N': -4, 'Q': -4, 'P': -4, 'S': -3, 'R': -3, 'T': -2, 'W': 1,
46
+ 'V': -1, 'Y': 4, 'X': -1, 'Z': -4},
47
+ 'I': {'*': -5, 'A': -1, 'C': -2, 'B': -4, 'E': -4, 'D': -4,
48
+ 'G': -4, 'F': 0, 'I': 5, 'H': -4, 'K': -3, 'M': 2, 'L': 2,
49
+ 'N': -3, 'Q': -3, 'P': -3, 'S': -3, 'R': -4, 'T': -1,
50
+ 'W': -3, 'V': 4, 'Y': -1, 'X': -1, 'Z': -3},
51
+ 'H': {'*': -5, 'A': -2, 'C': -3, 'B': 0, 'E': 0, 'D': -1, 'G': -2,
52
+ 'F': -1, 'I': -4, 'H': 10, 'K': 0, 'M': -1, 'L': -3, 'N': 1,
53
+ 'Q': 1, 'P': -2, 'S': -1, 'R': 0, 'T': -2, 'W': -3, 'V': -4,
54
+ 'Y': 2, 'X': -1, 'Z': 0},
55
+ 'K': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 1, 'D': -1, 'G': -2,
56
+ 'F': -4, 'I': -3, 'H': 0, 'K': 6, 'M': -2, 'L': -3, 'N': 0,
57
+ 'Q': 2, 'P': -1, 'S': 0, 'R': 3, 'T': -1, 'W': -3, 'V': -3,
58
+ 'Y': -2, 'X': -1, 'Z': 1},
59
+ 'M': {'*': -5, 'A': -1, 'C': -2, 'B': -3, 'E': -2, 'D': -4,
60
+ 'G': -3, 'F': 0, 'I': 2, 'H': -1, 'K': -2, 'M': 7, 'L': 3,
61
+ 'N': -2, 'Q': 0, 'P': -3, 'S': -2, 'R': -2, 'T': -1, 'W': -1,
62
+ 'V': 1, 'Y': 0, 'X': -1, 'Z': -1},
63
+ 'L': {'*': -5, 'A': -2, 'C': -2, 'B': -4, 'E': -3, 'D': -4,
64
+ 'G': -4, 'F': 1, 'I': 2, 'H': -3, 'K': -3, 'M': 3, 'L': 5,
65
+ 'N': -4, 'Q': -2, 'P': -4, 'S': -3, 'R': -3, 'T': -1,
66
+ 'W': -2, 'V': 1, 'Y': -1, 'X': -1, 'Z': -3},
67
+ 'N': {'*': -5, 'A': -1, 'C': -2, 'B': 5, 'E': 0, 'D': 2, 'G': 0,
68
+ 'F': -4, 'I': -3, 'H': 1, 'K': 0, 'M': -2, 'L': -4, 'N': 7,
69
+ 'Q': 0, 'P': -2, 'S': 1, 'R': -1, 'T': 0, 'W': -4, 'V': -3,
70
+ 'Y': -2, 'X': -1, 'Z': 0},
71
+ 'Q': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 2, 'D': 0, 'G': -2,
72
+ 'F': -4, 'I': -3, 'H': 1, 'K': 2, 'M': 0, 'L': -2, 'N': 0,
73
+ 'Q': 7, 'P': -1, 'S': 0, 'R': 1, 'T': -1, 'W': -1, 'V': -3,
74
+ 'Y': -1, 'X': -1, 'Z': 4},
75
+ 'P': {'*': -5, 'A': -1, 'C': -4, 'B': -2, 'E': -1, 'D': -1,
76
+ 'G': -2, 'F': -4, 'I': -3, 'H': -2, 'K': -1, 'M': -3,
77
+ 'L': -4, 'N': -2, 'Q': -1, 'P': 10, 'S': -1, 'R': -3,
78
+ 'T': -1, 'W': -4, 'V': -3, 'Y': -3, 'X': -1, 'Z': -1},
79
+ 'S': {'*': -5, 'A': 1, 'C': -1, 'B': 0, 'E': -1, 'D': 0, 'G': 0,
80
+ 'F': -3, 'I': -3, 'H': -1, 'K': 0, 'M': -2, 'L': -3, 'N': 1,
81
+ 'Q': 0, 'P': -1, 'S': 5, 'R': -1, 'T': 2, 'W': -4, 'V': -2,
82
+ 'Y': -2, 'X': -1, 'Z': 0},
83
+ 'R': {'*': -5, 'A': -2, 'C': -4, 'B': -1, 'E': 0, 'D': -2, 'G': -3,
84
+ 'F': -3, 'I': -4, 'H': 0, 'K': 3, 'M': -2, 'L': -3, 'N': -1,
85
+ 'Q': 1, 'P': -3, 'S': -1, 'R': 7, 'T': -1, 'W': -3, 'V': -3,
86
+ 'Y': -1, 'X': -1, 'Z': 0},
87
+ 'T': {'*': -5, 'A': 0, 'C': -1, 'B': 0, 'E': -1, 'D': -1, 'G': -2,
88
+ 'F': -2, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -1, 'N': 0,
89
+ 'Q': -1, 'P': -1, 'S': 2, 'R': -1, 'T': 5, 'W': -3, 'V': 0,
90
+ 'Y': -2, 'X': -1, 'Z': -1},
91
+ 'W': {'*': -5, 'A': -3, 'C': -5, 'B': -5, 'E': -3, 'D': -5,
92
+ 'G': -3, 'F': 1, 'I': -3, 'H': -3, 'K': -3, 'M': -1, 'L': -2,
93
+ 'N': -4, 'Q': -1, 'P': -4, 'S': -4, 'R': -3, 'T': -3,
94
+ 'W': 15, 'V': -3, 'Y': 2, 'X': -1, 'Z': -2},
95
+ 'V': {'*': -5, 'A': 0, 'C': -1, 'B': -3, 'E': -3, 'D': -4, 'G': -4,
96
+ 'F': -1, 'I': 4, 'H': -4, 'K': -3, 'M': 1, 'L': 1, 'N': -3,
97
+ 'Q': -3, 'P': -3, 'S': -2, 'R': -3, 'T': 0, 'W': -3, 'V': 5,
98
+ 'Y': -1, 'X': -1, 'Z': -3},
99
+ 'Y': {'*': -5, 'A': -2, 'C': -3, 'B': -3, 'E': -2, 'D': -3,
100
+ 'G': -3, 'F': 4, 'I': -1, 'H': 2, 'K': -2, 'M': 0, 'L': -1,
101
+ 'N': -2, 'Q': -1, 'P': -3, 'S': -2, 'R': -1, 'T': -2, 'W': 2,
102
+ 'V': -1, 'Y': 8, 'X': -1, 'Z': -2},
103
+ 'X': {'*': -5, 'A': -1, 'C': -1, 'B': -1, 'E': -1, 'D': -1,
104
+ 'G': -1, 'F': -1, 'I': -1, 'H': -1, 'K': -1, 'M': -1,
105
+ 'L': -1, 'N': -1, 'Q': -1, 'P': -1, 'S': -1, 'R': -1,
106
+ 'T': -1, 'W': -1, 'V': -1, 'Y': -1, 'X': -1, 'Z': -1},
107
+ 'Z': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 5, 'D': 1, 'G': -2,
108
+ 'F': -4, 'I': -3, 'H': 0, 'K': 1, 'M': -1, 'L': -3, 'N': 0,
109
+ 'Q': 4, 'P': -1, 'S': 0, 'R': 0, 'T': -1, 'W': -2, 'V': -3,
110
+ 'Y': -2, 'X': -1, 'Z': 5}}
111
+
112
+
113
+ import math
114
+ def get_align_score(fasta_1, fasta_2):
115
+ kwargs = {}
116
+ kwargs['suppress_sequences'] = False
117
+ kwargs['zero_index'] = True
118
+ kwargs['protein'] = True
119
+ kwargs['substitution_matrix'] = blosum50
120
+ query = alignment.StripedSmithWaterman(fasta_1, **kwargs)
121
+ align = query(fasta_2)
122
+ score = align.optimal_alignment_score
123
+ return float(score)
124
+
125
+
126
+ def read_data(data_root, result_root):
127
+ training_data_fastas = json.load(open(f"{data_root}/align_fastas_dict_10.23.json"))
128
+ bdb_fastas_dict = training_data_fastas['bdb_fastas']
129
+ pdbbind_fastas_dict = training_data_fastas['pdb_fastas']
130
+
131
+ save_dir_bdb = f"{result_root}/BDB"
132
+ save_dir_pdbbind = f"{result_root}/PDBBind"
133
+ mol_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_mol_reps.npy')
134
+ pocket_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_pocket_reps.npy')
135
+ pocket_names_bdb = json.load(open(f"{save_dir_bdb}/bdb_pocket_names.json"))
136
+ mol_smis_bdb = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json"))
137
+ bdb_pocket_feat_dict = {pocket_names_bdb[i]: pocket_feat_train_bdb[i] for i in range(len(pocket_names_bdb))}
138
+ bdb_mol_feat_dict = {mol_smis_bdb[i]: mol_feat_train_bdb[i] for i in range(len(mol_smis_bdb))}
139
+
140
+ mol_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_mol_reps.npy')
141
+ pocket_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_pocket_reps.npy')
142
+ pocket_names_pdbbind = json.load(open(f"{save_dir_pdbbind}/train_pdbbind_ids.json"))
143
+ mol_smis_pdbbind = json.load(open(f"{save_dir_pdbbind}/train_mol_smis.json"))
144
+ pdbbind_pocket_feat_dict = {pocket_names_pdbbind[i]: pocket_feat_train_pdbbind[i] for i in range(len(pocket_names_pdbbind))}
145
+ pdbbind_mol_feat_dict = {mol_smis_pdbbind[i]: mol_feat_train_pdbbind[i] for i in range(len(mol_smis_pdbbind))}
146
+
147
+ return bdb_fastas_dict, bdb_pocket_feat_dict, bdb_mol_feat_dict, pdbbind_fastas_dict, pdbbind_pocket_feat_dict, pdbbind_mol_feat_dict
148
+
149
+ def get_neighbor_pocket(test_fasta, data_root, result_root, device):
150
+ # 1. Read data file
151
+ print("reading datas")
152
+ bdb_fastas_dict, bdb_pocket_feat_dict, bdb_mol_feat_dict, pdbbind_fastas_dict, \
153
+ pdbbind_pocket_feat_dict, pdbbind_mol_feat_dict = read_data(data_root, result_root)
154
+
155
+ training_assay = json.load(open(f"{data_root}/train_label_blend_seq_full.json"))
156
+ training_assay += json.load(open(f"{data_root}/train_label_pdbbind_seq.json"))
157
+ assay_dict = {}
158
+ for assay in training_assay:
159
+ assay["ligands"] = sorted(assay["ligands"], key=lambda x: x["act"], reverse=True)
160
+ if "assay_id" in assay:
161
+ assay_dict[assay["assay_id"]] = assay
162
+ else:
163
+ assay_dict[assay["pockets"][0][:4]] = assay
164
+
165
+ skip = 0
166
+ # 2. run alignment
167
+ print("running alignment pdbbind")
168
+ align_res_list = []
169
+ for a_name, fasta in tqdm(pdbbind_fastas_dict.items()):
170
+ if a_name not in pdbbind_pocket_feat_dict:
171
+ skip += 1
172
+ continue
173
+ p_name = a_name
174
+ l_smi = assay_dict[a_name]["ligands"][0]["smi"]
175
+ align_score = get_align_score(test_fasta, fasta) / get_align_score(test_fasta, test_fasta)
176
+ if align_score >= 0.5:
177
+ align_res_list.append((pdbbind_pocket_feat_dict[p_name], pdbbind_mol_feat_dict[l_smi], align_score, a_name))
178
+
179
+ print("running alignment bindingdb")
180
+ for a_name, fasta in tqdm(bdb_fastas_dict.items()):
181
+ if a_name not in assay_dict:
182
+ skip += 1
183
+ continue
184
+ p_name = assay_dict[a_name]["pockets"][0]
185
+ l_smi = assay_dict[a_name]["ligands"][0]["smi"]
186
+ if l_smi not in bdb_mol_feat_dict:
187
+ continue
188
+ align_score = get_align_score(test_fasta, fasta) / get_align_score(test_fasta, test_fasta)
189
+ if align_score >= 0.5:
190
+ align_res_list.append((bdb_pocket_feat_dict[p_name], bdb_mol_feat_dict[l_smi], align_score, a_name))
191
+
192
+ for i, res in enumerate(align_res_list):
193
+ align_res_list[i] = (torch.tensor(res[0]).float().to(device),
194
+ torch.tensor(res[1]).float().to(device),
195
+ torch.tensor(int((res[2] - 0.5) * 10)).to(device))
196
+
197
+ return align_res_list
198
+
HGNN/data/CoreSet.dat ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code resl year logKa Ka target
2
+ 4llx 1.75 2014 2.89 Ki=1300uM 1
3
+ 5c28 1.56 2015 5.66 Ki=2.2uM 1
4
+ 3uuo 2.11 2012 7.96 Ki=11nM 1
5
+ 3ui7 2.28 2011 9.00 Ki=1nM 1
6
+ 5c2h 2.09 2015 11.09 Ki=8.2pM 1
7
+ 2v00 1.55 2007 3.66 Kd=0.22mM 2
8
+ 3wz8 1.45 2015 5.82 Ki=1.5uM 2
9
+ 3pww 1.22 2011 7.32 Ki=48nM 2
10
+ 3prs 1.38 2011 7.82 Ki=15nM 2
11
+ 3uri 2.10 2012 9.00 Ki=1nM 2
12
+ 4m0z 2.00 2014 5.19 Kd=6.4uM 3
13
+ 4m0y 1.70 2014 6.46 Kd=0.35uM 3
14
+ 3qgy 2.10 2011 7.80 Ki=16nM 3
15
+ 4qd6 2.45 2015 8.64 Ki=2.3nM 3
16
+ 4rfm 2.10 2015 10.05 Ki=90pM 3
17
+ 4cr9 1.70 2015 4.10 Ki=80uM 4
18
+ 4cra 1.80 2015 7.22 Ki=0.06uM 4
19
+ 4x6p 1.93 2015 8.30 Ki=5nM 4
20
+ 4crc 1.60 2015 8.72 Ki=0.0019uM 4
21
+ 4ty7 2.09 2014 9.52 Ki=0.3nM 4
22
+ 5aba 1.62 2015 2.98 Kd=1040uM 5
23
+ 5a7b 1.40 2015 3.57 Kd=271uM 5
24
+ 4agn 1.60 2012 3.97 Kd=107uM 5
25
+ 4agp 1.50 2012 4.69 Kd=20.6uM 5
26
+ 4agq 1.42 2012 5.01 Kd=9.7uM 5
27
+ 3bgz 2.40 2007 6.26 Ki=0.55uM 6
28
+ 3jya 2.10 2009 6.89 Ki=0.1301uM 6
29
+ 2c3i 1.90 2005 7.60 Kd=25nM 6
30
+ 4k18 2.05 2013 8.96 Ki=1.1nM 6
31
+ 5dwr 2.00 2015 11.22 Ki=6pM 6
32
+ 3mss 1.95 2010 4.66 Kd=22uM 7
33
+ 3k5v 1.74 2010 6.30 Kd=0.5uM 7
34
+ 3pyy 1.85 2011 6.86 Kd=137nM 7
35
+ 2v7a 2.50 2007 8.30 Kd=0.005uM 7
36
+ 4twp 2.40 2015 10.00 Ki=100pM 7
37
+ 3wtj 2.24 2015 6.53 Kd=0.297uM 8
38
+ 3zdg 2.48 2013 7.10 Ki=79nM 8
39
+ 3u8k 2.47 2011 8.66 Ki=2.2nM 8
40
+ 4qac 2.10 2014 9.40 Kd=0.4nM 8
41
+ 3u8n 2.35 2011 10.17 Ki=0.067nM 8
42
+ 1a30 2.00 1998 4.30 Ki=50uM 9
43
+ 2qnq 2.30 2008 6.11 Ki=0.77uM 9
44
+ 1g2k 1.95 2001 7.96 Ki=11nM 9
45
+ 1eby 2.29 2002 9.70 Ki=0.20nM 9
46
+ 3o9i 1.45 2011 11.82 Ki=1.5pM 9
47
+ 4lzs 2.20 2014 4.80 Kd=16uM 10
48
+ 3u5j 1.60 2011 5.61 Kd=2.46uM 10
49
+ 4wiv 1.56 2014 6.26 Kd=550nM 10
50
+ 4ogj 1.65 2014 6.79 Kd=164nM 10
51
+ 3p5o 1.60 2010 7.30 Kd=50.5nM 10
52
+ 1ps3 1.80 2003 2.28 Ki=5.2mM 11
53
+ 3dx1 1.21 2009 3.58 Ki=265uM 11
54
+ 3d4z 1.39 2008 4.89 Ki=13uM 11
55
+ 3dx2 1.40 2009 6.82 Ki=150nM 11
56
+ 3ejr 1.27 2009 8.57 Ki=2.7nM 11
57
+ 3l7b 2.00 2010 2.40 Ki=4.01mM 12
58
+ 4eky 2.45 2012 3.52 Ki=303.0uM 12
59
+ 3g2n 2.10 2010 4.09 Ki=81uM 12
60
+ 3syr 2.40 2012 5.10 Ki=7.9uM 12
61
+ 3ebp 2.00 2009 5.91 Ki=1.24uM 12
62
+ 2w66 2.27 2009 4.05 Ki=89uM 13
63
+ 2w4x 2.42 2009 4.85 Kd=14uM 13
64
+ 2wca 2.30 2009 5.60 Ki=2.5uM 13
65
+ 2xj7 2.00 2010 6.66 Ki=220nM 13
66
+ 2vvn 1.85 2008 7.30 Kd=50nM 13
67
+ 3aru 1.90 2011 3.22 Kd=600uM 14
68
+ 3arv 1.50 2011 5.64 Kd=2.3uM 14
69
+ 3ary 1.35 2011 6.00 Kd=1.0uM 14
70
+ 3arq 1.50 2011 6.40 Kd=0.4uM 14
71
+ 3arp 1.55 2011 7.15 Kd=0.07uM 14
72
+ 4ih5 1.90 2013 4.11 Kd=78uM 15
73
+ 4ih7 2.30 2013 5.24 Kd=5.8uM 15
74
+ 3cj4 2.07 2008 6.51 Kd=0.31uM 15
75
+ 4eo8 1.80 2012 8.15 Kd=7nM 15
76
+ 3gnw 2.39 2009 9.10 Kd=0.79nM 15
77
+ 1gpk 2.10 2002 5.37 Ki=4.3uM 16
78
+ 1gpn 2.35 2002 6.48 Ki=0.334uM 16
79
+ 1h23 2.15 2002 8.35 Ki=4.5nM 16
80
+ 1h22 2.15 2002 9.10 Ki=0.8nM 16
81
+ 1e66 2.10 2001 9.89 Ki=0.13nM 16
82
+ 3f3a 2.00 2008 4.19 Ki=64.8uM 17
83
+ 3f3c 2.10 2008 6.02 Ki=950nM 17
84
+ 4mme 2.50 2013 6.50 Kd=318nM 17
85
+ 3f3d 2.30 2008 7.16 Kd=69nM 17
86
+ 3f3e 1.80 2008 7.70 Kd=20nM 17
87
+ 2wbg 1.85 2009 4.45 Ki=35.2uM 18
88
+ 2cbv 1.95 2006 5.48 Kd=3.3uM 18
89
+ 2j78 1.65 2006 6.42 Kd=384nM 18
90
+ 2j7h 1.95 2006 7.19 Kd=65nM 18
91
+ 2cet 1.97 2006 8.02 Kd=9.6nM 18
92
+ 3udh 1.70 2012 2.85 Kd=1.4mM 19
93
+ 3rsx 2.48 2011 4.41 Kd=38.8uM 19
94
+ 4djv 1.73 2012 6.72 Ki=0.19uM 19
95
+ 2vkm 2.05 2008 8.74 Ki=1.8nM 19
96
+ 4gid 2.00 2012 10.77 Ki=0.017nM 19
97
+ 4jfs 2.00 2013 5.27 Ki=5.4uM 20
98
+ 4j28 1.73 2013 5.70 Ki=2.0uM 20
99
+ 2wvt 1.80 2010 6.12 Kd=755nM 20
100
+ 2xii 1.80 2010 7.20 Kd=63.3nM 20
101
+ 4pcs 1.77 2014 7.85 Ki=14nM 20
102
+ 3rr4 1.68 2012 4.55 Ki=28.05uM 21
103
+ 1s38 1.81 2004 5.15 Ki=7.0uM 21
104
+ 1r5y 1.20 2004 6.46 Ki=0.35uM 21
105
+ 3gc5 1.40 2009 7.26 Ki=55nM 21
106
+ 3ge7 1.50 2009 8.70 Ki=2nM 21
107
+ 4dli 1.91 2013 5.62 Kd=2.40uM 22
108
+ 2zb1 2.50 2008 6.32 Kd=0.48uM 22
109
+ 4f9w 2.00 2013 6.94 Ki=114nM 22
110
+ 3e92 2.00 2008 8.00 Ki=10nM 22
111
+ 3e93 2.00 2008 8.85 Ki=1.4nM 22
112
+ 4owm 1.99 2014 2.96 Ki=1090uM 23
113
+ 3twp 1.83 2012 3.92 Ki=119uM 23
114
+ 3r88 1.73 2012 4.82 Ki=15uM 23
115
+ 4gkm 1.67 2013 5.17 Ki=6.8uM 23
116
+ 3qqs 1.97 2012 5.82 Ki=1.5uM 23
117
+ 3gv9 1.80 2009 2.12 Ki=7.5mM 24
118
+ 3gr2 1.80 2009 2.52 Ki=3mM 24
119
+ 4kz6 1.68 2014 3.10 Ki=0.8mM 24
120
+ 4jxs 1.90 2014 4.74 Ki=18uM 24
121
+ 2r9w 1.80 2008 5.10 Ki=8uM 24
122
+ 2hb1 2.00 2006 3.80 Ki=160uM 25
123
+ 1bzc 2.35 1999 4.92 Ki=12uM 25
124
+ 2qbr 2.30 2008 6.33 Ki=0.47uM 25
125
+ 2qbq 2.10 2008 7.44 Ki=0.036uM 25
126
+ 2qbp 2.50 2008 8.40 Ki=0.004uM 25
127
+ 1q8t 2.00 2003 4.76 Kd=17.5uM 26
128
+ 1ydr 2.20 1997 5.52 Ki=3.0uM 26
129
+ 1q8u 1.90 2003 5.96 Kd=1.1uM 26
130
+ 1ydt 2.30 1997 7.32 Ki=48nM 26
131
+ 3ag9 2.00 2010 8.05 Ki=9nM 26
132
+ 3fcq 1.75 2009 2.77 Ki=1.7mM 27
133
+ 1z9g 1.70 2005 5.64 Ki=2.3uM 27
134
+ 1qf1 2.00 1999 7.32 Ki=48nM 27
135
+ 5tmn 1.60 1989 8.04 Ki=9.1nM 27
136
+ 4tmn 1.70 1989 10.17 Ki=0.068nM 27
137
+ 4ddk 1.75 2013 2.29 Kd=5.13mM 28
138
+ 4ddh 2.07 2013 3.32 Kd=0.48mM 28
139
+ 3ivg 1.95 2009 4.30 Kd=50uM 28
140
+ 3coz 2.00 2008 5.57 Kd=2.7uM 28
141
+ 3coy 2.03 2008 6.02 Kd=0.96uM 28
142
+ 3pxf 1.80 2011 4.43 Kd=37uM 29
143
+ 4eor 2.20 2013 6.30 Ki=500nM 29
144
+ 2xnb 1.85 2010 6.83 Ki=149nM 29
145
+ 1pxn 2.50 2004 7.15 Ki=0.07uM 29
146
+ 2fvd 1.85 2006 8.52 Ki=3nM 29
147
+ 4k77 2.40 2013 6.63 Ki=235nM 30
148
+ 4e5w 1.86 2012 7.66 Ki=22nM 30
149
+ 4ivb 1.90 2013 8.72 Ki=1.9nM 30
150
+ 4ivd 1.93 2013 9.52 Ki=0.3nM 30
151
+ 4ivc 2.35 2013 10.00 Ki=0.1nM 30
152
+ 4f09 2.40 2012 6.70 Ki=200nM 31
153
+ 4gfm 2.30 2013 7.22 Ki=0.06uM 31
154
+ 4hge 2.30 2012 7.92 Ki=11.9nM 31
155
+ 4e6q 1.95 2012 8.36 Ki=4.4nM 31
156
+ 4jia 1.85 2013 9.22 Ki=0.6nM 31
157
+ 2brb 2.10 2005 4.86 Ki=13.7uM 32
158
+ 2br1 2.00 2005 5.14 Ki=7.2uM 32
159
+ 3jvr 1.76 2009 5.72 Ki=1.89uM 32
160
+ 3jvs 1.90 2009 6.54 Kd=0.29uM 32
161
+ 1nvq 2.00 2003 8.25 Ki=5.6nM 32
162
+ 3acw 1.63 2010 4.76 Ki=17.5uM 33
163
+ 4ea2 2.05 2012 6.44 Ki=0.36uM 33
164
+ 2zcr 1.92 2008 6.87 Ki=135nM 33
165
+ 2zy1 1.78 2009 7.40 Ki=0.04uM 33
166
+ 2zcq 2.38 2008 8.82 Ki=1.5nM 33
167
+ 1bcu 2.00 1998 3.28 Kd=0.53mM 34
168
+ 3bv9 1.80 2008 5.36 Ki=4.4uM 34
169
+ 1oyt 1.67 2003 7.24 Ki=0.057uM 34
170
+ 2zda 1.73 2008 8.40 Ki=4nM 34
171
+ 3utu 1.55 2012 10.92 Ki=0.012nM 34
172
+ 3u9q 1.52 2011 4.38 Ki=41.7uM 35
173
+ 2yfe 2.00 2012 6.63 Ki=0.236uM 35
174
+ 3fur 2.30 2009 8.00 Ki=10nM 35
175
+ 3b1m 1.60 2011 8.48 Ki=3.3nM 35
176
+ 2p4y 2.25 2008 9.00 Ki=1nM 35
177
+ 3uo4 2.45 2012 6.52 Kd=299nM 36
178
+ 3up2 2.30 2012 7.40 Kd=40nM 36
179
+ 3e5a 2.30 2008 8.23 Ki=5.9nM 36
180
+ 2wtv 2.40 2010 8.74 Ki=1.8nM 36
181
+ 3myg 2.40 2010 10.70 Kd=0.02nM 36
182
+ 3kgp 2.35 2009 2.57 Ki=2.68mM 37
183
+ 1c5z 1.85 2000 4.01 Ki=97uM 37
184
+ 1o5b 1.85 2004 5.77 Ki=1.7uM 37
185
+ 1owh 1.61 2003 7.40 Ki=40nM 37
186
+ 1sqa 2.00 2004 9.21 Ki=0.62nM 37
187
+ 4jsz 1.90 2013 2.30 Ki=5000uM 38
188
+ 3kwa 2.00 2010 4.08 Ki=84uM 38
189
+ 2weg 1.10 2009 6.50 Kd=314nM 38
190
+ 3ryj 1.39 2011 7.80 Kd=16nM 38
191
+ 3dd0 1.48 2009 9.00 Ki=1nM 38
192
+ 2xdl 1.98 2010 3.10 Kd=790uM 39
193
+ 3b27 1.50 2011 5.16 Kd=6.9uM 39
194
+ 1yc1 1.70 2005 6.17 Kd=680nM 39
195
+ 3rlr 1.70 2011 7.52 Ki=30nM 39
196
+ 2yki 1.67 2011 9.46 Kd=0.35nM 39
197
+ 1z95 1.80 2005 7.12 Ki=76nM 40
198
+ 3b68 1.90 2008 8.40 Ki=4nM 40
199
+ 3b5r 1.80 2008 8.77 Ki=1.7nM 40
200
+ 3b65 1.80 2008 9.27 Ki=0.54nM 40
201
+ 3g0w 1.95 2009 9.52 Ki=0.3nM 40
202
+ 4u4s 1.90 2014 2.92 Kd=1200uM 41
203
+ 1p1q 2.00 2003 4.89 Kd=12.8uM 41
204
+ 1syi 2.10 2005 5.44 Ki=3590nM 41
205
+ 1p1n 1.60 2003 6.80 Kd=0.16uM 41
206
+ 2al5 1.65 2005 8.40 Ki=4nM 41
207
+ 3g2z 1.50 2009 2.36 Ki=4.4mM 42
208
+ 3g31 1.70 2009 2.89 Ki=1.3mM 42
209
+ 4de2 1.40 2012 4.12 Ki=76.0uM 42
210
+ 4de3 1.44 2012 5.52 Ki=3.0uM 42
211
+ 4de1 1.26 2012 5.96 Ki=1.1uM 42
212
+ 1vso 1.85 2007 4.72 Ki=18.98uM 43
213
+ 4dld 2.00 2012 5.82 Ki=1.5uM 43
214
+ 3gbb 2.10 2009 6.90 Ki=126nM 43
215
+ 3fv2 1.50 2010 8.11 Ki=7.7nM 43
216
+ 3fv1 1.50 2010 9.30 Ki=0.5nM 43
217
+ 4mgd 1.90 2014 4.69 Kd=20.19uM 44
218
+ 2qe4 2.40 2007 7.96 Ki=11.0nM 44
219
+ 1qkt 2.20 2000 9.04 Kd=0.92nM 44
220
+ 2pog 1.84 2007 9.54 Ki=0.29nM 44
221
+ 2p15 1.94 2007 10.30 Kd=50pM 44
222
+ 2y5h 1.33 2011 5.79 Ki=1620nM 45
223
+ 1lpg 2.00 2003 7.09 Ki=82nM 45
224
+ 2xbv 1.66 2010 8.43 Kd=3.7nM 45
225
+ 1z6e 1.80 2006 9.72 Ki=0.19nM 45
226
+ 1mq6 2.10 2003 11.15 Ki=7pM 45
227
+ 1nc3 2.20 2003 5.00 Ki=10uM 46
228
+ 1nc1 2.00 2003 6.12 Ki=0.75uM 46
229
+ 1y6r 2.20 2005 10.11 Ki=77pM 46
230
+ 4f2w 2.00 2013 11.30 Ki=5.0pM 46
231
+ 4f3c 1.93 2013 11.82 Ki=1.5pM 46
232
+ 1uto 1.15 2004 2.27 Kd=5.32mM 47
233
+ 4abg 1.52 2012 3.57 Kd=271uM 47
234
+ 3gy4 1.55 2010 5.10 Kd=8uM 47
235
+ 1k1i 2.20 2001 6.58 Kd=264nM 47
236
+ 1o3f 1.55 2003 7.96 Ki=0.011uM 47
237
+ 2yge 1.96 2011 5.06 Kd=8.62uM 48
238
+ 2fxs 2.00 2007 6.06 Kd=0.87uM 48
239
+ 2iwx 1.50 2006 6.68 Kd=0.21uM 48
240
+ 2wer 1.60 2009 7.05 Kd=90nM 48
241
+ 2vw5 1.90 2008 8.52 Kd=3nM 48
242
+ 4kzq 2.25 2013 6.10 Kd=788nM 49
243
+ 4kzu 2.10 2013 6.50 Ki=313nM 49
244
+ 4j21 1.93 2013 7.41 Kd=39nM 49
245
+ 4j3l 2.09 2013 7.80 Kd=16nM 49
246
+ 3kr8 2.10 2009 8.10 Kd=8nM 49
247
+ 2ymd 1.96 2012 3.16 Kd=693uM 50
248
+ 2wnc 2.20 2009 6.32 Kd=479nM 50
249
+ 2xys 1.91 2011 7.42 Ki=38nM 50
250
+ 2wn9 1.75 2009 8.52 Kd=3.0nM 50
251
+ 2x00 2.40 2010 11.33 Kd=4.7pM 50
252
+ 3ozt 1.48 2011 4.13 Ki=74.9uM 51
253
+ 3ozs 1.44 2011 5.33 Ki=4645nM 51
254
+ 3oe5 1.52 2011 6.88 Ki=132nM 51
255
+ 3oe4 1.49 2011 7.47 Ki=34nM 51
256
+ 3nw9 1.65 2011 9.00 Ki=1nM 51
257
+ 3ao4 1.95 2011 2.07 Kd=8.5mM 52
258
+ 3zt2 1.70 2012 2.84 Kd=1435uM 52
259
+ 3zsx 1.95 2012 3.28 Kd=519uM 52
260
+ 4cig 1.70 2014 3.67 Kd=214uM 52
261
+ 3zso 1.75 2012 5.12 Kd=7.6uM 52
262
+ 3n7a 2.00 2011 3.70 Ki=200uM 53
263
+ 4ciw 2.20 2014 4.82 Ki=15.0uM 53
264
+ 3n86 1.90 2011 5.64 Ki=2.3uM 53
265
+ 3n76 1.90 2011 6.85 Ki=0.14uM 53
266
+ 2xb8 2.40 2010 7.59 Ki=26nM 53
267
+ 4bkt 2.35 2013 3.62 Kd=240uM 54
268
+ 4w9c 2.20 2014 4.65 Kd=22.2uM 54
269
+ 4w9l 2.20 2014 5.02 Kd=9.52uM 54
270
+ 4w9i 2.40 2014 5.96 Kd=1.10uM 54
271
+ 4w9h 2.10 2014 6.73 Kd=0.185uM 54
272
+ 3nq9 1.90 2010 4.03 Kd=92.6uM 55
273
+ 3ueu 2.10 2011 5.24 Kd=5.81uM 55
274
+ 3uev 1.90 2011 5.89 Kd=1.29uM 55
275
+ 3uew 2.00 2011 6.31 Kd=0.49uM 55
276
+ 3uex 2.10 2011 6.92 Kd=0.12uM 55
277
+ 3lka 1.80 2010 2.82 Kd=1.5mM 56
278
+ 3ehy 1.90 2009 5.85 Ki=1.4uM 56
279
+ 3tsk 2.00 2012 7.17 Kd=67nM 56
280
+ 3nx7 1.80 2010 8.10 Kd=7.88nM 56
281
+ 4gr0 1.50 2013 9.55 Ki=0.28nM 56
282
+ 3dxg 1.39 2009 2.40 Ki=4.0mM 57
283
+ 3d6q 1.60 2009 3.76 Ki=172uM 57
284
+ 1w4o 1.60 2005 5.22 Ki=6uM 57
285
+ 1o0h 1.20 2003 5.92 Ki=1.2uM 57
286
+ 1u1b 2.00 2005 7.80 Kd=16nM 57
HGNN/data/PDBbind_v2020/index/INDEX_general_PL_data.2020 ADDED
The diff for this file is too large to render. See raw diff
 
HGNN/data/PDBbind_v2020/index/INDEX_general_PL_name.2020 ADDED
The diff for this file is too large to render. See raw diff
 
HGNN/data/PDBbind_v2020/index/INDEX_refined_data.2020 ADDED
The diff for this file is too large to render. See raw diff
 
HGNN/data/PDBbind_v2020/index/INDEX_refined_name.2020 ADDED
The diff for this file is too large to render. See raw diff
 
HGNN/main.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from PL_Encoder import PLEncoder
8
+ from PL_Aggregator import PLAggregator
9
+ from PP_Encoder import PPEncoder
10
+ from PP_Aggregator import PPAggregator
11
+ from screen_dataset import *
12
+ import torch.nn.functional as F
13
+ import torch.utils.data
14
+ import argparse
15
+ import os
16
+ from util import cal_metrics
17
+
18
+
19
+ class HGNN(nn.Module):
20
+
21
+ def __init__(self, enc_u, enc_v, r2e):
22
+ super(HGNN, self).__init__()
23
+ self.enc_u = enc_u
24
+ self.enc_v = enc_v
25
+ self.embed_dim = enc_u.embed_dim
26
+
27
+ self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim)
28
+ self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim)
29
+ self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim)
30
+ self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim)
31
+
32
+ self.r2e = r2e
33
+ self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
34
+ self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
35
+
36
+ self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14))
37
+
38
+ def trainable_parameters(self):
39
+ for name, param in self.named_parameters(recurse=True):
40
+ if param.requires_grad:
41
+ yield param
42
+
43
+ def forward(self, nodes_u, nodes_v):
44
+ embeds_u = self.enc_u(nodes_u, nodes_v)
45
+ embeds_v = self.enc_v(nodes_v)
46
+ return embeds_u, embeds_v
47
+
48
+ def criterion(self, x_u, x_v, labels):
49
+
50
+ netout = torch.matmul(x_u, torch.transpose(x_v, 0, 1))
51
+ score = netout * self.logit_scale.exp().detach()
52
+ score = (labels - torch.eye(len(labels)).to(labels.device)) * -1e6 + score
53
+
54
+ lprobs_pocket = F.log_softmax(score.float(), dim=-1)
55
+ lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
56
+ sample_size = lprobs_pocket.size(0)
57
+ targets = torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
58
+
59
+ # pocket retrieve mol
60
+ loss_pocket = F.nll_loss(
61
+ lprobs_pocket,
62
+ targets,
63
+ reduction="mean"
64
+ )
65
+
66
+ lprobs_mol = F.log_softmax(torch.transpose(score.float(), 0, 1), dim=-1)
67
+ lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
68
+ lprobs_mol = lprobs_mol[:sample_size]
69
+
70
+ # mol retrieve pocket
71
+ loss_mol = F.nll_loss(
72
+ lprobs_mol,
73
+ targets,
74
+ reduction="mean"
75
+ )
76
+
77
+ loss = 0.5 * loss_pocket + 0.5 * loss_mol
78
+
79
+ ef_all = []
80
+ for i in range(len(netout)):
81
+ act_pocket = labels[i]
82
+ affi_pocket = netout[i]
83
+ top1_index = torch.argmax(affi_pocket)
84
+ top1_act = act_pocket[top1_index]
85
+ ef_all.append(cal_metrics(affi_pocket.detach().cpu().numpy(), act_pocket.detach().cpu().numpy()))
86
+ ef_mean = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
87
+
88
+ return loss, ef_mean, netout
89
+
90
+ def loss(self, nodes_u, nodes_v, labels):
91
+ x_u, x_v = self.forward(nodes_u, nodes_v)
92
+ loss, ef_mean, netout = self.criterion(x_u, x_v, labels)
93
+ return loss, ef_mean
94
+
95
+
96
+ def train(model, device, train_loader, optimizer, epoch, valid_idxes, valid_molidxes, valid_labels):
97
+ model.train()
98
+ running_loss = 0.0
99
+ for i, data in enumerate(train_loader, 0):
100
+ batch_nodes_u, batch_nodes_v, labels = data
101
+ optimizer.zero_grad()
102
+ loss, _ = model.loss(batch_nodes_u[0].to(device), batch_nodes_v[0].to(device), labels[0].to(device))
103
+ loss.backward(retain_graph=True)
104
+ optimizer.step()
105
+ running_loss += loss.item()
106
+ if i % 200 == 0:
107
+ print('[%d, %5d] loss: %.3f '%(epoch, i, running_loss / 200))
108
+ running_loss = 0.0
109
+ avg_loss, avg_acc = valid(model,
110
+ device,
111
+ torch.tensor(valid_idxes).to(device),
112
+ torch.tensor(valid_molidxes).to(device),
113
+ torch.tensor(valid_labels).to(device))
114
+ print('Valid set results:', avg_loss.item(), avg_acc)
115
+ return 0
116
+
117
+
118
+ def valid(model, device, valid_idxes, valid_molidxes, valid_labels):
119
+ model.eval()
120
+ with torch.no_grad():
121
+ loss, ef = model.loss(valid_idxes.to(device), valid_molidxes.to(device), valid_labels.to(device))
122
+ model.train()
123
+ return loss, ef
124
+
125
+
126
+ def test_dekois(model, device, epoch, result_root, dekois_pocket_name, dekois_idxes):
127
+ model.eval()
128
+ loss_all, ef_all = [], []
129
+ loss_raw_all, ef_raw_all = [], []
130
+ dekois_dir = f"{result_root}/DEKOIS"
131
+ with torch.no_grad():
132
+ for dekois_id, pocket_node_id in zip(dekois_pocket_name, dekois_idxes):
133
+ embeds_pocket = model.enc_u([pocket_node_id], None, max_sample=-1)
134
+ embeds_lig = torch.tensor(np.load(f"{dekois_dir}/{dekois_id}/saved_mols_embed.npy")).to(device).float()
135
+ labels = np.load(f"{dekois_dir}/{dekois_id}/saved_labels.npy")
136
+ embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pocket_node_id]).to(device))
137
+
138
+ score = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
139
+ score_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
140
+ np.save(f"{dekois_dir}/{dekois_id}/GNN_res_epoch{epoch}.npy", score)
141
+ np.save(f"{dekois_dir}/{dekois_id}/noGNN_res.npy", score_raw)
142
+ metric = cal_metrics(score, labels)
143
+ metric_raw = cal_metrics(score_raw, labels)
144
+ # print(dekois_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
145
+ ef_all.append(metric)
146
+ ef_raw_all.append(metric_raw)
147
+
148
+ model.train()
149
+ ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
150
+ ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
151
+ print('Test on dekois:', ef_all)
152
+ print('No HGNN on dekois:', ef_raw_all)
153
+
154
+ def test_dude(model, device, epoch, result_root, dude_pocket_name, dude_idxes):
155
+ model.eval()
156
+ loss_all, ef_all = [], []
157
+ loss_raw_all, ef_raw_all = [], []
158
+ dude_dir = f"{result_root}/DUDE"
159
+ with torch.no_grad():
160
+ for dude_id, pocket_node_id in zip(dude_pocket_name, dude_idxes):
161
+ embeds_pocket = model.enc_u([pocket_node_id], None, max_sample=-1)
162
+ embeds_lig = torch.tensor(np.load(f"{dude_dir}/{dude_id}/saved_mols_embed.npy")).to(device).float()
163
+ labels = np.load(f"{dude_dir}/{dude_id}/saved_labels.npy")
164
+ embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pocket_node_id]).to(device))
165
+
166
+ score = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
167
+ score_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1)).squeeze().detach().cpu().numpy()
168
+ np.save(f"{dude_dir}/{dude_id}/GNN_res_epoch{epoch}.npy", score)
169
+ np.save(f"{dude_dir}/{dude_id}/noGNN_res.npy", score_raw)
170
+ metric = cal_metrics(score, labels)
171
+ metric_raw = cal_metrics(score_raw, labels)
172
+ # print(dude_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
173
+ ef_all.append(metric)
174
+ ef_raw_all.append(metric_raw)
175
+
176
+ model.train()
177
+ ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
178
+ ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
179
+ print('Test on dude:', ef_all)
180
+ print('No HGNN on dude:', ef_raw_all)
181
+
182
+ def test_pcba(model, device, epoch, result_root, pcba_idxes):
183
+ model.eval()
184
+ loss_all, ef_all = [], []
185
+ loss_raw_all, ef_raw_all = [], []
186
+ pcba_dir = f"{result_root}/PCBA"
187
+ with torch.no_grad():
188
+ pocket_idx = 0
189
+ for pcba_id in sorted(list(os.listdir(pcba_dir))):
190
+ pocket_names = []
191
+ for names in json.load(open(f"{pcba_dir}/{pcba_id}/saved_pocket_names.json")):
192
+ pocket_names += names
193
+ embeds_lig = torch.tensor(np.load(f"{pcba_dir}/{pcba_id}/saved_mols_embed.npy")).to(device).float()
194
+ labels = np.load(f"{pcba_dir}/{pcba_id}/saved_labels.npy")
195
+ score_all_pocket = []
196
+ score_raw_pocket = []
197
+
198
+ for i, pocket_name in enumerate(pocket_names):
199
+ pcba_test_idx = pcba_idxes[pocket_idx]
200
+ embeds_pocket = model.enc_u([pcba_test_idx], None, max_sample=-1)
201
+ netout = torch.matmul(embeds_pocket, torch.transpose(embeds_lig, 0, 1))
202
+ embeds_pocket_raw = model.enc_u.aggregator.u2e(torch.tensor([pcba_test_idx]).to(device))
203
+ netout_raw = torch.matmul(embeds_pocket_raw, torch.transpose(embeds_lig, 0, 1))
204
+ score_all_pocket.append(netout.squeeze().detach().cpu().numpy())
205
+ score_raw_pocket.append(netout_raw.squeeze().detach().cpu().numpy())
206
+ pocket_idx += 1
207
+
208
+ score_max = np.stack(score_all_pocket, axis=0).mean(axis=0)
209
+ score_raw_max = np.stack(score_raw_pocket, axis=0).max(axis=0)
210
+ metric = cal_metrics(score_max, labels)
211
+ print(pcba_id, metric["EF1"], metric["BEDROC"], metric["AUC"])
212
+ np.save(f"{pcba_dir}/{pcba_id}/GNN_res_epoch{epoch}.npy", score_max)
213
+ np.save(f"{pcba_dir}/{pcba_id}/noGNN_res.npy", score_raw_max)
214
+ ef_all.append(cal_metrics(score_max, labels))
215
+ ef_raw_all.append(cal_metrics(score_raw_max, labels))
216
+
217
+ model.train()
218
+ print(f"saving to {pcba_dir}")
219
+ ef_all = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
220
+ ef_raw_all = {k: np.mean([x[k] for x in ef_raw_all]) for k in ef_raw_all[0].keys()}
221
+ print('Test on pcba:', ef_all)
222
+ print('No HGNN on pcba:', ef_raw_all)
223
+ return ef_all["EF1"]
224
+
225
+
226
+ def main():
227
+ # Training settings
228
+ parser = argparse.ArgumentParser(description='HGNN model training')
229
+ parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
230
+ parser.add_argument('--embed_dim', type=int, default=128, metavar='N', help='embedding size')
231
+ parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
232
+ parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing')
233
+ parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train')
234
+ parser.add_argument("--test_ckpt", type=str, default=None)
235
+ parser.add_argument("--data_root", type=str, default="../data")
236
+ parser.add_argument("--result_root", type=str, default="../result/pocket_ranking")
237
+ args = parser.parse_args()
238
+ data_root = args.data_root
239
+
240
+ seed = 42
241
+ torch.manual_seed(seed)
242
+ torch.cuda.manual_seed(seed)
243
+ np.random.seed(seed)
244
+ random.seed(seed)
245
+
246
+ use_cuda = False
247
+ if torch.cuda.is_available():
248
+ use_cuda = True
249
+ device = torch.device("cuda" if use_cuda else "cpu")
250
+
251
+ print("begin load dataset")
252
+ assayinfo_lst, pocket_feat, mol_feat, assayid_lst_all, mol_smi_lst, \
253
+ assayid_lst_train, assayid_lst_test, dude_pocket_name, pcba_pocket_name, dekois_pocket_name, valid_molidxes = load_datas(data_root, result_root)
254
+ print("begin load pocket-pocket graph")
255
+ pocket_graph = load_pocket_pocket_graph(data_root, assayid_lst_all, assayid_lst_train)
256
+
257
+ screen_dataset = ScreenDataset(args.batch_size, pocket_graph, assayinfo_lst, assayid_lst_all, mol_smi_lst, assayid_lst_train)
258
+ num_pockets = len(assayid_lst_all)
259
+ num_ligs = mol_feat.shape[0]
260
+
261
+ embed_dim = args.embed_dim
262
+ pocket2e = nn.Embedding(num_pockets, embed_dim).to(device)
263
+ pocket2e.weight.data.copy_(torch.tensor(pocket_feat).to(device))
264
+ for param in pocket2e.parameters():
265
+ param.requires_grad = False
266
+
267
+ lig2e = nn.Embedding(num_ligs, embed_dim).to(device)
268
+ for param in lig2e.parameters():
269
+ param.requires_grad = False
270
+ type2e = nn.Embedding(10, embed_dim).to(device)
271
+
272
+ agg_pocket = PLAggregator(lig2e, type2e, pocket2e, embed_dim, cuda=device, uv=True)
273
+ enc_pocket = PLEncoder(embed_dim, pocket_graph, agg_pocket, assayid_lst_all, assayid_lst_train, mol_smi_lst, assayinfo_lst, cuda=device, uv=True)
274
+ # neighobrs
275
+ agg_pocket_sim = PPAggregator(pocket2e, embed_dim, cuda=device)
276
+ enc_pocket = PPEncoder(enc_pocket, embed_dim, pocket_graph, agg_pocket_sim, assayid_lst_all, assayid_lst_train,
277
+ base_model=enc_pocket, cuda=device)
278
+ enc_lig = lig2e
279
+ # model
280
+ graphrec = HGNN(enc_pocket, enc_lig, type2e).to(device)
281
+ print("trainable parameters")
282
+ for name, param in graphrec.named_parameters(recurse=True):
283
+ if param.requires_grad:
284
+ print(name, param.shape)
285
+ optimizer = torch.optim.RMSprop(graphrec.trainable_parameters(), lr=args.lr, alpha=0.9)
286
+
287
+ begin = len(assayid_lst_train+assayid_lst_test)
288
+ end = begin + len(dude_pocket_name)
289
+ dude_idxes = range(begin, end)
290
+ begin = end
291
+ end += len(pcba_pocket_name)
292
+ pcba_idxes = range(begin, end)
293
+ begin = end
294
+ end += len(dekois_pocket_name)
295
+ dekois_idxes = range(begin, end)
296
+
297
+ if args.test_ckpt is not None:
298
+ graphrec.load_state_dict(torch.load(args.test_ckpt, weights_only=True))
299
+ test_dude(graphrec, device, 0, result_root, dude_pocket_name, dude_idxes)
300
+ test_dekois(graphrec, device, 0, result_root, dekois_pocket_name, dekois_idxes)
301
+ test_pcba(graphrec, device, 0, result_root, pcba_idxes)
302
+ else:
303
+ for epoch in range(args.epochs):
304
+ screen_dataset.set_epoch(epoch)
305
+ train_loader = torch.utils.data.DataLoader(screen_dataset, batch_size=1, shuffle=True, num_workers=8)
306
+ lig2e.weight.data.copy_(torch.tensor(mol_feat).to(device))
307
+ valid_labels = load_valid_label(assayid_lst_test)
308
+ valid_idxes = range(len(assayid_lst_train), len(assayid_lst_train+assayid_lst_test))
309
+ train(graphrec, device, train_loader, optimizer, epoch, valid_idxes, valid_molidxes, valid_labels)
310
+ test_dude(graphrec, device, epoch+1, result_root, dude_pocket_name, dude_idxes)
311
+ test_dekois(graphrec, device, epoch+1, result_root, dekois_pocket_name, dekois_idxes)
312
+ test_pcba(graphrec, device, epoch+1, result_root, pcba_idxes)
313
+
314
+ os.system(f"mkdir -p {result_root}/HGNN_save")
315
+ torch.save(graphrec.state_dict(),f"{result_root}/HGNN_save/model_{epoch}.pt")
316
+
317
+ if __name__ == "__main__":
318
+ main()
HGNN/read_fasta.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re
2
+ import prody as pr
3
+ from tqdm import tqdm
4
+ from multiprocessing import Pool
5
+
6
+
7
+ # import subprocess
8
+ # os.environ["BABEL_LIBDIR"] = "/home/shenchao/.conda/envs/my2/lib/openbabel/3.1.0"
9
+
10
+ def write_file(output_file, outline):
11
+ buffer = open(output_file, 'w')
12
+ buffer.write(outline)
13
+ buffer.close()
14
+
15
+
16
+ def lig_rename(infile, outfile):
17
+ ##some peptides may impede the generation of pocket, so rename the ligname first.
18
+ lines = open(infile, 'r').readlines()
19
+ newlines = []
20
+ for line in lines:
21
+ if re.search(r'^HETATM|^ATOM', line):
22
+ newlines.append(line[:17] + "LIG" + line[20:])
23
+ else:
24
+ newlines.append(line)
25
+ write_file(outfile, ''.join(newlines))
26
+
27
+
28
+ def check_mol(infile, outfile):
29
+ # Some metals may have the same ID as ligand, thus making ligand included in the pocket.
30
+ os.system("cat %s | sed '/LIG/d' > %s" % (infile, outfile))
31
+
32
+
33
+ def extract_pocket(protpath,
34
+ ligpath,
35
+ cutoff=5.0,
36
+ protname=None,
37
+ ligname=None,
38
+ pdb_pocket_file=None,
39
+ workdir='.'):
40
+ """
41
+ protpath: the path of protein file (.pdb).
42
+ ligpath: the path of ligand file (.sdf|.mol2|.pdb).
43
+ cutoff: the distance range within the ligand to determine the pocket.
44
+ protname: the name of the protein.
45
+ ligname: the name of the ligand.
46
+ workdir: working directory.
47
+ """
48
+ if protname is None:
49
+ protname = os.path.basename(protpath).split('.')[0]
50
+ if ligname is None:
51
+ ligname = os.path.basename(ligpath).split('.')[0]
52
+
53
+
54
+ if not re.search(r'.pdb$', ligpath):
55
+ os.system(f"obabel {ligpath} -O {workdir}/{ligname}.pdb")
56
+ else:
57
+ os.system(f"cp {ligpath} {workdir}/{ligname}.pdb")
58
+
59
+ xprot = pr.parsePDB(protpath)
60
+ # xlig = pr.parsePDB("%s/%s.pdb"%(workdir, ligname))
61
+
62
+ # if (xlig.getResnames() == xlig.getResnames()[0]).all():
63
+ # lresname = xlig.getResnames()[0]
64
+ # else:
65
+ lig_rename("%s/%s.pdb" % (workdir, ligname), "%s/%s2.pdb" % (workdir, ligname))
66
+ os.remove("%s/%s.pdb" % (workdir, ligname))
67
+ os.rename("%s/%s2.pdb" % (workdir, ligname), "%s/%s.pdb" % (workdir, ligname))
68
+ xlig = pr.parsePDB("%s/%s.pdb" % (workdir, ligname))
69
+ lresname = xlig.getResnames()[0]
70
+ xcom = xlig + xprot
71
+
72
+ # select ONLY atoms that belong to the protein
73
+ ret = xcom.select(f'same residue as exwithin %s of resname %s' % (cutoff, lresname))
74
+
75
+ pr.writePDB("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff), ret)
76
+ # ret = pr.parsePDB("%s/%s_pocket_%s.pdb"%(workdir, protname, cutoff))
77
+
78
+ check_mol("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff), pdb_pocket_file)
79
+ os.remove("%s/%s_pocket_%s_temp.pdb" % (workdir, protname, cutoff))
80
+
81
+
82
+ def get_fasta_seq(fasta_file):
83
+ with open(fasta_file) as f:
84
+ lines = []
85
+ for line in f.readlines():
86
+ lines.append(line.strip())
87
+ fasta = "".join(lines[1:])
88
+ return fasta
89
+
90
+
91
+ def read_fasta_from_protein(pdb_file, lig_file, target_id="test", cutoff=5.0, pdb_pocket_file="test_pocket.pdb", fasta_pocket_file="test_pocket.fasta"):
92
+ if not os.path.exists(pdb_file):
93
+ return
94
+
95
+ try:
96
+ extract_pocket(pdb_file,
97
+ lig_file,
98
+ cutoff=cutoff,
99
+ protname=target_id,
100
+ pdb_pocket_file=pdb_pocket_file,
101
+ ligname=f"{target_id}_ligand")
102
+ except Exception as e:
103
+ print(e)
104
+ return
105
+
106
+ os.system(f"./pdb2fasta {pdb_pocket_file} > {fasta_pocket_file}")
107
+ return get_fasta_seq(fasta_pocket_file)
108
+
109
+
110
+ def read_fasta_from_pocket(pocket_pdb_file, fasta_pocket_file="test_pocket.fasta"):
111
+ os.system(f"./pdb2fasta {pocket_pdb_file} > {fasta_pocket_file}")
112
+ return get_fasta_seq(fasta_pocket_file)
HGNN/screen_dataset.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import math
4
+ import numpy as np
5
+ import contextlib
6
+ import copy
7
+ import torch
8
+ from torch.utils.data import Dataset, sampler, DataLoader
9
+
10
+
11
+ def load_ligname():
12
+ pdbbind_lig_dict = {}
13
+ with open("./data/PDBbind_v2020/index/INDEX_general_PL_data.2020") as f:
14
+ for line in f.readlines():
15
+ if line.startswith('#'):
16
+ continue
17
+ line = line.strip().split()
18
+ lig = line[-1][1:-1]
19
+ if lig != "":
20
+ pdbid = line[0]
21
+ pdbbind_lig_dict[pdbid] = lig
22
+ else:
23
+ continue
24
+
25
+ with open("./data/PDBbind_v2020/index/INDEX_refined_data.2020") as f:
26
+ for line in f.readlines():
27
+ if line.startswith('#'):
28
+ continue
29
+ line = line.strip().split()
30
+ lig = line[-1][1:-1]
31
+ if lig != "":
32
+ pdbid = line[0]
33
+ pdbbind_lig_dict[pdbid] = lig
34
+ else:
35
+ continue
36
+ return pdbbind_lig_dict
37
+
38
+ def load_uniprotid():
39
+ uniprot_id_dict = {}
40
+ with open("./data/PDBbind_v2020/index/INDEX_refined_name.2020") as f:
41
+ for line in f.readlines():
42
+ if line.startswith('#'):
43
+ continue
44
+ line = line.strip().split()
45
+ uniprot_id = line[2]
46
+ if uniprot_id != "" and uniprot_id != "------":
47
+ pdbid = line[0]
48
+ uniprot_id_dict[pdbid] = uniprot_id
49
+
50
+ with open("./data/PDBbind_v2020/index/INDEX_general_PL_name.2020") as f:
51
+ for line in f.readlines():
52
+ if line.startswith('#'):
53
+ continue
54
+ line = line.strip().split()
55
+ uniprot_id = line[2]
56
+ if uniprot_id != "" and uniprot_id != "------":
57
+ pdbid = line[0]
58
+ uniprot_id_dict[pdbid] = uniprot_id
59
+
60
+ return uniprot_id_dict
61
+
62
+ def load_pocket_dude(result_root):
63
+ data_root = f"{result_root}/DUDE"
64
+ dude_pocket_feat = []
65
+ dude_pocket_name = []
66
+ for target in sorted(list(os.listdir(data_root))):
67
+ pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
68
+ dude_pocket_feat.append(pocket_arr)
69
+ dude_pocket_name.append(target)
70
+
71
+ dude_pocket_feat = np.concatenate(dude_pocket_feat, axis=0)
72
+ return dude_pocket_feat, dude_pocket_name
73
+
74
+ def load_pocket_dekois(result_root):
75
+ data_root = f"{result_root}/DEKOIS"
76
+ dekois_pocket_feat = []
77
+ dekois_pocket_name = []
78
+ for target in sorted(list(os.listdir(data_root))):
79
+ pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
80
+ dekois_pocket_feat.append(pocket_arr)
81
+ dekois_pocket_name.append(target)
82
+
83
+ dekois_pocket_feat = np.concatenate(dekois_pocket_feat, axis=0)
84
+ return dekois_pocket_feat, dekois_pocket_name
85
+
86
+ def load_pocket_pcba(result_root):
87
+ data_root = f"{result_root}/PCBA"
88
+ pcba_pocket_feat = []
89
+ pcba_pocket_name = []
90
+ for target in sorted(list(os.listdir(data_root))):
91
+ pocket_arr = np.load(f"{data_root}/{target}/saved_target_embed.npy", allow_pickle=True)
92
+ names_target = []
93
+ for names in json.load(open(f"{data_root}/{target}/saved_pocket_names.json")):
94
+ names_target += [ f"{target}_{x}" for x in names]
95
+
96
+ if pocket_arr.shape[0] == 1:
97
+ pocket_arr = np.concatenate([pocket_arr]*len(names_target), axis=0)
98
+ pcba_pocket_feat.append(pocket_arr)
99
+ pcba_pocket_name += names_target
100
+
101
+ pcba_pocket_feat = np.concatenate(pcba_pocket_feat, axis=0)
102
+ return pcba_pocket_feat, pcba_pocket_name
103
+
104
+
105
+ def read_cluster_file(cluster_file):
106
+ protein_clstr_dict = {}
107
+ with open(cluster_file) as f:
108
+ line_in_clstr = []
109
+ for line in f.readlines():
110
+ if line.startswith(">"):
111
+ for a in line_in_clstr:
112
+ for b in line_in_clstr:
113
+ if a not in protein_clstr_dict.keys():
114
+ protein_clstr_dict[a] = []
115
+ protein_clstr_dict[a].append(b)
116
+
117
+ line_in_clstr = []
118
+ else:
119
+ line_in_clstr.append(line.split('|')[1])
120
+ return protein_clstr_dict
121
+
122
+ def load_assayinfo(data_root, result_root):
123
+ labels = json.load(open(f"{data_root}/train_label_pdbbind_seq.json")) + \
124
+ json.load(open("../test_datasets/casf_label_seq.json"))
125
+ save_dir_bdb = f"{result_root}/BDB"
126
+ bdb_mol_smi = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json"))
127
+ bdb_mol_smi = set(bdb_mol_smi)
128
+ for label in labels:
129
+ label["assay_id"] = label["pockets"][0].split("_")[0]
130
+ label["domain"] = "pdbbind"
131
+
132
+ # breakpoint()
133
+ labels_bdb = json.load(open(f"{data_root}/train_label_blend_seq_full.json"))
134
+ non_repeat_uniprot = []
135
+ testset_uniport_root = "../test_datasets"
136
+ non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/dude.json"))]
137
+ non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/PCBA.json"))]
138
+ non_repeat_uniprot += [x[0] for x in json.load(open(f"{testset_uniport_root}/dekois.json"))]
139
+ non_repeat_uniprot_strict = []
140
+ protein_clstr_dict_40 = read_cluster_file(f"{data_root}/uniport40.clstr")
141
+ protein_clstr_dict_80 = read_cluster_file(f"{data_root}/uniport80.clstr")
142
+ for uniprot in non_repeat_uniprot:
143
+ non_repeat_uniprot_strict += protein_clstr_dict_80.get(uniprot, [])
144
+ non_repeat_uniprot_strict.append(uniprot)
145
+ old_len = len(labels_bdb)
146
+ non_repeat_assayids = json.load(open(os.path.join(data_root, "fep_assays.json")))
147
+ labels_bdb = [x for x in labels_bdb if (x["assay_id"] not in non_repeat_assayids)]
148
+ labels_bdb = [x for x in labels_bdb if (x["uniprot"] not in non_repeat_uniprot)]
149
+
150
+ labels_bdb_new = []
151
+ for label in labels_bdb:
152
+ ligands = label["ligands"]
153
+ ligands_new = []
154
+ for lig in ligands:
155
+ if lig["smi"] in bdb_mol_smi and lig["act"] >= 5:
156
+ ligands_new.append(lig)
157
+ label["ligands"] = ligands_new
158
+ if len(ligands_new) > 0:
159
+ labels_bdb_new.append(label)
160
+
161
+ labels += labels_bdb_new
162
+ for label in labels:
163
+ label["ligands"] = sorted(label["ligands"], key=lambda x: x["act"], reverse=True)
164
+
165
+ # labels = [x for x in labels if (x["uniprot"] not in non_repeat_uniprot_strict)]
166
+ return labels
167
+
168
+ def load_id_dict(result_root, assayinfo_lst):
169
+ import random
170
+ random.seed(42)
171
+ bdb_dir = f"{result_root}/BDB"
172
+ pdbbind_dir = f"{result_root}/PDBBind"
173
+
174
+ pocket_names = json.load(open(f"{bdb_dir}/bdb_pocket_names.json"))
175
+ pocket_embed = np.load(f"{bdb_dir}/bdb_pocket_reps.npy")
176
+ name2idx = {name:i for i, name in enumerate(pocket_names)}
177
+
178
+ assay_feat_lst = []
179
+ bdb_assayid_lst = []
180
+ for assay in assayinfo_lst:
181
+ assay_id = assay["assay_id"]
182
+ if assay.get("domain", None) == "pdbbind":
183
+ continue
184
+ pockets = assay["pockets"]
185
+ repeat_num = len(assay["ligands"])
186
+ repeat_num = int(np.sqrt(repeat_num))
187
+ for i in range(repeat_num):
188
+ pocket = random.choice(pockets)
189
+ assay_feat_lst.append(pocket_embed[name2idx[pocket]])
190
+ bdb_assayid_lst.append(assay_id)
191
+
192
+ bdb_assay_feat = np.stack(assay_feat_lst)
193
+
194
+ train_pdbbind_ids = json.load(open(f'{pdbbind_dir}/train_pdbbind_ids.json'))
195
+ train_pdbbind_pocket_embed = np.load(f"{pdbbind_dir}/train_pocket_reps.npy")
196
+ train_pdbbind_ids_new = []
197
+ train_pdbbind_pocket_embed_new = []
198
+ pdbbind_aidlist = [assay["assay_id"] for assay in assayinfo_lst if assay.get("domain", None) == "pdbbind"]
199
+ pdbbind_aidset = set(pdbbind_aidlist)
200
+ for id, embed in zip(train_pdbbind_ids, train_pdbbind_pocket_embed):
201
+ if id in pdbbind_aidset:
202
+ train_pdbbind_ids_new.append(id)
203
+ train_pdbbind_pocket_embed_new.append(embed)
204
+
205
+ train_pdbbind_ids = train_pdbbind_ids_new
206
+ train_pdbbind_pocket_embed = np.stack(train_pdbbind_pocket_embed_new)
207
+
208
+ train_pocket = bdb_assayid_lst + train_pdbbind_ids
209
+ pocket_feat_train = np.concatenate([bdb_assay_feat, train_pdbbind_pocket_embed])
210
+ test_pocket = json.load(open(f'{pdbbind_dir}/test_pdbbind_ids.json'))
211
+ pocket_feat_test = np.load(f'{pdbbind_dir}/test_pocket_reps.npy')
212
+
213
+ return train_pocket, test_pocket, pocket_feat_train, pocket_feat_test
214
+
215
+
216
+ def load_datas(data_root, result_root):
217
+ assayinfo_lst = load_assayinfo(data_root, result_root)
218
+ assayid_lst_train, assayid_lst_test, pocket_feat_train, pocket_feat_test = load_id_dict(result_root, assayinfo_lst)
219
+
220
+ dude_pocket_feat, dude_pocket_name = load_pocket_dude(result_root)
221
+
222
+ pcba_pocket_feat, pcba_pocket_name = load_pocket_pcba(result_root)
223
+
224
+ dekois_pocket_feat, dekois_pocket_name = load_pocket_dekois(result_root)
225
+
226
+ pocket_feat = np.concatenate((pocket_feat_train, pocket_feat_test, dude_pocket_feat, pcba_pocket_feat, dekois_pocket_feat), axis=0)
227
+ assayid_lst_all = assayid_lst_train + assayid_lst_test + dude_pocket_name + pcba_pocket_name + dekois_pocket_name
228
+
229
+ save_dir_bdb = f"{result_root}/BDB"
230
+ save_dir_pdbbind = f"{result_root}/PDBBind"
231
+ mol_feat_train_bdb = np.load(f'{save_dir_bdb}/bdb_mol_reps.npy')
232
+ mol_feat_train_pdbbind = np.load(f'{save_dir_pdbbind}/train_mol_reps.npy')
233
+ mol_feat_test = np.load(f'{save_dir_pdbbind}/test_mol_reps.npy')
234
+ mol_feat = np.concatenate((mol_feat_train_bdb, mol_feat_train_pdbbind, mol_feat_test), axis=0)
235
+ mol_smi_lst = json.load(open(f"{save_dir_bdb}/bdb_mol_smis.json")) + json.load(open(f"{save_dir_pdbbind}/train_mol_smis.json")) + json.load(open(f"{save_dir_pdbbind}/test_mol_smis.json"))
236
+ test_len = len(json.load(open(f"{save_dir_pdbbind}/test_mol_smis.json")))
237
+ test_molidxes = range(len(mol_smi_lst)-test_len, len(mol_smi_lst))
238
+ return assayinfo_lst, pocket_feat, mol_feat, assayid_lst_all, mol_smi_lst, \
239
+ assayid_lst_train, assayid_lst_test, dude_pocket_name, pcba_pocket_name, dekois_pocket_name, test_molidxes
240
+
241
+ def load_valid_label(assayid_lst_test):
242
+ coreset = list(open("./data/CoreSet.dat").readlines())[1:]
243
+ pdbid2cluster = {}
244
+ for line in coreset:
245
+ line = line.strip().split()
246
+ pdbid = line[0]
247
+ cluster = line[-1]
248
+ pdbid2cluster[pdbid] = cluster
249
+
250
+ labels = np.zeros((len(assayid_lst_test), len(assayid_lst_test)))
251
+ for i, pdbid_1 in enumerate(assayid_lst_test):
252
+ for j, pdbid_2 in enumerate(assayid_lst_test):
253
+ if pdbid2cluster[pdbid_1] != pdbid2cluster[pdbid_2]:
254
+ labels[i, j] = 0
255
+ else:
256
+ labels[i, j] = 1
257
+ return labels
258
+
259
+ def load_pocket_pocket_graph(data_root, assayid_lst_all, assayid_lst_train):
260
+ neighbor_dict_train = json.load(
261
+ open(f"{data_root}/align_pair_res_train_10.23.json"))
262
+ train_keys = json.load(
263
+ open(f"{data_root}/align_train_keys_10.23.json"))
264
+ neighbor_dict_train_new = {}
265
+ for idx, neighbors in neighbor_dict_train.items():
266
+ neighbor_dict_train_new[train_keys[int(idx)]] = neighbors
267
+ neighbor_dict_train = neighbor_dict_train_new
268
+ assayid_set = set(assayid_lst_all)
269
+ assayid_set_train = set(assayid_lst_train)
270
+ PPGraph = {}
271
+
272
+ for assayid_1 in neighbor_dict_train.keys():
273
+ if assayid_1 not in assayid_set:
274
+ continue
275
+ neighbor_dict_train[assayid_1] = sorted(neighbor_dict_train[assayid_1], key=lambda x: x[1], reverse=True)
276
+
277
+ score_new = []
278
+ for assayid_2, score in neighbor_dict_train[assayid_1]:
279
+ if assayid_2 not in assayid_set_train:
280
+ continue
281
+ if score < 0.5:
282
+ continue
283
+ score_new.append((assayid_2, score))
284
+ PPGraph[assayid_1] = score_new
285
+
286
+ import pickle
287
+ align_res_test = json.load(open(f"{data_root}/align_pair_res_test_10.23.json"))
288
+ align_score_test = {}
289
+
290
+ for test_id in align_res_test.keys():
291
+ if test_id not in assayid_set:
292
+ continue
293
+ pocket_sim_infos = align_res_test[test_id]
294
+ pocket_sim_infos = sorted(pocket_sim_infos, key=lambda x: x[1], reverse=True)
295
+ score_new = []
296
+ for test_target, score in pocket_sim_infos:
297
+ test_target = test_target.split('.')[0]
298
+ if test_target not in assayid_set_train:
299
+ continue
300
+ if score < 0.5:
301
+ continue
302
+ score_new.append((test_target, score))
303
+ align_score_test[test_id] = score_new
304
+
305
+ # breakpoint()
306
+ PPGraph = {**PPGraph, **align_score_test}
307
+ return PPGraph
308
+
309
+ @contextlib.contextmanager
310
+ def numpy_seed(seed, *addl_seeds):
311
+ """Context manager which seeds the NumPy PRNG with the specified seed and
312
+ restores the state afterward"""
313
+ if seed is None:
314
+ yield
315
+ return
316
+ if len(addl_seeds) > 0:
317
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
318
+ state = np.random.get_state()
319
+ np.random.seed(seed)
320
+ try:
321
+ yield
322
+ finally:
323
+ np.random.set_state(state)
324
+
325
+ class ScreenDataset(Dataset):
326
+ def __init__(self, batch_size, assay_graph, assayinfo_lst, assayid_lst_all, mol_smi_lst, assayid_lst_train):
327
+ self.batch_size = batch_size
328
+ self.train_idxes = list(range(len(assayid_lst_train)))
329
+ self.assayid_set_train = set(assayid_lst_train)
330
+ self.train_idxes_epoch = copy.deepcopy(self.train_idxes)
331
+ self.assay_graph = assay_graph
332
+ self.assayinfo_dicts = {x["assay_id"]: x for x in assayinfo_lst}
333
+ self.smi2idx = {smi:idx for idx, smi in enumerate(mol_smi_lst)}
334
+ self.uniprotid_dict = load_uniprotid()
335
+ self.pocket_lig_graph = self.load_graph()
336
+ self.seed = 66
337
+ self.assayid2idxes = {}
338
+ for idx, assayid in enumerate(assayid_lst_all):
339
+ if assayid not in self.assayid2idxes:
340
+ self.assayid2idxes[assayid] = []
341
+ self.assayid2idxes[assayid].append(idx)
342
+ self.idx2assayid = assayid_lst_all
343
+ self.epoch = 0
344
+
345
+ def set_epoch(self, epoch):
346
+ self.epoch = epoch
347
+ with numpy_seed(self.seed, epoch):
348
+ self.train_idxes_epoch = copy.deepcopy(self.train_idxes)
349
+ np.random.shuffle(self.train_idxes_epoch)
350
+
351
+ def load_graph(self):
352
+ pocket_lig_graph = {}
353
+ if os.path.exists("./data/pocket_lig_graph.json"):
354
+ pocket_lig_graph = json.load(open("./data/pocket_lig_graph.json"))
355
+ else:
356
+ from tqdm import tqdm
357
+ for assayid in tqdm(self.assayid2idxes.keys()):
358
+ if assayid not in self.assayid_set_train:
359
+ continue
360
+ ligands = self.assayinfo_dicts[assayid]["ligands"]
361
+ lig_candidate = []
362
+ if len(ligands) > 1:
363
+ lig_assay = [x["smi"] for x in ligands if x["act"] >= 5]
364
+ else:
365
+ lig_assay = [x["smi"] for x in ligands]
366
+ lig_candidate += lig_assay
367
+ lig_assay = set(lig_assay)
368
+ uniprot = self.assayinfo_dicts[assayid]["uniprot"]
369
+
370
+ for assayid_nbr, score in self.assay_graph.get(assayid, []):
371
+ if assayid_nbr not in self.assayinfo_dicts:
372
+ continue
373
+ assay_nbr = self.assayinfo_dicts[assayid_nbr]
374
+ uniprot_nbr = assay_nbr["uniprot"]
375
+ ligands_nbr = assay_nbr["ligands"]
376
+ if len(ligands) > 1:
377
+ lig_candidate_nbr = [x["smi"] for x in ligands_nbr if x["act"] >= 5]
378
+ else:
379
+ lig_candidate_nbr = [x["smi"] for x in ligands_nbr]
380
+ if assayid_nbr not in self.assayid_set_train:
381
+ continue
382
+ if len(lig_assay & set(lig_candidate_nbr)) > 0:
383
+ lig_candidate += lig_candidate_nbr
384
+ elif uniprot == uniprot_nbr:
385
+ lig_candidate += lig_candidate_nbr
386
+
387
+ pocket_lig_graph[assayid] = [x for x in set(lig_candidate) if x in self.smi2idx]
388
+
389
+ json.dump(pocket_lig_graph, open("./data/pocket_lig_graph.json", "w"))
390
+ return pocket_lig_graph
391
+
392
+ def __getitem__(self, item):
393
+ pocket_idx_batch = self.train_idxes_epoch[item*self.batch_size:(item+1)*self.batch_size]
394
+ pocket_batch = [self.idx2assayid[idx] for idx in pocket_idx_batch]
395
+ lig_batch = []
396
+ lig_idx_batch = []
397
+ epoch = self.epoch
398
+ for pocket in pocket_batch:
399
+ lig_candidate = self.pocket_lig_graph[pocket]
400
+ with numpy_seed(self.seed, epoch, item):
401
+ lig = np.random.choice(lig_candidate)
402
+ lig_batch.append(lig)
403
+
404
+ lig_idx_batch.append(self.smi2idx[lig])
405
+
406
+ labels = np.zeros((self.batch_size, self.batch_size))
407
+ for i, pocket in enumerate(pocket_batch):
408
+ for j, lig in enumerate(lig_batch):
409
+ if lig in self.pocket_lig_graph[pocket]:
410
+ labels[i, j] = 1
411
+ else:
412
+ labels[i, j] = 0
413
+
414
+ return torch.tensor(pocket_idx_batch), torch.tensor(lig_idx_batch), torch.tensor(labels)
415
+
416
+ def __len__(self):
417
+ return len(self.train_idxes_epoch) // self.batch_size
418
+
419
+
420
+
HGNN/screening.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from PL_Encoder import PLEncoder
8
+ from PL_Aggregator import PLAggregator
9
+ from PP_Encoder import PPEncoder
10
+ from PP_Aggregator import PPAggregator
11
+ import torch.nn.functional as F
12
+ import torch.utils.data
13
+ import argparse
14
+ import os
15
+ from util import cal_metrics
16
+ from read_fasta import read_fasta_from_pocket, read_fasta_from_protein
17
+ from align import get_neighbor_pocket
18
+
19
+
20
+ class HGNN(nn.Module):
21
+
22
+ def __init__(self, enc_u, enc_v=None, r2e=None):
23
+ super(HGNN, self).__init__()
24
+ self.enc_u = enc_u
25
+ self.enc_v = enc_v
26
+ self.embed_dim = enc_u.embed_dim
27
+
28
+ self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim)
29
+ self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim)
30
+ self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim)
31
+ self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim)
32
+
33
+ self.r2e = r2e
34
+ self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
35
+ self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5)
36
+
37
+ self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14))
38
+
39
+ def trainable_parameters(self):
40
+ for name, param in self.named_parameters(recurse=True):
41
+ if param.requires_grad:
42
+ yield param
43
+
44
+ def forward(self, nodes_u, nodes_v):
45
+ embeds_u = self.enc_u(nodes_u, nodes_v)
46
+ embeds_v = self.enc_v(nodes_v)
47
+ return embeds_u, embeds_v
48
+
49
+ def criterion(self, x_u, x_v, labels):
50
+
51
+ netout = torch.matmul(x_u, torch.transpose(x_v, 0, 1))
52
+ score = netout * self.logit_scale.exp().detach()
53
+ score = (labels - torch.eye(len(labels)).to(labels.device)) * -1e6 + score
54
+
55
+ lprobs_pocket = F.log_softmax(score.float(), dim=-1)
56
+ lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
57
+ sample_size = lprobs_pocket.size(0)
58
+ targets = torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
59
+
60
+ # pocket retrieve mol
61
+ loss_pocket = F.nll_loss(
62
+ lprobs_pocket,
63
+ targets,
64
+ reduction="mean"
65
+ )
66
+
67
+ lprobs_mol = F.log_softmax(torch.transpose(score.float(), 0, 1), dim=-1)
68
+ lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
69
+ lprobs_mol = lprobs_mol[:sample_size]
70
+
71
+ # mol retrieve pocket
72
+ loss_mol = F.nll_loss(
73
+ lprobs_mol,
74
+ targets,
75
+ reduction="mean"
76
+ )
77
+
78
+ loss = 0.5 * loss_pocket + 0.5 * loss_mol
79
+
80
+ ef_all = []
81
+ for i in range(len(netout)):
82
+ act_pocket = labels[i]
83
+ affi_pocket = netout[i]
84
+ top1_index = torch.argmax(affi_pocket)
85
+ top1_act = act_pocket[top1_index]
86
+ ef_all.append(cal_metrics(affi_pocket.detach().cpu().numpy(), act_pocket.detach().cpu().numpy()))
87
+ ef_mean = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()}
88
+
89
+ return loss, ef_mean, netout
90
+
91
+ def loss(self, nodes_u, nodes_v, labels):
92
+ x_u, x_v = self.forward(nodes_u, nodes_v)
93
+ loss, ef_mean, netout = self.criterion(x_u, x_v, labels)
94
+ return loss, ef_mean
95
+
96
+ def refine_pocket(self, pocket_embed, neighbor_pocket_list):
97
+ embeds_u = self.enc_u.refine_pocket(pocket_embed, neighbor_pocket_list)
98
+ return embeds_u
99
+
100
+
101
+
102
+ def main():
103
+ # Training settings
104
+ parser = argparse.ArgumentParser(description='HGNN model inference')
105
+ parser.add_argument('--embed_dim', type=int, default=128, metavar='N', help='embedding size')
106
+ parser.add_argument("--test_ckpt", type=str, default=None)
107
+ parser.add_argument("--data_root", type=str, default="../data")
108
+ parser.add_argument("--result_root", type=str, default="../result/pocket_ranking")
109
+ parser.add_argument("--pocket_embed", type=str, default="../example/pocket_embed.npy")
110
+ parser.add_argument("--save_file", type=str, default="../example/refined_pocket.npy")
111
+ parser.add_argument("--pocket_pdb", type=str, default=None)
112
+ parser.add_argument("--protein_pdb", type=str, default="../example/protein.pdb")
113
+ parser.add_argument("--ligand_pdb", type=str, default="../example/ligand.pdb")
114
+
115
+ args = parser.parse_args()
116
+
117
+ seed = 42
118
+ torch.manual_seed(seed)
119
+ torch.cuda.manual_seed(seed)
120
+ np.random.seed(seed)
121
+ random.seed(seed)
122
+
123
+ use_cuda = False
124
+ if torch.cuda.is_available():
125
+ use_cuda = True
126
+ device = torch.device("cuda" if use_cuda else "cpu")
127
+
128
+ embed_dim = args.embed_dim
129
+ type2e = nn.Embedding(10, embed_dim).to(device)
130
+
131
+ # load model
132
+ agg_pocket = PLAggregator(r2e=type2e, embed_dim=embed_dim, cuda=device, uv=True)
133
+ enc_pocket = PLEncoder(embed_dim=embed_dim, aggregator=agg_pocket, cuda=device, uv=True)
134
+ agg_pocket_sim = PPAggregator(embed_dim=embed_dim, cuda=device)
135
+ enc_pocket = PPEncoder(enc_pocket, embed_dim=embed_dim, aggregator=agg_pocket_sim, cuda=device)
136
+
137
+ model = HGNN(enc_pocket).to(device)
138
+ model.load_state_dict(torch.load(args.test_ckpt, weights_only=True), strict=False)
139
+ model.eval()
140
+
141
+ # load pocket embedding and fasta
142
+ pocket_embed = torch.tensor(np.load(args.pocket_embed)).to(device)
143
+
144
+ if args.pocket_pdb is not None:
145
+ pocket_fasta = read_fasta_from_pocket(args.pocket_pdb)
146
+ else:
147
+ pocket_fasta = read_fasta_from_protein(args.protein_pdb, args.ligand_pdb)
148
+
149
+ # get neighbor pocket
150
+ neighbor_pocket_list = get_neighbor_pocket(pocket_fasta, args.data_root, args.result_root, device) # [(pocket_embed, ligand_embed, similarity)]
151
+
152
+ # get refined pocket
153
+ if len(neighbor_pocket_list) > 0:
154
+ with torch.no_grad():
155
+ refined_pocket = model.refine_pocket(pocket_embed, neighbor_pocket_list)
156
+ refined_pocket = refined_pocket.cpu().numpy()
157
+ else:
158
+ refined_pocket = pocket_embed.cpu().numpy()
159
+
160
+ print("finished, saving refined pocket embedding into:", args.save_file)
161
+ np.save(args.save_file, refined_pocket)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ main()
HGNN/test_pocket.fasta ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ >aa2ar_pocket_5:_ 19
2
+ VLTLFEMMNWLXNHALMYI
HGNN/util.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
3
+ from sklearn.metrics import roc_curve
4
+
5
+ def re_new(y_true, y_score, ratio):
6
+ fp = 0
7
+ tp = 0
8
+ p = sum(y_true)
9
+ n = len(y_true) - p
10
+ num = ratio * n
11
+ sort_index = np.argsort(y_score)[::-1]
12
+ for i in range(len(sort_index)):
13
+ index = sort_index[i]
14
+ if y_true[index] == 1:
15
+ tp += 1
16
+ else:
17
+ fp += 1
18
+ if fp >= num:
19
+ break
20
+ return (tp * n) / (p * fp)
21
+
22
+
23
+ def calc_re(y_true, y_score, ratio_list):
24
+ fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
25
+ # print(fpr, tpr)
26
+ res = {}
27
+ res2 = {}
28
+ total_active_compounds = sum(y_true)
29
+ total_compounds = len(y_true)
30
+
31
+ # for ratio in ratio_list:
32
+ # for i, t in enumerate(fpr):
33
+ # if t > ratio:
34
+ # #print(fpr[i], tpr[i])
35
+ # if fpr[i-1]==0:
36
+ # res[str(ratio)]=tpr[i]/fpr[i]
37
+ # else:
38
+ # res[str(ratio)]=tpr[i-1]/fpr[i-1]
39
+ # break
40
+
41
+ for ratio in ratio_list:
42
+ res2[str(ratio)] = re_new(y_true, y_score, ratio)
43
+
44
+ # print(res)
45
+ # print(res2)
46
+ return res2
47
+
48
+
49
+ def cal_metrics(y_score, y_true, alpha=80.5):
50
+ """
51
+ Calculate BEDROC score.
52
+
53
+ Parameters:
54
+ - y_true: true binary labels (0 or 1)
55
+ - y_score: predicted scores or probabilities
56
+ - alpha: parameter controlling the degree of early retrieval emphasis
57
+
58
+ Returns:
59
+ - BEDROC score
60
+ """
61
+
62
+ # concate res_single and labels
63
+ scores = np.expand_dims(y_score, axis=1)
64
+ y_true = np.expand_dims(y_true, axis=1)
65
+ scores = np.concatenate((scores, y_true), axis=1)
66
+ # inverse sort scores based on first column
67
+ scores = scores[scores[:, 0].argsort()[::-1]]
68
+ bedroc = CalcBEDROC(scores, 1, 80.5)
69
+ count = 0
70
+ # sort y_score, return index
71
+ index = np.argsort(y_score)[::-1]
72
+ for i in range(int(len(index) * 0.005)):
73
+ if y_true[index[i]] == 1:
74
+ count += 1
75
+ auc = CalcAUC(scores, 1)
76
+ ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.05])
77
+ return {
78
+ "BEDROC": bedroc,
79
+ "AUC": auc,
80
+ "EF0.5": ef_list[0],
81
+ "EF1": ef_list[1],
82
+ "EF5": ef_list[2]
83
+ }
84
+
85
+
86
+ # import torch
87
+ # torch.multiprocessing.set_start_method('spawn', force=True)
88
+ # def mycollator(input_batch):
89
+ # for data in input_batch:
90
+ # node, neighbors = data
91
+ # node["pocket_data"] = torch.tensor(node["pocket_data"]).cuda()
92
+ # node["lig_data"] = torch.tensor(node["lig_data"]).cuda()
93
+ # for neighbor in neighbors:
94
+ # neighbor["pocket_data"] = torch.tensor(node["pocket_data"]).cuda()
95
+ # neighbor["lig_data"] = torch.tensor(node["lig_data"]).cuda()
96
+ # return input_batch
License ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attribution-NonCommercial 4.0 International
2
+
3
+ > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
4
+ >
5
+ > ### Using Creative Commons Public Licenses
6
+ >
7
+ > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+ >
9
+ > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
+ >
11
+ > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
+
13
+ ## Creative Commons Attribution-NonCommercial 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ ### Section 1 – Definitions.
18
+
19
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+
21
+ b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22
+
23
+ c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
24
+
25
+ d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
26
+
27
+ e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
28
+
29
+ f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
30
+
31
+ g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
32
+
33
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
34
+
35
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
36
+
37
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
38
+
39
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
40
+
41
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
42
+
43
+ ### Section 2 – Scope.
44
+
45
+ a. ___License grant.___
46
+
47
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
48
+
49
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
50
+
51
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
52
+
53
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
54
+
55
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
56
+
57
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
58
+
59
+ 5. __Downstream recipients.__
60
+
61
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
62
+
63
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
64
+
65
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
66
+
67
+ b. ___Other rights.___
68
+
69
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
70
+
71
+ 2. Patent and trademark rights are not licensed under this Public License.
72
+
73
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
74
+
75
+ ### Section 3 – License Conditions.
76
+
77
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
78
+
79
+ a. ___Attribution.___
80
+
81
+ 1. If You Share the Licensed Material (including in modified form), You must:
82
+
83
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
84
+
85
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
86
+
87
+ ii. a copyright notice;
88
+
89
+ iii. a notice that refers to this Public License;
90
+
91
+ iv. a notice that refers to the disclaimer of warranties;
92
+
93
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
94
+
95
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
96
+
97
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
98
+
99
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100
+
101
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102
+
103
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
104
+
105
+ ### Section 4 – Sui Generis Database Rights.
106
+
107
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
108
+
109
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
110
+
111
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
112
+
113
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
114
+
115
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
116
+
117
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
118
+
119
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
120
+
121
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
122
+
123
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
124
+
125
+ ### Section 6 – Term and Termination.
126
+
127
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
128
+
129
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
130
+
131
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
132
+
133
+ 2. upon express reinstatement by the Licensor.
134
+
135
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
136
+
137
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
138
+
139
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
140
+
141
+ ### Section 7 – Other Terms and Conditions.
142
+
143
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
144
+
145
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
146
+
147
+ ### Section 8 – Interpretation.
148
+
149
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
150
+
151
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
152
+
153
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
154
+
155
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
156
+
157
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
158
+ >
159
+ > Creative Commons may be contacted at creativecommons.org
README.md CHANGED
@@ -1,3 +1,206 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## General
2
+ This repository contains the code for **LigUnity**: **Hierarchical affinity landscape navigation through learning a shared pocket-ligand space.**
3
+
4
+ **We are excited to announce that our paper has been accepted by Patterns and is featured as the cover article for the October 2025 issue!**
5
+
6
+
7
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green?style=flat-square)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
8
+ [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red?style=flat-square)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
9
+ [![DOI:10.1016/j.patter.2025.101371](http://img.shields.io/badge/DOI-10.1101/2025.02.17.638554-B31B1B.svg)](https://doi.org/10.1016/j.patter.2025.101371)
10
+ [![GitHub Link](https://img.shields.io/badge/GitHub-blue?style=flat-square&logo=github)](https://github.com/IDEA-XL/LigUnity)
11
+
12
+ <table>
13
+ <tr>
14
+ <td width="250px" valign="top">
15
+ <a href="https://www.cell.com/patterns/fulltext/S2666-3899(25)00219-3">
16
+ <img src="https://github.com/user-attachments/assets/5ab7f659-0b56-4cf1-8db0-7129d71ea9d5" alt="LigUnity Patterns Cover Image" width="230px" />
17
+ </a>
18
+ </td>
19
+ <td valign="top">
20
+ <p>
21
+ <strong>On the cover:</strong> This ocean symbolizes the human proteome—the complete set of proteins that carry out essential functions in our bodies. For medicine to work, it often needs to interact with a specific protein. For an estimated 90% of these proteins, however, they lack known small-molecule ligands with high activity. In the image, these proteins are represented as sailboats drifting in the dark.
22
+ </p>
23
+ <p>
24
+ At the center, stands a lighthouse symbolizing the AI method <strong>LigUnity</strong>. Its beam illuminates several sailboats, guiding them toward glowing buoys, which symbolize ligands with high activity found by LigUnity. The work by Feng et al. highlights the power of AI-driven computational methods to efficiently find active ligands and optimize their activity, opening up new therapeutic avenues for various diseases.
25
+ </p>
26
+ </td>
27
+ </tr>
28
+ </table>
29
+
30
+ ## Instruction on running our model
31
+
32
+ ### Virtual Screening
33
+ Colab demo for virtual screening with given protein pocket and candidate ligands.
34
+
35
+ https://colab.research.google.com/drive/1F0QSPjkKKLAfBexmIQotcs-jm87ohHeG?usp=sharing
36
+
37
+ ### Hit-to-lead optimization
38
+ **Direct inference**
39
+ Colab demo for code inference with given protein and unmeasured ligands.
40
+
41
+ https://colab.research.google.com/drive/11Fx6mO51rRkPvq71qupuUmscfBw8Dw5R?usp=sharing
42
+
43
+ **Few-shot fine-tuning**
44
+ Colab demo for few-shot fine-tuning with given protein, few measure ligands for fine-tuning and unmeasured ligands for testing.
45
+
46
+ https://colab.research.google.com/drive/1gf0HhgyqI4qBjUAUICCvDa-FnTaARmR_?usp=sharing
47
+
48
+ Please feel free to contact me by email if there is any problem with the code or paper: fengbin@idea.edu.cn.
49
+
50
+ ### Resource availability
51
+
52
+ The datasets for LigUnity were collected from ChEMBL version 34 and BindingDB version 2024m5. Our training dataset is available on figshare (https://doi.org/10.6084/m9.figshare.27966819). Our PocketAffDB with protein and pocket PDB structures is available on figshare (https://doi.org/10.6084/m9.figshare.29379161).
53
+
54
+ ## Abstract
55
+
56
+ Protein-ligand binding affinity plays an important role in drug discovery, especially during virtual screening and hit-to-lead optimization. Computational chemistry and machine learning methods have been developed to investigate these tasks. Despite the encouraging performance, virtual screening and hit-to-lead optimization are often studied separately by existing methods, partially because they are performed sequentially in the existing drug discovery pipeline, thereby overlooking their interdependency and complementarity. To address this problem, we propose LigUnity, a foundation model for protein-ligand binding prediction by jointly optimizing virtual screening and hit-to-lead optimization.
57
+ In particular, LigUnity learns coarse-grained active/inactive distinction for virtual screening, and fine-grained pocket-specific ligand preference for hit-to-lead optimization.
58
+ We demonstrate the effectiveness and versatility of LigUnity on eight benchmarks across virtual screening and hit-to-lead optimization. In virtual screening, LigUnity outperforms 24 competing methods with more than 50% improvement on the DUD-E and Dekois 2.0 benchmarks, and shows robust generalization to novel proteins. In hit-to-lead optimization, LigUnity achieves the best performance on split-by-time, split-by-scaffold, and split-by-unit settings, further demonstrating its potential as a cost-effective alternative to free energy perturbation (FEP) calculations. We further showcase how LigUnity can be employed in an active learning framework to efficiently identify active ligands for TYK2, a therapeutic target for autoimmune diseases, yielding over 40% improved prediction performance. Collectively, these comprehensive results establish LigUnity as a versatile foundation model for both virtual screening and hit-to-lead optimization, offering broad applicability across the drug discovery pipeline through accurate protein-ligand affinity predictions.
59
+
60
+
61
+
62
+ ## Reproduce results in our paper
63
+
64
+ ### Reproduce results on virtual screening benchmarks
65
+
66
+ Please first download checkpoints and processed dataset before running
67
+ - Download our procesed Dekois 2.0 dataset from https://doi.org/10.6084/m9.figshare.27967422
68
+ - Download LIT-PCBA and DUD-E datasets from https://drive.google.com/drive/folders/1zW1MGpgunynFxTKXC2Q4RgWxZmg6CInV?usp=sharing
69
+ - Clone model checkpoint from https://huggingface.co/fengb/LigUnity_VS (test proteins in DUD-E, Dekois, and LIT-PCBA are removed from the training set)
70
+ - Clone dataset from https://figshare.com/articles/dataset/LigUnity_project_data/27966819 and unzip them all (you can ignore .lmdb file if you only want to reproduce test result).
71
+
72
+ ```
73
+ # run pocket/protein and ligand encoder model
74
+ path2weight="absolute path to the checkpoint of pocket_ranking"
75
+ CUDA_VISIBLE_DEVICES=0 bash test.sh ALL pocket_ranking ${path2weight} "./result/pocket_ranking"
76
+ CUDA_VISIBLE_DEVICES=0 bash test.sh BDB pocket_ranking ${path2weight} "./result/pocket_ranking"
77
+ CUDA_VISIBLE_DEVICES=0 bash test.sh PDB pocket_ranking ${path2weight} "./result/pocket_ranking"
78
+
79
+ path2weight="absolute path to the checkpoint of protein_ranking"
80
+ CUDA_VISIBLE_DEVICES=0 bash test.sh ALL protein_ranking ${path2weight} "./result/protein_ranking"
81
+ CUDA_VISIBLE_DEVICES=0 bash test.sh BDB protein_ranking ${path2weight} "./result/protein_ranking"
82
+ CUDA_VISIBLE_DEVICES=0 bash test.sh PDB protein_ranking ${path2weight} "./result/protein_ranking"
83
+
84
+ # train H-GNN model
85
+ cd ./HGNN
86
+ path2weight_HGNN="absolute path to the checkpoint of HGNN pocket"
87
+ python main.py --data_root ${path2data} --result_root "../result/pocket_ranking" --test_ckpt ${path2weight_HGNN}
88
+ path2weight_HGNN="absolute path to the checkpoint of HGNN protein"
89
+ python main.py --data_root ${path2data} --result_root "../result/protein_ranking" --test_ckpt ${path2weight_HGNN}
90
+
91
+ # get final prediction of our model
92
+ python ensemble_result.py DUDE PCBA DEKOIS
93
+ ```
94
+
95
+
96
+ ### Reproduce results on FEP benchmarks (zero-shot)
97
+
98
+ Please first download checkpoints before running
99
+ - Clone model checkpoint from https://huggingface.co/fengb/LigUnity_pocket_ranking and https://huggingface.co/fengb/LigUnity_protein_ranking (test ligands and assays in FEP benchmarks are removed from the training set)
100
+
101
+ ```
102
+ # run pocket/protein and ligand encoder model
103
+ for r in {1..6} do
104
+ path2weight="path to checkpoint of pocket_ranking"
105
+ path2result="./result/pocket_ranking/FEP/repeat_{r}"
106
+ CUDA_VISIBLE_DEVICES=0 bash test.sh FEP pocket_ranking ${path2weight} ${path2result}
107
+
108
+ path2weight="path to checkpoint of protein_ranking"
109
+ path2result="./result/protein_ranking/FEP/repeat_{r}"
110
+ CUDA_VISIBLE_DEVICES=0 bash test.sh FEP protein_ranking ${path2weight} ${path2result}
111
+ done
112
+
113
+ # get final prediction of our model
114
+ python ensemble_result.py FEP
115
+ ```
116
+
117
+ ### Reproduce results on FEP benchmarks (few-shot)
118
+ ```
119
+ # use the same checkpoints as in zero-shot
120
+ # run few-shot fine-tuning
121
+ for r in {1..6} do
122
+ path2weight="path to checkpoint of pocket_ranking"
123
+ path2result="./result/pocket_ranking/FEP_fewshot/repeat_{r}"
124
+ support_num=0.6
125
+ CUDA_VISIBLE_DEVICES=0 bash test_fewshot.sh FEP pocket_ranking support_num ${path2weight} ${path2result}
126
+
127
+ path2weight="path to checkpoint of protein_ranking"
128
+ path2result="./result/protein_ranking/FEP_fewshot/repeat_{r}"
129
+ CUDA_VISIBLE_DEVICES=0 bash test_fewshot.sh FEP protein_ranking support_num ${path2weight} ${path2result}
130
+ done
131
+
132
+ # get final prediction of our model
133
+ python ensemble_result_fewshot.py FEP_fewshot support_num
134
+ ```
135
+
136
+ ### Reproduce results on active learning
137
+ to speed up the active learning process, you should modify the unicore code
138
+ 1. find the installed dir of unicore (root-to-unicore)
139
+ ```
140
+ python -c "import unicore; print('/'.join(unicore.__file__.split('/')[:-2]))"
141
+ ```
142
+
143
+ 2. goto root-to-unicore/unicore/options.py line 250, add following line
144
+ ```
145
+ group.add_argument('--validate-begin-epoch', type=int, default=0, metavar='N',
146
+ help='validate begin epoch')
147
+ ```
148
+
149
+ 3. goto root-to-unicore/unicore_cli/train.py line 303, add one line
150
+ ```
151
+ do_validate = (
152
+ (not end_of_epoch and do_save)
153
+ or (
154
+ end_of_epoch
155
+ and epoch_itr.epoch >= args.validate_begin_epoch # !!!! add this line
156
+ and epoch_itr.epoch % args.validate_interval == 0
157
+ and not args.no_epoch_checkpoints
158
+ )
159
+ or should_stop
160
+ or (
161
+ args.validate_interval_updates > 0
162
+ and num_updates > 0
163
+ and num_updates % args.validate_interval_updates == 0
164
+ )
165
+ ) and not args.disable_validation
166
+ ```
167
+
168
+ 4. run the active learning procedure
169
+ ```
170
+ # use the same checkpoints as in FEP experiments
171
+ path1="path to checkpoint of pocket_ranking"
172
+ path2="path to checkpoint of protein_ranking"
173
+ result1="./result/pocket_ranking/TYK2"
174
+ result2="./result/protein_ranking/TYK2"
175
+
176
+ # run active learning cycle for 5 iters with pure greedy strategy
177
+ bash ./active_learning_scripts/run_al.sh 5 0 path1 path2 result1 result2
178
+ ```
179
+ ## Citation
180
+
181
+ ```
182
+ @article{feng2025hierarchical,
183
+ title={Hierarchical affinity landscape navigation through learning a shared pocket-ligand space},
184
+ author={Feng, Bin and Liu, Zijing and Li, Hao and Yang, Mingjun and Zou, Junjie and Cao, He and Li, Yu and Zhang, Lei and Wang, Sheng},
185
+ journal={Patterns},
186
+ year={2025},
187
+ publisher={Elsevier}
188
+ }
189
+
190
+ @article{feng2024bioactivity,
191
+ title={A bioactivity foundation model using pairwise meta-learning},
192
+ author={Feng, Bin and Liu, Zequn and Huang, Nanlan and Xiao, Zhiping and Zhang, Haomiao and Mirzoyan, Srbuhi and Xu, Hanwen and Hao, Jiaran and Xu, Yinghui and Zhang, Ming and others},
193
+ journal={Nature Machine Intelligence},
194
+ volume={6},
195
+ number={8},
196
+ pages={962--974},
197
+ year={2024},
198
+ publisher={Nature Publishing Group UK London}
199
+ }
200
+ ```
201
+
202
+ ## Acknowledgments
203
+
204
+ This project was built based on Uni-Mol (https://github.com/deepmodeling/Uni-Mol)
205
+
206
+ Parts of our code reference the implementation from DrugCLIP (https://github.com/bowen-gao/DrugCLIP) by bowen-gao
active_learning_scripts/run_al.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_cycles=${1}
2
+ begin_greedy=${2}
3
+ weight_path1=${3}
4
+ weight_path2=${4}
5
+ result_path1=${5}
6
+ result_path2=${6}
7
+
8
+ python ./active_learning_scripts/run_cycle_ours.py \
9
+ --input_file ../PARank_data_curation/case_study/tyk2_fep_label.csv \
10
+ --results_dir_1 ${result_path1} \
11
+ --results_dir_2 ${result_path2} \
12
+ --al_batch_size 100 \
13
+ --num_cycles ${num_cycles} \
14
+ --arch_1 pocket_ranking \
15
+ --arch_2 protein_ranking \
16
+ --weight_path_1 ${weight_path1} \
17
+ --weight_path_2 ${weight_path2} \
18
+ --lr 0.0001 \
19
+ --device 0 \
20
+ --master_port 10071 \
21
+ --base_seed 42 \
22
+ --begin_greedy ${begin_greedy}
active_learning_scripts/run_cycle_ensemble.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import subprocess
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import argparse
8
+ import json
9
+ import subprocess
10
+ from concurrent.futures import ThreadPoolExecutor, wait
11
+
12
+ def parse_arguments():
13
+ parser = argparse.ArgumentParser(description='Active Learning Cycle for Ligand Prediction')
14
+
15
+ # Input/Output arguments
16
+ parser.add_argument('--input_file', type=str, required=True,
17
+ help='Input CSV file containing ligand data (e.g., tyk2_fep.csv)')
18
+ parser.add_argument('--results_dir_1', type=str, required=True,
19
+ help='Results directory for first model')
20
+ parser.add_argument('--results_dir_2', type=str, required=True,
21
+ help='Results directory for second model')
22
+ parser.add_argument('--al_batch_size', type=int, required=True,
23
+ help='Number of samples for each active learning batch')
24
+
25
+ # Experiment configuration
26
+ parser.add_argument('--num_repeats', type=int, default=5,
27
+ help='Number of repeated experiments (default: 5)')
28
+ parser.add_argument('--num_cycles', type=int, required=True,
29
+ help='Number of active learning cycles')
30
+
31
+ # Model configuration
32
+ parser.add_argument('--arch_1', type=str, required=True,
33
+ help='First model architecture')
34
+ parser.add_argument('--arch_2', type=str, required=True,
35
+ help='Second model architecture')
36
+ parser.add_argument('--weight_path_1', type=str, required=True,
37
+ help='Path to first model pretrained weights')
38
+ parser.add_argument('--weight_path_2', type=str, required=True,
39
+ help='Path to second model pretrained weights')
40
+ parser.add_argument('--lr', type=float, default=0.001,
41
+ help='Learning rate (default: 0.001)')
42
+ parser.add_argument('--master_port', type=int, default=29500,
43
+ help='Master port for distributed training (default: 29500)')
44
+ parser.add_argument('--device', type=int, default=0,
45
+ help='Base device to run the models on (default: 0)')
46
+ parser.add_argument('--begin_greedy', type=int, default=0,
47
+ help='iter of begin to be pure greedy, using half greedy before')
48
+
49
+ # Random seed
50
+ parser.add_argument('--base_seed', type=int, default=42,
51
+ help='Base random seed (default: 42)')
52
+
53
+ return parser.parse_args()
54
+
55
+
56
+ def _run(cmd):
57
+ import os
58
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
59
+ subprocess.run(cmd, check=True, cwd=project_root)
60
+
61
+
62
+ def run_model(arch_1, arch_2, weight_path_1, weight_path_2, results_path_1, results_path_2, result_file, lr,
63
+ master_port, train_ligf, test_ligf, device):
64
+ cmd1 = [
65
+ "bash", "./active_learning_scripts/run_model.sh",
66
+ arch_1,
67
+ weight_path_1,
68
+ results_path_1,
69
+ result_file,
70
+ str(lr),
71
+ str(master_port),
72
+ train_ligf,
73
+ test_ligf,
74
+ str(device)
75
+ ]
76
+
77
+ cmd2 = [
78
+ "bash", "./active_learning_scripts/run_model.sh",
79
+ arch_2,
80
+ weight_path_2,
81
+ results_path_2,
82
+ result_file,
83
+ str(lr),
84
+ str(master_port + 1),
85
+ train_ligf,
86
+ test_ligf,
87
+ str(device + 1)
88
+ ]
89
+
90
+ with ThreadPoolExecutor(max_workers=2) as executor:
91
+ task1 = executor.submit(_run, cmd=cmd1)
92
+ task2 = executor.submit(_run, cmd=cmd2)
93
+ wait([task1, task2])
94
+
95
+
96
+ def read_predictions(results_path, result_file):
97
+ """
98
+ Read predictions from a single model
99
+ """
100
+ predictions = {}
101
+
102
+ jsonl_path = os.path.join(results_path, result_file)
103
+ with open(jsonl_path, 'r') as f:
104
+ first_line = json.loads(f.readline().strip())
105
+ smiles_list = first_line["tyk2"]["smiles"]
106
+ all_predictions = []
107
+ for line in f:
108
+ pred_line = json.loads(line.strip())
109
+ all_predictions.append(pred_line["tyk2"]["pred"])
110
+
111
+ # Convert to numpy array and calculate mean predictions
112
+ pred_array = np.array(all_predictions)
113
+ mean_predictions = np.mean(pred_array, axis=0)
114
+
115
+ # Create dictionary mapping SMILES to predictions
116
+ for smile, pred in zip(smiles_list, mean_predictions):
117
+ predictions[smile] = float(pred)
118
+
119
+ return predictions
120
+
121
+ def prepare_initial_split(input_file, results_dir_1, results_dir_2, al_batch_size, repeat_idx, cycle_idx, base_seed):
122
+ # Read all ligands
123
+ df = pd.read_csv(input_file)
124
+
125
+ # Set random seed for reproducibility
126
+ random.seed(base_seed + repeat_idx)
127
+
128
+ # Randomly select ligands for training and testing
129
+ all_indices = list(range(len(df)))
130
+ train_indices = random.sample(all_indices, al_batch_size)
131
+ test_indices = [i for i in all_indices if i not in train_indices]
132
+
133
+ # Create train and test files
134
+ train_df = df.iloc[train_indices]
135
+ test_df = df.iloc[test_indices]
136
+
137
+ # Create file names for both directories
138
+ train_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
139
+ test_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
140
+
141
+ train_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
142
+ test_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
143
+
144
+ # Create directories if they don't exist
145
+ os.makedirs(os.path.dirname(train_file_1), exist_ok=True)
146
+ os.makedirs(os.path.dirname(train_file_2), exist_ok=True)
147
+
148
+ # Save files to both directories
149
+ train_df.to_csv(train_file_1, index=False)
150
+ test_df.to_csv(test_file_1, index=False)
151
+ train_df.to_csv(train_file_2, index=False)
152
+ test_df.to_csv(test_file_2, index=False)
153
+
154
+ return train_file_1, test_file_1, train_file_2, test_file_2
155
+
156
+
157
+ def read_and_combine_predictions(results_path_1, results_path_2, result_file):
158
+ """
159
+ Read predictions from both models and calculate average predictions
160
+ """
161
+ predictions = {}
162
+
163
+ # Read predictions from model 1
164
+ jsonl_path_1 = os.path.join(results_path_1, result_file)
165
+ with open(jsonl_path_1, 'r') as f:
166
+ first_line = json.loads(f.readline().strip())
167
+ smiles_list = first_line["tyk2"]["smiles"]
168
+ all_predictions_1 = []
169
+ for line in f:
170
+ pred_line = json.loads(line.strip())
171
+ all_predictions_1.append(pred_line["tyk2"]["pred"])
172
+
173
+ # Read predictions from model 2
174
+ jsonl_path_2 = os.path.join(results_path_2, result_file)
175
+ with open(jsonl_path_2, 'r') as f:
176
+ f.readline() # skip first line as we already have smiles_list
177
+ all_predictions_2 = []
178
+ for line in f:
179
+ pred_line = json.loads(line.strip())
180
+ all_predictions_2.append(pred_line["tyk2"]["pred"])
181
+
182
+ # Convert to numpy arrays
183
+ pred_array_1 = np.array(all_predictions_1)
184
+ pred_array_2 = np.array(all_predictions_2)
185
+
186
+ # Calculate mean predictions across both models
187
+ mean_predictions = (np.mean(pred_array_1, axis=0) + np.mean(pred_array_2, axis=0)) / 2
188
+
189
+ # Create dictionary mapping SMILES to average predictions
190
+ for smile, pred in zip(smiles_list, mean_predictions):
191
+ predictions[smile] = float(pred)
192
+
193
+ return predictions
194
+
195
+
196
+ def update_splits(results_dir_1, results_dir_2, predictions_1, predictions_2,
197
+ prev_train_file_1, prev_test_file_1,
198
+ prev_train_file_2, prev_test_file_2,
199
+ repeat_idx, cycle_idx, al_batch_size, begin_greedy):
200
+ # Read previous test files
201
+ test_df_1 = pd.read_csv(prev_test_file_1)
202
+ test_df_2 = pd.read_csv(prev_test_file_2)
203
+
204
+ # Add predictions to test_df
205
+ test_df_1['prediction_1'] = test_df_1['Smiles'].map(predictions_1)
206
+ test_df_1['prediction_2'] = test_df_1['Smiles'].map(predictions_2)
207
+ test_df_1['prediction'] = (test_df_1['prediction_1'] + test_df_1['prediction_2']) / 2
208
+
209
+ # Sort by average predictions (high to low)
210
+ test_df_sorted = test_df_1.sort_values('prediction', ascending=False)
211
+
212
+ # Read previous train files
213
+ train_df_1 = pd.read_csv(prev_train_file_1)
214
+ train_df_2 = pd.read_csv(prev_train_file_2)
215
+
216
+ # Create new file names for both directories
217
+ new_train_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
218
+ new_test_file_1 = os.path.join(results_dir_1, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
219
+ new_train_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
220
+ new_test_file_2 = os.path.join(results_dir_2, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
221
+
222
+ # Create directories if they don't exist
223
+ os.makedirs(os.path.dirname(new_train_file_1), exist_ok=True)
224
+ os.makedirs(os.path.dirname(new_train_file_2), exist_ok=True)
225
+
226
+ if cycle_idx >= begin_greedy:
227
+ # Take top al_batch_size compounds for training
228
+ new_train_compounds = test_df_sorted.head(al_batch_size)
229
+ remaining_test_compounds = test_df_sorted.iloc[al_batch_size:]
230
+ else:
231
+ # use half greedy approach
232
+ new_train_compounds_tmp_1 = test_df_sorted.head(al_batch_size//2)
233
+ remaining_test_compounds_tmp = test_df_sorted.iloc[al_batch_size//2:]
234
+ all_indices = list(range(len(remaining_test_compounds_tmp)))
235
+
236
+ train_indices = random.sample(all_indices, al_batch_size - al_batch_size//2)
237
+ test_indices = [i for i in all_indices if i not in train_indices]
238
+ remaining_test_compounds = remaining_test_compounds_tmp.iloc[test_indices]
239
+ new_train_compounds_tmp_2 = remaining_test_compounds_tmp.iloc[train_indices]
240
+ new_train_compounds = pd.concat([new_train_compounds_tmp_1, new_train_compounds_tmp_2])
241
+
242
+ # Combine with previous training data
243
+ combined_train_df = pd.concat([train_df_1, new_train_compounds])
244
+
245
+ for _ in range(3):
246
+ print("########################################")
247
+ print("Cycling: ", cycle_idx)
248
+ print("top_1p: {}/100".format(combined_train_df['top_1p'].sum()))
249
+ print("top_2p: {}/200".format(combined_train_df['top_2p'].sum()))
250
+ print("top_5p: {}/500".format(combined_train_df['top_5p'].sum()))
251
+
252
+ # Save files for both models (same content, different directories)
253
+ combined_train_df.to_csv(new_train_file_1, index=False)
254
+ remaining_test_compounds.to_csv(new_test_file_1, index=False)
255
+ combined_train_df.to_csv(new_train_file_2, index=False)
256
+ remaining_test_compounds.to_csv(new_test_file_2, index=False)
257
+
258
+ return (new_train_file_1, new_test_file_1,
259
+ new_train_file_2, new_test_file_2)
260
+
261
+
262
+ def run_active_learning(args):
263
+ # Create base results directories
264
+ os.system(f"rm -rf {args.results_dir_1}")
265
+ os.system(f"rm -rf {args.results_dir_2}")
266
+ os.makedirs(args.results_dir_1, exist_ok=True)
267
+ os.makedirs(args.results_dir_2, exist_ok=True)
268
+
269
+ for repeat_idx in range(args.num_repeats):
270
+ print(f"Starting repeat {repeat_idx}")
271
+
272
+ # Initial split for this repeat
273
+ train_file_1, test_file_1, train_file_2, test_file_2 = prepare_initial_split(
274
+ args.input_file,
275
+ args.results_dir_1,
276
+ args.results_dir_2,
277
+ args.al_batch_size,
278
+ repeat_idx,
279
+ 0, # First cycle
280
+ args.base_seed
281
+ )
282
+
283
+ for cycle_idx in range(args.num_cycles):
284
+ print(f"Running cycle {cycle_idx} for repeat {repeat_idx}")
285
+
286
+ # Result file name
287
+ result_file = f"repeat_{repeat_idx}_cycle_{cycle_idx}_results.jsonl"
288
+ if os.path.exists(f"{args.results_dir_1}/{result_file}"):
289
+ os.remove(f"{args.results_dir_1}/{result_file}")
290
+ if os.path.exists(f"{args.results_dir_2}/{result_file}"):
291
+ os.remove(f"{args.results_dir_2}/{result_file}")
292
+
293
+ # Run both models
294
+ run_model(
295
+ arch_1=args.arch_1,
296
+ arch_2=args.arch_2,
297
+ weight_path_1=args.weight_path_1,
298
+ weight_path_2=args.weight_path_2,
299
+ results_path_1=args.results_dir_1,
300
+ results_path_2=args.results_dir_2,
301
+ result_file=result_file,
302
+ lr=args.lr,
303
+ master_port=args.master_port,
304
+ train_ligf=train_file_1,
305
+ test_ligf=test_file_1,
306
+ device=args.device
307
+ )
308
+
309
+ # Update splits for next cycle
310
+ if cycle_idx < args.num_cycles - 1:
311
+ # Read predictions from both models separately
312
+ predictions_1 = read_predictions(args.results_dir_1, result_file)
313
+ predictions_2 = read_predictions(args.results_dir_2, result_file)
314
+
315
+ # Update splits for both models
316
+ train_file_1, test_file_1, train_file_2, test_file_2 = update_splits(
317
+ args.results_dir_1,
318
+ args.results_dir_2,
319
+ predictions_1,
320
+ predictions_2,
321
+ train_file_1,
322
+ test_file_1,
323
+ train_file_2,
324
+ test_file_2,
325
+ repeat_idx,
326
+ cycle_idx + 1,
327
+ args.al_batch_size,
328
+ args.begin_greedy
329
+ )
330
+
331
+
332
+ if __name__ == "__main__":
333
+ args = parse_arguments()
334
+ run_active_learning(args)
active_learning_scripts/run_cycle_one_model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import subprocess
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import argparse
8
+ import json
9
+
10
+
11
+ def parse_arguments():
12
+ parser = argparse.ArgumentParser(description='Active Learning Cycle for Ligand Prediction')
13
+
14
+ # Input/Output arguments
15
+ parser.add_argument('--input_file', type=str, required=True,
16
+ help='Input CSV file containing ligand data (e.g., tyk2_fep.csv)')
17
+ parser.add_argument('--results_dir', type=str, required=True,
18
+ help='Base directory for storing all results')
19
+ parser.add_argument('--al_batch_size', type=int, required=True,
20
+ help='Number of samples for each active learning batch')
21
+
22
+ # Experiment configuration
23
+ parser.add_argument('--num_repeats', type=int, default=5,
24
+ help='Number of repeated experiments (default: 5)')
25
+ parser.add_argument('--num_cycles', type=int, required=True,
26
+ help='Number of active learning cycles')
27
+
28
+ # Model configuration
29
+ parser.add_argument('--arch', type=str, required=True,
30
+ help='Model architecture')
31
+ parser.add_argument('--weight_path', type=str, required=True,
32
+ help='Path to pretrained model weights')
33
+ parser.add_argument('--lr', type=float, default=0.001,
34
+ help='Learning rate (default: 0.001)')
35
+ parser.add_argument('--master_port', type=int, default=29500,
36
+ help='Master port for distributed training (default: 29500)')
37
+ parser.add_argument('--device', type=int, default=0,
38
+ help='Device to run the model on (default: cuda:0)')
39
+ parser.add_argument('--begin_greedy', type=int, default=0,
40
+ help='iter of begin to be pure greedy, using half greedy before')
41
+
42
+ # Random seed
43
+ parser.add_argument('--base_seed', type=int, default=42,
44
+ help='Base random seed (default: 42)')
45
+
46
+ return parser.parse_args()
47
+
48
+
49
+ def run_model(arch, weight_path, results_path, result_file, lr, master_port, train_ligf, test_ligf, device):
50
+ import os
51
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
52
+ cmd = [
53
+ "bash", "./active_learning_scripts/run_model.sh",
54
+ arch,
55
+ weight_path,
56
+ results_path,
57
+ result_file,
58
+ str(lr),
59
+ str(master_port),
60
+ train_ligf,
61
+ test_ligf,
62
+ str(device)
63
+ ]
64
+ subprocess.run(cmd, check=True, cwd=project_root)
65
+
66
+
67
+ def prepare_initial_split(input_file, results_dir, al_batch_size, repeat_idx, cycle_idx, base_seed):
68
+ # Read all ligands
69
+ df = pd.read_csv(input_file)
70
+
71
+ # Set random seed for reproducibility
72
+ random.seed(base_seed + repeat_idx) # Different seed for each repeat
73
+
74
+ # Randomly select ligands for training and testing
75
+ all_indices = list(range(len(df)))
76
+ train_indices = random.sample(all_indices, al_batch_size)
77
+ test_indices = [i for i in all_indices if i not in train_indices]
78
+
79
+ # Create train and test files
80
+ train_df = df.iloc[train_indices]
81
+ test_df = df.iloc[test_indices]
82
+
83
+ # Create file names with repeat and cycle information
84
+ train_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
85
+ test_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
86
+
87
+ # Create directory if it doesn't exist
88
+ os.makedirs(os.path.dirname(train_file), exist_ok=True)
89
+
90
+ # Save files
91
+ train_df.to_csv(train_file, index=False)
92
+ test_df.to_csv(test_file, index=False)
93
+
94
+ return train_file, test_file
95
+
96
+
97
+ def read_jsonl_predictions(results_path, result_file):
98
+ """
99
+ Read predictions from jsonl file and calculate average predictions
100
+ Returns a dictionary mapping SMILES to average predictions
101
+ """
102
+ predictions = {}
103
+ all_predictions = []
104
+ smiles_list = None
105
+
106
+ jsonl_path = os.path.join(results_path, result_file)
107
+ with open(jsonl_path, 'r') as f:
108
+ # Read first line to get SMILES list
109
+ first_line = f.readline()
110
+ smiles_list = json.loads(first_line.strip())["tyk2"]["smiles"]
111
+
112
+ # Read rest of lines containing predictions
113
+ for line in f:
114
+ pred_line = json.loads(line.strip())
115
+ all_predictions.append(pred_line["tyk2"]["pred"])
116
+
117
+ # Convert to numpy array for easier computation
118
+ pred_array = np.array(all_predictions)
119
+ # Calculate mean predictions
120
+ mean_predictions = np.mean(pred_array, axis=0)
121
+
122
+ # Create dictionary mapping SMILES to average predictions
123
+ for smile, pred in zip(smiles_list, mean_predictions):
124
+ predictions[smile] = float(pred)
125
+
126
+ return predictions
127
+
128
+
129
+ def update_splits(results_dir, results_path, result_file, prev_train_file, prev_test_file, repeat_idx, cycle_idx,
130
+ al_batch_size, begin_greedy):
131
+ # Read predictions from jsonl file
132
+ predictions = read_jsonl_predictions(results_path, result_file)
133
+
134
+ # Read previous test file
135
+ test_df = pd.read_csv(prev_test_file)
136
+
137
+ # Add predictions to test_df
138
+ test_df['prediction'] = test_df['Smiles'].map(predictions)
139
+
140
+ # Sort by predictions (high to low)
141
+ test_df_sorted = test_df.sort_values('prediction', ascending=False)
142
+
143
+ # Read previous train file
144
+ train_df = pd.read_csv(prev_train_file)
145
+
146
+ # Create new file names
147
+ new_train_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_train.csv")
148
+ new_test_file = os.path.join(results_dir, f"repeat_{repeat_idx}_cycle_{cycle_idx}_test.csv")
149
+
150
+ # Create directory if it doesn't exist
151
+ os.makedirs(os.path.dirname(new_train_file), exist_ok=True)
152
+
153
+ if cycle_idx >= begin_greedy:
154
+ # Take top al_batch_size compounds for training
155
+ new_train_compounds = test_df_sorted.head(al_batch_size)
156
+ remaining_test_compounds = test_df_sorted.iloc[al_batch_size:]
157
+ else:
158
+ # use half greedy approach
159
+ new_train_compounds_tmp_1 = test_df_sorted.head(al_batch_size//2)
160
+ remaining_test_compounds_tmp = test_df_sorted.iloc[al_batch_size//2:]
161
+ all_indices = list(range(len(remaining_test_compounds_tmp)))
162
+
163
+ train_indices = random.sample(all_indices, al_batch_size - al_batch_size//2)
164
+ test_indices = [i for i in all_indices if i not in train_indices]
165
+ remaining_test_compounds = remaining_test_compounds_tmp.iloc[test_indices]
166
+ new_train_compounds_tmp_2 = remaining_test_compounds_tmp.iloc[train_indices]
167
+ new_train_compounds = pd.concat([new_train_compounds_tmp_1, new_train_compounds_tmp_2])
168
+
169
+
170
+ # Combine with previous training data
171
+ combined_train_df = pd.concat([train_df, new_train_compounds])
172
+
173
+ for _ in range(3):
174
+ print("########################################")
175
+ print("Cycling: ", cycle_idx)
176
+ print("top_1p: {}/100".format(combined_train_df['top_1p'].sum()))
177
+ print("top_2p: {}/200".format(combined_train_df['top_2p'].sum()))
178
+ print("top_5p: {}/500".format(combined_train_df['top_5p'].sum()))
179
+
180
+ # Save files
181
+ combined_train_df.to_csv(new_train_file, index=False)
182
+ remaining_test_compounds.to_csv(new_test_file, index=False)
183
+
184
+ return new_train_file, new_test_file
185
+
186
+
187
+ def run_active_learning(args):
188
+ # Create base results directory
189
+ os.system(f"rm -rf {args.results_dir}")
190
+ os.makedirs(args.results_dir, exist_ok=True)
191
+
192
+ for repeat_idx in range(args.num_repeats):
193
+ print(f"Starting repeat {repeat_idx}")
194
+
195
+ # Initial split for this repeat
196
+ train_file, test_file = prepare_initial_split(
197
+ args.input_file,
198
+ args.results_dir,
199
+ args.al_batch_size,
200
+ repeat_idx,
201
+ 0, # First cycle
202
+ args.base_seed
203
+ )
204
+
205
+ for cycle_idx in range(args.num_cycles):
206
+ print(f"Running cycle {cycle_idx} for repeat {repeat_idx}")
207
+
208
+ # Create results directory for this cycle
209
+ results_path = args.results_dir
210
+
211
+ # Result file name
212
+ result_file = f"repeat_{repeat_idx}_cycle_{cycle_idx}_results.jsonl"
213
+ if os.path.exists(f"{args.results_dir}/{result_file}"):
214
+ os.remove(f"{args.results_dir}/{result_file}")
215
+
216
+ # Run the model
217
+ run_model(
218
+ arch=args.arch,
219
+ weight_path=args.weight_path,
220
+ results_path=results_path,
221
+ result_file=result_file,
222
+ lr=args.lr,
223
+ master_port=args.master_port,
224
+ train_ligf=train_file,
225
+ test_ligf=test_file,
226
+ device=args.device
227
+ )
228
+
229
+ # Update splits for next cycle
230
+ if cycle_idx < args.num_cycles - 1: # Don't update after last cycle
231
+ train_file, test_file = update_splits(
232
+ args.results_dir,
233
+ results_path,
234
+ result_file,
235
+ train_file,
236
+ test_file,
237
+ repeat_idx,
238
+ cycle_idx + 1,
239
+ args.al_batch_size,
240
+ args.begin_greedy
241
+ )
242
+
243
+
244
+ if __name__ == "__main__":
245
+ args = parse_arguments()
246
+ run_active_learning(args)
active_learning_scripts/run_model.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_path="./test_datasets"
2
+
3
+ n_gpu=1
4
+
5
+ batch_size=1
6
+ batch_size_valid=1
7
+ epoch=20
8
+ update_freq=1
9
+ #lr=1e-3
10
+ #MASTER_PORT=10075
11
+ #arch=pocket_ranking
12
+
13
+ export NCCL_ASYNC_ERROR_HANDLING=1
14
+ export OMP_NUM_THREADS=1
15
+
16
+ arch=${1} # model architecture
17
+ weight_path=${2} # path for pretrained model
18
+ results_path=${3} #
19
+ result_file=${4} #
20
+ lr=${5} # learning rate
21
+ MASTER_PORT=${6}
22
+ train_ligf=${7} # !! input path for training ligands file (.csv format)
23
+ test_ligf=${8} # !! input path for test ligands file (.csv format)
24
+ device=${9} # cuda device
25
+
26
+ if [[ "$arch" == "pocketregression" ]] || [[ "$arch" == "DTA" ]]; then
27
+ loss="mseloss"
28
+ else
29
+ loss="rank_softmax"
30
+ fi
31
+
32
+
33
+ CUDA_VISIBLE_DEVICES=${device} python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
34
+ --results-path $results_path \
35
+ --num-workers 8 --ddp-backend=c10d \
36
+ --task train_task --loss ${loss} --arch $arch \
37
+ --max-pocket-atoms 256 \
38
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
39
+ --lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
40
+ --update-freq $update_freq --seed 1 \
41
+ --log-interval 1 --log-format simple \
42
+ --validate-interval 1 --validate-begin-epoch 15 \
43
+ --best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
44
+ --no-save --save-dir $results_path --tmp-save-dir $results_path \
45
+ --find-unused-parameters \
46
+ --maximize-best-checkpoint-metric \
47
+ --valid-set TYK2 \
48
+ --max-lignum 512 --test-max-lignum 10000 \
49
+ --restore-model $weight_path --few-shot true \
50
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
51
+ --active-learning-resfile ${result_file} \
52
+ --case-train-ligfile ${train_ligf} --case-test-ligfile ${test_ligf}
53
+
ensemble_result.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import json
4
+ import copy
5
+ import numpy as np
6
+ import scipy.stats as stats
7
+ import math
8
+ from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
9
+
10
+ def cal_metrics(y_score, y_true):
11
+ # concate res_single and labels
12
+ scores = np.expand_dims(y_score, axis=1)
13
+ y_true = np.expand_dims(y_true, axis=1)
14
+ scores = np.concatenate((scores, y_true), axis=1)
15
+ # inverse sort scores based on first column
16
+ scores = scores[scores[:, 0].argsort()[::-1]]
17
+ bedroc = CalcBEDROC(scores, 1, 80.5)
18
+ count = 0
19
+ # sort y_score, return index
20
+ index = np.argsort(y_score)[::-1]
21
+ for i in range(int(len(index) * 0.005)):
22
+ if y_true[index[i]] == 1:
23
+ count += 1
24
+ auc = CalcAUC(scores, 1)
25
+ ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05])
26
+
27
+ return {
28
+ "BEDROC": bedroc,
29
+ "AUROC": auc,
30
+ "EF0.5": ef_list[0],
31
+ "EF1": ef_list[1],
32
+ "EF5": ef_list[3]
33
+ }
34
+
35
+ def print_avg_metric(metric_dict, name):
36
+ metric_lst = list(metric_dict.values())
37
+ ret_metric = copy.deepcopy(metric_lst[0])
38
+ for m in metric_lst[1:]:
39
+ for k in m:
40
+ ret_metric[k] += m[k]
41
+
42
+ for k in ret_metric:
43
+ ret_metric[k] = ret_metric[k] / len(metric_lst)
44
+ print(name, ret_metric)
45
+
46
+ def read_zeroshot_res(res_dir):
47
+ targets = sorted(list(os.listdir(res_dir)))
48
+ res_dict = {}
49
+ for target in targets:
50
+ real_dg = np.load(f"{res_dir}/{target}/saved_labels.npy")
51
+ if os.path.exists(f"{res_dir}/{target}/saved_preds.npy"):
52
+ pred_dg = np.load(f"{res_dir}/{target}/saved_preds.npy")
53
+ else:
54
+ mol_reps = np.load(f"{res_dir}/{target}/saved_mols_embed.npy")
55
+ pocket_reps = np.load(f"{res_dir}/{target}/saved_target_embed.npy")
56
+ res = pocket_reps @ mol_reps.T
57
+ pred_dg = res.max(axis=0)
58
+ res_dict[target] = {
59
+ "pred": pred_dg,
60
+ "exp": real_dg
61
+ }
62
+ return res_dict
63
+
64
+ def get_ensemble_res(res_list, begin=0, end=-1):
65
+ if end == -1:
66
+ end = len(res_list)
67
+ ret = copy.deepcopy(res_list[begin])
68
+ for res in res_list[begin+1:end]:
69
+ for k in ret.keys():
70
+ ret[k]["pred"] = np.array(ret[k]["pred"]) + np.array(res[k]["pred"])
71
+
72
+ for k in ret.keys():
73
+ ret[k]["pred"] = np.array(ret[k]["pred"]) / (end-begin)
74
+
75
+ return ret
76
+
77
+ def avg_metric(metric_lst_all):
78
+ ret_metric_dict = {}
79
+ for metric_lst in metric_lst_all:
80
+ ret_metric = copy.deepcopy(metric_lst[0])
81
+ for m in metric_lst[1:]:
82
+ for k in ["pearsonr", "spearmanr", "r2"]:
83
+ ret_metric[k] += m[k]
84
+ for k in ["spearmanr", "pearsonr", "r2"]:
85
+ ret_metric[k] = ret_metric[k] / len(metric_lst)
86
+ ret_metric_dict[ret_metric["target"]] = ret_metric
87
+ return ret_metric_dict
88
+
89
+ def get_metric(res):
90
+ metric_dict = {}
91
+ for k in sorted(list(res.keys())):
92
+ pred = res[k]["pred"]
93
+ exp = res[k]["exp"]
94
+ spearmanr = stats.spearmanr(exp, pred).statistic
95
+ pearsonr = stats.pearsonr(exp, pred).statistic
96
+ if math.isnan(pearsonr):
97
+ pearsonr = 0
98
+ if math.isnan(spearmanr):
99
+ spearmanr = 0
100
+ metric_dict[k] = {
101
+ "pearsonr":pearsonr,
102
+ "spearmanr":spearmanr,
103
+ "r2":max(pearsonr, 0)**2,
104
+ "target":k
105
+ }
106
+ return metric_dict
107
+
108
+
109
+ if __name__ == '__main__':
110
+ mode = sys.argv[1]
111
+ if mode == "zeroshot":
112
+ test_sets = sys.argv[2:]
113
+ for test_set in test_sets:
114
+ if test_set in ["DUDE", "PCBA", "DEKOIS"]:
115
+ metrics = {}
116
+ target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}")))
117
+ for target_id in target_id_list:
118
+ lig_act = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/saved_labels.npy")
119
+ score_1 = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy")
120
+ score_2 = np.load(f"./result/protein_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy")
121
+
122
+ score = (score_1 + score_2)/2
123
+ metrics[target_id] = cal_metrics(score, lig_act)
124
+
125
+ json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
126
+ print_avg_metric(metrics, "Ours")
127
+ elif test_set in ["FEP"]:
128
+ target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}")))
129
+ res_all_pocket, res_all_protein = [], []
130
+ for repeat in range(1, 6):
131
+ res_pocket = read_zeroshot_res(f"./result/pocket_ranking/{test_set}/repeat_{repeat}")
132
+ res_protein = read_zeroshot_res(f"./result/protein_ranking/{test_set}/repeat_{repeat}")
133
+ res_all_pocket.append(res_pocket)
134
+ res_all_protein.append(res_protein)
135
+ res_all_fusion = get_ensemble_res(res_all_pocket + res_all_protein)
136
+ metrics = get_metric(res_all_fusion)
137
+ json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
138
+ print_avg_metric(metrics, "Ours")
139
+ elif mode == "fewshot":
140
+ test_set = sys.argv[2]
141
+ support_num = sys.argv[3]
142
+ begin = 15
143
+ end = 20
144
+ metric_fusion_all = []
145
+ for seed in range(1, 11):
146
+ res_repeat_pocket = []
147
+ res_repeat_seq = []
148
+
149
+ if test_set in ["TIME", "OOD"]:
150
+ res_file_pocket = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl"
151
+ res_file_seq = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl"
152
+ if not os.path.exists(res_file_pocket):
153
+ continue
154
+ res_repeat_pocket = [json.loads(line) for line in open(res_file_pocket)][1:]
155
+ res_repeat_seq = [json.loads(line) for line in open(res_file_seq)][1:]
156
+ elif test_set in ["FEP_fewshot"]:
157
+ for repeat in range(1, 6):
158
+ res_file_pocket = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl"
159
+ res_file_seq = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl"
160
+ if not os.path.exists(res_file_pocket):
161
+ continue
162
+ res_pocket = [json.loads(line) for line in open(res_file_pocket)][1:]
163
+ res_seq = [json.loads(line) for line in open(res_file_seq)][1:]
164
+ res_pocket = get_ensemble_res(res_pocket, begin, end)
165
+ res_seq = get_ensemble_res(res_seq, begin, end)
166
+ res_repeat_pocket.append(res_pocket)
167
+ res_repeat_seq.append(res_seq)
168
+
169
+ res_repeat_fusion = get_ensemble_res(res_repeat_pocket + res_repeat_seq)
170
+ metric_fusion_all.append(get_metric(res_repeat_fusion))
171
+ metric_fusion_all = avg_metric(list(map(list, zip(*metric_fusion_all))))
172
+ json.dump(metric_fusion_all, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w"))
173
+ print_avg_metric(metric_fusion_all, "Ours")
py_scripts/__init__.py ADDED
File without changes
py_scripts/write_case_study.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gzip
3
+ import json
4
+ import multiprocessing as mp
5
+ import os
6
+ import pickle
7
+ import random
8
+
9
+ import lmdb
10
+ import numpy as np
11
+ import pandas as pd
12
+ import rdkit
13
+ import rdkit.Chem.AllChem as AllChem
14
+ import torch
15
+ import tqdm
16
+ from biopandas.mol2 import PandasMol2
17
+ from biopandas.pdb import PandasPdb
18
+ from rdkit import Chem, RDLogger
19
+ from rdkit.Chem.MolStandardize import rdMolStandardize
20
+
21
+ RDLogger.DisableLog('rdApp.*')
22
+
23
+ def gen_conformation(mol, num_conf=20, num_worker=8):
24
+ try:
25
+ mol = Chem.AddHs(mol)
26
+ AllChem.EmbedMultipleConfs(mol, numConfs=num_conf, numThreads=num_worker, pruneRmsThresh=1, maxAttempts=10000, useRandomCoords=False)
27
+ try:
28
+ AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=num_worker)
29
+ except:
30
+ pass
31
+ mol = Chem.RemoveHs(mol)
32
+ except:
33
+ print("cannot gen conf", Chem.MolToSmiles(mol))
34
+ return None
35
+ if mol.GetNumConformers() == 0:
36
+ print("cannot gen conf", Chem.MolToSmiles(mol))
37
+ return None
38
+ return mol
39
+
40
+ def convert_2Dmol_to_data(smi, num_conf=1, num_worker=5):
41
+ #to 3D
42
+ mol = Chem.MolFromSmiles(smi)
43
+ if mol is None:
44
+ return None
45
+ mol = gen_conformation(mol, num_conf, num_worker)
46
+ if mol is None:
47
+ return None
48
+ coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
49
+ atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
50
+ return {'coords': coords, 'atom_types': atom_types, 'smi': smi, 'mol': mol}
51
+
52
+ def convert_3Dmol_to_data(mol):
53
+
54
+ if mol is None:
55
+ return None
56
+ coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
57
+ atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
58
+ return {'coords': coords, 'atom_types': atom_types, 'smi': Chem.MolToSmiles(mol), 'mol': mol}
59
+
60
+ def read_pdb(path):
61
+ pdb_df = PandasPdb().read_pdb(path)
62
+
63
+ coord = pdb_df.df['ATOM'][['x_coord', 'y_coord', 'z_coord']]
64
+ atom_type = pdb_df.df['ATOM']['atom_name']
65
+ residue_name = pdb_df.df['ATOM']['chain_id'] + pdb_df.df['ATOM']['residue_number'].astype(str)
66
+ residue_type = pdb_df.df['ATOM']['residue_name']
67
+ protein = {'coord': np.array(coord),
68
+ 'atom_type': list(atom_type),
69
+ 'residue_name': list(residue_name),
70
+ 'residue_type': list(residue_type)}
71
+ return protein
72
+
73
+
74
+ def read_sdf_gz_3d(path):
75
+ inf = gzip.open(path)
76
+ with Chem.ForwardSDMolSupplier(inf, removeHs=False, sanitize=False) as gzsuppl:
77
+ ms = [add_charges(x) for x in gzsuppl if x is not None]
78
+ ms = [rdMolStandardize.Uncharger().uncharge(Chem.RemoveHs(m)) for m in ms if m is not None]
79
+ return ms
80
+
81
+ def add_charges(m):
82
+ m.UpdatePropertyCache(strict=False)
83
+ ps = Chem.DetectChemistryProblems(m)
84
+ if not ps:
85
+ Chem.SanitizeMol(m)
86
+ return m
87
+ for p in ps:
88
+ if p.GetType()=='AtomValenceException':
89
+ at = m.GetAtomWithIdx(p.GetAtomIdx())
90
+ if at.GetAtomicNum()==7 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
91
+ at.SetFormalCharge(1)
92
+ if at.GetAtomicNum()==6 and at.GetExplicitValence()==5:
93
+ #remove a bond
94
+ for b in at.GetBonds():
95
+ if b.GetBondType()==Chem.rdchem.BondType.DOUBLE:
96
+ b.SetBondType(Chem.rdchem.BondType.SINGLE)
97
+ break
98
+ if at.GetAtomicNum()==8 and at.GetFormalCharge()==0 and at.GetExplicitValence()==3:
99
+ at.SetFormalCharge(1)
100
+ if at.GetAtomicNum()==5 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
101
+ at.SetFormalCharge(-1)
102
+ try:
103
+ Chem.SanitizeMol(m)
104
+ except:
105
+ return None
106
+ return m
107
+
108
+ def get_different_raid(protein, ligand, raid=6):
109
+ protein_coord = protein['coord']
110
+ ligand_coord = ligand['coord']
111
+ protein_residue_name = protein['residue_name']
112
+ pocket_residue = set()
113
+ for i in range(len(protein_coord)):
114
+ for j in range(len(ligand_coord)):
115
+ if np.linalg.norm(protein_coord[i] - ligand_coord[j]) < raid:
116
+ pocket_residue.add(protein_residue_name[i])
117
+ return pocket_residue
118
+
119
+ def read_mol2_ligand(path):
120
+ mol2_df = PandasMol2().read_mol2(path)
121
+ coord = mol2_df.df[['x', 'y', 'z']]
122
+ atom_type = mol2_df.df['atom_name']
123
+ ligand = {'coord': np.array(coord), 'atom_type': list(atom_type), 'mol': Chem.MolFromMol2File(path)}
124
+ return ligand
125
+
126
+ def read_smi_mol(path):
127
+ with open(path, 'r') as f:
128
+ mols_lines = list(f.readlines())
129
+ smis = [l.split(' ')[0] for l in mols_lines]
130
+ mols = [Chem.MolFromSmiles(m) for m in smis]
131
+ return mols
132
+
133
+ def parser(protein_path, mol_path, ligand_path, activity, pocket_index, raid=6):
134
+ protein = read_pdb(protein_path)
135
+ data_mols = read_smi_mol(mol_path)
136
+
137
+ ligand = read_mol2_ligand(ligand_path)
138
+ pocket_residue = get_different_raid(protein, ligand, raid=raid)
139
+ pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
140
+ pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
141
+ pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
142
+ pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
143
+ pocket_name = protein_path.split('/')[-2]
144
+ pool = mp.Pool(32)
145
+ #mols = [convert_2Dmol_to_data(m) for m in data_mols if m is not None]
146
+ data_mols = [m for m in data_mols if m is not None]
147
+ mols = [m for m in pool.imap_unordered(convert_2Dmol_to_data, data_mols)]
148
+ mols = [m for m in mols if m is not None]
149
+
150
+ return [{'atoms': m['atom_types'],
151
+ 'coordinates': m['coords'],
152
+ 'smi': m['smi'],
153
+ 'mol': ligand,
154
+ 'pocket_name': pocket_name,
155
+ 'pocket_index': pocket_index,
156
+ 'activity': activity,
157
+ "pocket_atom_type": pocket_atom_type,
158
+ "pocket_coord": pocket_coord} for m in mols]
159
+
160
+ def mol_parser(ligand_smis):
161
+ pool = mp.Pool(16)
162
+ mols = [m for m in pool.imap_unordered(convert_2Dmol_to_data, tqdm.tqdm(ligand_smis))]
163
+ mols = [m for m in mols if m is not None]
164
+ return [{'atoms': m['atom_types'],
165
+ 'coordinates': m['coords'],
166
+ 'smi': m['smi'],
167
+ 'mol': m['mol'],
168
+ 'label': 1,
169
+ } for m in mols]
170
+
171
+ def pocket_parser(protein_path, ligand_path, pocket_index, pocket_name, raid=6):
172
+ protein = read_pdb(protein_path)
173
+ ligand = read_mol2_ligand(ligand_path)
174
+ pocket_residue = get_different_raid(protein, ligand, raid=raid)
175
+ pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
176
+ pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
177
+ pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
178
+ pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
179
+ pocket_residue_name = [protein['residue_name'][i] for i in pocket_atom_idx]
180
+ return {'pocket': pocket_name,
181
+ 'pocket_index': pocket_index,
182
+ "pocket_atoms": pocket_atom_type,
183
+ "pocket_coordinates": pocket_coord,
184
+ "pocket_residue_type": pocket_residue_type,
185
+ "pocket_residue_name": pocket_residue_name}
186
+
187
+ def write_lmdb(data, lmdb_path):
188
+ #resume
189
+ if os.path.exists(lmdb_path):
190
+ os.system(f"rm {lmdb_path}")
191
+ env = lmdb.open(lmdb_path, subdir=False, readonly=False, lock=False, readahead=False, meminit=False, map_size=1099511627776)
192
+ num = 0
193
+ with env.begin(write=True) as txn:
194
+ for d in data:
195
+ txn.put(str(num).encode('ascii'), pickle.dumps(d))
196
+ num += 1
197
+
198
+ import sys
199
+ if __name__ == '__main__':
200
+ mode = sys.argv[1]
201
+
202
+ if mode == 'mol':
203
+ lig_file = sys.argv[2]
204
+ lig_write_file = sys.argv[3]
205
+
206
+ # read the ligands smiles into a list
207
+ smis = json.load(open(lig_file))
208
+ data = []
209
+ print("number of ligands", len(set(smis)))
210
+ d_active = (mol_parser(list(set(smis))))
211
+ data.extend(d_active)
212
+
213
+ # write ligands lmdb
214
+ write_lmdb(data, lig_write_file)
215
+ elif mode == 'pocket':
216
+ prot_file = sys.argv[2]
217
+ crystal_lig_file = sys.argv[3] # must be .mol2 file
218
+ prot_write_file = sys.argv[4]
219
+
220
+ # write pocket
221
+ data = []
222
+ d = pocket_parser(prot_file, crystal_lig_file, 1, "demo")
223
+ data.append(d)
224
+ write_lmdb(data, prot_write_file)
225
+
226
+
227
+
test.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size=256
2
+
3
+ TASK=${1}
4
+ arch=${2}
5
+ weight_path=${3}
6
+ results_path=${4}
7
+ echo "writing to ${results_path}"
8
+
9
+ mkdir -p $results_path
10
+ python ./unimol/test.py "./test_datasets" --user-dir ./unimol --valid-subset test \
11
+ --results-path $results_path \
12
+ --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
13
+ --task test_task --loss rank_softmax --arch $arch \
14
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --seed 1 \
15
+ --path $weight_path \
16
+ --log-interval 100 --log-format simple \
17
+ --max-pocket-atoms 511 \
18
+ --test-task $TASK
test_fewshot.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_path="./test_datasets"
2
+
3
+ TASK=${1}
4
+ arch=${2}
5
+ sup_num=${3}
6
+ weight_path=${4}
7
+ results_path=${5}
8
+
9
+ n_gpu=1
10
+ batch_size=8
11
+ batch_size_valid=16
12
+ epoch=10
13
+ update_freq=1
14
+ lr=1e-4
15
+ MASTER_PORT=10092
16
+ export NCCL_ASYNC_ERROR_HANDLING=1
17
+ export OMP_NUM_THREADS=1
18
+ seed=1
19
+
20
+ torchrun --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
21
+ --results-path $results_path \
22
+ --num-workers 8 --ddp-backend=c10d \
23
+ --task train_task --loss rank_softmax --arch $arch \
24
+ --max-pocket-atoms 256 \
25
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
26
+ --lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
27
+ --update-freq $update_freq --seed $seed \
28
+ --log-interval 1 --log-format simple \
29
+ --validate-interval 1 \
30
+ --best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
31
+ --no-save --save-dir $results_path --tmp-save-dir $results_path \
32
+ --find-unused-parameters \
33
+ --maximize-best-checkpoint-metric \
34
+ --split-method random --valid-set $TASK \
35
+ --max-lignum 512 \
36
+ --sup-num $sup_num \
37
+ --restore-model $weight_path --few-shot true \
38
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256
test_fewshot_demo.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_path="./vocab"
2
+
3
+ n_gpu=1
4
+ batch_size=1
5
+ batch_size_valid=1
6
+ epoch=20
7
+ update_freq=1
8
+ lr=1e-4
9
+ MASTER_PORT=10092
10
+
11
+ export NCCL_ASYNC_ERROR_HANDLING=1
12
+ export OMP_NUM_THREADS=1
13
+
14
+ arch=${1}
15
+ weight_path=${2}
16
+ results_path=${3}
17
+ lig_file=${4}
18
+ prot_file=${5}
19
+ split_file=${6}
20
+
21
+ sup_num=16
22
+ seed=1
23
+
24
+ torchrun --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
25
+ --results-path $results_path \
26
+ --num-workers 8 --ddp-backend=c10d \
27
+ --task train_task --loss rank_softmax --arch $arch \
28
+ --max-pocket-atoms 256 \
29
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
30
+ --lr-scheduler polynomial_decay --lr $lr --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
31
+ --update-freq $update_freq --seed $seed \
32
+ --log-interval 1 --log-format simple \
33
+ --validate-interval 1 \
34
+ --best-checkpoint-metric valid_mean_r2 --patience 100 --all-gather-list-size 2048000 \
35
+ --no-save --save-dir ./tmp --tmp-save-dir ./tmp \
36
+ --find-unused-parameters \
37
+ --maximize-best-checkpoint-metric \
38
+ --split-method random --valid-set DEMO \
39
+ --max-lignum 512 \
40
+ --sup-num $sup_num \
41
+ --restore-model $weight_path --few-shot true \
42
+ --demo-lig-file $lig_file --demo-prot-file $prot_file --demo-split-file $split_file \
43
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256
test_zeroshot_demo.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size=128
2
+
3
+ lig_file=${1}
4
+ prot_file=${2}
5
+ uniprot=${3}
6
+ arch=${4}
7
+ weight_path=${5}
8
+ results_path=${6}
9
+ echo "writing to ${results_path}"
10
+
11
+ mkdir -p $results_path
12
+ python ./unimol/test.py "./vocab" --user-dir ./unimol --valid-subset test \
13
+ --results-path $results_path \
14
+ --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
15
+ --task test_task --loss rank_softmax --arch $arch \
16
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --seed 1 \
17
+ --path $weight_path \
18
+ --log-interval 100 --log-format simple \
19
+ --max-pocket-atoms 511 --demo-lig-file $lig_file --demo-prot-file $prot_file --demo-uniprot $uniprot \
20
+ --test-task DEMO
train.sh ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_path="./data"
2
+
3
+ save_root="./save"
4
+ save_name="screen_pocket"
5
+ save_dir="${save_root}/${save_name}/savedir_screen"
6
+ tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
7
+ tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
8
+ mkdir -p ${save_dir}
9
+ n_gpu=2
10
+ MASTER_PORT=10062
11
+ finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
12
+ finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
13
+
14
+
15
+ batch_size=24
16
+ batch_size_valid=32
17
+ epoch=50
18
+ dropout=0.0
19
+ warmup=0.06
20
+ update_freq=1
21
+ dist_threshold=8.0
22
+ recycling=3
23
+ lr=1e-4
24
+
25
+ export NCCL_ASYNC_ERROR_HANDLING=1
26
+ export OMP_NUM_THREADS=1
27
+ CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
28
+ --num-workers 8 --ddp-backend=c10d \
29
+ --task train_task --loss rank_softmax --arch pocketscreen \
30
+ --max-pocket-atoms 256 \
31
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
32
+ --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
33
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
34
+ --tensorboard-logdir $tsb_dir \
35
+ --log-interval 100 --log-format simple \
36
+ --validate-interval 1 \
37
+ --best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
38
+ --save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
39
+ --find-unused-parameters \
40
+ --maximize-best-checkpoint-metric \
41
+ --finetune-pocket-model $finetune_pocket_model \
42
+ --finetune-mol-model $finetune_mol_model \
43
+ --valid-set CASF \
44
+ --max-lignum 16 \
45
+ --protein-similarity-thres 1.0 > ${save_root}/train_log/train_log_${save_name}.txt
46
+
47
+
48
+ save_name="screen_pocket_norank"
49
+ save_dir="${save_root}/${save_name}/savedir_screen"
50
+ tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
51
+ tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
52
+ mkdir -p ${save_dir}
53
+ n_gpu=2
54
+ MASTER_PORT=10062
55
+ finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
56
+ finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
57
+
58
+ export NCCL_ASYNC_ERROR_HANDLING=1
59
+ export OMP_NUM_THREADS=1
60
+ CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
61
+ --num-workers 8 --ddp-backend=c10d \
62
+ --task train_task --loss rank_softmax --arch pocketscreen \
63
+ --max-pocket-atoms 256 \
64
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
65
+ --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
66
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
67
+ --tensorboard-logdir $tsb_dir \
68
+ --log-interval 100 --log-format simple \
69
+ --validate-interval 1 \
70
+ --best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
71
+ --save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
72
+ --find-unused-parameters \
73
+ --maximize-best-checkpoint-metric \
74
+ --finetune-pocket-model $finetune_pocket_model \
75
+ --finetune-mol-model $finetune_mol_model \
76
+ --valid-set CASF \
77
+ --max-lignum 16 \
78
+ --protein-similarity-thres 1.0 \
79
+ --rank-weight 0.0 > ${save_root}/train_log/train_log_${save_name}.txt
80
+
81
+
82
+ save_name="screen_pocket_no_similar_protein0.8"
83
+ save_dir="${save_root}/${save_name}/savedir_screen"
84
+ tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
85
+ tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
86
+ mkdir -p ${save_dir}
87
+ n_gpu=2
88
+ MASTER_PORT=10062
89
+ finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
90
+ finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
91
+
92
+ export NCCL_ASYNC_ERROR_HANDLING=1
93
+ export OMP_NUM_THREADS=1
94
+ CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
95
+ --num-workers 8 --ddp-backend=c10d \
96
+ --task train_task --loss rank_softmax --arch pocketscreen \
97
+ --max-pocket-atoms 256 \
98
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
99
+ --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
100
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
101
+ --tensorboard-logdir $tsb_dir \
102
+ --log-interval 100 --log-format simple \
103
+ --validate-interval 1 \
104
+ --best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
105
+ --save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
106
+ --find-unused-parameters \
107
+ --maximize-best-checkpoint-metric \
108
+ --finetune-pocket-model $finetune_pocket_model \
109
+ --finetune-mol-model $finetune_mol_model \
110
+ --valid-set CASF \
111
+ --max-lignum 16 \
112
+ --protein-similarity-thres 0.8 > ${save_root}/train_log/train_log_${save_name}.txt
113
+
114
+
115
+ save_name="screen_pocket_no_similar_protein"
116
+ save_dir="${save_root}/${save_name}/savedir_screen"
117
+ tmp_save_dir="${save_root}/${save_name}/tmp_save_dir_screen"
118
+ tsb_dir="${save_root}/${save_name}/tsb_dir_screen"
119
+ mkdir -p ${save_dir}
120
+ n_gpu=2
121
+ MASTER_PORT=10062
122
+ finetune_mol_model="./pretrain/mol_pre_no_h_220816.pt" # unimol pretrained mol model
123
+ finetune_pocket_model="./pretrain/pocket_pre_220816.pt" # unimol pretrained pocket model
124
+
125
+ export NCCL_ASYNC_ERROR_HANDLING=1
126
+ export OMP_NUM_THREADS=1
127
+ CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
128
+ --num-workers 8 --ddp-backend=c10d \
129
+ --task train_task --loss rank_softmax --arch pocketscreen \
130
+ --max-pocket-atoms 256 \
131
+ --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
132
+ --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
133
+ --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
134
+ --tensorboard-logdir $tsb_dir \
135
+ --log-interval 100 --log-format simple \
136
+ --validate-interval 1 \
137
+ --best-checkpoint-metric valid_bedroc --patience 2000 --all-gather-list-size 2048000 \
138
+ --save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-best-checkpoints 8 --keep-last-epochs 10 \
139
+ --find-unused-parameters \
140
+ --maximize-best-checkpoint-metric \
141
+ --finetune-pocket-model $finetune_pocket_model \
142
+ --finetune-mol-model $finetune_mol_model \
143
+ --valid-set CASF \
144
+ --max-lignum 16 \
145
+ --protein-similarity-thres 0.4 > ${save_root}/train_log/train_log_${save_name}.txt
unimol/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import importlib
2
+ import unimol.tasks
3
+ import unimol.data
4
+ import unimol.models
5
+ import unimol.losses
6
+ import unimol.utils
unimol/data/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .key_dataset import KeyDataset, LengthDataset
2
+ from .normalize_dataset import (
3
+ NormalizeDataset,
4
+ NormalizeDockingPoseDataset,
5
+ )
6
+ from .remove_hydrogen_dataset import (
7
+ RemoveHydrogenDataset,
8
+ RemoveHydrogenResiduePocketDataset,
9
+ RemoveHydrogenPocketDataset,
10
+ )
11
+ from .tta_dataset import (
12
+ TTADataset,
13
+ TTADecoderDataset,
14
+ TTADockingPoseDataset,
15
+ )
16
+ from .cropping_dataset import (
17
+ CroppingDataset,
18
+ CroppingPocketDataset,
19
+ CroppingResiduePocketDataset,
20
+ CroppingPocketDockingPoseDataset,
21
+ CroppingPocketDockingPoseTestDataset,
22
+ )
23
+ from .atom_type_dataset import AtomTypeDataset
24
+ from .add_2d_conformer_dataset import Add2DConformerDataset
25
+ from .distance_dataset import (
26
+ DistanceDataset,
27
+ EdgeTypeDataset,
28
+ CrossDistanceDataset,
29
+ CrossEdgeTypeDataset
30
+ )
31
+ from .conformer_sample_dataset import (
32
+ ConformerSampleDataset,
33
+ ConformerSampleDecoderDataset,
34
+ ConformerSamplePocketDataset,
35
+ ConformerSamplePocketFinetuneDataset,
36
+ ConformerSampleConfGDataset,
37
+ ConformerSampleConfGV2Dataset,
38
+ ConformerSampleDockingPoseDataset,
39
+ )
40
+ from .mask_points_dataset import MaskPointsDataset, MaskPointsPocketDataset
41
+ from .coord_pad_dataset import RightPadDatasetCoord, RightPadDatasetCross2D
42
+ from .from_str_dataset import FromStrLabelDataset
43
+ from .lmdb_dataset import LMDBDataset
44
+ from .prepend_and_append_2d_dataset import PrependAndAppend2DDataset
45
+ from .affinity_dataset import AffinityDataset, AffinityTestDataset, AffinityValidDataset, AffinityMolDataset, AffinityPocketDataset, AffinityHNSDataset, AffinityAugDataset
46
+ from .pocket2mol_dataset import FragmentConformationDataset
47
+ from .vae_binding_dataset import VAEBindingDataset, VAEBindingTestDataset, VAEGenerationTestDataset
48
+ from .resampling_dataset import ResamplingDataset
49
+ from .pair_dataset import PairDataset
50
+ __all__ = []
unimol/data/add_2d_conformer_dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ from functools import lru_cache
7
+ from unicore.data import BaseWrapperDataset
8
+ from rdkit import Chem
9
+ from rdkit.Chem import AllChem
10
+
11
+
12
+ class Add2DConformerDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, smi, atoms, coordinates):
14
+ self.dataset = dataset
15
+ self.smi = smi
16
+ self.atoms = atoms
17
+ self.coordinates = coordinates
18
+ self.set_epoch(None)
19
+
20
+ def set_epoch(self, epoch, **unused):
21
+ super().set_epoch(epoch)
22
+ self.epoch = epoch
23
+
24
+ @lru_cache(maxsize=16)
25
+ def __cached_item__(self, index: int, epoch: int):
26
+ atoms = np.array(self.dataset[index][self.atoms])
27
+ assert len(atoms) > 0
28
+ smi = self.dataset[index][self.smi]
29
+ coordinates_2d = smi2_2Dcoords(smi)
30
+ coordinates = self.dataset[index][self.coordinates]
31
+ coordinates.append(coordinates_2d)
32
+ return {"smi": smi, "atoms": atoms, "coordinates": coordinates}
33
+
34
+ def __getitem__(self, index: int):
35
+ return self.__cached_item__(index, self.epoch)
36
+
37
+
38
+ def smi2_2Dcoords(smi):
39
+ mol = Chem.MolFromSmiles(smi)
40
+ mol = AllChem.AddHs(mol)
41
+ AllChem.Compute2DCoords(mol)
42
+ coordinates = mol.GetConformer().GetPositions().astype(np.float32)
43
+ len(mol.GetAtoms()) == len(
44
+ coordinates
45
+ ), "2D coordinates shape is not align with {}".format(smi)
46
+ return coordinates
unimol/data/affinity_dataset.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+
6
+ from functools import lru_cache
7
+
8
+ import numpy as np
9
+ from unicore.data import BaseWrapperDataset
10
+ import pickle
11
+ from . import data_utils
12
+
13
+
14
+ class AffinityDataset(BaseWrapperDataset):
15
+ def __init__(
16
+ self,
17
+ dataset,
18
+ seed,
19
+ atoms,
20
+ coordinates,
21
+ pocket_atoms,
22
+ pocket_coordinates,
23
+ affinity,
24
+ is_train=False,
25
+ pocket="pocket"
26
+ ):
27
+ self.dataset = dataset
28
+ self.seed = seed
29
+ self.atoms = atoms
30
+ self.coordinates = coordinates
31
+ self.pocket_atoms = pocket_atoms
32
+ self.pocket_coordinates = pocket_coordinates
33
+ self.affinity = affinity
34
+ self.is_train = is_train
35
+ self.pocket=pocket
36
+ self.set_epoch(None)
37
+
38
+ def set_epoch(self, epoch, **unused):
39
+ super().set_epoch(epoch)
40
+ self.epoch = epoch
41
+
42
+ def pocket_atom(self, atom):
43
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
44
+ return atom[1]
45
+ else:
46
+ return atom[0]
47
+
48
+ @lru_cache(maxsize=16)
49
+ def __cached_item__(self, index: int, epoch: int):
50
+ atoms = np.array(self.dataset[index][self.atoms])
51
+ ori_mol_length = len(atoms)
52
+ #coordinates = self.dataset[index][self.coordinates]
53
+ size = len(self.dataset[index][self.coordinates])
54
+ if self.is_train:
55
+ with data_utils.numpy_seed(self.seed, epoch, index):
56
+ sample_idx = np.random.randint(size)
57
+ else:
58
+ with data_utils.numpy_seed(self.seed, 1, index):
59
+ sample_idx = np.random.randint(size)
60
+ #print(len(self.dataset[index][self.coordinates][sample_idx]))
61
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
62
+ #print(coordinates.shape)
63
+ pocket_atoms = np.array(
64
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
65
+ )
66
+ ori_pocket_length = len(pocket_atoms)
67
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
68
+
69
+ smi = self.dataset[index]["smi"]
70
+ pocket = self.dataset[index][self.pocket]
71
+ if self.affinity in self.dataset[index]:
72
+ affinity = float(self.dataset[index][self.affinity])
73
+ else:
74
+ affinity = 1
75
+ return {
76
+ "atoms": atoms,
77
+ "coordinates": coordinates.astype(np.float32),
78
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
79
+ "pocket_atoms": pocket_atoms,
80
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
81
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
82
+ "smi": smi,
83
+ "pocket": pocket,
84
+ "affinity": affinity,
85
+ "ori_mol_length": ori_mol_length,
86
+ "ori_pocket_length": ori_pocket_length
87
+ }
88
+
89
+ def __getitem__(self, index: int):
90
+ return self.__cached_item__(index, self.epoch)
91
+
92
+
93
+ class AffinityAugDataset(BaseWrapperDataset):
94
+ def __init__(
95
+ self,
96
+ dataset,
97
+ seed,
98
+ atoms,
99
+ coordinates,
100
+ pocket_atoms,
101
+ pocket_coordinates,
102
+ affinity,
103
+ is_train=False,
104
+ pocket="pocket_id"
105
+ ):
106
+ self.dataset = dataset
107
+ self.seed = seed
108
+ self.atoms = atoms
109
+ self.coordinates = coordinates
110
+ self.pocket_atoms = pocket_atoms
111
+ self.pocket_coordinates = pocket_coordinates
112
+ self.affinity = affinity
113
+ self.is_train = is_train
114
+ self.pocket=pocket
115
+ self.set_epoch(None)
116
+
117
+ def set_epoch(self, epoch, **unused):
118
+ super().set_epoch(epoch)
119
+ self.epoch = epoch
120
+
121
+ def pocket_atom(self, atom):
122
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
123
+ return atom[1]
124
+ else:
125
+ return atom[0]
126
+
127
+ @lru_cache(maxsize=16)
128
+ def __cached_item__(self, index: int, epoch: int):
129
+ #mol_atoms_list = self.dataset[index][self.atoms]
130
+ with data_utils.numpy_seed(self.seed, epoch, index):
131
+ mol_idx = np.random.randint(len(self.dataset[index][self.atoms]))
132
+ atoms = np.array(self.dataset[index][self.atoms][mol_idx])
133
+ ori_mol_length = len(atoms)
134
+ #coordinates = self.dataset[index][self.coordinates]
135
+ size = len(self.dataset[index][self.coordinates][mol_idx])
136
+ if self.is_train:
137
+ with data_utils.numpy_seed(self.seed, epoch, index):
138
+ sample_idx = np.random.randint(size)
139
+ else:
140
+ with data_utils.numpy_seed(self.seed, 1, index):
141
+ sample_idx = np.random.randint(size)
142
+ #print(len(self.dataset[index][self.coordinates][sample_idx]))
143
+ coordinates = self.dataset[index][self.coordinates][mol_idx][sample_idx]
144
+
145
+
146
+ #pocket_list = self.dataset[index][self.pocket_atoms]
147
+ with data_utils.numpy_seed(self.seed, epoch, index):
148
+ pocket_idx = np.random.randint(len(self.dataset[index][self.pocket_atoms]))
149
+ pocket_atoms = np.array(
150
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms][pocket_idx]]
151
+ )
152
+
153
+ ori_pocket_length = len(pocket_atoms)
154
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates][pocket_idx])
155
+
156
+ smi = self.dataset[index]["smiles"][mol_idx]
157
+ pocket = self.dataset[index][self.pocket][0]
158
+ if self.affinity in self.dataset[index]:
159
+ affinity = float(self.dataset[index][self.affinity])
160
+ else:
161
+ affinity = 1
162
+ return {
163
+ "atoms": atoms,
164
+ "coordinates": coordinates.astype(np.float32),
165
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
166
+ "pocket_atoms": pocket_atoms,
167
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
168
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
169
+ "smi": smi,
170
+ "pocket": pocket,
171
+ "affinity": affinity,
172
+ "ori_mol_length": ori_mol_length,
173
+ "ori_pocket_length": ori_pocket_length
174
+ }
175
+
176
+ def __getitem__(self, index: int):
177
+ return self.__cached_item__(index, self.epoch)
178
+
179
+
180
+ class AffinityHNSDataset(BaseWrapperDataset):
181
+ def __init__(
182
+ self,
183
+ dataset,
184
+ seed,
185
+ atoms,
186
+ coordinates,
187
+ atoms_hns,
188
+ coordinates_hns,
189
+ pocket_atoms,
190
+ pocket_coordinates,
191
+ affinity,
192
+ is_train=False,
193
+ pocket="pocket"
194
+ ):
195
+ self.dataset = dataset
196
+ self.seed = seed
197
+ self.atoms = atoms
198
+ self.coordinates = coordinates
199
+ self.atoms_hns = atoms_hns
200
+ self.coordinates_hns = coordinates_hns
201
+ self.pocket_atoms = pocket_atoms
202
+ self.pocket_coordinates = pocket_coordinates
203
+ self.affinity = affinity
204
+ self.is_train = is_train
205
+ self.pocket=pocket
206
+ self.set_epoch(None)
207
+
208
+ def set_epoch(self, epoch, **unused):
209
+ super().set_epoch(epoch)
210
+ self.epoch = epoch
211
+
212
+ def pocket_atom(self, atom):
213
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
214
+ return atom[1]
215
+ else:
216
+ return atom[0]
217
+
218
+ @lru_cache(maxsize=16)
219
+ def __cached_item__(self, index: int, epoch: int):
220
+ atoms = np.array(self.dataset[index][self.atoms])
221
+ ori_mol_length = len(atoms)
222
+ #coordinates = self.dataset[index][self.coordinates]
223
+ size = len(self.dataset[index][self.coordinates])
224
+ if self.is_train:
225
+ with data_utils.numpy_seed(self.seed, epoch, index):
226
+ sample_idx = np.random.randint(size)
227
+ else:
228
+ with data_utils.numpy_seed(self.seed, 1, index):
229
+ sample_idx = np.random.randint(size)
230
+ #print(len(self.dataset[index][self.coordinates][sample_idx]))
231
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
232
+ atoms_hns = np.array(self.dataset[index][self.atoms_hns])
233
+ coordinates_hns = self.dataset[index][self.coordinates_hns][0]
234
+
235
+
236
+
237
+ pocket_atoms = np.array(
238
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
239
+ )
240
+ ori_pocket_length = len(pocket_atoms)
241
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
242
+
243
+ smi = self.dataset[index]["smi"]
244
+ pocket = self.dataset[index][self.pocket]
245
+ if self.affinity in self.dataset[index]:
246
+ affinity = float(self.dataset[index][self.affinity])
247
+ else:
248
+ affinity = 1
249
+ return {
250
+ "atoms": atoms,
251
+ "coordinates": coordinates.astype(np.float32),
252
+ "atoms_hns": atoms_hns,
253
+ "coordinates_hns": coordinates_hns.astype(np.float32),
254
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
255
+ "pocket_atoms": pocket_atoms,
256
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
257
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
258
+ "smi": smi,
259
+ "pocket": pocket,
260
+ "affinity": affinity,
261
+ "ori_mol_length": ori_mol_length,
262
+ "ori_pocket_length": ori_pocket_length
263
+ }
264
+
265
+ def __getitem__(self, index: int):
266
+ return self.__cached_item__(index, self.epoch)
267
+
268
+ class AffinityTestDataset(BaseWrapperDataset):
269
+ def __init__(
270
+ self,
271
+ dataset,
272
+ seed,
273
+ atoms,
274
+ coordinates,
275
+ pocket_atoms,
276
+ pocket_coordinates,
277
+ affinity=None,
278
+ is_train=False,
279
+ pocket="pocket"
280
+ ):
281
+ self.dataset = dataset
282
+ self.seed = seed
283
+ self.atoms = atoms
284
+ self.coordinates = coordinates
285
+ self.pocket_atoms = pocket_atoms
286
+ self.pocket_coordinates = pocket_coordinates
287
+ self.affinity = affinity
288
+ self.is_train = is_train
289
+ self.pocket=pocket
290
+ self.set_epoch(None)
291
+
292
+ def set_epoch(self, epoch, **unused):
293
+ super().set_epoch(epoch)
294
+ self.epoch = epoch
295
+
296
+ def pocket_atom(self, atom):
297
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
298
+ return atom[1]
299
+ else:
300
+ return atom[0]
301
+
302
+ @lru_cache(maxsize=16)
303
+ def __cached_item__(self, index: int, epoch: int):
304
+ atoms = np.array(self.dataset[index][self.atoms])
305
+ ori_length = len(atoms)
306
+ #coordinates = self.dataset[index][self.coordinates]
307
+ size = len(self.dataset[index][self.coordinates])
308
+ with data_utils.numpy_seed(self.seed, epoch, index):
309
+ sample_idx = np.random.randint(size)
310
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
311
+ pocket_atoms = np.array(
312
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
313
+ )
314
+ #print(len(self.dataset[index][self.pocket_coordinates]))
315
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
316
+
317
+ smi = self.dataset[index]["smi"]
318
+ pocket = self.dataset[index][self.pocket]
319
+ affinity = self.dataset[index][self.affinity]
320
+ return {
321
+ "atoms": atoms,
322
+ "coordinates": coordinates.astype(np.float32),
323
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
324
+ "pocket_atoms": pocket_atoms,
325
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
326
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
327
+ "smi": smi,
328
+ "pocket": pocket,
329
+ "affinity": affinity.astype(np.float32),
330
+ "ori_length": ori_length
331
+ }
332
+
333
+ def __getitem__(self, index: int):
334
+ return self.__cached_item__(index, self.epoch)
335
+
336
+
337
+ class AffinityMolDataset(BaseWrapperDataset):
338
+ def __init__(
339
+ self,
340
+ dataset,
341
+ seed,
342
+ atoms,
343
+ coordinates,
344
+ is_train=False,
345
+ ):
346
+ self.dataset = dataset
347
+ self.seed = seed
348
+ self.atoms = atoms
349
+ self.coordinates = coordinates
350
+ self.is_train = is_train
351
+ self.set_epoch(None)
352
+
353
+ def set_epoch(self, epoch, **unused):
354
+ super().set_epoch(epoch)
355
+ self.epoch = epoch
356
+
357
+ def pocket_atom(self, atom):
358
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
359
+ return atom[1]
360
+ else:
361
+ return atom[0]
362
+
363
+ @lru_cache(maxsize=16)
364
+ def __cached_item__(self, index: int, epoch: int):
365
+ #print(self.dataset[index])
366
+ atoms = np.array(self.dataset[index][self.atoms])
367
+ ori_length = len(atoms)
368
+ #coordinates = self.dataset[index][self.coordinates]
369
+ size = len(self.dataset[index][self.coordinates])
370
+ #print(size)
371
+
372
+ # TODO: FB: introduce enough random when training using pairwise data
373
+ # with data_utils.numpy_seed(self.seed, epoch, index):
374
+ if self.is_train:
375
+ sample_idx = np.random.randint(size)
376
+ else:
377
+ with data_utils.numpy_seed(self.seed, index):
378
+ sample_idx = np.random.randint(size)
379
+ # check coordinates is 2 dimension or not
380
+ if len(self.dataset[index][self.coordinates][sample_idx].shape) == 2:
381
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
382
+ else:
383
+ coordinates = self.dataset[index][self.coordinates]
384
+ #coordinates = self.dataset[index][self.coordinates][sample_idx]
385
+ #coordinates = self.dataset[index][self.coordinates]
386
+
387
+ smi = self.dataset[index]["smi"]
388
+ name = self.dataset[index].get("name", None)
389
+ mol = pickle.dumps(self.dataset[index].get("mol", None))
390
+ return {
391
+ "atoms": atoms,
392
+ "coordinates": coordinates.astype(np.float32),
393
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
394
+ "smi": smi,
395
+ "ori_length": ori_length,
396
+ "name": name,
397
+ "mol": mol
398
+ }
399
+
400
+ def __getitem__(self, index: int):
401
+ return self.__cached_item__(index, self.epoch)
402
+
403
+
404
+ class AffinityPocketDataset(BaseWrapperDataset):
405
+ def __init__(
406
+ self,
407
+ dataset,
408
+ seed,
409
+ pocket_atoms,
410
+ pocket_coordinates,
411
+ is_train=False,
412
+ pocket="pocket"
413
+ ):
414
+ self.dataset = dataset
415
+ self.seed = seed
416
+ self.pocket_atoms = pocket_atoms
417
+ self.pocket_coordinates = pocket_coordinates
418
+ self.is_train = is_train
419
+ self.pocket=pocket
420
+ self.set_epoch(None)
421
+
422
+ def set_epoch(self, epoch, **unused):
423
+ super().set_epoch(epoch)
424
+ self.epoch = epoch
425
+
426
+ def pocket_atom(self, atom):
427
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
428
+ return atom[1]
429
+ else:
430
+ return atom[0]
431
+
432
+ @lru_cache(maxsize=16)
433
+ def __cached_item__(self, index: int, epoch: int):
434
+ # print(self.dataset[index].keys())
435
+ pocket_atoms = np.array(
436
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
437
+ )
438
+ ori_length = len(pocket_atoms)
439
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
440
+ if self.pocket in self.dataset[index]:
441
+ pocket = self.dataset[index][self.pocket]
442
+ else:
443
+ pocket = ""
444
+ if "pocket_residue_name" in self.dataset[index]:
445
+ pocket_residue_name = self.dataset[index]["pocket_residue_name"]
446
+ pocket_residue_name_noH = []
447
+ for res, atom in zip(pocket_residue_name, pocket_atoms):
448
+ if atom == "H":
449
+ continue
450
+ pocket_residue_name_noH.append(res)
451
+ else:
452
+ pocket_residue_name_noH = [""]
453
+ return {
454
+ "pocket_atoms": pocket_atoms,
455
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
456
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
457
+ "pocket": pocket,
458
+ "pocket_residue_name": pocket_residue_name_noH,
459
+ "ori_length": ori_length
460
+ }
461
+
462
+ def __getitem__(self, index: int):
463
+ return self.__cached_item__(index, self.epoch)
464
+
465
+ class AffinityValidDataset(BaseWrapperDataset):
466
+ def __init__(
467
+ self,
468
+ dataset,
469
+ seed,
470
+ atoms,
471
+ coordinates,
472
+ pocket_atoms,
473
+ pocket_coordinates,
474
+ pocket="pocket"
475
+ ):
476
+ self.dataset = dataset
477
+ self.seed = seed
478
+ self.atoms = atoms
479
+ self.coordinates = coordinates
480
+ self.pocket_atoms = pocket_atoms
481
+ self.pocket_coordinates = pocket_coordinates
482
+ self.pocket=pocket
483
+ self.set_epoch(None)
484
+
485
+ def set_epoch(self, epoch, **unused):
486
+ super().set_epoch(epoch)
487
+ self.epoch = epoch
488
+
489
+ def pocket_atom(self, atom):
490
+ if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
491
+ return atom[1]
492
+ else:
493
+ return atom[0]
494
+
495
+ @lru_cache(maxsize=16)
496
+ def __cached_item__(self, index: int, epoch: int):
497
+ atoms = np.array(self.dataset[index][self.atoms])
498
+ ori_mol_length = len(atoms)
499
+ #coordinates = self.dataset[index][self.coordinates]
500
+
501
+ size = len(self.dataset[index][self.coordinates])
502
+ with data_utils.numpy_seed(self.seed, epoch, index):
503
+ sample_idx = np.random.randint(size)
504
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
505
+ pocket_atoms = np.array(
506
+ [self.pocket_atom(item) for item in self.dataset[index][self.pocket_atoms]]
507
+ )
508
+ ori_pocket_length = len(pocket_atoms)
509
+ pocket_coordinates = np.stack(self.dataset[index][self.pocket_coordinates])
510
+
511
+ smi = self.dataset[index]["smi"]
512
+ pocket = self.dataset[index][self.pocket]
513
+ return {
514
+ "atoms": atoms,
515
+ "coordinates": coordinates.astype(np.float32),
516
+ "holo_coordinates": coordinates.astype(np.float32),#placeholder
517
+ "pocket_atoms": pocket_atoms,
518
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
519
+ "holo_pocket_coordinates": pocket_coordinates.astype(np.float32),#placeholder
520
+ "smi": smi,
521
+ "pocket": pocket,
522
+ "ori_mol_length": ori_mol_length,
523
+ "ori_pocket_length": ori_pocket_length
524
+ }
525
+
526
+ def __getitem__(self, index: int):
527
+ return self.__cached_item__(index, self.epoch)
unimol/data/atom_type_dataset.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from functools import lru_cache
6
+ from unicore.data import BaseWrapperDataset
7
+
8
+
9
+ class AtomTypeDataset(BaseWrapperDataset):
10
+ def __init__(
11
+ self,
12
+ raw_dataset,
13
+ dataset,
14
+ smi="smi",
15
+ atoms="atoms",
16
+ ):
17
+ self.raw_dataset = raw_dataset
18
+ self.dataset = dataset
19
+ self.smi = smi
20
+ self.atoms = atoms
21
+
22
+ @lru_cache(maxsize=16)
23
+ def __getitem__(self, index: int):
24
+ # for low rdkit version
25
+ if len(self.dataset[index]["atoms"]) != len(self.dataset[index]["coordinates"]):
26
+ min_len = min(
27
+ len(self.dataset[index]["atoms"]),
28
+ len(self.dataset[index]["coordinates"]),
29
+ )
30
+ self.dataset[index]["atoms"] = self.dataset[index]["atoms"][:min_len]
31
+ self.dataset[index]["coordinates"] = self.dataset[index]["coordinates"][
32
+ :min_len
33
+ ]
34
+ return self.dataset[index]
unimol/data/conformer_sample_dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ from functools import lru_cache
7
+ from unicore.data import BaseWrapperDataset
8
+ from . import data_utils
9
+
10
+
11
+ class ConformerSampleDataset(BaseWrapperDataset):
12
+ def __init__(self, dataset, seed, atoms, coordinates):
13
+ self.dataset = dataset
14
+ self.seed = seed
15
+ self.atoms = atoms
16
+ self.coordinates = coordinates
17
+ self.set_epoch(None)
18
+
19
+ def set_epoch(self, epoch, **unused):
20
+ super().set_epoch(epoch)
21
+ self.epoch = epoch
22
+
23
+ @lru_cache(maxsize=16)
24
+ def __cached_item__(self, index: int, epoch: int):
25
+ #print(index,self.dataset[index])
26
+ atoms = np.array(self.dataset[index][self.atoms])
27
+ assert len(atoms) > 0
28
+ size = len(self.dataset[index][self.coordinates])
29
+ #print(size)
30
+ with data_utils.numpy_seed(self.seed, epoch, index):
31
+ sample_idx = np.random.randint(size)
32
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
33
+ return {"atoms": atoms, "coordinates": coordinates.astype(np.float32)}
34
+
35
+ def __getitem__(self, index: int):
36
+ return self.__cached_item__(index, self.epoch)
37
+
38
+ class ConformerSampleDecoderDataset(BaseWrapperDataset):
39
+ def __init__(self, dataset, seed, atoms, coordinates, selfies):
40
+ self.dataset = dataset
41
+ self.seed = seed
42
+ self.atoms = atoms
43
+ self.coordinates = coordinates
44
+ self.selfies = selfies
45
+ self.set_epoch(None)
46
+
47
+ def set_epoch(self, epoch, **unused):
48
+ super().set_epoch(epoch)
49
+ self.epoch = epoch
50
+
51
+ @lru_cache(maxsize=16)
52
+ def __cached_item__(self, index: int, epoch: int):
53
+ atoms = np.array(self.dataset[index][self.atoms])
54
+ assert len(atoms) > 0
55
+ # print("self.dataset[index][self.atoms]")
56
+ # print(self.dataset[index][self.atoms])
57
+ # print("self.dataset[index][self.selfies]")
58
+ # print(self.dataset[index][self.selfies])
59
+ selfies = np.array(self.dataset[index][self.selfies])
60
+ assert len(selfies) > 0
61
+ size = len(self.dataset[index][self.coordinates])
62
+ with data_utils.numpy_seed(self.seed, epoch, index):
63
+ sample_idx = np.random.randint(size)
64
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
65
+ return {"atoms": atoms, "selfies": selfies, "coordinates": coordinates.astype(np.float32)}
66
+
67
+ def __getitem__(self, index: int):
68
+ return self.__cached_item__(index, self.epoch)
69
+
70
+ class ConformerSamplePocketDataset(BaseWrapperDataset):
71
+ def __init__(self, dataset, seed, atoms, coordinates, dict_name):
72
+ self.dataset = dataset
73
+ self.seed = seed
74
+ self.atoms = atoms
75
+ self.dict_name = dict_name
76
+ self.coordinates = coordinates
77
+ self.set_epoch(None)
78
+
79
+ def set_epoch(self, epoch, **unused):
80
+ super().set_epoch(epoch)
81
+ self.epoch = epoch
82
+
83
+ @lru_cache(maxsize=16)
84
+ def __cached_item__(self, index: int, epoch: int):
85
+ if self.dict_name == "dict_coarse.txt":
86
+ atoms = np.array([a[0] for a in self.dataset[index][self.atoms]])
87
+ elif self.dict_name == "dict_fine.txt":
88
+ atoms = np.array(
89
+ [
90
+ a[0] if len(a) == 1 or a[0] == "H" else a[:2]
91
+ for a in self.dataset[index][self.atoms]
92
+ ]
93
+ )
94
+ assert len(atoms) > 0
95
+ size = len(self.dataset[index][self.coordinates])
96
+ with data_utils.numpy_seed(self.seed, epoch, index):
97
+ sample_idx = np.random.randint(size)
98
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
99
+ residue = np.array(self.dataset[index]["residue"])
100
+ score = np.float(self.dataset[index]["meta_info"]["fpocket"]["Score"])
101
+ return {
102
+ "atoms": atoms,
103
+ "coordinates": coordinates.astype(np.float32),
104
+ "residue": residue,
105
+ "score": score,
106
+ }
107
+
108
+ def __getitem__(self, index: int):
109
+ return self.__cached_item__(index, self.epoch)
110
+
111
+
112
+ class ConformerSamplePocketFinetuneDataset(BaseWrapperDataset):
113
+ def __init__(self, dataset, seed, atoms, residues, coordinates):
114
+ self.dataset = dataset
115
+ self.seed = seed
116
+ self.atoms = atoms
117
+ self.residues = residues
118
+ self.coordinates = coordinates
119
+ self.set_epoch(None)
120
+
121
+ def set_epoch(self, epoch, **unused):
122
+ super().set_epoch(epoch)
123
+ self.epoch = epoch
124
+
125
+ @lru_cache(maxsize=16)
126
+ def __cached_item__(self, index: int, epoch: int):
127
+ atoms = np.array(
128
+ [a[0] for a in self.dataset[index][self.atoms]]
129
+ ) # only 'C H O N S'
130
+ assert len(atoms) > 0
131
+ # This judgment is reserved for possible future expansion.
132
+ # The number of pocket conformations is 1, and the 'sample' does not work.
133
+ if isinstance(self.dataset[index][self.coordinates], list):
134
+ size = len(self.dataset[index][self.coordinates])
135
+ with data_utils.numpy_seed(self.seed, epoch, index):
136
+ sample_idx = np.random.randint(size)
137
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
138
+ else:
139
+ coordinates = self.dataset[index][self.coordinates]
140
+
141
+ if self.residues in self.dataset[index]:
142
+ residues = np.array(self.dataset[index][self.residues])
143
+ else:
144
+ residues = None
145
+ assert len(atoms) == len(coordinates)
146
+ return {
147
+ self.atoms: atoms,
148
+ self.coordinates: coordinates.astype(np.float32),
149
+ self.residues: residues,
150
+ }
151
+
152
+ def __getitem__(self, index: int):
153
+ return self.__cached_item__(index, self.epoch)
154
+
155
+
156
+ class ConformerSampleConfGDataset(BaseWrapperDataset):
157
+ def __init__(self, dataset, seed, atoms, coordinates, tgt_coordinates):
158
+ self.dataset = dataset
159
+ self.seed = seed
160
+ self.atoms = atoms
161
+ self.coordinates = coordinates
162
+ self.tgt_coordinates = tgt_coordinates
163
+ self.set_epoch(None)
164
+
165
+ def set_epoch(self, epoch, **unused):
166
+ super().set_epoch(epoch)
167
+ self.epoch = epoch
168
+
169
+ @lru_cache(maxsize=16)
170
+ def __cached_item__(self, index: int, epoch: int):
171
+ atoms = np.array(self.dataset[index][self.atoms])
172
+ assert len(atoms) > 0
173
+ size = len(self.dataset[index][self.coordinates])
174
+ with data_utils.numpy_seed(self.seed, epoch, index):
175
+ sample_idx = np.random.randint(size)
176
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
177
+ tgt_coordinates = self.dataset[index][self.tgt_coordinates]
178
+ return {
179
+ self.atoms: atoms,
180
+ self.coordinates: coordinates.astype(np.float32),
181
+ self.tgt_coordinates: tgt_coordinates.astype(np.float32),
182
+ }
183
+
184
+ def __getitem__(self, index: int):
185
+ return self.__cached_item__(index, self.epoch)
186
+
187
+
188
+ class ConformerSampleConfGV2Dataset(BaseWrapperDataset):
189
+ def __init__(
190
+ self,
191
+ dataset,
192
+ seed,
193
+ atoms,
194
+ coordinates,
195
+ tgt_coordinates,
196
+ beta=1.0,
197
+ smooth=0.1,
198
+ topN=10,
199
+ ):
200
+ self.dataset = dataset
201
+ self.seed = seed
202
+ self.atoms = atoms
203
+ self.coordinates = coordinates
204
+ self.tgt_coordinates = tgt_coordinates
205
+ self.beta = beta
206
+ self.smooth = smooth
207
+ self.topN = topN
208
+ self.set_epoch(None)
209
+
210
+ def set_epoch(self, epoch, **unused):
211
+ super().set_epoch(epoch)
212
+ self.epoch = epoch
213
+
214
+ @lru_cache(maxsize=16)
215
+ def __cached_item__(self, index: int, epoch: int):
216
+ atoms = np.array(self.dataset[index][self.atoms])
217
+ assert len(atoms) > 0
218
+ meta_df = self.dataset[index]["meta"]
219
+ tgt_conf_ids = meta_df["gid"].unique()
220
+ # randomly choose one conf
221
+ with data_utils.numpy_seed(self.seed, epoch, index):
222
+ conf_id = np.random.choice(tgt_conf_ids)
223
+ conf_df = meta_df[meta_df["gid"] == conf_id]
224
+ conf_df = conf_df.sort_values("score").reset_index(drop=False)[
225
+ : self.topN
226
+ ] # only use top 5 confs for sampling...
227
+ # importance sampling with rmsd inverse score
228
+
229
+ def normalize(x, beta=1.0, smooth=0.1):
230
+ x = 1.0 / (x**beta + smooth)
231
+ return x / x.sum()
232
+
233
+ rmsd_score = conf_df["score"].values
234
+ weight = normalize(
235
+ rmsd_score, beta=self.beta, smooth=self.smooth
236
+ ) # for smoothing purpose
237
+ with data_utils.numpy_seed(self.seed, epoch, index):
238
+ idx = np.random.choice(len(conf_df), 1, replace=False, p=weight)
239
+ # idx = [np.argmax(weight)]
240
+ coordinates = conf_df.iloc[idx]["rdkit_coords"].values[0]
241
+ tgt_coordinates = conf_df.iloc[idx]["tgt_coords"].values[0]
242
+ return {
243
+ self.atoms: atoms,
244
+ self.coordinates: coordinates.astype(np.float32),
245
+ self.tgt_coordinates: tgt_coordinates.astype(np.float32),
246
+ }
247
+
248
+ def __getitem__(self, index: int):
249
+ return self.__cached_item__(index, self.epoch)
250
+
251
+
252
+ class ConformerSampleDockingPoseDataset(BaseWrapperDataset):
253
+ def __init__(
254
+ self,
255
+ dataset,
256
+ seed,
257
+ atoms,
258
+ coordinates,
259
+ pocket_atoms,
260
+ pocket_coordinates,
261
+ holo_coordinates,
262
+ holo_pocket_coordinates,
263
+ is_train=True,
264
+ ):
265
+ self.dataset = dataset
266
+ self.seed = seed
267
+ self.atoms = atoms
268
+ self.coordinates = coordinates
269
+ self.pocket_atoms = pocket_atoms
270
+ self.pocket_coordinates = pocket_coordinates
271
+ self.holo_coordinates = holo_coordinates
272
+ self.holo_pocket_coordinates = holo_pocket_coordinates
273
+ self.is_train = is_train
274
+ self.set_epoch(None)
275
+
276
+ def set_epoch(self, epoch, **unused):
277
+ super().set_epoch(epoch)
278
+ self.epoch = epoch
279
+
280
+ @lru_cache(maxsize=16)
281
+ def __cached_item__(self, index: int, epoch: int):
282
+ atoms = np.array(self.dataset[index][self.atoms])
283
+ size = len(self.dataset[index][self.coordinates])
284
+ with data_utils.numpy_seed(self.seed, epoch, index):
285
+ sample_idx = np.random.randint(size)
286
+ coordinates = self.dataset[index][self.coordinates][sample_idx]
287
+ pocket_atoms = np.array(
288
+ [item[0] for item in self.dataset[index][self.pocket_atoms]]
289
+ )
290
+ pocket_coordinates = self.dataset[index][self.pocket_coordinates][0]
291
+ if self.is_train:
292
+ holo_coordinates = self.dataset[index][self.holo_coordinates][0]
293
+ holo_pocket_coordinates = self.dataset[index][self.holo_pocket_coordinates][
294
+ 0
295
+ ]
296
+ else:
297
+ holo_coordinates = coordinates
298
+ holo_pocket_coordinates = pocket_coordinates
299
+
300
+ smi = self.dataset[index]["smi"]
301
+ pocket = self.dataset[index]["pocket"]
302
+
303
+ return {
304
+ "atoms": atoms,
305
+ "coordinates": coordinates.astype(np.float32),
306
+ "pocket_atoms": pocket_atoms,
307
+ "pocket_coordinates": pocket_coordinates.astype(np.float32),
308
+ "holo_coordinates": holo_coordinates.astype(np.float32),
309
+ "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32),
310
+ "smi": smi,
311
+ "pocket": pocket,
312
+ }
313
+
314
+ def __getitem__(self, index: int):
315
+ return self.__cached_item__(index, self.epoch)
unimol/data/coord_pad_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from unicore.data import BaseWrapperDataset
6
+
7
+
8
+ def collate_tokens_coords(
9
+ values,
10
+ pad_idx,
11
+ left_pad=False,
12
+ pad_to_length=None,
13
+ pad_to_multiple=1,
14
+ ):
15
+ """Convert a list of 1d tensors into a padded 2d tensor."""
16
+ size = max(v.size(0) for v in values)
17
+ size = size if pad_to_length is None else max(size, pad_to_length)
18
+ #if pad_to_multiple != 1 and size % pad_to_multiple != 0:
19
+ # size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
20
+ res = values[0].new(len(values), size, 3).fill_(pad_idx)
21
+
22
+ def copy_tensor(src, dst):
23
+ assert dst.numel() == src.numel()
24
+ dst.copy_(src)
25
+
26
+ for i, v in enumerate(values):
27
+ copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
28
+ return res
29
+
30
+
31
+ class RightPadDatasetCoord(BaseWrapperDataset):
32
+ def __init__(self, dataset, pad_idx, left_pad=False):
33
+ super().__init__(dataset)
34
+ self.pad_idx = pad_idx
35
+ self.left_pad = left_pad
36
+
37
+ def collater(self, samples):
38
+ return collate_tokens_coords(
39
+ samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
40
+ )
41
+
42
+
43
+ def collate_cross_2d(
44
+ values,
45
+ pad_idx,
46
+ left_pad=False,
47
+ pad_to_length=None,
48
+ pad_to_multiple=1,
49
+ ):
50
+ """Convert a list of 2d tensors into a padded 2d tensor."""
51
+ size_h = max(v.size(0) for v in values)
52
+ size_w = max(v.size(1) for v in values)
53
+ if pad_to_multiple != 1 and size_h % pad_to_multiple != 0:
54
+ size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
55
+ if pad_to_multiple != 1 and size_w % pad_to_multiple != 0:
56
+ size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
57
+ res = values[0].new(len(values), size_h, size_w).fill_(pad_idx)
58
+
59
+ def copy_tensor(src, dst):
60
+ assert dst.numel() == src.numel()
61
+ dst.copy_(src)
62
+
63
+ for i, v in enumerate(values):
64
+ copy_tensor(
65
+ v,
66
+ res[i][size_h - v.size(0) :, size_w - v.size(1) :]
67
+ if left_pad
68
+ else res[i][: v.size(0), : v.size(1)],
69
+ )
70
+ return res
71
+
72
+
73
+ class RightPadDatasetCross2D(BaseWrapperDataset):
74
+ def __init__(self, dataset, pad_idx, left_pad=False):
75
+ super().__init__(dataset)
76
+ self.pad_idx = pad_idx
77
+ self.left_pad = left_pad
78
+
79
+ def collater(self, samples):
80
+ return collate_cross_2d(
81
+ samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
82
+ )
unimol/data/cropping_dataset.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ from functools import lru_cache
7
+ import logging
8
+ from unicore.data import BaseWrapperDataset
9
+ from . import data_utils
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CroppingDataset(BaseWrapperDataset):
15
+ def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
16
+ self.dataset = dataset
17
+ self.seed = seed
18
+ self.atoms = atoms
19
+ self.coordinates = coordinates
20
+ self.max_atoms = max_atoms
21
+ self.set_epoch(None)
22
+
23
+ def set_epoch(self, epoch, **unused):
24
+ super().set_epoch(epoch)
25
+ self.epoch = epoch
26
+
27
+ @lru_cache(maxsize=16)
28
+ def __cached_item__(self, index: int, epoch: int):
29
+ dd = self.dataset[index].copy()
30
+ atoms = dd[self.atoms]
31
+ coordinates = dd[self.coordinates]
32
+ if self.max_atoms and len(atoms) > self.max_atoms:
33
+ with data_utils.numpy_seed(self.seed, epoch, index):
34
+ index = np.random.choice(len(atoms), self.max_atoms, replace=False)
35
+ atoms = np.array(atoms)[index]
36
+ coordinates = coordinates[index]
37
+ dd[self.atoms] = atoms
38
+ dd[self.coordinates] = coordinates.astype(np.float32)
39
+ return dd
40
+
41
+ def __getitem__(self, index: int):
42
+ return self.__cached_item__(index, self.epoch)
43
+
44
+
45
+ class CroppingPocketDataset(BaseWrapperDataset):
46
+ def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
47
+ self.dataset = dataset
48
+ self.seed = seed
49
+ self.atoms = atoms
50
+ self.coordinates = coordinates
51
+ self.max_atoms = (
52
+ max_atoms # max number of atoms in a molecule, None indicates no limit.
53
+ )
54
+ self.set_epoch(None)
55
+
56
+ def set_epoch(self, epoch, **unused):
57
+ super().set_epoch(epoch)
58
+ self.epoch = epoch
59
+
60
+ @lru_cache(maxsize=16)
61
+ def __cached_item__(self, index: int, epoch: int):
62
+ dd = self.dataset[index].copy()
63
+ atoms = dd[self.atoms]
64
+ coordinates = dd[self.coordinates]
65
+ #residue = dd["residue"]
66
+
67
+ # crop atoms according to their distance to the center of pockets
68
+ if self.max_atoms and len(atoms) > self.max_atoms:
69
+ with data_utils.numpy_seed(self.seed, epoch, index):
70
+ distance = np.linalg.norm(
71
+ coordinates - coordinates.mean(axis=0), axis=1
72
+ )
73
+
74
+ def softmax(x):
75
+ x -= np.max(x)
76
+ x = np.exp(x) / np.sum(np.exp(x))
77
+ return x
78
+
79
+ distance += 1 # prevent inf
80
+ weight = softmax(np.reciprocal(distance))
81
+ index = np.random.choice(
82
+ len(atoms), self.max_atoms, replace=False, p=weight
83
+ )
84
+ atoms = atoms[index]
85
+ coordinates = coordinates[index]
86
+ #residue = residue[index]
87
+
88
+ dd[self.atoms] = atoms
89
+ dd[self.coordinates] = coordinates.astype(np.float32)
90
+ #dd["residue"] = residue
91
+ return dd
92
+
93
+ def __getitem__(self, index: int):
94
+ return self.__cached_item__(index, self.epoch)
95
+
96
+
97
+ class CroppingResiduePocketDataset(BaseWrapperDataset):
98
+ def __init__(self, dataset, seed, atoms, residues, coordinates, max_atoms=256):
99
+ self.dataset = dataset
100
+ self.seed = seed
101
+ self.atoms = atoms
102
+ self.residues = residues
103
+ self.coordinates = coordinates
104
+ self.max_atoms = (
105
+ max_atoms # max number of atoms in a molecule, None indicates no limit.
106
+ )
107
+
108
+ self.set_epoch(None)
109
+
110
+ def set_epoch(self, epoch, **unused):
111
+ super().set_epoch(epoch)
112
+ self.epoch = epoch
113
+
114
+ @lru_cache(maxsize=16)
115
+ def __cached_item__(self, index: int, epoch: int):
116
+ dd = self.dataset[index].copy()
117
+ atoms = dd[self.atoms]
118
+ residues = dd[self.residues]
119
+ coordinates = dd[self.coordinates]
120
+
121
+ residues_distance_map = {}
122
+
123
+ # crop atoms according to their distance to the center of pockets
124
+ if self.max_atoms and len(atoms) > self.max_atoms:
125
+ with data_utils.numpy_seed(self.seed, epoch, index):
126
+ distance = np.linalg.norm(
127
+ coordinates - coordinates.mean(axis=0), axis=1
128
+ )
129
+ residues_ids, residues_distance = [], []
130
+ for res in residues:
131
+ if res not in residues_ids:
132
+ residues_ids.append(res)
133
+ residues_distance.append(distance[residues == res].mean())
134
+ residues_ids = np.array(residues_ids)
135
+ residues_distance = np.array(residues_distance)
136
+
137
+ def softmax(x):
138
+ x -= np.max(x)
139
+ x = np.exp(x) / np.sum(np.exp(x))
140
+ return x
141
+
142
+ residues_distance += 1 # prevent inf and smoothing out the distance
143
+ weight = softmax(np.reciprocal(residues_distance))
144
+ max_residues = self.max_atoms // (len(atoms) // (len(residues_ids) + 1))
145
+ if max_residues < 1:
146
+ max_residues += 1
147
+ max_residues = min(max_residues, len(residues_ids))
148
+ residue_index = np.random.choice(
149
+ len(residues_ids), max_residues, replace=False, p=weight
150
+ )
151
+ index = [
152
+ i
153
+ for i in range(len(atoms))
154
+ if residues[i] in residues_ids[residue_index]
155
+ ]
156
+ atoms = atoms[index]
157
+ coordinates = coordinates[index]
158
+ residues = residues[index]
159
+
160
+ dd[self.atoms] = atoms
161
+ dd[self.coordinates] = coordinates.astype(np.float32)
162
+ dd[self.residues] = residues
163
+ return dd
164
+
165
+ def __getitem__(self, index: int):
166
+ return self.__cached_item__(index, self.epoch)
167
+
168
+
169
+ class CroppingPocketDockingPoseDataset(BaseWrapperDataset):
170
+ def __init__(
171
+ self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256
172
+ ):
173
+ self.dataset = dataset
174
+ self.seed = seed
175
+ self.atoms = atoms
176
+ self.coordinates = coordinates
177
+ self.max_atoms = max_atoms
178
+
179
+ self.set_epoch(None)
180
+
181
+ def set_epoch(self, epoch, **unused):
182
+ super().set_epoch(epoch)
183
+ self.epoch = epoch
184
+
185
+ @lru_cache(maxsize=16)
186
+ def __cached_item__(self, index: int, epoch: int):
187
+ dd = self.dataset[index].copy()
188
+ atoms = dd[self.atoms]
189
+ coordinates = dd[self.coordinates]
190
+ holo_coordinates = dd[self.holo_coordinates]
191
+
192
+ # crop atoms according to their distance to the center of pockets
193
+ #print(len(atoms))
194
+ if self.max_atoms and len(atoms) > self.max_atoms:
195
+ with data_utils.numpy_seed(self.seed, 1):
196
+ distance = np.linalg.norm(
197
+ coordinates - coordinates.mean(axis=0), axis=1
198
+ )
199
+
200
+ def softmax(x):
201
+ x -= np.max(x)
202
+ x = np.exp(x) / np.sum(np.exp(x))
203
+ return x
204
+
205
+ distance += 1 # prevent inf
206
+ weight = softmax(np.reciprocal(distance))
207
+ index = np.random.choice(
208
+ len(atoms), self.max_atoms, replace=False, p=weight
209
+ )
210
+ atoms = atoms[index]
211
+ coordinates = coordinates[index]
212
+ holo_coordinates = holo_coordinates[index]
213
+
214
+ dd[self.atoms] = atoms
215
+ dd[self.coordinates] = coordinates.astype(np.float32)
216
+ dd[self.holo_coordinates] = holo_coordinates.astype(np.float32)
217
+ return dd
218
+
219
+ def __getitem__(self, index: int):
220
+ return self.__cached_item__(index, self.epoch)
221
+
222
+ class CroppingPocketDockingPoseTestDataset(BaseWrapperDataset):
223
+ def __init__(
224
+ self, dataset, seed, atoms, coordinates, max_atoms=256
225
+ ):
226
+ self.dataset = dataset
227
+ self.seed = seed
228
+ self.atoms = atoms
229
+ self.coordinates = coordinates
230
+ self.max_atoms = max_atoms
231
+
232
+ self.set_epoch(None)
233
+
234
+ def set_epoch(self, epoch, **unused):
235
+ super().set_epoch(epoch)
236
+ self.epoch = epoch
237
+
238
+ @lru_cache(maxsize=16)
239
+ def __cached_item__(self, index: int, epoch: int):
240
+ dd = self.dataset[index].copy()
241
+ atoms = dd[self.atoms]
242
+ coordinates = dd[self.coordinates]
243
+
244
+ # crop atoms according to their distance to the center of pockets
245
+ if self.max_atoms and len(atoms) > self.max_atoms:
246
+ with data_utils.numpy_seed(1, 1):
247
+ distance = np.linalg.norm(
248
+ coordinates - coordinates.mean(axis=0), axis=1
249
+ )
250
+
251
+ def softmax(x):
252
+ x -= np.max(x)
253
+ x = np.exp(x) / np.sum(np.exp(x))
254
+ return x
255
+
256
+ distance += 1 # prevent inf
257
+ weight = softmax(np.reciprocal(distance))
258
+ index = np.random.choice(
259
+ len(atoms), self.max_atoms, replace=False, p=weight
260
+ )
261
+ atoms = atoms[index]
262
+ coordinates = coordinates[index]
263
+
264
+ dd[self.atoms] = atoms
265
+ dd[self.coordinates] = coordinates.astype(np.float32)
266
+ return dd
267
+
268
+ def __getitem__(self, index: int):
269
+ return self.__cached_item__(index, self.epoch)
unimol/data/data_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ import contextlib
7
+
8
+
9
+ @contextlib.contextmanager
10
+ def numpy_seed(seed, *addl_seeds):
11
+ """Context manager which seeds the NumPy PRNG with the specified seed and
12
+ restores the state afterward"""
13
+ if seed is None:
14
+ yield
15
+ return
16
+ if len(addl_seeds) > 0:
17
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
18
+ state = np.random.get_state()
19
+ np.random.seed(seed)
20
+ try:
21
+ yield
22
+ finally:
23
+ np.random.set_state(state)
unimol/data/dictionary.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import logging
7
+
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
11
+
12
+ class DecoderDictionary:
13
+ """A mapping from symbols to consecutive integers"""
14
+
15
+ def __init__(
16
+ self,
17
+ *, # begin keyword-only arguments
18
+ bos="[CLS]",
19
+ pad="[PAD]",
20
+ eos="[SEP]",
21
+ unk="[UNK]",
22
+ extra_special_symbols=None,
23
+ ):
24
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
25
+ self.symbols = []
26
+ self.count = []
27
+ self.indices = {}
28
+ self.idx2sym = {}
29
+ self.specials = set()
30
+ self.specials.add(bos)
31
+ self.specials.add(unk)
32
+ self.specials.add(pad)
33
+ self.specials.add(eos)
34
+
35
+ def __eq__(self, other):
36
+ return self.indices == other.indices
37
+
38
+ def __getitem__(self, idx):
39
+ if idx < len(self.symbols):
40
+ return self.symbols[idx]
41
+ return self.unk_word
42
+
43
+ def __len__(self):
44
+ """Returns the number of symbols in the dictionary"""
45
+ return len(self.symbols)
46
+
47
+ def __contains__(self, sym):
48
+ return sym in self.indices
49
+
50
+ def vec_index(self, a):
51
+ return np.vectorize(self.index)(a)
52
+
53
+ def index(self, sym):
54
+ """Returns the index of the specified symbol"""
55
+ assert isinstance(sym, str)
56
+ if sym in self.indices:
57
+ return self.indices[sym]
58
+ return self.indices[self.unk_word]
59
+
60
+ def index2symbol(self, idx):
61
+ """Returns the corresponding symbol of the specified index"""
62
+ assert isinstance(idx, int)
63
+ if idx in self.idx2sym:
64
+ return self.idx2sym[idx]
65
+ return self.unk_word
66
+
67
+ def special_index(self):
68
+ return [self.index(x) for x in self.specials]
69
+
70
+ def add_symbol(self, word, n=1, overwrite=False, is_special=False):
71
+ """Adds a word to the dictionary"""
72
+ if is_special:
73
+ self.specials.add(word)
74
+ if word in self.indices and not overwrite:
75
+ idx = self.indices[word]
76
+ self.count[idx] = self.count[idx] + n
77
+ return idx
78
+ else:
79
+ idx = len(self.symbols)
80
+ self.indices[word] = idx
81
+ self.idx2sym[idx] = word
82
+ self.symbols.append(word)
83
+ self.count.append(n)
84
+ return idx
85
+
86
+ def bos(self):
87
+ """Helper to get index of beginning-of-sentence symbol"""
88
+ return self.index(self.bos_word)
89
+
90
+ def pad(self):
91
+ """Helper to get index of pad symbol"""
92
+ return self.index(self.pad_word)
93
+
94
+ def eos(self):
95
+ """Helper to get index of end-of-sentence symbol"""
96
+ return self.index(self.eos_word)
97
+
98
+ def unk(self):
99
+ """Helper to get index of unk symbol"""
100
+ return self.index(self.unk_word)
101
+
102
+ @classmethod
103
+ def load(cls, f):
104
+ """Loads the dictionary from a text file with the format:
105
+
106
+ ```
107
+ <symbol0> <count0>
108
+ <symbol1> <count1>
109
+ ...
110
+ ```
111
+ """
112
+ d = cls()
113
+ d.add_from_file(f)
114
+ return d
115
+
116
+ def add_from_file(self, f):
117
+ """
118
+ Loads a pre-existing dictionary from a text file and adds its symbols
119
+ to this instance.
120
+ """
121
+ if isinstance(f, str):
122
+ try:
123
+ with open(f, "r", encoding="utf-8") as fd:
124
+ self.add_from_file(fd)
125
+ except FileNotFoundError as fnfe:
126
+ raise fnfe
127
+ except UnicodeError:
128
+ raise Exception(
129
+ "Incorrect encoding detected in {}, please "
130
+ "rebuild the dataset".format(f)
131
+ )
132
+ return
133
+
134
+ lines = f.readlines()
135
+
136
+ for line_idx, line in enumerate(lines):
137
+ try:
138
+ splits = line.rstrip().rsplit(" ", 1)
139
+ line = splits[0]
140
+ field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
141
+ if field == "#overwrite":
142
+ overwrite = True
143
+ line, field = line.rsplit(" ", 1)
144
+ else:
145
+ overwrite = False
146
+ count = int(field)
147
+ word = line
148
+ if word in self and not overwrite:
149
+ logger.info(
150
+ "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
151
+ )
152
+ else:
153
+ self.add_symbol(word, n=count, overwrite=overwrite)
154
+ except ValueError:
155
+ raise ValueError(
156
+ "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
157
+ )
unimol/data/distance_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ import torch
7
+ from scipy.spatial import distance_matrix
8
+ from functools import lru_cache
9
+ from unicore.data import BaseWrapperDataset
10
+
11
+
12
+ class DistanceDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset):
14
+ super().__init__(dataset)
15
+ self.dataset = dataset
16
+
17
+ @lru_cache(maxsize=16)
18
+ def __getitem__(self, idx):
19
+ pos = self.dataset[idx].view(-1, 3).numpy()
20
+ dist = distance_matrix(pos, pos).astype(np.float32)
21
+ return torch.from_numpy(dist)
22
+
23
+
24
+ class EdgeTypeDataset(BaseWrapperDataset):
25
+ def __init__(self, dataset: torch.utils.data.Dataset, num_types: int):
26
+ self.dataset = dataset
27
+ self.num_types = num_types
28
+
29
+ @lru_cache(maxsize=16)
30
+ def __getitem__(self, index: int):
31
+ node_input = self.dataset[index].clone()
32
+ offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1)
33
+ return offset
34
+
35
+
36
+ class CrossDistanceDataset(BaseWrapperDataset):
37
+ def __init__(self, mol_dataset, pocket_dataset):
38
+ super().__init__(mol_dataset)
39
+ self.dataset = mol_dataset
40
+ self.mol_dataset = mol_dataset
41
+ self.pocket_dataset = pocket_dataset
42
+
43
+ @lru_cache(maxsize=16)
44
+ def __getitem__(self, idx):
45
+ mol_pos = self.mol_dataset[idx].view(-1, 3).numpy()
46
+ pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy()
47
+ dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32)
48
+ assert dist.shape[0] == self.mol_dataset[idx].shape[0]
49
+ assert dist.shape[1] == self.pocket_dataset[idx].shape[0]
50
+ return torch.from_numpy(dist)
51
+
52
+ class CrossEdgeTypeDataset(BaseWrapperDataset):
53
+ def __init__(self, mol_dataset, pocket_dataset, num_types: int):
54
+ self.dataset = mol_dataset
55
+ self.mol_dataset = mol_dataset
56
+ self.pocket_dataset = pocket_dataset
57
+ self.num_types = num_types
58
+
59
+ @lru_cache(maxsize=16)
60
+ def __getitem__(self, index: int):
61
+ mol_node_input = self.mol_dataset[index].clone()
62
+ pocket_node_input = self.pocket_dataset[index].clone()
63
+ offset = mol_node_input.view(-1, 1) * self.num_types + pocket_node_input.view(1, -1)
64
+ return offset
unimol/data/from_str_dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import lru_cache
3
+ from unicore.data import UnicoreDataset
4
+
5
+
6
+ class FromStrLabelDataset(UnicoreDataset):
7
+ def __init__(self, labels):
8
+ super().__init__()
9
+ self.labels = labels
10
+
11
+ @lru_cache(maxsize=16)
12
+ def __getitem__(self, index):
13
+ return self.labels[index]
14
+
15
+ def __len__(self):
16
+ return len(self.labels)
17
+
18
+ def collater(self, samples):
19
+ return torch.tensor(list(map(float, samples)))
unimol/data/key_dataset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from functools import lru_cache
6
+ from unicore.data import BaseWrapperDataset
7
+
8
+
9
+ class KeyDataset(BaseWrapperDataset):
10
+ def __init__(self, dataset, key):
11
+ self.dataset = dataset
12
+ self.key = key
13
+
14
+ def __len__(self):
15
+ return len(self.dataset)
16
+
17
+ @lru_cache(maxsize=16)
18
+ def __getitem__(self, idx):
19
+ return self.dataset[idx][self.key]
20
+
21
+ class LengthDataset(BaseWrapperDataset):
22
+
23
+ def __init__(self, dataset):
24
+ super().__init__(dataset)
25
+
26
+ @lru_cache(maxsize=16)
27
+ def __getitem__(self, idx):
28
+ item = self.dataset[idx]
29
+ return len(item)
unimol/data/lmdb_dataset.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+
6
+ import lmdb
7
+ import os
8
+ import pickle
9
+ from functools import lru_cache
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LMDBDataset:
16
+ def __init__(self, db_path):
17
+ self.db_path = db_path
18
+ assert os.path.isfile(self.db_path), "{} not found".format(self.db_path)
19
+ env = self.connect_db(self.db_path)
20
+ with env.begin() as txn:
21
+ self._keys = list(txn.cursor().iternext(values=False))
22
+
23
+ def connect_db(self, lmdb_path, save_to_self=False):
24
+ env = lmdb.open(
25
+ lmdb_path,
26
+ subdir=False,
27
+ readonly=True,
28
+ lock=False,
29
+ readahead=False,
30
+ meminit=False,
31
+ max_readers=256,
32
+ )
33
+ if not save_to_self:
34
+ return env
35
+ else:
36
+ self.env = env
37
+
38
+ def __len__(self):
39
+ return len(self._keys)
40
+
41
+ @lru_cache(maxsize=16)
42
+ def __getitem__(self, idx):
43
+ if not hasattr(self, "env"):
44
+ self.connect_db(self.db_path, save_to_self=True)
45
+ #datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii"))
46
+ #print(idx)
47
+ datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii"))
48
+ data = pickle.loads(datapoint_pickled)
49
+ return data
unimol/data/mask_points_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+ import torch
9
+ from unicore.data import Dictionary
10
+ from unicore.data import BaseWrapperDataset
11
+ from . import data_utils
12
+
13
+
14
+ class MaskPointsDataset(BaseWrapperDataset):
15
+ def __init__(
16
+ self,
17
+ dataset: torch.utils.data.Dataset,
18
+ coord_dataset: torch.utils.data.Dataset,
19
+ vocab: Dictionary,
20
+ pad_idx: int,
21
+ mask_idx: int,
22
+ noise_type: str,
23
+ noise: float = 1.0,
24
+ seed: int = 1,
25
+ mask_prob: float = 0.15,
26
+ leave_unmasked_prob: float = 0.1,
27
+ random_token_prob: float = 0.1,
28
+ ):
29
+ assert 0.0 < mask_prob < 1.0
30
+ assert 0.0 <= random_token_prob <= 1.0
31
+ assert 0.0 <= leave_unmasked_prob <= 1.0
32
+ assert random_token_prob + leave_unmasked_prob <= 1.0
33
+
34
+ self.dataset = dataset
35
+ self.coord_dataset = coord_dataset
36
+ self.vocab = vocab
37
+ self.pad_idx = pad_idx
38
+ self.mask_idx = mask_idx
39
+ self.noise_type = noise_type
40
+ self.noise = noise
41
+ self.seed = seed
42
+ self.mask_prob = mask_prob
43
+ self.leave_unmasked_prob = leave_unmasked_prob
44
+ self.random_token_prob = random_token_prob
45
+
46
+ if random_token_prob > 0.0:
47
+ weights = np.ones(len(self.vocab))
48
+ weights[vocab.special_index()] = 0
49
+ self.weights = weights / weights.sum()
50
+
51
+ self.epoch = None
52
+ if self.noise_type == "trunc_normal":
53
+ self.noise_f = lambda num_mask: np.clip(
54
+ np.random.randn(num_mask, 3) * self.noise,
55
+ a_min=-self.noise * 2.0,
56
+ a_max=self.noise * 2.0,
57
+ )
58
+ elif self.noise_type == "normal":
59
+ self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
60
+ elif self.noise_type == "uniform":
61
+ self.noise_f = lambda num_mask: np.random.uniform(
62
+ low=-self.noise, high=self.noise, size=(num_mask, 3)
63
+ )
64
+ else:
65
+ self.noise_f = lambda num_mask: 0.0
66
+
67
+ def set_epoch(self, epoch, **unused):
68
+ super().set_epoch(epoch)
69
+ self.coord_dataset.set_epoch(epoch)
70
+ self.dataset.set_epoch(epoch)
71
+ self.epoch = epoch
72
+
73
+ def __getitem__(self, index: int):
74
+ return self.__getitem_cached__(self.epoch, index)
75
+
76
+ @lru_cache(maxsize=16)
77
+ def __getitem_cached__(self, epoch: int, index: int):
78
+ ret = {}
79
+ with data_utils.numpy_seed(self.seed, epoch, index):
80
+ item = self.dataset[index]
81
+ coord = self.coord_dataset[index]
82
+ sz = len(item)
83
+ # don't allow empty sequence
84
+ assert sz > 0
85
+ # decide elements to mask
86
+ num_mask = int(
87
+ # add a random number for probabilistic rounding
88
+ self.mask_prob * sz
89
+ + np.random.rand()
90
+ )
91
+ mask_idc = np.random.choice(sz, num_mask, replace=False)
92
+ mask = np.full(sz, False)
93
+ mask[mask_idc] = True
94
+ ret["targets"] = np.full(len(mask), self.pad_idx)
95
+ ret["targets"][mask] = item[mask]
96
+ ret["targets"] = torch.from_numpy(ret["targets"]).long()
97
+ # decide unmasking and random replacement
98
+ rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
99
+ if rand_or_unmask_prob > 0.0:
100
+ rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
101
+ if self.random_token_prob == 0.0:
102
+ unmask = rand_or_unmask
103
+ rand_mask = None
104
+ elif self.leave_unmasked_prob == 0.0:
105
+ unmask = None
106
+ rand_mask = rand_or_unmask
107
+ else:
108
+ unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
109
+ decision = np.random.rand(sz) < unmask_prob
110
+ unmask = rand_or_unmask & decision
111
+ rand_mask = rand_or_unmask & (~decision)
112
+ else:
113
+ unmask = rand_mask = None
114
+
115
+ if unmask is not None:
116
+ mask = mask ^ unmask
117
+
118
+ new_item = np.copy(item)
119
+ new_item[mask] = self.mask_idx
120
+
121
+ num_mask = mask.astype(np.int32).sum()
122
+ new_coord = np.copy(coord)
123
+ new_coord[mask, :] += self.noise_f(num_mask)
124
+
125
+ if rand_mask is not None:
126
+ num_rand = rand_mask.sum()
127
+ if num_rand > 0:
128
+ new_item[rand_mask] = np.random.choice(
129
+ len(self.vocab),
130
+ num_rand,
131
+ p=self.weights,
132
+ )
133
+ ret["atoms"] = torch.from_numpy(new_item).long()
134
+ ret["coordinates"] = torch.from_numpy(new_coord).float()
135
+ return ret
136
+
137
+
138
+ class MaskPointsPocketDataset(BaseWrapperDataset):
139
+ def __init__(
140
+ self,
141
+ dataset: torch.utils.data.Dataset,
142
+ coord_dataset: torch.utils.data.Dataset,
143
+ residue_dataset: torch.utils.data.Dataset,
144
+ vocab: Dictionary,
145
+ pad_idx: int,
146
+ mask_idx: int,
147
+ noise_type: str,
148
+ noise: float = 1.0,
149
+ seed: int = 1,
150
+ mask_prob: float = 0.15,
151
+ leave_unmasked_prob: float = 0.1,
152
+ random_token_prob: float = 0.1,
153
+ ):
154
+ assert 0.0 < mask_prob < 1.0
155
+ assert 0.0 <= random_token_prob <= 1.0
156
+ assert 0.0 <= leave_unmasked_prob <= 1.0
157
+ assert random_token_prob + leave_unmasked_prob <= 1.0
158
+
159
+ self.dataset = dataset
160
+ self.coord_dataset = coord_dataset
161
+ self.residue_dataset = residue_dataset
162
+ self.vocab = vocab
163
+ self.pad_idx = pad_idx
164
+ self.mask_idx = mask_idx
165
+ self.noise_type = noise_type
166
+ self.noise = noise
167
+ self.seed = seed
168
+ self.mask_prob = mask_prob
169
+ self.leave_unmasked_prob = leave_unmasked_prob
170
+ self.random_token_prob = random_token_prob
171
+
172
+ if random_token_prob > 0.0:
173
+ weights = np.ones(len(self.vocab))
174
+ weights[vocab.special_index()] = 0
175
+ self.weights = weights / weights.sum()
176
+
177
+ self.epoch = None
178
+ if self.noise_type == "trunc_normal":
179
+ self.noise_f = lambda num_mask: np.clip(
180
+ np.random.randn(num_mask, 3) * self.noise,
181
+ a_min=-self.noise * 2.0,
182
+ a_max=self.noise * 2.0,
183
+ )
184
+ elif self.noise_type == "normal":
185
+ self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
186
+ elif self.noise_type == "uniform":
187
+ self.noise_f = lambda num_mask: np.random.uniform(
188
+ low=-self.noise, high=self.noise, size=(num_mask, 3)
189
+ )
190
+ else:
191
+ self.noise_f = lambda num_mask: 0.0
192
+
193
+ def set_epoch(self, epoch, **unused):
194
+ super().set_epoch(epoch)
195
+ self.coord_dataset.set_epoch(epoch)
196
+ self.dataset.set_epoch(epoch)
197
+ self.epoch = epoch
198
+
199
+ def __getitem__(self, index: int):
200
+ return self.__getitem_cached__(self.epoch, index)
201
+
202
+ @lru_cache(maxsize=16)
203
+ def __getitem_cached__(self, epoch: int, index: int):
204
+ ret = {}
205
+ with data_utils.numpy_seed(self.seed, epoch, index):
206
+ item = self.dataset[index]
207
+ coord = self.coord_dataset[index]
208
+ sz = len(item)
209
+ # don't allow empty sequence
210
+ assert sz > 0
211
+
212
+ # mask on the level of residues
213
+ residue = self.residue_dataset[index]
214
+ res_list = list(set(residue))
215
+ res_sz = len(res_list)
216
+
217
+ # decide elements to mask
218
+ num_mask = int(
219
+ # add a random number for probabilistic rounding
220
+ self.mask_prob * res_sz
221
+ + np.random.rand()
222
+ )
223
+ mask_res = np.random.choice(res_list, num_mask, replace=False).tolist()
224
+ mask = np.isin(residue, mask_res)
225
+
226
+ ret["targets"] = np.full(len(mask), self.pad_idx)
227
+ ret["targets"][mask] = item[mask]
228
+ ret["targets"] = torch.from_numpy(ret["targets"]).long()
229
+ # decide unmasking and random replacement
230
+ rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
231
+ if rand_or_unmask_prob > 0.0:
232
+ rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
233
+ if self.random_token_prob == 0.0:
234
+ unmask = rand_or_unmask
235
+ rand_mask = None
236
+ elif self.leave_unmasked_prob == 0.0:
237
+ unmask = None
238
+ rand_mask = rand_or_unmask
239
+ else:
240
+ unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
241
+ decision = np.random.rand(sz) < unmask_prob
242
+ unmask = rand_or_unmask & decision
243
+ rand_mask = rand_or_unmask & (~decision)
244
+ else:
245
+ unmask = rand_mask = None
246
+
247
+ if unmask is not None:
248
+ mask = mask ^ unmask
249
+
250
+ new_item = np.copy(item)
251
+ new_item[mask] = self.mask_idx
252
+
253
+ num_mask = mask.astype(np.int32).sum()
254
+ new_coord = np.copy(coord)
255
+ new_coord[mask, :] += self.noise_f(num_mask)
256
+
257
+ if rand_mask is not None:
258
+ num_rand = rand_mask.sum()
259
+ if num_rand > 0:
260
+ new_item[rand_mask] = np.random.choice(
261
+ len(self.vocab),
262
+ num_rand,
263
+ p=self.weights,
264
+ )
265
+ ret["atoms"] = torch.from_numpy(new_item).long()
266
+ ret["coordinates"] = torch.from_numpy(new_coord).float()
267
+ return ret
unimol/data/normalize_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) DP Technology.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import numpy as np
6
+ from functools import lru_cache
7
+ from unicore.data import BaseWrapperDataset
8
+
9
+
10
+ class NormalizeDataset(BaseWrapperDataset):
11
+ def __init__(self, dataset, coordinates, normalize_coord=True):
12
+ self.dataset = dataset
13
+ self.coordinates = coordinates
14
+ self.normalize_coord = normalize_coord # normalize the coordinates.
15
+ self.set_epoch(None)
16
+
17
+ def set_epoch(self, epoch, **unused):
18
+ super().set_epoch(epoch)
19
+ self.epoch = epoch
20
+
21
+ @lru_cache(maxsize=16)
22
+ def __cached_item__(self, index: int, epoch: int):
23
+ dd = self.dataset[index].copy()
24
+ coordinates = dd[self.coordinates]
25
+ # normalize
26
+ if self.normalize_coord:
27
+ coordinates = coordinates - coordinates.mean(axis=0)
28
+ dd[self.coordinates] = coordinates.astype(np.float32)
29
+ return dd
30
+
31
+ def __getitem__(self, index: int):
32
+ return self.__cached_item__(index, self.epoch)
33
+
34
+
35
+ class NormalizeDockingPoseDataset(BaseWrapperDataset):
36
+ def __init__(
37
+ self,
38
+ dataset,
39
+ coordinates,
40
+ pocket_coordinates,
41
+ center_coordinates="center_coordinates",
42
+ ):
43
+ self.dataset = dataset
44
+ self.coordinates = coordinates
45
+ self.pocket_coordinates = pocket_coordinates
46
+ self.center_coordinates = center_coordinates
47
+ self.set_epoch(None)
48
+
49
+ def set_epoch(self, epoch, **unused):
50
+ super().set_epoch(epoch)
51
+ self.epoch = epoch
52
+
53
+ @lru_cache(maxsize=16)
54
+ def __cached_item__(self, index: int, epoch: int):
55
+ dd = self.dataset[index].copy()
56
+ coordinates = dd[self.coordinates]
57
+ pocket_coordinates = dd[self.pocket_coordinates]
58
+ # normalize coordinates and pocket coordinates ,align with pocket center coordinates
59
+ center_coordinates = pocket_coordinates.mean(axis=0)
60
+ coordinates = coordinates - center_coordinates
61
+ pocket_coordinates = pocket_coordinates - center_coordinates
62
+ dd[self.coordinates] = coordinates.astype(np.float32)
63
+ dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32)
64
+ dd[self.center_coordinates] = center_coordinates.astype(np.float32)
65
+ return dd
66
+
67
+ def __getitem__(self, index: int):
68
+ return self.__cached_item__(index, self.epoch)
unimol/data/pair_dataset.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+
4
+ import math
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ from unicore.data import UnicoreDataset
9
+ import numpy as np
10
+ from . import data_utils
11
+ import rdkit
12
+ from rdkit import Chem
13
+ from rdkit import DataStructs
14
+ from rdkit.Chem import rdFingerprintGenerator
15
+ from multiprocessing import Pool
16
+ from tqdm import tqdm
17
+
18
+ def get_fp(smiles):
19
+ mol = Chem.MolFromSmiles(smiles)
20
+ fp_numpy = np.zeros((0,), np.int8) # Generate target pointer to fill
21
+ if mol is None:
22
+ return None
23
+ fingerprints_vect = rdFingerprintGenerator.GetCountFPs(
24
+ [mol], fpType=rdFingerprintGenerator.MorganFP
25
+ )[0]
26
+ DataStructs.ConvertToNumpyArray(fingerprints_vect, fp_numpy)
27
+ return fp_numpy
28
+
29
+ class PairDataset(UnicoreDataset):
30
+ def __init__(self, args, pocket_dataset, mol_dataset, labels, split, use_cache=True, cache_dir=None):
31
+ self.args = args
32
+ self.pocket_dataset = pocket_dataset
33
+ self.mol_dataset = mol_dataset
34
+ self.labels = labels
35
+
36
+ # use the cached file, or it will take loooooong time to load
37
+ if use_cache:
38
+ pocket_name2idx_file = f"{cache_dir}/cache/pocket_name2idx_train_blend.json"
39
+ if os.path.exists(pocket_name2idx_file):
40
+ self.pocket_name2idx = json.load(open(pocket_name2idx_file))
41
+ else:
42
+ self.pocket_name2idx = {x["pocket_name"]:i for i,x in enumerate(self.pocket_dataset)}
43
+ json.dump(self.pocket_name2idx, open(pocket_name2idx_file, "w"))
44
+ else:
45
+ self.pocket_name2idx = {x["pocket_name"]: i for i, x in enumerate(self.pocket_dataset)}
46
+
47
+ if use_cache:
48
+ mol_smi2idx_file = f"{cache_dir}/cache/mol_smi2idx_train_blend.json"
49
+ if os.path.exists(mol_smi2idx_file):
50
+ self.mol_smi2idx = json.load(open(mol_smi2idx_file))
51
+ else:
52
+ self.mol_smi2idx = {x["smi_name"]: i for i, x in enumerate(self.mol_dataset)}
53
+ json.dump(self.mol_smi2idx, open(mol_smi2idx_file, "w"))
54
+ else:
55
+ self.mol_smi2idx = {x["smi_name"]: i for i, x in enumerate(self.mol_dataset)}
56
+
57
+ uniprot_ids = [x["uniprot"] for x in labels]
58
+ self.uniprot_id_dict = {x:i for i,x in enumerate(set(uniprot_ids))}
59
+ self.split = split
60
+ if self.split == "train":
61
+ self.max_lignum = args.max_lignum # default=16
62
+ else:
63
+ self.max_lignum = args.test_max_lignum # default 512
64
+
65
+ if self.split == "train":
66
+ trainidxmap = []
67
+ for idx, assay_item in enumerate(self.labels):
68
+ lig_info = assay_item["ligands"]
69
+ trainidxmap += [idx]*math.ceil(len(lig_info)/max(self.max_lignum, 32))
70
+ self.trainidxmap = trainidxmap
71
+
72
+ self.epoch = 0
73
+
74
+
75
+ def __len__(self):
76
+ if self.split == "train":
77
+ import os
78
+ world_size = int(os.environ["WORLD_SIZE"])
79
+ div = self.args.batch_size * world_size
80
+ return (len(self.trainidxmap) // div) * div
81
+ else:
82
+ return len(self.labels)
83
+
84
+ def set_epoch(self, epoch):
85
+ self.epoch = epoch
86
+ self.pocket_dataset.set_epoch(epoch)
87
+ self.mol_dataset.set_epoch(epoch)
88
+ super().set_epoch(epoch)
89
+
90
+ def collater(self, samples):
91
+ ret_pocket = []
92
+ ret_lig = []
93
+ batch_list = []
94
+ act_list = []
95
+ uniprot_list = []
96
+ ret_protein = []
97
+ assay_id_list = []
98
+
99
+ if len(samples) == 0:
100
+ return {}
101
+ for pocket, ligs, acts, uniprot, assay_id, prot_seq in samples:
102
+ ret_pocket.append(pocket)
103
+ lignum_old = len(ret_lig)
104
+ ret_lig += ligs
105
+ batch_list.append([lignum_old, len(ret_lig)])
106
+ uniprot_list.append(self.uniprot_id_dict[uniprot])
107
+ assay_id_list.append(assay_id)
108
+ act_list.append(acts)
109
+ ret_protein.append(prot_seq)
110
+
111
+ ret_pocket = self.pocket_dataset.collater(ret_pocket)
112
+ ret_lig = self.mol_dataset.collater(ret_lig)
113
+ return {"pocket": ret_pocket, "lig": ret_lig, "protein": ret_protein,
114
+ "batch_list": batch_list, "act_list": act_list,
115
+ "uniprot_list": uniprot_list, "assay_id_list": assay_id_list}
116
+
117
+ # @lru_cache(maxsize=16)
118
+ def __getitem__(self, idx):
119
+ if self.split == "train":
120
+ t_idx = self.trainidxmap[idx]
121
+ else:
122
+ t_idx = idx
123
+
124
+ with data_utils.numpy_seed(1111, idx, self.epoch):
125
+ pocket_name = np.random.choice(self.labels[t_idx]["pockets"], 1, replace=False)[0]
126
+
127
+ lig_info = self.labels[t_idx]["ligands"]
128
+ lig_info = [x for x in lig_info if x["smi"] in self.mol_smi2idx]
129
+ uniprot = self.labels[t_idx]["uniprot"]
130
+ assay_id = self.labels[t_idx].get("assay_id", "none")
131
+ prot_seq = self.labels[t_idx]["sequence"]
132
+ if len(lig_info) > self.max_lignum:
133
+ with data_utils.numpy_seed(1111, idx, self.epoch):
134
+ lig_idxes = np.random.choice(list(range(len(lig_info))), self.max_lignum, replace=False)
135
+ lig_idxes = sorted(lig_idxes)
136
+ lig_info = [lig_info[idx] for idx in lig_idxes]
137
+
138
+ lig_idxes = [self.mol_smi2idx[info["smi"]] for info in lig_info]
139
+ pocket_idx = self.pocket_name2idx[pocket_name]
140
+ lig_act = [info["act"] for info in lig_info]
141
+ pocket_data = self.pocket_dataset[pocket_idx]
142
+ lig_data = [self.mol_dataset[x] for x in lig_idxes]
143
+
144
+ return pocket_data, lig_data, lig_act, uniprot, assay_id, prot_seq