AlienChen commited on
Commit
a06a951
·
verified ·
1 Parent(s): c4b1fea

Upload 8 files

Browse files
custom_datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import discretized_cifar10
2
+ from . import ten_species_dataset
custom_datasets/discretized_cifar10.py ADDED
File without changes
custom_datasets/ten_species_dataset.py ADDED
File without changes
notebooks/eval_hyenadna_classifier.ipynb ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "5b178466-559f-47ed-bcd1-a171641d47b5",
6
+ "metadata": {},
7
+ "source": [
8
+ "import os\n",
9
+ "\n",
10
+ "import hydra\n",
11
+ "import numpy as np\n",
12
+ "import omegaconf\n",
13
+ "import torch\n",
14
+ "import transformers\n",
15
+ "from sklearn.metrics import f1_score, matthews_corrcoef, precision_score, recall_score\n",
16
+ "from tqdm.auto import tqdm\n",
17
+ "\n",
18
+ "import classifier\n",
19
+ "import dataloader"
20
+ ],
21
+ "outputs": [],
22
+ "execution_count": null
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "id": "08301e02-d279-426f-8aad-c23eea8fb120",
27
+ "metadata": {},
28
+ "source": [
29
+ "omegaconf.OmegaConf.register_new_resolver(\n",
30
+ " 'cwd', os.getcwd)\n",
31
+ "omegaconf.OmegaConf.register_new_resolver(\n",
32
+ " 'device_count', torch.cuda.device_count)\n",
33
+ "omegaconf.OmegaConf.register_new_resolver(\n",
34
+ " 'eval', eval)\n",
35
+ "omegaconf.OmegaConf.register_new_resolver(\n",
36
+ " 'div_up', lambda x, y: (x + y - 1) // y)\n",
37
+ "omegaconf.OmegaConf.register_new_resolver(\n",
38
+ " 'if_then_else',\n",
39
+ " lambda condition, x, y: x if condition else y\n",
40
+ ")"
41
+ ],
42
+ "outputs": [],
43
+ "execution_count": null
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "id": "4685c167-63c8-4912-81e0-4ecd635fcc24",
48
+ "metadata": {},
49
+ "source": [
50
+ "# Load classifier\n",
51
+ "with hydra.initialize(version_base=None, config_path='../configs/'):\n",
52
+ " classifier_config = hydra.compose(\n",
53
+ " config_name='config',\n",
54
+ " overrides=[\n",
55
+ " 'hydra.output_subdir=null',\n",
56
+ " f\"hydra.run.dir={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8\",\n",
57
+ " 'hydra/job_logging=disabled',\n",
58
+ " 'hydra/hydra_logging=disabled',\n",
59
+ " '+is_eval_classifier=True',\n",
60
+ " 'mode=train_classifier',\n",
61
+ " 'loader.global_batch_size=32',\n",
62
+ " 'loader.eval_global_batch_size=64',\n",
63
+ " 'loader.batch_size=1',\n",
64
+ " 'loader.eval_batch_size=1',\n",
65
+ " 'data=ten_species',\n",
66
+ " 'data.label_col=species_label',\n",
67
+ " 'data.num_classes=10',\n",
68
+ " 'classifier_model=hyenadna-classifier',\n",
69
+ " 'classifier_model.hyena_model_name_or_path=LongSafari/hyenadna-small-32k-seqlen-hf',\n",
70
+ " 'classifier_model.n_layer=8',\n",
71
+ " 'classifier_backbone=hyenadna',\n",
72
+ " 'model.length=32768',\n",
73
+ " 'diffusion=null',\n",
74
+ " 'T=null',\n",
75
+ " f\"eval.checkpoint_path={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8/checkpoints/best.ckpt\",\n",
76
+ " ]\n",
77
+ " )\n",
78
+ "classifier_config = omegaconf.OmegaConf.create(classifier_config)\n",
79
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_config.data.tokenizer_name_or_path, trust_remote_code=True)\n",
80
+ "pretrained_classifier = classifier.Classifier.load_from_checkpoint(\n",
81
+ " classifier_config.eval.checkpoint_path,\n",
82
+ " tokenizer=tokenizer,\n",
83
+ " config=classifier_config, logger=False)\n",
84
+ "pretrained_classifier.eval();"
85
+ ],
86
+ "outputs": [],
87
+ "execution_count": null
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "id": "bf18720b-64a9-4e9e-9e1e-2aa1c12dc6f0",
92
+ "metadata": {},
93
+ "source": [
94
+ "tokenizer = dataloader.get_tokenizer(classifier_config)\n",
95
+ "_, val_dl = dataloader.get_dataloaders(\n",
96
+ " classifier_config, tokenizer, skip_train=True, valid_seed=classifier_config.seed)"
97
+ ],
98
+ "outputs": [],
99
+ "execution_count": null
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "id": "bdcd3ba7-e26a-4e36-a5fb-ff1fb747cc3c",
104
+ "metadata": {},
105
+ "source": [
106
+ "labels = []\n",
107
+ "preds = []\n",
108
+ "for batch in tqdm(val_dl):\n",
109
+ " preds.append(\n",
110
+ " pretrained_classifier(batch['input_ids'].to(pretrained_classifier.device)).argmax(dim=-1).detach().to(\n",
111
+ " 'cpu', non_blocking=True).numpy()\n",
112
+ " )\n",
113
+ " labels.append(batch['species_label'].numpy())"
114
+ ],
115
+ "outputs": [],
116
+ "execution_count": null
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "id": "110ed75e-613c-4b6a-bb79-15517988735c",
121
+ "metadata": {},
122
+ "source": [
123
+ "labels = np.concatenate(labels)\n",
124
+ "preds = np.concatenate(preds)"
125
+ ],
126
+ "outputs": [],
127
+ "execution_count": null
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "id": "1558ca2e-6454-4c8c-b141-fca77f0025c5",
132
+ "metadata": {},
133
+ "source": [
134
+ "overall_accuracy_score = (preds == labels).sum() / preds.size\n",
135
+ "overall_f1_score = f1_score(y_pred=preds, y_true=labels, average=\"macro\", labels=list(range(classifier_config.data.num_classes)))\n",
136
+ "overall_mcc_score = matthews_corrcoef(y_pred=preds, y_true=labels)\n",
137
+ "\n",
138
+ "print(f\"Overall Acc: {overall_accuracy_score:0.3f}\")\n",
139
+ "print(f\"Overall F1: {overall_f1_score:0.3f}\")\n",
140
+ "print(f\"Overall MCC: {overall_mcc_score:0.3f}\")"
141
+ ],
142
+ "outputs": [],
143
+ "execution_count": null
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "id": "df8ce828-f6e1-4167-bae2-db4f13900758",
148
+ "metadata": {},
149
+ "source": [
150
+ "f1_scores = f1_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
151
+ "precision_scores = precision_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
152
+ "recall_scores = recall_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
153
+ "\n",
154
+ "species_list = ['Homo_sapiens', 'Mus_musculus', 'Drosophila_melanogaster', 'Danio_rerio',\n",
155
+ " 'Caenorhabditis_elegans', 'Gallus_gallus', 'Gorilla_gorilla', 'Felis_catus',\n",
156
+ " 'Salmo_trutta', 'Arabidopsis_thaliana']\n",
157
+ "for s in range(classifier_config.data.num_classes):\n",
158
+ " print(f\"Class {s} - {species_list[s]}:\")\n",
159
+ " print(f\" F1: {f1_scores[s]:0.3f}\")\n",
160
+ " print(f\" Precision: {precision_scores[s]:0.3f}\")\n",
161
+ " print(f\" Recall: {recall_scores[s]:0.3f}\")"
162
+ ],
163
+ "outputs": [],
164
+ "execution_count": null
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "id": "d18ca7cc-4fe6-4ba9-9175-1eac9ebca7b1",
169
+ "metadata": {},
170
+ "source": [],
171
+ "outputs": [],
172
+ "execution_count": null
173
+ }
174
+ ],
175
+ "metadata": {
176
+ "kernelspec": {
177
+ "display_name": "Python 3 (ipykernel)",
178
+ "language": "python",
179
+ "name": "python3"
180
+ },
181
+ "language_info": {
182
+ "codemirror_mode": {
183
+ "name": "ipython",
184
+ "version": 3
185
+ },
186
+ "file_extension": ".py",
187
+ "mimetype": "text/x-python",
188
+ "name": "python",
189
+ "nbconvert_exporter": "python",
190
+ "pygments_lexer": "ipython3",
191
+ "version": "3.9.18"
192
+ }
193
+ },
194
+ "nbformat": 4,
195
+ "nbformat_minor": 5
196
+ }
notebooks/qm9_data_prep.ipynb ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "5fa7908c-b785-49ce-9e5d-7c6ad6b4378b",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Imports and setup"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 3,
14
+ "id": "d0c96204-ea08-4330-b1bb-784b259ec32e",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import huggingface_hub"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 4,
25
+ "id": "6813e76b",
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
33
+ "Token is valid (permission: write).\n",
34
+ "Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token\n",
35
+ "Login successful\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n",
41
+ " with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n",
42
+ " token = f.read().strip()\n",
43
+ "else:\n",
44
+ " token = None\n",
45
+ "huggingface_hub.login(token=token)"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 5,
51
+ "id": "61cb2ac4",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "import json\n",
56
+ "import typing\n",
57
+ "\n",
58
+ "import datasets\n",
59
+ "import numpy as np\n",
60
+ "import pandas as pd\n",
61
+ "import rdkit\n",
62
+ "import transformers\n",
63
+ "from rdkit import Chem as rdChem\n",
64
+ "from rdkit.Chem import Crippen, QED\n",
65
+ "from rdkit.Contrib.NP_Score import npscorer\n",
66
+ "from rdkit.Contrib.SA_Score import sascorer\n",
67
+ "from tqdm.auto import tqdm"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 6,
73
+ "id": "24444c85",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "# TODO: Update to 2024.03.6 release when available instead of suppressing warning!\n",
78
+ "# See: https://github.com/rdkit/rdkit/issues/7625#\n",
79
+ "rdkit.rdBase.DisableLog('rdApp.warning')"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "id": "902de4c5-dda5-4e4c-a4dd-f3b88015464e",
85
+ "metadata": {},
86
+ "source": [
87
+ "## Create dataset"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "7b7a8986",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "def parse_float(\n",
98
+ " s: str\n",
99
+ ") -> float:\n",
100
+ " \"\"\"Parses floats potentially written as exponentiated values.\n",
101
+ " \n",
102
+ " Copied from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n",
103
+ " \"\"\"\n",
104
+ " try:\n",
105
+ " return float(s)\n",
106
+ " except ValueError:\n",
107
+ " base, power = s.split('*^')\n",
108
+ " return float(base) * 10**float(power)\n",
109
+ "\n",
110
+ "\n",
111
+ "def count_rings_and_bonds(\n",
112
+ " mol: rdChem.Mol, max_ring_size: int = -1\n",
113
+ ") -> typing.Dict[str, int]:\n",
114
+ " \"\"\"Counts bond and ring (by type).\"\"\"\n",
115
+ " \n",
116
+ " # Counting rings\n",
117
+ " ssr = rdChem.GetSymmSSSR(mol)\n",
118
+ " ring_count = len(ssr)\n",
119
+ " \n",
120
+ " ring_sizes = {} if max_ring_size < 0 else {i: 0 for i in range(3, max_ring_size+1)}\n",
121
+ " for ring in ssr:\n",
122
+ " ring_size = len(ring)\n",
123
+ " if ring_size not in ring_sizes:\n",
124
+ " ring_sizes[ring_size] = 0\n",
125
+ " ring_sizes[ring_size] += 1\n",
126
+ " \n",
127
+ " # Counting bond types\n",
128
+ " bond_counts = {\n",
129
+ " 'single': 0,\n",
130
+ " 'double': 0,\n",
131
+ " 'triple': 0,\n",
132
+ " 'aromatic': 0\n",
133
+ " }\n",
134
+ " \n",
135
+ " for bond in mol.GetBonds():\n",
136
+ " if bond.GetIsAromatic():\n",
137
+ " bond_counts['aromatic'] += 1\n",
138
+ " elif bond.GetBondType() == rdChem.BondType.SINGLE:\n",
139
+ " bond_counts['single'] += 1\n",
140
+ " elif bond.GetBondType() == rdChem.BondType.DOUBLE:\n",
141
+ " bond_counts['double'] += 1\n",
142
+ " elif bond.GetBondType() == rdChem.BondType.TRIPLE:\n",
143
+ " bond_counts['triple'] += 1\n",
144
+ " result = {\n",
145
+ " 'ring_count': ring_count,\n",
146
+ " }\n",
147
+ " for k, v in ring_sizes.items():\n",
148
+ " result[f\"R{k}\"] = v\n",
149
+ "\n",
150
+ " for k, v in bond_counts.items():\n",
151
+ " result[f\"{k}_bond\"] = v\n",
152
+ " return result\n",
153
+ "\n",
154
+ "\n",
155
+ "def parse_xyz(\n",
156
+ " filename: str,\n",
157
+ " max_ring_size: int = -1,\n",
158
+ " npscorer_model: typing.Optional[dict] = None,\n",
159
+ " array_format: str = 'np'\n",
160
+ ") -> typing.Dict[str, typing.Any]:\n",
161
+ " \"\"\"Parses QM9 specific xyz files. \n",
162
+ " \n",
163
+ " See https://www.nature.com/articles/sdata201422/tables/2 for reference.\n",
164
+ " Adapted from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n",
165
+ " \"\"\"\n",
166
+ " assert array_format in ['np', 'pt'], \\\n",
167
+ " f\"Invalid array_format: `{array_format}` provided. Must be one of `np` (numpy.array), `pt` (torch.tensor).\"\n",
168
+ " \n",
169
+ " num_atoms = 0\n",
170
+ " scalar_properties = []\n",
171
+ " atomic_symbols = []\n",
172
+ " xyz = []\n",
173
+ " charges = []\n",
174
+ " harmonic_vibrational_frequencies = []\n",
175
+ " smiles = ''\n",
176
+ " inchi = ''\n",
177
+ " with open(filename, 'r') as f:\n",
178
+ " for line_num, line in enumerate(f):\n",
179
+ " if line_num == 0:\n",
180
+ " num_atoms = int(line)\n",
181
+ " elif line_num == 1:\n",
182
+ " scalar_properties = [float(i) for i in line.split()[2:]]\n",
183
+ " elif 2 <= line_num <= 1 + num_atoms:\n",
184
+ " atom_symbol, x, y, z, charge = line.split()\n",
185
+ " atomic_symbols.append(atom_symbol)\n",
186
+ " xyz.append([parse_float(x), parse_float(y), parse_float(z)])\n",
187
+ " charges.append(parse_float(charge))\n",
188
+ " elif line_num == num_atoms + 2:\n",
189
+ " harmonic_vibrational_frequencies = [float(i) for i in line.split()]\n",
190
+ " elif line_num == num_atoms + 3:\n",
191
+ " smiles = line.split()[0]\n",
192
+ " elif line_num == num_atoms + 4:\n",
193
+ " inchi = line.split()[0]\n",
194
+ "\n",
195
+ " array_wrap = np.array if array_format == 'np' else torch.tensor\n",
196
+ " result = {\n",
197
+ " 'num_atoms': num_atoms,\n",
198
+ " 'atomic_symbols': atomic_symbols,\n",
199
+ " 'pos': array_wrap(xyz),\n",
200
+ " 'charges': array_wrap(charges),\n",
201
+ " 'harmonic_oscillator_frequencies': array_wrap(harmonic_vibrational_frequencies),\n",
202
+ " 'smiles': smiles,\n",
203
+ " 'inchi': inchi\n",
204
+ " }\n",
205
+ " scalar_property_labels = [\n",
206
+ " 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u', 'h', 'g', 'cv'\n",
207
+ " ] \n",
208
+ " scalar_properties = dict(zip(scalar_property_labels, scalar_properties))\n",
209
+ " result.update(scalar_properties)\n",
210
+ "\n",
211
+ " # RdKit\n",
212
+ " result['canonical_smiles'] = rdChem.CanonSmiles(result['smiles'])\n",
213
+ " m = rdChem.MolFromSmiles(result['canonical_smiles'])\n",
214
+ " result['logP'] = Crippen.MolLogP(m)\n",
215
+ " result['qed'] = QED.qed(m)\n",
216
+ " if npscorer_model is not None:\n",
217
+ " result['np_score'] = npscorer.scoreMol(m, npscorer_model)\n",
218
+ " result['sa_score'] = sascorer.calculateScore(m)\n",
219
+ " result.update(count_rings_and_bonds(m, max_ring_size=max_ring_size))\n",
220
+ " \n",
221
+ " return result"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "72254d85",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "\"\"\"\n",
232
+ " Download xyz files from:\n",
233
+ " https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904\n",
234
+ " > wget https://figshare.com/ndownloader/files/3195389/dsgdb9nsd.xyz.tar.bz2\n",
235
+ " > mkdir dsgdb9nsd.xyz\n",
236
+ " > tar -xvjf dsgdb9nsd.xyz.tar.bz2 -C dsgdb9nsd.xyz\n",
237
+ "\"\"\"\n",
238
+ "MAX_RING_SIZE = 9\n",
239
+ "fscore = npscorer.readNPModel()\n",
240
+ "xyz_dir_path = '/Users/yairschiff/Downloads/dsgdb9nsd.xyz'\n",
241
+ "parsed_xyz = []\n",
242
+ "for file in tqdm(sorted(os.listdir(xyz_dir_path)), desc='Parsing'):\n",
243
+ " parsed = parse_xyz(os.path.join(xyz_dir_path, file),\n",
244
+ " max_ring_size=MAX_RING_SIZE,\n",
245
+ " npscorer_model=fscore,\n",
246
+ " array_format='np')\n",
247
+ " parsed_xyz.append(parsed)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "id": "12969dd2",
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "qm9_df = pd.DataFrame(data=parsed_xyz)"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "eed4f163",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "# Conversion below is needed to avoid:\n",
268
+ "# `ArrowInvalid: ('Can only convert 1-dimensional array values',\n",
269
+ "# 'Conversion failed for column pos with type object')`\n",
270
+ "qm9_df['pos'] = qm9_df['pos'].apply(lambda x: [xi for xi in x])"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "id": "c912d23a",
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "dataset = datasets.Dataset.from_pandas(qm9_df)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "id": "7a7df506",
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "dataset.push_to_hub('yairschiff/qm9')"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "86c4e1ae",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "# # Random train/test splits as recommended by:\n",
301
+ "# # https://moleculenet.org/datasets-1\n",
302
+ "# test_size = 0.1\n",
303
+ "# seed = 1\n",
304
+ "# dataset.train_test_split(test_size=test_size, seed=seed)"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "markdown",
309
+ "id": "e982da1b-05ab-493b-bb82-8bf1225dcb2b",
310
+ "metadata": {},
311
+ "source": [
312
+ "## Create tokenizer"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 7,
318
+ "id": "b0504e77",
319
+ "metadata": {},
320
+ "outputs": [],
321
+ "source": [
322
+ "def smi_tokenizer(smi):\n",
323
+ " \"\"\"Tokenize a SMILES molecule or reaction.\n",
324
+ "\n",
325
+ " Copied from https://github.com/pschwllr/MolecularTransformer.\n",
326
+ " \"\"\"\n",
327
+ " import re\n",
328
+ " pattern = \"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\\\\\|\\/|:|~|@|\\?|>|\\*|\\$|\\%[0-9]{2}|[0-9])\"\n",
329
+ " regex = re.compile(pattern)\n",
330
+ " tokens = [token for token in regex.findall(smi)]\n",
331
+ " assert smi == ''.join(tokens)\n",
332
+ " return tokens"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": 8,
338
+ "id": "b89a4def-ea08-466a-8779-24acf75a2bd0",
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "dataset = datasets.load_dataset('yairschiff/qm9', split='train')"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": 9,
348
+ "id": "6ef61481-9384-4c1c-8361-ab858cb157ba",
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "# # If vocab file not created yet, uncomment and run this cell\n",
353
+ "\n",
354
+ "# tokens = []\n",
355
+ "# for smi in dataset['canonical_smiles']:\n",
356
+ "# tokens.extend(smi_tokenizer(smi))\n",
357
+ "\n",
358
+ "# with open('qm9_vocab.json', 'w', encoding='utf-8') as f:\n",
359
+ "# f.write(\n",
360
+ "# json.dumps(\n",
361
+ "# {t: i for i, t in enumerate(sorted(set(tokens)))},\n",
362
+ "# indent=2,\n",
363
+ "# sort_keys=True,\n",
364
+ "# ensure_ascii=False\n",
365
+ "# ) + '\\n')"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 9,
371
+ "id": "6af7fccb-08ee-4dc6-99dc-cfa4fc38074c",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "# # If HF tokenizer not yet published, uncomment and run this cell\n",
376
+ "# import tokenizer\n",
377
+ "\n",
378
+ "# tokenizer.QM9Tokenizer.register_for_auto_class()\n",
379
+ "# qm9_tokenizer = tokenizer.QM9Tokenizer(vocab_file='qm9_vocab.json')\n",
380
+ "# qm9_tokenizer.push_to_hub('yairschiff/qm9-tokenizer')"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 23,
386
+ "id": "4cc39f16-b53c-481a-a35e-a42fb1b08378",
387
+ "metadata": {},
388
+ "outputs": [],
389
+ "source": [
390
+ "# Test tokenizer\n",
391
+ "qm9_tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
392
+ " 'yairschiff/qm9-tokenizer', trust_remote_code=True, resume_download=None)\n",
393
+ "print(dataset[1000]['canonical_smiles'])\n",
394
+ "print(qm9_tokenizer.encode(dataset[1000]['canonical_smiles']))\n",
395
+ "print(qm9_tokenizer.decode(qm9_tokenizer.encode(dataset[1000]['canonical_smiles'])))"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": null,
401
+ "id": "41752e94-175e-4f40-b9d2-496241eab0c0",
402
+ "metadata": {},
403
+ "outputs": [],
404
+ "source": []
405
+ }
406
+ ],
407
+ "metadata": {
408
+ "kernelspec": {
409
+ "display_name": "Python 3 (ipykernel)",
410
+ "language": "python",
411
+ "name": "python3"
412
+ },
413
+ "language_info": {
414
+ "codemirror_mode": {
415
+ "name": "ipython",
416
+ "version": 3
417
+ },
418
+ "file_extension": ".py",
419
+ "mimetype": "text/x-python",
420
+ "name": "python",
421
+ "nbconvert_exporter": "python",
422
+ "pygments_lexer": "ipython3",
423
+ "version": "3.9.18"
424
+ }
425
+ },
426
+ "nbformat": 4,
427
+ "nbformat_minor": 5
428
+ }
notebooks/qm9_vocab.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "#": 0,
3
+ "(": 1,
4
+ ")": 2,
5
+ "-": 3,
6
+ "1": 4,
7
+ "2": 5,
8
+ "3": 6,
9
+ "4": 7,
10
+ "5": 8,
11
+ "=": 9,
12
+ "C": 10,
13
+ "F": 11,
14
+ "N": 12,
15
+ "O": 13,
16
+ "[C-]": 14,
17
+ "[CH-]": 15,
18
+ "[N+]": 16,
19
+ "[N-]": 17,
20
+ "[NH+]": 18,
21
+ "[NH2+]": 19,
22
+ "[NH3+]": 20,
23
+ "[O-]": 21,
24
+ "[c-]": 22,
25
+ "[cH-]": 23,
26
+ "[n-]": 24,
27
+ "[nH+]": 25,
28
+ "[nH]": 26,
29
+ "c": 27,
30
+ "n": 28,
31
+ "o": 29
32
+ }
notebooks/zinc250k_data_prep.ipynb ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fa328603-9e2b-4643-8500-ec11c51b5223",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Imports and setup"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 3,
14
+ "id": "7716fb32-a805-4888-9dac-da4cff4f6e40",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import huggingface_hub"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 4,
25
+ "id": "432e1636",
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
33
+ "Token is valid (permission: write).\n",
34
+ "Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token\n",
35
+ "Login successful\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n",
41
+ " with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n",
42
+ " token = f.read().strip()\n",
43
+ "else:\n",
44
+ " token = None\n",
45
+ "huggingface_hub.login(token=token)"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 26,
51
+ "id": "e22e86ae",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "import json\n",
56
+ "import re\n",
57
+ "import typing\n",
58
+ "\n",
59
+ "import datasets\n",
60
+ "import numpy as np\n",
61
+ "import pandas as pd\n",
62
+ "import rdkit\n",
63
+ "import transformers\n",
64
+ "from rdkit import Chem as rdChem\n",
65
+ "from tqdm.auto import tqdm"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 7,
71
+ "id": "aaa00828",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# TODO: Update to 2024.03.6 release when available instead of suppressing warning!\n",
76
+ "# See: https://github.com/rdkit/rdkit/issues/7625#\n",
77
+ "rdkit.rdBase.DisableLog('rdApp.warning')"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "id": "0a878a71-d33f-43fe-955d-4250950b1eec",
83
+ "metadata": {
84
+ "jp-MarkdownHeadingCollapsed": true
85
+ },
86
+ "source": [
87
+ "## Create dataset"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "26856fe2",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "def count_rings_and_bonds(\n",
98
+ " mol: rdChem.Mol\n",
99
+ ") -> typing.Dict[str, int]:\n",
100
+ " \"\"\"Counts bond and ring (by type).\"\"\"\n",
101
+ " \n",
102
+ " # Counting rings\n",
103
+ " ssr = rdChem.GetSymmSSSR(mol)\n",
104
+ " ring_count = len(ssr)\n",
105
+ " \n",
106
+ " ring_sizes = {}\n",
107
+ " for ring in ssr:\n",
108
+ " ring_size = len(ring)\n",
109
+ " if ring_size not in ring_sizes:\n",
110
+ " ring_sizes[ring_size] = 0\n",
111
+ " ring_sizes[ring_size] += 1\n",
112
+ " \n",
113
+ " # Counting bond types\n",
114
+ " bond_counts = {\n",
115
+ " 'single': 0,\n",
116
+ " 'double': 0,\n",
117
+ " 'triple': 0,\n",
118
+ " 'aromatic': 0\n",
119
+ " }\n",
120
+ " \n",
121
+ " for bond in mol.GetBonds():\n",
122
+ " if bond.GetIsAromatic():\n",
123
+ " bond_counts['aromatic'] += 1\n",
124
+ " elif bond.GetBondType() == rdChem.BondType.SINGLE:\n",
125
+ " bond_counts['single'] += 1\n",
126
+ " elif bond.GetBondType() == rdChem.BondType.DOUBLE:\n",
127
+ " bond_counts['double'] += 1\n",
128
+ " elif bond.GetBondType() == rdChem.BondType.TRIPLE:\n",
129
+ " bond_counts['triple'] += 1\n",
130
+ " result = {\n",
131
+ " 'ring_count': ring_count,\n",
132
+ " }\n",
133
+ " for k, v in ring_sizes.items():\n",
134
+ " result[f\"R{k}\"] = v\n",
135
+ "\n",
136
+ " for k, v in bond_counts.items():\n",
137
+ " result[f\"{k}_bond\"] = v\n",
138
+ " return result"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "fbde53f7",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "\"\"\"\n",
149
+ " Download data and validation indices from:\n",
150
+ " \"Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations\"\n",
151
+ " https://github.com/harryjo97/GDSS\n",
152
+ " > wget wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/zinc250k.csv\n",
153
+ " > wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/valid_idx_zinc250k.json\n",
154
+ "\"\"\"\n",
155
+ "df = pd.read_csv('/Users/yairschiff/Downloads/zinc250k.csv', index_col=0, encoding='utf_8')\n",
156
+ "feats = []\n",
157
+ "for i, row in tqdm(df.iterrows(), total=len(df), desc='RDKit feats', leave=False):\n",
158
+ " feat = {'smiles': row['smiles']}\n",
159
+ " feat['canonical_smiles'] = rdChem.CanonSmiles(feat['smiles'])\n",
160
+ " m = rdChem.MolFromSmiles(feat['canonical_smiles'])\n",
161
+ " feat.update(count_rings_and_bonds(m))\n",
162
+ " feats.append(feat)\n",
163
+ "df = pd.merge(df, pd.DataFrame.from_records(feats), on='smiles')\n",
164
+ "df = df.fillna(0)\n",
165
+ "for col in df.columns: # recast ring counts as int\n",
166
+ " if re.search(\"^R[0-9]+$\", col) is not None:\n",
167
+ " df[col] = df[col].astype(int)\n",
168
+ "# Re-order columns\n",
169
+ "df = df[\n",
170
+ " ['smiles', 'logP', 'qed', 'SAS', 'canonical_smiles',\n",
171
+ " 'single_bond', 'double_bond', 'triple_bond', 'aromatic_bond',\n",
172
+ " 'ring_count','R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R12', 'R13', 'R14', 'R15', 'R18', 'R24']]"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "1e2d5955",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "# Read in validation indices\n",
183
+ "with open('/Users/yairschiff/Downloads/valid_idx_zinc250k.json', 'r') as f:\n",
184
+ " valid_idxs = json.load(f)\n",
185
+ "df['validation'] = df.index.isin(valid_idxs).astype(int)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "2b89b732",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "# Create HF dataset\n",
196
+ "dataset = datasets.DatasetDict({\n",
197
+ " 'train': datasets.Dataset.from_pandas(df[df['validation'] == 0].drop(columns=['validation'])),\n",
198
+ " 'validation': datasets.Dataset.from_pandas(df[df['validation'] == 1].drop(columns=['validation'])),\n",
199
+ "})\n",
200
+ "dataset = dataset.remove_columns('__index_level_0__')"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "1efb5845",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "dataset.push_to_hub('yairschiff/zinc250k')"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "id": "5c6f357d-20d9-4004-8091-68726b6b4c86",
216
+ "metadata": {},
217
+ "source": [
218
+ "## Create tokenizer"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 8,
224
+ "id": "6642fc9d-4863-4e14-947b-95bae48e192d",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "def smi_tokenizer(smi):\n",
229
+ " \"\"\"Tokenize a SMILES molecule or reaction.\n",
230
+ "\n",
231
+ " Copied from https://github.com/pschwllr/MolecularTransformer.\n",
232
+ " \"\"\"\n",
233
+ " import re\n",
234
+ " pattern = \"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\\\\\|\\/|:|~|@|\\?|>|\\*|\\$|\\%[0-9]{2}|[0-9])\"\n",
235
+ " regex = re.compile(pattern)\n",
236
+ " tokens = [token for token in regex.findall(smi)]\n",
237
+ " assert smi == ''.join(tokens)\n",
238
+ " return tokens"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 11,
244
+ "id": "3a9e2e60-8596-4a91-acc3-d43e166ce723",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "dataset = datasets.load_dataset('yairschiff/zinc250k')"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 12,
254
+ "id": "fbd5c2fe-4318-46bb-bc43-6ef7fe76e9fd",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "# # If vocab file not created yet, uncomment and run this cell\n",
259
+ "\n",
260
+ "# tokens = []\n",
261
+ "# for split in dataset.keys():\n",
262
+ "# for smi in dataset[split]['canonical_smiles']:\n",
263
+ "# tokens.extend(smi_tokenizer(smi))\n",
264
+ "\n",
265
+ "# with open('zinc250k_vocab.json', 'w', encoding='utf-8') as f:\n",
266
+ "# f.write(\n",
267
+ "# json.dumps(\n",
268
+ "# {t: i for i, t in enumerate(sorted(set(tokens)))},\n",
269
+ "# indent=2,\n",
270
+ "# sort_keys=True,\n",
271
+ "# ensure_ascii=False\n",
272
+ "# ) + '\\n')"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 14,
278
+ "id": "4962478b-5343-4838-befe-64a5389625d4",
279
+ "metadata": {},
280
+ "outputs": [
281
+ {
282
+ "data": {
283
+ "text/plain": [
284
+ "CommitInfo(commit_url='https://huggingface.co/yairschiff/zinc250k-tokenizer/commit/7a07b0165a8a4f14f09d6137da8cdabf789397fd', commit_message='Upload tokenizer', commit_description='', oid='7a07b0165a8a4f14f09d6137da8cdabf789397fd', pr_url=None, pr_revision=None, pr_num=None)"
285
+ ]
286
+ },
287
+ "execution_count": 14,
288
+ "metadata": {},
289
+ "output_type": "execute_result"
290
+ }
291
+ ],
292
+ "source": [
293
+ "# # If HF tokenizer not yet published, uncomment and run this cell\n",
294
+ "# import tokenizer\n",
295
+ "\n",
296
+ "# tokenizer.Zinc250kTokenizer.register_for_auto_class()\n",
297
+ "# zinc250k_tokenizer = tokenizer.Zinc250kTokenizer(vocab_file='zinc250k_vocab.json')\n",
298
+ "# zinc250k_tokenizer.push_to_hub('yairschiff/zinc250k-tokenizer')"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 18,
304
+ "id": "a779aa57-0c9d-4b8c-bf11-ccc5ab4c462e",
305
+ "metadata": {},
306
+ "outputs": [
307
+ {
308
+ "name": "stdout",
309
+ "output_type": "stream",
310
+ "text": [
311
+ "Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1\n",
312
+ "[0, 25, 69, 15, 69, 68, 68, 16, 68, 15, 25, 25, 25, 35, 16, 29, 25, 11, 23, 30, 12, 29, 25, 35, 11, 30, 12, 25, 30, 68, 15, 68, 68, 68, 11, 27, 12, 68, 68, 15, 1]\n",
313
+ "<bos>Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1<eos>\n",
314
+ "Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1\n"
315
+ ]
316
+ }
317
+ ],
318
+ "source": [
319
+ "# Test tokenizer\n",
320
+ "zinc250k_tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
321
+ " 'yairschiff/zinc250k-tokenizer', trust_remote_code=True, resume_download=None)\n",
322
+ "print(dataset['train'][1000]['canonical_smiles'])\n",
323
+ "print(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles']))\n",
324
+ "print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'])))\n",
325
+ "print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'], add_special_tokens=False)))"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 28,
331
+ "id": "f3a15585-8e75-409d-9afe-0e7fe4a0bffc",
332
+ "metadata": {},
333
+ "outputs": [
334
+ {
335
+ "data": {
336
+ "application/vnd.jupyter.widget-view+json": {
337
+ "model_id": "",
338
+ "version_major": 2,
339
+ "version_minor": 0
340
+ },
341
+ "text/plain": [
342
+ " 0%| | 0/224568 [00:00<?, ?it/s]"
343
+ ]
344
+ },
345
+ "metadata": {},
346
+ "output_type": "display_data"
347
+ },
348
+ {
349
+ "data": {
350
+ "application/vnd.jupyter.widget-view+json": {
351
+ "model_id": "",
352
+ "version_major": 2,
353
+ "version_minor": 0
354
+ },
355
+ "text/plain": [
356
+ " 0%| | 0/24887 [00:00<?, ?it/s]"
357
+ ]
358
+ },
359
+ "metadata": {},
360
+ "output_type": "display_data"
361
+ },
362
+ {
363
+ "name": "stdout",
364
+ "output_type": "stream",
365
+ "text": [
366
+ "(array([ 152, 3351, 21311, 47185, 67972, 70367, 25030, 11778, 2179,\n",
367
+ " 130]), array([10. , 16.4, 22.8, 29.2, 35.6, 42. , 48.4, 54.8, 61.2, 67.6, 74. ]))\n",
368
+ "10\n",
369
+ "74\n"
370
+ ]
371
+ }
372
+ ],
373
+ "source": [
374
+ "lengths = [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['train'], leave=False)]\n",
375
+ "lengths += [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['validation'], leave=False)]\n",
376
+ "print(np.histogram(lengths))\n",
377
+ "print(min(lengths))\n",
378
+ "print(max(lengths))"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "id": "d7a6e081-4961-4cf4-a19d-0375bedd7dab",
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": []
388
+ }
389
+ ],
390
+ "metadata": {
391
+ "kernelspec": {
392
+ "display_name": "Python 3 (ipykernel)",
393
+ "language": "python",
394
+ "name": "python3"
395
+ },
396
+ "language_info": {
397
+ "codemirror_mode": {
398
+ "name": "ipython",
399
+ "version": 3
400
+ },
401
+ "file_extension": ".py",
402
+ "mimetype": "text/x-python",
403
+ "name": "python",
404
+ "nbconvert_exporter": "python",
405
+ "pygments_lexer": "ipython3",
406
+ "version": "3.9.18"
407
+ }
408
+ },
409
+ "nbformat": 4,
410
+ "nbformat_minor": 5
411
+ }
notebooks/zinc250k_vocab.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "#": 0,
3
+ "(": 1,
4
+ ")": 2,
5
+ "-": 3,
6
+ "/": 4,
7
+ "1": 5,
8
+ "2": 6,
9
+ "3": 7,
10
+ "4": 8,
11
+ "5": 9,
12
+ "6": 10,
13
+ "7": 11,
14
+ "8": 12,
15
+ "=": 13,
16
+ "Br": 14,
17
+ "C": 15,
18
+ "Cl": 16,
19
+ "F": 17,
20
+ "I": 18,
21
+ "N": 19,
22
+ "O": 20,
23
+ "P": 21,
24
+ "S": 22,
25
+ "[C@@H]": 23,
26
+ "[C@@]": 24,
27
+ "[C@H]": 25,
28
+ "[C@]": 26,
29
+ "[CH-]": 27,
30
+ "[CH2-]": 28,
31
+ "[N+]": 29,
32
+ "[N-]": 30,
33
+ "[NH+]": 31,
34
+ "[NH-]": 32,
35
+ "[NH2+]": 33,
36
+ "[NH3+]": 34,
37
+ "[O+]": 35,
38
+ "[O-]": 36,
39
+ "[OH+]": 37,
40
+ "[P+]": 38,
41
+ "[P@@H]": 39,
42
+ "[P@@]": 40,
43
+ "[P@]": 41,
44
+ "[PH+]": 42,
45
+ "[PH2]": 43,
46
+ "[PH]": 44,
47
+ "[S+]": 45,
48
+ "[S-]": 46,
49
+ "[S@@+]": 47,
50
+ "[S@@]": 48,
51
+ "[S@]": 49,
52
+ "[SH+]": 50,
53
+ "[n+]": 51,
54
+ "[n-]": 52,
55
+ "[nH+]": 53,
56
+ "[nH]": 54,
57
+ "[o+]": 55,
58
+ "[s+]": 56,
59
+ "\\": 57,
60
+ "c": 58,
61
+ "n": 59,
62
+ "o": 60,
63
+ "s": 61
64
+ }