Upload 8 files
Browse files- custom_datasets/__init__.py +2 -0
- custom_datasets/discretized_cifar10.py +0 -0
- custom_datasets/ten_species_dataset.py +0 -0
- notebooks/eval_hyenadna_classifier.ipynb +196 -0
- notebooks/qm9_data_prep.ipynb +428 -0
- notebooks/qm9_vocab.json +32 -0
- notebooks/zinc250k_data_prep.ipynb +411 -0
- notebooks/zinc250k_vocab.json +64 -0
custom_datasets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import discretized_cifar10
|
| 2 |
+
from . import ten_species_dataset
|
custom_datasets/discretized_cifar10.py
ADDED
|
File without changes
|
custom_datasets/ten_species_dataset.py
ADDED
|
File without changes
|
notebooks/eval_hyenadna_classifier.ipynb
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"id": "5b178466-559f-47ed-bcd1-a171641d47b5",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"import os\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"import hydra\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"import omegaconf\n",
|
| 13 |
+
"import torch\n",
|
| 14 |
+
"import transformers\n",
|
| 15 |
+
"from sklearn.metrics import f1_score, matthews_corrcoef, precision_score, recall_score\n",
|
| 16 |
+
"from tqdm.auto import tqdm\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"import classifier\n",
|
| 19 |
+
"import dataloader"
|
| 20 |
+
],
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"execution_count": null
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"id": "08301e02-d279-426f-8aad-c23eea8fb120",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"source": [
|
| 29 |
+
"omegaconf.OmegaConf.register_new_resolver(\n",
|
| 30 |
+
" 'cwd', os.getcwd)\n",
|
| 31 |
+
"omegaconf.OmegaConf.register_new_resolver(\n",
|
| 32 |
+
" 'device_count', torch.cuda.device_count)\n",
|
| 33 |
+
"omegaconf.OmegaConf.register_new_resolver(\n",
|
| 34 |
+
" 'eval', eval)\n",
|
| 35 |
+
"omegaconf.OmegaConf.register_new_resolver(\n",
|
| 36 |
+
" 'div_up', lambda x, y: (x + y - 1) // y)\n",
|
| 37 |
+
"omegaconf.OmegaConf.register_new_resolver(\n",
|
| 38 |
+
" 'if_then_else',\n",
|
| 39 |
+
" lambda condition, x, y: x if condition else y\n",
|
| 40 |
+
")"
|
| 41 |
+
],
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"execution_count": null
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"id": "4685c167-63c8-4912-81e0-4ecd635fcc24",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"source": [
|
| 50 |
+
"# Load classifier\n",
|
| 51 |
+
"with hydra.initialize(version_base=None, config_path='../configs/'):\n",
|
| 52 |
+
" classifier_config = hydra.compose(\n",
|
| 53 |
+
" config_name='config',\n",
|
| 54 |
+
" overrides=[\n",
|
| 55 |
+
" 'hydra.output_subdir=null',\n",
|
| 56 |
+
" f\"hydra.run.dir={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8\",\n",
|
| 57 |
+
" 'hydra/job_logging=disabled',\n",
|
| 58 |
+
" 'hydra/hydra_logging=disabled',\n",
|
| 59 |
+
" '+is_eval_classifier=True',\n",
|
| 60 |
+
" 'mode=train_classifier',\n",
|
| 61 |
+
" 'loader.global_batch_size=32',\n",
|
| 62 |
+
" 'loader.eval_global_batch_size=64',\n",
|
| 63 |
+
" 'loader.batch_size=1',\n",
|
| 64 |
+
" 'loader.eval_batch_size=1',\n",
|
| 65 |
+
" 'data=ten_species',\n",
|
| 66 |
+
" 'data.label_col=species_label',\n",
|
| 67 |
+
" 'data.num_classes=10',\n",
|
| 68 |
+
" 'classifier_model=hyenadna-classifier',\n",
|
| 69 |
+
" 'classifier_model.hyena_model_name_or_path=LongSafari/hyenadna-small-32k-seqlen-hf',\n",
|
| 70 |
+
" 'classifier_model.n_layer=8',\n",
|
| 71 |
+
" 'classifier_backbone=hyenadna',\n",
|
| 72 |
+
" 'model.length=32768',\n",
|
| 73 |
+
" 'diffusion=null',\n",
|
| 74 |
+
" 'T=null',\n",
|
| 75 |
+
" f\"eval.checkpoint_path={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8/checkpoints/best.ckpt\",\n",
|
| 76 |
+
" ]\n",
|
| 77 |
+
" )\n",
|
| 78 |
+
"classifier_config = omegaconf.OmegaConf.create(classifier_config)\n",
|
| 79 |
+
"tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_config.data.tokenizer_name_or_path, trust_remote_code=True)\n",
|
| 80 |
+
"pretrained_classifier = classifier.Classifier.load_from_checkpoint(\n",
|
| 81 |
+
" classifier_config.eval.checkpoint_path,\n",
|
| 82 |
+
" tokenizer=tokenizer,\n",
|
| 83 |
+
" config=classifier_config, logger=False)\n",
|
| 84 |
+
"pretrained_classifier.eval();"
|
| 85 |
+
],
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"execution_count": null
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"id": "bf18720b-64a9-4e9e-9e1e-2aa1c12dc6f0",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"tokenizer = dataloader.get_tokenizer(classifier_config)\n",
|
| 95 |
+
"_, val_dl = dataloader.get_dataloaders(\n",
|
| 96 |
+
" classifier_config, tokenizer, skip_train=True, valid_seed=classifier_config.seed)"
|
| 97 |
+
],
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"execution_count": null
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "code",
|
| 103 |
+
"id": "bdcd3ba7-e26a-4e36-a5fb-ff1fb747cc3c",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"source": [
|
| 106 |
+
"labels = []\n",
|
| 107 |
+
"preds = []\n",
|
| 108 |
+
"for batch in tqdm(val_dl):\n",
|
| 109 |
+
" preds.append(\n",
|
| 110 |
+
" pretrained_classifier(batch['input_ids'].to(pretrained_classifier.device)).argmax(dim=-1).detach().to(\n",
|
| 111 |
+
" 'cpu', non_blocking=True).numpy()\n",
|
| 112 |
+
" )\n",
|
| 113 |
+
" labels.append(batch['species_label'].numpy())"
|
| 114 |
+
],
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"execution_count": null
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "code",
|
| 120 |
+
"id": "110ed75e-613c-4b6a-bb79-15517988735c",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"source": [
|
| 123 |
+
"labels = np.concatenate(labels)\n",
|
| 124 |
+
"preds = np.concatenate(preds)"
|
| 125 |
+
],
|
| 126 |
+
"outputs": [],
|
| 127 |
+
"execution_count": null
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"id": "1558ca2e-6454-4c8c-b141-fca77f0025c5",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"source": [
|
| 134 |
+
"overall_accuracy_score = (preds == labels).sum() / preds.size\n",
|
| 135 |
+
"overall_f1_score = f1_score(y_pred=preds, y_true=labels, average=\"macro\", labels=list(range(classifier_config.data.num_classes)))\n",
|
| 136 |
+
"overall_mcc_score = matthews_corrcoef(y_pred=preds, y_true=labels)\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"print(f\"Overall Acc: {overall_accuracy_score:0.3f}\")\n",
|
| 139 |
+
"print(f\"Overall F1: {overall_f1_score:0.3f}\")\n",
|
| 140 |
+
"print(f\"Overall MCC: {overall_mcc_score:0.3f}\")"
|
| 141 |
+
],
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"execution_count": null
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"id": "df8ce828-f6e1-4167-bae2-db4f13900758",
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"source": [
|
| 150 |
+
"f1_scores = f1_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
|
| 151 |
+
"precision_scores = precision_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
|
| 152 |
+
"recall_scores = recall_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"species_list = ['Homo_sapiens', 'Mus_musculus', 'Drosophila_melanogaster', 'Danio_rerio',\n",
|
| 155 |
+
" 'Caenorhabditis_elegans', 'Gallus_gallus', 'Gorilla_gorilla', 'Felis_catus',\n",
|
| 156 |
+
" 'Salmo_trutta', 'Arabidopsis_thaliana']\n",
|
| 157 |
+
"for s in range(classifier_config.data.num_classes):\n",
|
| 158 |
+
" print(f\"Class {s} - {species_list[s]}:\")\n",
|
| 159 |
+
" print(f\" F1: {f1_scores[s]:0.3f}\")\n",
|
| 160 |
+
" print(f\" Precision: {precision_scores[s]:0.3f}\")\n",
|
| 161 |
+
" print(f\" Recall: {recall_scores[s]:0.3f}\")"
|
| 162 |
+
],
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"execution_count": null
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "code",
|
| 168 |
+
"id": "d18ca7cc-4fe6-4ba9-9175-1eac9ebca7b1",
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"source": [],
|
| 171 |
+
"outputs": [],
|
| 172 |
+
"execution_count": null
|
| 173 |
+
}
|
| 174 |
+
],
|
| 175 |
+
"metadata": {
|
| 176 |
+
"kernelspec": {
|
| 177 |
+
"display_name": "Python 3 (ipykernel)",
|
| 178 |
+
"language": "python",
|
| 179 |
+
"name": "python3"
|
| 180 |
+
},
|
| 181 |
+
"language_info": {
|
| 182 |
+
"codemirror_mode": {
|
| 183 |
+
"name": "ipython",
|
| 184 |
+
"version": 3
|
| 185 |
+
},
|
| 186 |
+
"file_extension": ".py",
|
| 187 |
+
"mimetype": "text/x-python",
|
| 188 |
+
"name": "python",
|
| 189 |
+
"nbconvert_exporter": "python",
|
| 190 |
+
"pygments_lexer": "ipython3",
|
| 191 |
+
"version": "3.9.18"
|
| 192 |
+
}
|
| 193 |
+
},
|
| 194 |
+
"nbformat": 4,
|
| 195 |
+
"nbformat_minor": 5
|
| 196 |
+
}
|
notebooks/qm9_data_prep.ipynb
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "5fa7908c-b785-49ce-9e5d-7c6ad6b4378b",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Imports and setup"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 3,
|
| 14 |
+
"id": "d0c96204-ea08-4330-b1bb-784b259ec32e",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os\n",
|
| 19 |
+
"import huggingface_hub"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": 4,
|
| 25 |
+
"id": "6813e76b",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [
|
| 28 |
+
{
|
| 29 |
+
"name": "stdout",
|
| 30 |
+
"output_type": "stream",
|
| 31 |
+
"text": [
|
| 32 |
+
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
|
| 33 |
+
"Token is valid (permission: write).\n",
|
| 34 |
+
"Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token\n",
|
| 35 |
+
"Login successful\n"
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"source": [
|
| 40 |
+
"if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n",
|
| 41 |
+
" with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n",
|
| 42 |
+
" token = f.read().strip()\n",
|
| 43 |
+
"else:\n",
|
| 44 |
+
" token = None\n",
|
| 45 |
+
"huggingface_hub.login(token=token)"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 5,
|
| 51 |
+
"id": "61cb2ac4",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"import json\n",
|
| 56 |
+
"import typing\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"import datasets\n",
|
| 59 |
+
"import numpy as np\n",
|
| 60 |
+
"import pandas as pd\n",
|
| 61 |
+
"import rdkit\n",
|
| 62 |
+
"import transformers\n",
|
| 63 |
+
"from rdkit import Chem as rdChem\n",
|
| 64 |
+
"from rdkit.Chem import Crippen, QED\n",
|
| 65 |
+
"from rdkit.Contrib.NP_Score import npscorer\n",
|
| 66 |
+
"from rdkit.Contrib.SA_Score import sascorer\n",
|
| 67 |
+
"from tqdm.auto import tqdm"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": 6,
|
| 73 |
+
"id": "24444c85",
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [],
|
| 76 |
+
"source": [
|
| 77 |
+
"# TODO: Update to 2024.03.6 release when available instead of suppressing warning!\n",
|
| 78 |
+
"# See: https://github.com/rdkit/rdkit/issues/7625#\n",
|
| 79 |
+
"rdkit.rdBase.DisableLog('rdApp.warning')"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "markdown",
|
| 84 |
+
"id": "902de4c5-dda5-4e4c-a4dd-f3b88015464e",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"source": [
|
| 87 |
+
"## Create dataset"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"id": "7b7a8986",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"def parse_float(\n",
|
| 98 |
+
" s: str\n",
|
| 99 |
+
") -> float:\n",
|
| 100 |
+
" \"\"\"Parses floats potentially written as exponentiated values.\n",
|
| 101 |
+
" \n",
|
| 102 |
+
" Copied from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n",
|
| 103 |
+
" \"\"\"\n",
|
| 104 |
+
" try:\n",
|
| 105 |
+
" return float(s)\n",
|
| 106 |
+
" except ValueError:\n",
|
| 107 |
+
" base, power = s.split('*^')\n",
|
| 108 |
+
" return float(base) * 10**float(power)\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"def count_rings_and_bonds(\n",
|
| 112 |
+
" mol: rdChem.Mol, max_ring_size: int = -1\n",
|
| 113 |
+
") -> typing.Dict[str, int]:\n",
|
| 114 |
+
" \"\"\"Counts bond and ring (by type).\"\"\"\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" # Counting rings\n",
|
| 117 |
+
" ssr = rdChem.GetSymmSSSR(mol)\n",
|
| 118 |
+
" ring_count = len(ssr)\n",
|
| 119 |
+
" \n",
|
| 120 |
+
" ring_sizes = {} if max_ring_size < 0 else {i: 0 for i in range(3, max_ring_size+1)}\n",
|
| 121 |
+
" for ring in ssr:\n",
|
| 122 |
+
" ring_size = len(ring)\n",
|
| 123 |
+
" if ring_size not in ring_sizes:\n",
|
| 124 |
+
" ring_sizes[ring_size] = 0\n",
|
| 125 |
+
" ring_sizes[ring_size] += 1\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" # Counting bond types\n",
|
| 128 |
+
" bond_counts = {\n",
|
| 129 |
+
" 'single': 0,\n",
|
| 130 |
+
" 'double': 0,\n",
|
| 131 |
+
" 'triple': 0,\n",
|
| 132 |
+
" 'aromatic': 0\n",
|
| 133 |
+
" }\n",
|
| 134 |
+
" \n",
|
| 135 |
+
" for bond in mol.GetBonds():\n",
|
| 136 |
+
" if bond.GetIsAromatic():\n",
|
| 137 |
+
" bond_counts['aromatic'] += 1\n",
|
| 138 |
+
" elif bond.GetBondType() == rdChem.BondType.SINGLE:\n",
|
| 139 |
+
" bond_counts['single'] += 1\n",
|
| 140 |
+
" elif bond.GetBondType() == rdChem.BondType.DOUBLE:\n",
|
| 141 |
+
" bond_counts['double'] += 1\n",
|
| 142 |
+
" elif bond.GetBondType() == rdChem.BondType.TRIPLE:\n",
|
| 143 |
+
" bond_counts['triple'] += 1\n",
|
| 144 |
+
" result = {\n",
|
| 145 |
+
" 'ring_count': ring_count,\n",
|
| 146 |
+
" }\n",
|
| 147 |
+
" for k, v in ring_sizes.items():\n",
|
| 148 |
+
" result[f\"R{k}\"] = v\n",
|
| 149 |
+
"\n",
|
| 150 |
+
" for k, v in bond_counts.items():\n",
|
| 151 |
+
" result[f\"{k}_bond\"] = v\n",
|
| 152 |
+
" return result\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"def parse_xyz(\n",
|
| 156 |
+
" filename: str,\n",
|
| 157 |
+
" max_ring_size: int = -1,\n",
|
| 158 |
+
" npscorer_model: typing.Optional[dict] = None,\n",
|
| 159 |
+
" array_format: str = 'np'\n",
|
| 160 |
+
") -> typing.Dict[str, typing.Any]:\n",
|
| 161 |
+
" \"\"\"Parses QM9 specific xyz files. \n",
|
| 162 |
+
" \n",
|
| 163 |
+
" See https://www.nature.com/articles/sdata201422/tables/2 for reference.\n",
|
| 164 |
+
" Adapted from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n",
|
| 165 |
+
" \"\"\"\n",
|
| 166 |
+
" assert array_format in ['np', 'pt'], \\\n",
|
| 167 |
+
" f\"Invalid array_format: `{array_format}` provided. Must be one of `np` (numpy.array), `pt` (torch.tensor).\"\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" num_atoms = 0\n",
|
| 170 |
+
" scalar_properties = []\n",
|
| 171 |
+
" atomic_symbols = []\n",
|
| 172 |
+
" xyz = []\n",
|
| 173 |
+
" charges = []\n",
|
| 174 |
+
" harmonic_vibrational_frequencies = []\n",
|
| 175 |
+
" smiles = ''\n",
|
| 176 |
+
" inchi = ''\n",
|
| 177 |
+
" with open(filename, 'r') as f:\n",
|
| 178 |
+
" for line_num, line in enumerate(f):\n",
|
| 179 |
+
" if line_num == 0:\n",
|
| 180 |
+
" num_atoms = int(line)\n",
|
| 181 |
+
" elif line_num == 1:\n",
|
| 182 |
+
" scalar_properties = [float(i) for i in line.split()[2:]]\n",
|
| 183 |
+
" elif 2 <= line_num <= 1 + num_atoms:\n",
|
| 184 |
+
" atom_symbol, x, y, z, charge = line.split()\n",
|
| 185 |
+
" atomic_symbols.append(atom_symbol)\n",
|
| 186 |
+
" xyz.append([parse_float(x), parse_float(y), parse_float(z)])\n",
|
| 187 |
+
" charges.append(parse_float(charge))\n",
|
| 188 |
+
" elif line_num == num_atoms + 2:\n",
|
| 189 |
+
" harmonic_vibrational_frequencies = [float(i) for i in line.split()]\n",
|
| 190 |
+
" elif line_num == num_atoms + 3:\n",
|
| 191 |
+
" smiles = line.split()[0]\n",
|
| 192 |
+
" elif line_num == num_atoms + 4:\n",
|
| 193 |
+
" inchi = line.split()[0]\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" array_wrap = np.array if array_format == 'np' else torch.tensor\n",
|
| 196 |
+
" result = {\n",
|
| 197 |
+
" 'num_atoms': num_atoms,\n",
|
| 198 |
+
" 'atomic_symbols': atomic_symbols,\n",
|
| 199 |
+
" 'pos': array_wrap(xyz),\n",
|
| 200 |
+
" 'charges': array_wrap(charges),\n",
|
| 201 |
+
" 'harmonic_oscillator_frequencies': array_wrap(harmonic_vibrational_frequencies),\n",
|
| 202 |
+
" 'smiles': smiles,\n",
|
| 203 |
+
" 'inchi': inchi\n",
|
| 204 |
+
" }\n",
|
| 205 |
+
" scalar_property_labels = [\n",
|
| 206 |
+
" 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u', 'h', 'g', 'cv'\n",
|
| 207 |
+
" ] \n",
|
| 208 |
+
" scalar_properties = dict(zip(scalar_property_labels, scalar_properties))\n",
|
| 209 |
+
" result.update(scalar_properties)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" # RdKit\n",
|
| 212 |
+
" result['canonical_smiles'] = rdChem.CanonSmiles(result['smiles'])\n",
|
| 213 |
+
" m = rdChem.MolFromSmiles(result['canonical_smiles'])\n",
|
| 214 |
+
" result['logP'] = Crippen.MolLogP(m)\n",
|
| 215 |
+
" result['qed'] = QED.qed(m)\n",
|
| 216 |
+
" if npscorer_model is not None:\n",
|
| 217 |
+
" result['np_score'] = npscorer.scoreMol(m, npscorer_model)\n",
|
| 218 |
+
" result['sa_score'] = sascorer.calculateScore(m)\n",
|
| 219 |
+
" result.update(count_rings_and_bonds(m, max_ring_size=max_ring_size))\n",
|
| 220 |
+
" \n",
|
| 221 |
+
" return result"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"id": "72254d85",
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"outputs": [],
|
| 230 |
+
"source": [
|
| 231 |
+
"\"\"\"\n",
|
| 232 |
+
" Download xyz files from:\n",
|
| 233 |
+
" https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904\n",
|
| 234 |
+
" > wget https://figshare.com/ndownloader/files/3195389/dsgdb9nsd.xyz.tar.bz2\n",
|
| 235 |
+
" > mkdir dsgdb9nsd.xyz\n",
|
| 236 |
+
" > tar -xvjf dsgdb9nsd.xyz.tar.bz2 -C dsgdb9nsd.xyz\n",
|
| 237 |
+
"\"\"\"\n",
|
| 238 |
+
"MAX_RING_SIZE = 9\n",
|
| 239 |
+
"fscore = npscorer.readNPModel()\n",
|
| 240 |
+
"xyz_dir_path = '/Users/yairschiff/Downloads/dsgdb9nsd.xyz'\n",
|
| 241 |
+
"parsed_xyz = []\n",
|
| 242 |
+
"for file in tqdm(sorted(os.listdir(xyz_dir_path)), desc='Parsing'):\n",
|
| 243 |
+
" parsed = parse_xyz(os.path.join(xyz_dir_path, file),\n",
|
| 244 |
+
" max_ring_size=MAX_RING_SIZE,\n",
|
| 245 |
+
" npscorer_model=fscore,\n",
|
| 246 |
+
" array_format='np')\n",
|
| 247 |
+
" parsed_xyz.append(parsed)"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"cell_type": "code",
|
| 252 |
+
"execution_count": null,
|
| 253 |
+
"id": "12969dd2",
|
| 254 |
+
"metadata": {},
|
| 255 |
+
"outputs": [],
|
| 256 |
+
"source": [
|
| 257 |
+
"qm9_df = pd.DataFrame(data=parsed_xyz)"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"id": "eed4f163",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"# Conversion below is needed to avoid:\n",
|
| 268 |
+
"# `ArrowInvalid: ('Can only convert 1-dimensional array values',\n",
|
| 269 |
+
"# 'Conversion failed for column pos with type object')`\n",
|
| 270 |
+
"qm9_df['pos'] = qm9_df['pos'].apply(lambda x: [xi for xi in x])"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"cell_type": "code",
|
| 275 |
+
"execution_count": null,
|
| 276 |
+
"id": "c912d23a",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [],
|
| 279 |
+
"source": [
|
| 280 |
+
"dataset = datasets.Dataset.from_pandas(qm9_df)"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": null,
|
| 286 |
+
"id": "7a7df506",
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"outputs": [],
|
| 289 |
+
"source": [
|
| 290 |
+
"dataset.push_to_hub('yairschiff/qm9')"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": null,
|
| 296 |
+
"id": "86c4e1ae",
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"outputs": [],
|
| 299 |
+
"source": [
|
| 300 |
+
"# # Random train/test splits as recommended by:\n",
|
| 301 |
+
"# # https://moleculenet.org/datasets-1\n",
|
| 302 |
+
"# test_size = 0.1\n",
|
| 303 |
+
"# seed = 1\n",
|
| 304 |
+
"# dataset.train_test_split(test_size=test_size, seed=seed)"
|
| 305 |
+
]
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"cell_type": "markdown",
|
| 309 |
+
"id": "e982da1b-05ab-493b-bb82-8bf1225dcb2b",
|
| 310 |
+
"metadata": {},
|
| 311 |
+
"source": [
|
| 312 |
+
"## Create tokenizer"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": 7,
|
| 318 |
+
"id": "b0504e77",
|
| 319 |
+
"metadata": {},
|
| 320 |
+
"outputs": [],
|
| 321 |
+
"source": [
|
| 322 |
+
"def smi_tokenizer(smi):\n",
|
| 323 |
+
" \"\"\"Tokenize a SMILES molecule or reaction.\n",
|
| 324 |
+
"\n",
|
| 325 |
+
" Copied from https://github.com/pschwllr/MolecularTransformer.\n",
|
| 326 |
+
" \"\"\"\n",
|
| 327 |
+
" import re\n",
|
| 328 |
+
" pattern = \"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\\\\\|\\/|:|~|@|\\?|>|\\*|\\$|\\%[0-9]{2}|[0-9])\"\n",
|
| 329 |
+
" regex = re.compile(pattern)\n",
|
| 330 |
+
" tokens = [token for token in regex.findall(smi)]\n",
|
| 331 |
+
" assert smi == ''.join(tokens)\n",
|
| 332 |
+
" return tokens"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "code",
|
| 337 |
+
"execution_count": 8,
|
| 338 |
+
"id": "b89a4def-ea08-466a-8779-24acf75a2bd0",
|
| 339 |
+
"metadata": {},
|
| 340 |
+
"outputs": [],
|
| 341 |
+
"source": [
|
| 342 |
+
"dataset = datasets.load_dataset('yairschiff/qm9', split='train')"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"execution_count": 9,
|
| 348 |
+
"id": "6ef61481-9384-4c1c-8361-ab858cb157ba",
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"outputs": [],
|
| 351 |
+
"source": [
|
| 352 |
+
"# # If vocab file not created yet, uncomment and run this cell\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"# tokens = []\n",
|
| 355 |
+
"# for smi in dataset['canonical_smiles']:\n",
|
| 356 |
+
"# tokens.extend(smi_tokenizer(smi))\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"# with open('qm9_vocab.json', 'w', encoding='utf-8') as f:\n",
|
| 359 |
+
"# f.write(\n",
|
| 360 |
+
"# json.dumps(\n",
|
| 361 |
+
"# {t: i for i, t in enumerate(sorted(set(tokens)))},\n",
|
| 362 |
+
"# indent=2,\n",
|
| 363 |
+
"# sort_keys=True,\n",
|
| 364 |
+
"# ensure_ascii=False\n",
|
| 365 |
+
"# ) + '\\n')"
|
| 366 |
+
]
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"cell_type": "code",
|
| 370 |
+
"execution_count": 9,
|
| 371 |
+
"id": "6af7fccb-08ee-4dc6-99dc-cfa4fc38074c",
|
| 372 |
+
"metadata": {},
|
| 373 |
+
"outputs": [],
|
| 374 |
+
"source": [
|
| 375 |
+
"# # If HF tokenizer not yet published, uncomment and run this cell\n",
|
| 376 |
+
"# import tokenizer\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"# tokenizer.QM9Tokenizer.register_for_auto_class()\n",
|
| 379 |
+
"# qm9_tokenizer = tokenizer.QM9Tokenizer(vocab_file='qm9_vocab.json')\n",
|
| 380 |
+
"# qm9_tokenizer.push_to_hub('yairschiff/qm9-tokenizer')"
|
| 381 |
+
]
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"cell_type": "code",
|
| 385 |
+
"execution_count": 23,
|
| 386 |
+
"id": "4cc39f16-b53c-481a-a35e-a42fb1b08378",
|
| 387 |
+
"metadata": {},
|
| 388 |
+
"outputs": [],
|
| 389 |
+
"source": [
|
| 390 |
+
"# Test tokenizer\n",
|
| 391 |
+
"qm9_tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
|
| 392 |
+
" 'yairschiff/qm9-tokenizer', trust_remote_code=True, resume_download=None)\n",
|
| 393 |
+
"print(dataset[1000]['canonical_smiles'])\n",
|
| 394 |
+
"print(qm9_tokenizer.encode(dataset[1000]['canonical_smiles']))\n",
|
| 395 |
+
"print(qm9_tokenizer.decode(qm9_tokenizer.encode(dataset[1000]['canonical_smiles'])))"
|
| 396 |
+
]
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"cell_type": "code",
|
| 400 |
+
"execution_count": null,
|
| 401 |
+
"id": "41752e94-175e-4f40-b9d2-496241eab0c0",
|
| 402 |
+
"metadata": {},
|
| 403 |
+
"outputs": [],
|
| 404 |
+
"source": []
|
| 405 |
+
}
|
| 406 |
+
],
|
| 407 |
+
"metadata": {
|
| 408 |
+
"kernelspec": {
|
| 409 |
+
"display_name": "Python 3 (ipykernel)",
|
| 410 |
+
"language": "python",
|
| 411 |
+
"name": "python3"
|
| 412 |
+
},
|
| 413 |
+
"language_info": {
|
| 414 |
+
"codemirror_mode": {
|
| 415 |
+
"name": "ipython",
|
| 416 |
+
"version": 3
|
| 417 |
+
},
|
| 418 |
+
"file_extension": ".py",
|
| 419 |
+
"mimetype": "text/x-python",
|
| 420 |
+
"name": "python",
|
| 421 |
+
"nbconvert_exporter": "python",
|
| 422 |
+
"pygments_lexer": "ipython3",
|
| 423 |
+
"version": "3.9.18"
|
| 424 |
+
}
|
| 425 |
+
},
|
| 426 |
+
"nbformat": 4,
|
| 427 |
+
"nbformat_minor": 5
|
| 428 |
+
}
|
notebooks/qm9_vocab.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"#": 0,
|
| 3 |
+
"(": 1,
|
| 4 |
+
")": 2,
|
| 5 |
+
"-": 3,
|
| 6 |
+
"1": 4,
|
| 7 |
+
"2": 5,
|
| 8 |
+
"3": 6,
|
| 9 |
+
"4": 7,
|
| 10 |
+
"5": 8,
|
| 11 |
+
"=": 9,
|
| 12 |
+
"C": 10,
|
| 13 |
+
"F": 11,
|
| 14 |
+
"N": 12,
|
| 15 |
+
"O": 13,
|
| 16 |
+
"[C-]": 14,
|
| 17 |
+
"[CH-]": 15,
|
| 18 |
+
"[N+]": 16,
|
| 19 |
+
"[N-]": 17,
|
| 20 |
+
"[NH+]": 18,
|
| 21 |
+
"[NH2+]": 19,
|
| 22 |
+
"[NH3+]": 20,
|
| 23 |
+
"[O-]": 21,
|
| 24 |
+
"[c-]": 22,
|
| 25 |
+
"[cH-]": 23,
|
| 26 |
+
"[n-]": 24,
|
| 27 |
+
"[nH+]": 25,
|
| 28 |
+
"[nH]": 26,
|
| 29 |
+
"c": 27,
|
| 30 |
+
"n": 28,
|
| 31 |
+
"o": 29
|
| 32 |
+
}
|
notebooks/zinc250k_data_prep.ipynb
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "fa328603-9e2b-4643-8500-ec11c51b5223",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Imports and setup"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 3,
|
| 14 |
+
"id": "7716fb32-a805-4888-9dac-da4cff4f6e40",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os\n",
|
| 19 |
+
"import huggingface_hub"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": 4,
|
| 25 |
+
"id": "432e1636",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [
|
| 28 |
+
{
|
| 29 |
+
"name": "stdout",
|
| 30 |
+
"output_type": "stream",
|
| 31 |
+
"text": [
|
| 32 |
+
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
|
| 33 |
+
"Token is valid (permission: write).\n",
|
| 34 |
+
"Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token\n",
|
| 35 |
+
"Login successful\n"
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"source": [
|
| 40 |
+
"if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n",
|
| 41 |
+
" with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n",
|
| 42 |
+
" token = f.read().strip()\n",
|
| 43 |
+
"else:\n",
|
| 44 |
+
" token = None\n",
|
| 45 |
+
"huggingface_hub.login(token=token)"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 26,
|
| 51 |
+
"id": "e22e86ae",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"import json\n",
|
| 56 |
+
"import re\n",
|
| 57 |
+
"import typing\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"import datasets\n",
|
| 60 |
+
"import numpy as np\n",
|
| 61 |
+
"import pandas as pd\n",
|
| 62 |
+
"import rdkit\n",
|
| 63 |
+
"import transformers\n",
|
| 64 |
+
"from rdkit import Chem as rdChem\n",
|
| 65 |
+
"from tqdm.auto import tqdm"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "code",
|
| 70 |
+
"execution_count": 7,
|
| 71 |
+
"id": "aaa00828",
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
+
"source": [
|
| 75 |
+
"# TODO: Update to 2024.03.6 release when available instead of suppressing warning!\n",
|
| 76 |
+
"# See: https://github.com/rdkit/rdkit/issues/7625#\n",
|
| 77 |
+
"rdkit.rdBase.DisableLog('rdApp.warning')"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "markdown",
|
| 82 |
+
"id": "0a878a71-d33f-43fe-955d-4250950b1eec",
|
| 83 |
+
"metadata": {
|
| 84 |
+
"jp-MarkdownHeadingCollapsed": true
|
| 85 |
+
},
|
| 86 |
+
"source": [
|
| 87 |
+
"## Create dataset"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"id": "26856fe2",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"def count_rings_and_bonds(\n",
|
| 98 |
+
" mol: rdChem.Mol\n",
|
| 99 |
+
") -> typing.Dict[str, int]:\n",
|
| 100 |
+
" \"\"\"Counts bond and ring (by type).\"\"\"\n",
|
| 101 |
+
" \n",
|
| 102 |
+
" # Counting rings\n",
|
| 103 |
+
" ssr = rdChem.GetSymmSSSR(mol)\n",
|
| 104 |
+
" ring_count = len(ssr)\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" ring_sizes = {}\n",
|
| 107 |
+
" for ring in ssr:\n",
|
| 108 |
+
" ring_size = len(ring)\n",
|
| 109 |
+
" if ring_size not in ring_sizes:\n",
|
| 110 |
+
" ring_sizes[ring_size] = 0\n",
|
| 111 |
+
" ring_sizes[ring_size] += 1\n",
|
| 112 |
+
" \n",
|
| 113 |
+
" # Counting bond types\n",
|
| 114 |
+
" bond_counts = {\n",
|
| 115 |
+
" 'single': 0,\n",
|
| 116 |
+
" 'double': 0,\n",
|
| 117 |
+
" 'triple': 0,\n",
|
| 118 |
+
" 'aromatic': 0\n",
|
| 119 |
+
" }\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" for bond in mol.GetBonds():\n",
|
| 122 |
+
" if bond.GetIsAromatic():\n",
|
| 123 |
+
" bond_counts['aromatic'] += 1\n",
|
| 124 |
+
" elif bond.GetBondType() == rdChem.BondType.SINGLE:\n",
|
| 125 |
+
" bond_counts['single'] += 1\n",
|
| 126 |
+
" elif bond.GetBondType() == rdChem.BondType.DOUBLE:\n",
|
| 127 |
+
" bond_counts['double'] += 1\n",
|
| 128 |
+
" elif bond.GetBondType() == rdChem.BondType.TRIPLE:\n",
|
| 129 |
+
" bond_counts['triple'] += 1\n",
|
| 130 |
+
" result = {\n",
|
| 131 |
+
" 'ring_count': ring_count,\n",
|
| 132 |
+
" }\n",
|
| 133 |
+
" for k, v in ring_sizes.items():\n",
|
| 134 |
+
" result[f\"R{k}\"] = v\n",
|
| 135 |
+
"\n",
|
| 136 |
+
" for k, v in bond_counts.items():\n",
|
| 137 |
+
" result[f\"{k}_bond\"] = v\n",
|
| 138 |
+
" return result"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "code",
|
| 143 |
+
"execution_count": null,
|
| 144 |
+
"id": "fbde53f7",
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"outputs": [],
|
| 147 |
+
"source": [
|
| 148 |
+
"\"\"\"\n",
|
| 149 |
+
" Download data and validation indices from:\n",
|
| 150 |
+
" \"Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations\"\n",
|
| 151 |
+
" https://github.com/harryjo97/GDSS\n",
|
| 152 |
+
" > wget wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/zinc250k.csv\n",
|
| 153 |
+
" > wget https://raw.githubusercontent.com/harryjo97/GDSS/master/data/valid_idx_zinc250k.json\n",
|
| 154 |
+
"\"\"\"\n",
|
| 155 |
+
"df = pd.read_csv('/Users/yairschiff/Downloads/zinc250k.csv', index_col=0, encoding='utf_8')\n",
|
| 156 |
+
"feats = []\n",
|
| 157 |
+
"for i, row in tqdm(df.iterrows(), total=len(df), desc='RDKit feats', leave=False):\n",
|
| 158 |
+
" feat = {'smiles': row['smiles']}\n",
|
| 159 |
+
" feat['canonical_smiles'] = rdChem.CanonSmiles(feat['smiles'])\n",
|
| 160 |
+
" m = rdChem.MolFromSmiles(feat['canonical_smiles'])\n",
|
| 161 |
+
" feat.update(count_rings_and_bonds(m))\n",
|
| 162 |
+
" feats.append(feat)\n",
|
| 163 |
+
"df = pd.merge(df, pd.DataFrame.from_records(feats), on='smiles')\n",
|
| 164 |
+
"df = df.fillna(0)\n",
|
| 165 |
+
"for col in df.columns: # recast ring counts as int\n",
|
| 166 |
+
" if re.search(\"^R[0-9]+$\", col) is not None:\n",
|
| 167 |
+
" df[col] = df[col].astype(int)\n",
|
| 168 |
+
"# Re-order columns\n",
|
| 169 |
+
"df = df[\n",
|
| 170 |
+
" ['smiles', 'logP', 'qed', 'SAS', 'canonical_smiles',\n",
|
| 171 |
+
" 'single_bond', 'double_bond', 'triple_bond', 'aromatic_bond',\n",
|
| 172 |
+
" 'ring_count','R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R12', 'R13', 'R14', 'R15', 'R18', 'R24']]"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": null,
|
| 178 |
+
"id": "1e2d5955",
|
| 179 |
+
"metadata": {},
|
| 180 |
+
"outputs": [],
|
| 181 |
+
"source": [
|
| 182 |
+
"# Read in validation indices\n",
|
| 183 |
+
"with open('/Users/yairschiff/Downloads/valid_idx_zinc250k.json', 'r') as f:\n",
|
| 184 |
+
" valid_idxs = json.load(f)\n",
|
| 185 |
+
"df['validation'] = df.index.isin(valid_idxs).astype(int)"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"id": "2b89b732",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"outputs": [],
|
| 194 |
+
"source": [
|
| 195 |
+
"# Create HF dataset\n",
|
| 196 |
+
"dataset = datasets.DatasetDict({\n",
|
| 197 |
+
" 'train': datasets.Dataset.from_pandas(df[df['validation'] == 0].drop(columns=['validation'])),\n",
|
| 198 |
+
" 'validation': datasets.Dataset.from_pandas(df[df['validation'] == 1].drop(columns=['validation'])),\n",
|
| 199 |
+
"})\n",
|
| 200 |
+
"dataset = dataset.remove_columns('__index_level_0__')"
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "code",
|
| 205 |
+
"execution_count": null,
|
| 206 |
+
"id": "1efb5845",
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"dataset.push_to_hub('yairschiff/zinc250k')"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "markdown",
|
| 215 |
+
"id": "5c6f357d-20d9-4004-8091-68726b6b4c86",
|
| 216 |
+
"metadata": {},
|
| 217 |
+
"source": [
|
| 218 |
+
"## Create tokenizer"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "code",
|
| 223 |
+
"execution_count": 8,
|
| 224 |
+
"id": "6642fc9d-4863-4e14-947b-95bae48e192d",
|
| 225 |
+
"metadata": {},
|
| 226 |
+
"outputs": [],
|
| 227 |
+
"source": [
|
| 228 |
+
"def smi_tokenizer(smi):\n",
|
| 229 |
+
" \"\"\"Tokenize a SMILES molecule or reaction.\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" Copied from https://github.com/pschwllr/MolecularTransformer.\n",
|
| 232 |
+
" \"\"\"\n",
|
| 233 |
+
" import re\n",
|
| 234 |
+
" pattern = \"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\\\\\|\\/|:|~|@|\\?|>|\\*|\\$|\\%[0-9]{2}|[0-9])\"\n",
|
| 235 |
+
" regex = re.compile(pattern)\n",
|
| 236 |
+
" tokens = [token for token in regex.findall(smi)]\n",
|
| 237 |
+
" assert smi == ''.join(tokens)\n",
|
| 238 |
+
" return tokens"
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"execution_count": 11,
|
| 244 |
+
"id": "3a9e2e60-8596-4a91-acc3-d43e166ce723",
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"outputs": [],
|
| 247 |
+
"source": [
|
| 248 |
+
"dataset = datasets.load_dataset('yairschiff/zinc250k')"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"cell_type": "code",
|
| 253 |
+
"execution_count": 12,
|
| 254 |
+
"id": "fbd5c2fe-4318-46bb-bc43-6ef7fe76e9fd",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"outputs": [],
|
| 257 |
+
"source": [
|
| 258 |
+
"# # If vocab file not created yet, uncomment and run this cell\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"# tokens = []\n",
|
| 261 |
+
"# for split in dataset.keys():\n",
|
| 262 |
+
"# for smi in dataset[split]['canonical_smiles']:\n",
|
| 263 |
+
"# tokens.extend(smi_tokenizer(smi))\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# with open('zinc250k_vocab.json', 'w', encoding='utf-8') as f:\n",
|
| 266 |
+
"# f.write(\n",
|
| 267 |
+
"# json.dumps(\n",
|
| 268 |
+
"# {t: i for i, t in enumerate(sorted(set(tokens)))},\n",
|
| 269 |
+
"# indent=2,\n",
|
| 270 |
+
"# sort_keys=True,\n",
|
| 271 |
+
"# ensure_ascii=False\n",
|
| 272 |
+
"# ) + '\\n')"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"execution_count": 14,
|
| 278 |
+
"id": "4962478b-5343-4838-befe-64a5389625d4",
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"outputs": [
|
| 281 |
+
{
|
| 282 |
+
"data": {
|
| 283 |
+
"text/plain": [
|
| 284 |
+
"CommitInfo(commit_url='https://huggingface.co/yairschiff/zinc250k-tokenizer/commit/7a07b0165a8a4f14f09d6137da8cdabf789397fd', commit_message='Upload tokenizer', commit_description='', oid='7a07b0165a8a4f14f09d6137da8cdabf789397fd', pr_url=None, pr_revision=None, pr_num=None)"
|
| 285 |
+
]
|
| 286 |
+
},
|
| 287 |
+
"execution_count": 14,
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"output_type": "execute_result"
|
| 290 |
+
}
|
| 291 |
+
],
|
| 292 |
+
"source": [
|
| 293 |
+
"# # If HF tokenizer not yet published, uncomment and run this cell\n",
|
| 294 |
+
"# import tokenizer\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"# tokenizer.Zinc250kTokenizer.register_for_auto_class()\n",
|
| 297 |
+
"# zinc250k_tokenizer = tokenizer.Zinc250kTokenizer(vocab_file='zinc250k_vocab.json')\n",
|
| 298 |
+
"# zinc250k_tokenizer.push_to_hub('yairschiff/zinc250k-tokenizer')"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "code",
|
| 303 |
+
"execution_count": 18,
|
| 304 |
+
"id": "a779aa57-0c9d-4b8c-bf11-ccc5ab4c462e",
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"outputs": [
|
| 307 |
+
{
|
| 308 |
+
"name": "stdout",
|
| 309 |
+
"output_type": "stream",
|
| 310 |
+
"text": [
|
| 311 |
+
"Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1\n",
|
| 312 |
+
"[0, 25, 69, 15, 69, 68, 68, 16, 68, 15, 25, 25, 25, 35, 16, 29, 25, 11, 23, 30, 12, 29, 25, 35, 11, 30, 12, 25, 30, 68, 15, 68, 68, 68, 11, 27, 12, 68, 68, 15, 1]\n",
|
| 313 |
+
"<bos>Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1<eos>\n",
|
| 314 |
+
"Cn1ncc2c1CCC[C@H]2NC(=O)NC[C@H](O)COc1ccc(F)cc1\n"
|
| 315 |
+
]
|
| 316 |
+
}
|
| 317 |
+
],
|
| 318 |
+
"source": [
|
| 319 |
+
"# Test tokenizer\n",
|
| 320 |
+
"zinc250k_tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
|
| 321 |
+
" 'yairschiff/zinc250k-tokenizer', trust_remote_code=True, resume_download=None)\n",
|
| 322 |
+
"print(dataset['train'][1000]['canonical_smiles'])\n",
|
| 323 |
+
"print(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles']))\n",
|
| 324 |
+
"print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'])))\n",
|
| 325 |
+
"print(zinc250k_tokenizer.decode(zinc250k_tokenizer.encode(dataset['train'][1000]['canonical_smiles'], add_special_tokens=False)))"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"cell_type": "code",
|
| 330 |
+
"execution_count": 28,
|
| 331 |
+
"id": "f3a15585-8e75-409d-9afe-0e7fe4a0bffc",
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"outputs": [
|
| 334 |
+
{
|
| 335 |
+
"data": {
|
| 336 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 337 |
+
"model_id": "",
|
| 338 |
+
"version_major": 2,
|
| 339 |
+
"version_minor": 0
|
| 340 |
+
},
|
| 341 |
+
"text/plain": [
|
| 342 |
+
" 0%| | 0/224568 [00:00<?, ?it/s]"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
"metadata": {},
|
| 346 |
+
"output_type": "display_data"
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"data": {
|
| 350 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 351 |
+
"model_id": "",
|
| 352 |
+
"version_major": 2,
|
| 353 |
+
"version_minor": 0
|
| 354 |
+
},
|
| 355 |
+
"text/plain": [
|
| 356 |
+
" 0%| | 0/24887 [00:00<?, ?it/s]"
|
| 357 |
+
]
|
| 358 |
+
},
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"output_type": "display_data"
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"name": "stdout",
|
| 364 |
+
"output_type": "stream",
|
| 365 |
+
"text": [
|
| 366 |
+
"(array([ 152, 3351, 21311, 47185, 67972, 70367, 25030, 11778, 2179,\n",
|
| 367 |
+
" 130]), array([10. , 16.4, 22.8, 29.2, 35.6, 42. , 48.4, 54.8, 61.2, 67.6, 74. ]))\n",
|
| 368 |
+
"10\n",
|
| 369 |
+
"74\n"
|
| 370 |
+
]
|
| 371 |
+
}
|
| 372 |
+
],
|
| 373 |
+
"source": [
|
| 374 |
+
"lengths = [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['train'], leave=False)]\n",
|
| 375 |
+
"lengths += [len(zinc250k_tokenizer.encode(i['canonical_smiles'])) for i in tqdm(dataset['validation'], leave=False)]\n",
|
| 376 |
+
"print(np.histogram(lengths))\n",
|
| 377 |
+
"print(min(lengths))\n",
|
| 378 |
+
"print(max(lengths))"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "code",
|
| 383 |
+
"execution_count": null,
|
| 384 |
+
"id": "d7a6e081-4961-4cf4-a19d-0375bedd7dab",
|
| 385 |
+
"metadata": {},
|
| 386 |
+
"outputs": [],
|
| 387 |
+
"source": []
|
| 388 |
+
}
|
| 389 |
+
],
|
| 390 |
+
"metadata": {
|
| 391 |
+
"kernelspec": {
|
| 392 |
+
"display_name": "Python 3 (ipykernel)",
|
| 393 |
+
"language": "python",
|
| 394 |
+
"name": "python3"
|
| 395 |
+
},
|
| 396 |
+
"language_info": {
|
| 397 |
+
"codemirror_mode": {
|
| 398 |
+
"name": "ipython",
|
| 399 |
+
"version": 3
|
| 400 |
+
},
|
| 401 |
+
"file_extension": ".py",
|
| 402 |
+
"mimetype": "text/x-python",
|
| 403 |
+
"name": "python",
|
| 404 |
+
"nbconvert_exporter": "python",
|
| 405 |
+
"pygments_lexer": "ipython3",
|
| 406 |
+
"version": "3.9.18"
|
| 407 |
+
}
|
| 408 |
+
},
|
| 409 |
+
"nbformat": 4,
|
| 410 |
+
"nbformat_minor": 5
|
| 411 |
+
}
|
notebooks/zinc250k_vocab.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"#": 0,
|
| 3 |
+
"(": 1,
|
| 4 |
+
")": 2,
|
| 5 |
+
"-": 3,
|
| 6 |
+
"/": 4,
|
| 7 |
+
"1": 5,
|
| 8 |
+
"2": 6,
|
| 9 |
+
"3": 7,
|
| 10 |
+
"4": 8,
|
| 11 |
+
"5": 9,
|
| 12 |
+
"6": 10,
|
| 13 |
+
"7": 11,
|
| 14 |
+
"8": 12,
|
| 15 |
+
"=": 13,
|
| 16 |
+
"Br": 14,
|
| 17 |
+
"C": 15,
|
| 18 |
+
"Cl": 16,
|
| 19 |
+
"F": 17,
|
| 20 |
+
"I": 18,
|
| 21 |
+
"N": 19,
|
| 22 |
+
"O": 20,
|
| 23 |
+
"P": 21,
|
| 24 |
+
"S": 22,
|
| 25 |
+
"[C@@H]": 23,
|
| 26 |
+
"[C@@]": 24,
|
| 27 |
+
"[C@H]": 25,
|
| 28 |
+
"[C@]": 26,
|
| 29 |
+
"[CH-]": 27,
|
| 30 |
+
"[CH2-]": 28,
|
| 31 |
+
"[N+]": 29,
|
| 32 |
+
"[N-]": 30,
|
| 33 |
+
"[NH+]": 31,
|
| 34 |
+
"[NH-]": 32,
|
| 35 |
+
"[NH2+]": 33,
|
| 36 |
+
"[NH3+]": 34,
|
| 37 |
+
"[O+]": 35,
|
| 38 |
+
"[O-]": 36,
|
| 39 |
+
"[OH+]": 37,
|
| 40 |
+
"[P+]": 38,
|
| 41 |
+
"[P@@H]": 39,
|
| 42 |
+
"[P@@]": 40,
|
| 43 |
+
"[P@]": 41,
|
| 44 |
+
"[PH+]": 42,
|
| 45 |
+
"[PH2]": 43,
|
| 46 |
+
"[PH]": 44,
|
| 47 |
+
"[S+]": 45,
|
| 48 |
+
"[S-]": 46,
|
| 49 |
+
"[S@@+]": 47,
|
| 50 |
+
"[S@@]": 48,
|
| 51 |
+
"[S@]": 49,
|
| 52 |
+
"[SH+]": 50,
|
| 53 |
+
"[n+]": 51,
|
| 54 |
+
"[n-]": 52,
|
| 55 |
+
"[nH+]": 53,
|
| 56 |
+
"[nH]": 54,
|
| 57 |
+
"[o+]": 55,
|
| 58 |
+
"[s+]": 56,
|
| 59 |
+
"\\": 57,
|
| 60 |
+
"c": 58,
|
| 61 |
+
"n": 59,
|
| 62 |
+
"o": 60,
|
| 63 |
+
"s": 61
|
| 64 |
+
}
|