Sophia commited on
Commit ·
8019be0
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- .gitignore +16 -0
- LICENSE +21 -0
- README.md +62 -0
- a2d2_mol/README.md +132 -0
- a2d2_mol/config_mol.yaml +54 -0
- a2d2_mol/evaluate_mol_table.py +308 -0
- a2d2_mol/finetune_mol.py +747 -0
- a2d2_mol/inference_quality_mol.py +554 -0
- a2d2_mol/mol_dataset.py +379 -0
- a2d2_mol/mol_scoring/oracle/fpscores.pkl +3 -0
- a2d2_mol/mol_scoring/scoring_functions.py +68 -0
- a2d2_mol/mol_utils/bracket_safe_converter.py +159 -0
- a2d2_mol/mol_utils/utils.py +135 -0
- a2d2_mol/mol_utils/utils_chem.py +187 -0
- a2d2_mol/oracle/fpscores.pkl +3 -0
- a2d2_mol/remasking_scheduleaware.py +177 -0
- a2d2_mol/sampling.py +1401 -0
- a2d2_mol/scripts/run_mol_finetune.slurm +200 -0
- a2d2_mol/scripts/train_mol.sh +93 -0
- a2d2_mol/train.py +216 -0
- a2d2_pep/README.md +145 -0
- a2d2_pep/config_pep.yaml +50 -0
- a2d2_pep/data/dataloading_for_dynamic_batching.py +189 -0
- a2d2_pep/data/dataset.py +207 -0
- a2d2_pep/evaluate_peptide_table.py +326 -0
- a2d2_pep/finetune_quality.py +892 -0
- a2d2_pep/inference_quality.py +605 -0
- a2d2_pep/pep_scoring/functions/binding.py +178 -0
- a2d2_pep/pep_scoring/functions/binding_utils.py +290 -0
- a2d2_pep/pep_scoring/functions/hemolysis.py +63 -0
- a2d2_pep/pep_scoring/functions/nonfouling.py +66 -0
- a2d2_pep/pep_scoring/functions/permeability.py +170 -0
- a2d2_pep/pep_scoring/functions/scoring_utils.py +94 -0
- a2d2_pep/pep_scoring/functions/solubility.py +63 -0
- a2d2_pep/pep_scoring/scoring_functions.py +79 -0
- a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py +424 -0
- a2d2_pep/pep_utils/analyzer.py +1274 -0
- a2d2_pep/pep_utils/utils.py +135 -0
- a2d2_pep/remasking_scheduleaware.py +181 -0
- a2d2_pep/sampling.py +1401 -0
- a2d2_pep/scripts/run_peptide_finetune.slurm +210 -0
- a2d2_pep/scripts/train_pep.sh +93 -0
- a2d2_pep/train.py +216 -0
- assets/a2d2.gif +3 -0
- demo/quality_inference_demo.ipynb +0 -0
- environment.yml +57 -0
- lightning_modules/__init__.py +16 -0
- lightning_modules/any_length_remask.py +801 -0
- lightning_modules/any_order.py +417 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoints/
|
| 2 |
+
pretrained/
|
| 3 |
+
__pycache__/
|
| 4 |
+
results/
|
| 5 |
+
a2d2_language/
|
| 6 |
+
a2d2_language/wandb/
|
| 7 |
+
a2d2_pep/wandb/
|
| 8 |
+
a2d2_mol/wandb/
|
| 9 |
+
logs/
|
| 10 |
+
*.pt
|
| 11 |
+
*.pyc
|
| 12 |
+
*.out
|
| 13 |
+
*.json
|
| 14 |
+
*.log
|
| 15 |
+
*.txt
|
| 16 |
+
*.wandb
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Sophia Tang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding](https://arxiv.org/abs/2606.13565) 🃏🔮
|
| 2 |
+
|
| 3 |
+
[**Sophia Tang**](https://sophtang.github.io/), [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/), [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
|
| 4 |
+
|
| 5 |
+
<p>
|
| 6 |
+
<a href="https://arxiv.org/abs/2606.13565"><img src="https://img.shields.io/badge/arXiv-6B67EE?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a>
|
| 7 |
+
<a href="https://sophtang.github.io/a2d2/"><img src="https://img.shields.io/badge/Project_Page-6B67EE?style=for-the-badge&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHZpZXdCb3g9IjAgMCAyNCAyNCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiBmaWxsPSJ3aGl0ZSI+PHBhdGggZD0iTTEwLjUgMS41QzExIDcuMyAxMy4yIDkuNSAxOSAxMEMxMy4yIDEwLjUgMTEgMTIuNyAxMC41IDE4LjVDMTAgMTIuNyA3LjggMTAuNSAyIDEwQzcuOCA5LjUgMTAgNy4zIDEwLjUgMS41WiIvPjxwYXRoIGQ9Ik0xOC41IDEzLjVDMTguNyAxNS44IDE5LjcgMTYuOCAyMiAxN0MxOS43IDE3LjIgMTguNyAxOC4yIDE4LjUgMjAuNUMxOC4zIDE4LjIgMTcuMyAxNy4yIDE1IDE3QzE3LjMgMTYuOCAxOC4zIDE1LjggMTguNSAxMy41WiIvPjxwYXRoIGQ9Ik01IDE1LjVDNS4xMiAxNyA1LjUgMTcuMzggNyAxNy41QzUuNSAxNy42MiA1LjEyIDE4IDUgMTkuNUM0Ljg4IDE4IDQuNSAxNy42MiAzIDE3LjVDNC41IDE3LjM4IDQuODggMTcgNSAxNS41WiIvPjwvc3ZnPg==" alt="Project Page"></a>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+

|
| 11 |
+
|
| 12 |
+
This is the repository for the paper [**A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**](https://arxiv.org/abs/2606.13565).
|
| 13 |
+
|
| 14 |
+
Masked discrete diffusion models (MDMs) offer a simple, stable likelihood-based framework for sequence generation, recently extended to **any-length** settings via token insertion. **A2D2** is a unified framework for reward-guided fine-tuning of any-length MDMs that **jointly optimizes the insertion and unmasking policies together with a quality-based inference schedule**, converging to the intractable reward-tilted distribution without requiring target samples.
|
| 15 |
+
|
| 16 |
+
🃏 We derive the **Radon–Nikodym derivative** for the joint insertion–unmasking path measures, enabling theoretically guaranteed convergence to the reward-tilted sequence distribution.
|
| 17 |
+
|
| 18 |
+
🃏 We establish **unmasking and insertion quality** as tractable approaches for minimizing decoding error (compounding parallelization error), and train lightweight quality predictors alongside the policy.
|
| 19 |
+
|
| 20 |
+
🃏 We introduce the **Adaptive Joint Decoding (AJD)** loss, which provably yields the optimal path measure that generates the reward-tilted distribution while remasking low-quality tokens and dropping low-quality insertions at inference.
|
| 21 |
+
|
| 22 |
+
🃏 Empirically, A2D2 improves reward optimization while enhancing generation **flexibility** and **accuracy** over prior fixed-length fine-tuning and inference-time guidance methods.
|
| 23 |
+
|
| 24 |
+
## Drug-Like Small Molecule Design 🧪
|
| 25 |
+
|
| 26 |
+
We pre-train an any-length MDM on the **SAFE** dataset ([Noutahi et al. 2024](https://arxiv.org/abs/2310.10773), ~950M molecules from ZINC and Unichem in SAFE notation) and fine-tune it with **A2D2** to optimize **QED** (drug-likeness) and **synthetic accessibility (SA)**. A2D2 jointly raises QED and lowers SA over the pre-trained baseline while increasing the fraction of valid, unique, drug-like, and synthesizable molecules. Code and instructions are in [`/a2d2_mol`](a2d2_mol).
|
| 27 |
+
|
| 28 |
+
## Multi-Objective Therapeutic Peptide Generation 💉
|
| 29 |
+
|
| 30 |
+
We pre-train an any-length **peptide SMILES** MDM on ~11M peptides (CycPeptMPDB, SmProt, CycloPs) and fine-tune with **A2D2** on five therapeutic properties simultaneously: **target-protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability**. A2D2 outperforms inference-time multi-objective guidance and fixed-length off-policy RL fine-tuning on almost all objectives, while improving the fraction of valid peptides. Code and instructions are in [`/a2d2_pep`](a2d2_pep).
|
| 31 |
+
|
| 32 |
+
## Language Model Reasoning 🧠
|
| 33 |
+
|
| 34 |
+
We additionally apply **A2D2** to reward fine-tuning of any-length language MDMs (LLaDA / FlexMDM), optimizing math-reasoning correctness and format rewards (GSM8K / MATH), including infilling variants. Code is in [`/a2d2_language`](a2d2_language).
|
| 35 |
+
|
| 36 |
+
## Repository Structure
|
| 37 |
+
|
| 38 |
+
| Directory | Experiment |
|
| 39 |
+
|-----------|------------|
|
| 40 |
+
| [`a2d2_mol`](a2d2_mol) | Drug-like small molecule design (QED, SA) |
|
| 41 |
+
| [`a2d2_pep`](a2d2_pep) | Multi-objective therapeutic peptide generation |
|
| 42 |
+
| [`a2d2_language`](a2d2_language) | Language model reasoning reward fine-tuning (code soon) |
|
| 43 |
+
| [`lightning_modules`](lightning_modules) | Any-length insertion MDM Lightning modules (policy + quality predictors) |
|
| 44 |
+
| [`model`](model) | Shared model architecture |
|
| 45 |
+
| [`demo`](demo) | Quality-guided inference demo notebook |
|
| 46 |
+
|
| 47 |
+
Each experiment directory contains its own `README.md` with environment setup, pretrained weight placement, fine-tuning commands, and evaluation instructions.
|
| 48 |
+
|
| 49 |
+
## Citation
|
| 50 |
+
|
| 51 |
+
If you find this repository helpful for your publications, please consider citing our paper:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
@article{tang2026a2d2,
|
| 55 |
+
title={A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding},
|
| 56 |
+
author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee},
|
| 57 |
+
journal={arXiv preprint arXiv:2606.13565},
|
| 58 |
+
year={2026}
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
To use this repository, you agree to abide by the MIT License.
|
a2d2_mol/README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A2D2 for Molecule Generation 🧪
|
| 2 |
+
|
| 3 |
+
This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA).
|
| 4 |
+
|
| 5 |
+
A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
|
| 6 |
+
|
| 7 |
+
Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer.
|
| 8 |
+
|
| 9 |
+
The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
|
| 10 |
+
|
| 11 |
+
## Environment Installation
|
| 12 |
+
```
|
| 13 |
+
# from the repository root
|
| 14 |
+
conda env create -f environment.yml
|
| 15 |
+
|
| 16 |
+
conda activate a2d2
|
| 17 |
+
```
|
| 18 |
+
The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
|
| 19 |
+
|
| 20 |
+
## Model Pretrained Weights
|
| 21 |
+
|
| 22 |
+
A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at:
|
| 23 |
+
```
|
| 24 |
+
A2D2/pretrained/anylength_mol.ckpt
|
| 25 |
+
```
|
| 26 |
+
```bash
|
| 27 |
+
# from the repository root
|
| 28 |
+
pip install gdown
|
| 29 |
+
mkdir -p pretrained
|
| 30 |
+
gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt
|
| 31 |
+
```
|
| 32 |
+
(Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
|
| 33 |
+
This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout.
|
| 34 |
+
|
| 35 |
+
## Pretraining the Any-Length Model
|
| 36 |
+
|
| 37 |
+
If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
|
| 38 |
+
|
| 39 |
+
### 1. The pretraining dataset
|
| 40 |
+
|
| 41 |
+
The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run.
|
| 42 |
+
|
| 43 |
+
The dataset is configured in [`config_mol.yaml`](config_mol.yaml):
|
| 44 |
+
|
| 45 |
+
```yaml
|
| 46 |
+
hf_dataset:
|
| 47 |
+
name: "datamol-io/safe-gpt"
|
| 48 |
+
smiles_column: "smiles"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column).
|
| 52 |
+
|
| 53 |
+
### 2. Configure
|
| 54 |
+
|
| 55 |
+
Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields:
|
| 56 |
+
|
| 57 |
+
| Field | Default | Notes |
|
| 58 |
+
|-------|---------|-------|
|
| 59 |
+
| `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). |
|
| 60 |
+
| `training.devices` | `2` | GPUs per node (DDP). |
|
| 61 |
+
| `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
|
| 62 |
+
| `training.max_steps` | `500000` | Total optimizer steps. |
|
| 63 |
+
| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
|
| 64 |
+
| `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). |
|
| 65 |
+
| `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. |
|
| 66 |
+
| `interpolant.max_length` | `256` | Max token length. |
|
| 67 |
+
|
| 68 |
+
### 3. Pre-training Any-Length Molecule Model
|
| 69 |
+
|
| 70 |
+
Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# from a2d2_mol/
|
| 74 |
+
sbatch train_mol.sh
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
`train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
python train.py --task mol
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`.
|
| 84 |
+
|
| 85 |
+
Checkpoints are written to `checkpoints/pretrain_mol/<timestamp>/` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/<date>_a2d2-mol_<jobid>.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
|
| 86 |
+
|
| 87 |
+
## Fine-Tune with A2D2
|
| 88 |
+
|
| 89 |
+
The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running:
|
| 90 |
+
|
| 91 |
+
1. Set `--base_path` to the location of `a2d2`. Results plots are written to `<base_path>/flexible/results/<run_name>/`.
|
| 92 |
+
2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`.
|
| 93 |
+
|
| 94 |
+
### Single run
|
| 95 |
+
|
| 96 |
+
[`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand.
|
| 97 |
+
|
| 98 |
+
The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`):
|
| 99 |
+
```bash
|
| 100 |
+
export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
|
| 101 |
+
export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
|
| 102 |
+
sbatch scripts/run_mol_finetune.slurm
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
|
| 106 |
+
```bash
|
| 107 |
+
sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm
|
| 108 |
+
```
|
| 109 |
+
The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/<job>_mol_<mode>/` and `results/mol_ablation/<mode>/`.
|
| 110 |
+
|
| 111 |
+
### Ablation flags
|
| 112 |
+
| Flag | Variant |
|
| 113 |
+
|------|---------|
|
| 114 |
+
| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
|
| 115 |
+
| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
|
| 116 |
+
| `--disable_insertion_planner` | A2D2 w/o insertion quality |
|
| 117 |
+
| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
|
| 118 |
+
| `--joint_training` | train policy + quality heads jointly (no alternation) |
|
| 119 |
+
|
| 120 |
+
## Evaluation
|
| 121 |
+
|
| 122 |
+
Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually:
|
| 123 |
+
```
|
| 124 |
+
python evaluate_mol_table.py \
|
| 125 |
+
--checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \
|
| 126 |
+
--pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \
|
| 127 |
+
--output_dir /path/to/results \
|
| 128 |
+
--num_samples 1000 --batch_size 50 \
|
| 129 |
+
--max_length 256 --total_num_steps 256 \
|
| 130 |
+
--num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0
|
| 131 |
+
```
|
| 132 |
+
This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_<mode>.csv`.
|
a2d2_mol/config_mol.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trainer: "any-order-flow"
|
| 2 |
+
dataset: "safe-drugs"
|
| 3 |
+
|
| 4 |
+
# HuggingFace dataset configuration
|
| 5 |
+
hf_dataset:
|
| 6 |
+
name: "datamol-io/safe-gpt"
|
| 7 |
+
smiles_column: "smiles" # Adjust based on actual column name in the dataset
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
hidden_size: 768
|
| 11 |
+
n_heads: 12
|
| 12 |
+
cond_dim: 128
|
| 13 |
+
dropout: 0.05
|
| 14 |
+
n_blocks: 12
|
| 15 |
+
torch_dtype: 'float32' # Options: 'float32', 'float16', 'bfloat16'
|
| 16 |
+
|
| 17 |
+
interpolant:
|
| 18 |
+
type: "any-order"
|
| 19 |
+
tokens: null # filled in automatically
|
| 20 |
+
pad_token: null # filled in automatically
|
| 21 |
+
mask_token: null # filled in automatically
|
| 22 |
+
max_length: 256
|
| 23 |
+
insert_schedule:
|
| 24 |
+
type: "linear"
|
| 25 |
+
unmask_schedule:
|
| 26 |
+
type: "linear"
|
| 27 |
+
|
| 28 |
+
training:
|
| 29 |
+
only_embed_insert: true
|
| 30 |
+
batch_size: 2048
|
| 31 |
+
per_gpu_batch_size: 64 # Gradient accumulation happens automatically
|
| 32 |
+
cpus: 4
|
| 33 |
+
learning_rate: 3e-4
|
| 34 |
+
nodes: 1
|
| 35 |
+
devices: 2
|
| 36 |
+
max_steps: 500000
|
| 37 |
+
weight_decay: 0.03
|
| 38 |
+
checkpoint_dir: "checkpoints/pretrain_mol"
|
| 39 |
+
save_top_k: 3
|
| 40 |
+
save_every_n_steps: 1000 # Save checkpoint every 1k steps (for streaming datasets)
|
| 41 |
+
# save_every_n_epochs: 1 # Not used with streaming datasets
|
| 42 |
+
loss_fn:
|
| 43 |
+
unmask: "elbo"
|
| 44 |
+
insert: "expectation"
|
| 45 |
+
reset_lr: false
|
| 46 |
+
warmup_steps: 2000
|
| 47 |
+
ema_decay: 0.9999
|
| 48 |
+
filter_max_length: false
|
| 49 |
+
|
| 50 |
+
wandb:
|
| 51 |
+
entity: null # set to your W&B entity, or leave null to use the default
|
| 52 |
+
project: "a2d2-mol"
|
| 53 |
+
name: "a2d2-mol"
|
| 54 |
+
path: "./wandb"
|
a2d2_mol/evaluate_mol_table.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate a finetuned molecule model checkpoint by sampling sequences
|
| 3 |
+
and computing metrics for the De Novo Small Molecule Generation table:
|
| 4 |
+
Validity (%), Uniqueness (%), QED (↑), SA (↓), Quality (%), Diversity (↑), Sampling Time (↓)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
import time
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from tdc import Oracle, Evaluator
|
| 15 |
+
|
| 16 |
+
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
|
| 17 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
sys.path.insert(0, REPO_ROOT)
|
| 19 |
+
|
| 20 |
+
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
|
| 21 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 22 |
+
from inference_quality_mol import sample_mol_eval
|
| 23 |
+
from mol_scoring.scoring_functions import MolScoringFunctions
|
| 24 |
+
from finetune_mol import MolFinetuner, get_tokenizer
|
| 25 |
+
from mol_utils.utils import str2bool, set_seed
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
|
| 29 |
+
"""Load a finetuned MolFinetuner from a Lightning checkpoint."""
|
| 30 |
+
# We need to reconstruct the model the same way main() does, then load state
|
| 31 |
+
# Load from Lightning checkpoint directly
|
| 32 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 33 |
+
hparams = ckpt.get('hyper_parameters', {})
|
| 34 |
+
args = hparams.get('args', None)
|
| 35 |
+
|
| 36 |
+
# Load pretrained base checkpoint to get config
|
| 37 |
+
base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
|
| 38 |
+
if 'hyper_parameters' in base_ckpt:
|
| 39 |
+
config = base_ckpt['hyper_parameters']['config']
|
| 40 |
+
elif 'config' in base_ckpt:
|
| 41 |
+
config = base_ckpt['config']
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError("Cannot find config in base checkpoint")
|
| 44 |
+
|
| 45 |
+
from omegaconf import OmegaConf, DictConfig
|
| 46 |
+
if not OmegaConf.is_config(config):
|
| 47 |
+
config = DictConfig(config)
|
| 48 |
+
OmegaConf.set_struct(config, False)
|
| 49 |
+
|
| 50 |
+
# Set adaptive schedule config from args or defaults
|
| 51 |
+
config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
|
| 52 |
+
config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
|
| 53 |
+
config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
|
| 54 |
+
config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
|
| 55 |
+
config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
|
| 56 |
+
config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
|
| 57 |
+
config.training.use_bracket_safe = True
|
| 58 |
+
OmegaConf.set_struct(config, True)
|
| 59 |
+
|
| 60 |
+
# Determine if planner should be loaded based on disable_planner flag
|
| 61 |
+
disable_planner = getattr(args, 'disable_planner', False)
|
| 62 |
+
|
| 63 |
+
# Initialize policy model
|
| 64 |
+
policy_model = AnyOrderInsertionFlowModuleFT(
|
| 65 |
+
config=config,
|
| 66 |
+
args=args,
|
| 67 |
+
pretrained_checkpoint=pretrained_ckpt_path,
|
| 68 |
+
insertion_planner=not disable_planner,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Load policy model weights from the finetuned checkpoint
|
| 72 |
+
state_dict = ckpt['state_dict']
|
| 73 |
+
# Lightning wraps the model: 'policy_model.xxx' -> remove prefix for the sub-module
|
| 74 |
+
policy_state = {}
|
| 75 |
+
for k, v in state_dict.items():
|
| 76 |
+
if k.startswith('policy_model.'):
|
| 77 |
+
policy_state[k[len('policy_model.'):]] = v
|
| 78 |
+
policy_model.load_state_dict(policy_state, strict=False)
|
| 79 |
+
policy_model = policy_model.to(device)
|
| 80 |
+
policy_model.eval()
|
| 81 |
+
|
| 82 |
+
return policy_model, args, config
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def evaluate_checkpoint(policy_model, tokenizer, reward_model, evaluator,
|
| 87 |
+
num_samples=1000, batch_size=50, max_length=256,
|
| 88 |
+
total_num_steps=256, quality_mode="both", num_remasking=2,
|
| 89 |
+
quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
|
| 90 |
+
"""
|
| 91 |
+
Sample `num_samples` molecules and compute all table metrics.
|
| 92 |
+
Returns a dict with: validity, uniqueness, qed, sa, quality, diversity, sampling_time
|
| 93 |
+
"""
|
| 94 |
+
all_valid_seqs = []
|
| 95 |
+
all_smiles_generated = 0
|
| 96 |
+
total_time = 0.0
|
| 97 |
+
|
| 98 |
+
num_batches = (num_samples + batch_size - 1) // batch_size
|
| 99 |
+
remaining = num_samples
|
| 100 |
+
|
| 101 |
+
for b in range(num_batches):
|
| 102 |
+
bs = min(batch_size, remaining)
|
| 103 |
+
remaining -= bs
|
| 104 |
+
|
| 105 |
+
t_start = time.time()
|
| 106 |
+
result = sample_mol_eval(
|
| 107 |
+
model=policy_model,
|
| 108 |
+
reward_model=reward_model,
|
| 109 |
+
tokenizer=tokenizer,
|
| 110 |
+
steps=total_num_steps,
|
| 111 |
+
mask=policy_model.interpolant.mask_token,
|
| 112 |
+
pad=policy_model.interpolant.pad_token,
|
| 113 |
+
batch_size=bs,
|
| 114 |
+
max_length=max_length,
|
| 115 |
+
quality_mode=quality_mode,
|
| 116 |
+
num_remasking=num_remasking,
|
| 117 |
+
quality_threshold=quality_threshold,
|
| 118 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 119 |
+
evaluator=evaluator,
|
| 120 |
+
dataframe=True,
|
| 121 |
+
)
|
| 122 |
+
t_end = time.time()
|
| 123 |
+
|
| 124 |
+
# Unpack: uniqueSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
|
| 125 |
+
unique_seqs, qed_scores, sa_scores, valid_frac, uniq, div, qual, df = result
|
| 126 |
+
|
| 127 |
+
all_valid_seqs.extend(list(unique_seqs) if not isinstance(unique_seqs, list) else unique_seqs)
|
| 128 |
+
all_smiles_generated += bs
|
| 129 |
+
total_time += (t_end - t_start)
|
| 130 |
+
|
| 131 |
+
print(f" Batch {b+1}/{num_batches}: {len(unique_seqs)} valid unique, "
|
| 132 |
+
f"time={t_end - t_start:.1f}s")
|
| 133 |
+
|
| 134 |
+
# --- Aggregate metrics over all samples ---
|
| 135 |
+
total_generated = num_samples
|
| 136 |
+
|
| 137 |
+
# Valid sequences (keeping duplicates for validity count)
|
| 138 |
+
# Re-evaluate from scratch on all collected valid sequences
|
| 139 |
+
all_unique = list(set(all_valid_seqs))
|
| 140 |
+
num_valid = len(all_valid_seqs) # total valid across batches (before dedup)
|
| 141 |
+
num_unique = len(all_unique)
|
| 142 |
+
|
| 143 |
+
validity = num_valid / total_generated * 100.0
|
| 144 |
+
uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
|
| 145 |
+
|
| 146 |
+
# Diversity on unique SMILES
|
| 147 |
+
diversity = evaluator(all_unique) if num_unique > 1 else 0.0
|
| 148 |
+
|
| 149 |
+
# QED and SA on unique sequences
|
| 150 |
+
if num_unique > 0:
|
| 151 |
+
oracle_qed = Oracle('qed')
|
| 152 |
+
oracle_sa = Oracle('sa')
|
| 153 |
+
qed_vals = oracle_qed(all_unique)
|
| 154 |
+
sa_vals = oracle_sa(all_unique)
|
| 155 |
+
mean_qed = np.mean(qed_vals)
|
| 156 |
+
mean_sa = np.mean(sa_vals)
|
| 157 |
+
|
| 158 |
+
# Quality: unique sequences with QED >= 0.6 AND SA <= 4
|
| 159 |
+
quality_mask = [(q >= 0.6 and s <= 4) for q, s in zip(qed_vals, sa_vals)]
|
| 160 |
+
quality = sum(quality_mask) / total_generated * 100.0
|
| 161 |
+
else:
|
| 162 |
+
mean_qed = 0.0
|
| 163 |
+
mean_sa = 0.0
|
| 164 |
+
quality = 0.0
|
| 165 |
+
|
| 166 |
+
sampling_time = total_time
|
| 167 |
+
|
| 168 |
+
metrics = {
|
| 169 |
+
'Validity (%)': validity,
|
| 170 |
+
'Uniqueness (%)': uniqueness,
|
| 171 |
+
'QED': mean_qed,
|
| 172 |
+
'Synthetic Accessibility': mean_sa,
|
| 173 |
+
'Quality (%)': quality,
|
| 174 |
+
'Diversity': diversity,
|
| 175 |
+
'Sampling Time (s)': sampling_time,
|
| 176 |
+
'Num Generated': total_generated,
|
| 177 |
+
'Num Valid': num_valid,
|
| 178 |
+
'Num Unique': num_unique,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
return metrics, all_unique, qed_vals if num_unique > 0 else [], sa_vals if num_unique > 0 else []
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def main():
|
| 185 |
+
parser = argparse.ArgumentParser(description="Evaluate a finetuned mol checkpoint")
|
| 186 |
+
parser.add_argument('--checkpoint_path', type=str, required=True,
|
| 187 |
+
help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
|
| 188 |
+
parser.add_argument('--pretrained_ckpt', type=str,
|
| 189 |
+
default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt'),
|
| 190 |
+
help='Path to the pretrained base model checkpoint '
|
| 191 |
+
'(defaults to <repo>/pretrained/anylength_mol.ckpt)')
|
| 192 |
+
parser.add_argument('--num_samples', type=int, default=1000,
|
| 193 |
+
help='Number of molecules to sample')
|
| 194 |
+
parser.add_argument('--batch_size', type=int, default=50,
|
| 195 |
+
help='Batch size for sampling')
|
| 196 |
+
parser.add_argument('--max_length', type=int, default=256)
|
| 197 |
+
parser.add_argument('--total_num_steps', type=int, default=256)
|
| 198 |
+
parser.add_argument('--num_remasking', type=int, default=2)
|
| 199 |
+
parser.add_argument('--disable_planner', action='store_true',
|
| 200 |
+
help='If set, disable remasking during evaluation (matches training mode)')
|
| 201 |
+
parser.add_argument('--disable_insertion_planner', action='store_true',
|
| 202 |
+
help='If set, disable insertion quality filtering during evaluation')
|
| 203 |
+
parser.add_argument('--disable_unmasking_planner', action='store_true',
|
| 204 |
+
help='If set, disable unmasking confidence planner during evaluation')
|
| 205 |
+
parser.add_argument('--quality_threshold', type=float, default=0.5,
|
| 206 |
+
help='Threshold for insertion quality filtering during sampling')
|
| 207 |
+
parser.add_argument('--unmask_quality_threshold', type=float, default=None,
|
| 208 |
+
help='If set, gate unmasking remasking on confidence: remask clean '
|
| 209 |
+
'tokens whose remasking_conf < threshold (overrides the '
|
| 210 |
+
'schedule-driven count). Default None = schedule-driven behavior.')
|
| 211 |
+
parser.add_argument('--output_dir', type=str, default=None,
|
| 212 |
+
help='Directory to save results CSV. Defaults to checkpoint directory.')
|
| 213 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
| 214 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 215 |
+
args = parser.parse_args()
|
| 216 |
+
|
| 217 |
+
set_seed(args.seed, use_cuda=True)
|
| 218 |
+
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
| 219 |
+
|
| 220 |
+
print(f"Loading checkpoint: {args.checkpoint_path}")
|
| 221 |
+
print(f"Pretrained base: {args.pretrained_ckpt}")
|
| 222 |
+
print(f"Disable planner (no remasking): {args.disable_planner}")
|
| 223 |
+
print(f"Disable insertion planner: {args.disable_insertion_planner}")
|
| 224 |
+
print(f"Disable unmasking planner: {args.disable_unmasking_planner}")
|
| 225 |
+
|
| 226 |
+
policy_model, train_args, config = load_finetuned_model(
|
| 227 |
+
args.checkpoint_path, args.pretrained_ckpt, device=device
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
tokenizer = get_tokenizer()
|
| 231 |
+
score_func_names = ['qed', 'sa']
|
| 232 |
+
reward_model = MolScoringFunctions(score_func_names, device=device)
|
| 233 |
+
evaluator = Evaluator('diversity')
|
| 234 |
+
|
| 235 |
+
use_remasking = not args.disable_planner
|
| 236 |
+
disable_insertion_planner = args.disable_insertion_planner
|
| 237 |
+
disable_unmasking_planner = args.disable_unmasking_planner
|
| 238 |
+
|
| 239 |
+
# Map flags to quality_mode
|
| 240 |
+
if args.disable_planner:
|
| 241 |
+
quality_mode = "none"
|
| 242 |
+
elif args.disable_insertion_planner and args.disable_unmasking_planner:
|
| 243 |
+
quality_mode = "none"
|
| 244 |
+
elif args.disable_insertion_planner:
|
| 245 |
+
quality_mode = "unmasking_only"
|
| 246 |
+
elif args.disable_unmasking_planner:
|
| 247 |
+
quality_mode = "insertion_only"
|
| 248 |
+
else:
|
| 249 |
+
quality_mode = "both"
|
| 250 |
+
|
| 251 |
+
print(f"\nSampling {args.num_samples} molecules (quality_mode={quality_mode})...")
|
| 252 |
+
|
| 253 |
+
metrics, unique_smiles, qed_vals, sa_vals = evaluate_checkpoint(
|
| 254 |
+
policy_model=policy_model,
|
| 255 |
+
tokenizer=tokenizer,
|
| 256 |
+
reward_model=reward_model,
|
| 257 |
+
evaluator=evaluator,
|
| 258 |
+
num_samples=args.num_samples,
|
| 259 |
+
batch_size=args.batch_size,
|
| 260 |
+
max_length=args.max_length,
|
| 261 |
+
total_num_steps=args.total_num_steps,
|
| 262 |
+
quality_mode=quality_mode,
|
| 263 |
+
num_remasking=args.num_remasking,
|
| 264 |
+
quality_threshold=getattr(args, 'quality_threshold', 0.5),
|
| 265 |
+
unmask_quality_threshold=args.unmask_quality_threshold,
|
| 266 |
+
device=device,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Print summary table
|
| 270 |
+
print("\n" + "=" * 60)
|
| 271 |
+
print(" De Novo Small Molecule Generation Results")
|
| 272 |
+
print("=" * 60)
|
| 273 |
+
for k, v in metrics.items():
|
| 274 |
+
if isinstance(v, float):
|
| 275 |
+
print(f" {k:<30s}: {v:.4f}")
|
| 276 |
+
else:
|
| 277 |
+
print(f" {k:<30s}: {v}")
|
| 278 |
+
print("=" * 60)
|
| 279 |
+
|
| 280 |
+
# Save results
|
| 281 |
+
output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
|
| 282 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 283 |
+
|
| 284 |
+
if args.disable_planner:
|
| 285 |
+
tag = "no_planner"
|
| 286 |
+
elif args.disable_insertion_planner:
|
| 287 |
+
tag = "no_insertion_planner"
|
| 288 |
+
elif args.disable_unmasking_planner:
|
| 289 |
+
tag = "no_unmasking_planner"
|
| 290 |
+
else:
|
| 291 |
+
tag = "with_planner"
|
| 292 |
+
metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}.csv')
|
| 293 |
+
pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
|
| 294 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 295 |
+
|
| 296 |
+
if unique_smiles:
|
| 297 |
+
smiles_path = os.path.join(output_dir, f'eval_smiles_{tag}.csv')
|
| 298 |
+
df = pd.DataFrame({
|
| 299 |
+
'SMILES': unique_smiles,
|
| 300 |
+
'QED': qed_vals,
|
| 301 |
+
'SA': sa_vals,
|
| 302 |
+
})
|
| 303 |
+
df.to_csv(smiles_path, index=False)
|
| 304 |
+
print(f"SMILES saved to: {smiles_path}")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == '__main__':
|
| 308 |
+
main()
|
a2d2_mol/finetune_mol.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 8 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 9 |
+
import wandb
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
|
| 16 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
# imports
|
| 19 |
+
from inference_quality_mol import sample_mol_buffer, sample_mol_eval
|
| 20 |
+
from mol_utils.utils import str2bool, set_seed
|
| 21 |
+
from mol_scoring.scoring_functions import MolScoringFunctions
|
| 22 |
+
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
|
| 23 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 24 |
+
from safe.tokenizer import SAFETokenizer
|
| 25 |
+
from tdc import Evaluator
|
| 26 |
+
|
| 27 |
+
# Repository root (two levels up from this file: A2D2/a2d2_mol/finetune_mol.py)
|
| 28 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_tokenizer():
|
| 32 |
+
"""Get SAFE tokenizer with added special tokens."""
|
| 33 |
+
tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
|
| 34 |
+
tk.add_tokens(['<', '>']) # for bracket_safe
|
| 35 |
+
return tk
|
| 36 |
+
|
| 37 |
+
class MolFinetuner(pl.LightningModule):
|
| 38 |
+
"""Lightning module for distributed molecule finetuning."""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
args,
|
| 43 |
+
policy_model,
|
| 44 |
+
reward_model,
|
| 45 |
+
tokenizer,
|
| 46 |
+
pretrained=None,
|
| 47 |
+
mcts=None,
|
| 48 |
+
filename=None,
|
| 49 |
+
eps=1e-5
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.args = args
|
| 53 |
+
self.policy_model = policy_model
|
| 54 |
+
self.reward_model = reward_model
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.pretrained = pretrained
|
| 57 |
+
self.mcts = mcts
|
| 58 |
+
self.filename = filename
|
| 59 |
+
self.eps = eps
|
| 60 |
+
|
| 61 |
+
self.evaluator = Evaluator("diversity")
|
| 62 |
+
|
| 63 |
+
# Save hyperparameters
|
| 64 |
+
self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
|
| 65 |
+
|
| 66 |
+
# Buffer for sequences
|
| 67 |
+
self.x_saved = None
|
| 68 |
+
self.log_rnd_saved = None
|
| 69 |
+
self.final_rewards_saved = None
|
| 70 |
+
|
| 71 |
+
# initialize logs
|
| 72 |
+
self.valid_fraction_log = []
|
| 73 |
+
self.diversity_log = []
|
| 74 |
+
self.qed_log = []
|
| 75 |
+
self.sa_log = []
|
| 76 |
+
self.quality_log = []
|
| 77 |
+
self.uniqueness_log = []
|
| 78 |
+
|
| 79 |
+
# Alternating training between policy and planner
|
| 80 |
+
self.train_policy = True # Start by training policy
|
| 81 |
+
self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
|
| 82 |
+
|
| 83 |
+
def freeze_policy_model(self):
|
| 84 |
+
"""Freeze policy model parameters (but not planner)."""
|
| 85 |
+
for name, param in self.policy_model.named_parameters():
|
| 86 |
+
if not name.startswith('planner.'):
|
| 87 |
+
param.requires_grad = False
|
| 88 |
+
|
| 89 |
+
def unfreeze_policy_model(self):
|
| 90 |
+
"""Unfreeze policy model parameters (but not planner)."""
|
| 91 |
+
for name, param in self.policy_model.named_parameters():
|
| 92 |
+
if not name.startswith('planner.'):
|
| 93 |
+
param.requires_grad = True
|
| 94 |
+
|
| 95 |
+
def freeze_planner_model(self):
|
| 96 |
+
"""Freeze planner parameters."""
|
| 97 |
+
if hasattr(self.policy_model, 'planner'):
|
| 98 |
+
for param in self.policy_model.planner.parameters():
|
| 99 |
+
param.requires_grad = False
|
| 100 |
+
|
| 101 |
+
def unfreeze_planner_model(self):
|
| 102 |
+
"""Unfreeze planner parameters."""
|
| 103 |
+
if hasattr(self.policy_model, 'planner'):
|
| 104 |
+
for param in self.policy_model.planner.parameters():
|
| 105 |
+
param.requires_grad = True
|
| 106 |
+
|
| 107 |
+
def configure_optimizers(self):
|
| 108 |
+
# Separate parameter groups for policy backbone vs planner heads
|
| 109 |
+
planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
|
| 110 |
+
planner_params = []
|
| 111 |
+
policy_params = []
|
| 112 |
+
for name, param in self.policy_model.named_parameters():
|
| 113 |
+
if name.startswith('planner.'):
|
| 114 |
+
planner_params.append(param)
|
| 115 |
+
else:
|
| 116 |
+
policy_params.append(param)
|
| 117 |
+
|
| 118 |
+
param_groups = [
|
| 119 |
+
{'params': policy_params, 'lr': self.args.learning_rate},
|
| 120 |
+
{'params': planner_params, 'lr': planner_lr},
|
| 121 |
+
]
|
| 122 |
+
optimizer = torch.optim.AdamW(param_groups)
|
| 123 |
+
return optimizer
|
| 124 |
+
|
| 125 |
+
def _get_quality_mode(self):
|
| 126 |
+
"""Map ablation flags + warmup state to quality_mode string."""
|
| 127 |
+
if self.args.disable_planner:
|
| 128 |
+
return "none"
|
| 129 |
+
if self.current_epoch < self.args.schedule_warmup_epochs:
|
| 130 |
+
return "none"
|
| 131 |
+
di = getattr(self.args, 'disable_insertion_planner', False)
|
| 132 |
+
du = getattr(self.args, 'disable_unmasking_planner', False)
|
| 133 |
+
if di and du:
|
| 134 |
+
return "none"
|
| 135 |
+
if di:
|
| 136 |
+
return "unmasking_only"
|
| 137 |
+
if du:
|
| 138 |
+
return "insertion_only"
|
| 139 |
+
return "both"
|
| 140 |
+
|
| 141 |
+
def on_train_epoch_start(self):
|
| 142 |
+
"""Called at the start of each training epoch."""
|
| 143 |
+
|
| 144 |
+
# If disable_planner mode, only train policy (no alternation)
|
| 145 |
+
if self.args.disable_planner:
|
| 146 |
+
self.train_policy = True
|
| 147 |
+
self.unfreeze_policy_model()
|
| 148 |
+
self.freeze_planner_model()
|
| 149 |
+
if self.global_rank == 0 and self.current_epoch == 0:
|
| 150 |
+
print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
|
| 151 |
+
|
| 152 |
+
elif getattr(self.args, 'joint_training', False):
|
| 153 |
+
# Joint mode: train policy + planner together every step (no alternation)
|
| 154 |
+
self.train_policy = True # marker; training_step adds planner loss when joint_training is set
|
| 155 |
+
self.unfreeze_policy_model()
|
| 156 |
+
self.unfreeze_planner_model()
|
| 157 |
+
if self.global_rank == 0 and self.current_epoch == 0:
|
| 158 |
+
print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
# Alternate between training policy and planner from epoch 0
|
| 162 |
+
# Determine which model to train this epoch
|
| 163 |
+
cycle_position = (self.current_epoch // self.alternation_frequency) % 2
|
| 164 |
+
self.train_policy = (cycle_position == 0)
|
| 165 |
+
|
| 166 |
+
if self.train_policy:
|
| 167 |
+
# Train policy, freeze planner
|
| 168 |
+
self.unfreeze_policy_model()
|
| 169 |
+
self.freeze_planner_model()
|
| 170 |
+
if self.global_rank == 0:
|
| 171 |
+
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
|
| 172 |
+
else:
|
| 173 |
+
# Train planner, freeze policy
|
| 174 |
+
self.freeze_policy_model()
|
| 175 |
+
self.unfreeze_planner_model()
|
| 176 |
+
if self.global_rank == 0:
|
| 177 |
+
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
|
| 178 |
+
|
| 179 |
+
# Resample buffer if needed
|
| 180 |
+
if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
|
| 181 |
+
if self.global_rank == 0:
|
| 182 |
+
print(f"[BUFFER] Starting buffer generation for epoch {self.current_epoch}")
|
| 183 |
+
self._generate_buffer()
|
| 184 |
+
# Synchronize all ranks after buffer generation
|
| 185 |
+
if self.trainer and self.trainer.world_size > 1:
|
| 186 |
+
if self.global_rank == 0:
|
| 187 |
+
print(f"[BUFFER] All ranks completed buffer generation, synchronizing...")
|
| 188 |
+
torch.distributed.barrier()
|
| 189 |
+
if self.global_rank == 0:
|
| 190 |
+
print(f"[BUFFER] Synchronization complete!")
|
| 191 |
+
|
| 192 |
+
def _generate_buffer(self):
|
| 193 |
+
"""Generate buffer of sequences for training.
|
| 194 |
+
|
| 195 |
+
When pool_size > 0, maintains a persistent pool and refreshes a fraction
|
| 196 |
+
each time instead of regenerating the entire buffer from scratch.
|
| 197 |
+
"""
|
| 198 |
+
rank = self.global_rank if self.trainer else 0
|
| 199 |
+
world_size = self.trainer.world_size if self.trainer else 1
|
| 200 |
+
|
| 201 |
+
pool_size = getattr(self.args, 'pool_size', 0)
|
| 202 |
+
is_pool = pool_size > 0
|
| 203 |
+
is_init = self.x_saved is None
|
| 204 |
+
|
| 205 |
+
# Determine how many molecules to sample this call
|
| 206 |
+
if is_pool:
|
| 207 |
+
refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
|
| 208 |
+
if is_init:
|
| 209 |
+
samples_per_gpu = pool_size
|
| 210 |
+
else:
|
| 211 |
+
samples_per_gpu = max(1, int(pool_size * refresh_frac))
|
| 212 |
+
if rank == 0:
|
| 213 |
+
if is_init:
|
| 214 |
+
print(f"\n[POOL] Initializing pool with {pool_size} molecules at epoch {self.current_epoch}")
|
| 215 |
+
else:
|
| 216 |
+
print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} molecules ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
|
| 217 |
+
else:
|
| 218 |
+
samples_per_gpu = self.args.buffer_size // world_size
|
| 219 |
+
if rank == 0:
|
| 220 |
+
samples_per_gpu += self.args.buffer_size % world_size
|
| 221 |
+
|
| 222 |
+
if rank == 0:
|
| 223 |
+
print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
|
| 224 |
+
|
| 225 |
+
accumulated_x = []
|
| 226 |
+
accumulated_log_rnd = []
|
| 227 |
+
accumulated_rewards = []
|
| 228 |
+
total_accumulated = 0
|
| 229 |
+
|
| 230 |
+
max_attempts = 100 # Prevent infinite loop
|
| 231 |
+
attempts = 0
|
| 232 |
+
|
| 233 |
+
import time
|
| 234 |
+
while total_accumulated < samples_per_gpu and attempts < max_attempts:
|
| 235 |
+
attempts += 1
|
| 236 |
+
if rank == 0:
|
| 237 |
+
print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
|
| 238 |
+
|
| 239 |
+
start_time = time.time()
|
| 240 |
+
|
| 241 |
+
x_final, log_rnd, final_rewards, trace = \
|
| 242 |
+
sample_mol_buffer(
|
| 243 |
+
self.policy_model,
|
| 244 |
+
self.pretrained,
|
| 245 |
+
self.reward_model,
|
| 246 |
+
self.tokenizer,
|
| 247 |
+
steps=self.args.total_num_steps,
|
| 248 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 249 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 250 |
+
batch_size=self.args.batch_size,
|
| 251 |
+
max_length=self.args.max_length,
|
| 252 |
+
quality_mode=self._get_quality_mode(),
|
| 253 |
+
alpha=self.args.alpha,
|
| 254 |
+
num_remasking=self.args.num_remasking,
|
| 255 |
+
quality_threshold=self.args.quality_threshold,
|
| 256 |
+
use_quality_filter=self.args.use_quality_filter,
|
| 257 |
+
)
|
| 258 |
+
if self.args.elbo_rnd:
|
| 259 |
+
# Override trajectory log_rnd with forward ELBO estimate
|
| 260 |
+
if x_final.shape[0] > 0:
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
noised = self.policy_model.prepare_noised_sample(
|
| 263 |
+
x_final, num_samples=self.args.elbo_rnd_num_samples)
|
| 264 |
+
policy_loss = self.policy_model.compute_loss_from_noised(noised)
|
| 265 |
+
pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
|
| 266 |
+
log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
|
| 267 |
+
|
| 268 |
+
elapsed = time.time() - start_time
|
| 269 |
+
if rank == 0:
|
| 270 |
+
print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
|
| 271 |
+
|
| 272 |
+
n_valid = x_final.shape[0]
|
| 273 |
+
if n_valid > 0:
|
| 274 |
+
accumulated_x.append(x_final)
|
| 275 |
+
accumulated_log_rnd.append(log_rnd)
|
| 276 |
+
accumulated_rewards.append(final_rewards)
|
| 277 |
+
total_accumulated += n_valid
|
| 278 |
+
|
| 279 |
+
if rank == 0:
|
| 280 |
+
qm = self._get_quality_mode()
|
| 281 |
+
print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode={qm} accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
|
| 282 |
+
|
| 283 |
+
if total_accumulated == 0:
|
| 284 |
+
raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
|
| 285 |
+
|
| 286 |
+
if total_accumulated < samples_per_gpu:
|
| 287 |
+
print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
|
| 288 |
+
|
| 289 |
+
new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
|
| 290 |
+
new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
|
| 291 |
+
new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
|
| 292 |
+
|
| 293 |
+
del accumulated_x, accumulated_log_rnd, accumulated_rewards
|
| 294 |
+
torch.cuda.empty_cache()
|
| 295 |
+
|
| 296 |
+
# add to buffer: pool mode replaces a random subset, classic mode overwrites
|
| 297 |
+
if is_pool and not is_init:
|
| 298 |
+
actual_new = min(new_x.shape[0], self.x_saved.shape[0])
|
| 299 |
+
indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
|
| 300 |
+
self.x_saved[indices] = new_x[:actual_new]
|
| 301 |
+
self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
|
| 302 |
+
self.final_rewards_saved[indices] = new_rewards[:actual_new]
|
| 303 |
+
if rank == 0:
|
| 304 |
+
print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} molecules, reward mean={self.final_rewards_saved.mean():.4f}")
|
| 305 |
+
else:
|
| 306 |
+
self.x_saved = new_x
|
| 307 |
+
self.log_rnd_saved = new_log_rnd
|
| 308 |
+
self.final_rewards_saved = new_rewards
|
| 309 |
+
|
| 310 |
+
if rank == 0:
|
| 311 |
+
print(f"[BUFFER] After cleanup - GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
| 312 |
+
|
| 313 |
+
def training_step(self, batch, batch_idx):
|
| 314 |
+
"""Training step - batch is ignored, we use saved buffer."""
|
| 315 |
+
# Process buffer in mini-batches to avoid OOM
|
| 316 |
+
mini_batch_size = getattr(self.args, 'training_mini_batch_size', 8)
|
| 317 |
+
buffer_size = self.x_saved.shape[0]
|
| 318 |
+
|
| 319 |
+
# Randomly sample a mini-batch from buffer
|
| 320 |
+
indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
|
| 321 |
+
x_final = self.x_saved[indices]
|
| 322 |
+
|
| 323 |
+
# get log_rnd values
|
| 324 |
+
log_rnd = self.log_rnd_saved[indices]
|
| 325 |
+
|
| 326 |
+
sm_temp = getattr(self.args, 'softmax_temperature', 1.0)
|
| 327 |
+
|
| 328 |
+
joint = getattr(self.args, 'joint_training', False)
|
| 329 |
+
policy_loss = None
|
| 330 |
+
planner_loss = None
|
| 331 |
+
|
| 332 |
+
if self.train_policy:
|
| 333 |
+
# Train policy with WDCE loss
|
| 334 |
+
policy_loss = self.policy_model.loss_wdce_flexible(
|
| 335 |
+
log_rnd,
|
| 336 |
+
x_final,
|
| 337 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 338 |
+
centering=self.args.centering,
|
| 339 |
+
centering_strength=self.args.centering_strength,
|
| 340 |
+
softmax_temperature=sm_temp,
|
| 341 |
+
)
|
| 342 |
+
self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 343 |
+
|
| 344 |
+
if (not self.train_policy) or joint:
|
| 345 |
+
# Train planner with appropriate loss based on ablation flags
|
| 346 |
+
if self.args.disable_insertion_planner:
|
| 347 |
+
# Ablation: only train unmasking planner (no insertion head)
|
| 348 |
+
planner_loss = self.policy_model.loss_planner_flexible(
|
| 349 |
+
log_rnd,
|
| 350 |
+
x_final,
|
| 351 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 352 |
+
centering=self.args.centering,
|
| 353 |
+
centering_strength=self.args.centering_strength,
|
| 354 |
+
softmax_temperature=sm_temp,
|
| 355 |
+
)
|
| 356 |
+
self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 357 |
+
self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 358 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 359 |
+
elif self.args.disable_unmasking_planner:
|
| 360 |
+
# only train insertion planner (no remasking head)
|
| 361 |
+
unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
|
| 362 |
+
log_rnd,
|
| 363 |
+
x_final,
|
| 364 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 365 |
+
centering=self.args.centering,
|
| 366 |
+
centering_strength=self.args.centering_strength,
|
| 367 |
+
softmax_temperature=sm_temp,
|
| 368 |
+
)
|
| 369 |
+
# Zero out the unmasking component - only backprop insertion loss
|
| 370 |
+
planner_loss = insert_loss
|
| 371 |
+
self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 372 |
+
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 373 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 374 |
+
else:
|
| 375 |
+
# Full planner: train both remasking + insertion
|
| 376 |
+
unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
|
| 377 |
+
log_rnd,
|
| 378 |
+
x_final,
|
| 379 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 380 |
+
centering=self.args.centering,
|
| 381 |
+
centering_strength=self.args.centering_strength,
|
| 382 |
+
softmax_temperature=sm_temp,
|
| 383 |
+
)
|
| 384 |
+
self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 385 |
+
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 386 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 387 |
+
|
| 388 |
+
# Combine losses depending on mode
|
| 389 |
+
if joint:
|
| 390 |
+
loss = policy_loss + planner_loss
|
| 391 |
+
mode_value = 0.5
|
| 392 |
+
elif self.train_policy:
|
| 393 |
+
loss = policy_loss
|
| 394 |
+
mode_value = 0.0
|
| 395 |
+
else:
|
| 396 |
+
loss = planner_loss
|
| 397 |
+
mode_value = 1.0
|
| 398 |
+
|
| 399 |
+
# Log overall loss and mode
|
| 400 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 401 |
+
self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
|
| 402 |
+
|
| 403 |
+
return loss
|
| 404 |
+
|
| 405 |
+
def on_train_epoch_end(self):
|
| 406 |
+
"""Called at the end of each training epoch - only rank 0 evaluates."""
|
| 407 |
+
# Only evaluate every N epochs to save time
|
| 408 |
+
eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
|
| 409 |
+
is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
|
| 410 |
+
if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
|
| 411 |
+
# Sample eval batch with updated policy
|
| 412 |
+
x_eval, qed, sa, uniqueness, diversity, quality, valid_fraction = \
|
| 413 |
+
sample_mol_eval(
|
| 414 |
+
self.policy_model, self.reward_model,
|
| 415 |
+
self.tokenizer,
|
| 416 |
+
steps=self.args.total_num_steps,
|
| 417 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 418 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 419 |
+
batch_size=50,
|
| 420 |
+
max_length=self.args.max_length,
|
| 421 |
+
quality_mode=self._get_quality_mode(),
|
| 422 |
+
num_remasking=self.args.num_remasking,
|
| 423 |
+
evaluator=self.evaluator,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Append to logs
|
| 427 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 428 |
+
self.uniqueness_log.append(uniqueness)
|
| 429 |
+
self.diversity_log.append(diversity)
|
| 430 |
+
self.qed_log.append(qed)
|
| 431 |
+
self.sa_log.append(sa)
|
| 432 |
+
self.quality_log.append(quality)
|
| 433 |
+
|
| 434 |
+
# Compute reward stats
|
| 435 |
+
mean_reward = self.final_rewards_saved.mean().item()
|
| 436 |
+
min_reward = self.final_rewards_saved.min().item()
|
| 437 |
+
max_reward = self.final_rewards_saved.max().item()
|
| 438 |
+
median_reward = self.final_rewards_saved.median().item()
|
| 439 |
+
|
| 440 |
+
# Log metrics
|
| 441 |
+
self.log_dict({
|
| 442 |
+
"eval/valid_fraction": valid_fraction,
|
| 443 |
+
"eval/uniqueness": np.mean(uniqueness),
|
| 444 |
+
"eval/diversity": np.mean(diversity),
|
| 445 |
+
"eval/qed": np.mean(qed),
|
| 446 |
+
"eval/sa": np.mean(sa),
|
| 447 |
+
"eval/quality": np.mean(quality),
|
| 448 |
+
"eval/mean_reward_search": mean_reward,
|
| 449 |
+
"eval/min_reward_search": min_reward,
|
| 450 |
+
"eval/max_reward_search": max_reward,
|
| 451 |
+
"eval/median_reward_search": median_reward
|
| 452 |
+
})
|
| 453 |
+
|
| 454 |
+
print(f"epoch {self.current_epoch} | validity {valid_fraction:.4f} | uniqueness {np.mean(uniqueness):.4f} | diversity {np.mean(diversity):.4f} | "
|
| 455 |
+
f"QED {np.mean(qed):.4f} | SA {np.mean(sa):.4f} | quality {np.mean(quality):.4f} | ")
|
| 456 |
+
|
| 457 |
+
def on_fit_end(self):
|
| 458 |
+
"""Called at the end of training - save results."""
|
| 459 |
+
if self.global_rank == 0:
|
| 460 |
+
# Save logs and plot
|
| 461 |
+
base_path = self.args.base_path
|
| 462 |
+
plot_path = f'{base_path}/results/{self.args.run_name}'
|
| 463 |
+
os.makedirs(plot_path, exist_ok=True)
|
| 464 |
+
|
| 465 |
+
output_log_path = f'{plot_path}/log_{self.filename}.csv'
|
| 466 |
+
save_logs_to_file(self.valid_fraction_log, self.uniqueness_log,
|
| 467 |
+
self.diversity_log, self.qed_log, self.sa_log,
|
| 468 |
+
self.quality_log, output_log_path)
|
| 469 |
+
|
| 470 |
+
# Final generation
|
| 471 |
+
x_eval, qed, sa, valid_fraction, uniqueness, diversity, quality, df = \
|
| 472 |
+
sample_mol_eval(
|
| 473 |
+
self.policy_model, self.reward_model,
|
| 474 |
+
self.tokenizer,
|
| 475 |
+
steps=self.args.total_num_steps,
|
| 476 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 477 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 478 |
+
batch_size=50,
|
| 479 |
+
max_length=self.args.max_length,
|
| 480 |
+
quality_mode=self._get_quality_mode(),
|
| 481 |
+
num_remasking=self.args.num_remasking,
|
| 482 |
+
evaluator=self.evaluator,
|
| 483 |
+
dataframe=True,
|
| 484 |
+
)
|
| 485 |
+
df.to_csv(f'{plot_path}/mol_generation_results.csv', index=False)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def save_logs_to_file(valid_fraction_log, uniqueness_log,
|
| 489 |
+
diversity_log, qed_log, sa_log,
|
| 490 |
+
quality_log, output_path):
|
| 491 |
+
"""
|
| 492 |
+
Saves the logs to a CSV file.
|
| 493 |
+
"""
|
| 494 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 495 |
+
|
| 496 |
+
log_data = {
|
| 497 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 498 |
+
"Valid Fraction": valid_fraction_log,
|
| 499 |
+
"Uniqueness": uniqueness_log,
|
| 500 |
+
"Diversity": diversity_log,
|
| 501 |
+
"QED": qed_log,
|
| 502 |
+
"Synthetic Accessibility": sa_log,
|
| 503 |
+
"Quality": quality_log
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
df = pd.DataFrame(log_data)
|
| 507 |
+
df.to_csv(output_path, index=False)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class DummyDataset(torch.utils.data.Dataset):
|
| 511 |
+
"""Dummy dataset for Lightning trainer (we use buffer instead)."""
|
| 512 |
+
def __init__(self, size=100):
|
| 513 |
+
self.size = size
|
| 514 |
+
|
| 515 |
+
def __len__(self):
|
| 516 |
+
return self.size
|
| 517 |
+
|
| 518 |
+
def __getitem__(self, idx):
|
| 519 |
+
return torch.zeros(1) # Dummy data
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def main():
|
| 523 |
+
"""Main entry point for distributed training."""
|
| 524 |
+
# Disable DDP optimizer for higher-order ops like flex_attention
|
| 525 |
+
import torch._dynamo
|
| 526 |
+
torch._dynamo.config.optimize_ddp = False
|
| 527 |
+
|
| 528 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 529 |
+
argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
|
| 530 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 531 |
+
argparser.add_argument('--num_epochs', type=int, default=100)
|
| 532 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 533 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 534 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 535 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 536 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 537 |
+
argparser.add_argument('--batch_size', type=int, default=50)
|
| 538 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 539 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 540 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 541 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
|
| 542 |
+
argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
|
| 543 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 544 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 545 |
+
# new
|
| 546 |
+
argparser.add_argument('--run_name', type=str, default='mol')
|
| 547 |
+
argparser.add_argument("--save_path_dir", default="", type=str)
|
| 548 |
+
# mcts
|
| 549 |
+
argparser.add_argument('--num_sequences', type=int, default=10)
|
| 550 |
+
argparser.add_argument('--max_length', type=int, default=1024)
|
| 551 |
+
argparser.add_argument('--num_children', type=int, default=50)
|
| 552 |
+
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
|
| 553 |
+
argparser.add_argument('--seq_length', type=int, default=1024)
|
| 554 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 555 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 556 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 557 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 558 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 559 |
+
argparser.add_argument('--grad_clip', action='store_true', default=False)
|
| 560 |
+
argparser.add_argument('--resample_every_n_step', type=int, default=3)
|
| 561 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 562 |
+
argparser.add_argument('--reset_every_n_step', type=int, default=100)
|
| 563 |
+
argparser.add_argument('--alpha', type=float, default=0.01)
|
| 564 |
+
argparser.add_argument('--scalarization', type=str, default='sum')
|
| 565 |
+
argparser.add_argument('--no_mcts', action='store_true', default=False)
|
| 566 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 567 |
+
argparser.add_argument("--centering_strength", type=float, default=1.0)
|
| 568 |
+
|
| 569 |
+
# adaptive schedule parameters
|
| 570 |
+
argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
|
| 571 |
+
argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
|
| 572 |
+
argparser.add_argument('--schedule_num_layers', type=int, default=2)
|
| 573 |
+
argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
|
| 574 |
+
argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
|
| 575 |
+
argparser.add_argument('--freeze_base_model', action='store_true', default=False)
|
| 576 |
+
argparser.add_argument('--schedule_warmup_epochs', type=int, default=20, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
|
| 577 |
+
argparser.add_argument('--alternation_frequency', type=int, default=5, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
|
| 578 |
+
argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
|
| 579 |
+
|
| 580 |
+
# objectives
|
| 581 |
+
argparser.add_argument('--num_obj', type=int, default=2)
|
| 582 |
+
argparser.add_argument('--devices', type=int, default=-1)
|
| 583 |
+
argparser.add_argument('--checkpoint_path', type=str, default=None)
|
| 584 |
+
|
| 585 |
+
# ELBO-based log_rnd estimation
|
| 586 |
+
argparser.add_argument('--elbo_rnd', action='store_true', default=False,
|
| 587 |
+
help='If set, compute log_rnd via forward ELBO instead of trajectory rollout')
|
| 588 |
+
argparser.add_argument('--elbo_rnd_num_samples', type=int, default=4,
|
| 589 |
+
help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
|
| 590 |
+
|
| 591 |
+
# remasking
|
| 592 |
+
argparser.add_argument('--num_remasking', type=int, default=5)
|
| 593 |
+
argparser.add_argument('--quality_threshold', type=float, default=1)
|
| 594 |
+
argparser.add_argument('--use_quality_filter', action='store_true', help='If set, filter buffer to only include molecules with QED>=0.6 and SA<=4')
|
| 595 |
+
argparser.add_argument('--training_mini_batch_size', type=int, default=8, help='Mini-batch size for training step to avoid OOM')
|
| 596 |
+
argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
|
| 597 |
+
argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
|
| 598 |
+
argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
|
| 599 |
+
argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
|
| 600 |
+
argparser.add_argument('--qed_only', action='store_true', help='If set, optimize only for QED score (no SA)')
|
| 601 |
+
argparser.add_argument('--softmax_temperature', type=float, default=1.0,
|
| 602 |
+
help='Temperature for softmax on importance weights (>1 smooths, prevents concentration)')
|
| 603 |
+
argparser.add_argument('--pool_size', type=int, default=0,
|
| 604 |
+
help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer)')
|
| 605 |
+
argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
|
| 606 |
+
help='Fraction of pool to replace each resample step (only used when pool_size>0)')
|
| 607 |
+
argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10,
|
| 608 |
+
help='Number of gradient updates per epoch (1=original, 10=recommended)')
|
| 609 |
+
|
| 610 |
+
args = argparser.parse_args()
|
| 611 |
+
|
| 612 |
+
# Default planner LR to policy LR if not specified
|
| 613 |
+
if args.planner_learning_rate is None:
|
| 614 |
+
args.planner_learning_rate = args.learning_rate
|
| 615 |
+
|
| 616 |
+
# Set seed
|
| 617 |
+
pl.seed_everything(args.seed)
|
| 618 |
+
|
| 619 |
+
# Load models
|
| 620 |
+
checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
|
| 621 |
+
os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt')
|
| 622 |
+
|
| 623 |
+
curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 624 |
+
|
| 625 |
+
if args.no_mcts:
|
| 626 |
+
args.run_name = f'mol_al_resample{args.resample_every_n_step}_no-mcts_{curr_time}'
|
| 627 |
+
else:
|
| 628 |
+
args.run_name = f'mol_al_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
|
| 629 |
+
|
| 630 |
+
# append ablation tags to run name for easy identification
|
| 631 |
+
if args.disable_planner:
|
| 632 |
+
args.run_name += '_no_planner'
|
| 633 |
+
if args.disable_insertion_planner:
|
| 634 |
+
args.run_name += '_no_insertion_planner'
|
| 635 |
+
if args.disable_unmasking_planner:
|
| 636 |
+
args.run_name += '_no_unmasking_planner'
|
| 637 |
+
if args.joint_training:
|
| 638 |
+
if args.disable_planner:
|
| 639 |
+
raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
|
| 640 |
+
args.run_name += '_joint_training'
|
| 641 |
+
|
| 642 |
+
args.save_path = os.path.join(args.save_path_dir, args.run_name)
|
| 643 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 644 |
+
set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
|
| 645 |
+
|
| 646 |
+
# Initialize the model
|
| 647 |
+
print("Loading models..")
|
| 648 |
+
|
| 649 |
+
# Load pretrained model for reference (frozen)
|
| 650 |
+
pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
|
| 651 |
+
map_location='cpu',
|
| 652 |
+
weights_only=False)
|
| 653 |
+
pretrained.eval()
|
| 654 |
+
for param in pretrained.parameters():
|
| 655 |
+
param.requires_grad = False
|
| 656 |
+
|
| 657 |
+
# Load checkpoint to extract config
|
| 658 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 659 |
+
if 'hyper_parameters' in checkpoint:
|
| 660 |
+
config = checkpoint['hyper_parameters']['config']
|
| 661 |
+
elif 'config' in checkpoint:
|
| 662 |
+
config = checkpoint['config']
|
| 663 |
+
else:
|
| 664 |
+
raise ValueError("Cannot find config in checkpoint")
|
| 665 |
+
|
| 666 |
+
# Update config for adaptive schedule
|
| 667 |
+
from omegaconf import OmegaConf
|
| 668 |
+
if not OmegaConf.is_config(config):
|
| 669 |
+
from omegaconf import DictConfig
|
| 670 |
+
config = DictConfig(config)
|
| 671 |
+
|
| 672 |
+
OmegaConf.set_struct(config, False)
|
| 673 |
+
|
| 674 |
+
config.training.use_adaptive_schedule = args.use_adaptive_schedule
|
| 675 |
+
config.training.schedule_hidden_dim = args.schedule_hidden_dim
|
| 676 |
+
config.training.schedule_num_layers = args.schedule_num_layers
|
| 677 |
+
config.training.schedule_loss_weight = args.schedule_loss_weight
|
| 678 |
+
config.training.freeze_base_model = args.freeze_base_model
|
| 679 |
+
config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
|
| 680 |
+
config.training.use_bracket_safe = True
|
| 681 |
+
|
| 682 |
+
OmegaConf.set_struct(config, True)
|
| 683 |
+
|
| 684 |
+
# initialize policy model with adaptive schedule
|
| 685 |
+
policy_model = AnyOrderInsertionFlowModuleFT(
|
| 686 |
+
config=config,
|
| 687 |
+
args=args,
|
| 688 |
+
pretrained_checkpoint=checkpoint_path,
|
| 689 |
+
insertion_planner=True,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# define mcts
|
| 693 |
+
if args.qed_only:
|
| 694 |
+
score_func_names = ['qed']
|
| 695 |
+
else:
|
| 696 |
+
score_func_names = ['qed', 'sa']
|
| 697 |
+
|
| 698 |
+
tokenizer = get_tokenizer()
|
| 699 |
+
|
| 700 |
+
filename = args.run_name
|
| 701 |
+
|
| 702 |
+
# Device will be set by Lightning automatically in DDP
|
| 703 |
+
reward_model = MolScoringFunctions(score_func_names, device='cpu')
|
| 704 |
+
model = MolFinetuner(
|
| 705 |
+
args=args,
|
| 706 |
+
policy_model=policy_model,
|
| 707 |
+
reward_model=reward_model,
|
| 708 |
+
tokenizer=tokenizer,
|
| 709 |
+
pretrained=pretrained,
|
| 710 |
+
mcts=None,
|
| 711 |
+
filename=filename,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
checkpoint_callback = ModelCheckpoint(
|
| 715 |
+
dirpath=args.save_path,
|
| 716 |
+
filename='model-{epoch:02d}-{train_loss:.4f}',
|
| 717 |
+
every_n_epochs=args.save_every_n_epochs,
|
| 718 |
+
save_top_k=-1, # Save all checkpoints
|
| 719 |
+
save_last=True, # Also save last.ckpt
|
| 720 |
+
auto_insert_metric_name=False
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
|
| 724 |
+
wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-mol', name=args.run_name)
|
| 725 |
+
|
| 726 |
+
# create dummy dataloader
|
| 727 |
+
dataset = DummyDataset(size=args.num_training_steps_per_epoch)
|
| 728 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
| 729 |
+
|
| 730 |
+
# setup trainer with DDP
|
| 731 |
+
trainer = pl.Trainer(
|
| 732 |
+
max_epochs=args.num_epochs,
|
| 733 |
+
accelerator='gpu',
|
| 734 |
+
devices=args.devices,
|
| 735 |
+
strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
|
| 736 |
+
gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
|
| 737 |
+
logger=wandb_logger,
|
| 738 |
+
callbacks=[checkpoint_callback],
|
| 739 |
+
enable_progress_bar=True,
|
| 740 |
+
log_every_n_steps=1
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Train
|
| 744 |
+
trainer.fit(model, dataloader)
|
| 745 |
+
|
| 746 |
+
if __name__ == '__main__':
|
| 747 |
+
main()
|
a2d2_mol/inference_quality_mol.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified molecule sampling with quality-guided planning.
|
| 2 |
+
|
| 3 |
+
Supports 4 quality modes and optional RND (importance weight) computation.
|
| 4 |
+
|
| 5 |
+
Quality modes:
|
| 6 |
+
"none" - No planner, no remasking (policy-only)
|
| 7 |
+
"both" - Both unmasking + insertion planners active
|
| 8 |
+
"unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
|
| 9 |
+
"insertion_only" - Only insertion planner (unmasking planner disabled)
|
| 10 |
+
|
| 11 |
+
RND toggle:
|
| 12 |
+
compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
|
| 13 |
+
compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
|
| 21 |
+
from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
|
| 22 |
+
from mol_utils.utils_chem import batch_safe_to_smiles, batch_validate_and_extract
|
| 23 |
+
from tdc import Evaluator, Oracle
|
| 24 |
+
|
| 25 |
+
QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def _diffusion_loop(
|
| 30 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 31 |
+
quality_mode="both",
|
| 32 |
+
compute_rnd=False,
|
| 33 |
+
pretrained=None,
|
| 34 |
+
remasking_mode="schedule_aware",
|
| 35 |
+
num_remasking=1,
|
| 36 |
+
quality_threshold=1,
|
| 37 |
+
temperature=1.0,
|
| 38 |
+
return_trace=False,
|
| 39 |
+
unmask_quality_threshold=None,
|
| 40 |
+
):
|
| 41 |
+
"""Core discrete diffusion sampling loop for molecule generation.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model: Finetuned policy model.
|
| 45 |
+
steps: Number of diffusion steps.
|
| 46 |
+
mask: Mask token ID.
|
| 47 |
+
pad: Pad token ID.
|
| 48 |
+
batch_size: Number of sequences to generate.
|
| 49 |
+
max_length: Maximum sequence length.
|
| 50 |
+
quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
|
| 51 |
+
compute_rnd: Whether to compute step-wise log importance weights.
|
| 52 |
+
pretrained: Frozen pretrained model (required if compute_rnd=True).
|
| 53 |
+
remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
|
| 54 |
+
num_remasking: Number of tokens to remask per step.
|
| 55 |
+
quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
|
| 56 |
+
temperature: Sampling temperature (1.0 = no scaling).
|
| 57 |
+
return_trace: Whether to record sampling trace.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
(xt, log_rnd, sampling_trace)
|
| 61 |
+
log_rnd is None when compute_rnd=False.
|
| 62 |
+
"""
|
| 63 |
+
assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
|
| 64 |
+
if compute_rnd:
|
| 65 |
+
assert pretrained is not None, "pretrained model required when compute_rnd=True"
|
| 66 |
+
|
| 67 |
+
# Derive flags from quality_mode
|
| 68 |
+
use_remasking = quality_mode != "none"
|
| 69 |
+
disable_unmasking_planner = quality_mode in ("none", "insertion_only")
|
| 70 |
+
disable_insertion_planner = quality_mode in ("none", "unmasking_only")
|
| 71 |
+
|
| 72 |
+
device = next(model.parameters()).device
|
| 73 |
+
|
| 74 |
+
# Initialize all-pad sequence
|
| 75 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 76 |
+
|
| 77 |
+
dt = 1.0 / steps
|
| 78 |
+
t = torch.zeros(batch_size, device=device)
|
| 79 |
+
|
| 80 |
+
# Precompute index tensors
|
| 81 |
+
batch_idx_L = (
|
| 82 |
+
torch.arange(batch_size, device=device)
|
| 83 |
+
.view(batch_size, 1)
|
| 84 |
+
.expand(batch_size, max_length)
|
| 85 |
+
)
|
| 86 |
+
pos_idx_L = (
|
| 87 |
+
torch.arange(max_length, device=device)
|
| 88 |
+
.view(1, max_length)
|
| 89 |
+
.expand(batch_size, max_length)
|
| 90 |
+
)
|
| 91 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 92 |
+
|
| 93 |
+
neg_inf = torch.tensor(-np.inf, device=device)
|
| 94 |
+
|
| 95 |
+
if use_remasking and remasking_mode == "remdm_conf":
|
| 96 |
+
remasking_score = torch.zeros((batch_size, max_length), device=device)
|
| 97 |
+
|
| 98 |
+
log_rnd = None
|
| 99 |
+
|
| 100 |
+
for i in range(steps):
|
| 101 |
+
# --- Policy model forward ---
|
| 102 |
+
pred_rate = model(xt, t)
|
| 103 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 104 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 105 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 106 |
+
|
| 107 |
+
# --- Pretrained model forward (for RND) ---
|
| 108 |
+
if compute_rnd:
|
| 109 |
+
pretrained_pred = pretrained(xt, t)
|
| 110 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 111 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 112 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 113 |
+
|
| 114 |
+
# --- Unmask step (Euler) ---
|
| 115 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 116 |
+
unmask_rate[xt != mask] = 0
|
| 117 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 118 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 119 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 120 |
+
|
| 121 |
+
if compute_rnd:
|
| 122 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 123 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 124 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 125 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 126 |
+
|
| 127 |
+
# Add "stay" probability
|
| 128 |
+
_xt = xt.clone()
|
| 129 |
+
_xt[xt == pad] = mask
|
| 130 |
+
trans_prob.scatter_add_(
|
| 131 |
+
2,
|
| 132 |
+
_xt.unsqueeze(-1),
|
| 133 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 134 |
+
)
|
| 135 |
+
if compute_rnd:
|
| 136 |
+
pretrained_trans_prob.scatter_add_(
|
| 137 |
+
2,
|
| 138 |
+
_xt.unsqueeze(-1),
|
| 139 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Temperature scaling
|
| 143 |
+
if temperature != 1.0:
|
| 144 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 145 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 146 |
+
|
| 147 |
+
# Final step: remove mask token from sampling
|
| 148 |
+
if i == steps - 1:
|
| 149 |
+
print("Final step, removing mask token from sampling")
|
| 150 |
+
trans_prob[mask_pos + (mask,)] = 0.0
|
| 151 |
+
|
| 152 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 153 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 154 |
+
if mask_has_zero_prob.any():
|
| 155 |
+
num_zero_prob = mask_has_zero_prob.sum().item()
|
| 156 |
+
uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
|
| 157 |
+
uniform_prob[:, :mask] = 1.0 / mask
|
| 158 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 159 |
+
else:
|
| 160 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 161 |
+
|
| 162 |
+
new_xt = _sample_tokens(trans_prob)
|
| 163 |
+
new_xt[xt == pad] = pad
|
| 164 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 165 |
+
|
| 166 |
+
# Update remasking_score buffer for remdm_conf mode
|
| 167 |
+
if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
|
| 168 |
+
token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
|
| 169 |
+
chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
|
| 170 |
+
changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 171 |
+
remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
|
| 172 |
+
|
| 173 |
+
# --- Remasking step ---
|
| 174 |
+
if use_remasking and i < steps - 1:
|
| 175 |
+
if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
|
| 176 |
+
remasking_conf = torch.zeros((batch_size, max_length), device=device)
|
| 177 |
+
else:
|
| 178 |
+
planner_out = model.planner(new_xt, t)
|
| 179 |
+
remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
|
| 180 |
+
|
| 181 |
+
clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
|
| 182 |
+
|
| 183 |
+
if remasking_mode == "schedule_aware":
|
| 184 |
+
new_xt = apply_schedule_aware_remasking(
|
| 185 |
+
model, new_xt, t, dt, remasking_conf, clean_index,
|
| 186 |
+
mask, neg_inf, batch_size,
|
| 187 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 188 |
+
)
|
| 189 |
+
remasking_score_temp = None
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
|
| 192 |
+
|
| 193 |
+
if remasking_score_temp is not None:
|
| 194 |
+
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
|
| 195 |
+
for j in range(batch_size):
|
| 196 |
+
k = min(num_remasking, int(clean_index[j].sum().item()))
|
| 197 |
+
if k > 0:
|
| 198 |
+
_, select_indices = torch.topk(remasking_score_temp[j], k=k)
|
| 199 |
+
new_xt[j, select_indices] = mask
|
| 200 |
+
|
| 201 |
+
if return_trace:
|
| 202 |
+
for batch_idx in range(batch_size):
|
| 203 |
+
for pos in range(max_length):
|
| 204 |
+
if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
|
| 205 |
+
sampling_trace[batch_idx].append(
|
| 206 |
+
SamplingTraceDatapoint(
|
| 207 |
+
t=t[batch_idx].item(),
|
| 208 |
+
event_type="change",
|
| 209 |
+
position=pos,
|
| 210 |
+
token=mask,
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# --- Compute log probabilities for RND ---
|
| 215 |
+
if compute_rnd:
|
| 216 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 217 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 218 |
+
|
| 219 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 220 |
+
|
| 221 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 222 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 223 |
+
|
| 224 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 225 |
+
|
| 226 |
+
# --- Insertion step ---
|
| 227 |
+
if i != steps - 1:
|
| 228 |
+
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
|
| 229 |
+
|
| 230 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 231 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 232 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 233 |
+
total_ext = ext.sum(dim=1)
|
| 234 |
+
valid = xt_len + total_ext <= max_length
|
| 235 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 236 |
+
|
| 237 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 238 |
+
new_len = xt_len + total_ext # (B,)
|
| 239 |
+
|
| 240 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 241 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 242 |
+
xt_tmp[mask_fill] = mask
|
| 243 |
+
|
| 244 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 245 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 246 |
+
flat_b = batch_idx_L[orig_mask]
|
| 247 |
+
flat_p = new_pos_orig[orig_mask]
|
| 248 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 249 |
+
|
| 250 |
+
# Schedule-aware insertion quality filtering
|
| 251 |
+
if use_remasking and not disable_insertion_planner:
|
| 252 |
+
if compute_rnd:
|
| 253 |
+
xt_tmp_before = xt_tmp.clone()
|
| 254 |
+
|
| 255 |
+
xt_tmp = apply_schedule_aware_insertion(
|
| 256 |
+
model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
|
| 257 |
+
orig_mask, new_pos_orig, quality_threshold
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if compute_rnd:
|
| 261 |
+
# Compute corrected ext based on what actually stayed
|
| 262 |
+
ext_corrected = torch.zeros_like(ext)
|
| 263 |
+
for b in range(batch_size):
|
| 264 |
+
after_len = xt_tmp[b].ne(pad).sum().item()
|
| 265 |
+
orig_len = xt_len[b].item()
|
| 266 |
+
surviving_insertions = after_len - orig_len
|
| 267 |
+
if total_ext[b] > 0:
|
| 268 |
+
ratio = surviving_insertions / total_ext[b].item()
|
| 269 |
+
ext_corrected[b] = (ext[b].float() * ratio).long()
|
| 270 |
+
else:
|
| 271 |
+
ext_corrected = ext
|
| 272 |
+
else:
|
| 273 |
+
ext_corrected = ext
|
| 274 |
+
|
| 275 |
+
# Compute insertion log_rnd
|
| 276 |
+
if compute_rnd:
|
| 277 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 278 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 279 |
+
|
| 280 |
+
log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
|
| 281 |
+
log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
|
| 282 |
+
|
| 283 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert
|
| 284 |
+
log_rnd += log_insert_diff
|
| 285 |
+
else:
|
| 286 |
+
xt_tmp = new_xt
|
| 287 |
+
|
| 288 |
+
if return_trace:
|
| 289 |
+
for batch_idx in range(batch_size):
|
| 290 |
+
for j in range(max_length):
|
| 291 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 292 |
+
sampling_trace[batch_idx].append(
|
| 293 |
+
SamplingTraceDatapoint(
|
| 294 |
+
t=t[batch_idx].item(),
|
| 295 |
+
event_type="change",
|
| 296 |
+
position=j,
|
| 297 |
+
token=new_xt[batch_idx, j].item(),
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if i != steps - 1:
|
| 302 |
+
for j in range(max_length):
|
| 303 |
+
id = max_length - j - 1
|
| 304 |
+
if ext[batch_idx, id]:
|
| 305 |
+
sampling_trace[batch_idx].append(
|
| 306 |
+
SamplingTraceDatapoint(
|
| 307 |
+
t=t[batch_idx].item(),
|
| 308 |
+
event_type="insertion",
|
| 309 |
+
position=id,
|
| 310 |
+
token=mask,
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
xt = xt_tmp
|
| 315 |
+
t = t + dt
|
| 316 |
+
|
| 317 |
+
return xt, log_rnd, sampling_trace
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _decode_and_validate(model, tokenizer, samples):
|
| 321 |
+
"""Decode token IDs to SMILES and validate.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
(validSequences, valid_indices): list of valid SMILES, list of batch indices.
|
| 325 |
+
"""
|
| 326 |
+
decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
|
| 327 |
+
|
| 328 |
+
use_bracket_safe = model.config.training.get('use_bracket_safe', False)
|
| 329 |
+
smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
|
| 330 |
+
|
| 331 |
+
# Extract valid sequences (take largest fragment)
|
| 332 |
+
validSequences = []
|
| 333 |
+
valid_indices = []
|
| 334 |
+
for idx, s in enumerate(smiles_samples):
|
| 335 |
+
if s:
|
| 336 |
+
largest_frag = sorted(s.split('.'), key=len)[-1]
|
| 337 |
+
validSequences.append(largest_frag)
|
| 338 |
+
valid_indices.append(idx)
|
| 339 |
+
|
| 340 |
+
return validSequences, valid_indices
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def sample_mol_buffer(
|
| 345 |
+
model, pretrained, reward_model, tokenizer,
|
| 346 |
+
steps, mask, pad, batch_size, max_length,
|
| 347 |
+
quality_mode="both",
|
| 348 |
+
alpha=0.1,
|
| 349 |
+
remasking_mode="schedule_aware",
|
| 350 |
+
num_remasking=1,
|
| 351 |
+
quality_threshold=1,
|
| 352 |
+
temperature=1.0,
|
| 353 |
+
use_quality_filter=True,
|
| 354 |
+
):
|
| 355 |
+
"""Generate molecules for training buffer. Always computes step-wise RND.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
model: Finetuned policy model.
|
| 359 |
+
pretrained: Frozen pretrained model.
|
| 360 |
+
reward_model: Molecule scoring function.
|
| 361 |
+
tokenizer: SAFE tokenizer for decoding.
|
| 362 |
+
steps: Number of diffusion steps.
|
| 363 |
+
mask: Mask token ID.
|
| 364 |
+
pad: Pad token ID.
|
| 365 |
+
batch_size: Number of sequences to generate.
|
| 366 |
+
max_length: Maximum sequence length.
|
| 367 |
+
quality_mode: "none", "both", "unmasking_only", or "insertion_only".
|
| 368 |
+
alpha: RND scaling factor.
|
| 369 |
+
remasking_mode: Remasking strategy.
|
| 370 |
+
num_remasking: Number of tokens to remask per step.
|
| 371 |
+
quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
|
| 372 |
+
temperature: Sampling temperature.
|
| 373 |
+
use_quality_filter: If True, filter to QED>=0.6 and SA<=4.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
(valid_x, log_rnd, scalar_rewards, sampling_trace)
|
| 377 |
+
"""
|
| 378 |
+
xt, log_rnd, trace = _diffusion_loop(
|
| 379 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 380 |
+
quality_mode=quality_mode,
|
| 381 |
+
compute_rnd=True,
|
| 382 |
+
pretrained=pretrained,
|
| 383 |
+
remasking_mode=remasking_mode,
|
| 384 |
+
num_remasking=num_remasking,
|
| 385 |
+
quality_threshold=quality_threshold,
|
| 386 |
+
temperature=temperature,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
device = xt.device
|
| 390 |
+
samples = xt.to(device)
|
| 391 |
+
|
| 392 |
+
validSequences, valid_indices = _decode_and_validate(model, tokenizer, samples)
|
| 393 |
+
|
| 394 |
+
valid_x_final = [samples[idx] for idx in valid_indices]
|
| 395 |
+
valid_log_rnd = [log_rnd[idx] for idx in valid_indices]
|
| 396 |
+
|
| 397 |
+
print("len valid sequences:", len(validSequences))
|
| 398 |
+
|
| 399 |
+
if len(validSequences) == 0:
|
| 400 |
+
print("[WARNING] No valid molecules generated in this batch")
|
| 401 |
+
empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
|
| 402 |
+
empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
|
| 403 |
+
empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
|
| 404 |
+
return empty_x, empty_log_rnd, empty_rewards, trace
|
| 405 |
+
|
| 406 |
+
# Compute multi-objective rewards
|
| 407 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 408 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 409 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
|
| 410 |
+
|
| 411 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 412 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 413 |
+
|
| 414 |
+
log_rnd = valid_log_rnd + (scalar_rewards / alpha)
|
| 415 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 416 |
+
|
| 417 |
+
# Optionally filter to only keep quality sequences (QED >= 0.6 and SA <= 4)
|
| 418 |
+
if use_quality_filter:
|
| 419 |
+
qed_scores = score_vectors[:, 0]
|
| 420 |
+
if score_vectors.shape[1] > 1:
|
| 421 |
+
sa_scores = score_vectors[:, 1]
|
| 422 |
+
else:
|
| 423 |
+
_oracle_sa = Oracle('sa')
|
| 424 |
+
raw_sa = np.array(_oracle_sa(validSequences))
|
| 425 |
+
sa_scores = raw_sa
|
| 426 |
+
quality_mask = (qed_scores >= 0.6) & (sa_scores <= 4)
|
| 427 |
+
|
| 428 |
+
n_quality = quality_mask.sum()
|
| 429 |
+
print(f"Quality filtering: {n_quality}/{len(validSequences)} sequences pass (QED>=0.6, SA<=4)")
|
| 430 |
+
|
| 431 |
+
if n_quality == 0:
|
| 432 |
+
print("[WARNING] No quality molecules in this batch")
|
| 433 |
+
empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
|
| 434 |
+
empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
|
| 435 |
+
empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
|
| 436 |
+
return empty_x, empty_log_rnd, empty_rewards, trace
|
| 437 |
+
|
| 438 |
+
quality_mask_torch = torch.as_tensor(quality_mask, dtype=torch.bool, device=device)
|
| 439 |
+
|
| 440 |
+
quality_x_final = valid_x_final[quality_mask_torch]
|
| 441 |
+
quality_log_rnd = log_rnd[quality_mask_torch]
|
| 442 |
+
quality_rewards = scalar_rewards[quality_mask_torch]
|
| 443 |
+
else:
|
| 444 |
+
print(f"No quality filtering applied - using all {len(validSequences)} valid molecules")
|
| 445 |
+
quality_x_final = valid_x_final
|
| 446 |
+
quality_log_rnd = log_rnd
|
| 447 |
+
quality_rewards = scalar_rewards
|
| 448 |
+
|
| 449 |
+
return quality_x_final, quality_log_rnd, quality_rewards, trace
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
@torch.no_grad()
|
| 453 |
+
def sample_mol_eval(
|
| 454 |
+
model, reward_model, tokenizer,
|
| 455 |
+
steps, mask, pad, batch_size, max_length,
|
| 456 |
+
quality_mode="both",
|
| 457 |
+
remasking_mode="schedule_aware",
|
| 458 |
+
num_remasking=1,
|
| 459 |
+
quality_threshold=1,
|
| 460 |
+
temperature=1.0,
|
| 461 |
+
evaluator=None,
|
| 462 |
+
dataframe=False,
|
| 463 |
+
unmask_quality_threshold=None,
|
| 464 |
+
):
|
| 465 |
+
"""Generate molecules for evaluation.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
model: Finetuned policy model.
|
| 469 |
+
reward_model: Molecule scoring function.
|
| 470 |
+
tokenizer: SAFE tokenizer for decoding.
|
| 471 |
+
steps: Number of diffusion steps.
|
| 472 |
+
mask: Mask token ID.
|
| 473 |
+
pad: Pad token ID.
|
| 474 |
+
batch_size: Number of sequences to generate.
|
| 475 |
+
max_length: Maximum sequence length.
|
| 476 |
+
quality_mode: "none", "both", "unmasking_only", or "insertion_only".
|
| 477 |
+
remasking_mode: Remasking strategy.
|
| 478 |
+
num_remasking: Number of tokens to remask per step.
|
| 479 |
+
quality_threshold: Threshold for insertion quality filtering. Pass None
|
| 480 |
+
to use schedule-driven deletion with no threshold gate
|
| 481 |
+
temperature: Sampling temperature.
|
| 482 |
+
evaluator: TDC Evaluator for diversity (created if None).
|
| 483 |
+
dataframe: If True, include a pandas DataFrame in the return.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
Without dataframe:
|
| 487 |
+
(validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction)
|
| 488 |
+
With dataframe:
|
| 489 |
+
(validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df)
|
| 490 |
+
validSequences is the raw list including duplicates; qed/sa are scored
|
| 491 |
+
on the unique set. Caller can dedup with set(validSequences). The
|
| 492 |
+
dataframe (when requested) has one row per unique molecule.
|
| 493 |
+
"""
|
| 494 |
+
if evaluator is None:
|
| 495 |
+
evaluator = Evaluator('diversity')
|
| 496 |
+
|
| 497 |
+
xt, _, trace = _diffusion_loop(
|
| 498 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 499 |
+
quality_mode=quality_mode,
|
| 500 |
+
compute_rnd=False,
|
| 501 |
+
remasking_mode=remasking_mode,
|
| 502 |
+
num_remasking=num_remasking,
|
| 503 |
+
quality_threshold=quality_threshold,
|
| 504 |
+
temperature=temperature,
|
| 505 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
device = xt.device
|
| 509 |
+
samples = xt.to(device)
|
| 510 |
+
|
| 511 |
+
decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
|
| 512 |
+
|
| 513 |
+
use_bracket_safe = model.config.training.get('use_bracket_safe', False)
|
| 514 |
+
smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
|
| 515 |
+
|
| 516 |
+
# Extract valid sequences (take largest fragment)
|
| 517 |
+
validSequences = [sorted(s.split('.'), key=len)[-1] for s in smiles_samples if s]
|
| 518 |
+
|
| 519 |
+
print("len valid sequences:", len(validSequences))
|
| 520 |
+
valid_fraction = len(validSequences) / batch_size
|
| 521 |
+
uniqueSequences = list(set(validSequences))
|
| 522 |
+
uniqueness = len(uniqueSequences) / len(validSequences) if len(validSequences) > 0 else 0
|
| 523 |
+
diversity = evaluator(uniqueSequences) if len(uniqueSequences) > 0 else 0
|
| 524 |
+
|
| 525 |
+
# Calculate quality (unique sequences with QED >= 0.6 and SA <= 4)
|
| 526 |
+
if len(uniqueSequences) > 0:
|
| 527 |
+
score_vectors_temp = reward_model(input_seqs=list(uniqueSequences))
|
| 528 |
+
qed_scores = score_vectors_temp[:, 0] # Raw QED (0-1)
|
| 529 |
+
|
| 530 |
+
# Always use raw SA (1-10 scale) for quality filtering
|
| 531 |
+
_oracle_sa = Oracle('sa')
|
| 532 |
+
raw_sa_scores = np.array(_oracle_sa(list(uniqueSequences)))
|
| 533 |
+
|
| 534 |
+
quality_count = sum((qed_scores >= 0.6) & (raw_sa_scores <= 4))
|
| 535 |
+
quality = quality_count / batch_size
|
| 536 |
+
print(f'Quality:\t{quality}')
|
| 537 |
+
|
| 538 |
+
qed = qed_scores
|
| 539 |
+
sa = raw_sa_scores
|
| 540 |
+
else:
|
| 541 |
+
zeros = [0.0]
|
| 542 |
+
qed = zeros
|
| 543 |
+
sa = zeros
|
| 544 |
+
quality = 0.0
|
| 545 |
+
|
| 546 |
+
if dataframe:
|
| 547 |
+
df = pd.DataFrame({
|
| 548 |
+
"Mol Sequence": uniqueSequences,
|
| 549 |
+
"QED": qed if len(uniqueSequences) else [0.0],
|
| 550 |
+
"SA": sa if len(uniqueSequences) else [0.0],
|
| 551 |
+
})
|
| 552 |
+
return validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
|
| 553 |
+
|
| 554 |
+
return validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction
|
a2d2_mol/mol_dataset.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Adapter to use HuggingFace datasets with the any-length discrete diffusion model.
|
| 4 |
+
This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format
|
| 5 |
+
expected by the training pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
from safe.tokenizer import SAFETokenizer
|
| 13 |
+
from mol_utils.bracket_safe_converter import safe2bracketsafe
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_tokenizer():
|
| 19 |
+
"""Get SAFE tokenizer with added special tokens."""
|
| 20 |
+
tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
|
| 21 |
+
tk.add_tokens(['<', '>']) # for bracket_safe
|
| 22 |
+
return tk
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Collator:
|
| 26 |
+
"""Data collator for SAFE/bracket-SAFE format."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config, tokenizer=None):
|
| 29 |
+
self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer()
|
| 30 |
+
self.max_length = config.interpolant.max_length
|
| 31 |
+
self.use_bracket_safe = config.training.get('use_bracket_safe', False)
|
| 32 |
+
|
| 33 |
+
def __call__(self, examples):
|
| 34 |
+
# Handle both dict with 'labels' and direct string format
|
| 35 |
+
inputs = []
|
| 36 |
+
for example in examples:
|
| 37 |
+
if isinstance(example, dict):
|
| 38 |
+
# Try different key names: 'input', 'labels', 'smiles'
|
| 39 |
+
input_text = example.get('input', example.get('labels', example.get('smiles', '')))
|
| 40 |
+
else:
|
| 41 |
+
input_text = example
|
| 42 |
+
|
| 43 |
+
if self.use_bracket_safe:
|
| 44 |
+
input_text = safe2bracketsafe(input_text)
|
| 45 |
+
|
| 46 |
+
inputs.append(input_text)
|
| 47 |
+
|
| 48 |
+
batch = self.tokenizer(
|
| 49 |
+
inputs,
|
| 50 |
+
return_tensors='pt',
|
| 51 |
+
padding=True,
|
| 52 |
+
truncation=True,
|
| 53 |
+
max_length=self.max_length
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Convert BatchEncoding to plain dict with tensors
|
| 57 |
+
# Remove token_type_ids if present (not needed for diffusion models)
|
| 58 |
+
result = {
|
| 59 |
+
'input_ids': batch['input_ids'],
|
| 60 |
+
'attention_mask': batch['attention_mask']
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class HFDatasetAdapter(Dataset):
|
| 67 |
+
"""Adapts HuggingFace datasets to the format expected by the diffusion model."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
hf_dataset: HuggingFace dataset object (streaming or regular)
|
| 73 |
+
tokenizer: SMILES tokenizer instance
|
| 74 |
+
smiles_column: Name of the column containing SMILES strings
|
| 75 |
+
max_length: Maximum sequence length
|
| 76 |
+
convert_to_safe: Whether to convert SMILES to SAFE format
|
| 77 |
+
is_streaming: Whether dataset is in streaming mode
|
| 78 |
+
"""
|
| 79 |
+
self.tokenizer = tokenizer
|
| 80 |
+
self.smiles_column = smiles_column
|
| 81 |
+
self.max_length = max_length
|
| 82 |
+
self.convert_to_safe = convert_to_safe
|
| 83 |
+
self.is_streaming = is_streaming
|
| 84 |
+
|
| 85 |
+
if is_streaming:
|
| 86 |
+
# For streaming datasets, we don't pre-load the data
|
| 87 |
+
self.data = hf_dataset
|
| 88 |
+
self._length = None # Unknown length for streaming
|
| 89 |
+
print(f'Initialized streaming dataset adapter')
|
| 90 |
+
else:
|
| 91 |
+
# Store raw data without pre-tokenization (tokenization will happen in collator)
|
| 92 |
+
print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...')
|
| 93 |
+
self.data = []
|
| 94 |
+
for item in hf_dataset:
|
| 95 |
+
smiles = item[smiles_column]
|
| 96 |
+
if smiles: # Skip empty SMILES
|
| 97 |
+
self.data.append({'input': smiles, 'labels': smiles})
|
| 98 |
+
print(f'Processed {len(self.data)} valid samples')
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
if self.is_streaming:
|
| 102 |
+
# Streaming datasets don't have a length
|
| 103 |
+
# Return a large number to prevent issues with samplers
|
| 104 |
+
return 10_000_000 if self._length is None else self._length
|
| 105 |
+
return len(self.data)
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx):
|
| 108 |
+
if self.is_streaming:
|
| 109 |
+
# For streaming, iteration happens differently
|
| 110 |
+
raise NotImplementedError("Streaming datasets should be iterated, not indexed")
|
| 111 |
+
return self.data[idx]
|
| 112 |
+
|
| 113 |
+
def __iter__(self):
|
| 114 |
+
"""Support iteration for streaming datasets."""
|
| 115 |
+
if self.is_streaming:
|
| 116 |
+
for item in self.data:
|
| 117 |
+
smiles = item[self.smiles_column]
|
| 118 |
+
if smiles: # Skip empty SMILES
|
| 119 |
+
yield {'input': smiles, 'labels': smiles}
|
| 120 |
+
else:
|
| 121 |
+
for item in self.data:
|
| 122 |
+
yield item
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class HFDataModule(pl.LightningDataModule):
|
| 126 |
+
"""PyTorch Lightning DataModule for HuggingFace datasets."""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
config,
|
| 131 |
+
dataset_name: str,
|
| 132 |
+
tokenizer: SAFETokenizer,
|
| 133 |
+
smiles_column: str = 'smiles',
|
| 134 |
+
val_split: float = 0.1,
|
| 135 |
+
test_split: Optional[float] = None,
|
| 136 |
+
streaming: bool = True,
|
| 137 |
+
max_train_samples: Optional[int] = None,
|
| 138 |
+
max_val_samples: Optional[int] = None,
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
config: Configuration object containing training parameters
|
| 143 |
+
dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt")
|
| 144 |
+
tokenizer: SMILES tokenizer instance
|
| 145 |
+
smiles_column: Name of column containing SMILES strings
|
| 146 |
+
val_split: Fraction of data to use for validation
|
| 147 |
+
test_split: Optional fraction of data to use for testing
|
| 148 |
+
streaming: Whether to use streaming mode (recommended for large datasets)
|
| 149 |
+
max_train_samples: Maximum number of training samples to use (for non-streaming)
|
| 150 |
+
max_val_samples: Maximum number of validation samples to use (for non-streaming)
|
| 151 |
+
"""
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.config = config
|
| 154 |
+
self.dataset_name = dataset_name
|
| 155 |
+
self.tokenizer = tokenizer
|
| 156 |
+
self.smiles_column = smiles_column
|
| 157 |
+
self.max_length = config.interpolant.max_length
|
| 158 |
+
self.batch_size = config.training.per_gpu_batch_size
|
| 159 |
+
self.num_workers = config.training.get('cpus', 4)
|
| 160 |
+
self.val_split = val_split
|
| 161 |
+
self.test_split = test_split
|
| 162 |
+
self.streaming = streaming
|
| 163 |
+
self.max_train_samples = max_train_samples
|
| 164 |
+
self.max_val_samples = max_val_samples
|
| 165 |
+
|
| 166 |
+
self.train_dataset = None
|
| 167 |
+
self.val_dataset = None
|
| 168 |
+
self.test_dataset = None
|
| 169 |
+
|
| 170 |
+
# Initialize collator
|
| 171 |
+
self.collator = Collator(config, tokenizer)
|
| 172 |
+
|
| 173 |
+
def setup(self, stage: Optional[str] = None):
|
| 174 |
+
"""Load and split the dataset."""
|
| 175 |
+
print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})')
|
| 176 |
+
|
| 177 |
+
if self.streaming:
|
| 178 |
+
# Load dataset in streaming mode
|
| 179 |
+
raw_dataset = load_dataset(self.dataset_name, streaming=True)
|
| 180 |
+
|
| 181 |
+
# Handle different dataset structures
|
| 182 |
+
if 'train' in raw_dataset:
|
| 183 |
+
train_stream = raw_dataset['train']
|
| 184 |
+
else:
|
| 185 |
+
# If no splits exist, use the entire dataset
|
| 186 |
+
train_stream = raw_dataset[list(raw_dataset.keys())[0]]
|
| 187 |
+
|
| 188 |
+
# For streaming, we need to manually split train/val
|
| 189 |
+
# Skip validation samples, then take training samples
|
| 190 |
+
val_size = int(100000 * self.val_split) # Assume ~100k samples for val split calculation
|
| 191 |
+
train_size = 100000 - val_size
|
| 192 |
+
|
| 193 |
+
# Create validation stream (take first val_size samples)
|
| 194 |
+
val_stream = train_stream.take(val_size)
|
| 195 |
+
|
| 196 |
+
# Create training stream (skip val_size samples, then iterate)
|
| 197 |
+
train_stream_shifted = train_stream.skip(val_size)
|
| 198 |
+
|
| 199 |
+
# Create adapted datasets
|
| 200 |
+
self.train_dataset = HFDatasetAdapter(
|
| 201 |
+
train_stream_shifted,
|
| 202 |
+
self.tokenizer,
|
| 203 |
+
self.smiles_column,
|
| 204 |
+
self.max_length,
|
| 205 |
+
is_streaming=True
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.val_dataset = HFDatasetAdapter(
|
| 209 |
+
val_stream,
|
| 210 |
+
self.tokenizer,
|
| 211 |
+
self.smiles_column,
|
| 212 |
+
self.max_length,
|
| 213 |
+
is_streaming=True
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
print(f'Streaming dataset initialized - samples will be loaded on-the-fly')
|
| 217 |
+
|
| 218 |
+
else:
|
| 219 |
+
# Traditional non-streaming mode with full dataset loading
|
| 220 |
+
raw_dataset = load_dataset(self.dataset_name)
|
| 221 |
+
|
| 222 |
+
# Handle different dataset structures
|
| 223 |
+
if 'train' in raw_dataset:
|
| 224 |
+
train_data = raw_dataset['train']
|
| 225 |
+
else:
|
| 226 |
+
# If no splits exist, use the entire dataset and split it
|
| 227 |
+
train_data = raw_dataset[list(raw_dataset.keys())[0]]
|
| 228 |
+
|
| 229 |
+
# Limit samples if specified
|
| 230 |
+
if self.max_train_samples:
|
| 231 |
+
train_data = train_data.select(range(min(self.max_train_samples, len(train_data))))
|
| 232 |
+
|
| 233 |
+
# Check if dataset already has validation split
|
| 234 |
+
if 'validation' in raw_dataset or 'val' in raw_dataset:
|
| 235 |
+
val_key = 'validation' if 'validation' in raw_dataset else 'val'
|
| 236 |
+
val_data = raw_dataset[val_key]
|
| 237 |
+
else:
|
| 238 |
+
# Create train/val split
|
| 239 |
+
split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42)
|
| 240 |
+
train_data = split_dataset['train']
|
| 241 |
+
val_data = split_dataset['test']
|
| 242 |
+
|
| 243 |
+
# Limit validation samples if specified
|
| 244 |
+
if self.max_val_samples:
|
| 245 |
+
val_data = val_data.select(range(min(self.max_val_samples, len(val_data))))
|
| 246 |
+
|
| 247 |
+
# Create test split if requested
|
| 248 |
+
if self.test_split and 'test' not in raw_dataset:
|
| 249 |
+
split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42)
|
| 250 |
+
train_data = split_dataset['train']
|
| 251 |
+
self.test_dataset = HFDatasetAdapter(
|
| 252 |
+
split_dataset['test'],
|
| 253 |
+
self.tokenizer,
|
| 254 |
+
self.smiles_column,
|
| 255 |
+
self.max_length,
|
| 256 |
+
is_streaming=False
|
| 257 |
+
)
|
| 258 |
+
elif 'test' in raw_dataset:
|
| 259 |
+
self.test_dataset = HFDatasetAdapter(
|
| 260 |
+
raw_dataset['test'],
|
| 261 |
+
self.tokenizer,
|
| 262 |
+
self.smiles_column,
|
| 263 |
+
self.max_length,
|
| 264 |
+
is_streaming=False
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Create adapted datasets
|
| 268 |
+
self.train_dataset = HFDatasetAdapter(
|
| 269 |
+
train_data,
|
| 270 |
+
self.tokenizer,
|
| 271 |
+
self.smiles_column,
|
| 272 |
+
self.max_length,
|
| 273 |
+
is_streaming=False
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.val_dataset = HFDatasetAdapter(
|
| 277 |
+
val_data,
|
| 278 |
+
self.tokenizer,
|
| 279 |
+
self.smiles_column,
|
| 280 |
+
self.max_length,
|
| 281 |
+
is_streaming=False
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}')
|
| 285 |
+
if self.test_dataset:
|
| 286 |
+
print(f'Test: {len(self.test_dataset)}')
|
| 287 |
+
|
| 288 |
+
def train_dataloader(self):
|
| 289 |
+
if self.streaming:
|
| 290 |
+
# Pass streaming dataset directly to DataLoader (HF IterableDataset)
|
| 291 |
+
# Must use num_workers=0 when using .skip() or .take() operations
|
| 292 |
+
return DataLoader(
|
| 293 |
+
self.train_dataset.data, # Use the raw HF streaming dataset
|
| 294 |
+
batch_size=self.batch_size,
|
| 295 |
+
collate_fn=self.collator,
|
| 296 |
+
num_workers=0, # Required for streaming with skip/take operations
|
| 297 |
+
pin_memory=True,
|
| 298 |
+
shuffle=False, # Cannot shuffle streaming datasets
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
return DataLoader(
|
| 302 |
+
self.train_dataset,
|
| 303 |
+
batch_size=self.batch_size,
|
| 304 |
+
collate_fn=self.collator,
|
| 305 |
+
shuffle=True,
|
| 306 |
+
num_workers=self.num_workers,
|
| 307 |
+
pin_memory=True,
|
| 308 |
+
persistent_workers=True if self.num_workers > 0 else False
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def val_dataloader(self):
|
| 312 |
+
if self.streaming:
|
| 313 |
+
# Pass streaming dataset directly to DataLoader (HF IterableDataset)
|
| 314 |
+
# Must use num_workers=0 when using .skip() or .take() operations
|
| 315 |
+
return DataLoader(
|
| 316 |
+
self.val_dataset.data, # Use the raw HF streaming dataset
|
| 317 |
+
batch_size=self.batch_size,
|
| 318 |
+
collate_fn=self.collator,
|
| 319 |
+
num_workers=0, # Required for streaming with skip/take operations
|
| 320 |
+
pin_memory=True,
|
| 321 |
+
shuffle=False, # Cannot shuffle streaming datasets
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
return DataLoader(
|
| 325 |
+
self.val_dataset,
|
| 326 |
+
batch_size=self.batch_size,
|
| 327 |
+
collate_fn=self.collator,
|
| 328 |
+
shuffle=False,
|
| 329 |
+
num_workers=self.num_workers,
|
| 330 |
+
pin_memory=True,
|
| 331 |
+
persistent_workers=True if self.num_workers > 0 else False
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def test_dataloader(self):
|
| 335 |
+
if self.test_dataset:
|
| 336 |
+
return DataLoader(
|
| 337 |
+
self.test_dataset,
|
| 338 |
+
batch_size=self.batch_size,
|
| 339 |
+
collate_fn=self.collator,
|
| 340 |
+
shuffle=False,
|
| 341 |
+
num_workers=self.num_workers,
|
| 342 |
+
pin_memory=True,
|
| 343 |
+
persistent_workers=True if self.num_workers > 0 else False
|
| 344 |
+
)
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True):
|
| 349 |
+
"""
|
| 350 |
+
Setup HuggingFace dataset and update config with token information.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
config: Hydra config object
|
| 354 |
+
dataset_name: HuggingFace dataset identifier
|
| 355 |
+
smiles_column: Name of column containing SMILES strings
|
| 356 |
+
streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt)
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
HFDataModule instance
|
| 360 |
+
"""
|
| 361 |
+
# Initialize tokenizer
|
| 362 |
+
tokenizer = get_tokenizer()
|
| 363 |
+
|
| 364 |
+
# Update config with tokenizer info
|
| 365 |
+
config.interpolant.tokens = len(tokenizer)
|
| 366 |
+
config.interpolant.pad_token = tokenizer.pad_token_id
|
| 367 |
+
config.interpolant.mask_token = tokenizer.mask_token_id
|
| 368 |
+
|
| 369 |
+
# Create data module
|
| 370 |
+
data_module = HFDataModule(
|
| 371 |
+
config=config,
|
| 372 |
+
dataset_name=dataset_name,
|
| 373 |
+
tokenizer=tokenizer,
|
| 374 |
+
smiles_column=smiles_column,
|
| 375 |
+
val_split=0.1,
|
| 376 |
+
streaming=streaming,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
return data_module
|
a2d2_mol/mol_scoring/oracle/fpscores.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
|
| 3 |
+
size 9048931
|
a2d2_mol/mol_scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForMaskedLM
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tdc import Oracle, Evaluator
|
| 4 |
+
|
| 5 |
+
class MolScoringFunctions:
|
| 6 |
+
def __init__(self, score_func_names=None, device=None, sa_transform='inverse'):
|
| 7 |
+
"""
|
| 8 |
+
Class for generating score vectors given generated sequence
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
score_func_names: list of scoring function names to be evaluated
|
| 12 |
+
score_weights: weights to scale scores (default: 1)
|
| 13 |
+
sa_transform: how to transform SA scores to higher-is-better ~[0,1]:
|
| 14 |
+
'inverse' (default): 1/(1+SA) — range ~0.09-0.5, weak gradient
|
| 15 |
+
'linear': (10-SA)/9 — range ~0-1, stronger gradient
|
| 16 |
+
"""
|
| 17 |
+
if score_func_names is None:
|
| 18 |
+
# just do unmasking based on validity of peptide bonds
|
| 19 |
+
self.score_func_names = []
|
| 20 |
+
else:
|
| 21 |
+
self.score_func_names = score_func_names
|
| 22 |
+
|
| 23 |
+
self.sa_transform = sa_transform
|
| 24 |
+
|
| 25 |
+
oracle_qed = Oracle('qed')
|
| 26 |
+
oracle_sa = Oracle('sa')
|
| 27 |
+
|
| 28 |
+
self.all_funcs = {'qed': oracle_qed,
|
| 29 |
+
'sa': oracle_sa,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def forward(self, input_seqs):
|
| 33 |
+
scores = []
|
| 34 |
+
|
| 35 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 36 |
+
score = self.all_funcs[score_func](input_seqs)
|
| 37 |
+
|
| 38 |
+
# Transform SA to be maximized and normalized (original SA: 1-10, lower is better)
|
| 39 |
+
# Convert to: higher is better, normalized to ~0-1 range like QED
|
| 40 |
+
if score_func == 'sa':
|
| 41 |
+
if self.sa_transform == 'linear':
|
| 42 |
+
score = (10.0 - np.array(score)) / 9.0 # range ~0-1, clipped at 0
|
| 43 |
+
score = np.maximum(score, 0.0)
|
| 44 |
+
else:
|
| 45 |
+
score = 1.0 / (1.0 + np.array(score)) # range ~0.09-0.5
|
| 46 |
+
|
| 47 |
+
scores.append(score)
|
| 48 |
+
|
| 49 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 50 |
+
scores = np.float32(scores).T
|
| 51 |
+
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
def __call__(self, input_seqs: list):
|
| 55 |
+
return self.forward(input_seqs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def unittest():
|
| 59 |
+
scoring = MolScoringFunctions(score_func_names=['qed', 'sa'])
|
| 60 |
+
|
| 61 |
+
smiles = ['CCOc1cc(ccc1NC(=O)N[C@@H]2CCCC[C@@H]2O)F']
|
| 62 |
+
|
| 63 |
+
scores = scoring(input_seqs=smiles)
|
| 64 |
+
print(scores)
|
| 65 |
+
print(len(scores))
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
unittest()
|
a2d2_mol/mol_utils/bracket_safe_converter.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Patch: stub out `auto_docstring` if missing from transformers.utils
|
| 17 |
+
# (needed by safe.trainer.model in newer safe versions)
|
| 18 |
+
import transformers.utils as _tu
|
| 19 |
+
if not hasattr(_tu, 'auto_docstring'):
|
| 20 |
+
_tu.auto_docstring = lambda *a, **kw: (lambda fn: fn)
|
| 21 |
+
|
| 22 |
+
from safe.converter import *
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BracketSAFEConverter(SAFEConverter):
|
| 26 |
+
def encoder(
|
| 27 |
+
self,
|
| 28 |
+
inp: Union[str, dm.Mol],
|
| 29 |
+
canonical: bool = True,
|
| 30 |
+
randomize: Optional[bool] = False,
|
| 31 |
+
seed: Optional[int] = None,
|
| 32 |
+
constraints: Optional[List[dm.Mol]] = None,
|
| 33 |
+
allow_empty: bool = False,
|
| 34 |
+
rdkit_safe: bool = True,
|
| 35 |
+
):
|
| 36 |
+
rng = None
|
| 37 |
+
if randomize:
|
| 38 |
+
rng = np.random.default_rng(seed)
|
| 39 |
+
if not canonical:
|
| 40 |
+
inp = dm.to_mol(inp, remove_hs=False)
|
| 41 |
+
inp = self.randomize(inp, rng)
|
| 42 |
+
|
| 43 |
+
if isinstance(inp, dm.Mol):
|
| 44 |
+
inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False)
|
| 45 |
+
|
| 46 |
+
branch_numbers = self._find_branch_number(inp)
|
| 47 |
+
|
| 48 |
+
mol = dm.to_mol(inp, remove_hs=False)
|
| 49 |
+
if self.ignore_stereo:
|
| 50 |
+
mol = dm.remove_stereochemistry(mol)
|
| 51 |
+
|
| 52 |
+
bond_map_id = 1
|
| 53 |
+
for atom in mol.GetAtoms():
|
| 54 |
+
if atom.GetAtomicNum() == 0:
|
| 55 |
+
atom.SetAtomMapNum(0)
|
| 56 |
+
atom.SetIsotope(bond_map_id)
|
| 57 |
+
bond_map_id += 1
|
| 58 |
+
|
| 59 |
+
if self.require_hs:
|
| 60 |
+
mol = dm.add_hs(mol)
|
| 61 |
+
matching_bonds = self._fragment(mol, allow_empty=allow_empty)
|
| 62 |
+
substructed_ignored = []
|
| 63 |
+
if constraints is not None:
|
| 64 |
+
substructed_ignored = list(
|
| 65 |
+
itertools.chain(
|
| 66 |
+
*[
|
| 67 |
+
mol.GetSubstructMatches(constraint, uniquify=True)
|
| 68 |
+
for constraint in constraints
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
bonds = []
|
| 74 |
+
for i_a, i_b in matching_bonds:
|
| 75 |
+
# if both atoms of the bond are found in a disallowed substructure, we cannot consider them
|
| 76 |
+
# on the other end, a bond between two substructure to preserved independently is perfectly fine
|
| 77 |
+
if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored):
|
| 78 |
+
continue
|
| 79 |
+
obond = mol.GetBondBetweenAtoms(i_a, i_b)
|
| 80 |
+
bonds.append(obond.GetIdx())
|
| 81 |
+
|
| 82 |
+
if len(bonds) > 0:
|
| 83 |
+
mol = Chem.FragmentOnBonds(
|
| 84 |
+
mol,
|
| 85 |
+
bonds,
|
| 86 |
+
dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
frags = list(Chem.GetMolFrags(mol, asMols=True))
|
| 90 |
+
if randomize:
|
| 91 |
+
frags = rng.permutation(frags).tolist()
|
| 92 |
+
elif canonical:
|
| 93 |
+
frags = sorted(
|
| 94 |
+
frags,
|
| 95 |
+
key=lambda x: x.GetNumAtoms(),
|
| 96 |
+
reverse=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
frags_str = []
|
| 100 |
+
for frag in frags:
|
| 101 |
+
non_map_atom_idxs = [
|
| 102 |
+
atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
|
| 103 |
+
]
|
| 104 |
+
frags_str.append(
|
| 105 |
+
Chem.MolToSmiles(
|
| 106 |
+
frag,
|
| 107 |
+
isomericSmiles=True,
|
| 108 |
+
canonical=True, # needs to always be true
|
| 109 |
+
rootedAtAtom=non_map_atom_idxs[0],
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
scaffold_str = ".".join(frags_str)
|
| 114 |
+
|
| 115 |
+
# don't capture atom mapping in the scaffold
|
| 116 |
+
attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
|
| 117 |
+
if canonical:
|
| 118 |
+
attach_pos = sorted(attach_pos)
|
| 119 |
+
starting_num = 1
|
| 120 |
+
for attach in attach_pos:
|
| 121 |
+
val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
|
| 122 |
+
val = '<' + val + '>' # bracket added
|
| 123 |
+
# we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
|
| 124 |
+
attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
|
| 125 |
+
scaffold_str = attach_regexp.sub(val, scaffold_str)
|
| 126 |
+
starting_num += 1
|
| 127 |
+
|
| 128 |
+
# now we need to remove all the parenthesis around digit only number
|
| 129 |
+
wrong_attach = re.compile(r"\((<[\%\d+]*>)\)") # bracket added
|
| 130 |
+
scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
|
| 131 |
+
# furthermore, we autoapply rdkit-compatible digit standardization.
|
| 132 |
+
if rdkit_safe:
|
| 133 |
+
pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
|
| 134 |
+
replacement = r"\g<1>\g<2>"
|
| 135 |
+
scaffold_str = re.sub(pattern, replacement, scaffold_str)
|
| 136 |
+
return scaffold_str
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def safe2bracketsafe(safe_str):
|
| 140 |
+
try:
|
| 141 |
+
return BracketSAFEConverter().encoder(Chem.MolFromSmiles(safe_str), allow_empty=True, canonical=False, randomize=True)
|
| 142 |
+
except:
|
| 143 |
+
return safe_str
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def bracketsafe2safe(safe_str):
|
| 147 |
+
intrafrag_points = [m.group(0) for m in re.finditer(r'(?<!%)\d(?!>)', safe_str)] + \
|
| 148 |
+
[m.group(0).lstrip('%') for m in re.finditer(r'%\d+', safe_str)]
|
| 149 |
+
starting_num = max([int(i) for i in intrafrag_points]) + 1 if intrafrag_points else 0
|
| 150 |
+
interfrag_points = [(m.start(0), m.end(0)) for m in re.finditer(r'<\d+>', safe_str)]
|
| 151 |
+
|
| 152 |
+
safe_str = list(safe_str)
|
| 153 |
+
for start, end in interfrag_points:
|
| 154 |
+
safe_str[start] = safe_str[end-1] = ' ' # '<', '>' -> ''
|
| 155 |
+
num_to_replace = int(''.join(safe_str[start+1 : end-1])) + starting_num
|
| 156 |
+
num_to_replace = '%' + str(num_to_replace) if num_to_replace >= 10 else str(num_to_replace)
|
| 157 |
+
safe_str[start+1 : end-1] = [num_to_replace] + [' '] * (end - start - 3)
|
| 158 |
+
safe_str = re.sub(' ', '', ''.join(safe_str))
|
| 159 |
+
return safe_str
|
a2d2_mol/mol_utils/utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import fsspec
|
| 9 |
+
import lightning
|
| 10 |
+
import torch
|
| 11 |
+
from timm.scheduler import CosineLRScheduler
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def sample_categorical_logits(logits, dtype=torch.float64):
|
| 18 |
+
# do not require logits to be log-softmaxed
|
| 19 |
+
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
|
| 20 |
+
return (logits + gumbel_noise).argmax(dim=-1)
|
| 21 |
+
|
| 22 |
+
def fsspec_exists(filename):
|
| 23 |
+
"""Check if a file exists using fsspec."""
|
| 24 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 25 |
+
return fs.exists(filename)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fsspec_listdir(dirname):
|
| 29 |
+
"""Listdir in manner compatible with fsspec."""
|
| 30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 31 |
+
return fs.ls(dirname)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 35 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 36 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 37 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def print_nans(tensor, name):
|
| 41 |
+
if torch.isnan(tensor).any():
|
| 42 |
+
print(name, tensor)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CosineDecayWarmupLRScheduler(
|
| 46 |
+
CosineLRScheduler,
|
| 47 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 48 |
+
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
self._last_epoch = -1
|
| 52 |
+
self.step(epoch=0)
|
| 53 |
+
|
| 54 |
+
def step(self, epoch=None):
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self._last_epoch += 1
|
| 57 |
+
else:
|
| 58 |
+
self._last_epoch = epoch
|
| 59 |
+
# We call either step or step_update, depending on
|
| 60 |
+
# whether we're using the scheduler every epoch or every
|
| 61 |
+
# step.
|
| 62 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 63 |
+
# meant for each epoch), and if we set scheduler
|
| 64 |
+
# interval to "step", then the learning rate update will
|
| 65 |
+
# be wrong.
|
| 66 |
+
if self.t_in_epochs:
|
| 67 |
+
super().step(epoch=self._last_epoch)
|
| 68 |
+
else:
|
| 69 |
+
super().step_update(num_updates=self._last_epoch)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LoggingContext:
|
| 73 |
+
"""Context manager for selective logging."""
|
| 74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 75 |
+
self.logger = logger
|
| 76 |
+
self.level = level
|
| 77 |
+
self.handler = handler
|
| 78 |
+
self.close = close
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
if self.level is not None:
|
| 82 |
+
self.old_level = self.logger.level
|
| 83 |
+
self.logger.setLevel(self.level)
|
| 84 |
+
if self.handler:
|
| 85 |
+
self.logger.addHandler(self.handler)
|
| 86 |
+
|
| 87 |
+
def __exit__(self, et, ev, tb):
|
| 88 |
+
if self.level is not None:
|
| 89 |
+
self.logger.setLevel(self.old_level)
|
| 90 |
+
if self.handler:
|
| 91 |
+
self.logger.removeHandler(self.handler)
|
| 92 |
+
if self.handler and self.close:
|
| 93 |
+
self.handler.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 98 |
+
|
| 99 |
+
logger = logging.getLogger(name)
|
| 100 |
+
logger.setLevel(level)
|
| 101 |
+
|
| 102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 105 |
+
'exception', 'fatal', 'critical'):
|
| 106 |
+
setattr(logger,
|
| 107 |
+
level,
|
| 108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 109 |
+
getattr(logger, level)))
|
| 110 |
+
|
| 111 |
+
return logger
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def str2bool(v):
|
| 115 |
+
if isinstance(v, bool):
|
| 116 |
+
return v
|
| 117 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 118 |
+
return True
|
| 119 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 120 |
+
return False
|
| 121 |
+
else:
|
| 122 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def set_seed(seed, use_cuda):
|
| 126 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 127 |
+
np.random.seed(seed)
|
| 128 |
+
random.seed(seed)
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
# torch.backends.cudnn.deterministic = True
|
| 131 |
+
if use_cuda:
|
| 132 |
+
torch.cuda.manual_seed(seed)
|
| 133 |
+
torch.cuda.manual_seed_all(seed)
|
| 134 |
+
print(f'=> Seed of the run set to {seed}')
|
| 135 |
+
|
a2d2_mol/mol_utils/utils_chem.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import random
|
| 18 |
+
import safe as sf
|
| 19 |
+
import datamol as dm
|
| 20 |
+
from contextlib import suppress
|
| 21 |
+
from rdkit import Chem, RDLogger
|
| 22 |
+
RDLogger.DisableLog('rdApp.*')
|
| 23 |
+
|
| 24 |
+
# https://github.com/datamol-io/safe/blob/main/safe/sample.py
|
| 25 |
+
# https://github.com/jensengroup/GB_GA/blob/master/crossover.py
|
| 26 |
+
def safe_to_smiles(safe_str, fix=True):
|
| 27 |
+
if fix:
|
| 28 |
+
safe_str = '.'.join([frag for frag in safe_str.split('.')
|
| 29 |
+
if sf.decode(frag, ignore_errors=True) is not None])
|
| 30 |
+
return sf.decode(safe_str, canonical=True, ignore_errors=True)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _safe_to_smiles_worker(args):
|
| 34 |
+
"""Worker function for parallel SAFE to SMILES conversion."""
|
| 35 |
+
safe_str, use_bracket_safe, fix = args
|
| 36 |
+
try:
|
| 37 |
+
from mol_utils.bracket_safe_converter import bracketsafe2safe
|
| 38 |
+
if use_bracket_safe:
|
| 39 |
+
safe_str = bracketsafe2safe(safe_str)
|
| 40 |
+
return safe_to_smiles(safe_str, fix=fix)
|
| 41 |
+
except Exception:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
|
| 46 |
+
"""
|
| 47 |
+
Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
safe_strings: List of SAFE format strings
|
| 51 |
+
use_bracket_safe: Whether to convert from bracket SAFE format first
|
| 52 |
+
fix: Whether to fix invalid fragments
|
| 53 |
+
num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
List of SMILES strings (None for invalid molecules)
|
| 57 |
+
"""
|
| 58 |
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
| 59 |
+
import os
|
| 60 |
+
|
| 61 |
+
n = len(safe_strings)
|
| 62 |
+
if n == 0:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
# For small batches, use sequential processing (overhead not worth it)
|
| 66 |
+
if n <= 4:
|
| 67 |
+
if use_bracket_safe:
|
| 68 |
+
from mol_utils.bracket_safe_converter import bracketsafe2safe
|
| 69 |
+
return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
|
| 70 |
+
else:
|
| 71 |
+
return [safe_to_smiles(s, fix=fix) for s in safe_strings]
|
| 72 |
+
|
| 73 |
+
# Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
|
| 74 |
+
# ProcessPoolExecutor has too much overhead for this use case
|
| 75 |
+
if num_workers is None:
|
| 76 |
+
num_workers = min(os.cpu_count() or 4, n, 8)
|
| 77 |
+
|
| 78 |
+
args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
|
| 79 |
+
|
| 80 |
+
# ThreadPoolExecutor is faster here because:
|
| 81 |
+
# 1. No pickle serialization overhead
|
| 82 |
+
# 2. RDKit releases the GIL during computation
|
| 83 |
+
# 3. Lower startup cost
|
| 84 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 85 |
+
results = list(executor.map(_safe_to_smiles_worker, args_list))
|
| 86 |
+
|
| 87 |
+
return results
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
|
| 91 |
+
"""
|
| 92 |
+
Batch validate SMILES and extract valid samples efficiently.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
smiles_list: List of SMILES strings (may contain None for invalid)
|
| 96 |
+
samples_tensor: Tensor of token IDs (B, L)
|
| 97 |
+
log_rnd_tensor: Tensor of log random values (B,)
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
valid_sequences: List of valid SMILES (largest fragment)
|
| 101 |
+
valid_indices: List of indices of valid samples
|
| 102 |
+
"""
|
| 103 |
+
valid_sequences = []
|
| 104 |
+
valid_indices = []
|
| 105 |
+
|
| 106 |
+
for idx, smiles in enumerate(smiles_list):
|
| 107 |
+
if smiles: # Valid SMILES
|
| 108 |
+
# Take largest fragment if multiple
|
| 109 |
+
largest_fragment = sorted(smiles.split('.'), key=len)[-1]
|
| 110 |
+
valid_sequences.append(largest_fragment)
|
| 111 |
+
valid_indices.append(idx)
|
| 112 |
+
|
| 113 |
+
return valid_sequences, valid_indices
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def filter_by_substructure(sequences, substruct):
|
| 117 |
+
substruct = sf.utils.standardize_attach(substruct)
|
| 118 |
+
substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
|
| 119 |
+
substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
|
| 120 |
+
return sf.utils.filter_by_substructure_constraints(sequences, substruct)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
|
| 124 |
+
mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)
|
| 125 |
+
|
| 126 |
+
prefix_linkers = []
|
| 127 |
+
suffix_linkers = []
|
| 128 |
+
prefix_query = dm.from_smarts(prefix)
|
| 129 |
+
suffix_query = dm.from_smarts(suffix)
|
| 130 |
+
|
| 131 |
+
for x in prefix_sequences:
|
| 132 |
+
with suppress(Exception):
|
| 133 |
+
x = dm.to_mol(x)
|
| 134 |
+
out = mol_linker_slicer(x, prefix_query)
|
| 135 |
+
prefix_linkers.append(out[1])
|
| 136 |
+
|
| 137 |
+
for x in suffix_sequences:
|
| 138 |
+
with suppress(Exception):
|
| 139 |
+
x = dm.to_mol(x)
|
| 140 |
+
out = mol_linker_slicer(x, suffix_query)
|
| 141 |
+
suffix_linkers.append(out[1])
|
| 142 |
+
|
| 143 |
+
n_linked = 0
|
| 144 |
+
linked = []
|
| 145 |
+
linkers = prefix_linkers + suffix_linkers
|
| 146 |
+
linkers = [x for x in linkers if x is not None]
|
| 147 |
+
for n_linked, linker in enumerate(linkers):
|
| 148 |
+
linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
|
| 149 |
+
if n_linked > num_samples:
|
| 150 |
+
break
|
| 151 |
+
linked = [x for x in linked if x]
|
| 152 |
+
return linked[:num_samples]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def cut(smiles):
|
| 156 |
+
def cut_nonring(mol):
|
| 157 |
+
if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring
|
| 161 |
+
bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
|
| 162 |
+
fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
|
| 166 |
+
except ValueError:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 170 |
+
frags = set()
|
| 171 |
+
# non-ring cut
|
| 172 |
+
for _ in range(3):
|
| 173 |
+
frags_nonring = cut_nonring(mol)
|
| 174 |
+
if frags_nonring is not None:
|
| 175 |
+
frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
|
| 176 |
+
return frags
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Slicer:
|
| 180 |
+
def __call__(self, mol):
|
| 181 |
+
if isinstance(mol, str):
|
| 182 |
+
mol = Chem.MolFromSmiles(mol)
|
| 183 |
+
|
| 184 |
+
# non-ring single bonds
|
| 185 |
+
bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
|
| 186 |
+
for bond in bonds:
|
| 187 |
+
yield bond
|
a2d2_mol/oracle/fpscores.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
|
| 3 |
+
size 9048931
|
a2d2_mol/remasking_scheduleaware.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Schedule-aware remasking and insertion logic that ensures the number of masked tokens
|
| 3 |
+
follows the interpolant schedule.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def apply_schedule_aware_insertion(
|
| 9 |
+
model,
|
| 10 |
+
xt_tmp,
|
| 11 |
+
new_xt,
|
| 12 |
+
t,
|
| 13 |
+
dt,
|
| 14 |
+
ext,
|
| 15 |
+
mask,
|
| 16 |
+
pad,
|
| 17 |
+
max_length,
|
| 18 |
+
orig_mask,
|
| 19 |
+
new_pos_orig,
|
| 20 |
+
quality_threshold=1,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Remove low-quality insertions based on insertion confidence while respecting
|
| 24 |
+
the interpolant schedule for expected sequence length.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model: Model with planner and interpolant
|
| 28 |
+
xt_tmp: Sequence after insertion [B, L]
|
| 29 |
+
new_xt: Sequence before insertion [B, L]
|
| 30 |
+
t: Current time [B]
|
| 31 |
+
dt: Time step size
|
| 32 |
+
ext: Number of insertions per gap [B, L+1]
|
| 33 |
+
mask: Mask token ID
|
| 34 |
+
pad: Pad token ID
|
| 35 |
+
max_length: Maximum sequence length
|
| 36 |
+
orig_mask: Mask of original token positions [B, L]
|
| 37 |
+
new_pos_orig: New positions of original tokens [B, L]
|
| 38 |
+
quality_threshold: If a float, drop insertions with confidence below it
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
|
| 42 |
+
"""
|
| 43 |
+
device = xt_tmp.device
|
| 44 |
+
batch_size, L = xt_tmp.shape
|
| 45 |
+
total_ext = ext.sum(dim=1)
|
| 46 |
+
|
| 47 |
+
# Only proceed if there were insertions
|
| 48 |
+
if total_ext.sum() == 0:
|
| 49 |
+
return xt_tmp
|
| 50 |
+
|
| 51 |
+
# Get planner predictions on inserted state. The insertion head is trained
|
| 52 |
+
# with the pre-step time t (see loss_insert_planner_flexible), so condition
|
| 53 |
+
# on t here too; t_next is still used below for the length schedule.
|
| 54 |
+
t_next = t + dt
|
| 55 |
+
planner_out = model.planner(xt_tmp, t)
|
| 56 |
+
insertion_conf = planner_out.get("insertion_conf", None)
|
| 57 |
+
|
| 58 |
+
if insertion_conf is None:
|
| 59 |
+
return xt_tmp
|
| 60 |
+
|
| 61 |
+
insertion_conf = insertion_conf.squeeze(-1) # (B, L)
|
| 62 |
+
|
| 63 |
+
# Expected sequence length at next timestep according to schedule
|
| 64 |
+
current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
|
| 65 |
+
expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
|
| 66 |
+
estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
|
| 67 |
+
expected_length = estimated_final_length * expected_progress # [B]
|
| 68 |
+
|
| 69 |
+
# Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
|
| 70 |
+
# Fancy-indexing scatter avoids the per-batch python loop.
|
| 71 |
+
valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
|
| 72 |
+
valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
|
| 73 |
+
is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
|
| 74 |
+
is_original[valid_b, valid_p] = True
|
| 75 |
+
inserted_positions = (xt_tmp == mask) & ~is_original
|
| 76 |
+
|
| 77 |
+
# Two deletion modes, selected by `quality_threshold`:
|
| 78 |
+
# * float: drop insertions whose confidence is below the threshold, capped
|
| 79 |
+
# so the length never falls below the scheduled minimum.
|
| 80 |
+
candidates = inserted_positions & (insertion_conf < quality_threshold)
|
| 81 |
+
num_bad = candidates.sum(dim=1) # [B], long
|
| 82 |
+
min_length = expected_length.long().clamp(min=1) # [B]
|
| 83 |
+
max_removable = (current_length_after.long() - min_length).clamp(min=0)
|
| 84 |
+
length_after_removal = current_length_after.long() - num_bad
|
| 85 |
+
schedule_violates = length_after_removal < min_length
|
| 86 |
+
k_per_row = torch.where(schedule_violates, max_removable, num_bad)
|
| 87 |
+
k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
|
| 88 |
+
|
| 89 |
+
if not candidates.any():
|
| 90 |
+
return xt_tmp
|
| 91 |
+
|
| 92 |
+
# Select the lowest-confidence candidates per row via a sort.
|
| 93 |
+
neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
|
| 94 |
+
scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
|
| 95 |
+
_, sorted_indices = scores.sort(dim=1, descending=True)
|
| 96 |
+
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
|
| 97 |
+
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
|
| 98 |
+
final_bad = torch.zeros_like(candidates)
|
| 99 |
+
final_bad.scatter_(1, sorted_indices, keep_in_topk)
|
| 100 |
+
|
| 101 |
+
if not final_bad.any():
|
| 102 |
+
return xt_tmp
|
| 103 |
+
|
| 104 |
+
# Compact each row to the left (keep good, drop bad), then pad the tail.
|
| 105 |
+
# Stable sort by the bad flag pushes bad positions to the right.
|
| 106 |
+
sort_key = final_bad.long()
|
| 107 |
+
_, perm = torch.sort(sort_key, dim=1, stable=True)
|
| 108 |
+
xt_tmp = torch.gather(xt_tmp, 1, perm)
|
| 109 |
+
num_keep = (~final_bad).sum(dim=1) # [B]
|
| 110 |
+
tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
|
| 111 |
+
xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
|
| 112 |
+
|
| 113 |
+
return xt_tmp
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def apply_schedule_aware_remasking(
|
| 117 |
+
model,
|
| 118 |
+
new_xt,
|
| 119 |
+
t,
|
| 120 |
+
dt,
|
| 121 |
+
remasking_conf,
|
| 122 |
+
clean_index,
|
| 123 |
+
mask,
|
| 124 |
+
neg_inf,
|
| 125 |
+
batch_size,
|
| 126 |
+
unmask_quality_threshold=None,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
model: Model with interpolant that has an unmask_schedule
|
| 133 |
+
new_xt: Current sequence [B, L]
|
| 134 |
+
t: Current time [B]
|
| 135 |
+
dt: Time step size
|
| 136 |
+
remasking_conf: Confidence scores for tokens [B, L]
|
| 137 |
+
clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
|
| 138 |
+
mask: Mask token ID
|
| 139 |
+
neg_inf: Negative infinity tensor
|
| 140 |
+
batch_size: Batch size
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
new_xt: Modified sequence with schedule-aware remasking applied
|
| 144 |
+
"""
|
| 145 |
+
# Optional AJD threshold gate (overrides the schedule-driven count when set):
|
| 146 |
+
# remask every clean token whose unmasking-quality confidence is below the
|
| 147 |
+
# threshold. Higher threshold => more aggressive remasking.
|
| 148 |
+
if unmask_quality_threshold is not None:
|
| 149 |
+
to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
|
| 150 |
+
return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
|
| 151 |
+
|
| 152 |
+
t_next = t + dt
|
| 153 |
+
num_clean = clean_index.sum(dim=1) # [B], long
|
| 154 |
+
current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
|
| 155 |
+
expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
|
| 156 |
+
expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
|
| 157 |
+
masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
|
| 158 |
+
|
| 159 |
+
# Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
|
| 160 |
+
k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
|
| 161 |
+
|
| 162 |
+
if k_per_row.sum() == 0:
|
| 163 |
+
return new_xt
|
| 164 |
+
|
| 165 |
+
# Use confidence to decide which clean tokens to remask: lowest conf first.
|
| 166 |
+
remasking_score_temp = -1.0 * remasking_conf # low conf = high score
|
| 167 |
+
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
|
| 168 |
+
|
| 169 |
+
_, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
|
| 170 |
+
L = remasking_score_temp.shape[1]
|
| 171 |
+
positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
|
| 172 |
+
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
|
| 173 |
+
to_mask = torch.zeros_like(clean_index)
|
| 174 |
+
to_mask.scatter_(1, sorted_indices, keep_in_topk)
|
| 175 |
+
new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
|
| 176 |
+
|
| 177 |
+
return new_xt
|
a2d2_mol/sampling.py
ADDED
|
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Literal, Optional
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from lightning_modules.mdm import MaskedDiffusionModule
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SamplingTraceDatapoint:
|
| 16 |
+
t: float
|
| 17 |
+
event_type: Literal["insertion", "change"]
|
| 18 |
+
position: int
|
| 19 |
+
token: Any
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SamplingResult:
|
| 24 |
+
samples: torch.Tensor
|
| 25 |
+
# Trace is supposed to be processed sequentially as updates are not commutative
|
| 26 |
+
trace: Optional[list[SamplingTraceDatapoint]]
|
| 27 |
+
|
| 28 |
+
def __iter__(self):
|
| 29 |
+
yield from [self.samples, self.trace]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Sample from categorical distribution for each position using the transition probabilities
|
| 33 |
+
def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""Sample one token per position from probability distribution.
|
| 35 |
+
Args:
|
| 36 |
+
probs: [batch_size, seq_len, vocab_size] transition probabilities
|
| 37 |
+
Returns:
|
| 38 |
+
[batch_size, seq_len] sampled token indices
|
| 39 |
+
"""
|
| 40 |
+
batch_size, seq_len, vocab_size = probs.shape
|
| 41 |
+
flat_probs = probs.view(-1, vocab_size)
|
| 42 |
+
samples = torch.multinomial(flat_probs, num_samples=1)
|
| 43 |
+
return samples.view(batch_size, seq_len)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
|
| 48 |
+
batch_size, seq_len, vocab_size = probs.shape
|
| 49 |
+
|
| 50 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
|
| 51 |
+
noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
|
| 52 |
+
|
| 53 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 54 |
+
samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
|
| 55 |
+
|
| 56 |
+
return samples.view(batch_size, seq_len)
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def mdm_euler_sampling(
|
| 60 |
+
model: MaskedDiffusionModule,
|
| 61 |
+
steps: int,
|
| 62 |
+
mask: int,
|
| 63 |
+
pad: int,
|
| 64 |
+
batch_size: int,
|
| 65 |
+
max_length: int,
|
| 66 |
+
return_trace: bool = False,
|
| 67 |
+
temperature: float = 1.0,
|
| 68 |
+
):
|
| 69 |
+
assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
|
| 70 |
+
device = next(model.parameters()).device
|
| 71 |
+
xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
|
| 72 |
+
|
| 73 |
+
dt = 1.0 / steps
|
| 74 |
+
t = torch.zeros(batch_size, device=device)
|
| 75 |
+
|
| 76 |
+
for i in range(steps):
|
| 77 |
+
print("i-th sampling step")
|
| 78 |
+
# ——— predict and convert rates ———
|
| 79 |
+
pred_rate = model(xt, t)
|
| 80 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 81 |
+
unmask_rate = pred_rate.unmask_rate
|
| 82 |
+
|
| 83 |
+
# ——— unmask step (Euler) ———
|
| 84 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 85 |
+
unmask_rate[xt != mask] = 0
|
| 86 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 87 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 88 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 89 |
+
|
| 90 |
+
_xt = xt.clone()
|
| 91 |
+
trans_prob.scatter_add_(
|
| 92 |
+
2,
|
| 93 |
+
_xt.unsqueeze(-1),
|
| 94 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Apply temperature scaling
|
| 98 |
+
if temperature != 1.0:
|
| 99 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 100 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 101 |
+
|
| 102 |
+
if i == steps - 1:
|
| 103 |
+
print("Final step, removing mask token from sampling")
|
| 104 |
+
trans_prob[mask_pos + (mask,)] = 0.0
|
| 105 |
+
print(trans_prob[mask_pos + (mask,)])
|
| 106 |
+
|
| 107 |
+
new_xt = _sample_tokens(trans_prob)
|
| 108 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 109 |
+
|
| 110 |
+
xt = new_xt
|
| 111 |
+
t = t + dt
|
| 112 |
+
|
| 113 |
+
return xt, []
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def any_order_mask_insertion_euler_sampling(
|
| 118 |
+
model: torch.nn.Module,
|
| 119 |
+
steps: int,
|
| 120 |
+
mask: int,
|
| 121 |
+
pad: int,
|
| 122 |
+
batch_size: int,
|
| 123 |
+
max_length: int,
|
| 124 |
+
return_trace: bool = False,
|
| 125 |
+
temperature: float = 1.0,
|
| 126 |
+
) -> SamplingResult:
|
| 127 |
+
device = next(model.parameters()).device
|
| 128 |
+
|
| 129 |
+
# 1) Initialize all‑pad sequence and trace
|
| 130 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 131 |
+
sampling_trace = []
|
| 132 |
+
|
| 133 |
+
dt = 1.0 / steps
|
| 134 |
+
t = torch.zeros(batch_size, device=device)
|
| 135 |
+
|
| 136 |
+
# Precompute row indices for scatter
|
| 137 |
+
batch_idx_L = (
|
| 138 |
+
torch.arange(batch_size, device=device)
|
| 139 |
+
.view(batch_size, 1)
|
| 140 |
+
.expand(batch_size, max_length)
|
| 141 |
+
)
|
| 142 |
+
pos_idx_L = (
|
| 143 |
+
torch.arange(max_length, device=device)
|
| 144 |
+
.view(1, max_length)
|
| 145 |
+
.expand(batch_size, max_length)
|
| 146 |
+
)
|
| 147 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 148 |
+
|
| 149 |
+
for i in range(steps):
|
| 150 |
+
# ——— predict and convert rates ———
|
| 151 |
+
pred_rate = model(xt, t)
|
| 152 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 153 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 154 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 155 |
+
|
| 156 |
+
# ——— unmask step (Euler) ———
|
| 157 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 158 |
+
unmask_rate[xt != mask] = 0
|
| 159 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 160 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 161 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 162 |
+
|
| 163 |
+
# add “stay” probability
|
| 164 |
+
_xt = xt.clone()
|
| 165 |
+
_xt[xt == pad] = mask
|
| 166 |
+
trans_prob.scatter_add_(
|
| 167 |
+
2,
|
| 168 |
+
_xt.unsqueeze(-1),
|
| 169 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if i == steps - 1:
|
| 173 |
+
print("Final step, removing mask token from sampling")
|
| 174 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 175 |
+
|
| 176 |
+
# renormalize probabilities to ensure they sum to 1
|
| 177 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 178 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 179 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 180 |
+
if mask_has_zero_prob.any():
|
| 181 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 182 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 183 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 184 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 185 |
+
else:
|
| 186 |
+
# normalize to sum to 1
|
| 187 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 188 |
+
|
| 189 |
+
new_xt = _sample_tokens(trans_prob)
|
| 190 |
+
new_xt[xt == pad] = pad
|
| 191 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 192 |
+
|
| 193 |
+
if i != steps - 1:
|
| 194 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 195 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 196 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 197 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 198 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 199 |
+
total_ext = ext.sum(dim=1)
|
| 200 |
+
valid = xt_len + total_ext <= max_length
|
| 201 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 202 |
+
|
| 203 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 204 |
+
new_len = xt_len + total_ext # (B,)
|
| 205 |
+
|
| 206 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 207 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 208 |
+
xt_tmp[mask_fill] = mask
|
| 209 |
+
|
| 210 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 211 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 212 |
+
flat_b = batch_idx_L[orig_mask]
|
| 213 |
+
flat_p = new_pos_orig[orig_mask]
|
| 214 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 215 |
+
else:
|
| 216 |
+
xt_tmp = new_xt
|
| 217 |
+
|
| 218 |
+
if return_trace:
|
| 219 |
+
# Check if the token was changed
|
| 220 |
+
for batch_idx in range(batch_size):
|
| 221 |
+
for j in range(max_length):
|
| 222 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 223 |
+
sampling_trace[batch_idx].append(
|
| 224 |
+
SamplingTraceDatapoint(
|
| 225 |
+
t=t[batch_idx].item(),
|
| 226 |
+
event_type="change",
|
| 227 |
+
position=j,
|
| 228 |
+
token=new_xt[batch_idx, j].item(),
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Check if a new token was inserted
|
| 233 |
+
for j in range(max_length):
|
| 234 |
+
id = max_length - j - 1
|
| 235 |
+
if ext[batch_idx, id]:
|
| 236 |
+
sampling_trace[batch_idx].append(
|
| 237 |
+
SamplingTraceDatapoint(
|
| 238 |
+
t=t[batch_idx].item(),
|
| 239 |
+
event_type="insertion",
|
| 240 |
+
position=id,
|
| 241 |
+
token=mask,
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
xt = xt_tmp
|
| 246 |
+
t = t + dt
|
| 247 |
+
|
| 248 |
+
return xt, sampling_trace
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def batch_mcts_reverse_step(
|
| 252 |
+
xt: torch.Tensor,
|
| 253 |
+
t: torch.Tensor,
|
| 254 |
+
dt: float,
|
| 255 |
+
model: torch.nn.Module,
|
| 256 |
+
pretrained: torch.nn.Module,
|
| 257 |
+
mask: int,
|
| 258 |
+
pad: int,
|
| 259 |
+
batch_size: int,
|
| 260 |
+
max_length: int,
|
| 261 |
+
last_step: bool = False,
|
| 262 |
+
temperature: float = 1.0,
|
| 263 |
+
) -> SamplingResult:
|
| 264 |
+
device = next(model.parameters()).device
|
| 265 |
+
|
| 266 |
+
xt = xt.repeat(batch_size, 1)
|
| 267 |
+
|
| 268 |
+
# squeeze to remove extra dimensions, then expand to batch_size
|
| 269 |
+
t = t.squeeze().expand(batch_size)
|
| 270 |
+
# precompute row indices for scatter
|
| 271 |
+
batch_idx_L = (
|
| 272 |
+
torch.arange(batch_size, device=device)
|
| 273 |
+
.view(batch_size, 1)
|
| 274 |
+
.expand(batch_size, max_length)
|
| 275 |
+
)
|
| 276 |
+
pos_idx_L = (
|
| 277 |
+
torch.arange(max_length, device=device)
|
| 278 |
+
.view(1, max_length)
|
| 279 |
+
.expand(batch_size, max_length)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# ——— predict and convert rates ———
|
| 283 |
+
pred_rate = model(xt, t)
|
| 284 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 285 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 286 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 287 |
+
|
| 288 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 289 |
+
pretrained_pred = pretrained(xt, t)
|
| 290 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 291 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 292 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 293 |
+
|
| 294 |
+
# ——— unmask step (Euler) ———
|
| 295 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 296 |
+
unmask_rate[xt != mask] = 0
|
| 297 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 298 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 299 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 300 |
+
|
| 301 |
+
# Same for pretrained
|
| 302 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 303 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 304 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 305 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 306 |
+
|
| 307 |
+
# add “stay” probability
|
| 308 |
+
_xt = xt.clone()
|
| 309 |
+
_xt[xt == pad] = mask
|
| 310 |
+
trans_prob.scatter_add_(
|
| 311 |
+
2,
|
| 312 |
+
_xt.unsqueeze(-1),
|
| 313 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 314 |
+
)
|
| 315 |
+
pretrained_trans_prob.scatter_add_(
|
| 316 |
+
2,
|
| 317 |
+
_xt.unsqueeze(-1),
|
| 318 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if last_step:
|
| 322 |
+
print("Final step, removing mask token from sampling")
|
| 323 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 324 |
+
|
| 325 |
+
# renormalize probabilities to ensure they sum to 1
|
| 326 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 327 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 328 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 329 |
+
if mask_has_zero_prob.any():
|
| 330 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 331 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 332 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 333 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 334 |
+
else:
|
| 335 |
+
# normalize to sum to 1
|
| 336 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 337 |
+
|
| 338 |
+
new_xt = _sample_tokens(trans_prob)
|
| 339 |
+
new_xt[xt == pad] = pad
|
| 340 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 341 |
+
|
| 342 |
+
# ——— compute log probabilities for RND ———
|
| 343 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 344 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 345 |
+
|
| 346 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 347 |
+
|
| 348 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 349 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 350 |
+
|
| 351 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 352 |
+
|
| 353 |
+
if not last_step:
|
| 354 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 355 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 356 |
+
|
| 357 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 358 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 359 |
+
|
| 360 |
+
# log P(ext; λ) = ext*log(λ) - λ
|
| 361 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 362 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 363 |
+
|
| 364 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 365 |
+
log_rnd += log_insert_diff
|
| 366 |
+
log_pretrained_step += log_pretrained_insert
|
| 367 |
+
log_policy_step += log_policy_insert
|
| 368 |
+
|
| 369 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 370 |
+
seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
|
| 371 |
+
gaps = torch.arange(seq_dim, device=device).view(1, -1)
|
| 372 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 373 |
+
total_ext = ext.sum(dim=1)
|
| 374 |
+
valid = xt_len + total_ext <= max_length
|
| 375 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 376 |
+
|
| 377 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 378 |
+
new_len = xt_len + total_ext # (B,)
|
| 379 |
+
|
| 380 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 381 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 382 |
+
xt_tmp[mask_fill] = mask
|
| 383 |
+
|
| 384 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 385 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 386 |
+
flat_b = batch_idx_L[orig_mask]
|
| 387 |
+
flat_p = new_pos_orig[orig_mask]
|
| 388 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 389 |
+
else:
|
| 390 |
+
xt_tmp = new_xt
|
| 391 |
+
|
| 392 |
+
return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@torch.no_grad()
|
| 396 |
+
def mcts_reverse_step(
|
| 397 |
+
xt: torch.Tensor,
|
| 398 |
+
t: torch.Tensor,
|
| 399 |
+
dt: float,
|
| 400 |
+
model: torch.nn.Module,
|
| 401 |
+
pretrained: torch.nn.Module,
|
| 402 |
+
mask: int,
|
| 403 |
+
pad: int,
|
| 404 |
+
max_length: int,
|
| 405 |
+
last_step: bool = False,
|
| 406 |
+
temperature: float = 1.0,
|
| 407 |
+
) -> SamplingResult:
|
| 408 |
+
device = next(model.parameters()).device
|
| 409 |
+
|
| 410 |
+
batch_size = xt.size(0)
|
| 411 |
+
|
| 412 |
+
# precompute row indices for scatter
|
| 413 |
+
batch_idx_L = (
|
| 414 |
+
torch.arange(batch_size, device=device)
|
| 415 |
+
.view(batch_size, 1)
|
| 416 |
+
.expand(batch_size, max_length)
|
| 417 |
+
)
|
| 418 |
+
pos_idx_L = (
|
| 419 |
+
torch.arange(max_length, device=device)
|
| 420 |
+
.view(1, max_length)
|
| 421 |
+
.expand(batch_size, max_length)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# ——— predict and convert rates ———
|
| 425 |
+
pred_rate = model(xt, t)
|
| 426 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 427 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 428 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 429 |
+
|
| 430 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 431 |
+
pretrained_pred = pretrained(xt, t)
|
| 432 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 433 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 434 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 435 |
+
|
| 436 |
+
# ——— unmask step (Euler) ———
|
| 437 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 438 |
+
unmask_rate[xt != mask] = 0
|
| 439 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 440 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 441 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 442 |
+
|
| 443 |
+
# same for pretrained
|
| 444 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 445 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 446 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 447 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 448 |
+
|
| 449 |
+
# add “stay” probability
|
| 450 |
+
_xt = xt.clone()
|
| 451 |
+
_xt[xt == pad] = mask
|
| 452 |
+
trans_prob.scatter_add_(
|
| 453 |
+
2,
|
| 454 |
+
_xt.unsqueeze(-1),
|
| 455 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 456 |
+
)
|
| 457 |
+
pretrained_trans_prob.scatter_add_(
|
| 458 |
+
2,
|
| 459 |
+
_xt.unsqueeze(-1),
|
| 460 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if last_step:
|
| 464 |
+
print("Final step, removing mask token from sampling")
|
| 465 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 466 |
+
|
| 467 |
+
# renormalize probabilities to ensure they sum to 1
|
| 468 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 469 |
+
# avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 470 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 471 |
+
if mask_has_zero_prob.any():
|
| 472 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 473 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 474 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 475 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 476 |
+
else:
|
| 477 |
+
# normalize to sum to 1
|
| 478 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 479 |
+
|
| 480 |
+
new_xt = _sample_tokens(trans_prob)
|
| 481 |
+
new_xt[xt == pad] = pad
|
| 482 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 483 |
+
|
| 484 |
+
# ——— compute log probabilities for RND ———
|
| 485 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 486 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 487 |
+
|
| 488 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 489 |
+
|
| 490 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 491 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 492 |
+
|
| 493 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 494 |
+
|
| 495 |
+
if not last_step:
|
| 496 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 497 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 498 |
+
|
| 499 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 500 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 501 |
+
|
| 502 |
+
# log P(ext; λ) = ext*log(λ) - λ
|
| 503 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 504 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 505 |
+
|
| 506 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 507 |
+
log_rnd += log_insert_diff
|
| 508 |
+
log_pretrained_step += log_pretrained_insert
|
| 509 |
+
log_policy_step += log_policy_insert
|
| 510 |
+
|
| 511 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 512 |
+
seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
|
| 513 |
+
gaps = torch.arange(seq_dim, device=device).view(1, -1)
|
| 514 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 515 |
+
total_ext = ext.sum(dim=1)
|
| 516 |
+
valid = xt_len + total_ext <= max_length
|
| 517 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 518 |
+
|
| 519 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 520 |
+
new_len = xt_len + total_ext # (B,)
|
| 521 |
+
|
| 522 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 523 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 524 |
+
xt_tmp[mask_fill] = mask
|
| 525 |
+
|
| 526 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 527 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 528 |
+
flat_b = batch_idx_L[orig_mask]
|
| 529 |
+
flat_p = new_pos_orig[orig_mask]
|
| 530 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 531 |
+
else:
|
| 532 |
+
xt_tmp = new_xt
|
| 533 |
+
|
| 534 |
+
return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
|
| 535 |
+
|
| 536 |
+
@torch.no_grad()
|
| 537 |
+
def any_order_euler_sampling_with_schedule(
|
| 538 |
+
model: torch.nn.Module,
|
| 539 |
+
time_schedule: torch.Tensor,
|
| 540 |
+
mask: int,
|
| 541 |
+
pad: int,
|
| 542 |
+
batch_size: int,
|
| 543 |
+
max_length: int,
|
| 544 |
+
return_trace: bool = False,
|
| 545 |
+
temperature: float = 1.0,
|
| 546 |
+
) -> SamplingResult:
|
| 547 |
+
device = next(model.parameters()).device
|
| 548 |
+
|
| 549 |
+
time_schedule = time_schedule.to(device)
|
| 550 |
+
if time_schedule[0] < time_schedule[-1]:
|
| 551 |
+
time_schedule = torch.flip(time_schedule, [0]) # descending order
|
| 552 |
+
|
| 553 |
+
steps = len(time_schedule) - 1
|
| 554 |
+
|
| 555 |
+
# initialize all-pad sequence and trace
|
| 556 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 557 |
+
|
| 558 |
+
# precompute row indices for scatter
|
| 559 |
+
batch_idx_L = (
|
| 560 |
+
torch.arange(batch_size, device=device)
|
| 561 |
+
.view(batch_size, 1)
|
| 562 |
+
.expand(batch_size, max_length)
|
| 563 |
+
)
|
| 564 |
+
pos_idx_L = (
|
| 565 |
+
torch.arange(max_length, device=device)
|
| 566 |
+
.view(1, max_length)
|
| 567 |
+
.expand(batch_size, max_length)
|
| 568 |
+
)
|
| 569 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 570 |
+
|
| 571 |
+
for i in range(steps):
|
| 572 |
+
# use scheduled timesteps
|
| 573 |
+
t = time_schedule[i].repeat(batch_size)
|
| 574 |
+
t_next = time_schedule[i + 1]
|
| 575 |
+
dt = (t - t_next).abs() # timestep difference
|
| 576 |
+
|
| 577 |
+
# ——— predict and convert rates ———
|
| 578 |
+
pred_rate = model(xt, t)
|
| 579 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 580 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 581 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 582 |
+
|
| 583 |
+
# ——— unmask step (Euler) ———
|
| 584 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 585 |
+
unmask_rate[xt != mask] = 0
|
| 586 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 587 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 588 |
+
trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
|
| 589 |
+
|
| 590 |
+
# add "stay" probability
|
| 591 |
+
_xt = xt.clone()
|
| 592 |
+
_xt[xt == pad] = mask
|
| 593 |
+
trans_prob.scatter_add_(
|
| 594 |
+
2,
|
| 595 |
+
_xt.unsqueeze(-1),
|
| 596 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# Apply temperature scaling
|
| 600 |
+
if temperature != 1.0:
|
| 601 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 602 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 603 |
+
|
| 604 |
+
if i == steps - 1:
|
| 605 |
+
print("Final step, removing mask token from sampling")
|
| 606 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 607 |
+
|
| 608 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 609 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 610 |
+
|
| 611 |
+
if mask_has_zero_prob.any():
|
| 612 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 613 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 614 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 615 |
+
else:
|
| 616 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 617 |
+
|
| 618 |
+
new_xt = _sample_tokens(trans_prob)
|
| 619 |
+
new_xt[xt == pad] = pad
|
| 620 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 621 |
+
|
| 622 |
+
if i != steps - 1:
|
| 623 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 624 |
+
ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 625 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 626 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 627 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 628 |
+
total_ext = ext.sum(dim=1)
|
| 629 |
+
valid = xt_len + total_ext <= max_length
|
| 630 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 631 |
+
|
| 632 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 633 |
+
new_len = xt_len + total_ext # (B,)
|
| 634 |
+
|
| 635 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 636 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 637 |
+
xt_tmp[mask_fill] = mask
|
| 638 |
+
|
| 639 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 640 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 641 |
+
flat_b = batch_idx_L[orig_mask]
|
| 642 |
+
flat_p = new_pos_orig[orig_mask]
|
| 643 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 644 |
+
else:
|
| 645 |
+
xt_tmp = new_xt
|
| 646 |
+
|
| 647 |
+
if return_trace:
|
| 648 |
+
# Check if the token was changed
|
| 649 |
+
for batch_idx in range(batch_size):
|
| 650 |
+
for j in range(max_length):
|
| 651 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 652 |
+
sampling_trace[batch_idx].append(
|
| 653 |
+
SamplingTraceDatapoint(
|
| 654 |
+
t=t[batch_idx].item(),
|
| 655 |
+
event_type="change",
|
| 656 |
+
position=j,
|
| 657 |
+
token=new_xt[batch_idx, j].item(),
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Check if a new token was inserted
|
| 662 |
+
for j in range(max_length):
|
| 663 |
+
id = max_length - j - 1
|
| 664 |
+
if ext[batch_idx, id]:
|
| 665 |
+
sampling_trace[batch_idx].append(
|
| 666 |
+
SamplingTraceDatapoint(
|
| 667 |
+
t=t[batch_idx].item(),
|
| 668 |
+
event_type="insertion",
|
| 669 |
+
position=id,
|
| 670 |
+
token=mask,
|
| 671 |
+
)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
xt = xt_tmp
|
| 675 |
+
|
| 676 |
+
return xt, sampling_trace
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@torch.no_grad()
|
| 680 |
+
def any_order_mask_insertion_euler_sampling_with_rnd(
|
| 681 |
+
model, pretrained, reward_model, analyzer,
|
| 682 |
+
tokenizer, steps,
|
| 683 |
+
mask,
|
| 684 |
+
pad,
|
| 685 |
+
batch_size,
|
| 686 |
+
max_length,
|
| 687 |
+
return_trace = False,
|
| 688 |
+
alpha = 0.1,
|
| 689 |
+
temperature: float = 1.0,
|
| 690 |
+
):
|
| 691 |
+
device = next(model.parameters()).device
|
| 692 |
+
|
| 693 |
+
# initialize all‑pad sequence and trace
|
| 694 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 695 |
+
sampling_trace = []
|
| 696 |
+
|
| 697 |
+
# initialize log_rnd to accumulate log probability ratios
|
| 698 |
+
log_rnd = torch.zeros(batch_size, device=device)
|
| 699 |
+
|
| 700 |
+
dt = 1.0 / steps
|
| 701 |
+
t = torch.zeros(batch_size, device=device)
|
| 702 |
+
|
| 703 |
+
# precompute row indices for scatter
|
| 704 |
+
batch_idx_L = (
|
| 705 |
+
torch.arange(batch_size, device=device)
|
| 706 |
+
.view(batch_size, 1)
|
| 707 |
+
.expand(batch_size, max_length)
|
| 708 |
+
)
|
| 709 |
+
pos_idx_L = (
|
| 710 |
+
torch.arange(max_length, device=device)
|
| 711 |
+
.view(1, max_length)
|
| 712 |
+
.expand(batch_size, max_length)
|
| 713 |
+
)
|
| 714 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 715 |
+
|
| 716 |
+
for i in range(steps):
|
| 717 |
+
# ——— predict and convert rates ———
|
| 718 |
+
pred_rate = model(xt, t)
|
| 719 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 720 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 721 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 722 |
+
|
| 723 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 724 |
+
pretrained_pred = pretrained(xt, t)
|
| 725 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 726 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 727 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 728 |
+
|
| 729 |
+
# ——— unmask step (Euler) ———
|
| 730 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 731 |
+
unmask_rate[xt != mask] = 0
|
| 732 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 733 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 734 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 735 |
+
|
| 736 |
+
# Same for pretrained
|
| 737 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 738 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 739 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 740 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 741 |
+
|
| 742 |
+
# add “stay” probability
|
| 743 |
+
_xt = xt.clone()
|
| 744 |
+
_xt[xt == pad] = mask
|
| 745 |
+
trans_prob.scatter_add_(
|
| 746 |
+
2,
|
| 747 |
+
_xt.unsqueeze(-1),
|
| 748 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 749 |
+
)
|
| 750 |
+
pretrained_trans_prob.scatter_add_(
|
| 751 |
+
2,
|
| 752 |
+
_xt.unsqueeze(-1),
|
| 753 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# Apply temperature scaling
|
| 757 |
+
if temperature != 1.0:
|
| 758 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 759 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 760 |
+
|
| 761 |
+
if i == steps - 1:
|
| 762 |
+
print("Final step, removing mask token from sampling")
|
| 763 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 764 |
+
|
| 765 |
+
# renormalize probabilities to ensure they sum to 1
|
| 766 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 767 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 768 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 769 |
+
if mask_has_zero_prob.any():
|
| 770 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 771 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 772 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 773 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 774 |
+
else:
|
| 775 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 776 |
+
|
| 777 |
+
new_xt = _sample_tokens(trans_prob)
|
| 778 |
+
new_xt[xt == pad] = pad
|
| 779 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 780 |
+
|
| 781 |
+
# ——— compute log probabilities for RND ———
|
| 782 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 783 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 784 |
+
|
| 785 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 786 |
+
|
| 787 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 788 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 789 |
+
|
| 790 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 791 |
+
|
| 792 |
+
if i != steps - 1:
|
| 793 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 794 |
+
|
| 795 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 796 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 797 |
+
|
| 798 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 799 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 800 |
+
|
| 801 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 802 |
+
log_rnd += log_insert_diff
|
| 803 |
+
|
| 804 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 805 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 806 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 807 |
+
total_ext = ext.sum(dim=1)
|
| 808 |
+
valid = xt_len + total_ext <= max_length
|
| 809 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 810 |
+
|
| 811 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 812 |
+
new_len = xt_len + total_ext # (B,)
|
| 813 |
+
|
| 814 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 815 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 816 |
+
xt_tmp[mask_fill] = mask
|
| 817 |
+
|
| 818 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 819 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 820 |
+
flat_b = batch_idx_L[orig_mask]
|
| 821 |
+
flat_p = new_pos_orig[orig_mask]
|
| 822 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 823 |
+
else:
|
| 824 |
+
xt_tmp = new_xt
|
| 825 |
+
|
| 826 |
+
if return_trace:
|
| 827 |
+
# check if the token was changed
|
| 828 |
+
for i in range(batch_size):
|
| 829 |
+
for j in range(max_length):
|
| 830 |
+
if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
|
| 831 |
+
sampling_trace[i].append(
|
| 832 |
+
SamplingTraceDatapoint(
|
| 833 |
+
t=t[i].item(),
|
| 834 |
+
event_type="change",
|
| 835 |
+
position=j,
|
| 836 |
+
token=new_xt[i, j].item(),
|
| 837 |
+
)
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
# check if a new token was inserted
|
| 841 |
+
for j in range(max_length):
|
| 842 |
+
id = max_length - j - 1
|
| 843 |
+
if ext[i, id]:
|
| 844 |
+
sampling_trace[i].append(
|
| 845 |
+
SamplingTraceDatapoint(
|
| 846 |
+
t=t[i].item(),
|
| 847 |
+
event_type="insertion",
|
| 848 |
+
position=id,
|
| 849 |
+
token=mask,
|
| 850 |
+
)
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
xt = xt_tmp
|
| 854 |
+
t = t + dt
|
| 855 |
+
|
| 856 |
+
# change rewards for peptides
|
| 857 |
+
samples = xt.to(device)
|
| 858 |
+
|
| 859 |
+
# store raw token IDs
|
| 860 |
+
# Decode and strip samples
|
| 861 |
+
decoded_samples = tokenizer.batch_decode(samples)
|
| 862 |
+
|
| 863 |
+
valid_x_final = []
|
| 864 |
+
validSequences = []
|
| 865 |
+
valid_log_rnd = []
|
| 866 |
+
|
| 867 |
+
for idx, seq in enumerate(decoded_samples):
|
| 868 |
+
# check if the peptide is valid
|
| 869 |
+
if analyzer.is_peptide(seq):
|
| 870 |
+
valid_x_final.append(xt[idx])
|
| 871 |
+
validSequences.append(seq)
|
| 872 |
+
valid_log_rnd.append(log_rnd[idx])
|
| 873 |
+
|
| 874 |
+
print("len valid sequences:", len(validSequences))
|
| 875 |
+
# compute multi-objective rewards
|
| 876 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 877 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 878 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
|
| 879 |
+
|
| 880 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 881 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 882 |
+
|
| 883 |
+
log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
|
| 884 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 885 |
+
|
| 886 |
+
return valid_x_final, log_rnd, scalar_rewards, sampling_trace
|
| 887 |
+
|
| 888 |
+
@torch.no_grad()
|
| 889 |
+
def any_order_finetuned_euler_sampler(
|
| 890 |
+
model, reward_model, analyzer,
|
| 891 |
+
tokenizer, steps,
|
| 892 |
+
mask,
|
| 893 |
+
pad,
|
| 894 |
+
batch_size,
|
| 895 |
+
max_length,
|
| 896 |
+
return_trace = False,
|
| 897 |
+
dataframe = False,
|
| 898 |
+
temperature: float = 1.0,
|
| 899 |
+
):
|
| 900 |
+
device = next(model.parameters()).device
|
| 901 |
+
|
| 902 |
+
# initialize all‑pad sequence and trace
|
| 903 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 904 |
+
sampling_trace = []
|
| 905 |
+
|
| 906 |
+
dt = 1.0 / steps
|
| 907 |
+
t = torch.zeros(batch_size, device=device)
|
| 908 |
+
|
| 909 |
+
# precompute row indices for scatter
|
| 910 |
+
batch_idx_L = (
|
| 911 |
+
torch.arange(batch_size, device=device)
|
| 912 |
+
.view(batch_size, 1)
|
| 913 |
+
.expand(batch_size, max_length)
|
| 914 |
+
)
|
| 915 |
+
pos_idx_L = (
|
| 916 |
+
torch.arange(max_length, device=device)
|
| 917 |
+
.view(1, max_length)
|
| 918 |
+
.expand(batch_size, max_length)
|
| 919 |
+
)
|
| 920 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 921 |
+
|
| 922 |
+
for i in range(steps):
|
| 923 |
+
# ——— predict and convert rates ———
|
| 924 |
+
pred_rate = model(xt, t)
|
| 925 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 926 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 927 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 928 |
+
|
| 929 |
+
# ——— unmask step (Euler) ———
|
| 930 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 931 |
+
unmask_rate[xt != mask] = 0
|
| 932 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 933 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 934 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 935 |
+
|
| 936 |
+
# add “stay” probability
|
| 937 |
+
_xt = xt.clone()
|
| 938 |
+
_xt[xt == pad] = mask
|
| 939 |
+
trans_prob.scatter_add_(
|
| 940 |
+
2,
|
| 941 |
+
_xt.unsqueeze(-1),
|
| 942 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
# Apply temperature scaling
|
| 946 |
+
if temperature != 1.0:
|
| 947 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 948 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 949 |
+
|
| 950 |
+
if i == steps - 1:
|
| 951 |
+
print("Final step, removing mask token from sampling")
|
| 952 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 953 |
+
|
| 954 |
+
# renormalize probabilities to ensure they sum to 1
|
| 955 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 956 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 957 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 958 |
+
if mask_has_zero_prob.any():
|
| 959 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 960 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 961 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 962 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 963 |
+
else:
|
| 964 |
+
# normalize to sum to 1
|
| 965 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 966 |
+
|
| 967 |
+
new_xt = _sample_tokens(trans_prob)
|
| 968 |
+
new_xt[xt == pad] = pad
|
| 969 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 970 |
+
|
| 971 |
+
if i != steps - 1:
|
| 972 |
+
# gap-wise insertion refactored — compute new length, fill masks, scatter tokens
|
| 973 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 974 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 975 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 976 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 977 |
+
total_ext = ext.sum(dim=1)
|
| 978 |
+
valid = xt_len + total_ext <= max_length
|
| 979 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 980 |
+
|
| 981 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 982 |
+
new_len = xt_len + total_ext # (B,)
|
| 983 |
+
|
| 984 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 985 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 986 |
+
xt_tmp[mask_fill] = mask
|
| 987 |
+
|
| 988 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 989 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 990 |
+
flat_b = batch_idx_L[orig_mask]
|
| 991 |
+
flat_p = new_pos_orig[orig_mask]
|
| 992 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 993 |
+
else:
|
| 994 |
+
xt_tmp = new_xt
|
| 995 |
+
|
| 996 |
+
if return_trace:
|
| 997 |
+
# check if the token was changed
|
| 998 |
+
for batch_idx in range(batch_size):
|
| 999 |
+
for j in range(max_length):
|
| 1000 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 1001 |
+
sampling_trace[batch_idx].append(
|
| 1002 |
+
SamplingTraceDatapoint(
|
| 1003 |
+
t=t[batch_idx].item(),
|
| 1004 |
+
event_type="change",
|
| 1005 |
+
position=j,
|
| 1006 |
+
token=new_xt[batch_idx, j].item(),
|
| 1007 |
+
)
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
# check if a new token was inserted
|
| 1011 |
+
for j in range(max_length):
|
| 1012 |
+
id = max_length - j - 1
|
| 1013 |
+
if ext[batch_idx, id]:
|
| 1014 |
+
sampling_trace[batch_idx].append(
|
| 1015 |
+
SamplingTraceDatapoint(
|
| 1016 |
+
t=t[batch_idx].item(),
|
| 1017 |
+
event_type="insertion",
|
| 1018 |
+
position=id,
|
| 1019 |
+
token=mask,
|
| 1020 |
+
)
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
xt = xt_tmp
|
| 1024 |
+
t = t + dt
|
| 1025 |
+
|
| 1026 |
+
# start eval
|
| 1027 |
+
samples = xt.to(device)
|
| 1028 |
+
|
| 1029 |
+
decoded_samples = tokenizer.batch_decode(samples)
|
| 1030 |
+
|
| 1031 |
+
valid_x_final = []
|
| 1032 |
+
validSequences = []
|
| 1033 |
+
|
| 1034 |
+
for idx, seq in enumerate(decoded_samples):
|
| 1035 |
+
if analyzer.is_peptide(seq):
|
| 1036 |
+
valid_x_final.append(samples[idx])
|
| 1037 |
+
validSequences.append(seq)
|
| 1038 |
+
|
| 1039 |
+
print("len valid sequences:", len(validSequences))
|
| 1040 |
+
valid_fraction = len(validSequences) / batch_size
|
| 1041 |
+
|
| 1042 |
+
if (len(validSequences) != 0):
|
| 1043 |
+
# add scores to log
|
| 1044 |
+
score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
|
| 1045 |
+
average_scores = score_vectors.T
|
| 1046 |
+
|
| 1047 |
+
affinity = average_scores[0]
|
| 1048 |
+
sol = average_scores[1]
|
| 1049 |
+
hemo = average_scores[2]
|
| 1050 |
+
nf = average_scores[3]
|
| 1051 |
+
permeability = average_scores[4]
|
| 1052 |
+
|
| 1053 |
+
else:
|
| 1054 |
+
zeros = [0.0]
|
| 1055 |
+
|
| 1056 |
+
affinity = zeros
|
| 1057 |
+
sol = zeros
|
| 1058 |
+
hemo = zeros
|
| 1059 |
+
nf = zeros
|
| 1060 |
+
permeability = zeros
|
| 1061 |
+
|
| 1062 |
+
if dataframe:
|
| 1063 |
+
df = pd.DataFrame({
|
| 1064 |
+
"Peptide Sequence": validSequences,
|
| 1065 |
+
"Binding Affinity": affinity if len(validSequences) else [0.0],
|
| 1066 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 1067 |
+
"Hemolysis": hemo if len(validSequences) else [0.0],
|
| 1068 |
+
"Nonfouling": nf if len(validSequences) else [0.0],
|
| 1069 |
+
"Permeability": permeability if len(validSequences) else [0.0],
|
| 1070 |
+
})
|
| 1071 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 1072 |
+
|
| 1073 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 1074 |
+
|
| 1075 |
+
@torch.no_grad()
|
| 1076 |
+
def mdm_tau_leaping_sampling(
|
| 1077 |
+
model: MaskedDiffusionModule,
|
| 1078 |
+
steps: int,
|
| 1079 |
+
mask: int,
|
| 1080 |
+
pad: int,
|
| 1081 |
+
batch_size: int,
|
| 1082 |
+
max_length: int,
|
| 1083 |
+
return_trace: bool = False,
|
| 1084 |
+
temperature: float = 1.0,
|
| 1085 |
+
):
|
| 1086 |
+
assert not return_trace, "Trace is not yet supported"
|
| 1087 |
+
device = next(model.parameters()).device
|
| 1088 |
+
xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
|
| 1089 |
+
dt = 1.0 / steps
|
| 1090 |
+
t = torch.zeros(batch_size, device=device)
|
| 1091 |
+
|
| 1092 |
+
for i in range(steps):
|
| 1093 |
+
# ——— predict and convert rates ———
|
| 1094 |
+
pred = model(xt, t)
|
| 1095 |
+
pred = model.interpolant.to_actual_rate(xt, pred, t)
|
| 1096 |
+
unmask_rate = pred.unmask_rate # (B, L, V)
|
| 1097 |
+
|
| 1098 |
+
if i == steps - 1:
|
| 1099 |
+
# last step: deterministic unmask via argmax
|
| 1100 |
+
mask_pos = xt == mask # (B, L)
|
| 1101 |
+
new_token = unmask_rate.argmax(dim=2) # (B, L)
|
| 1102 |
+
new_xt = xt.clone()
|
| 1103 |
+
new_xt[mask_pos] = new_token[mask_pos]
|
| 1104 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 1105 |
+
xt = new_xt
|
| 1106 |
+
t = t + dt
|
| 1107 |
+
continue
|
| 1108 |
+
# tau-leaping via Poisson counts
|
| 1109 |
+
counts = torch.poisson(unmask_rate * dt).long()
|
| 1110 |
+
mask_pos = xt == mask # (B, L)
|
| 1111 |
+
# zero out non-mask positions and mask→mask
|
| 1112 |
+
counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
|
| 1113 |
+
counts[..., mask] = 0
|
| 1114 |
+
# only accept exactly one event
|
| 1115 |
+
sum_c = counts.sum(dim=2) # (B, L)
|
| 1116 |
+
one_event = sum_c == 1
|
| 1117 |
+
new_token = counts.argmax(dim=2) # (B, L)
|
| 1118 |
+
|
| 1119 |
+
# build new xt
|
| 1120 |
+
new_xt = xt.clone()
|
| 1121 |
+
new_xt[one_event] = new_token[one_event]
|
| 1122 |
+
# keep pads and already-unmasked tokens
|
| 1123 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 1124 |
+
xt = new_xt
|
| 1125 |
+
t = t + dt
|
| 1126 |
+
|
| 1127 |
+
return xt, []
|
| 1128 |
+
|
| 1129 |
+
# Not used in production, for debugging purposes
|
| 1130 |
+
lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
|
| 1131 |
+
|
| 1132 |
+
def binomial_mass(k, n, p):
|
| 1133 |
+
"""
|
| 1134 |
+
Calculate the probability mass function (PMF) for a binomial distribution.
|
| 1135 |
+
|
| 1136 |
+
Args:
|
| 1137 |
+
k (int): Number of successes
|
| 1138 |
+
n (int): Number of trials
|
| 1139 |
+
p (float): Probability of success in a single trial
|
| 1140 |
+
|
| 1141 |
+
Returns:
|
| 1142 |
+
float: Probability mass P(X = k)
|
| 1143 |
+
"""
|
| 1144 |
+
import math
|
| 1145 |
+
|
| 1146 |
+
# Calculate binomial coefficient (n choose k)
|
| 1147 |
+
try:
|
| 1148 |
+
binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
|
| 1149 |
+
except ValueError:
|
| 1150 |
+
# Handle cases where k > n or negative values
|
| 1151 |
+
return 0.0
|
| 1152 |
+
|
| 1153 |
+
# Calculate probability mass
|
| 1154 |
+
return binom_coef * (p ** k) * ((1 - p) ** (n - k))
|
| 1155 |
+
|
| 1156 |
+
def calculate_rate_batch(alpha_t, len_t):
|
| 1157 |
+
"""
|
| 1158 |
+
Calculate rate for a batch of alpha_t and len_t values.
|
| 1159 |
+
|
| 1160 |
+
Args:
|
| 1161 |
+
alpha_t (torch.Tensor): Tensor of shape (batch_size,)
|
| 1162 |
+
len_t (torch.Tensor): Tensor of shape (batch_size,)
|
| 1163 |
+
|
| 1164 |
+
Returns:
|
| 1165 |
+
torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
|
| 1166 |
+
"""
|
| 1167 |
+
batch_size = alpha_t.shape[0]
|
| 1168 |
+
device = alpha_t.device
|
| 1169 |
+
|
| 1170 |
+
# Initialize tensors for numerator and denominator
|
| 1171 |
+
nom = torch.zeros(batch_size, device=device)
|
| 1172 |
+
denom = torch.zeros(batch_size, device=device)
|
| 1173 |
+
|
| 1174 |
+
for length, probability in lengths.items():
|
| 1175 |
+
# Create mask for valid entries where len_t <= length
|
| 1176 |
+
valid_mask = (len_t <= length) & (len_t >= 0)
|
| 1177 |
+
|
| 1178 |
+
if not valid_mask.any():
|
| 1179 |
+
continue
|
| 1180 |
+
|
| 1181 |
+
valid_indices = valid_mask.nonzero(as_tuple=True)[0]
|
| 1182 |
+
valid_len_t = len_t[valid_indices]
|
| 1183 |
+
valid_alpha_t = alpha_t[valid_indices]
|
| 1184 |
+
|
| 1185 |
+
# Calculate binomial probabilities efficiently using torch distribution
|
| 1186 |
+
binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
|
| 1187 |
+
binom_probs = binom_dist.log_prob(valid_len_t).exp()
|
| 1188 |
+
|
| 1189 |
+
# Update numerator and denominator for valid indices
|
| 1190 |
+
nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
|
| 1191 |
+
denom[valid_indices] += probability * binom_probs
|
| 1192 |
+
|
| 1193 |
+
# Handle division by zero in a vectorized way
|
| 1194 |
+
result = torch.zeros_like(nom)
|
| 1195 |
+
div_mask = denom > 0
|
| 1196 |
+
result[div_mask] = nom[div_mask] / (denom[div_mask])
|
| 1197 |
+
|
| 1198 |
+
return result
|
| 1199 |
+
|
| 1200 |
+
# Keep the original function for backward compatibility
|
| 1201 |
+
def calculate_rate(alpha_t, len_t):
|
| 1202 |
+
"""Legacy scalar version of calculate_rate"""
|
| 1203 |
+
if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
|
| 1204 |
+
return calculate_rate_batch(alpha_t, len_t)
|
| 1205 |
+
|
| 1206 |
+
nom, denom = 0, 0
|
| 1207 |
+
for length, probability in lengths.items():
|
| 1208 |
+
if length >= len_t:
|
| 1209 |
+
nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
|
| 1210 |
+
denom += probability * binomial_mass(len_t, length, alpha_t)
|
| 1211 |
+
|
| 1212 |
+
if denom == 0:
|
| 1213 |
+
return 0.0
|
| 1214 |
+
|
| 1215 |
+
return nom /denom
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
@torch.no_grad()
|
| 1219 |
+
def any_order_mask_insertion_tau_leaping_sampling(
|
| 1220 |
+
model: torch.nn.Module,
|
| 1221 |
+
steps: int,
|
| 1222 |
+
mask: int,
|
| 1223 |
+
pad: int,
|
| 1224 |
+
batch_size: int,
|
| 1225 |
+
max_length: int,
|
| 1226 |
+
return_trace: bool = False,
|
| 1227 |
+
confidence_based_sampling: bool = True, # whether to use confidence-based decoding
|
| 1228 |
+
alpha: float = 5.0, # hyperparameter for window size calculation
|
| 1229 |
+
max_window: int = 32, # Maximum window size for sliding window
|
| 1230 |
+
confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
|
| 1231 |
+
use_sliding_window: bool = False, # whether to use sliding window for position selection
|
| 1232 |
+
temperature: float = 1.0,
|
| 1233 |
+
) -> SamplingResult:
|
| 1234 |
+
|
| 1235 |
+
device = next(model.parameters()).device
|
| 1236 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 1237 |
+
sampling_trace = []
|
| 1238 |
+
dt = 1.0 / steps
|
| 1239 |
+
t = torch.zeros(batch_size, device=device)
|
| 1240 |
+
|
| 1241 |
+
# Precompute row indices for scatter
|
| 1242 |
+
batch_idx_L = (
|
| 1243 |
+
torch.arange(batch_size, device=device)
|
| 1244 |
+
.view(batch_size, 1)
|
| 1245 |
+
.expand(batch_size, max_length)
|
| 1246 |
+
)
|
| 1247 |
+
pos_idx_L = (
|
| 1248 |
+
torch.arange(max_length, device=device)
|
| 1249 |
+
.view(1, max_length)
|
| 1250 |
+
.expand(batch_size, max_length)
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
for i in range(steps):
|
| 1254 |
+
# --- predict rates ---
|
| 1255 |
+
pred = model(xt, t)
|
| 1256 |
+
xt_len = (xt != pad).sum(dim=1)
|
| 1257 |
+
pred = model.interpolant.to_actual_rate(xt, pred, t)
|
| 1258 |
+
unmask_rate = pred.unmask_rate # (B, L, V)
|
| 1259 |
+
len_rate = pred.length_rate # (B, L+1)
|
| 1260 |
+
|
| 1261 |
+
if i == steps - 1:
|
| 1262 |
+
# last step: deterministic unmask via argmax
|
| 1263 |
+
mask_pos = xt == mask
|
| 1264 |
+
new_token = unmask_rate.argmax(dim=2)
|
| 1265 |
+
new_xt = xt.clone()
|
| 1266 |
+
new_xt[mask_pos] = new_token[mask_pos]
|
| 1267 |
+
new_xt = torch.where(xt == pad, pad, new_xt)
|
| 1268 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 1269 |
+
xt = new_xt
|
| 1270 |
+
t = t + dt
|
| 1271 |
+
continue
|
| 1272 |
+
|
| 1273 |
+
# --- confidence-based decoding ---
|
| 1274 |
+
if confidence_based_sampling > 0.0:
|
| 1275 |
+
# Confidence-based unmasking (vectorized)
|
| 1276 |
+
mask_positions = (xt == mask) # (B, L)
|
| 1277 |
+
num_mask_positions = mask_positions.sum(dim=1) # (B,)
|
| 1278 |
+
|
| 1279 |
+
# 1. Determine number of tokens to unmask using Poisson
|
| 1280 |
+
unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
|
| 1281 |
+
|
| 1282 |
+
# 2. Calculate confidence based on selected method
|
| 1283 |
+
if confidence_method == "position":
|
| 1284 |
+
# Position-based confidence: position i / len(xt)
|
| 1285 |
+
xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
|
| 1286 |
+
position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
|
| 1287 |
+
confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
|
| 1288 |
+
|
| 1289 |
+
elif confidence_method == "top_prob":
|
| 1290 |
+
# Top probability confidence
|
| 1291 |
+
import torch.nn.functional as F
|
| 1292 |
+
token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
|
| 1293 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1294 |
+
confidence = unmask_probs.max(dim=-1)[0] # (B, L)
|
| 1295 |
+
|
| 1296 |
+
elif confidence_method == "prob_diff":
|
| 1297 |
+
# Probability difference confidence (top - second top)
|
| 1298 |
+
import torch.nn.functional as F
|
| 1299 |
+
token_logits = unmask_rate # (B, L, V)
|
| 1300 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1301 |
+
top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
|
| 1302 |
+
confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
|
| 1303 |
+
|
| 1304 |
+
elif confidence_method == "entropy":
|
| 1305 |
+
# Entropy-based confidence (lower entropy = higher confidence)
|
| 1306 |
+
import torch.nn.functional as F
|
| 1307 |
+
token_logits = unmask_rate # (B, L, V)
|
| 1308 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1309 |
+
entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
|
| 1310 |
+
confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
|
| 1311 |
+
|
| 1312 |
+
else:
|
| 1313 |
+
raise ValueError(f"Unknown confidence_method: {confidence_method}")
|
| 1314 |
+
|
| 1315 |
+
# 3. Apply window constraint if enabled
|
| 1316 |
+
if use_sliding_window:
|
| 1317 |
+
# Calculate dynamic k for each batch
|
| 1318 |
+
k_values = torch.minimum(
|
| 1319 |
+
torch.minimum(
|
| 1320 |
+
(alpha * unmask_counts).long(),
|
| 1321 |
+
torch.tensor(max_window, device=device)
|
| 1322 |
+
), num_mask_positions) # (B,)
|
| 1323 |
+
|
| 1324 |
+
# Get cumulative count of mask positions
|
| 1325 |
+
mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
|
| 1326 |
+
|
| 1327 |
+
# Create window mask: position is eligible if it's a mask and within first k masks
|
| 1328 |
+
is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
|
| 1329 |
+
window_mask = mask_positions & is_within_window # (B, L)
|
| 1330 |
+
|
| 1331 |
+
# Set confidence to -inf for positions outside the window or non-mask positions
|
| 1332 |
+
confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
|
| 1333 |
+
else:
|
| 1334 |
+
# No window constraint - only mask positions are eligible
|
| 1335 |
+
confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
|
| 1336 |
+
|
| 1337 |
+
new_xt = xt.clone()
|
| 1338 |
+
|
| 1339 |
+
# vectorized unmasking
|
| 1340 |
+
max_unmask = unmask_counts.max().item()
|
| 1341 |
+
if max_unmask > 0:
|
| 1342 |
+
_, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
|
| 1343 |
+
|
| 1344 |
+
# create mask for valid unmask operations
|
| 1345 |
+
unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
|
| 1346 |
+
|
| 1347 |
+
most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
|
| 1348 |
+
|
| 1349 |
+
selected_positions = all_top_indices[unmask_mask]
|
| 1350 |
+
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
|
| 1351 |
+
|
| 1352 |
+
new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
|
| 1353 |
+
else:
|
| 1354 |
+
# --- tau-leaping unmask via Poisson ---
|
| 1355 |
+
counts = torch.poisson(unmask_rate * dt).long()
|
| 1356 |
+
mask_pos = xt == mask
|
| 1357 |
+
counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
|
| 1358 |
+
counts[..., mask] = 0
|
| 1359 |
+
sum_c = counts.sum(dim=2)
|
| 1360 |
+
one_event = sum_c == 1
|
| 1361 |
+
new_token = counts.argmax(dim=2)
|
| 1362 |
+
new_xt = xt.clone()
|
| 1363 |
+
new_xt[one_event] = new_token[one_event]
|
| 1364 |
+
new_xt = torch.where(xt == pad, pad, new_xt)
|
| 1365 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 1366 |
+
|
| 1367 |
+
# insertion only on non-last
|
| 1368 |
+
if i != steps - 1:
|
| 1369 |
+
# --- Poisson insertion, compute new lengths and fill masks ---
|
| 1370 |
+
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
|
| 1371 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 1372 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 1373 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 1374 |
+
total_ext = ext.sum(dim=1)
|
| 1375 |
+
valid = xt_len + total_ext <= max_length
|
| 1376 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 1377 |
+
|
| 1378 |
+
# compute prefix sums of insertions
|
| 1379 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 1380 |
+
new_len = xt_len + total_ext # (B,)
|
| 1381 |
+
|
| 1382 |
+
# initialize with pads, then fill mask up to new_len
|
| 1383 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 1384 |
+
mask_pos = pos_idx_L < new_len.view(batch_size, 1)
|
| 1385 |
+
xt_tmp[mask_pos] = mask
|
| 1386 |
+
|
| 1387 |
+
# shift and scatter original tokens
|
| 1388 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 1389 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 1390 |
+
flat_b = batch_idx_L[orig_mask]
|
| 1391 |
+
flat_p = new_pos_orig[orig_mask]
|
| 1392 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 1393 |
+
else:
|
| 1394 |
+
xt_tmp = new_xt
|
| 1395 |
+
|
| 1396 |
+
xt = xt_tmp
|
| 1397 |
+
t = t + dt
|
| 1398 |
+
if return_trace:
|
| 1399 |
+
sampling_trace.append(xt)
|
| 1400 |
+
|
| 1401 |
+
return xt, sampling_trace
|
a2d2_mol/scripts/run_mol_finetune.slurm
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# NOTE: --partition and --qos below are specific to our cluster. Change them
|
| 3 |
+
# (or remove them and pass `--partition` on the `sbatch` command line) to match
|
| 4 |
+
# the partitions/QOS available on yours.
|
| 5 |
+
#SBATCH --job-name=mol-finetune
|
| 6 |
+
#SBATCH --partition=dgx-b200
|
| 7 |
+
#SBATCH --nodes=1
|
| 8 |
+
#SBATCH --gpus-per-node=1
|
| 9 |
+
#SBATCH --cpus-per-task=8
|
| 10 |
+
#SBATCH --ntasks-per-node=1
|
| 11 |
+
#SBATCH --mem=80GB
|
| 12 |
+
#SBATCH --time=02-00:00:00
|
| 13 |
+
#SBATCH --output=logs/slurm-%A.%x.log
|
| 14 |
+
|
| 15 |
+
# =====================================================================
|
| 16 |
+
# run_mol_finetune.slurm
|
| 17 |
+
#
|
| 18 |
+
# Single-mode job (1 MIG GPU) running ONE finetune_mol experiment.
|
| 19 |
+
# Select which mode to run via the MODE_ID variable below (or override
|
| 20 |
+
# at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
|
| 21 |
+
# 0) A2D2 (Ours) – with full planner (alternating)
|
| 22 |
+
# 1) A2D2 w/o quality – --disable_planner
|
| 23 |
+
# 2) A2D2 w/o insertion planner – --disable_insertion_planner
|
| 24 |
+
# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
|
| 25 |
+
#
|
| 26 |
+
# The job trains the selected mode then evaluates the resulting
|
| 27 |
+
# checkpoint on the same GPU.
|
| 28 |
+
# =====================================================================
|
| 29 |
+
|
| 30 |
+
set -e
|
| 31 |
+
|
| 32 |
+
# --- Mode selection ---------------------------------------------------
|
| 33 |
+
# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
|
| 34 |
+
MODE_ID="${MODE_ID:-0}"
|
| 35 |
+
|
| 36 |
+
# Run prefix
|
| 37 |
+
PREFIX=${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}
|
| 38 |
+
|
| 39 |
+
# --- Paths ------------------------------------------------------------
|
| 40 |
+
# Repo root is resolved at submit time so the job runs from any clone:
|
| 41 |
+
# - set A2D2_ROOT explicitly, OR
|
| 42 |
+
# - submit with `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR;
|
| 43 |
+
# note sbatch copies the script to a spool dir, so we can't rely on the
|
| 44 |
+
# script's own path here), OR
|
| 45 |
+
# - run the script directly, falling back to its location on disk.
|
| 46 |
+
if [ -n "${A2D2_ROOT:-}" ]; then
|
| 47 |
+
HOME_LOC="$A2D2_ROOT"
|
| 48 |
+
elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
|
| 49 |
+
HOME_LOC="$SLURM_SUBMIT_DIR"
|
| 50 |
+
else
|
| 51 |
+
# This script lives in a2d2_mol/scripts/, so the repo root is two levels up.
|
| 52 |
+
HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 53 |
+
fi
|
| 54 |
+
SCRIPT_LOC="$HOME_LOC/a2d2_mol"
|
| 55 |
+
LOG_LOC=$HOME_LOC/logs
|
| 56 |
+
SAVE_DIR=$HOME_LOC/checkpoints/finetune_mol
|
| 57 |
+
RESULTS_DIR=$HOME_LOC/results/mol_ablation
|
| 58 |
+
|
| 59 |
+
mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
|
| 60 |
+
|
| 61 |
+
# --- Environment setup ------------------------------------------------
|
| 62 |
+
# Set WANDB_API_KEY in your shell/secret store before submitting (do NOT commit it):
|
| 63 |
+
# export WANDB_API_KEY=... or `wandb login`
|
| 64 |
+
export WANDB_DIR=$HOME_LOC/.wandb
|
| 65 |
+
export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
|
| 66 |
+
export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
|
| 67 |
+
mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
|
| 68 |
+
|
| 69 |
+
export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
|
| 70 |
+
mkdir -p "$TRITON_CACHE_DIR"
|
| 71 |
+
|
| 72 |
+
export TORCHINDUCTOR_CACHE_DIR=$HOME_LOC/.torchinductor/cache
|
| 73 |
+
mkdir -p "$TORCHINDUCTOR_CACHE_DIR"
|
| 74 |
+
|
| 75 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 76 |
+
|
| 77 |
+
# Force unbuffered stdout/stderr so live training output is flushed to the
|
| 78 |
+
# redirected RUN_LOG (Python block-buffers stdout when it's a file, not a TTY).
|
| 79 |
+
export PYTHONUNBUFFERED=1
|
| 80 |
+
|
| 81 |
+
# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
|
| 82 |
+
# install, or just have `conda` on PATH; override CONDA_ENV if your env name
|
| 83 |
+
# differs from the one created by environment.yml.
|
| 84 |
+
CONDA_ENV="${CONDA_ENV:-a2d2}"
|
| 85 |
+
if [ -n "${CONDA_ROOT:-}" ]; then
|
| 86 |
+
source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
|
| 87 |
+
elif command -v conda >/dev/null 2>&1; then
|
| 88 |
+
source "$(conda info --base)/bin/activate" "$CONDA_ENV"
|
| 89 |
+
else
|
| 90 |
+
echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
|
| 91 |
+
exit 1
|
| 92 |
+
fi
|
| 93 |
+
PYTHON_EXECUTABLE=$(which python)
|
| 94 |
+
|
| 95 |
+
cd "$SCRIPT_LOC"
|
| 96 |
+
|
| 97 |
+
# Pretrained base checkpoint
|
| 98 |
+
PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_mol.ckpt"
|
| 99 |
+
|
| 100 |
+
# --- Shared training hyperparameters ----------------------------------
|
| 101 |
+
COMMON_ARGS=(
|
| 102 |
+
--base_path "$HOME_LOC"
|
| 103 |
+
--use_quality_filter
|
| 104 |
+
--noise_removal
|
| 105 |
+
--wdce_num_replicates 16
|
| 106 |
+
--pool_size 1000
|
| 107 |
+
--pool_refresh_fraction 0.3
|
| 108 |
+
--buffer_size 100
|
| 109 |
+
--batch_size 200
|
| 110 |
+
--training_mini_batch_size 20
|
| 111 |
+
--max_length 256
|
| 112 |
+
--total_num_steps 256
|
| 113 |
+
--num_iter 20
|
| 114 |
+
--resample_every_n_step 10
|
| 115 |
+
--num_epochs 1000
|
| 116 |
+
--save_every_n_epochs 100
|
| 117 |
+
--reset_every_n_step 1
|
| 118 |
+
--alpha 0.01
|
| 119 |
+
--no_mcts
|
| 120 |
+
--schedule_warmup_epochs 20
|
| 121 |
+
--alternation_frequency 5
|
| 122 |
+
--num_remasking 3
|
| 123 |
+
--quality_threshold 0.3
|
| 124 |
+
--checkpoint_path "$PRETRAINED_CKPT"
|
| 125 |
+
--grad_clip
|
| 126 |
+
--qed_only
|
| 127 |
+
--seed 42
|
| 128 |
+
--num_training_steps_per_epoch 25
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# --- Shared evaluation hyperparameters --------------------------------
|
| 132 |
+
EVAL_COMMON_ARGS=(
|
| 133 |
+
--pretrained_ckpt "$PRETRAINED_CKPT"
|
| 134 |
+
--num_samples 1000
|
| 135 |
+
--batch_size 50
|
| 136 |
+
--max_length 256
|
| 137 |
+
--total_num_steps 256
|
| 138 |
+
--num_remasking 2
|
| 139 |
+
--quality_threshold 0.3
|
| 140 |
+
--seed 42
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# =====================================================================
|
| 144 |
+
# Pick experiment from $MODE_ID
|
| 145 |
+
# =====================================================================
|
| 146 |
+
case "$MODE_ID" in
|
| 147 |
+
0) MODE="with_planner"; EXTRA_ARGS=() ;;
|
| 148 |
+
1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
|
| 149 |
+
2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
|
| 150 |
+
3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
|
| 151 |
+
*) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
|
| 152 |
+
esac
|
| 153 |
+
|
| 154 |
+
RUN_NAME="${PREFIX}_mol_${MODE}"
|
| 155 |
+
RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
|
| 156 |
+
RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
|
| 157 |
+
RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
|
| 158 |
+
mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
|
| 159 |
+
|
| 160 |
+
echo "=== Mol finetune (MODE_ID=$MODE_ID) ==="
|
| 161 |
+
echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
|
| 162 |
+
echo "Mode: $MODE"
|
| 163 |
+
echo "Save dir: $RUN_SAVE_DIR"
|
| 164 |
+
echo "Results dir: $RESULTS_SUBDIR"
|
| 165 |
+
echo "Python: $PYTHON_EXECUTABLE"
|
| 166 |
+
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
|
| 167 |
+
|
| 168 |
+
# =====================================================================
|
| 169 |
+
# Train
|
| 170 |
+
# =====================================================================
|
| 171 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_mol.py \
|
| 172 |
+
"${COMMON_ARGS[@]}" \
|
| 173 |
+
--devices 1 \
|
| 174 |
+
"${EXTRA_ARGS[@]}" \
|
| 175 |
+
--save_path_dir "$RUN_SAVE_DIR" \
|
| 176 |
+
>> "$RUN_LOG" 2>&1
|
| 177 |
+
|
| 178 |
+
echo "Training finished for $MODE. Log: $RUN_LOG"
|
| 179 |
+
|
| 180 |
+
# =====================================================================
|
| 181 |
+
# Evaluate
|
| 182 |
+
# =====================================================================
|
| 183 |
+
RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt "$RUN_SAVE_DIR"/last.ckpt 2>/dev/null | head -1)
|
| 184 |
+
if [ -z "$RUN_CKPT" ]; then
|
| 185 |
+
echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
|
| 186 |
+
exit 1
|
| 187 |
+
fi
|
| 188 |
+
|
| 189 |
+
echo "Evaluating checkpoint: $RUN_CKPT"
|
| 190 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_mol_table.py \
|
| 191 |
+
--checkpoint_path "$RUN_CKPT" \
|
| 192 |
+
"${EVAL_COMMON_ARGS[@]}" \
|
| 193 |
+
"${EXTRA_ARGS[@]}" \
|
| 194 |
+
--output_dir "$RESULTS_SUBDIR" \
|
| 195 |
+
--device cuda:0 \
|
| 196 |
+
>> "$RESULTS_SUBDIR/eval.log" 2>&1
|
| 197 |
+
|
| 198 |
+
echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}.csv"
|
| 199 |
+
|
| 200 |
+
conda deactivate
|
a2d2_mol/scripts/train_mol.sh
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=a2d2-mol-pretrain
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --gpus-per-node=2
|
| 6 |
+
#SBATCH --ntasks-per-node=2
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=512GB
|
| 9 |
+
#SBATCH --time=7-00:00:00
|
| 10 |
+
# SLURM's own catch-file (anything printed before the exec redirect below, plus
|
| 11 |
+
# slurm-infra messages). Relative to the submit dir, so submit this script from
|
| 12 |
+
# the a2d2_mol/ directory; the real run output is redirected via exec below.
|
| 13 |
+
#SBATCH --output=logs/slurm/%x_%j.out
|
| 14 |
+
#SBATCH --error=logs/slurm/%x_%j.err
|
| 15 |
+
#
|
| 16 |
+
# Pretrain the any-length insertion MDM on drug-like SAFE molecules on a dgx-b200 node.
|
| 17 |
+
# Submit with: sbatch scripts/train_mol.sh (from the a2d2_mol/ directory).
|
| 18 |
+
#
|
| 19 |
+
# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
|
| 20 |
+
# --ntasks-per-node must match; change both together (and they override the
|
| 21 |
+
# training.devices value baked into config_mol.yaml via the hydra override below).
|
| 22 |
+
|
| 23 |
+
DATE=$(date +%Y%m%d)
|
| 24 |
+
SPECIAL_PREFIX='a2d2-mol'
|
| 25 |
+
|
| 26 |
+
# Resolve a2d2_mol/ (which holds train.py + config_mol.yaml) so paths are
|
| 27 |
+
# repo-relative. This script lives in a2d2_mol/scripts/, so the direct-run
|
| 28 |
+
# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
|
| 29 |
+
# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_mol/ directory).
|
| 30 |
+
if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
|
| 31 |
+
SCRIPT_DIR="$SLURM_SUBMIT_DIR"
|
| 32 |
+
else
|
| 33 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 34 |
+
fi
|
| 35 |
+
cd "$SCRIPT_DIR"
|
| 36 |
+
|
| 37 |
+
# Auto-detect GPUs from the SLURM allocation (falls back to 2 for `bash` runs).
|
| 38 |
+
DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-2}}
|
| 39 |
+
NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
|
| 40 |
+
NODES=${SLURM_NNODES:-1}
|
| 41 |
+
|
| 42 |
+
LOG_LOC="$SCRIPT_DIR/logs"
|
| 43 |
+
mkdir -p "$LOG_LOC/slurm"
|
| 44 |
+
exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Weights & Biases: log in once on your machine before running this script with
|
| 48 |
+
# `wandb login` (or `export WANDB_API_KEY=<your-key>`).
|
| 49 |
+
# Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
|
| 50 |
+
# export WANDB_MODE=disabled
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 54 |
+
|
| 55 |
+
# Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
|
| 56 |
+
# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
|
| 57 |
+
# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
|
| 58 |
+
# PATH; override CONDA_ENV if your env name differs from the one created by
|
| 59 |
+
# environment.yml.
|
| 60 |
+
CONDA_ENV="${CONDA_ENV:-a2d2}"
|
| 61 |
+
if [ -n "${CONDA_ROOT:-}" ]; then
|
| 62 |
+
source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
|
| 63 |
+
elif command -v conda >/dev/null 2>&1; then
|
| 64 |
+
source "$(conda info --base)/bin/activate" "$CONDA_ENV"
|
| 65 |
+
else
|
| 66 |
+
echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
|
| 67 |
+
exit 1
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
# --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
|
| 71 |
+
ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
|
| 72 |
+
if [ -z "$ETH_IFACE" ]; then
|
| 73 |
+
ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
|
| 74 |
+
fi
|
| 75 |
+
export NCCL_IB_DISABLE=1
|
| 76 |
+
export NCCL_SOCKET_FAMILY=AF_INET
|
| 77 |
+
export NCCL_SOCKET_IFNAME=$ETH_IFACE
|
| 78 |
+
export NCCL_P2P_LEVEL=NVL
|
| 79 |
+
|
| 80 |
+
export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
|
| 81 |
+
export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
|
| 82 |
+
export NODE_RANK=${SLURM_NODEID:-0}
|
| 83 |
+
|
| 84 |
+
echo "=== a2d2 molecule pretraining (dgx-b200) ==="
|
| 85 |
+
echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
|
| 86 |
+
|
| 87 |
+
# --task mol makes train.py load config_mol.yaml; the hydra overrides pin
|
| 88 |
+
# devices/nodes to the SLURM allocation so the two never drift apart.
|
| 89 |
+
srun --ntasks-per-node=$NTASKS python train.py --task mol \
|
| 90 |
+
training.devices=$DEVICES \
|
| 91 |
+
training.nodes=$NODES
|
| 92 |
+
|
| 93 |
+
conda deactivate
|
a2d2_mol/train.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import argparse
|
| 8 |
+
import hydra
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
# Directory containing this file and the config_*.yaml files (used by Hydra below).
|
| 12 |
+
CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
|
| 14 |
+
sys.path.insert(0, os.path.dirname(CONFIG_DIR))
|
| 15 |
+
|
| 16 |
+
import wandb
|
| 17 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
torch.set_printoptions(threshold=10_000)
|
| 21 |
+
torch.set_float32_matmul_precision("high")
|
| 22 |
+
|
| 23 |
+
# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
|
| 24 |
+
torch._dynamo.config.optimize_ddp = False
|
| 25 |
+
|
| 26 |
+
def train(config):
|
| 27 |
+
wandb_logger = None
|
| 28 |
+
|
| 29 |
+
# set the random seed
|
| 30 |
+
pl.seed_everything(42)
|
| 31 |
+
torch.manual_seed(42)
|
| 32 |
+
|
| 33 |
+
# Only initialize wandb on rank 0 to avoid multiple runs
|
| 34 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
| 35 |
+
wandb.init(
|
| 36 |
+
project=config.wandb.project,
|
| 37 |
+
name=config.wandb.name,
|
| 38 |
+
config=OmegaConf.to_container(config, resolve=True), # Convert to dict
|
| 39 |
+
dir=config.wandb.path
|
| 40 |
+
)
|
| 41 |
+
wandb_logger = WandbLogger(
|
| 42 |
+
project=wandb.run.project,
|
| 43 |
+
name=wandb.run.name,
|
| 44 |
+
log_model=False, # Disable checkpoint uploading to save disk space
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Modify config to add timestamp to checkpoint directory
|
| 48 |
+
OmegaConf.set_struct(config, False)
|
| 49 |
+
time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 50 |
+
config.training.checkpoint_dir = os.path.join(
|
| 51 |
+
config.training.checkpoint_dir, time_string
|
| 52 |
+
)
|
| 53 |
+
OmegaConf.set_struct(config, True)
|
| 54 |
+
|
| 55 |
+
# Create checkpoint directory
|
| 56 |
+
os.makedirs(config.training.checkpoint_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# Setup data module - check if using HuggingFace dataset
|
| 59 |
+
if hasattr(config, 'hf_dataset'):
|
| 60 |
+
# Imported lazily: the HF/SAFE path is only used by the molecule configs,
|
| 61 |
+
# which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
|
| 62 |
+
from mol_dataset import setup_hf_data_and_update_config
|
| 63 |
+
print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
|
| 64 |
+
data_module = setup_hf_data_and_update_config(
|
| 65 |
+
config,
|
| 66 |
+
dataset_name=config.hf_dataset.name,
|
| 67 |
+
smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
# Imported lazily: the local (arrow) path is used by the peptide config,
|
| 71 |
+
# which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
|
| 72 |
+
from dataloading_for_dynamic_batching import setup_data_and_update_config
|
| 73 |
+
print("Using local dataset")
|
| 74 |
+
data_module = setup_data_and_update_config(config)
|
| 75 |
+
|
| 76 |
+
module = AnyOrderInsertionFlowModule(config)
|
| 77 |
+
|
| 78 |
+
# Initialize trainer
|
| 79 |
+
|
| 80 |
+
# Configure trainer arguments
|
| 81 |
+
# Map torch_dtype to Lightning precision
|
| 82 |
+
dtype_str = config.model.get('torch_dtype', 'bfloat16')
|
| 83 |
+
precision_map = {
|
| 84 |
+
'float32': '32-true',
|
| 85 |
+
'float16': '16-mixed',
|
| 86 |
+
'bfloat16': 'bf16-mixed'
|
| 87 |
+
}
|
| 88 |
+
precision = precision_map.get(dtype_str, 'bf16-mixed')
|
| 89 |
+
|
| 90 |
+
trainer_kwargs = dict(
|
| 91 |
+
num_nodes=config.training.nodes,
|
| 92 |
+
accelerator="gpu",
|
| 93 |
+
devices=config.training.devices,
|
| 94 |
+
strategy="ddp",
|
| 95 |
+
precision=precision,
|
| 96 |
+
accumulate_grad_batches=(
|
| 97 |
+
config.training.batch_size
|
| 98 |
+
// (
|
| 99 |
+
config.training.per_gpu_batch_size
|
| 100 |
+
* config.training.nodes
|
| 101 |
+
* config.training.devices
|
| 102 |
+
)
|
| 103 |
+
),
|
| 104 |
+
log_every_n_steps=10,
|
| 105 |
+
enable_checkpointing=True,
|
| 106 |
+
default_root_dir=config.training.checkpoint_dir,
|
| 107 |
+
gradient_clip_val=1.0,
|
| 108 |
+
)
|
| 109 |
+
# Only one of max_steps or max_epochs will be used
|
| 110 |
+
if config.training.max_steps is not None:
|
| 111 |
+
trainer_kwargs["max_steps"] = config.training.max_steps
|
| 112 |
+
elif config.training.num_epochs is not None:
|
| 113 |
+
trainer_kwargs["max_epochs"] = config.training.num_epochs
|
| 114 |
+
config.training.max_steps = config.training.max_steps
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"Either max_steps or num_epochs must be specified in the config"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if config.training.warmup_steps is None:
|
| 121 |
+
config.training.warmup_steps = int(config.training.max_steps * 0.01)
|
| 122 |
+
|
| 123 |
+
# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
|
| 124 |
+
checkpoint_callback = ModelCheckpoint(
|
| 125 |
+
monitor="train/total_loss",
|
| 126 |
+
mode="min",
|
| 127 |
+
save_top_k=config.training.save_top_k,
|
| 128 |
+
save_last=True,
|
| 129 |
+
filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
|
| 130 |
+
dirpath=config.training.checkpoint_dir,
|
| 131 |
+
# Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
|
| 132 |
+
auto_insert_metric_name=False
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Add separate callback for periodic saves (no val_loss dependency). Use
|
| 136 |
+
# step-based saves for streaming datasets (save_every_n_steps) and epoch-based
|
| 137 |
+
# saves otherwise (save_every_n_epochs); whichever the config provides.
|
| 138 |
+
save_every_n_steps = config.training.get('save_every_n_steps', None)
|
| 139 |
+
save_every_n_epochs = config.training.get('save_every_n_epochs', None)
|
| 140 |
+
if save_every_n_steps is not None:
|
| 141 |
+
periodic_checkpoint_callback = ModelCheckpoint(
|
| 142 |
+
save_top_k=-1, # Save all periodic checkpoints
|
| 143 |
+
filename="step-{step:08d}",
|
| 144 |
+
dirpath=config.training.checkpoint_dir,
|
| 145 |
+
every_n_train_steps=save_every_n_steps,
|
| 146 |
+
auto_insert_metric_name=False
|
| 147 |
+
)
|
| 148 |
+
elif save_every_n_epochs is not None:
|
| 149 |
+
periodic_checkpoint_callback = ModelCheckpoint(
|
| 150 |
+
save_top_k=-1, # Save all periodic checkpoints
|
| 151 |
+
filename="epoch-{epoch:02d}",
|
| 152 |
+
dirpath=config.training.checkpoint_dir,
|
| 153 |
+
every_n_epochs=save_every_n_epochs,
|
| 154 |
+
auto_insert_metric_name=False
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"Either save_every_n_steps or save_every_n_epochs must be specified in the config"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
|
| 162 |
+
|
| 163 |
+
if wandb_logger is not None:
|
| 164 |
+
trainer_kwargs["logger"] = wandb_logger
|
| 165 |
+
|
| 166 |
+
trainer = pl.Trainer(**trainer_kwargs)
|
| 167 |
+
|
| 168 |
+
# Train the model
|
| 169 |
+
ckpt_path = None
|
| 170 |
+
if "resume_path" in config.training:
|
| 171 |
+
ckpt_path = config.training.resume_path
|
| 172 |
+
|
| 173 |
+
trainer.fit(module,
|
| 174 |
+
datamodule=data_module,
|
| 175 |
+
ckpt_path=ckpt_path)
|
| 176 |
+
|
| 177 |
+
# Only finish wandb on rank 0
|
| 178 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
| 179 |
+
wandb.finish()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == '__main__':
|
| 183 |
+
# Parse arguments to get config name
|
| 184 |
+
parser = argparse.ArgumentParser()
|
| 185 |
+
parser.add_argument('--config_name', type=str, default='config',
|
| 186 |
+
help='Name of the config file to use')
|
| 187 |
+
parser.add_argument('--task', type=str, default=None,
|
| 188 |
+
help='Task name (uses config_{task}.yaml)')
|
| 189 |
+
|
| 190 |
+
# Parse known args (hydra will handle the rest)
|
| 191 |
+
args, unknown = parser.parse_known_args()
|
| 192 |
+
|
| 193 |
+
# Determine config name from task or config_name
|
| 194 |
+
if args.task:
|
| 195 |
+
config_name = f'config_{args.task}'
|
| 196 |
+
else:
|
| 197 |
+
config_name = args.config_name
|
| 198 |
+
|
| 199 |
+
print(f"Using config: {config_name}.yaml")
|
| 200 |
+
|
| 201 |
+
# Add config name to Hydra overrides (this persists across DDP subprocesses)
|
| 202 |
+
if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
|
| 203 |
+
unknown.insert(0, f'--config-name={config_name}')
|
| 204 |
+
|
| 205 |
+
# Reconstruct sys.argv for hydra
|
| 206 |
+
sys.argv = [sys.argv[0]] + unknown
|
| 207 |
+
|
| 208 |
+
# Define main function with default config (will be overridden by command line)
|
| 209 |
+
@hydra.main(version_base=None,
|
| 210 |
+
config_path=CONFIG_DIR,
|
| 211 |
+
config_name='config')
|
| 212 |
+
def main(config):
|
| 213 |
+
"""Main entry point for training"""
|
| 214 |
+
train(config)
|
| 215 |
+
|
| 216 |
+
main()
|
a2d2_pep/README.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A2D2 for Multi-Objective Therapeutic Peptide Generation 🧫
|
| 2 |
+
|
| 3 |
+
This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over peptide SMILES with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize **five therapeutic properties simultaneously**: binding affinity to a target protein, solubility, non-hemolysis, non-fouling, and cell-membrane permeability.
|
| 4 |
+
|
| 5 |
+
A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating peptides via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
|
| 6 |
+
|
| 7 |
+
Peptides are represented as **SMILES** strings and tokenized with the SMILES Pair Encoding tokenizer (vocabulary size `V = 587`) from [PeptideCLM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01443). Generated SMILES are decoded and validity-checked with the `SMILES2PEPTIDE` filter from [PepTune](https://arxiv.org/abs/2412.17780).
|
| 8 |
+
|
| 9 |
+
The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
|
| 10 |
+
|
| 11 |
+
## Environment Installation
|
| 12 |
+
```
|
| 13 |
+
# from the repository root
|
| 14 |
+
conda env create -f environment.yml
|
| 15 |
+
|
| 16 |
+
conda activate a2d2
|
| 17 |
+
```
|
| 18 |
+
The peptide scripts share the `a2d2` environment with the molecule and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
|
| 19 |
+
|
| 20 |
+
## Model Pretrained Weights
|
| 21 |
+
|
| 22 |
+
A2D2 fine-tunes a pretrained any-length insertion MDM trained on ~11M peptide SMILES (7,451 sequences from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs). Download the base checkpoint and place it at:
|
| 23 |
+
```
|
| 24 |
+
A2D2/pretrained/anylength_pep.ckpt
|
| 25 |
+
```
|
| 26 |
+
```bash
|
| 27 |
+
# from the repository root
|
| 28 |
+
pip install gdown
|
| 29 |
+
mkdir -p pretrained
|
| 30 |
+
gdown 1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc -O pretrained/anylength_pep.ckpt
|
| 31 |
+
```
|
| 32 |
+
(Or download manually from https://drive.google.com/file/d/1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
|
| 33 |
+
This is the default `--checkpoint_path`; pass `--checkpoint_path` to override it.
|
| 34 |
+
|
| 35 |
+
The reward classifiers (binding-affinity Transformer, plus XGBoost predictors for solubility, hemolysis, non-fouling, and permeability) and the SMILES PE tokenizer ship with the repo under [`pep_scoring/`](pep_scoring); no separate download is required. The PeptideCLM embedding model is fetched automatically from the Hugging Face Hub (`aaronfeller/PeptideCLM-23M-all`) on first run.
|
| 36 |
+
|
| 37 |
+
## Pretraining the Any-Length Model
|
| 38 |
+
|
| 39 |
+
If you only want to fine-tune with A2D2, download the released `anylength_pep.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
|
| 40 |
+
|
| 41 |
+
### 1. Download the pretraining dataset
|
| 42 |
+
|
| 43 |
+
The pretraining corpus is ~11M peptide SMILES (7,451 from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs), already tokenized with the in-repo SMILES PE tokenizer and saved as a Hugging Face `arrow` dataset (with `train`/`val` splits) via `save_to_disk`.
|
| 44 |
+
|
| 45 |
+
Download the archive and unpack it into [`data/`](data):
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
# from a2d2_pep/
|
| 49 |
+
pip install gdown
|
| 50 |
+
gdown https://drive.google.com/uc?id=1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp -O 11M_peptide_smiles.zip
|
| 51 |
+
mkdir -p data
|
| 52 |
+
unzip 11M_peptide_smiles.zip -d data/
|
| 53 |
+
# result: a2d2_pep/data/11M_peptide_smiles/{train,val}/...
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
This is the default `training.data_path` in [`config_pep.yaml`](config_pep.yaml). To store the dataset elsewhere, set `training.data_path` (absolute, or relative to `a2d2_pep/`).
|
| 57 |
+
|
| 58 |
+
### 2. Configure
|
| 59 |
+
|
| 60 |
+
Pretraining is driven by [`config_pep.yaml`](config_pep.yaml). Key fields:
|
| 61 |
+
|
| 62 |
+
| Field | Default | Notes |
|
| 63 |
+
|-------|---------|-------|
|
| 64 |
+
| `training.data_path` | `data/11M_peptide_smiles` | Preprocessed arrow dataset from step 1. |
|
| 65 |
+
| `training.devices` | `4` | GPUs per node (DDP). |
|
| 66 |
+
| `training.batch_size` | `1024` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
|
| 67 |
+
| `training.max_steps` | `1000000` | Total optimizer steps. |
|
| 68 |
+
| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
|
| 69 |
+
| `training.checkpoint_dir` | `checkpoints/peptides` | A timestamped subdirectory is created per run. |
|
| 70 |
+
| `interpolant.max_length` | `1024` | Max token length. |
|
| 71 |
+
|
| 72 |
+
### 3. Pre-training Any-Length Peptide Model
|
| 73 |
+
|
| 74 |
+
Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# from a2d2_pep/
|
| 78 |
+
sbatch train_pep.sh
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
`train_pep.sh` is a SLURM batch script that requests one `dgx-b200` node with 4 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python train.py --task pep
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task pep` makes `train.py` load `config_pep.yaml`.
|
| 88 |
+
|
| 89 |
+
Checkpoints are written to `checkpoints/peptides/<timestamp>/` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` for fine-tuning); the run log goes to `logs/<date>_a2d2-peptide_<jobid>.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
|
| 90 |
+
|
| 91 |
+
## Fine-Tune with A2D2
|
| 92 |
+
|
| 93 |
+
All paths resolve relative to the repository, so the scripts run from any checkout. Before running, create the output directories `A2D2/checkpoints`, `A2D2/results`, and `A2D2/logs` (the script also creates them on demand). Fine-tuning curves and a `<prot_name>_generation_results.csv` are written to `<base_path>/results/<run_name>/`, and checkpoints to `--save_path_dir`.
|
| 94 |
+
|
| 95 |
+
Choose a target protein with `--prot_name` (looked up in the built-in `PROTEINS` table — e.g. `glp1` for GLP-1R or `glast` for GLAST), or supply an arbitrary target with `--prot_seq <amino acid sequence>`.
|
| 96 |
+
|
| 97 |
+
#### Available `--prot_name` targets
|
| 98 |
+
|
| 99 |
+
The named targets and their amino-acid sequences are defined in the `PROTEINS` dict in [`finetune_quality.py`](finetune_quality.py) (search for `PROTEINS = {`). The default is `glast`; passing a name not in the table raises an error listing the valid keys. To add a new target, add a `'<key>': '<sequence>'` entry there, or skip the table entirely with `--prot_seq`.
|
| 100 |
+
|
| 101 |
+
| `--prot_name` | Target |
|
| 102 |
+
|---------------|--------|
|
| 103 |
+
| `tfr` | Transferrin receptor (TfR) |
|
| 104 |
+
| `glp1` | GLP-1 receptor (GLP-1R) |
|
| 105 |
+
|
| 106 |
+
### Single run
|
| 107 |
+
|
| 108 |
+
[`scripts/run_peptide_finetune.slurm`](scripts/run_peptide_finetune.slurm) runs a single `finetune_quality.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the hyperparameter set from the peptide column of the fine-tuning table in the paper — replicates `R = 8`, buffer size `B = 50`, resample interval `N_resample = 10`, gradient steps per iteration `N_update = 10`, alternation frequency `N_alt = 5`, warmup `N_warmup = 20`, sampling steps `N_steps = 256`, training mini-batch `10`, reward scaling `α = 0.1`, quality threshold `μ_min = 0.5`, and `--num_obj 5` — so you don't have to pass them by hand.
|
| 109 |
+
|
| 110 |
+
The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`) and `WANDB_ENTITY`:
|
| 111 |
+
```bash
|
| 112 |
+
export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
|
| 113 |
+
export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
|
| 114 |
+
export WANDB_ENTITY=your_wandb_entity # optional
|
| 115 |
+
sbatch scripts/run_peptide_finetune.slurm
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
|
| 119 |
+
```bash
|
| 120 |
+
sbatch --export=ALL,MODE_ID=2 scripts/run_peptide_finetune.slurm
|
| 121 |
+
```
|
| 122 |
+
The target protein is set by the `PROT_NAME` variable near the top of the script (default `tfr`); edit it to one of the named targets above (or any key in the `PROTEINS` table). The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_pep.ckpt`. Outputs land in `checkpoints/finetune_test_peptides_<prot>/<job>_peptide_<prot>_<mode>/` and `results/peptide_test_ablation_<prot>/<mode>/`.
|
| 123 |
+
|
| 124 |
+
### Key arguments
|
| 125 |
+
- `--prot_name` / `--prot_seq` — target protein (named lookup, or a raw amino-acid sequence).
|
| 126 |
+
- `--alternation_frequency` — epochs to train each of {policy, planner} before alternating.
|
| 127 |
+
- `--alpha` — reward-tilting temperature (smaller = stronger reward optimization).
|
| 128 |
+
- `--buffer_size`, `--resample_every_n_step` — replay-buffer size and how often it is regenerated.
|
| 129 |
+
|
| 130 |
+
### Ablation flags
|
| 131 |
+
| Flag | Variant |
|
| 132 |
+
|------|---------|
|
| 133 |
+
| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
|
| 134 |
+
| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
|
| 135 |
+
| `--disable_insertion_planner` | A2D2 w/o insertion quality |
|
| 136 |
+
| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
|
| 137 |
+
| `--joint_training` | train policy + quality heads jointly (no alternation) |
|
| 138 |
+
|
| 139 |
+
During buffer generation only sequences passing the `SMILES2PEPTIDE` validity filter are retained; the scalarized multi-objective reward is added to the log Radon–Nikodym derivative of each sequence. Fine-tuning runs on a single GPU (`--devices 1`).
|
| 140 |
+
|
| 141 |
+
## Evaluation
|
| 142 |
+
|
| 143 |
+
Evaluation runs automatically every `--eval_every_n_epochs` epochs and at the end of training. It samples from the current model and reports the fraction of valid peptides along with the five therapeutic rewards (binding affinity, solubility, non-hemolysis, non-fouling, permeability), saving per-objective curves and `<prot_name>_generation_results.csv` under `<base_path>/results/<run_name>/`.
|
| 144 |
+
|
| 145 |
+
To resume a run, pass `--resume_ckpt /path/to/last.ckpt` (restores epoch, optimizer, and planner state; new checkpoints continue in the same directory).
|
a2d2_pep/config_pep.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trainer: "any-order-flow"
|
| 2 |
+
dataset: "peptides"
|
| 3 |
+
|
| 4 |
+
model:
|
| 5 |
+
hidden_size: 768
|
| 6 |
+
n_heads: 12
|
| 7 |
+
cond_dim: 128
|
| 8 |
+
dropout: 0.05
|
| 9 |
+
n_blocks: 12
|
| 10 |
+
|
| 11 |
+
interpolant:
|
| 12 |
+
type: "any-order"
|
| 13 |
+
tokens: null # filled in automatically
|
| 14 |
+
pad_token: null # filled in automatically
|
| 15 |
+
mask_token: null # filled in automatically
|
| 16 |
+
max_length: 1024
|
| 17 |
+
insert_schedule:
|
| 18 |
+
type: "linear"
|
| 19 |
+
unmask_schedule:
|
| 20 |
+
type: "linear"
|
| 21 |
+
|
| 22 |
+
training:
|
| 23 |
+
only_embed_insert: true
|
| 24 |
+
batch_size: 1024
|
| 25 |
+
per_gpu_batch_size: 64 # Gradient accumulation happens automatically
|
| 26 |
+
cpus: 4
|
| 27 |
+
learning_rate: 3e-4
|
| 28 |
+
nodes: 1
|
| 29 |
+
devices: 4
|
| 30 |
+
max_steps: 1000000
|
| 31 |
+
weight_decay: 0.03
|
| 32 |
+
# Path to the preprocessed (arrow) pretraining dataset; see README for the download link.
|
| 33 |
+
# Relative paths resolve against a2d2_pep/. Defaults to a2d2_pep/data/11M_peptide_smiles.
|
| 34 |
+
data_path: "data/11M_peptide_smiles"
|
| 35 |
+
checkpoint_dir: "checkpoints/peptides"
|
| 36 |
+
save_top_k: 1
|
| 37 |
+
save_every_n_epochs: 1
|
| 38 |
+
loss_fn:
|
| 39 |
+
unmask: "elbo"
|
| 40 |
+
insert: "expectation"
|
| 41 |
+
reset_lr: false
|
| 42 |
+
warmup_steps: 2000
|
| 43 |
+
ema_decay: 0.9999
|
| 44 |
+
filter_max_length: false
|
| 45 |
+
|
| 46 |
+
wandb:
|
| 47 |
+
entity: null # set to your W&B entity, or leave null to use the default
|
| 48 |
+
project: "a2d2-pep"
|
| 49 |
+
name: "a2d2-pep"
|
| 50 |
+
path: "./wandb"
|
a2d2_pep/data/dataloading_for_dynamic_batching.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from datasets import Dataset,load_from_disk
|
| 6 |
+
import sys
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
+
from functools import partial
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
# Directory containing this file; used to resolve the in-repo tokenizer files.
|
| 13 |
+
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DynamicBatchingDataset(Dataset):
|
| 17 |
+
def __init__(self, dataset_dict, tokenizer):
|
| 18 |
+
print('Initializing dataset...')
|
| 19 |
+
self.dataset_dict = {
|
| 20 |
+
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
|
| 21 |
+
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
|
| 22 |
+
'labels': dataset_dict['labels']
|
| 23 |
+
}
|
| 24 |
+
self.tokenizer = tokenizer
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.dataset_dict['attention_mask'])
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
if isinstance(idx, int):
|
| 31 |
+
return {
|
| 32 |
+
'input_ids': self.dataset_dict['input_ids'][idx],
|
| 33 |
+
'attention_mask': self.dataset_dict['attention_mask'][idx],
|
| 34 |
+
'labels': self.dataset_dict['labels'][idx]
|
| 35 |
+
}
|
| 36 |
+
elif isinstance(idx, list):
|
| 37 |
+
return {
|
| 38 |
+
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
|
| 39 |
+
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
|
| 40 |
+
'labels': [self.dataset_dict['labels'][i] for i in idx]
|
| 41 |
+
}
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
|
| 44 |
+
|
| 45 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 46 |
+
def __init__(self, dataset_path, tokenizer):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.dataset = load_from_disk(dataset_path)
|
| 49 |
+
self.tokenizer = tokenizer
|
| 50 |
+
|
| 51 |
+
def peptide_bond_mask(self, smiles_list):
|
| 52 |
+
"""
|
| 53 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 54 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 61 |
+
"""
|
| 62 |
+
# Initialize the batch mask
|
| 63 |
+
batch_size = len(smiles_list)
|
| 64 |
+
max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 65 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 66 |
+
|
| 67 |
+
bond_patterns = [
|
| 68 |
+
(r'OC\(=O\)', 'ester'),
|
| 69 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 70 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 71 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 72 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 73 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 77 |
+
positions = []
|
| 78 |
+
used = set()
|
| 79 |
+
|
| 80 |
+
# Identify bonds
|
| 81 |
+
for pattern, bond_type in bond_patterns:
|
| 82 |
+
for match in re.finditer(pattern, smiles):
|
| 83 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 84 |
+
positions.append({
|
| 85 |
+
'start': match.start(),
|
| 86 |
+
'end': match.end(),
|
| 87 |
+
'type': bond_type,
|
| 88 |
+
'pattern': match.group()
|
| 89 |
+
})
|
| 90 |
+
used.update(range(match.start(), match.end()))
|
| 91 |
+
|
| 92 |
+
# Update the mask for the current SMILES
|
| 93 |
+
for pos in positions:
|
| 94 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 95 |
+
|
| 96 |
+
return mask
|
| 97 |
+
|
| 98 |
+
def peptide_token_mask(self, smiles_list, token_lists):
|
| 99 |
+
"""
|
| 100 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 101 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 105 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 109 |
+
"""
|
| 110 |
+
# Initialize the batch mask
|
| 111 |
+
batch_size = len(smiles_list)
|
| 112 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 113 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 114 |
+
atomwise_masks = self.peptide_bond_mask(smiles_list)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 118 |
+
token_seq = token_lists[batch_idx]
|
| 119 |
+
atom_idx = 0
|
| 120 |
+
|
| 121 |
+
for token_idx, token in enumerate(token_seq):
|
| 122 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 123 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 124 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 125 |
+
atom_idx += len(token)
|
| 126 |
+
|
| 127 |
+
return tokenized_masks
|
| 128 |
+
|
| 129 |
+
def collate_fn(self, batch):
|
| 130 |
+
item = batch[0]
|
| 131 |
+
|
| 132 |
+
token_array = self.tokenizer.get_token_split(item['input_ids'])
|
| 133 |
+
bond_mask = self.peptide_token_mask(item['labels'], token_array)
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
'input_ids': item['input_ids'],
|
| 137 |
+
'attention_mask': item['attention_mask'],
|
| 138 |
+
'bond_mask': bond_mask
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
def train_dataloader(self):
|
| 142 |
+
train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
|
| 143 |
+
return DataLoader(
|
| 144 |
+
train_dataset,
|
| 145 |
+
batch_size=1,
|
| 146 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 147 |
+
shuffle=True,
|
| 148 |
+
num_workers=12,
|
| 149 |
+
pin_memory=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def val_dataloader(self):
|
| 153 |
+
val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
|
| 154 |
+
return DataLoader(
|
| 155 |
+
val_dataset,
|
| 156 |
+
batch_size=1,
|
| 157 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 158 |
+
num_workers=8,
|
| 159 |
+
pin_memory=True
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def setup_data_and_update_config(config):
|
| 164 |
+
"""
|
| 165 |
+
Get the dataset and update the config with token information for text datasets.
|
| 166 |
+
"""
|
| 167 |
+
# SMILES Pair Encoding tokenizer ships with the repo under pep_scoring/tokenizer/.
|
| 168 |
+
tokenizer = SMILES_SPE_Tokenizer(
|
| 169 |
+
os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
|
| 170 |
+
os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_splits.txt'),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
config.interpolant.tokens = len(tokenizer)
|
| 174 |
+
config.interpolant.pad_token = tokenizer.pad_token_id
|
| 175 |
+
config.interpolant.mask_token = tokenizer.mask_token_id
|
| 176 |
+
|
| 177 |
+
# Path to the preprocessed (arrow) pretraining dataset saved via `save_to_disk`.
|
| 178 |
+
# Download instructions are in the README; override with `training.data_path` in the config.
|
| 179 |
+
data_path = config.training.get('data_path', os.path.join('data', '11M_peptide_smiles'))
|
| 180 |
+
if not os.path.isabs(data_path):
|
| 181 |
+
data_path = os.path.join(_THIS_DIR, data_path)
|
| 182 |
+
if not os.path.exists(data_path):
|
| 183 |
+
raise FileNotFoundError(
|
| 184 |
+
f"Pretraining dataset not found at '{data_path}'. Download it (see a2d2_pep/README.md, "
|
| 185 |
+
"'Pretraining the Any-Length Model') and set `training.data_path` in config_pep.yaml."
|
| 186 |
+
)
|
| 187 |
+
data_module = CustomDataModule(data_path, tokenizer)
|
| 188 |
+
|
| 189 |
+
return data_module
|
a2d2_pep/data/dataset.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import utils
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
from functools import partial
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
class CustomDataset(Dataset):
|
| 13 |
+
def __init__(self, dataset, indices):
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
self.indices = indices
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.indices)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
actual_idx = int(self.indices[idx])
|
| 22 |
+
item = self.dataset[actual_idx]
|
| 23 |
+
return item
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# for weighting losses of peptide bonds
|
| 27 |
+
def peptide_bond_mask(smiles_list):
|
| 28 |
+
"""
|
| 29 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 30 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 37 |
+
"""
|
| 38 |
+
# Initialize the batch mask
|
| 39 |
+
batch_size = len(smiles_list)
|
| 40 |
+
max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 41 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 42 |
+
|
| 43 |
+
bond_patterns = [
|
| 44 |
+
(r'OC\(=O\)', 'ester'),
|
| 45 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 46 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 47 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 48 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 49 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 53 |
+
positions = []
|
| 54 |
+
used = set()
|
| 55 |
+
|
| 56 |
+
# Identify bonds
|
| 57 |
+
for pattern, bond_type in bond_patterns:
|
| 58 |
+
for match in re.finditer(pattern, smiles):
|
| 59 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 60 |
+
positions.append({
|
| 61 |
+
'start': match.start(),
|
| 62 |
+
'end': match.end(),
|
| 63 |
+
'type': bond_type,
|
| 64 |
+
'pattern': match.group()
|
| 65 |
+
})
|
| 66 |
+
used.update(range(match.start(), match.end()))
|
| 67 |
+
|
| 68 |
+
# Update the mask for the current SMILES
|
| 69 |
+
for pos in positions:
|
| 70 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 71 |
+
|
| 72 |
+
return mask
|
| 73 |
+
|
| 74 |
+
def peptide_token_mask(smiles_list, token_lists):
|
| 75 |
+
"""
|
| 76 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 77 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 81 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 85 |
+
"""
|
| 86 |
+
# Initialize the batch mask
|
| 87 |
+
batch_size = len(smiles_list)
|
| 88 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 89 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 90 |
+
atomwise_masks = peptide_bond_mask(smiles_list)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 94 |
+
token_seq = token_lists[batch_idx]
|
| 95 |
+
atom_idx = 0
|
| 96 |
+
|
| 97 |
+
for token_idx, token in enumerate(token_seq):
|
| 98 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 99 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 100 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 101 |
+
atom_idx += len(token)
|
| 102 |
+
|
| 103 |
+
return tokenized_masks
|
| 104 |
+
|
| 105 |
+
def extract_amino_acid_sequence(helm_string):
|
| 106 |
+
"""
|
| 107 |
+
Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
|
| 108 |
+
removing any brackets around each amino acid.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
helm_string (str): The HELM notation string for a peptide.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
list: A list containing each amino acid in sequence without brackets.
|
| 115 |
+
"""
|
| 116 |
+
# Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
|
| 117 |
+
matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
|
| 118 |
+
|
| 119 |
+
if matches:
|
| 120 |
+
# Join all matched sequences and split by dots to get individual amino acids
|
| 121 |
+
amino_acid_sequence = []
|
| 122 |
+
for match in matches:
|
| 123 |
+
sequence = match.replace('[', '').replace(']', '').split('.')
|
| 124 |
+
amino_acid_sequence.extend(sequence)
|
| 125 |
+
return amino_acid_sequence
|
| 126 |
+
else:
|
| 127 |
+
return "Invalid HELM notation or no peptide sequence found."
|
| 128 |
+
|
| 129 |
+
def helm_collate_fn(batch, tokenizer):
|
| 130 |
+
sequences = [item['HELM'] for item in batch]
|
| 131 |
+
|
| 132 |
+
max_len = 0
|
| 133 |
+
for sequence in sequences:
|
| 134 |
+
seq_len = len(extract_amino_acid_sequence(sequence))
|
| 135 |
+
if seq_len > max_len:
|
| 136 |
+
max_len = seq_len
|
| 137 |
+
|
| 138 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
'input_ids': tokens['input_ids'],
|
| 142 |
+
'attention_mask': tokens['attention_mask']
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def collate_fn(batch, tokenizer):
|
| 147 |
+
"""Standard data collator that truncates/pad sequences based on max_length"""
|
| 148 |
+
valid_sequences = []
|
| 149 |
+
valid_items = []
|
| 150 |
+
|
| 151 |
+
for item in batch:
|
| 152 |
+
try:
|
| 153 |
+
test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
|
| 154 |
+
valid_sequences.append(item['SMILES'])
|
| 155 |
+
valid_items.append(item)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Skipping sequence due to: {str(e)}")
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
#sequences = [item['SMILES'] for item in batch]
|
| 161 |
+
#max_len = max([len(seq) for seq in sequences])
|
| 162 |
+
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
|
| 163 |
+
|
| 164 |
+
tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
|
| 165 |
+
|
| 166 |
+
token_array = tokenizer.get_token_split(tokens['input_ids'])
|
| 167 |
+
bond_mask = peptide_token_mask(valid_sequences, token_array)
|
| 168 |
+
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'input_ids': tokens['input_ids'],
|
| 172 |
+
'attention_mask': tokens['attention_mask'],
|
| 173 |
+
'bond_mask': bond_mask
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 178 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.train_dataset = train_dataset
|
| 181 |
+
self.val_dataset = val_dataset
|
| 182 |
+
#self.test_dataset = test_dataset
|
| 183 |
+
self.batch_size = batch_size
|
| 184 |
+
self.tokenizer = tokenizer
|
| 185 |
+
self.collate_fn = collate_fn
|
| 186 |
+
|
| 187 |
+
def train_dataloader(self):
|
| 188 |
+
return DataLoader(self.train_dataset,
|
| 189 |
+
batch_size=self.batch_size,
|
| 190 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 191 |
+
num_workers=8,
|
| 192 |
+
pin_memory=True
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def val_dataloader(self):
|
| 197 |
+
return DataLoader(self.val_dataset,
|
| 198 |
+
batch_size=self.batch_size,
|
| 199 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 200 |
+
num_workers=8,
|
| 201 |
+
pin_memory=True
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
"""def test_dataloader(self):
|
| 205 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size,
|
| 206 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 207 |
+
num_workers=8, pin_memory=True)"""
|
a2d2_pep/evaluate_peptide_table.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate a finetuned peptide model checkpoint by sampling sequences
|
| 3 |
+
and computing metrics for the De Novo Peptide Generation table:
|
| 4 |
+
Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑),
|
| 5 |
+
Nonfouling (↑), Permeability (↑), Sampling Time (↓)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import argparse
|
| 11 |
+
import time
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
|
| 16 |
+
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
|
| 17 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
sys.path.insert(0, REPO_ROOT)
|
| 19 |
+
|
| 20 |
+
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
|
| 21 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 22 |
+
from inference_quality import sample_peptides_eval
|
| 23 |
+
from pep_scoring.scoring_functions import ScoringFunctions
|
| 24 |
+
from pep_utils.analyzer import PeptideAnalyzer
|
| 25 |
+
from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 26 |
+
from finetune_quality import PeptideFinetuner
|
| 27 |
+
from pep_utils.utils import str2bool, set_seed
|
| 28 |
+
from tdc import Evaluator
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Protein sequences
|
| 32 |
+
PROTEINS = {
|
| 33 |
+
'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
|
| 34 |
+
'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
|
| 35 |
+
'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
|
| 36 |
+
'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
|
| 37 |
+
'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
|
| 38 |
+
'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
|
| 39 |
+
'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
|
| 40 |
+
'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
|
| 41 |
+
'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
|
| 46 |
+
"""Load a finetuned PeptideFinetuner from a Lightning checkpoint."""
|
| 47 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 48 |
+
hparams = ckpt.get('hyper_parameters', {})
|
| 49 |
+
args = hparams.get('args', None)
|
| 50 |
+
|
| 51 |
+
# Load pretrained base checkpoint to get config
|
| 52 |
+
base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
|
| 53 |
+
if 'hyper_parameters' in base_ckpt:
|
| 54 |
+
config = base_ckpt['hyper_parameters']['config']
|
| 55 |
+
elif 'config' in base_ckpt:
|
| 56 |
+
config = base_ckpt['config']
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError("Cannot find config in base checkpoint")
|
| 59 |
+
|
| 60 |
+
from omegaconf import OmegaConf, DictConfig
|
| 61 |
+
if not OmegaConf.is_config(config):
|
| 62 |
+
config = DictConfig(config)
|
| 63 |
+
OmegaConf.set_struct(config, False)
|
| 64 |
+
|
| 65 |
+
config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
|
| 66 |
+
config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
|
| 67 |
+
config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
|
| 68 |
+
config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
|
| 69 |
+
config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
|
| 70 |
+
config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
|
| 71 |
+
OmegaConf.set_struct(config, True)
|
| 72 |
+
|
| 73 |
+
disable_planner = getattr(args, 'disable_planner', False)
|
| 74 |
+
|
| 75 |
+
policy_model = AnyOrderInsertionFlowModuleFT(
|
| 76 |
+
config=config,
|
| 77 |
+
args=args,
|
| 78 |
+
pretrained_checkpoint=pretrained_ckpt_path,
|
| 79 |
+
insertion_planner=not disable_planner,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Load finetuned weights
|
| 83 |
+
state_dict = ckpt['state_dict']
|
| 84 |
+
policy_state = {}
|
| 85 |
+
for k, v in state_dict.items():
|
| 86 |
+
if k.startswith('policy_model.'):
|
| 87 |
+
policy_state[k[len('policy_model.'):]] = v
|
| 88 |
+
policy_model.load_state_dict(policy_state, strict=False)
|
| 89 |
+
policy_model = policy_model.to(device)
|
| 90 |
+
policy_model.eval()
|
| 91 |
+
|
| 92 |
+
return policy_model, args, config
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer,
|
| 97 |
+
num_samples=1000, batch_size=50, max_length=512,
|
| 98 |
+
total_num_steps=256, quality_mode="both", num_remasking=3,
|
| 99 |
+
quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
|
| 100 |
+
"""
|
| 101 |
+
Sample `num_samples` peptides and compute all table metrics.
|
| 102 |
+
Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time
|
| 103 |
+
"""
|
| 104 |
+
all_affinity = []
|
| 105 |
+
all_sol = []
|
| 106 |
+
all_hemo = []
|
| 107 |
+
all_nf = []
|
| 108 |
+
all_permeability = []
|
| 109 |
+
all_valid_seqs = []
|
| 110 |
+
total_valid = 0
|
| 111 |
+
total_generated = 0
|
| 112 |
+
total_time = 0.0
|
| 113 |
+
|
| 114 |
+
num_batches = (num_samples + batch_size - 1) // batch_size
|
| 115 |
+
remaining = num_samples
|
| 116 |
+
|
| 117 |
+
for b in range(num_batches):
|
| 118 |
+
bs = min(batch_size, remaining)
|
| 119 |
+
remaining -= bs
|
| 120 |
+
|
| 121 |
+
t_start = time.time()
|
| 122 |
+
result = sample_peptides_eval(
|
| 123 |
+
model=policy_model,
|
| 124 |
+
reward_model=reward_model,
|
| 125 |
+
analyzer=analyzer,
|
| 126 |
+
tokenizer=tokenizer,
|
| 127 |
+
steps=total_num_steps,
|
| 128 |
+
mask=policy_model.interpolant.mask_token,
|
| 129 |
+
pad=policy_model.interpolant.pad_token,
|
| 130 |
+
batch_size=bs,
|
| 131 |
+
max_length=max_length,
|
| 132 |
+
quality_mode=quality_mode,
|
| 133 |
+
num_remasking=num_remasking,
|
| 134 |
+
quality_threshold=quality_threshold,
|
| 135 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 136 |
+
return_valid=True,
|
| 137 |
+
)
|
| 138 |
+
t_end = time.time()
|
| 139 |
+
|
| 140 |
+
# Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 141 |
+
valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result
|
| 142 |
+
|
| 143 |
+
batch_valid = len(valid_seqs)
|
| 144 |
+
total_valid += batch_valid
|
| 145 |
+
total_generated += bs
|
| 146 |
+
total_time += (t_end - t_start)
|
| 147 |
+
all_valid_seqs.extend(valid_seqs)
|
| 148 |
+
|
| 149 |
+
if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0:
|
| 150 |
+
all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist())
|
| 151 |
+
all_sol.extend(sol if isinstance(sol, list) else sol.tolist())
|
| 152 |
+
all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist())
|
| 153 |
+
all_nf.extend(nf if isinstance(nf, list) else nf.tolist())
|
| 154 |
+
all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist())
|
| 155 |
+
|
| 156 |
+
print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, "
|
| 157 |
+
f"time={t_end - t_start:.1f}s")
|
| 158 |
+
|
| 159 |
+
validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0
|
| 160 |
+
|
| 161 |
+
# Uniqueness (% of valid sequences that are unique) and
|
| 162 |
+
# Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences).
|
| 163 |
+
# Matches the convention used in evaluate_mol_table.py.
|
| 164 |
+
all_unique = list(set(all_valid_seqs))
|
| 165 |
+
num_unique = len(all_unique)
|
| 166 |
+
uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0
|
| 167 |
+
if num_unique > 1:
|
| 168 |
+
diversity = Evaluator('diversity')(all_unique)
|
| 169 |
+
else:
|
| 170 |
+
diversity = 0.0
|
| 171 |
+
|
| 172 |
+
metrics = {
|
| 173 |
+
'Validity (%)': validity,
|
| 174 |
+
'Uniqueness (%)': uniqueness,
|
| 175 |
+
'Diversity': diversity,
|
| 176 |
+
'Affinity': np.mean(all_affinity) if all_affinity else 0.0,
|
| 177 |
+
'Affinity Std': np.std(all_affinity) if all_affinity else 0.0,
|
| 178 |
+
'Solubility': np.mean(all_sol) if all_sol else 0.0,
|
| 179 |
+
'Solubility Std': np.std(all_sol) if all_sol else 0.0,
|
| 180 |
+
'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0,
|
| 181 |
+
'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0,
|
| 182 |
+
'Nonfouling': np.mean(all_nf) if all_nf else 0.0,
|
| 183 |
+
'Nonfouling Std': np.std(all_nf) if all_nf else 0.0,
|
| 184 |
+
'Permeability': np.mean(all_permeability) if all_permeability else 0.0,
|
| 185 |
+
'Permeability Std': np.std(all_permeability) if all_permeability else 0.0,
|
| 186 |
+
'Sampling Time (s)': total_time,
|
| 187 |
+
'Num Generated': total_generated,
|
| 188 |
+
'Num Valid': total_valid,
|
| 189 |
+
'Num Unique': num_unique,
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
return metrics
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main():
|
| 196 |
+
parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint")
|
| 197 |
+
parser.add_argument('--checkpoint_path', type=str, required=True,
|
| 198 |
+
help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
|
| 199 |
+
parser.add_argument('--pretrained_ckpt', type=str,
|
| 200 |
+
default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'),
|
| 201 |
+
help='Path to the pretrained base model checkpoint')
|
| 202 |
+
parser.add_argument('--num_samples', type=int, default=500,
|
| 203 |
+
help='Number of peptides to sample')
|
| 204 |
+
parser.add_argument('--batch_size', type=int, default=50,
|
| 205 |
+
help='Batch size for sampling')
|
| 206 |
+
parser.add_argument('--max_length', type=int, default=512)
|
| 207 |
+
parser.add_argument('--total_num_steps', type=int, default=256)
|
| 208 |
+
parser.add_argument('--num_remasking', type=int, default=3)
|
| 209 |
+
parser.add_argument('--quality_threshold', type=float, default=0.5,
|
| 210 |
+
help='Threshold for insertion quality filtering during sampling')
|
| 211 |
+
parser.add_argument('--unmask_quality_threshold', type=float, default=None,
|
| 212 |
+
help='If set, gate unmasking/remasking by confidence: remask '
|
| 213 |
+
'ALL clean tokens whose unmasking confidence is below this '
|
| 214 |
+
'threshold, regardless of the schedule budget. If unset '
|
| 215 |
+
'(default), remasking is purely schedule-driven (count-based).')
|
| 216 |
+
parser.add_argument('--prot_name', type=str, default='glast',
|
| 217 |
+
help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')')
|
| 218 |
+
parser.add_argument('--prot_seq', type=str, default=None,
|
| 219 |
+
help='Custom protein sequence (overrides --prot_name)')
|
| 220 |
+
parser.add_argument('--disable_planner', action='store_true',
|
| 221 |
+
help='If set, disable remasking during evaluation')
|
| 222 |
+
parser.add_argument('--disable_insertion_planner', action='store_true',
|
| 223 |
+
help='If set, disable insertion quality filtering during evaluation')
|
| 224 |
+
parser.add_argument('--disable_unmasking_planner', action='store_true',
|
| 225 |
+
help='If set, disable unmasking confidence planner during evaluation')
|
| 226 |
+
parser.add_argument('--output_dir', type=str, default=None,
|
| 227 |
+
help='Directory to save results CSV. Defaults to checkpoint directory.')
|
| 228 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
| 229 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
|
| 232 |
+
set_seed(args.seed, use_cuda=True)
|
| 233 |
+
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
| 234 |
+
|
| 235 |
+
# Map flags to quality_mode
|
| 236 |
+
if args.disable_planner:
|
| 237 |
+
quality_mode = "none"
|
| 238 |
+
elif args.disable_insertion_planner and args.disable_unmasking_planner:
|
| 239 |
+
quality_mode = "none"
|
| 240 |
+
elif args.disable_insertion_planner:
|
| 241 |
+
quality_mode = "unmasking_only"
|
| 242 |
+
elif args.disable_unmasking_planner:
|
| 243 |
+
quality_mode = "insertion_only"
|
| 244 |
+
else:
|
| 245 |
+
quality_mode = "both"
|
| 246 |
+
|
| 247 |
+
print(f"Loading checkpoint: {args.checkpoint_path}")
|
| 248 |
+
print(f"Pretrained base: {args.pretrained_ckpt}")
|
| 249 |
+
print(f"Quality mode: {quality_mode}")
|
| 250 |
+
|
| 251 |
+
policy_model, train_args, config = load_finetuned_model(
|
| 252 |
+
args.checkpoint_path, args.pretrained_ckpt, device=device
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Setup tokenizer, reward model, analyzer
|
| 256 |
+
tokenizer = SMILES_SPE_Tokenizer(
|
| 257 |
+
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
|
| 258 |
+
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if args.prot_seq is not None:
|
| 262 |
+
prot = args.prot_seq
|
| 263 |
+
prot_name = args.prot_name
|
| 264 |
+
else:
|
| 265 |
+
prot_name = args.prot_name
|
| 266 |
+
if prot_name not in PROTEINS:
|
| 267 |
+
raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
|
| 268 |
+
prot = PROTEINS[prot_name]
|
| 269 |
+
|
| 270 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 271 |
+
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device)
|
| 272 |
+
analyzer = PeptideAnalyzer()
|
| 273 |
+
|
| 274 |
+
print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...")
|
| 275 |
+
|
| 276 |
+
metrics = evaluate_checkpoint(
|
| 277 |
+
policy_model=policy_model,
|
| 278 |
+
tokenizer=tokenizer,
|
| 279 |
+
reward_model=reward_model,
|
| 280 |
+
analyzer=analyzer,
|
| 281 |
+
num_samples=args.num_samples,
|
| 282 |
+
batch_size=args.batch_size,
|
| 283 |
+
max_length=args.max_length,
|
| 284 |
+
total_num_steps=args.total_num_steps,
|
| 285 |
+
quality_mode=quality_mode,
|
| 286 |
+
num_remasking=args.num_remasking,
|
| 287 |
+
quality_threshold=args.quality_threshold,
|
| 288 |
+
unmask_quality_threshold=args.unmask_quality_threshold,
|
| 289 |
+
device=device,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Print summary table
|
| 293 |
+
print("\n" + "=" * 60)
|
| 294 |
+
print(" De Novo Peptide Generation Results")
|
| 295 |
+
print("=" * 60)
|
| 296 |
+
for k, v in metrics.items():
|
| 297 |
+
if isinstance(v, float):
|
| 298 |
+
print(f" {k:<30s}: {v:.4f}")
|
| 299 |
+
else:
|
| 300 |
+
print(f" {k:<30s}: {v}")
|
| 301 |
+
print("=" * 60)
|
| 302 |
+
|
| 303 |
+
# Save results
|
| 304 |
+
output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
|
| 305 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 306 |
+
|
| 307 |
+
if args.disable_planner:
|
| 308 |
+
tag = "no_planner"
|
| 309 |
+
elif args.disable_insertion_planner:
|
| 310 |
+
tag = "no_insertion_planner"
|
| 311 |
+
elif args.disable_unmasking_planner:
|
| 312 |
+
tag = "no_unmasking_planner"
|
| 313 |
+
else:
|
| 314 |
+
tag = "with_planner"
|
| 315 |
+
if args.unmask_quality_threshold is not None:
|
| 316 |
+
tag += f"_ut{args.unmask_quality_threshold:g}"
|
| 317 |
+
# Record the sweep parameter in the saved row for traceability.
|
| 318 |
+
metrics['unmask_quality_threshold'] = args.unmask_quality_threshold
|
| 319 |
+
metrics['quality_threshold'] = args.quality_threshold
|
| 320 |
+
metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv')
|
| 321 |
+
pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
|
| 322 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if __name__ == '__main__':
|
| 326 |
+
main()
|
a2d2_pep/finetune_quality.py
ADDED
|
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Distributed Data Parallel (DDP) finetuning for peptide generation using PyTorch Lightning
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 9 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 10 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 11 |
+
import wandb
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
|
| 18 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
|
| 20 |
+
from inference_quality import sample_peptides_buffer, sample_peptides_eval
|
| 21 |
+
from pep_utils.analyzer import PeptideAnalyzer
|
| 22 |
+
from pep_utils.utils import str2bool, set_seed
|
| 23 |
+
from pep_scoring.scoring_functions import ScoringFunctions
|
| 24 |
+
from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 25 |
+
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
|
| 26 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 27 |
+
from tdc import Evaluator
|
| 28 |
+
|
| 29 |
+
# Repository root (two levels up from this file: A2D2/a2d2_pep/finetune_quality.py)
|
| 30 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 31 |
+
|
| 32 |
+
class PeptideFinetuner(pl.LightningModule):
|
| 33 |
+
"""Lightning module for distributed peptide finetuning."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
args,
|
| 38 |
+
policy_model,
|
| 39 |
+
reward_model,
|
| 40 |
+
tokenizer,
|
| 41 |
+
pretrained=None,
|
| 42 |
+
mcts=None,
|
| 43 |
+
filename=None,
|
| 44 |
+
prot_name=None,
|
| 45 |
+
eps=1e-5
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.args = args
|
| 49 |
+
self.policy_model = policy_model
|
| 50 |
+
self.reward_model = reward_model
|
| 51 |
+
self.tokenizer = tokenizer
|
| 52 |
+
self.pretrained = pretrained
|
| 53 |
+
self.mcts = mcts
|
| 54 |
+
self.filename = filename
|
| 55 |
+
self.prot_name = prot_name
|
| 56 |
+
self.eps = eps
|
| 57 |
+
|
| 58 |
+
# Length cutoff is tunable from the CLI: --min_peptide_bonds N enforces
|
| 59 |
+
# >=N peptide bonds (filters degenerate short reward-hacked molecules);
|
| 60 |
+
# --min_peptide_bonds 0 disables the cutoff.
|
| 61 |
+
min_bonds = getattr(args, 'min_peptide_bonds', 4)
|
| 62 |
+
self.analyzer = PeptideAnalyzer(
|
| 63 |
+
min_peptide_bonds=max(0, min_bonds),
|
| 64 |
+
enforce_min_peptide_bonds=min_bonds > 0,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Save hyperparameters
|
| 68 |
+
self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
|
| 69 |
+
|
| 70 |
+
# Buffer for sequences
|
| 71 |
+
self.x_saved = None
|
| 72 |
+
self.log_rnd_saved = None
|
| 73 |
+
self.final_rewards_saved = None
|
| 74 |
+
|
| 75 |
+
# Logs
|
| 76 |
+
self.valid_fraction_log = []
|
| 77 |
+
self.uniqueness_log = []
|
| 78 |
+
self.diversity_log = []
|
| 79 |
+
self.affinity_log = []
|
| 80 |
+
self.sol_log = []
|
| 81 |
+
self.hemo_log = []
|
| 82 |
+
self.nf_log = []
|
| 83 |
+
self.permeability_log = []
|
| 84 |
+
self._diversity_evaluator = Evaluator('diversity')
|
| 85 |
+
|
| 86 |
+
# Alternating training between policy and planner
|
| 87 |
+
self.train_policy = True # Start by training policy
|
| 88 |
+
self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
|
| 89 |
+
|
| 90 |
+
def freeze_policy_model(self):
|
| 91 |
+
"""Freeze policy model parameters (but not planner)."""
|
| 92 |
+
for name, param in self.policy_model.named_parameters():
|
| 93 |
+
if not name.startswith('planner.'):
|
| 94 |
+
param.requires_grad = False
|
| 95 |
+
|
| 96 |
+
def unfreeze_policy_model(self):
|
| 97 |
+
"""Unfreeze policy model parameters (but not planner)."""
|
| 98 |
+
for name, param in self.policy_model.named_parameters():
|
| 99 |
+
if not name.startswith('planner.'):
|
| 100 |
+
param.requires_grad = True
|
| 101 |
+
|
| 102 |
+
def freeze_planner_model(self):
|
| 103 |
+
"""Freeze planner parameters."""
|
| 104 |
+
if hasattr(self.policy_model, 'planner'):
|
| 105 |
+
for param in self.policy_model.planner.parameters():
|
| 106 |
+
param.requires_grad = False
|
| 107 |
+
|
| 108 |
+
def unfreeze_planner_model(self):
|
| 109 |
+
"""Unfreeze planner parameters."""
|
| 110 |
+
if hasattr(self.policy_model, 'planner'):
|
| 111 |
+
for param in self.policy_model.planner.parameters():
|
| 112 |
+
param.requires_grad = True
|
| 113 |
+
|
| 114 |
+
def configure_optimizers(self):
|
| 115 |
+
# Separate parameter groups for policy backbone vs planner heads
|
| 116 |
+
planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
|
| 117 |
+
planner_params = []
|
| 118 |
+
policy_params = []
|
| 119 |
+
for name, param in self.policy_model.named_parameters():
|
| 120 |
+
if name.startswith('planner.'):
|
| 121 |
+
planner_params.append(param)
|
| 122 |
+
else:
|
| 123 |
+
policy_params.append(param)
|
| 124 |
+
|
| 125 |
+
param_groups = [
|
| 126 |
+
{'params': policy_params, 'lr': self.args.learning_rate},
|
| 127 |
+
{'params': planner_params, 'lr': planner_lr},
|
| 128 |
+
]
|
| 129 |
+
optimizer = torch.optim.AdamW(param_groups)
|
| 130 |
+
return optimizer
|
| 131 |
+
|
| 132 |
+
def _get_quality_mode(self):
|
| 133 |
+
"""Map ablation flags + warmup state to quality_mode string."""
|
| 134 |
+
if self.args.disable_planner:
|
| 135 |
+
return "none"
|
| 136 |
+
if self.current_epoch < self.args.schedule_warmup_epochs:
|
| 137 |
+
return "none"
|
| 138 |
+
di = getattr(self.args, 'disable_insertion_planner', False)
|
| 139 |
+
du = getattr(self.args, 'disable_unmasking_planner', False)
|
| 140 |
+
if di and du:
|
| 141 |
+
return "none"
|
| 142 |
+
if di:
|
| 143 |
+
return "unmasking_only"
|
| 144 |
+
if du:
|
| 145 |
+
return "insertion_only"
|
| 146 |
+
return "both"
|
| 147 |
+
|
| 148 |
+
def on_save_checkpoint(self, checkpoint):
|
| 149 |
+
"""
|
| 150 |
+
Save additional metadata to make loading easier.
|
| 151 |
+
Saves the config directly in the checkpoint so loading doesn't need to follow references.
|
| 152 |
+
"""
|
| 153 |
+
# Save the config from the policy model directly in the checkpoint
|
| 154 |
+
if hasattr(self.policy_model, 'config'):
|
| 155 |
+
checkpoint['config'] = self.policy_model.config
|
| 156 |
+
print(f"Saved config to checkpoint for easier loading")
|
| 157 |
+
|
| 158 |
+
# Save EMA params if they exist in the policy model
|
| 159 |
+
if hasattr(self.policy_model, 'ema_params') and self.policy_model.ema_params:
|
| 160 |
+
checkpoint['ema_params'] = self.policy_model.ema_params
|
| 161 |
+
print(f"Saved EMA params to checkpoint")
|
| 162 |
+
|
| 163 |
+
# Save planner state if it exists
|
| 164 |
+
if hasattr(self.policy_model, 'planner'):
|
| 165 |
+
checkpoint['planner_state'] = self.policy_model.planner.state_dict()
|
| 166 |
+
print(f"Saved planner state to checkpoint")
|
| 167 |
+
|
| 168 |
+
def on_train_epoch_start(self):
|
| 169 |
+
"""Called at the start of each training epoch."""
|
| 170 |
+
# If disable_planner mode, only train policy (no alternation)
|
| 171 |
+
if self.args.disable_planner:
|
| 172 |
+
self.train_policy = True
|
| 173 |
+
self.unfreeze_policy_model()
|
| 174 |
+
self.freeze_planner_model()
|
| 175 |
+
if self.global_rank == 0 and self.current_epoch == 0:
|
| 176 |
+
print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
|
| 177 |
+
elif getattr(self.args, 'joint_training', False):
|
| 178 |
+
# Joint mode: train policy + planner together every step (no alternation)
|
| 179 |
+
self.train_policy = True # marker; training_step adds planner loss when joint_training is set
|
| 180 |
+
self.unfreeze_policy_model()
|
| 181 |
+
self.unfreeze_planner_model()
|
| 182 |
+
if self.global_rank == 0 and self.current_epoch == 0:
|
| 183 |
+
print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
|
| 184 |
+
else:
|
| 185 |
+
# Alternate between training policy and planner from epoch 0
|
| 186 |
+
# Determine which model to train this epoch
|
| 187 |
+
cycle_position = (self.current_epoch // self.alternation_frequency) % 2
|
| 188 |
+
self.train_policy = (cycle_position == 0)
|
| 189 |
+
|
| 190 |
+
if self.train_policy:
|
| 191 |
+
# Train policy, freeze planner
|
| 192 |
+
self.unfreeze_policy_model()
|
| 193 |
+
self.freeze_planner_model()
|
| 194 |
+
if self.global_rank == 0:
|
| 195 |
+
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
|
| 196 |
+
else:
|
| 197 |
+
# Train planner, freeze policy
|
| 198 |
+
self.freeze_policy_model()
|
| 199 |
+
self.unfreeze_planner_model()
|
| 200 |
+
if self.global_rank == 0:
|
| 201 |
+
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
|
| 202 |
+
|
| 203 |
+
# Resample buffer if needed
|
| 204 |
+
if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
|
| 205 |
+
self._generate_buffer()
|
| 206 |
+
# Synchronize all ranks after buffer generation to prevent NCCL timeout
|
| 207 |
+
if self.trainer and self.trainer.world_size > 1:
|
| 208 |
+
torch.distributed.barrier()
|
| 209 |
+
|
| 210 |
+
def _generate_buffer(self):
|
| 211 |
+
"""Generate buffer of sequences for training - all ranks generate in parallel.
|
| 212 |
+
|
| 213 |
+
When pool_size > 0, maintains a persistent pool and refreshes a fraction
|
| 214 |
+
each time instead of regenerating the entire buffer from scratch. This
|
| 215 |
+
preserves diversity/uniqueness across training by avoiding wholesale
|
| 216 |
+
replacement with samples from an increasingly mode-collapsed policy.
|
| 217 |
+
"""
|
| 218 |
+
world_size = self.trainer.world_size if self.trainer else 1
|
| 219 |
+
rank = self.global_rank if self.trainer else 0
|
| 220 |
+
|
| 221 |
+
pool_size = getattr(self.args, 'pool_size', 0)
|
| 222 |
+
is_pool = pool_size > 0
|
| 223 |
+
is_init = self.x_saved is None
|
| 224 |
+
|
| 225 |
+
# Determine how many sequences to sample this call
|
| 226 |
+
if is_pool:
|
| 227 |
+
refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
|
| 228 |
+
if is_init:
|
| 229 |
+
samples_per_gpu = pool_size
|
| 230 |
+
else:
|
| 231 |
+
samples_per_gpu = max(1, int(pool_size * refresh_frac))
|
| 232 |
+
if rank == 0:
|
| 233 |
+
if is_init:
|
| 234 |
+
print(f"\n[POOL] Initializing pool with {pool_size} sequences at epoch {self.current_epoch}")
|
| 235 |
+
else:
|
| 236 |
+
print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} sequences ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
|
| 237 |
+
else:
|
| 238 |
+
samples_per_gpu = self.args.buffer_size // world_size
|
| 239 |
+
if rank == 0:
|
| 240 |
+
samples_per_gpu += self.args.buffer_size % world_size
|
| 241 |
+
|
| 242 |
+
accumulated_x = []
|
| 243 |
+
accumulated_log_rnd = []
|
| 244 |
+
accumulated_rewards = []
|
| 245 |
+
total_accumulated = 0
|
| 246 |
+
|
| 247 |
+
if rank == 0:
|
| 248 |
+
print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
|
| 249 |
+
print(f"[BUFFER] GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
| 250 |
+
print(f"[BUFFER] GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
| 251 |
+
if not is_pool:
|
| 252 |
+
print(f"[BUFFER] Each of {world_size} ranks will generate {samples_per_gpu} samples")
|
| 253 |
+
|
| 254 |
+
max_attempts = getattr(self.args, 'max_buffer_attempts', 100) # cap wasted GPU / infinite loop
|
| 255 |
+
starvation_patience = getattr(self.args, 'buffer_starvation_patience', 10)
|
| 256 |
+
attempts = 0
|
| 257 |
+
|
| 258 |
+
import time
|
| 259 |
+
while total_accumulated < samples_per_gpu and attempts < max_attempts:
|
| 260 |
+
attempts += 1
|
| 261 |
+
if rank == 0:
|
| 262 |
+
print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
|
| 263 |
+
|
| 264 |
+
start_time = time.time()
|
| 265 |
+
|
| 266 |
+
# new elbo loss
|
| 267 |
+
if self.args.elbo_rnd:
|
| 268 |
+
x_final, _, final_rewards, trace = \
|
| 269 |
+
sample_peptides_buffer(
|
| 270 |
+
self.policy_model,
|
| 271 |
+
self.reward_model, self.analyzer,
|
| 272 |
+
self.tokenizer,
|
| 273 |
+
steps=self.args.total_num_steps,
|
| 274 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 275 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 276 |
+
batch_size=self.args.batch_size,
|
| 277 |
+
max_length=self.args.max_length,
|
| 278 |
+
# Buffer generation never uses the quality heads (planner):
|
| 279 |
+
# the backbone must train on raw policy samples so that a
|
| 280 |
+
# poorly-trained planner can't corrupt the backbone's data.
|
| 281 |
+
quality_mode="none",
|
| 282 |
+
compute_rnd=False,
|
| 283 |
+
alpha=self.args.alpha,
|
| 284 |
+
num_remasking=self.args.num_remasking,
|
| 285 |
+
min_length=self.args.min_length,
|
| 286 |
+
)
|
| 287 |
+
if x_final.shape[0] > 0:
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
noised = self.policy_model.prepare_noised_sample(
|
| 290 |
+
x_final, num_samples=self.args.elbo_rnd_num_samples)
|
| 291 |
+
policy_loss = self.policy_model.compute_loss_from_noised(noised)
|
| 292 |
+
pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
|
| 293 |
+
log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
|
| 294 |
+
else:
|
| 295 |
+
log_rnd = torch.empty((0,), dtype=torch.float32, device=x_final.device)
|
| 296 |
+
else:
|
| 297 |
+
x_final, log_rnd, final_rewards, trace = \
|
| 298 |
+
sample_peptides_buffer(
|
| 299 |
+
self.policy_model,
|
| 300 |
+
self.reward_model, self.analyzer,
|
| 301 |
+
self.tokenizer,
|
| 302 |
+
steps=self.args.total_num_steps,
|
| 303 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 304 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 305 |
+
batch_size=self.args.batch_size,
|
| 306 |
+
max_length=self.args.max_length,
|
| 307 |
+
# Buffer generation never uses the quality heads (planner):
|
| 308 |
+
# the backbone must train on raw policy samples so that a
|
| 309 |
+
# poorly-trained planner can't corrupt the backbone's data.
|
| 310 |
+
quality_mode="none",
|
| 311 |
+
compute_rnd=True,
|
| 312 |
+
pretrained=self.pretrained,
|
| 313 |
+
alpha=self.args.alpha,
|
| 314 |
+
num_remasking=self.args.num_remasking,
|
| 315 |
+
min_length=self.args.min_length,
|
| 316 |
+
)
|
| 317 |
+
elapsed = time.time() - start_time
|
| 318 |
+
if rank == 0:
|
| 319 |
+
print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
|
| 320 |
+
|
| 321 |
+
n_valid = x_final.shape[0]
|
| 322 |
+
if n_valid > 0:
|
| 323 |
+
accumulated_x.append(x_final)
|
| 324 |
+
accumulated_log_rnd.append(log_rnd)
|
| 325 |
+
accumulated_rewards.append(final_rewards)
|
| 326 |
+
total_accumulated += n_valid
|
| 327 |
+
|
| 328 |
+
if rank == 0:
|
| 329 |
+
print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode=none (heads disabled for buffer gen) accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
|
| 330 |
+
|
| 331 |
+
# Starvation guard: if nothing valid comes through (e.g. the length
|
| 332 |
+
# cutoff is too aggressive for a collapsed policy), stop grinding GPU
|
| 333 |
+
# hours and fail fast with an actionable message.
|
| 334 |
+
if attempts >= starvation_patience and total_accumulated == 0:
|
| 335 |
+
if rank == 0:
|
| 336 |
+
print(f"[BUFFER STARVATION] 0 valid samples after {attempts} attempts "
|
| 337 |
+
f"(min_peptide_bonds={getattr(self.args, 'min_peptide_bonds', 4)}). "
|
| 338 |
+
f"Aborting refill early — lower --min_peptide_bonds or check the policy.")
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
if total_accumulated == 0:
|
| 342 |
+
raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
|
| 343 |
+
|
| 344 |
+
if total_accumulated < samples_per_gpu:
|
| 345 |
+
print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
|
| 346 |
+
|
| 347 |
+
new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
|
| 348 |
+
new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
|
| 349 |
+
new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
|
| 350 |
+
|
| 351 |
+
del accumulated_x, accumulated_log_rnd, accumulated_rewards
|
| 352 |
+
torch.cuda.empty_cache()
|
| 353 |
+
|
| 354 |
+
# Pool mode (after init): replace a random subset of the existing pool.
|
| 355 |
+
# Classic mode / pool init: overwrite the buffer.
|
| 356 |
+
if is_pool and not is_init:
|
| 357 |
+
actual_new = min(new_x.shape[0], self.x_saved.shape[0])
|
| 358 |
+
indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
|
| 359 |
+
self.x_saved[indices] = new_x[:actual_new]
|
| 360 |
+
self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
|
| 361 |
+
self.final_rewards_saved[indices] = new_rewards[:actual_new]
|
| 362 |
+
if rank == 0:
|
| 363 |
+
print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} sequences, reward mean={self.final_rewards_saved.mean():.4f}")
|
| 364 |
+
else:
|
| 365 |
+
self.x_saved = new_x
|
| 366 |
+
self.log_rnd_saved = new_log_rnd
|
| 367 |
+
self.final_rewards_saved = new_rewards
|
| 368 |
+
|
| 369 |
+
# Sanity check: median length (non-pad tokens) of buffered peptides.
|
| 370 |
+
if rank == 0:
|
| 371 |
+
pad = self.policy_model.interpolant.pad_token
|
| 372 |
+
token_lens = (self.x_saved != pad).sum(dim=1)
|
| 373 |
+
print(f"[BUFFER] peptide token length: median={token_lens.median().item()} "
|
| 374 |
+
f"min={token_lens.min().item()} max={token_lens.max().item()} "
|
| 375 |
+
f"(n={token_lens.shape[0]})")
|
| 376 |
+
|
| 377 |
+
def training_step(self, batch, batch_idx):
|
| 378 |
+
"""Training step - batch is ignored, we use saved buffer."""
|
| 379 |
+
# Use mini-batch sampling from buffer to avoid OOM
|
| 380 |
+
buffer_size = self.x_saved.shape[0]
|
| 381 |
+
mini_batch_size = getattr(self.args, 'training_mini_batch_size', 6)
|
| 382 |
+
|
| 383 |
+
# Randomly sample mini_batch_size sequences from buffer
|
| 384 |
+
if buffer_size > mini_batch_size:
|
| 385 |
+
indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
|
| 386 |
+
x_final = self.x_saved[indices]
|
| 387 |
+
log_rnd = self.log_rnd_saved[indices]
|
| 388 |
+
else:
|
| 389 |
+
# If buffer is smaller than mini_batch_size, use all
|
| 390 |
+
x_final = self.x_saved
|
| 391 |
+
log_rnd = self.log_rnd_saved
|
| 392 |
+
|
| 393 |
+
joint = getattr(self.args, 'joint_training', False)
|
| 394 |
+
policy_loss = None
|
| 395 |
+
planner_loss = None
|
| 396 |
+
|
| 397 |
+
if self.train_policy:
|
| 398 |
+
# Train policy with WDCE loss
|
| 399 |
+
policy_loss = self.policy_model.loss_wdce_flexible(
|
| 400 |
+
log_rnd,
|
| 401 |
+
x_final,
|
| 402 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 403 |
+
centering=self.args.centering,
|
| 404 |
+
centering_strength=self.args.centering_strength
|
| 405 |
+
)
|
| 406 |
+
self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 407 |
+
|
| 408 |
+
if (not self.train_policy) or joint:
|
| 409 |
+
# Train planner with appropriate loss based on ablation flags
|
| 410 |
+
if self.args.disable_insertion_planner:
|
| 411 |
+
# Ablation: only train unmasking/remasking planner (no insertion head)
|
| 412 |
+
planner_loss = self.policy_model.loss_planner_flexible(
|
| 413 |
+
log_rnd,
|
| 414 |
+
x_final,
|
| 415 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 416 |
+
centering=self.args.centering,
|
| 417 |
+
centering_strength=self.args.centering_strength
|
| 418 |
+
)
|
| 419 |
+
self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 420 |
+
self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 421 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 422 |
+
elif self.args.disable_unmasking_planner:
|
| 423 |
+
# Ablation: only train insertion planner (no remasking head)
|
| 424 |
+
unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
|
| 425 |
+
log_rnd,
|
| 426 |
+
x_final,
|
| 427 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 428 |
+
centering=self.args.centering,
|
| 429 |
+
centering_strength=self.args.centering_strength
|
| 430 |
+
)
|
| 431 |
+
# Zero out the unmasking component - only backprop insertion loss
|
| 432 |
+
planner_loss = insert_loss
|
| 433 |
+
self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 434 |
+
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 435 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 436 |
+
else:
|
| 437 |
+
# Full planner: train both remasking + insertion
|
| 438 |
+
unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
|
| 439 |
+
log_rnd,
|
| 440 |
+
x_final,
|
| 441 |
+
num_replicates=self.args.wdce_num_replicates,
|
| 442 |
+
centering=self.args.centering,
|
| 443 |
+
centering_strength=self.args.centering_strength
|
| 444 |
+
)
|
| 445 |
+
self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 446 |
+
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 447 |
+
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# Combine losses depending on mode
|
| 451 |
+
if joint:
|
| 452 |
+
loss = policy_loss + planner_loss
|
| 453 |
+
mode_value = 0.5
|
| 454 |
+
elif self.train_policy:
|
| 455 |
+
loss = policy_loss
|
| 456 |
+
mode_value = 0.0
|
| 457 |
+
else:
|
| 458 |
+
loss = planner_loss
|
| 459 |
+
mode_value = 1.0
|
| 460 |
+
|
| 461 |
+
# Log overall loss and mode
|
| 462 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 463 |
+
self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
|
| 464 |
+
|
| 465 |
+
return loss
|
| 466 |
+
|
| 467 |
+
def on_train_epoch_end(self):
|
| 468 |
+
"""Called at the end of each training epoch - only rank 0 evaluates."""
|
| 469 |
+
# Only evaluate every N epochs to save time
|
| 470 |
+
eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
|
| 471 |
+
is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
|
| 472 |
+
if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
|
| 473 |
+
# Sample eval batch with updated policy
|
| 474 |
+
valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = \
|
| 475 |
+
sample_peptides_eval(
|
| 476 |
+
self.policy_model, self.reward_model, self.analyzer,
|
| 477 |
+
self.tokenizer,
|
| 478 |
+
steps=self.args.total_num_steps,
|
| 479 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 480 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 481 |
+
batch_size=50,
|
| 482 |
+
max_length=self.args.max_length,
|
| 483 |
+
quality_mode=self._get_quality_mode(),
|
| 484 |
+
num_remasking=self.args.num_remasking,
|
| 485 |
+
return_valid=True,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Uniqueness (% of valid that are unique) and Diversity
|
| 489 |
+
# (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences),
|
| 490 |
+
# matching evaluate_peptide_table.py / evaluate_mol_table.py.
|
| 491 |
+
num_valid = len(valid_seqs)
|
| 492 |
+
unique_seqs = list(set(valid_seqs))
|
| 493 |
+
num_unique = len(unique_seqs)
|
| 494 |
+
uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
|
| 495 |
+
diversity = self._diversity_evaluator(unique_seqs) if num_unique > 1 else 0.0
|
| 496 |
+
|
| 497 |
+
# Append to logs
|
| 498 |
+
self.affinity_log.append(affinity)
|
| 499 |
+
self.sol_log.append(sol)
|
| 500 |
+
self.hemo_log.append(hemo)
|
| 501 |
+
self.nf_log.append(nf)
|
| 502 |
+
self.permeability_log.append(permeability)
|
| 503 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 504 |
+
self.uniqueness_log.append(uniqueness)
|
| 505 |
+
self.diversity_log.append(diversity)
|
| 506 |
+
|
| 507 |
+
# Compute reward stats
|
| 508 |
+
mean_reward = self.final_rewards_saved.mean().item()
|
| 509 |
+
min_reward = self.final_rewards_saved.min().item()
|
| 510 |
+
max_reward = self.final_rewards_saved.max().item()
|
| 511 |
+
median_reward = self.final_rewards_saved.median().item()
|
| 512 |
+
|
| 513 |
+
# Log metrics
|
| 514 |
+
self.log_dict({
|
| 515 |
+
"eval/affinity": np.mean(affinity),
|
| 516 |
+
"eval/sol": np.mean(sol),
|
| 517 |
+
"eval/hemo": np.mean(hemo),
|
| 518 |
+
"eval/nf": np.mean(nf),
|
| 519 |
+
"eval/permeability": np.mean(permeability),
|
| 520 |
+
"eval/valid_fraction": valid_fraction,
|
| 521 |
+
"eval/uniqueness": uniqueness,
|
| 522 |
+
"eval/diversity": diversity,
|
| 523 |
+
"eval/mean_reward_search": mean_reward,
|
| 524 |
+
"eval/min_reward_search": min_reward,
|
| 525 |
+
"eval/max_reward_search": max_reward,
|
| 526 |
+
"eval/median_reward_search": median_reward
|
| 527 |
+
})
|
| 528 |
+
|
| 529 |
+
print(f"epoch {self.current_epoch} | affinity {np.mean(affinity):.4f} | "
|
| 530 |
+
f"sol {np.mean(sol):.4f} | hemo {np.mean(hemo):.4f} | "
|
| 531 |
+
f"nf {np.mean(nf):.4f} | permeability {np.mean(permeability):.4f} | "
|
| 532 |
+
f"valid {valid_fraction:.4f} | uniq {uniqueness:.2f}% | div {diversity:.4f}")
|
| 533 |
+
|
| 534 |
+
def on_fit_end(self):
|
| 535 |
+
"""Called at the end of training - save results."""
|
| 536 |
+
if self.global_rank == 0:
|
| 537 |
+
# Save logs and plot
|
| 538 |
+
base_path = self.args.base_path
|
| 539 |
+
plot_path = f'{base_path}/results/{self.args.run_name}'
|
| 540 |
+
os.makedirs(plot_path, exist_ok=True)
|
| 541 |
+
|
| 542 |
+
output_log_path = f'{plot_path}/log_{self.filename}.csv'
|
| 543 |
+
save_logs_to_file(self.valid_fraction_log, self.affinity_log,
|
| 544 |
+
self.sol_log, self.hemo_log, self.nf_log,
|
| 545 |
+
self.permeability_log, output_log_path,
|
| 546 |
+
uniqueness_log=self.uniqueness_log,
|
| 547 |
+
diversity_log=self.diversity_log)
|
| 548 |
+
|
| 549 |
+
# Final generation
|
| 550 |
+
x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = \
|
| 551 |
+
sample_peptides_eval(
|
| 552 |
+
self.policy_model, self.reward_model, self.analyzer,
|
| 553 |
+
self.tokenizer,
|
| 554 |
+
steps=self.args.total_num_steps,
|
| 555 |
+
mask=self.policy_model.interpolant.mask_token,
|
| 556 |
+
pad=self.policy_model.interpolant.pad_token,
|
| 557 |
+
batch_size=50,
|
| 558 |
+
max_length=self.args.max_length,
|
| 559 |
+
quality_mode=self._get_quality_mode(),
|
| 560 |
+
num_remasking=self.args.num_remasking,
|
| 561 |
+
dataframe=True,
|
| 562 |
+
)
|
| 563 |
+
df.to_csv(f'{plot_path}/{self.prot_name}_generation_results.csv', index=False)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def save_logs_to_file(valid_fraction_log, affinity_log,
|
| 567 |
+
sol_log, hemo_log, nf_log,
|
| 568 |
+
permeability_log, output_path,
|
| 569 |
+
uniqueness_log=None, diversity_log=None):
|
| 570 |
+
"""
|
| 571 |
+
Saves the logs to a CSV file.
|
| 572 |
+
"""
|
| 573 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 574 |
+
|
| 575 |
+
log_data = {
|
| 576 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 577 |
+
"Valid Fraction": valid_fraction_log,
|
| 578 |
+
"Binding Affinity": affinity_log,
|
| 579 |
+
"Solubility": sol_log,
|
| 580 |
+
"Hemolysis": hemo_log,
|
| 581 |
+
"Nonfouling": nf_log,
|
| 582 |
+
"Permeability": permeability_log,
|
| 583 |
+
}
|
| 584 |
+
if uniqueness_log is not None:
|
| 585 |
+
log_data["Uniqueness (%)"] = uniqueness_log
|
| 586 |
+
if diversity_log is not None:
|
| 587 |
+
log_data["Diversity"] = diversity_log
|
| 588 |
+
|
| 589 |
+
df = pd.DataFrame(log_data)
|
| 590 |
+
df.to_csv(output_path, index=False)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
class DummyDataset(torch.utils.data.Dataset):
|
| 594 |
+
"""Dummy dataset for Lightning trainer (we use buffer instead)."""
|
| 595 |
+
def __init__(self, size=10):
|
| 596 |
+
self.size = size
|
| 597 |
+
|
| 598 |
+
def __len__(self):
|
| 599 |
+
return self.size
|
| 600 |
+
|
| 601 |
+
def __getitem__(self, idx):
|
| 602 |
+
return torch.zeros(1) # Dummy data
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def main():
|
| 606 |
+
"""Main entry point for distributed training."""
|
| 607 |
+
# Disable DDP optimizer for higher-order ops like flex_attention
|
| 608 |
+
import torch._dynamo
|
| 609 |
+
torch._dynamo.config.optimize_ddp = False
|
| 610 |
+
|
| 611 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 612 |
+
argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
|
| 613 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 614 |
+
argparser.add_argument('--num_epochs', type=int, default=100)
|
| 615 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 616 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 617 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 618 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 619 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 620 |
+
argparser.add_argument('--batch_size', type=int, default=50)
|
| 621 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 622 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 623 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 624 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
|
| 625 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 626 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 627 |
+
# new
|
| 628 |
+
argparser.add_argument('--run_name', type=str, default='peptides')
|
| 629 |
+
argparser.add_argument("--save_path_dir", default=os.path.join(REPO_ROOT, "checkpoints", "finetune_peptides"), type=str)
|
| 630 |
+
# mcts
|
| 631 |
+
argparser.add_argument('--num_sequences', type=int, default=10)
|
| 632 |
+
argparser.add_argument('--max_length', type=int, default=1024)
|
| 633 |
+
argparser.add_argument('--min_length', type=int, default=0,
|
| 634 |
+
help='Minimum sequence length (in SMILES SPE tokens). '
|
| 635 |
+
'Samples shorter than this are dropped from the buffer. 0 disables the filter.')
|
| 636 |
+
argparser.add_argument('--num_children', type=int, default=50)
|
| 637 |
+
argparser.add_argument('--num_iter', type=int, default=30)
|
| 638 |
+
argparser.add_argument('--seq_length', type=int, default=1024)
|
| 639 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 640 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 641 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 642 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 643 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 644 |
+
argparser.add_argument('--grad_clip', action='store_true', default=False)
|
| 645 |
+
argparser.add_argument('--resample_every_n_step', type=int, default=10)
|
| 646 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 647 |
+
argparser.add_argument('--reset_every_n_step', type=int, default=100)
|
| 648 |
+
argparser.add_argument('--alpha', type=float, default=0.01)
|
| 649 |
+
argparser.add_argument('--scalarization', type=str, default='sum')
|
| 650 |
+
argparser.add_argument('--no_mcts', action='store_true', default=False)
|
| 651 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 652 |
+
argparser.add_argument("--centering_strength", type=float, default=1.0)
|
| 653 |
+
|
| 654 |
+
# ELBO-based log_rnd estimation
|
| 655 |
+
argparser.add_argument('--elbo_rnd', action='store_true', default=False,
|
| 656 |
+
help='If set, compute log_rnd via ELBO instead of trajectory rollout')
|
| 657 |
+
argparser.add_argument('--elbo_rnd_num_samples', type=int, default=16,
|
| 658 |
+
help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
|
| 659 |
+
|
| 660 |
+
# adaptive schedule parameters
|
| 661 |
+
argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
|
| 662 |
+
argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
|
| 663 |
+
argparser.add_argument('--schedule_num_layers', type=int, default=2)
|
| 664 |
+
argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
|
| 665 |
+
argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
|
| 666 |
+
argparser.add_argument('--freeze_base_model', action='store_true', default=False)
|
| 667 |
+
argparser.add_argument('--schedule_warmup_epochs', type=int, default=0, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
|
| 668 |
+
argparser.add_argument('--alternation_frequency', type=int, default=20, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
|
| 669 |
+
argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
|
| 670 |
+
|
| 671 |
+
# objectives
|
| 672 |
+
argparser.add_argument('--num_obj', type=int, default=5)
|
| 673 |
+
argparser.add_argument('--prot_seq', type=str, default=None)
|
| 674 |
+
argparser.add_argument('--prot_name', type=str, default='glast',
|
| 675 |
+
help='Protein target name. Looked up in PROTEINS dict unless --prot_seq is given.')
|
| 676 |
+
argparser.add_argument('--devices', type=int, default=-1)
|
| 677 |
+
argparser.add_argument('--checkpoint_path', type=str, default=None)
|
| 678 |
+
argparser.add_argument('--resume_ckpt', type=str, default=None,
|
| 679 |
+
help='Path to a Lightning last.ckpt to resume training from (restores epoch/optimizer/planner state). '
|
| 680 |
+
'New checkpoints continue in the same directory as this checkpoint.')
|
| 681 |
+
|
| 682 |
+
# remasking
|
| 683 |
+
argparser.add_argument('--num_remasking', type=int, default=5)
|
| 684 |
+
argparser.add_argument('--quality_threshold', type=float, default=1)
|
| 685 |
+
|
| 686 |
+
# length cutoff (peptide-bond filter) + buffer starvation guard
|
| 687 |
+
argparser.add_argument('--min_peptide_bonds', type=int, default=4,
|
| 688 |
+
help='Minimum backbone peptide bonds for a sample to count as valid. '
|
| 689 |
+
'0 disables the cutoff. Filters degenerate short reward-hacked molecules.')
|
| 690 |
+
argparser.add_argument('--max_buffer_attempts', type=int, default=100,
|
| 691 |
+
help='Max sampling rounds per buffer refill before giving up (caps wasted GPU when validity is low).')
|
| 692 |
+
argparser.add_argument('--buffer_starvation_patience', type=int, default=10,
|
| 693 |
+
help='If 0 valid samples after this many rounds, abort the refill early (starvation guard).')
|
| 694 |
+
|
| 695 |
+
# planner ablation flags
|
| 696 |
+
argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
|
| 697 |
+
argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
|
| 698 |
+
argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
|
| 699 |
+
argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
|
| 700 |
+
|
| 701 |
+
# performance optimization
|
| 702 |
+
argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
|
| 703 |
+
argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10, help='Number of gradient updates per epoch')
|
| 704 |
+
argparser.add_argument('--training_mini_batch_size', type=int, default=6, help='Mini-batch size for training from buffer to avoid OOM')
|
| 705 |
+
argparser.add_argument('--pool_size', type=int, default=0,
|
| 706 |
+
help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer). Helps preserve uniqueness/diversity over training.')
|
| 707 |
+
argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
|
| 708 |
+
help='Fraction of pool to replace each resample step (only used when pool_size>0)')
|
| 709 |
+
|
| 710 |
+
args = argparser.parse_args()
|
| 711 |
+
|
| 712 |
+
# Default planner LR to policy LR if not specified
|
| 713 |
+
if args.planner_learning_rate is None:
|
| 714 |
+
args.planner_learning_rate = args.learning_rate
|
| 715 |
+
|
| 716 |
+
# Set seed
|
| 717 |
+
pl.seed_everything(args.seed)
|
| 718 |
+
|
| 719 |
+
# Load models
|
| 720 |
+
checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
|
| 721 |
+
os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt')
|
| 722 |
+
|
| 723 |
+
# Update args.checkpoint_path to ensure it's saved in hyperparameters for later inference
|
| 724 |
+
args.checkpoint_path = checkpoint_path
|
| 725 |
+
|
| 726 |
+
PROTEINS = {
|
| 727 |
+
'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
|
| 728 |
+
'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
|
| 729 |
+
'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
|
| 730 |
+
'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
|
| 731 |
+
'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
|
| 732 |
+
'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
|
| 733 |
+
'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
|
| 734 |
+
'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
|
| 735 |
+
'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
if args.prot_seq is not None:
|
| 739 |
+
prot = args.prot_seq
|
| 740 |
+
prot_name = args.prot_name
|
| 741 |
+
else:
|
| 742 |
+
prot_name = args.prot_name
|
| 743 |
+
if prot_name not in PROTEINS:
|
| 744 |
+
raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
|
| 745 |
+
prot = PROTEINS[prot_name]
|
| 746 |
+
filename = prot_name
|
| 747 |
+
|
| 748 |
+
curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 749 |
+
|
| 750 |
+
if args.no_mcts:
|
| 751 |
+
args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_no-mcts'
|
| 752 |
+
else:
|
| 753 |
+
args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}'
|
| 754 |
+
|
| 755 |
+
# Append ablation tags to run name for easy identification
|
| 756 |
+
if args.disable_planner:
|
| 757 |
+
args.run_name += '_no_planner'
|
| 758 |
+
if args.disable_insertion_planner:
|
| 759 |
+
args.run_name += '_no_insertion_planner'
|
| 760 |
+
if args.disable_unmasking_planner:
|
| 761 |
+
args.run_name += '_no_unmasking_planner'
|
| 762 |
+
if args.joint_training:
|
| 763 |
+
if args.disable_planner:
|
| 764 |
+
raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
|
| 765 |
+
args.run_name += '_joint_training'
|
| 766 |
+
|
| 767 |
+
# When resuming, continue writing checkpoints into the SAME directory as the
|
| 768 |
+
# checkpoint we resume from (keeps model-{epoch}.ckpt contiguous) instead of
|
| 769 |
+
# spawning a fresh timestamped run directory.
|
| 770 |
+
if args.resume_ckpt:
|
| 771 |
+
args.save_path = os.path.dirname(os.path.abspath(args.resume_ckpt))
|
| 772 |
+
args.run_name = os.path.basename(args.save_path)
|
| 773 |
+
else:
|
| 774 |
+
args.save_path = os.path.join(args.save_path_dir, args.run_name)
|
| 775 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 776 |
+
set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
|
| 777 |
+
|
| 778 |
+
# Initialize the model
|
| 779 |
+
print("Loading models..")
|
| 780 |
+
|
| 781 |
+
# Load pretrained model for reference (frozen)
|
| 782 |
+
pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
|
| 783 |
+
map_location='cpu',
|
| 784 |
+
weights_only=False)
|
| 785 |
+
pretrained.eval()
|
| 786 |
+
for param in pretrained.parameters():
|
| 787 |
+
param.requires_grad = False
|
| 788 |
+
|
| 789 |
+
# Load checkpoint to extract config
|
| 790 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 791 |
+
if 'hyper_parameters' in checkpoint:
|
| 792 |
+
config = checkpoint['hyper_parameters']['config']
|
| 793 |
+
elif 'config' in checkpoint:
|
| 794 |
+
config = checkpoint['config']
|
| 795 |
+
else:
|
| 796 |
+
raise ValueError("Cannot find config in checkpoint")
|
| 797 |
+
|
| 798 |
+
# Update config for adaptive schedule
|
| 799 |
+
from omegaconf import OmegaConf
|
| 800 |
+
if not OmegaConf.is_config(config):
|
| 801 |
+
from omegaconf import DictConfig
|
| 802 |
+
config = DictConfig(config)
|
| 803 |
+
|
| 804 |
+
# Disable struct mode to allow adding new keys
|
| 805 |
+
OmegaConf.set_struct(config, False)
|
| 806 |
+
|
| 807 |
+
config.training.use_adaptive_schedule = args.use_adaptive_schedule
|
| 808 |
+
config.training.schedule_hidden_dim = args.schedule_hidden_dim
|
| 809 |
+
config.training.schedule_num_layers = args.schedule_num_layers
|
| 810 |
+
config.training.schedule_loss_weight = args.schedule_loss_weight
|
| 811 |
+
config.training.freeze_base_model = args.freeze_base_model
|
| 812 |
+
config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
|
| 813 |
+
|
| 814 |
+
# Re-enable struct mode
|
| 815 |
+
OmegaConf.set_struct(config, True)
|
| 816 |
+
|
| 817 |
+
# Initialize policy model with adaptive schedule
|
| 818 |
+
policy_model = AnyOrderInsertionFlowModuleFT(
|
| 819 |
+
config=config,
|
| 820 |
+
args=args,
|
| 821 |
+
pretrained_checkpoint=checkpoint_path,
|
| 822 |
+
insertion_planner=True,
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# define mcts
|
| 826 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 827 |
+
|
| 828 |
+
tokenizer = SMILES_SPE_Tokenizer(
|
| 829 |
+
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
|
| 830 |
+
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# Device will be set by Lightning automatically in DDP
|
| 834 |
+
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device='cpu')
|
| 835 |
+
model = PeptideFinetuner(
|
| 836 |
+
args=args,
|
| 837 |
+
policy_model=policy_model,
|
| 838 |
+
reward_model=reward_model,
|
| 839 |
+
tokenizer=tokenizer,
|
| 840 |
+
pretrained=pretrained,
|
| 841 |
+
mcts=None,
|
| 842 |
+
filename=filename,
|
| 843 |
+
prot_name=prot_name
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# Setup checkpoint callback
|
| 847 |
+
checkpoint_callback = ModelCheckpoint(
|
| 848 |
+
dirpath=args.save_path,
|
| 849 |
+
filename='model-{epoch:02d}',
|
| 850 |
+
every_n_epochs=args.save_every_n_epochs,
|
| 851 |
+
save_top_k=-1,
|
| 852 |
+
save_last=True, # Also save last.ckpt
|
| 853 |
+
auto_insert_metric_name=False
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
# Setup wandb logger - only on rank 0 to avoid multiple runs
|
| 857 |
+
# Check if we're in a spawned DDP process
|
| 858 |
+
rank = int(os.environ.get('LOCAL_RANK', 0))
|
| 859 |
+
if rank == 0:
|
| 860 |
+
# Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
|
| 861 |
+
wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-pep', name=args.run_name)
|
| 862 |
+
else:
|
| 863 |
+
wandb_logger = None
|
| 864 |
+
|
| 865 |
+
# Create dummy dataloader
|
| 866 |
+
dataset = DummyDataset(size=args.num_training_steps_per_epoch)
|
| 867 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
| 868 |
+
|
| 869 |
+
# Setup trainer with DDP
|
| 870 |
+
trainer = pl.Trainer(
|
| 871 |
+
max_epochs=args.num_epochs,
|
| 872 |
+
accelerator='gpu',
|
| 873 |
+
devices=args.devices,
|
| 874 |
+
strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
|
| 875 |
+
gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
|
| 876 |
+
logger=wandb_logger,
|
| 877 |
+
callbacks=[checkpoint_callback],
|
| 878 |
+
enable_progress_bar=True,
|
| 879 |
+
log_every_n_steps=1
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
# Train (resume full training state from --resume_ckpt if provided).
|
| 883 |
+
# weights_only=False is required when resuming because these checkpoints
|
| 884 |
+
# store argparse.Namespace / OmegaConf objects in hyper_parameters, which
|
| 885 |
+
# PyTorch 2.6's default weights_only=True unpickler rejects.
|
| 886 |
+
if args.resume_ckpt:
|
| 887 |
+
trainer.fit(model, dataloader, ckpt_path=args.resume_ckpt, weights_only=False)
|
| 888 |
+
else:
|
| 889 |
+
trainer.fit(model, dataloader)
|
| 890 |
+
|
| 891 |
+
if __name__ == '__main__':
|
| 892 |
+
main()
|
a2d2_pep/inference_quality.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified peptide sampling with quality-guided planning.
|
| 2 |
+
|
| 3 |
+
Supports 4 quality modes and optional RND (importance weight) computation.
|
| 4 |
+
|
| 5 |
+
Quality modes:
|
| 6 |
+
"none" - No planner, no remasking (policy-only)
|
| 7 |
+
"both" - Both unmasking + insertion planners active
|
| 8 |
+
"unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
|
| 9 |
+
"insertion_only" - Only insertion planner (unmasking planner disabled)
|
| 10 |
+
|
| 11 |
+
RND toggle:
|
| 12 |
+
compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
|
| 13 |
+
compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
|
| 22 |
+
from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
|
| 23 |
+
|
| 24 |
+
QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
|
| 25 |
+
|
| 26 |
+
# When set (e.g. A2D2_QUALITY_DEBUG=1), the diffusion loop prints, per step, how
|
| 27 |
+
# many already-unmasked tokens get remasked and how many proposed insertions get
|
| 28 |
+
# filtered by the quality planner, plus a per-batch total. Off by default so it
|
| 29 |
+
# never spams training/eval runs.
|
| 30 |
+
_QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def _diffusion_loop(
|
| 35 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 36 |
+
quality_mode="both",
|
| 37 |
+
compute_rnd=False,
|
| 38 |
+
pretrained=None,
|
| 39 |
+
remasking_mode="schedule_aware",
|
| 40 |
+
num_remasking=1,
|
| 41 |
+
quality_threshold=1,
|
| 42 |
+
unmask_quality_threshold=None,
|
| 43 |
+
unmask_all=False,
|
| 44 |
+
freq_penalty=0.0,
|
| 45 |
+
return_trace=False,
|
| 46 |
+
):
|
| 47 |
+
"""Core discrete diffusion sampling loop for peptide generation.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
model: Finetuned policy model.
|
| 51 |
+
steps: Number of diffusion steps.
|
| 52 |
+
mask: Mask token ID.
|
| 53 |
+
pad: Pad token ID.
|
| 54 |
+
batch_size: Number of sequences to generate.
|
| 55 |
+
max_length: Maximum sequence length.
|
| 56 |
+
quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
|
| 57 |
+
compute_rnd: Whether to compute step-wise log importance weights.
|
| 58 |
+
pretrained: Frozen pretrained model (required if compute_rnd=True).
|
| 59 |
+
remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
|
| 60 |
+
num_remasking: Number of tokens to remask per step.
|
| 61 |
+
quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
|
| 62 |
+
return_trace: Whether to record sampling trace.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
(xt, log_rnd, sampling_trace)
|
| 66 |
+
log_rnd is None when compute_rnd=False.
|
| 67 |
+
"""
|
| 68 |
+
assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
|
| 69 |
+
if compute_rnd:
|
| 70 |
+
assert pretrained is not None, "pretrained model required when compute_rnd=True"
|
| 71 |
+
|
| 72 |
+
# Derive flags from quality_mode
|
| 73 |
+
use_remasking = quality_mode != "none"
|
| 74 |
+
disable_unmasking_planner = quality_mode in ("none", "insertion_only")
|
| 75 |
+
disable_insertion_planner = quality_mode in ("none", "unmasking_only")
|
| 76 |
+
|
| 77 |
+
device = next(model.parameters()).device
|
| 78 |
+
|
| 79 |
+
# Initialize all-pad sequence
|
| 80 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 81 |
+
|
| 82 |
+
dt = 1.0 / steps
|
| 83 |
+
t = torch.zeros(batch_size, device=device)
|
| 84 |
+
|
| 85 |
+
# Precompute index tensors
|
| 86 |
+
batch_idx_L = (
|
| 87 |
+
torch.arange(batch_size, device=device)
|
| 88 |
+
.view(batch_size, 1)
|
| 89 |
+
.expand(batch_size, max_length)
|
| 90 |
+
)
|
| 91 |
+
pos_idx_L = (
|
| 92 |
+
torch.arange(max_length, device=device)
|
| 93 |
+
.view(1, max_length)
|
| 94 |
+
.expand(batch_size, max_length)
|
| 95 |
+
)
|
| 96 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 97 |
+
|
| 98 |
+
neg_inf = torch.tensor(-np.inf, device=device)
|
| 99 |
+
|
| 100 |
+
if use_remasking and remasking_mode == "remdm_conf":
|
| 101 |
+
remasking_score = torch.zeros((batch_size, max_length), device=device)
|
| 102 |
+
|
| 103 |
+
log_rnd = None
|
| 104 |
+
|
| 105 |
+
dbg_total_remasked = 0
|
| 106 |
+
dbg_total_proposed_ins = 0
|
| 107 |
+
dbg_total_filtered = 0
|
| 108 |
+
|
| 109 |
+
for i in range(steps):
|
| 110 |
+
step_remasked = 0
|
| 111 |
+
step_proposed_ins = 0
|
| 112 |
+
step_filtered = 0
|
| 113 |
+
# --- Policy model forward ---
|
| 114 |
+
pred_rate = model(xt, t)
|
| 115 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 116 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 117 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 118 |
+
|
| 119 |
+
# --- Pretrained model forward (for RND) ---
|
| 120 |
+
if compute_rnd:
|
| 121 |
+
pretrained_pred = pretrained(xt, t)
|
| 122 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 123 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 124 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 125 |
+
|
| 126 |
+
# --- Unmask step (Euler) ---
|
| 127 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 128 |
+
unmask_rate[xt != mask] = 0
|
| 129 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 130 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 131 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 132 |
+
|
| 133 |
+
if compute_rnd:
|
| 134 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 135 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 136 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 137 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 138 |
+
|
| 139 |
+
# Add "stay" probability
|
| 140 |
+
_xt = xt.clone()
|
| 141 |
+
_xt[xt == pad] = mask
|
| 142 |
+
trans_prob.scatter_add_(
|
| 143 |
+
2,
|
| 144 |
+
_xt.unsqueeze(-1),
|
| 145 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 146 |
+
)
|
| 147 |
+
if compute_rnd:
|
| 148 |
+
pretrained_trans_prob.scatter_add_(
|
| 149 |
+
2,
|
| 150 |
+
_xt.unsqueeze(-1),
|
| 151 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Remove mask token from sampling so every masked position is decoded.
|
| 155 |
+
# The final step always does this; unmask_all does it every step, so the
|
| 156 |
+
# schedule-aware remasking below re-masks the lowest-quality tokens back
|
| 157 |
+
# down to the schedule's expected mask count.
|
| 158 |
+
if i == steps - 1 or unmask_all:
|
| 159 |
+
if i == steps - 1:
|
| 160 |
+
print("Final step, removing mask token from sampling")
|
| 161 |
+
trans_prob[mask_pos + (mask,)] = 0.0
|
| 162 |
+
|
| 163 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 164 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 165 |
+
if mask_has_zero_prob.any():
|
| 166 |
+
num_zero_prob = mask_has_zero_prob.sum().item()
|
| 167 |
+
uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
|
| 168 |
+
uniform_prob[:, :mask] = 1.0 / mask
|
| 169 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 170 |
+
else:
|
| 171 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 172 |
+
|
| 173 |
+
# --- Frequency penalty: down-weight residues already abundant in the
|
| 174 |
+
# sequence so (re)decoded masked positions don't collapse onto the modal
|
| 175 |
+
# token (glycine). Only masked positions are sampled; clean positions are
|
| 176 |
+
# overwritten below, so penalizing the whole tensor is harmless. mask/pad
|
| 177 |
+
# never accumulate counts, so their entries stay untouched. Applied to a
|
| 178 |
+
# copy so trans_prob (used for RND log-probs) is unchanged.
|
| 179 |
+
sample_prob = trans_prob
|
| 180 |
+
if freq_penalty > 0.0:
|
| 181 |
+
V = trans_prob.shape[-1]
|
| 182 |
+
clean_tok = (xt != mask) & (xt != pad) # (B, L)
|
| 183 |
+
counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype)
|
| 184 |
+
counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)),
|
| 185 |
+
clean_tok.to(trans_prob.dtype))
|
| 186 |
+
sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1)
|
| 187 |
+
|
| 188 |
+
new_xt = _sample_tokens(sample_prob)
|
| 189 |
+
new_xt[xt == pad] = pad
|
| 190 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 191 |
+
|
| 192 |
+
# Update remasking_score buffer for remdm_conf mode
|
| 193 |
+
if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
|
| 194 |
+
token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
|
| 195 |
+
chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
|
| 196 |
+
changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 197 |
+
remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
|
| 198 |
+
|
| 199 |
+
# --- Remasking step ---
|
| 200 |
+
if use_remasking and i < steps - 1:
|
| 201 |
+
if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
|
| 202 |
+
remasking_conf = torch.zeros((batch_size, max_length), device=device)
|
| 203 |
+
else:
|
| 204 |
+
planner_out = model.planner(new_xt, t)
|
| 205 |
+
remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
|
| 206 |
+
|
| 207 |
+
clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
|
| 208 |
+
|
| 209 |
+
if remasking_mode == "remdm":
|
| 210 |
+
remasking_score_temp = torch.rand(remasking_conf.shape, device=device)
|
| 211 |
+
elif remasking_mode == "remdm_conf":
|
| 212 |
+
remasking_score_temp = -1.0 * remasking_conf
|
| 213 |
+
elif remasking_mode == "schedule_aware":
|
| 214 |
+
# Only remask when the unmasking planner is active. Otherwise
|
| 215 |
+
# (e.g. insertion_only / no_unmasking_planner) remasking_conf is
|
| 216 |
+
# all zeros, so this would remask schedule-excess tokens by
|
| 217 |
+
# position rather than by quality.
|
| 218 |
+
if not disable_unmasking_planner:
|
| 219 |
+
new_xt = apply_schedule_aware_remasking(
|
| 220 |
+
model, new_xt, t, dt, remasking_conf, clean_index,
|
| 221 |
+
mask, neg_inf, batch_size,
|
| 222 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 223 |
+
)
|
| 224 |
+
remasking_score_temp = None
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
|
| 227 |
+
|
| 228 |
+
if remasking_score_temp is not None:
|
| 229 |
+
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
|
| 230 |
+
for j in range(batch_size):
|
| 231 |
+
k = min(num_remasking, int(clean_index[j].sum().item()))
|
| 232 |
+
if k > 0:
|
| 233 |
+
_, select_indices = torch.topk(remasking_score_temp[j], k=k)
|
| 234 |
+
new_xt[j, select_indices] = mask
|
| 235 |
+
|
| 236 |
+
if _QUALITY_DEBUG:
|
| 237 |
+
# Positions that were clean before this remasking block and are
|
| 238 |
+
# now mask are exactly the unmasked tokens that got remasked.
|
| 239 |
+
step_remasked = int((clean_index & (new_xt == mask)).sum().item())
|
| 240 |
+
|
| 241 |
+
if return_trace:
|
| 242 |
+
for batch_idx in range(batch_size):
|
| 243 |
+
for pos in range(max_length):
|
| 244 |
+
if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
|
| 245 |
+
sampling_trace[batch_idx].append(
|
| 246 |
+
SamplingTraceDatapoint(
|
| 247 |
+
t=t[batch_idx].item(),
|
| 248 |
+
event_type="change",
|
| 249 |
+
position=pos,
|
| 250 |
+
token=mask,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# --- Compute log probabilities for RND ---
|
| 255 |
+
if compute_rnd:
|
| 256 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 257 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 258 |
+
|
| 259 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 260 |
+
|
| 261 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 262 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 263 |
+
|
| 264 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 265 |
+
|
| 266 |
+
# --- Insertion step ---
|
| 267 |
+
if i != steps - 1:
|
| 268 |
+
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
|
| 269 |
+
|
| 270 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 271 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 272 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 273 |
+
total_ext = ext.sum(dim=1)
|
| 274 |
+
valid = xt_len + total_ext <= max_length
|
| 275 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 276 |
+
|
| 277 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 278 |
+
new_len = xt_len + total_ext # (B,)
|
| 279 |
+
|
| 280 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 281 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 282 |
+
xt_tmp[mask_fill] = mask
|
| 283 |
+
|
| 284 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 285 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 286 |
+
flat_b = batch_idx_L[orig_mask]
|
| 287 |
+
flat_p = new_pos_orig[orig_mask]
|
| 288 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 289 |
+
|
| 290 |
+
if _QUALITY_DEBUG:
|
| 291 |
+
# ext has been masked by the max-length validity check above, so
|
| 292 |
+
# this is the number of fresh mask tokens actually inserted.
|
| 293 |
+
step_proposed_ins = int(ext.sum().item())
|
| 294 |
+
|
| 295 |
+
# Schedule-aware insertion quality filtering
|
| 296 |
+
if use_remasking and not disable_insertion_planner:
|
| 297 |
+
if compute_rnd:
|
| 298 |
+
xt_tmp_before = xt_tmp.clone()
|
| 299 |
+
|
| 300 |
+
dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0
|
| 301 |
+
|
| 302 |
+
xt_tmp = apply_schedule_aware_insertion(
|
| 303 |
+
model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
|
| 304 |
+
orig_mask, new_pos_orig, quality_threshold
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if _QUALITY_DEBUG:
|
| 308 |
+
# Filtering only drops/compacts tokens, so the drop in
|
| 309 |
+
# non-pad count is the number of insertions filtered out.
|
| 310 |
+
step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item())
|
| 311 |
+
|
| 312 |
+
if compute_rnd:
|
| 313 |
+
# Compute corrected ext based on what actually stayed
|
| 314 |
+
ext_corrected = torch.zeros_like(ext)
|
| 315 |
+
for b in range(batch_size):
|
| 316 |
+
after_len = xt_tmp[b].ne(pad).sum().item()
|
| 317 |
+
orig_len = xt_len[b].item()
|
| 318 |
+
surviving_insertions = after_len - orig_len
|
| 319 |
+
if total_ext[b] > 0:
|
| 320 |
+
ratio = surviving_insertions / total_ext[b].item()
|
| 321 |
+
ext_corrected[b] = (ext[b].float() * ratio).long()
|
| 322 |
+
else:
|
| 323 |
+
ext_corrected = ext
|
| 324 |
+
else:
|
| 325 |
+
ext_corrected = ext
|
| 326 |
+
|
| 327 |
+
# Compute insertion log_rnd
|
| 328 |
+
if compute_rnd:
|
| 329 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 330 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 331 |
+
|
| 332 |
+
log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
|
| 333 |
+
log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
|
| 334 |
+
|
| 335 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert
|
| 336 |
+
log_rnd += log_insert_diff
|
| 337 |
+
else:
|
| 338 |
+
xt_tmp = new_xt
|
| 339 |
+
|
| 340 |
+
if return_trace:
|
| 341 |
+
for batch_idx in range(batch_size):
|
| 342 |
+
for j in range(max_length):
|
| 343 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 344 |
+
sampling_trace[batch_idx].append(
|
| 345 |
+
SamplingTraceDatapoint(
|
| 346 |
+
t=t[batch_idx].item(),
|
| 347 |
+
event_type="change",
|
| 348 |
+
position=j,
|
| 349 |
+
token=new_xt[batch_idx, j].item(),
|
| 350 |
+
)
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if i != steps - 1:
|
| 354 |
+
for j in range(max_length):
|
| 355 |
+
id = max_length - j - 1
|
| 356 |
+
if ext[batch_idx, id]:
|
| 357 |
+
sampling_trace[batch_idx].append(
|
| 358 |
+
SamplingTraceDatapoint(
|
| 359 |
+
t=t[batch_idx].item(),
|
| 360 |
+
event_type="insertion",
|
| 361 |
+
position=id,
|
| 362 |
+
token=mask,
|
| 363 |
+
)
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if _QUALITY_DEBUG:
|
| 367 |
+
dbg_total_remasked += step_remasked
|
| 368 |
+
dbg_total_proposed_ins += step_proposed_ins
|
| 369 |
+
dbg_total_filtered += step_filtered
|
| 370 |
+
print(
|
| 371 |
+
f"[QUALITY {quality_mode}] step {i+1}/{steps}: "
|
| 372 |
+
f"remasked {step_remasked} unmasked tokens -> mask | "
|
| 373 |
+
f"insertions proposed {step_proposed_ins}, "
|
| 374 |
+
f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
xt = xt_tmp
|
| 378 |
+
t = t + dt
|
| 379 |
+
|
| 380 |
+
if _QUALITY_DEBUG:
|
| 381 |
+
print(
|
| 382 |
+
f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): "
|
| 383 |
+
f"remasked {dbg_total_remasked} unmasked tokens | "
|
| 384 |
+
f"insertions proposed {dbg_total_proposed_ins}, "
|
| 385 |
+
f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
return xt, log_rnd, sampling_trace
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@torch.no_grad()
|
| 392 |
+
def sample_peptides_buffer(
|
| 393 |
+
model, reward_model, analyzer, tokenizer,
|
| 394 |
+
steps, mask, pad, batch_size, max_length,
|
| 395 |
+
quality_mode="both",
|
| 396 |
+
compute_rnd=False,
|
| 397 |
+
pretrained=None,
|
| 398 |
+
alpha=0.1,
|
| 399 |
+
remasking_mode="schedule_aware",
|
| 400 |
+
num_remasking=1,
|
| 401 |
+
quality_threshold=1,
|
| 402 |
+
min_length=0,
|
| 403 |
+
):
|
| 404 |
+
"""Generate peptides for training buffer.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
model: Finetuned policy model.
|
| 408 |
+
reward_model: Multi-objective scoring function.
|
| 409 |
+
analyzer: PeptideAnalyzer for validation.
|
| 410 |
+
tokenizer: Tokenizer for decoding.
|
| 411 |
+
steps: Number of diffusion steps.
|
| 412 |
+
mask: Mask token ID.
|
| 413 |
+
pad: Pad token ID.
|
| 414 |
+
batch_size: Number of sequences to generate.
|
| 415 |
+
max_length: Maximum sequence length.
|
| 416 |
+
quality_mode: "none", "both", "unmasking_only", or "insertion_only".
|
| 417 |
+
compute_rnd: If True, compute step-wise log importance weights (requires pretrained).
|
| 418 |
+
If False, returns placeholder zero log_rnd (for ELBO-based RND).
|
| 419 |
+
pretrained: Frozen pretrained model (required when compute_rnd=True).
|
| 420 |
+
alpha: RND scaling factor.
|
| 421 |
+
remasking_mode: Remasking strategy.
|
| 422 |
+
num_remasking: Number of tokens to remask per step.
|
| 423 |
+
quality_threshold: Threshold for insertion quality filtering.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
(valid_x, log_rnd, scalar_rewards, sampling_trace)
|
| 427 |
+
"""
|
| 428 |
+
xt, log_rnd, trace = _diffusion_loop(
|
| 429 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 430 |
+
quality_mode=quality_mode,
|
| 431 |
+
compute_rnd=compute_rnd,
|
| 432 |
+
pretrained=pretrained,
|
| 433 |
+
remasking_mode=remasking_mode,
|
| 434 |
+
num_remasking=num_remasking,
|
| 435 |
+
quality_threshold=quality_threshold,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
device = xt.device
|
| 439 |
+
decoded_samples = tokenizer.batch_decode(xt)
|
| 440 |
+
|
| 441 |
+
valid_x_final = []
|
| 442 |
+
validSequences = []
|
| 443 |
+
valid_log_rnd = []
|
| 444 |
+
|
| 445 |
+
for idx, seq in enumerate(decoded_samples):
|
| 446 |
+
if not analyzer.is_peptide(seq):
|
| 447 |
+
continue
|
| 448 |
+
token_len = int((xt[idx] != pad).sum().item())
|
| 449 |
+
if min_length > 0 and token_len < min_length:
|
| 450 |
+
continue
|
| 451 |
+
valid_x_final.append(xt[idx])
|
| 452 |
+
validSequences.append(seq)
|
| 453 |
+
if compute_rnd:
|
| 454 |
+
valid_log_rnd.append(log_rnd[idx])
|
| 455 |
+
|
| 456 |
+
print("len valid sequences:", len(validSequences))
|
| 457 |
+
|
| 458 |
+
if len(validSequences) == 0:
|
| 459 |
+
print("[WARNING] No valid peptides generated in this batch")
|
| 460 |
+
empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
|
| 461 |
+
empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
|
| 462 |
+
empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
|
| 463 |
+
return empty_x, empty_log_rnd, empty_rewards, trace
|
| 464 |
+
|
| 465 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 466 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 467 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
|
| 468 |
+
|
| 469 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 470 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 471 |
+
|
| 472 |
+
if compute_rnd:
|
| 473 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 474 |
+
log_rnd_out = valid_log_rnd + (scalar_rewards / alpha)
|
| 475 |
+
else:
|
| 476 |
+
log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device)
|
| 477 |
+
|
| 478 |
+
return valid_x_final, log_rnd_out, scalar_rewards, trace
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@torch.no_grad()
|
| 482 |
+
def sample_peptides_eval(
|
| 483 |
+
model, reward_model, analyzer, tokenizer,
|
| 484 |
+
steps, mask, pad, batch_size, max_length,
|
| 485 |
+
quality_mode="both",
|
| 486 |
+
remasking_mode="schedule_aware",
|
| 487 |
+
num_remasking=1,
|
| 488 |
+
quality_threshold=1,
|
| 489 |
+
unmask_quality_threshold=None,
|
| 490 |
+
unmask_all=False,
|
| 491 |
+
freq_penalty=0.0,
|
| 492 |
+
dataframe=False,
|
| 493 |
+
return_valid=False,
|
| 494 |
+
):
|
| 495 |
+
"""Generate peptides for evaluation.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
model: Finetuned policy model.
|
| 499 |
+
reward_model: Multi-objective scoring function.
|
| 500 |
+
analyzer: PeptideAnalyzer for validation.
|
| 501 |
+
tokenizer: Tokenizer for decoding.
|
| 502 |
+
steps: Number of diffusion steps.
|
| 503 |
+
mask: Mask token ID.
|
| 504 |
+
pad: Pad token ID.
|
| 505 |
+
batch_size: Number of sequences to generate.
|
| 506 |
+
max_length: Maximum sequence length.
|
| 507 |
+
quality_mode: "none", "both", "unmasking_only", or "insertion_only".
|
| 508 |
+
remasking_mode: Remasking strategy.
|
| 509 |
+
num_remasking: Number of tokens to remask per step.
|
| 510 |
+
quality_threshold: Threshold for insertion quality filtering.
|
| 511 |
+
dataframe: If True, include a pandas DataFrame in the return.
|
| 512 |
+
return_valid: If True, return decoded valid sequences instead of raw token tensors.
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
For multi-objective (5 objectives):
|
| 516 |
+
(samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df])
|
| 517 |
+
For single objective:
|
| 518 |
+
(samples, sol, valid_fraction[, df])
|
| 519 |
+
When return_valid=True, samples is replaced with validSequences list.
|
| 520 |
+
"""
|
| 521 |
+
xt, _, trace = _diffusion_loop(
|
| 522 |
+
model, steps, mask, pad, batch_size, max_length,
|
| 523 |
+
quality_mode=quality_mode,
|
| 524 |
+
compute_rnd=False,
|
| 525 |
+
remasking_mode=remasking_mode,
|
| 526 |
+
num_remasking=num_remasking,
|
| 527 |
+
quality_threshold=quality_threshold,
|
| 528 |
+
unmask_quality_threshold=unmask_quality_threshold,
|
| 529 |
+
unmask_all=unmask_all,
|
| 530 |
+
freq_penalty=freq_penalty,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
device = xt.device
|
| 534 |
+
samples = xt.to(device)
|
| 535 |
+
decoded_samples = tokenizer.batch_decode(samples)
|
| 536 |
+
|
| 537 |
+
valid_x_final = []
|
| 538 |
+
validSequences = []
|
| 539 |
+
|
| 540 |
+
for idx, seq in enumerate(decoded_samples):
|
| 541 |
+
if analyzer.is_peptide(seq):
|
| 542 |
+
valid_x_final.append(samples[idx])
|
| 543 |
+
validSequences.append(seq)
|
| 544 |
+
|
| 545 |
+
print("len valid sequences:", len(validSequences))
|
| 546 |
+
|
| 547 |
+
valid_fraction = len(validSequences) / batch_size
|
| 548 |
+
|
| 549 |
+
# Determine number of objectives from reward model
|
| 550 |
+
num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5
|
| 551 |
+
|
| 552 |
+
if len(validSequences) != 0:
|
| 553 |
+
score_vectors = reward_model(input_seqs=validSequences) # (N, num_objectives)
|
| 554 |
+
average_scores = score_vectors.T
|
| 555 |
+
|
| 556 |
+
if num_objectives == 1:
|
| 557 |
+
sol = average_scores[0]
|
| 558 |
+
else:
|
| 559 |
+
affinity = average_scores[0]
|
| 560 |
+
sol = average_scores[1]
|
| 561 |
+
hemo = average_scores[2]
|
| 562 |
+
nf = average_scores[3]
|
| 563 |
+
permeability = average_scores[4]
|
| 564 |
+
else:
|
| 565 |
+
zeros = [0.0]
|
| 566 |
+
|
| 567 |
+
if num_objectives == 1:
|
| 568 |
+
sol = zeros
|
| 569 |
+
else:
|
| 570 |
+
affinity = zeros
|
| 571 |
+
sol = zeros
|
| 572 |
+
hemo = zeros
|
| 573 |
+
nf = zeros
|
| 574 |
+
permeability = zeros
|
| 575 |
+
|
| 576 |
+
if num_objectives == 1:
|
| 577 |
+
if dataframe:
|
| 578 |
+
df = pd.DataFrame({
|
| 579 |
+
"Peptide Sequence": validSequences,
|
| 580 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 581 |
+
})
|
| 582 |
+
if return_valid:
|
| 583 |
+
return validSequences, sol, valid_fraction, df
|
| 584 |
+
return samples, sol, valid_fraction, df
|
| 585 |
+
|
| 586 |
+
if return_valid:
|
| 587 |
+
return validSequences, sol, valid_fraction
|
| 588 |
+
return samples, sol, valid_fraction
|
| 589 |
+
|
| 590 |
+
if dataframe:
|
| 591 |
+
df = pd.DataFrame({
|
| 592 |
+
"Peptide Sequence": validSequences,
|
| 593 |
+
"Binding Affinity": affinity if len(validSequences) else [0.0],
|
| 594 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 595 |
+
"Hemolysis": hemo if len(validSequences) else [0.0],
|
| 596 |
+
"Nonfouling": nf if len(validSequences) else [0.0],
|
| 597 |
+
"Permeability": permeability if len(validSequences) else [0.0],
|
| 598 |
+
})
|
| 599 |
+
if return_valid:
|
| 600 |
+
return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 601 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 602 |
+
|
| 603 |
+
if return_valid:
|
| 604 |
+
return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 605 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction
|
a2d2_pep/pep_scoring/functions/binding.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import esm
|
| 8 |
+
from transformers import AutoModelForMaskedLM
|
| 9 |
+
|
| 10 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
esm_dim=1280,
|
| 13 |
+
smiles_dim=768,
|
| 14 |
+
hidden_dim=512,
|
| 15 |
+
n_heads=8,
|
| 16 |
+
n_layers=3,
|
| 17 |
+
dropout=0.1):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
# Define binding thresholds
|
| 21 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 22 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 23 |
+
|
| 24 |
+
# Project to same dimension
|
| 25 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 26 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 27 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 28 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 29 |
+
|
| 30 |
+
# Cross attention blocks with layer norm
|
| 31 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 32 |
+
nn.ModuleDict({
|
| 33 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 34 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 35 |
+
'ffn': nn.Sequential(
|
| 36 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 40 |
+
),
|
| 41 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 42 |
+
}) for _ in range(n_layers)
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
# Prediction heads
|
| 46 |
+
self.shared_head = nn.Sequential(
|
| 47 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Regression head
|
| 53 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 54 |
+
|
| 55 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 56 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 57 |
+
|
| 58 |
+
def get_binding_class(self, affinity):
|
| 59 |
+
"""Convert affinity values to class indices
|
| 60 |
+
0: tight binding (>= 7.5)
|
| 61 |
+
1: medium binding (6.0-7.5)
|
| 62 |
+
2: weak binding (< 6.0)
|
| 63 |
+
"""
|
| 64 |
+
if isinstance(affinity, torch.Tensor):
|
| 65 |
+
tight_mask = affinity >= self.tight_threshold
|
| 66 |
+
weak_mask = affinity < self.weak_threshold
|
| 67 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 68 |
+
|
| 69 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 70 |
+
classes[medium_mask] = 1
|
| 71 |
+
classes[weak_mask] = 2
|
| 72 |
+
return classes
|
| 73 |
+
else:
|
| 74 |
+
if affinity >= self.tight_threshold:
|
| 75 |
+
return 0 # tight binding
|
| 76 |
+
elif affinity < self.weak_threshold:
|
| 77 |
+
return 2 # weak binding
|
| 78 |
+
else:
|
| 79 |
+
return 1 # medium binding
|
| 80 |
+
|
| 81 |
+
def forward(self, protein_emb, smiles_emb):
|
| 82 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 83 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 84 |
+
|
| 85 |
+
#protein = protein.transpose(0, 1)
|
| 86 |
+
#smiles = smiles.transpose(0, 1)
|
| 87 |
+
|
| 88 |
+
# Cross attention layers
|
| 89 |
+
for layer in self.cross_attention_layers:
|
| 90 |
+
# Protein attending to SMILES
|
| 91 |
+
attended_protein = layer['attention'](
|
| 92 |
+
protein, smiles, smiles
|
| 93 |
+
)[0]
|
| 94 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 95 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 96 |
+
|
| 97 |
+
# SMILES attending to protein
|
| 98 |
+
attended_smiles = layer['attention'](
|
| 99 |
+
smiles, protein, protein
|
| 100 |
+
)[0]
|
| 101 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 102 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 103 |
+
|
| 104 |
+
# Get sequence-level representations
|
| 105 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 106 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 107 |
+
|
| 108 |
+
# Concatenate both representations
|
| 109 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 110 |
+
|
| 111 |
+
# Shared features
|
| 112 |
+
shared_features = self.shared_head(combined)
|
| 113 |
+
|
| 114 |
+
regression_output = self.regression_head(shared_features)
|
| 115 |
+
classification_logits = self.classification_head(shared_features)
|
| 116 |
+
|
| 117 |
+
return regression_output, classification_logits
|
| 118 |
+
|
| 119 |
+
class BindingAffinity:
|
| 120 |
+
def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 123 |
+
|
| 124 |
+
# peptide embeddings
|
| 125 |
+
if emb_model is not None:
|
| 126 |
+
self.pep_model = emb_model.to(self.device).eval()
|
| 127 |
+
else:
|
| 128 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 129 |
+
|
| 130 |
+
self.pep_tokenizer = tokenizer
|
| 131 |
+
|
| 132 |
+
self.model = ImprovedBindingPredictor().to(self.device)
|
| 133 |
+
checkpoint = torch.load(f'{base_path}/functions/classifiers/binding-affinity.pt',
|
| 134 |
+
map_location=self.device,
|
| 135 |
+
weights_only=False)
|
| 136 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 137 |
+
|
| 138 |
+
self.model.eval()
|
| 139 |
+
|
| 140 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 141 |
+
self.esm_model = self.esm_model.to(self.device).eval()
|
| 142 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 143 |
+
|
| 144 |
+
data = [("target", prot_seq)]
|
| 145 |
+
# get tokenized protein
|
| 146 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 147 |
+
prot_tokens = prot_tokens.to(self.device)
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 150 |
+
prot_emb = results["representations"][33]
|
| 151 |
+
|
| 152 |
+
self.prot_emb = prot_emb[0].to(self.device)
|
| 153 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def forward(self, input_seqs):
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
scores = []
|
| 159 |
+
for seq in input_seqs:
|
| 160 |
+
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
|
| 161 |
+
|
| 162 |
+
pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
emb = self.pep_model(input_ids=pep_tokens['input_ids'],
|
| 166 |
+
attention_mask=pep_tokens['attention_mask'],
|
| 167 |
+
output_hidden_states=True)
|
| 168 |
+
|
| 169 |
+
#emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
|
| 170 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 171 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 172 |
+
|
| 173 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 174 |
+
scores.append(score.item())
|
| 175 |
+
return scores
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_seqs: list):
|
| 178 |
+
return self.forward(input_seqs)
|
a2d2_pep/pep_scoring/functions/binding_utils.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def to_var(x):
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
x = x.cuda()
|
| 8 |
+
return x
|
| 9 |
+
|
| 10 |
+
class MultiHeadAttentionSequence(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 13 |
+
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.n_head = n_head
|
| 17 |
+
self.d_model = d_model
|
| 18 |
+
self.d_k = d_k
|
| 19 |
+
self.d_v = d_v
|
| 20 |
+
|
| 21 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 22 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 23 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 24 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 25 |
+
|
| 26 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 27 |
+
|
| 28 |
+
self.dropout = nn.Dropout(dropout)
|
| 29 |
+
|
| 30 |
+
def forward(self, q, k, v):
|
| 31 |
+
|
| 32 |
+
batch, len_q, _ = q.size()
|
| 33 |
+
batch, len_k, _ = k.size()
|
| 34 |
+
batch, len_v, _ = v.size()
|
| 35 |
+
|
| 36 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 37 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 38 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 39 |
+
|
| 40 |
+
Q = Q.transpose(1, 2)
|
| 41 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 42 |
+
V = V.transpose(1, 2)
|
| 43 |
+
|
| 44 |
+
attention = torch.matmul(Q, K)
|
| 45 |
+
|
| 46 |
+
attention = attention / np.sqrt(self.d_k)
|
| 47 |
+
|
| 48 |
+
attention = F.softmax(attention, dim=-1)
|
| 49 |
+
|
| 50 |
+
output = torch.matmul(attention, V)
|
| 51 |
+
|
| 52 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 53 |
+
|
| 54 |
+
output = self.W_O(output)
|
| 55 |
+
|
| 56 |
+
output = self.dropout(output)
|
| 57 |
+
|
| 58 |
+
output = self.layer_norm(output + q)
|
| 59 |
+
|
| 60 |
+
return output, attention
|
| 61 |
+
|
| 62 |
+
class MultiHeadAttentionReciprocal(nn.Module):
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 66 |
+
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.n_head = n_head
|
| 70 |
+
self.d_model = d_model
|
| 71 |
+
self.d_k = d_k
|
| 72 |
+
self.d_v = d_v
|
| 73 |
+
|
| 74 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 75 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 76 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 77 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 78 |
+
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 79 |
+
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 80 |
+
|
| 81 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 82 |
+
|
| 83 |
+
self.dropout = nn.Dropout(dropout)
|
| 84 |
+
|
| 85 |
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 86 |
+
|
| 87 |
+
self.dropout_2 = nn.Dropout(dropout)
|
| 88 |
+
|
| 89 |
+
def forward(self, q, k, v, v_2):
|
| 90 |
+
|
| 91 |
+
batch, len_q, _ = q.size()
|
| 92 |
+
batch, len_k, _ = k.size()
|
| 93 |
+
batch, len_v, _ = v.size()
|
| 94 |
+
batch, len_v_2, _ = v_2.size()
|
| 95 |
+
|
| 96 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 97 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 98 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 99 |
+
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 100 |
+
|
| 101 |
+
Q = Q.transpose(1, 2)
|
| 102 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 103 |
+
V = V.transpose(1, 2)
|
| 104 |
+
V_2 = V_2.transpose(1,2)
|
| 105 |
+
|
| 106 |
+
attention = torch.matmul(Q, K)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
attention = attention /np.sqrt(self.d_k)
|
| 110 |
+
|
| 111 |
+
attention_2 = attention.transpose(-2, -1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
attention = F.softmax(attention, dim=-1)
|
| 116 |
+
|
| 117 |
+
attention_2 = F.softmax(attention_2, dim=-1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
output = torch.matmul(attention, V)
|
| 121 |
+
|
| 122 |
+
output_2 = torch.matmul(attention_2, V_2)
|
| 123 |
+
|
| 124 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 125 |
+
|
| 126 |
+
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 127 |
+
|
| 128 |
+
output = self.W_O(output)
|
| 129 |
+
|
| 130 |
+
output_2 = self.W_O_2(output_2)
|
| 131 |
+
|
| 132 |
+
output = self.dropout(output)
|
| 133 |
+
|
| 134 |
+
output = self.layer_norm(output + q)
|
| 135 |
+
|
| 136 |
+
output_2 = self.dropout(output_2)
|
| 137 |
+
|
| 138 |
+
output_2 = self.layer_norm(output_2 + k)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
return output, output_2, attention, attention_2
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class FFN(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 150 |
+
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 151 |
+
self.relu = nn.ReLU()
|
| 152 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
| 153 |
+
|
| 154 |
+
self.dropout = nn.Dropout(dropout)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
|
| 158 |
+
residual = x
|
| 159 |
+
output = self.layer_1(x.transpose(1, 2))
|
| 160 |
+
|
| 161 |
+
output = self.relu(output)
|
| 162 |
+
|
| 163 |
+
output = self.layer_2(output)
|
| 164 |
+
|
| 165 |
+
output = self.dropout(output)
|
| 166 |
+
|
| 167 |
+
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 168 |
+
|
| 169 |
+
return output
|
| 170 |
+
|
| 171 |
+
class ConvLayer(nn.Module):
|
| 172 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 173 |
+
super(ConvLayer, self).__init__()
|
| 174 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 175 |
+
self.relu = nn.ReLU()
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
out = self.conv(x)
|
| 179 |
+
out = self.relu(out)
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DilatedCNN(nn.Module):
|
| 184 |
+
def __init__(self, d_model, d_hidden):
|
| 185 |
+
super(DilatedCNN, self).__init__()
|
| 186 |
+
self.first_ = nn.ModuleList()
|
| 187 |
+
self.second_ = nn.ModuleList()
|
| 188 |
+
self.third_ = nn.ModuleList()
|
| 189 |
+
|
| 190 |
+
dilation_tuple = (1, 2, 3)
|
| 191 |
+
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 192 |
+
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 193 |
+
|
| 194 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 195 |
+
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 196 |
+
dilation=dilation_rate))
|
| 197 |
+
|
| 198 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 199 |
+
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 200 |
+
dilation=dilation_rate))
|
| 201 |
+
|
| 202 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 203 |
+
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 204 |
+
dilation=dilation_rate))
|
| 205 |
+
|
| 206 |
+
def forward(self, protein_seq_enc):
|
| 207 |
+
# pdb.set_trace()
|
| 208 |
+
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 209 |
+
|
| 210 |
+
first_embedding = protein_seq_enc
|
| 211 |
+
second_embedding = protein_seq_enc
|
| 212 |
+
third_embedding = protein_seq_enc
|
| 213 |
+
|
| 214 |
+
for i in range(len(self.first_)):
|
| 215 |
+
first_embedding = self.first_[i](first_embedding)
|
| 216 |
+
|
| 217 |
+
for i in range(len(self.second_)):
|
| 218 |
+
second_embedding = self.second_[i](second_embedding)
|
| 219 |
+
|
| 220 |
+
for i in range(len(self.third_)):
|
| 221 |
+
third_embedding = self.third_[i](third_embedding)
|
| 222 |
+
|
| 223 |
+
# pdb.set_trace()
|
| 224 |
+
|
| 225 |
+
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 226 |
+
|
| 227 |
+
return protein_seq_enc.transpose(1, 2)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class ReciprocalLayerwithCNN(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 236 |
+
|
| 237 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 238 |
+
|
| 239 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 240 |
+
|
| 241 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 242 |
+
|
| 243 |
+
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 244 |
+
|
| 245 |
+
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 246 |
+
|
| 247 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 248 |
+
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 249 |
+
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 250 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 251 |
+
|
| 252 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 253 |
+
|
| 254 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 255 |
+
|
| 256 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 257 |
+
|
| 258 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 259 |
+
|
| 260 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ReciprocalLayer(nn.Module):
|
| 264 |
+
|
| 265 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 266 |
+
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 270 |
+
|
| 271 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 272 |
+
|
| 273 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 274 |
+
|
| 275 |
+
self.ffn_seq = FFN(d_model, d_inner)
|
| 276 |
+
|
| 277 |
+
self.ffn_protein = FFN(d_model, d_inner)
|
| 278 |
+
|
| 279 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 280 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 281 |
+
|
| 282 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 286 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 287 |
+
|
| 288 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 289 |
+
|
| 290 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
a2d2_pep/pep_scoring/functions/hemolysis.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Hemolysis:
|
| 15 |
+
|
| 16 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 17 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 18 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/hemolysis-xgboost.json')
|
| 19 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def generate_embeddings(self, sequences):
|
| 23 |
+
embeddings = []
|
| 24 |
+
for sequence in sequences:
|
| 25 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 26 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
output = self.emb_model(**tokenized)
|
| 29 |
+
# Mean pooling across sequence length
|
| 30 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 31 |
+
embeddings.append(embedding)
|
| 32 |
+
return np.array(embeddings)
|
| 33 |
+
|
| 34 |
+
def get_scores(self, input_seqs: list):
|
| 35 |
+
scores = np.ones(len(input_seqs))
|
| 36 |
+
features = self.generate_embeddings(input_seqs)
|
| 37 |
+
|
| 38 |
+
if len(features) == 0:
|
| 39 |
+
return scores
|
| 40 |
+
|
| 41 |
+
features = np.nan_to_num(features, nan=0.)
|
| 42 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 43 |
+
|
| 44 |
+
features = xgb.DMatrix(features)
|
| 45 |
+
|
| 46 |
+
probs = self.predictor.predict(features)
|
| 47 |
+
# return the probability of it being not hemolytic
|
| 48 |
+
return scores - probs
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_seqs: list):
|
| 51 |
+
scores = self.get_scores(input_seqs)
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
def unittest():
|
| 55 |
+
hemo = Hemolysis()
|
| 56 |
+
seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 57 |
+
print(hemo.tokenizer.vocab_size)
|
| 58 |
+
scores = hemo(input_seqs=seq)
|
| 59 |
+
print(scores)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
a2d2_pep/pep_scoring/functions/nonfouling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
rdBase.DisableLog('rdApp.error')
|
| 13 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
|
| 17 |
+
class Nonfouling:
|
| 18 |
+
|
| 19 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 21 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/nonfouling-xgboost.json')
|
| 22 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
# return the probability of it being not hemolytic
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
def __call__(self, input_seqs: list):
|
| 54 |
+
scores = self.get_scores(input_seqs)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def unittest():
|
| 58 |
+
nf = Nonfouling()
|
| 59 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 60 |
+
|
| 61 |
+
scores = nf(input_seqs=seq)
|
| 62 |
+
print(scores)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
unittest()
|
a2d2_pep/pep_scoring/functions/permeability.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
rdBase.DisableLog('rdApp.error')
|
| 16 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 17 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 19 |
+
|
| 20 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 21 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 22 |
+
fps = []
|
| 23 |
+
valid_mask = []
|
| 24 |
+
for i, smile in enumerate(smiles):
|
| 25 |
+
mol = Chem.MolFromSmiles(smile)
|
| 26 |
+
valid_mask.append(int(mol is not None))
|
| 27 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 28 |
+
fps.append(fp)
|
| 29 |
+
|
| 30 |
+
fps = np.concatenate(fps, axis=0)
|
| 31 |
+
return fps, valid_mask
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 35 |
+
""" Create ECFP fingerprint of a molecule """
|
| 36 |
+
if hashed:
|
| 37 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 38 |
+
else:
|
| 39 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 40 |
+
fp_np = np.zeros((1,))
|
| 41 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 42 |
+
return fp_np.reshape(1, -1)
|
| 43 |
+
|
| 44 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 45 |
+
""" calculate the full list of descriptors for a molecule """
|
| 46 |
+
|
| 47 |
+
values, names = [], []
|
| 48 |
+
for nm, fn in Descriptors._descList:
|
| 49 |
+
try:
|
| 50 |
+
val = fn(mol)
|
| 51 |
+
except:
|
| 52 |
+
val = missingVal
|
| 53 |
+
values.append(val)
|
| 54 |
+
names.append(nm)
|
| 55 |
+
|
| 56 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 57 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 58 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 59 |
+
|
| 60 |
+
for nm, fn in custom_descriptors.items():
|
| 61 |
+
try:
|
| 62 |
+
val = fn(mol)
|
| 63 |
+
except:
|
| 64 |
+
val = missingVal
|
| 65 |
+
values.append(val)
|
| 66 |
+
names.append(nm)
|
| 67 |
+
return values, names
|
| 68 |
+
|
| 69 |
+
def get_pep_dps_from_smi(smi):
|
| 70 |
+
try:
|
| 71 |
+
mol = Chem.MolFromSmiles(smi)
|
| 72 |
+
except:
|
| 73 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 74 |
+
mol = None
|
| 75 |
+
|
| 76 |
+
dps, _ = getMolDescriptors(mol)
|
| 77 |
+
return np.array(dps)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_pep_dps(smi_list):
|
| 81 |
+
if len(smi_list) == 0:
|
| 82 |
+
return np.zeros((0, 213))
|
| 83 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 84 |
+
|
| 85 |
+
def check_smi_validity(smiles: list):
|
| 86 |
+
valid_smi, valid_idx = [], []
|
| 87 |
+
for idx, smi in enumerate(smiles):
|
| 88 |
+
try:
|
| 89 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 90 |
+
if mol:
|
| 91 |
+
valid_smi.append(smi)
|
| 92 |
+
valid_idx.append(idx)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 95 |
+
pass
|
| 96 |
+
return valid_smi, valid_idx
|
| 97 |
+
|
| 98 |
+
class Permeability:
|
| 99 |
+
|
| 100 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 101 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 102 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/permeability-xgboost.json')
|
| 103 |
+
if emb_model is not None:
|
| 104 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 105 |
+
else:
|
| 106 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 107 |
+
|
| 108 |
+
self.tokenizer = tokenizer
|
| 109 |
+
|
| 110 |
+
def generate_embeddings(self, sequences):
|
| 111 |
+
embeddings = []
|
| 112 |
+
for sequence in sequences:
|
| 113 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 114 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
output = self.emb_model(**tokenized)
|
| 117 |
+
# Mean pooling across sequence length
|
| 118 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 119 |
+
embeddings.append(embedding)
|
| 120 |
+
return np.array(embeddings)
|
| 121 |
+
|
| 122 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 123 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if fps:
|
| 127 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 128 |
+
else:
|
| 129 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 130 |
+
|
| 131 |
+
if dps:
|
| 132 |
+
descriptors = get_pep_dps(input_seqs)
|
| 133 |
+
else:
|
| 134 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 135 |
+
|
| 136 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 137 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 138 |
+
|
| 139 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 140 |
+
|
| 141 |
+
return features
|
| 142 |
+
|
| 143 |
+
def get_scores(self, input_seqs: list):
|
| 144 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 145 |
+
features = self.get_features(input_seqs)
|
| 146 |
+
|
| 147 |
+
if len(features) == 0:
|
| 148 |
+
return scores
|
| 149 |
+
|
| 150 |
+
features = np.nan_to_num(features, nan=0.)
|
| 151 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 152 |
+
|
| 153 |
+
features = xgb.DMatrix(features)
|
| 154 |
+
|
| 155 |
+
scores = self.predictor.predict(features)
|
| 156 |
+
return scores
|
| 157 |
+
|
| 158 |
+
def __call__(self, input_seqs: list):
|
| 159 |
+
scores = self.get_scores(input_seqs)
|
| 160 |
+
return scores
|
| 161 |
+
|
| 162 |
+
def unittest():
|
| 163 |
+
permeability = Permeability()
|
| 164 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 165 |
+
scores = permeability(input_seqs=seq)
|
| 166 |
+
print(scores)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
unittest()
|
a2d2_pep/pep_scoring/functions/scoring_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 6 |
+
import joblib
|
| 7 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 8 |
+
from rdkit.Chem import AllChem
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 13 |
+
"""
|
| 14 |
+
Create ECFP fingerprint of a molecule
|
| 15 |
+
"""
|
| 16 |
+
if hashed:
|
| 17 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 18 |
+
else:
|
| 19 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 20 |
+
fp_np = np.zeros((1,))
|
| 21 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 22 |
+
return fp_np.reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 26 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 27 |
+
fps = []
|
| 28 |
+
valid_mask = []
|
| 29 |
+
for i, smile in enumerate(smiles):
|
| 30 |
+
mol = Chem.MolFromSmiles(smile)
|
| 31 |
+
valid_mask.append(int(mol is not None))
|
| 32 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 33 |
+
fps.append(fp)
|
| 34 |
+
|
| 35 |
+
fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
|
| 36 |
+
return fps, valid_mask
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 40 |
+
""" calculate the full list of descriptors for a molecule """
|
| 41 |
+
|
| 42 |
+
values, names = [], []
|
| 43 |
+
for nm, fn in Descriptors._descList:
|
| 44 |
+
try:
|
| 45 |
+
val = fn(mol)
|
| 46 |
+
except:
|
| 47 |
+
val = missingVal
|
| 48 |
+
values.append(val)
|
| 49 |
+
names.append(nm)
|
| 50 |
+
|
| 51 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 52 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 53 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 54 |
+
|
| 55 |
+
for nm, fn in custom_descriptors.items():
|
| 56 |
+
try:
|
| 57 |
+
val = fn(mol)
|
| 58 |
+
except:
|
| 59 |
+
val = missingVal
|
| 60 |
+
values.append(val)
|
| 61 |
+
names.append(nm)
|
| 62 |
+
return values, names
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_pep_dps_from_smi(smi):
|
| 66 |
+
try:
|
| 67 |
+
mol = Chem.MolFromSmiles(smi)
|
| 68 |
+
except:
|
| 69 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 70 |
+
mol = None
|
| 71 |
+
|
| 72 |
+
dps, _ = getMolDescriptors(mol)
|
| 73 |
+
return np.array(dps)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_pep_dps(smi_list):
|
| 77 |
+
if len(smi_list) == 0:
|
| 78 |
+
return np.zeros((0, 211))
|
| 79 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def check_smi_validity(smiles: list):
|
| 84 |
+
valid_smi, valid_idx = [], []
|
| 85 |
+
for idx, smi in enumerate(smiles):
|
| 86 |
+
try:
|
| 87 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 88 |
+
if mol:
|
| 89 |
+
valid_smi.append(smi)
|
| 90 |
+
valid_idx.append(idx)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 93 |
+
pass
|
| 94 |
+
return valid_smi, valid_idx
|
a2d2_pep/pep_scoring/functions/solubility.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Solubility:
|
| 15 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 17 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/solubility-xgboost.json')
|
| 18 |
+
if emb_model is not None:
|
| 19 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 20 |
+
else:
|
| 21 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 22 |
+
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
return scores
|
| 51 |
+
|
| 52 |
+
def __call__(self, input_seqs: list):
|
| 53 |
+
scores = self.get_scores(input_seqs)
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def unittest():
|
| 57 |
+
solubility = Solubility()
|
| 58 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 59 |
+
scores = solubility(input_seqs=seq)
|
| 60 |
+
print(scores)
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
a2d2_pep/pep_scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 3 |
+
from transformers import AutoModelForMaskedLM
|
| 4 |
+
import numpy as np
|
| 5 |
+
from .functions.binding import BindingAffinity
|
| 6 |
+
from .functions.permeability import Permeability
|
| 7 |
+
from .functions.solubility import Solubility
|
| 8 |
+
from .functions.hemolysis import Hemolysis
|
| 9 |
+
from .functions.nonfouling import Nonfouling
|
| 10 |
+
|
| 11 |
+
# base path: this package directory (holds tokenizer/ and functions/classifiers/)
|
| 12 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
|
| 14 |
+
class ScoringFunctions:
|
| 15 |
+
def __init__(self, score_func_names=None, prot_seqs=None, device=None):
|
| 16 |
+
"""
|
| 17 |
+
Class for generating score vectors given generated sequence
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
score_func_names: list of scoring function names to be evaluated
|
| 21 |
+
score_weights: weights to scale scores (default: 1)
|
| 22 |
+
target_protein: sequence of target protein binder
|
| 23 |
+
"""
|
| 24 |
+
emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 25 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tokenizer/new_vocab.txt',
|
| 26 |
+
f'{base_path}/tokenizer/new_splits.txt')
|
| 27 |
+
prot_seqs = prot_seqs if prot_seqs is not None else []
|
| 28 |
+
|
| 29 |
+
if score_func_names is None:
|
| 30 |
+
# just do unmasking based on validity of peptide bonds
|
| 31 |
+
self.score_func_names = []
|
| 32 |
+
else:
|
| 33 |
+
self.score_func_names = score_func_names
|
| 34 |
+
|
| 35 |
+
# self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
|
| 36 |
+
|
| 37 |
+
# binding affinities
|
| 38 |
+
self.target_protein = prot_seqs
|
| 39 |
+
print(len(prot_seqs))
|
| 40 |
+
|
| 41 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 42 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 43 |
+
binding_affinity2 = None
|
| 44 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 45 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 46 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 47 |
+
else:
|
| 48 |
+
print("here")
|
| 49 |
+
binding_affinity1 = None
|
| 50 |
+
binding_affinity2 = None
|
| 51 |
+
|
| 52 |
+
permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 53 |
+
sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 54 |
+
nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 55 |
+
hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 56 |
+
|
| 57 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 58 |
+
'binding_affinity2': binding_affinity2,
|
| 59 |
+
'permeability': permeability,
|
| 60 |
+
'nonfouling': nonfouling,
|
| 61 |
+
'solubility': sol,
|
| 62 |
+
'hemolysis': hemo
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def forward(self, input_seqs):
|
| 66 |
+
scores = []
|
| 67 |
+
|
| 68 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 69 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 70 |
+
|
| 71 |
+
scores.append(score)
|
| 72 |
+
|
| 73 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 74 |
+
scores = np.float32(scores).T
|
| 75 |
+
|
| 76 |
+
return scores
|
| 77 |
+
|
| 78 |
+
def __call__(self, input_seqs: list):
|
| 79 |
+
return self.forward(input_seqs)
|
a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def load_vocab(vocab_file):
|
| 10 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 11 |
+
vocab = collections.OrderedDict()
|
| 12 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 13 |
+
tokens = reader.readlines()
|
| 14 |
+
for index, token in enumerate(tokens):
|
| 15 |
+
token = token.rstrip("\n")
|
| 16 |
+
vocab[token] = index
|
| 17 |
+
return vocab
|
| 18 |
+
|
| 19 |
+
class Atomwise_Tokenizer(object):
|
| 20 |
+
"""Run atom-level SMILES tokenization"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" Constructs a atom-level Tokenizer.
|
| 24 |
+
"""
|
| 25 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 26 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 27 |
+
|
| 28 |
+
self.regex = re.compile(self.regex_pattern)
|
| 29 |
+
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
""" Basic Tokenization of a SMILES.
|
| 32 |
+
"""
|
| 33 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 34 |
+
return tokens
|
| 35 |
+
|
| 36 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 39 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 40 |
+
should refer to the superclass for more information regarding methods.
|
| 41 |
+
Args:
|
| 42 |
+
vocab_file (:obj:`string`):
|
| 43 |
+
File containing the vocabulary.
|
| 44 |
+
spe_file (:obj:`string`):
|
| 45 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 46 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 50 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 51 |
+
for sequence classification or for a text and a question for question answering.
|
| 52 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 53 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 56 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 57 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 58 |
+
special tokens.
|
| 59 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 60 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 61 |
+
modeling. This is the token which the model will try to predict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab_file, spe_file,
|
| 65 |
+
unk_token="[UNK]",
|
| 66 |
+
sep_token="[SEP]",
|
| 67 |
+
pad_token="[PAD]",
|
| 68 |
+
cls_token="[CLS]",
|
| 69 |
+
mask_token="[MASK]",
|
| 70 |
+
**kwargs):
|
| 71 |
+
if not os.path.isfile(vocab_file):
|
| 72 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 73 |
+
if not os.path.isfile(spe_file):
|
| 74 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 75 |
+
|
| 76 |
+
self.vocab = load_vocab(vocab_file)
|
| 77 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 78 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 79 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 80 |
+
|
| 81 |
+
super().__init__(
|
| 82 |
+
unk_token=unk_token,
|
| 83 |
+
sep_token=sep_token,
|
| 84 |
+
pad_token=pad_token,
|
| 85 |
+
cls_token=cls_token,
|
| 86 |
+
mask_token=mask_token,
|
| 87 |
+
**kwargs)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def vocab_size(self):
|
| 91 |
+
return len(self.vocab)
|
| 92 |
+
|
| 93 |
+
def get_vocab(self):
|
| 94 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 95 |
+
|
| 96 |
+
def _tokenize(self, text):
|
| 97 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 98 |
+
|
| 99 |
+
def _convert_token_to_id(self, token):
|
| 100 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 101 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 102 |
+
|
| 103 |
+
# changed encode and decode functions
|
| 104 |
+
def encode(self, token_array):
|
| 105 |
+
token_ids = []
|
| 106 |
+
token_ids.append(2)
|
| 107 |
+
for token in token_array:
|
| 108 |
+
id = self._convert_token_to_id(token)
|
| 109 |
+
token_ids.append(id)
|
| 110 |
+
token_ids.append(3)
|
| 111 |
+
token_ids = torch.tensor([token_ids])
|
| 112 |
+
attn_mask = torch.ones_like(token_ids)
|
| 113 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 116 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 117 |
+
token_array = []
|
| 118 |
+
for idx in token_ids:
|
| 119 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 120 |
+
break
|
| 121 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 122 |
+
continue
|
| 123 |
+
token = self._convert_id_to_token(idx)
|
| 124 |
+
token_array.append(token)
|
| 125 |
+
sequence = "".join(token_array)
|
| 126 |
+
return sequence
|
| 127 |
+
|
| 128 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 129 |
+
sequences = []
|
| 130 |
+
for token_ids in batch_token_ids:
|
| 131 |
+
sequences.append(self.decode(token_ids))
|
| 132 |
+
return sequences
|
| 133 |
+
|
| 134 |
+
def get_token_split(self, token_ids):
|
| 135 |
+
if isinstance(token_ids, torch.Tensor):
|
| 136 |
+
token_ids = token_ids.cpu().tolist()
|
| 137 |
+
|
| 138 |
+
token_array = []
|
| 139 |
+
for seq_ids in token_ids:
|
| 140 |
+
seq_array = []
|
| 141 |
+
for id in seq_ids:
|
| 142 |
+
token = self._convert_id_to_token(id)
|
| 143 |
+
seq_array.append(token)
|
| 144 |
+
token_array.append(seq_array)
|
| 145 |
+
|
| 146 |
+
return token_array
|
| 147 |
+
|
| 148 |
+
def _convert_id_to_token(self, index):
|
| 149 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 150 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 151 |
+
|
| 152 |
+
def convert_tokens_to_string(self, tokens):
|
| 153 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 154 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 155 |
+
return out_string
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 162 |
+
by concatenating and adding special tokens.
|
| 163 |
+
A BERT sequence has the following format:
|
| 164 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 165 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 166 |
+
Args:
|
| 167 |
+
token_ids_0 (:obj:`List[int]`):
|
| 168 |
+
List of IDs to which the special tokens will be added
|
| 169 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
Optional second list of IDs for sequence pairs.
|
| 171 |
+
Returns:
|
| 172 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 173 |
+
"""
|
| 174 |
+
if token_ids_1 is None:
|
| 175 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 176 |
+
cls = [self.cls_token_id]
|
| 177 |
+
sep = [self.sep_token_id]
|
| 178 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 179 |
+
|
| 180 |
+
def get_special_tokens_mask(
|
| 181 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 182 |
+
) -> List[int]:
|
| 183 |
+
"""
|
| 184 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 185 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 186 |
+
Args:
|
| 187 |
+
token_ids_0 (:obj:`List[int]`):
|
| 188 |
+
List of ids.
|
| 189 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 190 |
+
Optional second list of IDs for sequence pairs.
|
| 191 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 192 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if already_has_special_tokens:
|
| 198 |
+
if token_ids_1 is not None:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 201 |
+
"ids is already formated with special tokens for the model."
|
| 202 |
+
)
|
| 203 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 204 |
+
|
| 205 |
+
if token_ids_1 is not None:
|
| 206 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 207 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 208 |
+
|
| 209 |
+
def create_token_type_ids_from_sequences(
|
| 210 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 211 |
+
) -> List[int]:
|
| 212 |
+
"""
|
| 213 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 214 |
+
A BERT sequence pair mask has the following format:
|
| 215 |
+
::
|
| 216 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 217 |
+
| first sequence | second sequence |
|
| 218 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 219 |
+
Args:
|
| 220 |
+
token_ids_0 (:obj:`List[int]`):
|
| 221 |
+
List of ids.
|
| 222 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 223 |
+
Optional second list of IDs for sequence pairs.
|
| 224 |
+
Returns:
|
| 225 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 226 |
+
sequence(s).
|
| 227 |
+
"""
|
| 228 |
+
sep = [self.sep_token_id]
|
| 229 |
+
cls = [self.cls_token_id]
|
| 230 |
+
if token_ids_1 is None:
|
| 231 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 232 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 233 |
+
|
| 234 |
+
def save_vocabulary(self, vocab_path):
|
| 235 |
+
"""
|
| 236 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 237 |
+
Args:
|
| 238 |
+
vocab_path (:obj:`str`):
|
| 239 |
+
The directory in which to save the vocabulary.
|
| 240 |
+
Returns:
|
| 241 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 242 |
+
"""
|
| 243 |
+
index = 0
|
| 244 |
+
vocab_file = vocab_path
|
| 245 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 246 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 247 |
+
if index != token_index:
|
| 248 |
+
index = token_index
|
| 249 |
+
writer.write(token + "\n")
|
| 250 |
+
index += 1
|
| 251 |
+
return (vocab_file,)
|
| 252 |
+
|
| 253 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 254 |
+
r"""
|
| 255 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 256 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 257 |
+
should refer to the superclass for more information regarding methods.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_file (:obj:`string`):
|
| 260 |
+
File containing the vocabulary.
|
| 261 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 262 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 263 |
+
token instead.
|
| 264 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 265 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 266 |
+
for sequence classification or for a text and a question for question answering.
|
| 267 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 268 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 269 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 270 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 271 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 272 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 273 |
+
special tokens.
|
| 274 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 275 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 276 |
+
modeling. This is the token which the model will try to predict.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
vocab_file,
|
| 282 |
+
unk_token="[UNK]",
|
| 283 |
+
sep_token="[SEP]",
|
| 284 |
+
pad_token="[PAD]",
|
| 285 |
+
cls_token="[CLS]",
|
| 286 |
+
mask_token="[MASK]",
|
| 287 |
+
**kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(
|
| 290 |
+
unk_token=unk_token,
|
| 291 |
+
sep_token=sep_token,
|
| 292 |
+
pad_token=pad_token,
|
| 293 |
+
cls_token=cls_token,
|
| 294 |
+
mask_token=mask_token,
|
| 295 |
+
**kwargs,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not os.path.isfile(vocab_file):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 301 |
+
)
|
| 302 |
+
self.vocab = load_vocab(vocab_file)
|
| 303 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 304 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def vocab_size(self):
|
| 308 |
+
return len(self.vocab)
|
| 309 |
+
|
| 310 |
+
def get_vocab(self):
|
| 311 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text):
|
| 315 |
+
return self.tokenizer.tokenize(text)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 319 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 320 |
+
|
| 321 |
+
def _convert_id_to_token(self, index):
|
| 322 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 323 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 324 |
+
|
| 325 |
+
def convert_tokens_to_string(self, tokens):
|
| 326 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 327 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
def build_inputs_with_special_tokens(
|
| 331 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 332 |
+
) -> List[int]:
|
| 333 |
+
"""
|
| 334 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 335 |
+
by concatenating and adding special tokens.
|
| 336 |
+
A BERT sequence has the following format:
|
| 337 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 338 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 339 |
+
Args:
|
| 340 |
+
token_ids_0 (:obj:`List[int]`):
|
| 341 |
+
List of IDs to which the special tokens will be added
|
| 342 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 343 |
+
Optional second list of IDs for sequence pairs.
|
| 344 |
+
Returns:
|
| 345 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 346 |
+
"""
|
| 347 |
+
if token_ids_1 is None:
|
| 348 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 349 |
+
cls = [self.cls_token_id]
|
| 350 |
+
sep = [self.sep_token_id]
|
| 351 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 352 |
+
|
| 353 |
+
def get_special_tokens_mask(
|
| 354 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 355 |
+
) -> List[int]:
|
| 356 |
+
"""
|
| 357 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 358 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 365 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 366 |
+
Returns:
|
| 367 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
if already_has_special_tokens:
|
| 371 |
+
if token_ids_1 is not None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 374 |
+
"ids is already formated with special tokens for the model."
|
| 375 |
+
)
|
| 376 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 377 |
+
|
| 378 |
+
if token_ids_1 is not None:
|
| 379 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 380 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 381 |
+
|
| 382 |
+
def create_token_type_ids_from_sequences(
|
| 383 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 384 |
+
) -> List[int]:
|
| 385 |
+
"""
|
| 386 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 387 |
+
A BERT sequence pair mask has the following format:
|
| 388 |
+
::
|
| 389 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 390 |
+
| first sequence | second sequence |
|
| 391 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 392 |
+
Args:
|
| 393 |
+
token_ids_0 (:obj:`List[int]`):
|
| 394 |
+
List of ids.
|
| 395 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 396 |
+
Optional second list of IDs for sequence pairs.
|
| 397 |
+
Returns:
|
| 398 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 399 |
+
sequence(s).
|
| 400 |
+
"""
|
| 401 |
+
sep = [self.sep_token_id]
|
| 402 |
+
cls = [self.cls_token_id]
|
| 403 |
+
if token_ids_1 is None:
|
| 404 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 405 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 406 |
+
|
| 407 |
+
def save_vocabulary(self, vocab_path):
|
| 408 |
+
"""
|
| 409 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 410 |
+
Args:
|
| 411 |
+
vocab_path (:obj:`str`):
|
| 412 |
+
The directory in which to save the vocabulary.
|
| 413 |
+
Returns:
|
| 414 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 415 |
+
"""
|
| 416 |
+
index = 0
|
| 417 |
+
vocab_file = vocab_path
|
| 418 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 419 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 420 |
+
if index != token_index:
|
| 421 |
+
index = token_index
|
| 422 |
+
writer.write(token + "\n")
|
| 423 |
+
index += 1
|
| 424 |
+
return (vocab_file,)
|
a2d2_pep/pep_utils/analyzer.py
ADDED
|
@@ -0,0 +1,1274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from io import StringIO
|
| 5 |
+
import rdkit
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem, Draw
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.patches as patches
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import tempfile
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
|
| 16 |
+
class PeptideAnalyzer:
|
| 17 |
+
def __init__(self, min_peptide_bonds=2, enforce_min_peptide_bonds=True):
|
| 18 |
+
# length cutoff: minimum number of backbone residues (N-Cα-C(=O) units)
|
| 19 |
+
|
| 20 |
+
self.min_peptide_bonds = min_peptide_bonds
|
| 21 |
+
self.enforce_min_peptide_bonds = enforce_min_peptide_bonds
|
| 22 |
+
self.bond_patterns = [
|
| 23 |
+
(r'OC\(=O\)', 'ester'), # Ester bond
|
| 24 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
|
| 25 |
+
(r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
|
| 26 |
+
(r'NC\(=O\)', 'peptide'), # Standard peptide bond
|
| 27 |
+
(r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
|
| 28 |
+
(r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
|
| 29 |
+
]
|
| 30 |
+
# Three to one letter code mapping
|
| 31 |
+
self.three_to_one = {
|
| 32 |
+
'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
|
| 33 |
+
'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
|
| 34 |
+
'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
|
| 35 |
+
'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
|
| 36 |
+
'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def count_peptide_bonds(self, smiles):
|
| 40 |
+
"""Count backbone peptide residues via N-Cα-C(=O) units.
|
| 41 |
+
|
| 42 |
+
Matches the backbone pattern [NX3][CX4][CX3](=O): an amide nitrogen
|
| 43 |
+
bonded to an sp3 alpha-carbon bonded to a carbonyl. Requiring the sp3
|
| 44 |
+
Cα excludes non-backbone amides — ureas/biurets (N-C(=O)-N, no Cα),
|
| 45 |
+
sulfonamides, and side-chain amides (Asn/Gln) — and uniquify=True
|
| 46 |
+
avoids the multiple-mapping over-count of symmetric N-methyl groups.
|
| 47 |
+
Each match corresponds to one backbone residue.
|
| 48 |
+
"""
|
| 49 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 50 |
+
if mol is None:
|
| 51 |
+
return 0
|
| 52 |
+
backbone_pattern = Chem.MolFromSmarts('[NX3][CX4][CX3](=O)')
|
| 53 |
+
return len(mol.GetSubstructMatches(backbone_pattern, uniquify=True))
|
| 54 |
+
|
| 55 |
+
def is_peptide(self, smiles):
|
| 56 |
+
"""Check if the SMILES represents a peptide structure"""
|
| 57 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 58 |
+
if mol is None:
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
# Count backbone residues (N-Cα-C(=O) units). Requiring a real backbone
|
| 62 |
+
# unit rejects ureas/biurets and side-chain-only amides outright.
|
| 63 |
+
n_residues = self.count_peptide_bonds(smiles)
|
| 64 |
+
if n_residues == 0:
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
# length cutoff: reject molecules with too few backbone residues
|
| 68 |
+
if self.enforce_min_peptide_bonds and n_residues < self.min_peptide_bonds:
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
def is_cyclic(self, smiles):
|
| 74 |
+
"""Improved cyclic peptide detection"""
|
| 75 |
+
# Check for C-terminal carboxyl
|
| 76 |
+
if smiles.endswith('C(=O)O'):
|
| 77 |
+
return False, [], []
|
| 78 |
+
|
| 79 |
+
# Find all numbers used in ring closures
|
| 80 |
+
ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
|
| 81 |
+
|
| 82 |
+
# Find aromatic ring numbers
|
| 83 |
+
aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
|
| 84 |
+
aromatic_cycles = []
|
| 85 |
+
for match in aromatic_matches:
|
| 86 |
+
numbers = re.findall(r'[0-9]', match)
|
| 87 |
+
aromatic_cycles.extend(numbers)
|
| 88 |
+
|
| 89 |
+
# Numbers that aren't part of aromatic rings are peptide cycles
|
| 90 |
+
peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
|
| 91 |
+
|
| 92 |
+
is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
|
| 93 |
+
return is_cyclic, peptide_cycles, aromatic_cycles
|
| 94 |
+
|
| 95 |
+
def split_on_bonds(self, smiles):
|
| 96 |
+
"""Split SMILES into segments with simplified Pro handling"""
|
| 97 |
+
positions = []
|
| 98 |
+
used = set()
|
| 99 |
+
|
| 100 |
+
# Find Gly pattern first
|
| 101 |
+
gly_pattern = r'NCC\(=O\)'
|
| 102 |
+
for match in re.finditer(gly_pattern, smiles):
|
| 103 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 104 |
+
positions.append({
|
| 105 |
+
'start': match.start(),
|
| 106 |
+
'end': match.end(),
|
| 107 |
+
'type': 'gly',
|
| 108 |
+
'pattern': match.group()
|
| 109 |
+
})
|
| 110 |
+
used.update(range(match.start(), match.end()))
|
| 111 |
+
|
| 112 |
+
for pattern, bond_type in self.bond_patterns:
|
| 113 |
+
for match in re.finditer(pattern, smiles):
|
| 114 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 115 |
+
positions.append({
|
| 116 |
+
'start': match.start(),
|
| 117 |
+
'end': match.end(),
|
| 118 |
+
'type': bond_type,
|
| 119 |
+
'pattern': match.group()
|
| 120 |
+
})
|
| 121 |
+
used.update(range(match.start(), match.end()))
|
| 122 |
+
|
| 123 |
+
# Sort by position
|
| 124 |
+
positions.sort(key=lambda x: x['start'])
|
| 125 |
+
|
| 126 |
+
# Create segments
|
| 127 |
+
segments = []
|
| 128 |
+
|
| 129 |
+
if positions:
|
| 130 |
+
# First segment
|
| 131 |
+
if positions[0]['start'] > 0:
|
| 132 |
+
segments.append({
|
| 133 |
+
'content': smiles[0:positions[0]['start']],
|
| 134 |
+
'bond_after': positions[0]['pattern']
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
# Process segments
|
| 138 |
+
for i in range(len(positions)-1):
|
| 139 |
+
current = positions[i]
|
| 140 |
+
next_pos = positions[i+1]
|
| 141 |
+
|
| 142 |
+
if current['type'] == 'gly':
|
| 143 |
+
segments.append({
|
| 144 |
+
'content': 'NCC(=O)',
|
| 145 |
+
'bond_before': positions[i-1]['pattern'] if i > 0 else None,
|
| 146 |
+
'bond_after': next_pos['pattern']
|
| 147 |
+
})
|
| 148 |
+
else:
|
| 149 |
+
content = smiles[current['end']:next_pos['start']]
|
| 150 |
+
if content:
|
| 151 |
+
segments.append({
|
| 152 |
+
'content': content,
|
| 153 |
+
'bond_before': current['pattern'],
|
| 154 |
+
'bond_after': next_pos['pattern']
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
# Last segment
|
| 158 |
+
if positions[-1]['end'] < len(smiles):
|
| 159 |
+
segments.append({
|
| 160 |
+
'content': smiles[positions[-1]['end']:],
|
| 161 |
+
'bond_before': positions[-1]['pattern']
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
return segments
|
| 165 |
+
|
| 166 |
+
def clean_terminal_carboxyl(self, segment):
|
| 167 |
+
"""Remove C-terminal carboxyl only if it's the true terminus"""
|
| 168 |
+
content = segment['content']
|
| 169 |
+
|
| 170 |
+
# Only clean if:
|
| 171 |
+
# 1. Contains C(=O)O
|
| 172 |
+
# 2. No bond_after exists (meaning it's the last segment)
|
| 173 |
+
# 3. C(=O)O is at the end of the content
|
| 174 |
+
if 'C(=O)O' in content and not segment.get('bond_after'):
|
| 175 |
+
print('recognized?')
|
| 176 |
+
# Remove C(=O)O pattern regardless of position
|
| 177 |
+
cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
|
| 178 |
+
# Remove any leftover empty parentheses
|
| 179 |
+
cleaned = re.sub(r'\(\)', '', cleaned)
|
| 180 |
+
print(cleaned)
|
| 181 |
+
return cleaned
|
| 182 |
+
return content
|
| 183 |
+
|
| 184 |
+
def identify_residue(self, segment):
|
| 185 |
+
"""Identify residue with Pro reconstruction"""
|
| 186 |
+
# Only clean terminal carboxyl if this is the last segment
|
| 187 |
+
content = self.clean_terminal_carboxyl(segment)
|
| 188 |
+
mods = self.get_modifications(segment)
|
| 189 |
+
|
| 190 |
+
# UAA pattern matching section - before regular residues
|
| 191 |
+
# Phenylglycine and derivatives
|
| 192 |
+
if 'c1ccccc1' in content:
|
| 193 |
+
if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
|
| 194 |
+
return '4', mods # Base phenylglycine
|
| 195 |
+
|
| 196 |
+
# 4-substituted phenylalanines
|
| 197 |
+
if 'Cc1ccc' in content:
|
| 198 |
+
if 'OMe' in content or 'OCc1ccc' in content:
|
| 199 |
+
return '0A1', mods # 4-methoxy-Phenylalanine
|
| 200 |
+
elif 'Clc1ccc' in content:
|
| 201 |
+
return '200', mods # 4-chloro-Phenylalanine
|
| 202 |
+
elif 'Brc1ccc' in content:
|
| 203 |
+
return '4BF', mods # 4-Bromo-phenylalanine
|
| 204 |
+
elif 'C#Nc1ccc' in content:
|
| 205 |
+
return '4CF', mods # 4-cyano-phenylalanine
|
| 206 |
+
elif 'Ic1ccc' in content:
|
| 207 |
+
return 'PHI', mods # 4-Iodo-phenylalanine
|
| 208 |
+
elif 'Fc1ccc' in content:
|
| 209 |
+
return 'PFF', mods # 4-Fluoro-phenylalanine
|
| 210 |
+
|
| 211 |
+
# Modified tryptophans
|
| 212 |
+
if 'c[nH]c2' in content:
|
| 213 |
+
if 'Oc2cccc2' in content:
|
| 214 |
+
return '0AF', mods # 7-hydroxy-tryptophan
|
| 215 |
+
elif 'Fc2cccc2' in content:
|
| 216 |
+
return '4FW', mods # 4-fluoro-tryptophan
|
| 217 |
+
elif 'Clc2cccc2' in content:
|
| 218 |
+
return '6CW', mods # 6-chloro-tryptophan
|
| 219 |
+
elif 'Brc2cccc2' in content:
|
| 220 |
+
return 'BTR', mods # 6-bromo-tryptophan
|
| 221 |
+
elif 'COc2cccc2' in content:
|
| 222 |
+
return 'MOT5', mods # 5-Methoxy-tryptophan
|
| 223 |
+
elif 'Cc2cccc2' in content:
|
| 224 |
+
return 'MTR5', mods # 5-Methyl-tryptophan
|
| 225 |
+
|
| 226 |
+
# Special amino acids
|
| 227 |
+
if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
|
| 228 |
+
return 'BUG', mods # Tertleucine
|
| 229 |
+
|
| 230 |
+
if 'CCCNC(=N)N' in content:
|
| 231 |
+
return 'CIR', mods # Citrulline
|
| 232 |
+
|
| 233 |
+
if '[SeH]' in content:
|
| 234 |
+
return 'CSE', mods # Selenocysteine
|
| 235 |
+
|
| 236 |
+
if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
|
| 237 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 238 |
+
|
| 239 |
+
if 'C1CCCCC1' in content:
|
| 240 |
+
if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
|
| 241 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 242 |
+
elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
|
| 243 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 244 |
+
|
| 245 |
+
# Naphthalene derivatives
|
| 246 |
+
if 'c1cccc2c1cccc2' in content:
|
| 247 |
+
if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
|
| 248 |
+
return 'NAL', mods # 2-Naphthyl-alanine
|
| 249 |
+
|
| 250 |
+
# Heteroaromatic derivatives
|
| 251 |
+
if 'c1cncc' in content:
|
| 252 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 253 |
+
if 'c1cscc' in content:
|
| 254 |
+
return 'THA3', mods # 3-(3-thienyl)-alanine
|
| 255 |
+
if 'c1nnc' in content:
|
| 256 |
+
return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
|
| 257 |
+
|
| 258 |
+
# Modified serines and threonines
|
| 259 |
+
if 'OP(O)(O)O' in content:
|
| 260 |
+
if '[C@@H](COP' in content or '[C@H](COP' in content:
|
| 261 |
+
return 'SEP', mods # phosphoserine
|
| 262 |
+
elif '[C@@H](OP' in content or '[C@H](OP' in content:
|
| 263 |
+
return 'TPO', mods # phosphothreonine
|
| 264 |
+
|
| 265 |
+
# Specialized ring systems
|
| 266 |
+
if 'c1c2ccccc2cc2c1cccc2' in content:
|
| 267 |
+
return 'ANTH', mods # 3-(9-anthryl)-alanine
|
| 268 |
+
if 'c1csc2c1cccc2' in content:
|
| 269 |
+
return 'BTH3', mods # 3-(3-benzothienyl)-alanine
|
| 270 |
+
if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
|
| 271 |
+
return 'ADAM', mods # Adamanthane
|
| 272 |
+
|
| 273 |
+
# Fluorinated derivatives
|
| 274 |
+
if 'FC(F)(F)' in content:
|
| 275 |
+
if 'CC(F)(F)F' in content:
|
| 276 |
+
return 'FLA', mods # Trifluoro-alanine
|
| 277 |
+
if 'C(F)(F)F)c1' in content:
|
| 278 |
+
if 'c1ccccc1C(F)(F)F' in content:
|
| 279 |
+
return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
|
| 280 |
+
if 'c1cccc(c1)C(F)(F)F' in content:
|
| 281 |
+
return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
|
| 282 |
+
if 'c1ccc(cc1)C(F)(F)F' in content:
|
| 283 |
+
return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
|
| 284 |
+
|
| 285 |
+
# Multiple halogen patterns
|
| 286 |
+
if 'F' in content and 'c1' in content:
|
| 287 |
+
if 'c1ccc(c(c1)F)F' in content:
|
| 288 |
+
return 'F2F', mods # 3,4-Difluoro-phenylalanine
|
| 289 |
+
if 'cc(F)cc(c1)F' in content:
|
| 290 |
+
return 'WFP', mods # 3,5-Difluoro-phenylalanine
|
| 291 |
+
if 'Cl' in content and 'c1' in content:
|
| 292 |
+
if 'c1ccc(cc1Cl)Cl' in content:
|
| 293 |
+
return 'CP24', mods # 2,4-dichloro-phenylalanine
|
| 294 |
+
if 'c1ccc(c(c1)Cl)Cl' in content:
|
| 295 |
+
return 'CP34', mods # 3,4-dichloro-phenylalanine
|
| 296 |
+
|
| 297 |
+
# Hydroxy and amino derivatives
|
| 298 |
+
if 'O' in content and 'c1' in content:
|
| 299 |
+
if 'c1cc(O)cc(c1)O' in content:
|
| 300 |
+
return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
|
| 301 |
+
if 'c1ccc(c(c1)O)O' in content:
|
| 302 |
+
return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
|
| 303 |
+
|
| 304 |
+
# Cyclic amino acids
|
| 305 |
+
if 'C1CCCC1' in content:
|
| 306 |
+
return 'CPA3', mods # 3-Cyclopentyl-alanine
|
| 307 |
+
if 'C1CCCCC1' in content:
|
| 308 |
+
if 'CC1CCCCC1' in content:
|
| 309 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 310 |
+
else:
|
| 311 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 312 |
+
|
| 313 |
+
# Chain-length variants
|
| 314 |
+
if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
|
| 315 |
+
return 'NLE', mods # Norleucine
|
| 316 |
+
if 'CC[C@@H]' in content or 'CC[C@H]' in content:
|
| 317 |
+
if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
|
| 318 |
+
return 'ABA', mods # 2-Aminobutyric acid
|
| 319 |
+
|
| 320 |
+
# Modified histidines
|
| 321 |
+
if 'c1cnc' in content:
|
| 322 |
+
if '[C@@H]1CN[C@@H](N1)F' in content:
|
| 323 |
+
return '2HF', mods # 2-fluoro-l-histidine
|
| 324 |
+
if 'c1cnc([nH]1)F' in content:
|
| 325 |
+
return '2HF1', mods # 2-fluoro-l-histidine variant
|
| 326 |
+
if 'c1c[nH]c(n1)F' in content:
|
| 327 |
+
return '2HF2', mods # 2-fluoro-l-histidine variant
|
| 328 |
+
|
| 329 |
+
# Sulfur and selenium containing
|
| 330 |
+
if '[SeH]' in content:
|
| 331 |
+
return 'CSE', mods # Selenocysteine
|
| 332 |
+
if 'S' in content:
|
| 333 |
+
if 'CSCc1ccccc1' in content:
|
| 334 |
+
return 'BCS', mods # benzylcysteine
|
| 335 |
+
if 'CCSC' in content:
|
| 336 |
+
return 'ESC', mods # Ethionine
|
| 337 |
+
if 'CCS' in content:
|
| 338 |
+
return 'HCS', mods # homocysteine
|
| 339 |
+
|
| 340 |
+
# Additional modifications
|
| 341 |
+
if 'CN=[N]=N' in content:
|
| 342 |
+
return 'AZDA', mods # azido-alanine
|
| 343 |
+
if '[NH]=[C](=[NH2])=[NH2]' in content:
|
| 344 |
+
if 'CCC[NH]=' in content:
|
| 345 |
+
return 'AGM', mods # 5-methyl-arginine
|
| 346 |
+
if 'CC[NH]=' in content:
|
| 347 |
+
return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
|
| 348 |
+
|
| 349 |
+
if 'CCON' in content:
|
| 350 |
+
return 'CAN', mods # canaline
|
| 351 |
+
if '[C@@H]1C=C[C@@H](C=C1)' in content:
|
| 352 |
+
return 'ACZ', mods # cis-amiclenomycin
|
| 353 |
+
if 'CCC(=O)[NH3]' in content:
|
| 354 |
+
return 'ONL', mods # 5-oxo-l-norleucine
|
| 355 |
+
if 'c1ccncc1' in content:
|
| 356 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 357 |
+
if 'c1ccco1' in content:
|
| 358 |
+
return 'FUA2', mods # (2-furyl)-alanine
|
| 359 |
+
|
| 360 |
+
if 'c1ccc' in content:
|
| 361 |
+
if 'c1ccc(cc1)c1ccccc1' in content:
|
| 362 |
+
return 'BIF', mods # 4,4-biphenylalanine
|
| 363 |
+
if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
|
| 364 |
+
return 'PBF', mods # 4-benzoyl-phenylalanine
|
| 365 |
+
if 'c1ccc(cc1)C(C)(C)C' in content:
|
| 366 |
+
return 'TBP4', mods # 4-tert-butyl-phenylalanine
|
| 367 |
+
if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
|
| 368 |
+
return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
|
| 369 |
+
if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
|
| 370 |
+
return 'APM', mods # m-amidinophenyl-3-alanine
|
| 371 |
+
|
| 372 |
+
# Multiple hydroxy patterns
|
| 373 |
+
if 'O' in content:
|
| 374 |
+
if '[C@H]([C@H](C)O)O' in content:
|
| 375 |
+
return 'ILX', mods # 4,5-dihydroxy-isoleucine
|
| 376 |
+
if '[C@H]([C@@H](C)O)O' in content:
|
| 377 |
+
return 'ALO', mods # Allo-threonine
|
| 378 |
+
if '[C@H](COP(O)(O)O)' in content:
|
| 379 |
+
return 'SEP', mods # phosphoserine
|
| 380 |
+
if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
|
| 381 |
+
return 'TPO', mods # phosphothreonine
|
| 382 |
+
if '[C@H](c1ccc(O)cc1)O' in content:
|
| 383 |
+
return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
|
| 384 |
+
if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
|
| 385 |
+
return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
|
| 386 |
+
|
| 387 |
+
# Heterocyclic patterns
|
| 388 |
+
if 'n1' in content:
|
| 389 |
+
if 'n1cccn1' in content:
|
| 390 |
+
return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
|
| 391 |
+
if 'n1nncn1' in content:
|
| 392 |
+
return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
|
| 393 |
+
if 'c2c(n1)cccc2' in content:
|
| 394 |
+
return 'QU32', mods # 3-(2-Quinolyl)-alanine
|
| 395 |
+
if 'c1cnc2c(c1)cccc2' in content:
|
| 396 |
+
return 'QU33', mods # 3-(3-quinolyl)-alanine
|
| 397 |
+
if 'c1ccnc2c1cccc2' in content:
|
| 398 |
+
return 'QU34', mods # 3-(4-quinolyl)-alanine
|
| 399 |
+
if 'c1ccc2c(c1)nccc2' in content:
|
| 400 |
+
return 'QU35', mods # 3-(5-Quinolyl)-alanine
|
| 401 |
+
if 'c1ccc2c(c1)cncc2' in content:
|
| 402 |
+
return 'QU36', mods # 3-(6-Quinolyl)-alanine
|
| 403 |
+
if 'c1cnc2c(n1)cccc2' in content:
|
| 404 |
+
return 'QX32', mods # 3-(2-quinoxalyl)-alanine
|
| 405 |
+
|
| 406 |
+
# Multiple nitrogen patterns
|
| 407 |
+
if 'N' in content:
|
| 408 |
+
if '[NH3]CC[C@@H]' in content:
|
| 409 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 410 |
+
if '[NH3]C[C@@H]' in content:
|
| 411 |
+
return 'DPP', mods # 2,3-Diaminopropanoic acid
|
| 412 |
+
if '[NH3]CCCCCC[C@@H]' in content:
|
| 413 |
+
return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
|
| 414 |
+
if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
|
| 415 |
+
return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
|
| 416 |
+
if '[NH]=[C](=S)=[NH2]' in content:
|
| 417 |
+
return 'THIC', mods # Thio-citrulline
|
| 418 |
+
|
| 419 |
+
# Chain modified amino acids
|
| 420 |
+
if 'CC' in content:
|
| 421 |
+
if 'CCCC[C@@H]' in content:
|
| 422 |
+
return 'AHP', mods # 2-Aminoheptanoic acid
|
| 423 |
+
if 'CCC([C@@H])(C)C' in content:
|
| 424 |
+
return 'I2M', mods # 3-methyl-l-alloisoleucine
|
| 425 |
+
if 'CC[C@H]([C@@H])C' in content:
|
| 426 |
+
return 'IIL', mods # Allo-Isoleucine
|
| 427 |
+
if '[C@H](CCC(C)C)' in content:
|
| 428 |
+
return 'HLEU', mods # Homoleucine
|
| 429 |
+
if '[C@@H]([C@@H](C)O)C' in content:
|
| 430 |
+
return 'HLU', mods # beta-hydroxyleucine
|
| 431 |
+
|
| 432 |
+
# Modified glutamate/aspartate patterns
|
| 433 |
+
if '[C@@H]' in content:
|
| 434 |
+
if '[C@@H](C[C@@H](F))' in content:
|
| 435 |
+
return 'FGA4', mods # 4-Fluoro-glutamic acid
|
| 436 |
+
if '[C@@H](C[C@@H](O))' in content:
|
| 437 |
+
return '3GL', mods # 4-hydroxy-glutamic-acid
|
| 438 |
+
if '[C@@H](C[C@H](C))' in content:
|
| 439 |
+
return 'LME', mods # (3r)-3-methyl-l-glutamic acid
|
| 440 |
+
if '[C@@H](CC[C@H](C))' in content:
|
| 441 |
+
return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
|
| 442 |
+
|
| 443 |
+
# Sulfur and selenium modifications
|
| 444 |
+
if 'S' in content:
|
| 445 |
+
if 'SCC[C@@H]' in content:
|
| 446 |
+
return 'HSER', mods # homoserine
|
| 447 |
+
if 'SCCN' in content:
|
| 448 |
+
return 'SLZ', mods # thialysine
|
| 449 |
+
if 'SC(=O)' in content:
|
| 450 |
+
return 'CSA', mods # s-acetonylcysteine
|
| 451 |
+
if '[S@@](=O)' in content:
|
| 452 |
+
return 'SME', mods # Methionine sulfoxide
|
| 453 |
+
if 'S(=O)(=O)' in content:
|
| 454 |
+
return 'OMT', mods # Methionine sulfone
|
| 455 |
+
|
| 456 |
+
# Double bond containing
|
| 457 |
+
if 'C=' in content:
|
| 458 |
+
if 'C=C[C@@H]' in content:
|
| 459 |
+
return '2AG', mods # 2-Allyl-glycine
|
| 460 |
+
if 'C=C[C@@H]' in content:
|
| 461 |
+
return 'LVG', mods # vinylglycine
|
| 462 |
+
if 'C=Cc1ccccc1' in content:
|
| 463 |
+
return 'STYA', mods # Styrylalanine
|
| 464 |
+
|
| 465 |
+
# Special cases
|
| 466 |
+
if '[C@@H]1Cc2c(C1)cccc2' in content:
|
| 467 |
+
return 'IGL', mods # alpha-amino-2-indanacetic acid
|
| 468 |
+
if '[C](=[C](=O)=O)=O' in content:
|
| 469 |
+
return '26P', mods # 2-amino-6-oxopimelic acid
|
| 470 |
+
if '[C](=[C](=O)=O)=C' in content:
|
| 471 |
+
return '2NP', mods # l-2-amino-6-methylene-pimelic acid
|
| 472 |
+
if 'c2cnc[nH]2' in content:
|
| 473 |
+
return 'HIS', mods # histidine core
|
| 474 |
+
if 'c1cccc2c1cc(O)cc2' in content:
|
| 475 |
+
return 'NAO1', mods # 5-hydroxy-1-naphthalene
|
| 476 |
+
if 'c1ccc2c(c1)cc(O)cc2' in content:
|
| 477 |
+
return 'NAO2', mods # 6-hydroxy-2-naphthalene
|
| 478 |
+
|
| 479 |
+
# Proline (P) - flexible ring numbers
|
| 480 |
+
if any([
|
| 481 |
+
# Check for any ring number in bond patterns
|
| 482 |
+
(segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
|
| 483 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 484 |
+
for n in '123456789'
|
| 485 |
+
]) or any([
|
| 486 |
+
# Check ending patterns with any ring number
|
| 487 |
+
(f'CCCN{n}' in content and content.endswith('=O') and
|
| 488 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 489 |
+
for n in '123456789'
|
| 490 |
+
]) or any([
|
| 491 |
+
# Handle CCC[C@H]n patterns
|
| 492 |
+
(content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 493 |
+
(content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 494 |
+
# N-terminal Pro with any ring number
|
| 495 |
+
(f'N{n}CCC[C@H]{n}' in content) or
|
| 496 |
+
(f'N{n}CCC[C@@H]{n}' in content)
|
| 497 |
+
for n in '123456789'
|
| 498 |
+
]):
|
| 499 |
+
return 'Pro', mods
|
| 500 |
+
|
| 501 |
+
# Tryptophan (W) - more specific indole pattern
|
| 502 |
+
if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
|
| 503 |
+
'c[nH]c' in content.replace(' ', ''):
|
| 504 |
+
return 'Trp', mods
|
| 505 |
+
|
| 506 |
+
# Lysine (K) - both patterns
|
| 507 |
+
if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
|
| 508 |
+
return 'Lys', mods
|
| 509 |
+
|
| 510 |
+
# Arginine (R) - both patterns
|
| 511 |
+
if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
|
| 512 |
+
return 'Arg', mods
|
| 513 |
+
|
| 514 |
+
if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
|
| 515 |
+
return 'Nle', mods
|
| 516 |
+
|
| 517 |
+
# Ornithine (Orn) - 3-carbon chain with NH2
|
| 518 |
+
if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
|
| 519 |
+
return 'Orn', mods
|
| 520 |
+
|
| 521 |
+
# 2-Naphthylalanine (2Nal) - distinct from Phe pattern
|
| 522 |
+
if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 523 |
+
return '2Nal', mods
|
| 524 |
+
|
| 525 |
+
# Cyclohexylalanine (Cha) - already in your code but moved here for clarity
|
| 526 |
+
if 'N2CCCCC2' in content or 'CCCCC2' in content:
|
| 527 |
+
return 'Cha', mods
|
| 528 |
+
|
| 529 |
+
# Aminobutyric acid (Abu) - 2-carbon chain
|
| 530 |
+
if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
|
| 531 |
+
return 'Abu', mods
|
| 532 |
+
|
| 533 |
+
# Pipecolic acid (Pip) - 6-membered ring like Pro
|
| 534 |
+
if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 535 |
+
return 'Pip', mods
|
| 536 |
+
|
| 537 |
+
# Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
|
| 538 |
+
if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
|
| 539 |
+
return 'Chg', mods
|
| 540 |
+
|
| 541 |
+
# 4-Fluorophenylalanine (4F-Phe)
|
| 542 |
+
if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 543 |
+
return '4F-Phe', mods
|
| 544 |
+
|
| 545 |
+
# Regular residue identification
|
| 546 |
+
if ('NCC(=O)' in content) or (content == 'C'):
|
| 547 |
+
# Middle case - between bonds
|
| 548 |
+
if segment.get('bond_before') and segment.get('bond_after'):
|
| 549 |
+
if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
|
| 550 |
+
return 'Gly', mods
|
| 551 |
+
# Terminal case - at the end
|
| 552 |
+
elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
|
| 553 |
+
return 'Gly', mods
|
| 554 |
+
|
| 555 |
+
if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
|
| 556 |
+
return 'Leu', mods
|
| 557 |
+
if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
|
| 558 |
+
return 'Leu', mods
|
| 559 |
+
|
| 560 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
|
| 561 |
+
return 'Thr', mods
|
| 562 |
+
|
| 563 |
+
if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
|
| 564 |
+
return 'Phe', mods
|
| 565 |
+
|
| 566 |
+
if ('[C@H](C(C)C)' in content or # With outer parentheses
|
| 567 |
+
'[C@@H](C(C)C)' in content or # With outer parentheses
|
| 568 |
+
'[C@H]C(C)C' in content or # Without outer parentheses
|
| 569 |
+
'[C@@H]C(C)C' in content): # Without outer parentheses
|
| 570 |
+
if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
|
| 571 |
+
return 'Val', mods
|
| 572 |
+
|
| 573 |
+
if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
|
| 574 |
+
return 'O-tBu', mods
|
| 575 |
+
|
| 576 |
+
if any([
|
| 577 |
+
'CC[C@H](C)' in content,
|
| 578 |
+
'CC[C@@H](C)' in content,
|
| 579 |
+
'C(C)C[C@H]' in content and 'CC(C)C' not in content,
|
| 580 |
+
'C(C)C[C@@H]' in content and 'CC(C)C' not in content
|
| 581 |
+
]):
|
| 582 |
+
return 'Ile', mods
|
| 583 |
+
|
| 584 |
+
if ('[C@H](C)' in content or '[C@@H](C)' in content):
|
| 585 |
+
if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
|
| 586 |
+
return 'Ala', mods
|
| 587 |
+
|
| 588 |
+
# Tyrosine (Tyr) - 4-hydroxybenzyl side chain
|
| 589 |
+
if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
|
| 590 |
+
return 'Tyr', mods
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
# Serine (Ser) - Hydroxymethyl side chain
|
| 594 |
+
if '[C@H](CO)' in content or '[C@@H](CO)' in content:
|
| 595 |
+
if not ('C(C)O' in content or 'COC' in content):
|
| 596 |
+
return 'Ser', mods
|
| 597 |
+
|
| 598 |
+
# Threonine (Thr) - 1-hydroxyethyl side chain
|
| 599 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
|
| 600 |
+
return 'Thr', mods
|
| 601 |
+
|
| 602 |
+
# Cysteine (Cys) - Thiol side chain
|
| 603 |
+
if '[C@H](CS)' in content or '[C@@H](CS)' in content:
|
| 604 |
+
return 'Cys', mods
|
| 605 |
+
|
| 606 |
+
# Methionine (Met) - Methylthioethyl side chain
|
| 607 |
+
if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
|
| 608 |
+
return 'Met', mods
|
| 609 |
+
|
| 610 |
+
# Asparagine (Asn) - Carbamoylmethyl side chain
|
| 611 |
+
if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 612 |
+
return 'Asn', mods
|
| 613 |
+
|
| 614 |
+
# Glutamine (Gln) - Carbamoylethyl side chain
|
| 615 |
+
if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 616 |
+
return 'Gln', mods
|
| 617 |
+
|
| 618 |
+
# Aspartic acid (Asp) - Carboxymethyl side chain
|
| 619 |
+
if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 620 |
+
return 'Asp', mods
|
| 621 |
+
|
| 622 |
+
# Glutamic acid (Glu) - Carboxyethyl side chain
|
| 623 |
+
if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 624 |
+
return 'Glu', mods
|
| 625 |
+
|
| 626 |
+
# Arginine (Arg) - 3-guanidinopropyl side chain
|
| 627 |
+
if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 628 |
+
return 'Arg', mods
|
| 629 |
+
|
| 630 |
+
# Histidine (His) - Imidazole side chain
|
| 631 |
+
if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 632 |
+
return 'His', mods
|
| 633 |
+
|
| 634 |
+
return None, mods
|
| 635 |
+
|
| 636 |
+
def get_modifications(self, segment):
|
| 637 |
+
"""Get modifications based on bond types"""
|
| 638 |
+
mods = []
|
| 639 |
+
if segment.get('bond_after'):
|
| 640 |
+
if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
|
| 641 |
+
mods.append('N-Me')
|
| 642 |
+
if 'OC(=O)' in segment['bond_after']:
|
| 643 |
+
mods.append('O-linked')
|
| 644 |
+
return mods
|
| 645 |
+
|
| 646 |
+
def analyze_structure(self, smiles):
|
| 647 |
+
"""Main analysis function with debug output"""
|
| 648 |
+
print("\nAnalyzing structure:", smiles)
|
| 649 |
+
|
| 650 |
+
# Split into segments
|
| 651 |
+
segments = self.split_on_bonds(smiles)
|
| 652 |
+
|
| 653 |
+
print("\nSegment Analysis:")
|
| 654 |
+
sequence = []
|
| 655 |
+
for i, segment in enumerate(segments):
|
| 656 |
+
print(f"\nSegment {i}:")
|
| 657 |
+
print(f"Content: {segment['content']}")
|
| 658 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 659 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 660 |
+
|
| 661 |
+
residue, mods = self.identify_residue(segment)
|
| 662 |
+
if residue:
|
| 663 |
+
if mods:
|
| 664 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 665 |
+
else:
|
| 666 |
+
sequence.append(residue)
|
| 667 |
+
print(f"Identified as: {residue}")
|
| 668 |
+
print(f"Modifications: {mods}")
|
| 669 |
+
else:
|
| 670 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 671 |
+
|
| 672 |
+
# Check if cyclic
|
| 673 |
+
is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
|
| 674 |
+
three_letter = '-'.join(sequence)
|
| 675 |
+
one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
|
| 676 |
+
|
| 677 |
+
if is_cyclic:
|
| 678 |
+
three_letter = f"cyclo({three_letter})"
|
| 679 |
+
one_letter = f"cyclo({one_letter})"
|
| 680 |
+
|
| 681 |
+
print(f"\nFinal sequence: {three_letter}")
|
| 682 |
+
print(f"One-letter code: {one_letter}")
|
| 683 |
+
print(f"Is cyclic: {is_cyclic}")
|
| 684 |
+
#print(f"Peptide cycles: {peptide_cycles}")
|
| 685 |
+
#print(f"Aromatic cycles: {aromatic_cycles}")
|
| 686 |
+
|
| 687 |
+
return three_letter, len(segments)
|
| 688 |
+
"""return {
|
| 689 |
+
'three_letter': three_letter,
|
| 690 |
+
#'one_letter': one_letter,
|
| 691 |
+
'is_cyclic': is_cyclic
|
| 692 |
+
}"""
|
| 693 |
+
|
| 694 |
+
def return_sequence(self, smiles):
|
| 695 |
+
"""Main analysis function with debug output"""
|
| 696 |
+
print("\nAnalyzing structure:", smiles)
|
| 697 |
+
|
| 698 |
+
# Split into segments
|
| 699 |
+
segments = self.split_on_bonds(smiles)
|
| 700 |
+
|
| 701 |
+
print("\nSegment Analysis:")
|
| 702 |
+
sequence = []
|
| 703 |
+
for i, segment in enumerate(segments):
|
| 704 |
+
print(f"\nSegment {i}:")
|
| 705 |
+
print(f"Content: {segment['content']}")
|
| 706 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 707 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 708 |
+
|
| 709 |
+
residue, mods = self.identify_residue(segment)
|
| 710 |
+
if residue:
|
| 711 |
+
if mods:
|
| 712 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 713 |
+
else:
|
| 714 |
+
sequence.append(residue)
|
| 715 |
+
print(f"Identified as: {residue}")
|
| 716 |
+
print(f"Modifications: {mods}")
|
| 717 |
+
else:
|
| 718 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 719 |
+
|
| 720 |
+
return sequence
|
| 721 |
+
|
| 722 |
+
"""
|
| 723 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 724 |
+
'''Create annotated 2D structure with clear, non-overlapping residue labels'''
|
| 725 |
+
# Generate 2D coordinates
|
| 726 |
+
# Generate 2D coordinates
|
| 727 |
+
AllChem.Compute2DCoords(mol)
|
| 728 |
+
|
| 729 |
+
# Create drawer with larger size for annotations
|
| 730 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
|
| 731 |
+
|
| 732 |
+
# Get residue list and reverse it to match structural representation
|
| 733 |
+
if sequence.startswith('cyclo('):
|
| 734 |
+
residues = sequence[6:-1].split('-')
|
| 735 |
+
else:
|
| 736 |
+
residues = sequence.split('-')
|
| 737 |
+
residues = list(reversed(residues)) # Reverse the sequence
|
| 738 |
+
|
| 739 |
+
# Draw molecule first to get its bounds
|
| 740 |
+
drawer.drawOptions().addAtomIndices = False
|
| 741 |
+
drawer.DrawMolecule(mol)
|
| 742 |
+
drawer.FinishDrawing()
|
| 743 |
+
|
| 744 |
+
# Convert to PIL Image
|
| 745 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 746 |
+
draw = ImageDraw.Draw(img)
|
| 747 |
+
|
| 748 |
+
try:
|
| 749 |
+
# Try to use DejaVuSans as it's commonly available on Linux systems
|
| 750 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 751 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 752 |
+
except OSError:
|
| 753 |
+
try:
|
| 754 |
+
# Fallback to Arial if available (common on Windows)
|
| 755 |
+
font = ImageFont.truetype("arial.ttf", 60)
|
| 756 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 757 |
+
except OSError:
|
| 758 |
+
# If no TrueType fonts are available, fall back to default
|
| 759 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 760 |
+
font = ImageFont.load_default()
|
| 761 |
+
small_font = ImageFont.load_default()
|
| 762 |
+
# Get molecule bounds
|
| 763 |
+
conf = mol.GetConformer()
|
| 764 |
+
positions = []
|
| 765 |
+
for i in range(mol.GetNumAtoms()):
|
| 766 |
+
pos = conf.GetAtomPosition(i)
|
| 767 |
+
positions.append((pos.x, pos.y))
|
| 768 |
+
|
| 769 |
+
x_coords = [p[0] for p in positions]
|
| 770 |
+
y_coords = [p[1] for p in positions]
|
| 771 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
| 772 |
+
min_y, max_y = min(y_coords), max(y_coords)
|
| 773 |
+
|
| 774 |
+
# Calculate scaling factors
|
| 775 |
+
scale = 150 # Increased scale factor
|
| 776 |
+
center_x = 1000 # Image center
|
| 777 |
+
center_y = 1000
|
| 778 |
+
|
| 779 |
+
# Add residue labels in a circular arrangement around the structure
|
| 780 |
+
n_residues = len(residues)
|
| 781 |
+
radius = 700 # Distance of labels from center
|
| 782 |
+
|
| 783 |
+
# Start from the rightmost point (3 o'clock position) and go counterclockwise
|
| 784 |
+
# Offset by -3 positions to align with structure
|
| 785 |
+
offset = 0 # Adjust this value to match the structure alignment
|
| 786 |
+
for i, residue in enumerate(residues):
|
| 787 |
+
# Calculate position in a circle around the structure
|
| 788 |
+
# Start from 0 (3 o'clock) and go counterclockwise
|
| 789 |
+
angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
|
| 790 |
+
|
| 791 |
+
# Calculate label position
|
| 792 |
+
label_x = center_x + radius * np.cos(angle)
|
| 793 |
+
label_y = center_y + radius * np.sin(angle)
|
| 794 |
+
|
| 795 |
+
# Draw residue label
|
| 796 |
+
text = f"{i+1}. {residue}"
|
| 797 |
+
bbox = draw.textbbox((label_x, label_y), text, font=font)
|
| 798 |
+
padding = 10
|
| 799 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 800 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 801 |
+
fill='white', outline='white')
|
| 802 |
+
draw.text((label_x, label_y), text,
|
| 803 |
+
font=font, fill='black', anchor="mm")
|
| 804 |
+
|
| 805 |
+
# Add sequence at the top with white background
|
| 806 |
+
seq_text = f"Sequence: {sequence}"
|
| 807 |
+
bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
|
| 808 |
+
padding = 10
|
| 809 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 810 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 811 |
+
fill='white', outline='white')
|
| 812 |
+
draw.text((center_x, 100), seq_text,
|
| 813 |
+
font=small_font, fill='black', anchor="mm")
|
| 814 |
+
|
| 815 |
+
return img
|
| 816 |
+
|
| 817 |
+
"""
|
| 818 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 819 |
+
"""Create structure visualization with just the sequence header"""
|
| 820 |
+
# Generate 2D coordinates
|
| 821 |
+
AllChem.Compute2DCoords(mol)
|
| 822 |
+
|
| 823 |
+
# Create drawer with larger size for annotations
|
| 824 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
|
| 825 |
+
|
| 826 |
+
# Draw molecule first
|
| 827 |
+
drawer.drawOptions().addAtomIndices = False
|
| 828 |
+
drawer.DrawMolecule(mol)
|
| 829 |
+
drawer.FinishDrawing()
|
| 830 |
+
|
| 831 |
+
# Convert to PIL Image
|
| 832 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 833 |
+
draw = ImageDraw.Draw(img)
|
| 834 |
+
try:
|
| 835 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 836 |
+
except OSError:
|
| 837 |
+
try:
|
| 838 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 839 |
+
except OSError:
|
| 840 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 841 |
+
small_font = ImageFont.load_default()
|
| 842 |
+
|
| 843 |
+
# Add just the sequence header at the top
|
| 844 |
+
seq_text = f"Sequence: {sequence}"
|
| 845 |
+
bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
|
| 846 |
+
padding = 10
|
| 847 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 848 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 849 |
+
fill='white', outline='white')
|
| 850 |
+
draw.text((1000, 100), seq_text,
|
| 851 |
+
font=small_font, fill='black', anchor="mm")
|
| 852 |
+
|
| 853 |
+
return img
|
| 854 |
+
|
| 855 |
+
def create_enhanced_linear_viz(sequence, smiles):
|
| 856 |
+
"""Create an enhanced linear representation using PeptideAnalyzer"""
|
| 857 |
+
analyzer = PeptideAnalyzer() # Create analyzer instance
|
| 858 |
+
|
| 859 |
+
# Create figure with two subplots
|
| 860 |
+
fig = plt.figure(figsize=(15, 10))
|
| 861 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
|
| 862 |
+
ax_struct = fig.add_subplot(gs[0])
|
| 863 |
+
ax_detail = fig.add_subplot(gs[1])
|
| 864 |
+
|
| 865 |
+
# Parse sequence and get residues
|
| 866 |
+
if sequence.startswith('cyclo('):
|
| 867 |
+
residues = sequence[6:-1].split('-')
|
| 868 |
+
else:
|
| 869 |
+
residues = sequence.split('-')
|
| 870 |
+
|
| 871 |
+
# Get segments using analyzer
|
| 872 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 873 |
+
|
| 874 |
+
# Debug print
|
| 875 |
+
print(f"Number of residues: {len(residues)}")
|
| 876 |
+
print(f"Number of segments: {len(segments)}")
|
| 877 |
+
|
| 878 |
+
# Top subplot - Basic structure
|
| 879 |
+
ax_struct.set_xlim(0, 10)
|
| 880 |
+
ax_struct.set_ylim(0, 2)
|
| 881 |
+
|
| 882 |
+
num_residues = len(residues)
|
| 883 |
+
spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
|
| 884 |
+
|
| 885 |
+
# Draw basic structure
|
| 886 |
+
y_pos = 1.5
|
| 887 |
+
for i in range(num_residues):
|
| 888 |
+
x_pos = 0.5 + i * spacing
|
| 889 |
+
|
| 890 |
+
# Draw amino acid box
|
| 891 |
+
rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
|
| 892 |
+
facecolor='lightblue', edgecolor='black')
|
| 893 |
+
ax_struct.add_patch(rect)
|
| 894 |
+
|
| 895 |
+
# Draw connecting bonds if not the last residue
|
| 896 |
+
if i < num_residues - 1:
|
| 897 |
+
segment = segments[i] if i < len(segments) else None
|
| 898 |
+
if segment:
|
| 899 |
+
# Determine bond type from segment info
|
| 900 |
+
bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
|
| 901 |
+
is_n_methylated = 'N-Me' in segment.get('bond_after', '')
|
| 902 |
+
|
| 903 |
+
bond_color = 'red' if bond_type == 'ester' else 'black'
|
| 904 |
+
linestyle = '--' if bond_type == 'ester' else '-'
|
| 905 |
+
|
| 906 |
+
# Draw bond line
|
| 907 |
+
ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
|
| 908 |
+
color=bond_color, linestyle=linestyle, linewidth=2)
|
| 909 |
+
|
| 910 |
+
# Add bond type label
|
| 911 |
+
mid_x = x_pos + spacing/2
|
| 912 |
+
bond_label = f"{bond_type}"
|
| 913 |
+
if is_n_methylated:
|
| 914 |
+
bond_label += "\n(N-Me)"
|
| 915 |
+
ax_struct.text(mid_x, y_pos+0.1, bond_label,
|
| 916 |
+
ha='center', va='bottom', fontsize=10,
|
| 917 |
+
color=bond_color)
|
| 918 |
+
|
| 919 |
+
# Add residue label
|
| 920 |
+
ax_struct.text(x_pos, y_pos-0.5, residues[i],
|
| 921 |
+
ha='center', va='top', fontsize=14)
|
| 922 |
+
|
| 923 |
+
# Bottom subplot - Detailed breakdown
|
| 924 |
+
ax_detail.set_ylim(0, len(segments)+1)
|
| 925 |
+
ax_detail.set_xlim(0, 1)
|
| 926 |
+
|
| 927 |
+
# Create detailed breakdown
|
| 928 |
+
segment_y = len(segments) # Start from top
|
| 929 |
+
for i, segment in enumerate(segments):
|
| 930 |
+
y = segment_y - i
|
| 931 |
+
|
| 932 |
+
# Check if this is a bond or residue
|
| 933 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 934 |
+
if residue:
|
| 935 |
+
text = f"Residue {i+1}: {residue}"
|
| 936 |
+
if mods:
|
| 937 |
+
text += f" ({', '.join(mods)})"
|
| 938 |
+
color = 'blue'
|
| 939 |
+
else:
|
| 940 |
+
# Must be a bond
|
| 941 |
+
text = f"Bond {i}: "
|
| 942 |
+
if 'O-linked' in segment.get('bond_after', ''):
|
| 943 |
+
text += "ester"
|
| 944 |
+
elif 'N-Me' in segment.get('bond_after', ''):
|
| 945 |
+
text += "peptide (N-methylated)"
|
| 946 |
+
else:
|
| 947 |
+
text += "peptide"
|
| 948 |
+
color = 'red'
|
| 949 |
+
|
| 950 |
+
# Add segment analysis
|
| 951 |
+
ax_detail.text(0.05, y, text, fontsize=12, color=color)
|
| 952 |
+
ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
|
| 953 |
+
|
| 954 |
+
# If cyclic, add connection indicator
|
| 955 |
+
if sequence.startswith('cyclo('):
|
| 956 |
+
ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
|
| 957 |
+
arrowprops=dict(arrowstyle='<->', color='red', lw=2))
|
| 958 |
+
ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
|
| 959 |
+
ha='center', color='red', fontsize=14)
|
| 960 |
+
|
| 961 |
+
# Add titles and adjust layout
|
| 962 |
+
ax_struct.set_title("Peptide Structure Overview", pad=20)
|
| 963 |
+
ax_detail.set_title("Segment Analysis Breakdown", pad=20)
|
| 964 |
+
|
| 965 |
+
# Remove axes
|
| 966 |
+
for ax in [ax_struct, ax_detail]:
|
| 967 |
+
ax.set_xticks([])
|
| 968 |
+
ax.set_yticks([])
|
| 969 |
+
ax.axis('off')
|
| 970 |
+
|
| 971 |
+
plt.tight_layout()
|
| 972 |
+
return fig
|
| 973 |
+
|
| 974 |
+
class PeptideStructureGenerator:
|
| 975 |
+
"""A class to generate 3D structures of peptides using different embedding methods"""
|
| 976 |
+
|
| 977 |
+
@staticmethod
|
| 978 |
+
def prepare_molecule(smiles):
|
| 979 |
+
"""Prepare molecule with proper hydrogen handling"""
|
| 980 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 981 |
+
if mol is None:
|
| 982 |
+
raise ValueError("Failed to create molecule from SMILES")
|
| 983 |
+
|
| 984 |
+
# Calculate valence for each atom
|
| 985 |
+
for atom in mol.GetAtoms():
|
| 986 |
+
atom.UpdatePropertyCache(strict=False)
|
| 987 |
+
|
| 988 |
+
# Sanitize with reduced requirements
|
| 989 |
+
Chem.SanitizeMol(mol,
|
| 990 |
+
sanitizeOps=Chem.SANITIZE_FINDRADICALS|
|
| 991 |
+
Chem.SANITIZE_KEKULIZE|
|
| 992 |
+
Chem.SANITIZE_SETAROMATICITY|
|
| 993 |
+
Chem.SANITIZE_SETCONJUGATION|
|
| 994 |
+
Chem.SANITIZE_SETHYBRIDIZATION|
|
| 995 |
+
Chem.SANITIZE_CLEANUPCHIRALITY)
|
| 996 |
+
|
| 997 |
+
mol = Chem.AddHs(mol)
|
| 998 |
+
return mol
|
| 999 |
+
|
| 1000 |
+
@staticmethod
|
| 1001 |
+
def get_etkdg_params(attempt=0):
|
| 1002 |
+
"""Get ETKDG parameters with optional modifications based on attempt number"""
|
| 1003 |
+
params = AllChem.ETKDGv3()
|
| 1004 |
+
params.randomSeed = -1
|
| 1005 |
+
params.maxIterations = 200
|
| 1006 |
+
params.numThreads = 4 # Reduced for web interface
|
| 1007 |
+
params.useBasicKnowledge = True
|
| 1008 |
+
params.enforceChirality = True
|
| 1009 |
+
params.useExpTorsionAnglePrefs = True
|
| 1010 |
+
params.useSmallRingTorsions = True
|
| 1011 |
+
params.useMacrocycleTorsions = True
|
| 1012 |
+
params.ETversion = 2
|
| 1013 |
+
params.pruneRmsThresh = -1
|
| 1014 |
+
params.embedRmsThresh = 0.5
|
| 1015 |
+
|
| 1016 |
+
if attempt > 10:
|
| 1017 |
+
params.bondLength = 1.5 + (attempt - 10) * 0.02
|
| 1018 |
+
params.useExpTorsionAnglePrefs = False
|
| 1019 |
+
|
| 1020 |
+
return params
|
| 1021 |
+
|
| 1022 |
+
def generate_structure_etkdg(self, smiles, max_attempts=20):
|
| 1023 |
+
"""Generate 3D structure using ETKDG without UFF optimization"""
|
| 1024 |
+
success = False
|
| 1025 |
+
mol = None
|
| 1026 |
+
|
| 1027 |
+
for attempt in range(max_attempts):
|
| 1028 |
+
try:
|
| 1029 |
+
mol = self.prepare_molecule(smiles)
|
| 1030 |
+
params = self.get_etkdg_params(attempt)
|
| 1031 |
+
|
| 1032 |
+
if AllChem.EmbedMolecule(mol, params) == 0:
|
| 1033 |
+
success = True
|
| 1034 |
+
break
|
| 1035 |
+
except Exception as e:
|
| 1036 |
+
continue
|
| 1037 |
+
|
| 1038 |
+
if not success:
|
| 1039 |
+
raise ValueError("Failed to generate structure with ETKDG")
|
| 1040 |
+
|
| 1041 |
+
return mol
|
| 1042 |
+
|
| 1043 |
+
def generate_structure_uff(self, smiles, max_attempts=20):
|
| 1044 |
+
"""Generate 3D structure using ETKDG followed by UFF optimization"""
|
| 1045 |
+
best_mol = None
|
| 1046 |
+
lowest_energy = float('inf')
|
| 1047 |
+
|
| 1048 |
+
for attempt in range(max_attempts):
|
| 1049 |
+
try:
|
| 1050 |
+
test_mol = self.prepare_molecule(smiles)
|
| 1051 |
+
params = self.get_etkdg_params(attempt)
|
| 1052 |
+
|
| 1053 |
+
if AllChem.EmbedMolecule(test_mol, params) == 0:
|
| 1054 |
+
res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
|
| 1055 |
+
vdwThresh=10.0, confId=0,
|
| 1056 |
+
ignoreInterfragInteractions=True)
|
| 1057 |
+
|
| 1058 |
+
if res == 0:
|
| 1059 |
+
ff = AllChem.UFFGetMoleculeForceField(test_mol)
|
| 1060 |
+
if ff:
|
| 1061 |
+
current_energy = ff.CalcEnergy()
|
| 1062 |
+
if current_energy < lowest_energy:
|
| 1063 |
+
lowest_energy = current_energy
|
| 1064 |
+
best_mol = Chem.Mol(test_mol)
|
| 1065 |
+
except Exception:
|
| 1066 |
+
continue
|
| 1067 |
+
|
| 1068 |
+
if best_mol is None:
|
| 1069 |
+
raise ValueError("Failed to generate optimized structure")
|
| 1070 |
+
|
| 1071 |
+
return best_mol
|
| 1072 |
+
|
| 1073 |
+
@staticmethod
|
| 1074 |
+
def mol_to_sdf_bytes(mol):
|
| 1075 |
+
"""Convert RDKit molecule to SDF file bytes"""
|
| 1076 |
+
# First write to StringIO in text mode
|
| 1077 |
+
sio = StringIO()
|
| 1078 |
+
writer = Chem.SDWriter(sio)
|
| 1079 |
+
writer.write(mol)
|
| 1080 |
+
writer.close()
|
| 1081 |
+
|
| 1082 |
+
# Convert the string to bytes
|
| 1083 |
+
return sio.getvalue().encode('utf-8')
|
| 1084 |
+
|
| 1085 |
+
def process_input(smiles_input=None, file_obj=None, show_linear=False,
|
| 1086 |
+
show_segment_details=False, generate_3d=False, use_uff=False):
|
| 1087 |
+
"""Process input and create visualizations using PeptideAnalyzer"""
|
| 1088 |
+
analyzer = PeptideAnalyzer()
|
| 1089 |
+
temp_dir = tempfile.mkdtemp() if generate_3d else None
|
| 1090 |
+
structure_files = []
|
| 1091 |
+
|
| 1092 |
+
# Handle direct SMILES input
|
| 1093 |
+
if smiles_input:
|
| 1094 |
+
smiles = smiles_input.strip()
|
| 1095 |
+
|
| 1096 |
+
# First check if it's a peptide using analyzer's method
|
| 1097 |
+
if not analyzer.is_peptide(smiles):
|
| 1098 |
+
return "Error: Input SMILES does not appear to be a peptide structure.", None, None
|
| 1099 |
+
|
| 1100 |
+
try:
|
| 1101 |
+
# Create molecule
|
| 1102 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 1103 |
+
if mol is None:
|
| 1104 |
+
return "Error: Invalid SMILES notation.", None, None
|
| 1105 |
+
|
| 1106 |
+
# Generate 3D structures if requested
|
| 1107 |
+
if generate_3d:
|
| 1108 |
+
generator = PeptideStructureGenerator()
|
| 1109 |
+
|
| 1110 |
+
try:
|
| 1111 |
+
# Generate ETKDG structure
|
| 1112 |
+
mol_etkdg = generator.generate_structure_etkdg(smiles)
|
| 1113 |
+
etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
|
| 1114 |
+
writer = Chem.SDWriter(etkdg_path)
|
| 1115 |
+
writer.write(mol_etkdg)
|
| 1116 |
+
writer.close()
|
| 1117 |
+
structure_files.append(etkdg_path)
|
| 1118 |
+
|
| 1119 |
+
# Generate UFF structure if requested
|
| 1120 |
+
if use_uff:
|
| 1121 |
+
mol_uff = generator.generate_structure_uff(smiles)
|
| 1122 |
+
uff_path = os.path.join(temp_dir, "structure_uff.sdf")
|
| 1123 |
+
writer = Chem.SDWriter(uff_path)
|
| 1124 |
+
writer.write(mol_uff)
|
| 1125 |
+
writer.close()
|
| 1126 |
+
structure_files.append(uff_path)
|
| 1127 |
+
|
| 1128 |
+
except Exception as e:
|
| 1129 |
+
return f"Error generating 3D structures: {str(e)}", None, None, None
|
| 1130 |
+
|
| 1131 |
+
# Use analyzer to get sequence
|
| 1132 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1133 |
+
|
| 1134 |
+
# Process segments and build sequence
|
| 1135 |
+
sequence_parts = []
|
| 1136 |
+
output_text = ""
|
| 1137 |
+
|
| 1138 |
+
# Only include segment analysis in output if requested
|
| 1139 |
+
if show_segment_details:
|
| 1140 |
+
output_text += "Segment Analysis:\n"
|
| 1141 |
+
for i, segment in enumerate(segments):
|
| 1142 |
+
output_text += f"\nSegment {i}:\n"
|
| 1143 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1144 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1145 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1146 |
+
|
| 1147 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1148 |
+
if residue:
|
| 1149 |
+
if mods:
|
| 1150 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1151 |
+
else:
|
| 1152 |
+
sequence_parts.append(residue)
|
| 1153 |
+
output_text += f"Identified as: {residue}\n"
|
| 1154 |
+
output_text += f"Modifications: {mods}\n"
|
| 1155 |
+
else:
|
| 1156 |
+
output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
|
| 1157 |
+
output_text += "\n"
|
| 1158 |
+
else:
|
| 1159 |
+
# Just build sequence without detailed analysis in output
|
| 1160 |
+
for segment in segments:
|
| 1161 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1162 |
+
if residue:
|
| 1163 |
+
if mods:
|
| 1164 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1165 |
+
else:
|
| 1166 |
+
sequence_parts.append(residue)
|
| 1167 |
+
|
| 1168 |
+
# Check if cyclic using analyzer's method
|
| 1169 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1170 |
+
three_letter = '-'.join(sequence_parts)
|
| 1171 |
+
one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
|
| 1172 |
+
|
| 1173 |
+
if is_cyclic:
|
| 1174 |
+
three_letter = f"cyclo({three_letter})"
|
| 1175 |
+
one_letter = f"cyclo({one_letter})"
|
| 1176 |
+
|
| 1177 |
+
# Create cyclic structure visualization
|
| 1178 |
+
img_cyclic = annotate_cyclic_structure(mol, three_letter)
|
| 1179 |
+
|
| 1180 |
+
# Create linear representation if requested
|
| 1181 |
+
img_linear = None
|
| 1182 |
+
if show_linear:
|
| 1183 |
+
fig_linear = create_enhanced_linear_viz(three_letter, smiles)
|
| 1184 |
+
buf = BytesIO()
|
| 1185 |
+
fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
| 1186 |
+
buf.seek(0)
|
| 1187 |
+
img_linear = Image.open(buf)
|
| 1188 |
+
plt.close(fig_linear)
|
| 1189 |
+
|
| 1190 |
+
# Add summary to output
|
| 1191 |
+
summary = "Summary:\n"
|
| 1192 |
+
summary += f"Sequence: {three_letter}\n"
|
| 1193 |
+
summary += f"One-letter code: {one_letter}\n"
|
| 1194 |
+
summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1195 |
+
#if is_cyclic:
|
| 1196 |
+
#summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1197 |
+
#summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1198 |
+
|
| 1199 |
+
if structure_files:
|
| 1200 |
+
summary += "\n3D Structures Generated:\n"
|
| 1201 |
+
for filepath in structure_files:
|
| 1202 |
+
summary += f"- {os.path.basename(filepath)}\n"
|
| 1203 |
+
|
| 1204 |
+
return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
|
| 1205 |
+
|
| 1206 |
+
except Exception as e:
|
| 1207 |
+
return f"Error processing SMILES: {str(e)}", None, None, None
|
| 1208 |
+
|
| 1209 |
+
# Handle file input
|
| 1210 |
+
if file_obj is not None:
|
| 1211 |
+
try:
|
| 1212 |
+
# Handle file content
|
| 1213 |
+
if hasattr(file_obj, 'name'):
|
| 1214 |
+
with open(file_obj.name, 'r') as f:
|
| 1215 |
+
content = f.read()
|
| 1216 |
+
else:
|
| 1217 |
+
content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
|
| 1218 |
+
|
| 1219 |
+
output_text = ""
|
| 1220 |
+
for line in content.splitlines():
|
| 1221 |
+
smiles = line.strip()
|
| 1222 |
+
if smiles:
|
| 1223 |
+
# Check if it's a peptide
|
| 1224 |
+
if not analyzer.is_peptide(smiles):
|
| 1225 |
+
output_text += f"Skipping non-peptide SMILES: {smiles}\n"
|
| 1226 |
+
continue
|
| 1227 |
+
|
| 1228 |
+
# Process this SMILES
|
| 1229 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1230 |
+
sequence_parts = []
|
| 1231 |
+
|
| 1232 |
+
# Add segment details if requested
|
| 1233 |
+
if show_segment_details:
|
| 1234 |
+
output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
|
| 1235 |
+
for i, segment in enumerate(segments):
|
| 1236 |
+
output_text += f"\nSegment {i}:\n"
|
| 1237 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1238 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1239 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1240 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1241 |
+
if residue:
|
| 1242 |
+
if mods:
|
| 1243 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1244 |
+
else:
|
| 1245 |
+
sequence_parts.append(residue)
|
| 1246 |
+
output_text += f"Identified as: {residue}\n"
|
| 1247 |
+
output_text += f"Modifications: {mods}\n"
|
| 1248 |
+
else:
|
| 1249 |
+
for segment in segments:
|
| 1250 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1251 |
+
if residue:
|
| 1252 |
+
if mods:
|
| 1253 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1254 |
+
else:
|
| 1255 |
+
sequence_parts.append(residue)
|
| 1256 |
+
|
| 1257 |
+
# Get cyclicity and create sequence
|
| 1258 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1259 |
+
sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
|
| 1260 |
+
|
| 1261 |
+
output_text += f"\nSummary for SMILES: {smiles}\n"
|
| 1262 |
+
output_text += f"Sequence: {sequence}\n"
|
| 1263 |
+
output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1264 |
+
if is_cyclic:
|
| 1265 |
+
output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1266 |
+
#output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1267 |
+
output_text += "-" * 50 + "\n"
|
| 1268 |
+
|
| 1269 |
+
return output_text, None, None
|
| 1270 |
+
|
| 1271 |
+
except Exception as e:
|
| 1272 |
+
return f"Error processing file: {str(e)}", None, None
|
| 1273 |
+
|
| 1274 |
+
return "No input provided.", None, None
|
a2d2_pep/pep_utils/utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import fsspec
|
| 9 |
+
import lightning
|
| 10 |
+
import torch
|
| 11 |
+
from timm.scheduler import CosineLRScheduler
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def sample_categorical_logits(logits, dtype=torch.float64):
|
| 18 |
+
# do not require logits to be log-softmaxed
|
| 19 |
+
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
|
| 20 |
+
return (logits + gumbel_noise).argmax(dim=-1)
|
| 21 |
+
|
| 22 |
+
def fsspec_exists(filename):
|
| 23 |
+
"""Check if a file exists using fsspec."""
|
| 24 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 25 |
+
return fs.exists(filename)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fsspec_listdir(dirname):
|
| 29 |
+
"""Listdir in manner compatible with fsspec."""
|
| 30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 31 |
+
return fs.ls(dirname)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 35 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 36 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 37 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def print_nans(tensor, name):
|
| 41 |
+
if torch.isnan(tensor).any():
|
| 42 |
+
print(name, tensor)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CosineDecayWarmupLRScheduler(
|
| 46 |
+
CosineLRScheduler,
|
| 47 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 48 |
+
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
self._last_epoch = -1
|
| 52 |
+
self.step(epoch=0)
|
| 53 |
+
|
| 54 |
+
def step(self, epoch=None):
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self._last_epoch += 1
|
| 57 |
+
else:
|
| 58 |
+
self._last_epoch = epoch
|
| 59 |
+
# We call either step or step_update, depending on
|
| 60 |
+
# whether we're using the scheduler every epoch or every
|
| 61 |
+
# step.
|
| 62 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 63 |
+
# meant for each epoch), and if we set scheduler
|
| 64 |
+
# interval to "step", then the learning rate update will
|
| 65 |
+
# be wrong.
|
| 66 |
+
if self.t_in_epochs:
|
| 67 |
+
super().step(epoch=self._last_epoch)
|
| 68 |
+
else:
|
| 69 |
+
super().step_update(num_updates=self._last_epoch)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LoggingContext:
|
| 73 |
+
"""Context manager for selective logging."""
|
| 74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 75 |
+
self.logger = logger
|
| 76 |
+
self.level = level
|
| 77 |
+
self.handler = handler
|
| 78 |
+
self.close = close
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
if self.level is not None:
|
| 82 |
+
self.old_level = self.logger.level
|
| 83 |
+
self.logger.setLevel(self.level)
|
| 84 |
+
if self.handler:
|
| 85 |
+
self.logger.addHandler(self.handler)
|
| 86 |
+
|
| 87 |
+
def __exit__(self, et, ev, tb):
|
| 88 |
+
if self.level is not None:
|
| 89 |
+
self.logger.setLevel(self.old_level)
|
| 90 |
+
if self.handler:
|
| 91 |
+
self.logger.removeHandler(self.handler)
|
| 92 |
+
if self.handler and self.close:
|
| 93 |
+
self.handler.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 98 |
+
|
| 99 |
+
logger = logging.getLogger(name)
|
| 100 |
+
logger.setLevel(level)
|
| 101 |
+
|
| 102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 105 |
+
'exception', 'fatal', 'critical'):
|
| 106 |
+
setattr(logger,
|
| 107 |
+
level,
|
| 108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 109 |
+
getattr(logger, level)))
|
| 110 |
+
|
| 111 |
+
return logger
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def str2bool(v):
|
| 115 |
+
if isinstance(v, bool):
|
| 116 |
+
return v
|
| 117 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 118 |
+
return True
|
| 119 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 120 |
+
return False
|
| 121 |
+
else:
|
| 122 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def set_seed(seed, use_cuda):
|
| 126 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 127 |
+
np.random.seed(seed)
|
| 128 |
+
random.seed(seed)
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
# torch.backends.cudnn.deterministic = True
|
| 131 |
+
if use_cuda:
|
| 132 |
+
torch.cuda.manual_seed(seed)
|
| 133 |
+
torch.cuda.manual_seed_all(seed)
|
| 134 |
+
print(f'=> Seed of the run set to {seed}')
|
| 135 |
+
|
a2d2_pep/remasking_scheduleaware.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Schedule-aware remasking and insertion logic that ensures the number of masked tokens
|
| 3 |
+
follows the interpolant schedule.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def apply_schedule_aware_insertion(
|
| 9 |
+
model,
|
| 10 |
+
xt_tmp,
|
| 11 |
+
new_xt,
|
| 12 |
+
t,
|
| 13 |
+
dt,
|
| 14 |
+
ext,
|
| 15 |
+
mask,
|
| 16 |
+
pad,
|
| 17 |
+
max_length,
|
| 18 |
+
orig_mask,
|
| 19 |
+
new_pos_orig,
|
| 20 |
+
quality_threshold=1,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Remove low-quality insertions based on insertion confidence while respecting
|
| 24 |
+
the interpolant schedule for expected sequence length.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model: Model with planner and interpolant
|
| 28 |
+
xt_tmp: Sequence after insertion [B, L]
|
| 29 |
+
new_xt: Sequence before insertion [B, L]
|
| 30 |
+
t: Current time [B]
|
| 31 |
+
dt: Time step size
|
| 32 |
+
ext: Number of insertions per gap [B, L+1]
|
| 33 |
+
mask: Mask token ID
|
| 34 |
+
pad: Pad token ID
|
| 35 |
+
max_length: Maximum sequence length
|
| 36 |
+
orig_mask: Mask of original token positions [B, L]
|
| 37 |
+
new_pos_orig: New positions of original tokens [B, L]
|
| 38 |
+
quality_threshold: If a float, drop insertions with confidence below it; if None, use schedule-driven deletion
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
|
| 42 |
+
"""
|
| 43 |
+
device = xt_tmp.device
|
| 44 |
+
batch_size, L = xt_tmp.shape
|
| 45 |
+
total_ext = ext.sum(dim=1)
|
| 46 |
+
|
| 47 |
+
# Only proceed if there were insertions
|
| 48 |
+
if total_ext.sum() == 0:
|
| 49 |
+
return xt_tmp
|
| 50 |
+
|
| 51 |
+
# Get planner predictions on inserted state. The insertion head is trained
|
| 52 |
+
# with the pre-step time t (see loss_insert_planner_flexible), so condition
|
| 53 |
+
# on t here too; t_next is still used below for the length schedule.
|
| 54 |
+
t_next = t + dt
|
| 55 |
+
planner_out = model.planner(xt_tmp, t)
|
| 56 |
+
insertion_conf = planner_out.get("insertion_conf", None)
|
| 57 |
+
|
| 58 |
+
if insertion_conf is None:
|
| 59 |
+
return xt_tmp
|
| 60 |
+
|
| 61 |
+
insertion_conf = insertion_conf.squeeze(-1) # (B, L)
|
| 62 |
+
|
| 63 |
+
# Expected sequence length at next timestep according to schedule
|
| 64 |
+
current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
|
| 65 |
+
expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
|
| 66 |
+
estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
|
| 67 |
+
expected_length = estimated_final_length * expected_progress # [B]
|
| 68 |
+
|
| 69 |
+
# Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
|
| 70 |
+
# Fancy-indexing scatter avoids the per-batch python loop.
|
| 71 |
+
valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
|
| 72 |
+
valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
|
| 73 |
+
is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
|
| 74 |
+
is_original[valid_b, valid_p] = True
|
| 75 |
+
inserted_positions = (xt_tmp == mask) & ~is_original
|
| 76 |
+
|
| 77 |
+
# Two deletion modes, selected by `quality_threshold`:
|
| 78 |
+
# * float: drop insertions whose confidence is below the threshold, capped
|
| 79 |
+
# so the length never falls below the scheduled minimum.
|
| 80 |
+
candidates = inserted_positions & (insertion_conf < quality_threshold)
|
| 81 |
+
num_bad = candidates.sum(dim=1) # [B], long
|
| 82 |
+
min_length = expected_length.long().clamp(min=1) # [B]
|
| 83 |
+
max_removable = (current_length_after.long() - min_length).clamp(min=0)
|
| 84 |
+
length_after_removal = current_length_after.long() - num_bad
|
| 85 |
+
schedule_violates = length_after_removal < min_length
|
| 86 |
+
k_per_row = torch.where(schedule_violates, max_removable, num_bad)
|
| 87 |
+
k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
|
| 88 |
+
|
| 89 |
+
if not candidates.any():
|
| 90 |
+
return xt_tmp
|
| 91 |
+
|
| 92 |
+
# Select the lowest-confidence candidates per row via a sort.
|
| 93 |
+
neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
|
| 94 |
+
scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
|
| 95 |
+
_, sorted_indices = scores.sort(dim=1, descending=True)
|
| 96 |
+
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
|
| 97 |
+
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
|
| 98 |
+
final_bad = torch.zeros_like(candidates)
|
| 99 |
+
final_bad.scatter_(1, sorted_indices, keep_in_topk)
|
| 100 |
+
|
| 101 |
+
if not final_bad.any():
|
| 102 |
+
return xt_tmp
|
| 103 |
+
|
| 104 |
+
# Compact each row to the left (keep good, drop bad), then pad the tail.
|
| 105 |
+
# Stable sort by the bad flag pushes bad positions to the right.
|
| 106 |
+
sort_key = final_bad.long()
|
| 107 |
+
_, perm = torch.sort(sort_key, dim=1, stable=True)
|
| 108 |
+
xt_tmp = torch.gather(xt_tmp, 1, perm)
|
| 109 |
+
num_keep = (~final_bad).sum(dim=1) # [B]
|
| 110 |
+
tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
|
| 111 |
+
xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
|
| 112 |
+
|
| 113 |
+
return xt_tmp
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def apply_schedule_aware_remasking(
|
| 117 |
+
model,
|
| 118 |
+
new_xt,
|
| 119 |
+
t,
|
| 120 |
+
dt,
|
| 121 |
+
remasking_conf,
|
| 122 |
+
clean_index,
|
| 123 |
+
mask,
|
| 124 |
+
neg_inf,
|
| 125 |
+
batch_size,
|
| 126 |
+
unmask_quality_threshold=None,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
model: Model with interpolant that has an unmask_schedule
|
| 133 |
+
new_xt: Current sequence [B, L]
|
| 134 |
+
t: Current time [B]
|
| 135 |
+
dt: Time step size
|
| 136 |
+
remasking_conf: Confidence scores for tokens [B, L]
|
| 137 |
+
clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
|
| 138 |
+
mask: Mask token ID
|
| 139 |
+
neg_inf: Negative infinity tensor
|
| 140 |
+
batch_size: Batch size
|
| 141 |
+
unmask_quality_threshold: If None (default), remask exactly the schedule
|
| 142 |
+
excess (count-based). If a float, ignore the schedule budget entirely
|
| 143 |
+
and remask EVERY clean token whose unmasking-quality confidence is
|
| 144 |
+
below the threshold. Higher threshold => more aggressive remasking.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
new_xt: Modified sequence with schedule-aware remasking applied
|
| 148 |
+
"""
|
| 149 |
+
# Threshold gate (overrides the schedule-driven count when set): remask every
|
| 150 |
+
# clean token whose unmasking-quality confidence is below the threshold,
|
| 151 |
+
# regardless of the schedule budget. Higher threshold => more remasking.
|
| 152 |
+
if unmask_quality_threshold is not None:
|
| 153 |
+
to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
|
| 154 |
+
return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
|
| 155 |
+
|
| 156 |
+
t_next = t + dt
|
| 157 |
+
num_clean = clean_index.sum(dim=1) # [B], long
|
| 158 |
+
current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
|
| 159 |
+
expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
|
| 160 |
+
expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
|
| 161 |
+
masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
|
| 162 |
+
|
| 163 |
+
# Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
|
| 164 |
+
k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
|
| 165 |
+
|
| 166 |
+
if k_per_row.sum() == 0:
|
| 167 |
+
return new_xt
|
| 168 |
+
|
| 169 |
+
# Use confidence to decide which clean tokens to remask: lowest conf first.
|
| 170 |
+
remasking_score_temp = -1.0 * remasking_conf # low conf = high score
|
| 171 |
+
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
|
| 172 |
+
|
| 173 |
+
_, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
|
| 174 |
+
L = remasking_score_temp.shape[1]
|
| 175 |
+
positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
|
| 176 |
+
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
|
| 177 |
+
to_mask = torch.zeros_like(clean_index)
|
| 178 |
+
to_mask.scatter_(1, sorted_indices, keep_in_topk)
|
| 179 |
+
new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
|
| 180 |
+
|
| 181 |
+
return new_xt
|
a2d2_pep/sampling.py
ADDED
|
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Literal, Optional
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from lightning_modules.mdm import MaskedDiffusionModule
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SamplingTraceDatapoint:
|
| 16 |
+
t: float
|
| 17 |
+
event_type: Literal["insertion", "change"]
|
| 18 |
+
position: int
|
| 19 |
+
token: Any
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SamplingResult:
|
| 24 |
+
samples: torch.Tensor
|
| 25 |
+
# Trace is supposed to be processed sequentially as updates are not commutative
|
| 26 |
+
trace: Optional[list[SamplingTraceDatapoint]]
|
| 27 |
+
|
| 28 |
+
def __iter__(self):
|
| 29 |
+
yield from [self.samples, self.trace]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Sample from categorical distribution for each position using the transition probabilities
|
| 33 |
+
def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""Sample one token per position from probability distribution.
|
| 35 |
+
Args:
|
| 36 |
+
probs: [batch_size, seq_len, vocab_size] transition probabilities
|
| 37 |
+
Returns:
|
| 38 |
+
[batch_size, seq_len] sampled token indices
|
| 39 |
+
"""
|
| 40 |
+
batch_size, seq_len, vocab_size = probs.shape
|
| 41 |
+
flat_probs = probs.view(-1, vocab_size)
|
| 42 |
+
samples = torch.multinomial(flat_probs, num_samples=1)
|
| 43 |
+
return samples.view(batch_size, seq_len)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
|
| 48 |
+
batch_size, seq_len, vocab_size = probs.shape
|
| 49 |
+
|
| 50 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
|
| 51 |
+
noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
|
| 52 |
+
|
| 53 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 54 |
+
samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
|
| 55 |
+
|
| 56 |
+
return samples.view(batch_size, seq_len)
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def mdm_euler_sampling(
|
| 60 |
+
model: MaskedDiffusionModule,
|
| 61 |
+
steps: int,
|
| 62 |
+
mask: int,
|
| 63 |
+
pad: int,
|
| 64 |
+
batch_size: int,
|
| 65 |
+
max_length: int,
|
| 66 |
+
return_trace: bool = False,
|
| 67 |
+
temperature: float = 1.0,
|
| 68 |
+
):
|
| 69 |
+
assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
|
| 70 |
+
device = next(model.parameters()).device
|
| 71 |
+
xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
|
| 72 |
+
|
| 73 |
+
dt = 1.0 / steps
|
| 74 |
+
t = torch.zeros(batch_size, device=device)
|
| 75 |
+
|
| 76 |
+
for i in range(steps):
|
| 77 |
+
print("i-th sampling step")
|
| 78 |
+
# ——— predict and convert rates ———
|
| 79 |
+
pred_rate = model(xt, t)
|
| 80 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 81 |
+
unmask_rate = pred_rate.unmask_rate
|
| 82 |
+
|
| 83 |
+
# ——— unmask step (Euler) ———
|
| 84 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 85 |
+
unmask_rate[xt != mask] = 0
|
| 86 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 87 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 88 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 89 |
+
|
| 90 |
+
_xt = xt.clone()
|
| 91 |
+
trans_prob.scatter_add_(
|
| 92 |
+
2,
|
| 93 |
+
_xt.unsqueeze(-1),
|
| 94 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Apply temperature scaling
|
| 98 |
+
if temperature != 1.0:
|
| 99 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 100 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 101 |
+
|
| 102 |
+
if i == steps - 1:
|
| 103 |
+
print("Final step, removing mask token from sampling")
|
| 104 |
+
trans_prob[mask_pos + (mask,)] = 0.0
|
| 105 |
+
print(trans_prob[mask_pos + (mask,)])
|
| 106 |
+
|
| 107 |
+
new_xt = _sample_tokens(trans_prob)
|
| 108 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 109 |
+
|
| 110 |
+
xt = new_xt
|
| 111 |
+
t = t + dt
|
| 112 |
+
|
| 113 |
+
return xt, []
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def any_order_mask_insertion_euler_sampling(
|
| 118 |
+
model: torch.nn.Module,
|
| 119 |
+
steps: int,
|
| 120 |
+
mask: int,
|
| 121 |
+
pad: int,
|
| 122 |
+
batch_size: int,
|
| 123 |
+
max_length: int,
|
| 124 |
+
return_trace: bool = False,
|
| 125 |
+
temperature: float = 1.0,
|
| 126 |
+
) -> SamplingResult:
|
| 127 |
+
device = next(model.parameters()).device
|
| 128 |
+
|
| 129 |
+
# 1) Initialize all‑pad sequence and trace
|
| 130 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 131 |
+
sampling_trace = []
|
| 132 |
+
|
| 133 |
+
dt = 1.0 / steps
|
| 134 |
+
t = torch.zeros(batch_size, device=device)
|
| 135 |
+
|
| 136 |
+
# Precompute row indices for scatter
|
| 137 |
+
batch_idx_L = (
|
| 138 |
+
torch.arange(batch_size, device=device)
|
| 139 |
+
.view(batch_size, 1)
|
| 140 |
+
.expand(batch_size, max_length)
|
| 141 |
+
)
|
| 142 |
+
pos_idx_L = (
|
| 143 |
+
torch.arange(max_length, device=device)
|
| 144 |
+
.view(1, max_length)
|
| 145 |
+
.expand(batch_size, max_length)
|
| 146 |
+
)
|
| 147 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 148 |
+
|
| 149 |
+
for i in range(steps):
|
| 150 |
+
# ——— predict and convert rates ———
|
| 151 |
+
pred_rate = model(xt, t)
|
| 152 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 153 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 154 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 155 |
+
|
| 156 |
+
# ——— unmask step (Euler) ———
|
| 157 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 158 |
+
unmask_rate[xt != mask] = 0
|
| 159 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 160 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 161 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 162 |
+
|
| 163 |
+
# add “stay” probability
|
| 164 |
+
_xt = xt.clone()
|
| 165 |
+
_xt[xt == pad] = mask
|
| 166 |
+
trans_prob.scatter_add_(
|
| 167 |
+
2,
|
| 168 |
+
_xt.unsqueeze(-1),
|
| 169 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if i == steps - 1:
|
| 173 |
+
print("Final step, removing mask token from sampling")
|
| 174 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 175 |
+
|
| 176 |
+
# renormalize probabilities to ensure they sum to 1
|
| 177 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 178 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 179 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 180 |
+
if mask_has_zero_prob.any():
|
| 181 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 182 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 183 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 184 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 185 |
+
else:
|
| 186 |
+
# normalize to sum to 1
|
| 187 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 188 |
+
|
| 189 |
+
new_xt = _sample_tokens(trans_prob)
|
| 190 |
+
new_xt[xt == pad] = pad
|
| 191 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 192 |
+
|
| 193 |
+
if i != steps - 1:
|
| 194 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 195 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 196 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 197 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 198 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 199 |
+
total_ext = ext.sum(dim=1)
|
| 200 |
+
valid = xt_len + total_ext <= max_length
|
| 201 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 202 |
+
|
| 203 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 204 |
+
new_len = xt_len + total_ext # (B,)
|
| 205 |
+
|
| 206 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 207 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 208 |
+
xt_tmp[mask_fill] = mask
|
| 209 |
+
|
| 210 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 211 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 212 |
+
flat_b = batch_idx_L[orig_mask]
|
| 213 |
+
flat_p = new_pos_orig[orig_mask]
|
| 214 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 215 |
+
else:
|
| 216 |
+
xt_tmp = new_xt
|
| 217 |
+
|
| 218 |
+
if return_trace:
|
| 219 |
+
# Check if the token was changed
|
| 220 |
+
for batch_idx in range(batch_size):
|
| 221 |
+
for j in range(max_length):
|
| 222 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 223 |
+
sampling_trace[batch_idx].append(
|
| 224 |
+
SamplingTraceDatapoint(
|
| 225 |
+
t=t[batch_idx].item(),
|
| 226 |
+
event_type="change",
|
| 227 |
+
position=j,
|
| 228 |
+
token=new_xt[batch_idx, j].item(),
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Check if a new token was inserted
|
| 233 |
+
for j in range(max_length):
|
| 234 |
+
id = max_length - j - 1
|
| 235 |
+
if ext[batch_idx, id]:
|
| 236 |
+
sampling_trace[batch_idx].append(
|
| 237 |
+
SamplingTraceDatapoint(
|
| 238 |
+
t=t[batch_idx].item(),
|
| 239 |
+
event_type="insertion",
|
| 240 |
+
position=id,
|
| 241 |
+
token=mask,
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
xt = xt_tmp
|
| 246 |
+
t = t + dt
|
| 247 |
+
|
| 248 |
+
return xt, sampling_trace
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def batch_mcts_reverse_step(
|
| 252 |
+
xt: torch.Tensor,
|
| 253 |
+
t: torch.Tensor,
|
| 254 |
+
dt: float,
|
| 255 |
+
model: torch.nn.Module,
|
| 256 |
+
pretrained: torch.nn.Module,
|
| 257 |
+
mask: int,
|
| 258 |
+
pad: int,
|
| 259 |
+
batch_size: int,
|
| 260 |
+
max_length: int,
|
| 261 |
+
last_step: bool = False,
|
| 262 |
+
temperature: float = 1.0,
|
| 263 |
+
) -> SamplingResult:
|
| 264 |
+
device = next(model.parameters()).device
|
| 265 |
+
|
| 266 |
+
xt = xt.repeat(batch_size, 1)
|
| 267 |
+
|
| 268 |
+
# squeeze to remove extra dimensions, then expand to batch_size
|
| 269 |
+
t = t.squeeze().expand(batch_size)
|
| 270 |
+
# precompute row indices for scatter
|
| 271 |
+
batch_idx_L = (
|
| 272 |
+
torch.arange(batch_size, device=device)
|
| 273 |
+
.view(batch_size, 1)
|
| 274 |
+
.expand(batch_size, max_length)
|
| 275 |
+
)
|
| 276 |
+
pos_idx_L = (
|
| 277 |
+
torch.arange(max_length, device=device)
|
| 278 |
+
.view(1, max_length)
|
| 279 |
+
.expand(batch_size, max_length)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# ——— predict and convert rates ———
|
| 283 |
+
pred_rate = model(xt, t)
|
| 284 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 285 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 286 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 287 |
+
|
| 288 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 289 |
+
pretrained_pred = pretrained(xt, t)
|
| 290 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 291 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 292 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 293 |
+
|
| 294 |
+
# ——— unmask step (Euler) ———
|
| 295 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 296 |
+
unmask_rate[xt != mask] = 0
|
| 297 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 298 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 299 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 300 |
+
|
| 301 |
+
# Same for pretrained
|
| 302 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 303 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 304 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 305 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 306 |
+
|
| 307 |
+
# add “stay” probability
|
| 308 |
+
_xt = xt.clone()
|
| 309 |
+
_xt[xt == pad] = mask
|
| 310 |
+
trans_prob.scatter_add_(
|
| 311 |
+
2,
|
| 312 |
+
_xt.unsqueeze(-1),
|
| 313 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 314 |
+
)
|
| 315 |
+
pretrained_trans_prob.scatter_add_(
|
| 316 |
+
2,
|
| 317 |
+
_xt.unsqueeze(-1),
|
| 318 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if last_step:
|
| 322 |
+
print("Final step, removing mask token from sampling")
|
| 323 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 324 |
+
|
| 325 |
+
# renormalize probabilities to ensure they sum to 1
|
| 326 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 327 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 328 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 329 |
+
if mask_has_zero_prob.any():
|
| 330 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 331 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 332 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 333 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 334 |
+
else:
|
| 335 |
+
# normalize to sum to 1
|
| 336 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 337 |
+
|
| 338 |
+
new_xt = _sample_tokens(trans_prob)
|
| 339 |
+
new_xt[xt == pad] = pad
|
| 340 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 341 |
+
|
| 342 |
+
# ——— compute log probabilities for RND ———
|
| 343 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 344 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 345 |
+
|
| 346 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 347 |
+
|
| 348 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 349 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 350 |
+
|
| 351 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 352 |
+
|
| 353 |
+
if not last_step:
|
| 354 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 355 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 356 |
+
|
| 357 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 358 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 359 |
+
|
| 360 |
+
# log P(ext; λ) = ext*log(λ) - λ
|
| 361 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 362 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 363 |
+
|
| 364 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 365 |
+
log_rnd += log_insert_diff
|
| 366 |
+
log_pretrained_step += log_pretrained_insert
|
| 367 |
+
log_policy_step += log_policy_insert
|
| 368 |
+
|
| 369 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 370 |
+
seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
|
| 371 |
+
gaps = torch.arange(seq_dim, device=device).view(1, -1)
|
| 372 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 373 |
+
total_ext = ext.sum(dim=1)
|
| 374 |
+
valid = xt_len + total_ext <= max_length
|
| 375 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 376 |
+
|
| 377 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 378 |
+
new_len = xt_len + total_ext # (B,)
|
| 379 |
+
|
| 380 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 381 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 382 |
+
xt_tmp[mask_fill] = mask
|
| 383 |
+
|
| 384 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 385 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 386 |
+
flat_b = batch_idx_L[orig_mask]
|
| 387 |
+
flat_p = new_pos_orig[orig_mask]
|
| 388 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 389 |
+
else:
|
| 390 |
+
xt_tmp = new_xt
|
| 391 |
+
|
| 392 |
+
return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@torch.no_grad()
|
| 396 |
+
def mcts_reverse_step(
|
| 397 |
+
xt: torch.Tensor,
|
| 398 |
+
t: torch.Tensor,
|
| 399 |
+
dt: float,
|
| 400 |
+
model: torch.nn.Module,
|
| 401 |
+
pretrained: torch.nn.Module,
|
| 402 |
+
mask: int,
|
| 403 |
+
pad: int,
|
| 404 |
+
max_length: int,
|
| 405 |
+
last_step: bool = False,
|
| 406 |
+
temperature: float = 1.0,
|
| 407 |
+
) -> SamplingResult:
|
| 408 |
+
device = next(model.parameters()).device
|
| 409 |
+
|
| 410 |
+
batch_size = xt.size(0)
|
| 411 |
+
|
| 412 |
+
# precompute row indices for scatter
|
| 413 |
+
batch_idx_L = (
|
| 414 |
+
torch.arange(batch_size, device=device)
|
| 415 |
+
.view(batch_size, 1)
|
| 416 |
+
.expand(batch_size, max_length)
|
| 417 |
+
)
|
| 418 |
+
pos_idx_L = (
|
| 419 |
+
torch.arange(max_length, device=device)
|
| 420 |
+
.view(1, max_length)
|
| 421 |
+
.expand(batch_size, max_length)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# ——— predict and convert rates ———
|
| 425 |
+
pred_rate = model(xt, t)
|
| 426 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 427 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 428 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 429 |
+
|
| 430 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 431 |
+
pretrained_pred = pretrained(xt, t)
|
| 432 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 433 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 434 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 435 |
+
|
| 436 |
+
# ——— unmask step (Euler) ———
|
| 437 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 438 |
+
unmask_rate[xt != mask] = 0
|
| 439 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 440 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 441 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 442 |
+
|
| 443 |
+
# same for pretrained
|
| 444 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 445 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 446 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 447 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 448 |
+
|
| 449 |
+
# add “stay” probability
|
| 450 |
+
_xt = xt.clone()
|
| 451 |
+
_xt[xt == pad] = mask
|
| 452 |
+
trans_prob.scatter_add_(
|
| 453 |
+
2,
|
| 454 |
+
_xt.unsqueeze(-1),
|
| 455 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 456 |
+
)
|
| 457 |
+
pretrained_trans_prob.scatter_add_(
|
| 458 |
+
2,
|
| 459 |
+
_xt.unsqueeze(-1),
|
| 460 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if last_step:
|
| 464 |
+
print("Final step, removing mask token from sampling")
|
| 465 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 466 |
+
|
| 467 |
+
# renormalize probabilities to ensure they sum to 1
|
| 468 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 469 |
+
# avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 470 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 471 |
+
if mask_has_zero_prob.any():
|
| 472 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 473 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 474 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 475 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 476 |
+
else:
|
| 477 |
+
# normalize to sum to 1
|
| 478 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 479 |
+
|
| 480 |
+
new_xt = _sample_tokens(trans_prob)
|
| 481 |
+
new_xt[xt == pad] = pad
|
| 482 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 483 |
+
|
| 484 |
+
# ——— compute log probabilities for RND ———
|
| 485 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 486 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 487 |
+
|
| 488 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 489 |
+
|
| 490 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 491 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 492 |
+
|
| 493 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 494 |
+
|
| 495 |
+
if not last_step:
|
| 496 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 497 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 498 |
+
|
| 499 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 500 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 501 |
+
|
| 502 |
+
# log P(ext; λ) = ext*log(λ) - λ
|
| 503 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 504 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 505 |
+
|
| 506 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 507 |
+
log_rnd += log_insert_diff
|
| 508 |
+
log_pretrained_step += log_pretrained_insert
|
| 509 |
+
log_policy_step += log_policy_insert
|
| 510 |
+
|
| 511 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 512 |
+
seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
|
| 513 |
+
gaps = torch.arange(seq_dim, device=device).view(1, -1)
|
| 514 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 515 |
+
total_ext = ext.sum(dim=1)
|
| 516 |
+
valid = xt_len + total_ext <= max_length
|
| 517 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 518 |
+
|
| 519 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 520 |
+
new_len = xt_len + total_ext # (B,)
|
| 521 |
+
|
| 522 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 523 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 524 |
+
xt_tmp[mask_fill] = mask
|
| 525 |
+
|
| 526 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 527 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 528 |
+
flat_b = batch_idx_L[orig_mask]
|
| 529 |
+
flat_p = new_pos_orig[orig_mask]
|
| 530 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 531 |
+
else:
|
| 532 |
+
xt_tmp = new_xt
|
| 533 |
+
|
| 534 |
+
return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
|
| 535 |
+
|
| 536 |
+
@torch.no_grad()
|
| 537 |
+
def any_order_euler_sampling_with_schedule(
|
| 538 |
+
model: torch.nn.Module,
|
| 539 |
+
time_schedule: torch.Tensor,
|
| 540 |
+
mask: int,
|
| 541 |
+
pad: int,
|
| 542 |
+
batch_size: int,
|
| 543 |
+
max_length: int,
|
| 544 |
+
return_trace: bool = False,
|
| 545 |
+
temperature: float = 1.0,
|
| 546 |
+
) -> SamplingResult:
|
| 547 |
+
device = next(model.parameters()).device
|
| 548 |
+
|
| 549 |
+
time_schedule = time_schedule.to(device)
|
| 550 |
+
if time_schedule[0] < time_schedule[-1]:
|
| 551 |
+
time_schedule = torch.flip(time_schedule, [0]) # descending order
|
| 552 |
+
|
| 553 |
+
steps = len(time_schedule) - 1
|
| 554 |
+
|
| 555 |
+
# initialize all-pad sequence and trace
|
| 556 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 557 |
+
|
| 558 |
+
# precompute row indices for scatter
|
| 559 |
+
batch_idx_L = (
|
| 560 |
+
torch.arange(batch_size, device=device)
|
| 561 |
+
.view(batch_size, 1)
|
| 562 |
+
.expand(batch_size, max_length)
|
| 563 |
+
)
|
| 564 |
+
pos_idx_L = (
|
| 565 |
+
torch.arange(max_length, device=device)
|
| 566 |
+
.view(1, max_length)
|
| 567 |
+
.expand(batch_size, max_length)
|
| 568 |
+
)
|
| 569 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 570 |
+
|
| 571 |
+
for i in range(steps):
|
| 572 |
+
# use scheduled timesteps
|
| 573 |
+
t = time_schedule[i].repeat(batch_size)
|
| 574 |
+
t_next = time_schedule[i + 1]
|
| 575 |
+
dt = (t - t_next).abs() # timestep difference
|
| 576 |
+
|
| 577 |
+
# ——— predict and convert rates ———
|
| 578 |
+
pred_rate = model(xt, t)
|
| 579 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 580 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 581 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 582 |
+
|
| 583 |
+
# ——— unmask step (Euler) ———
|
| 584 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 585 |
+
unmask_rate[xt != mask] = 0
|
| 586 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 587 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 588 |
+
trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
|
| 589 |
+
|
| 590 |
+
# add "stay" probability
|
| 591 |
+
_xt = xt.clone()
|
| 592 |
+
_xt[xt == pad] = mask
|
| 593 |
+
trans_prob.scatter_add_(
|
| 594 |
+
2,
|
| 595 |
+
_xt.unsqueeze(-1),
|
| 596 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# Apply temperature scaling
|
| 600 |
+
if temperature != 1.0:
|
| 601 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 602 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 603 |
+
|
| 604 |
+
if i == steps - 1:
|
| 605 |
+
print("Final step, removing mask token from sampling")
|
| 606 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 607 |
+
|
| 608 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 609 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 610 |
+
|
| 611 |
+
if mask_has_zero_prob.any():
|
| 612 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 613 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 614 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 615 |
+
else:
|
| 616 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 617 |
+
|
| 618 |
+
new_xt = _sample_tokens(trans_prob)
|
| 619 |
+
new_xt[xt == pad] = pad
|
| 620 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 621 |
+
|
| 622 |
+
if i != steps - 1:
|
| 623 |
+
# ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
|
| 624 |
+
ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 625 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 626 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 627 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 628 |
+
total_ext = ext.sum(dim=1)
|
| 629 |
+
valid = xt_len + total_ext <= max_length
|
| 630 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 631 |
+
|
| 632 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 633 |
+
new_len = xt_len + total_ext # (B,)
|
| 634 |
+
|
| 635 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 636 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 637 |
+
xt_tmp[mask_fill] = mask
|
| 638 |
+
|
| 639 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 640 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 641 |
+
flat_b = batch_idx_L[orig_mask]
|
| 642 |
+
flat_p = new_pos_orig[orig_mask]
|
| 643 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 644 |
+
else:
|
| 645 |
+
xt_tmp = new_xt
|
| 646 |
+
|
| 647 |
+
if return_trace:
|
| 648 |
+
# Check if the token was changed
|
| 649 |
+
for batch_idx in range(batch_size):
|
| 650 |
+
for j in range(max_length):
|
| 651 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 652 |
+
sampling_trace[batch_idx].append(
|
| 653 |
+
SamplingTraceDatapoint(
|
| 654 |
+
t=t[batch_idx].item(),
|
| 655 |
+
event_type="change",
|
| 656 |
+
position=j,
|
| 657 |
+
token=new_xt[batch_idx, j].item(),
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Check if a new token was inserted
|
| 662 |
+
for j in range(max_length):
|
| 663 |
+
id = max_length - j - 1
|
| 664 |
+
if ext[batch_idx, id]:
|
| 665 |
+
sampling_trace[batch_idx].append(
|
| 666 |
+
SamplingTraceDatapoint(
|
| 667 |
+
t=t[batch_idx].item(),
|
| 668 |
+
event_type="insertion",
|
| 669 |
+
position=id,
|
| 670 |
+
token=mask,
|
| 671 |
+
)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
xt = xt_tmp
|
| 675 |
+
|
| 676 |
+
return xt, sampling_trace
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@torch.no_grad()
|
| 680 |
+
def any_order_mask_insertion_euler_sampling_with_rnd(
|
| 681 |
+
model, pretrained, reward_model, analyzer,
|
| 682 |
+
tokenizer, steps,
|
| 683 |
+
mask,
|
| 684 |
+
pad,
|
| 685 |
+
batch_size,
|
| 686 |
+
max_length,
|
| 687 |
+
return_trace = False,
|
| 688 |
+
alpha = 0.1,
|
| 689 |
+
temperature: float = 1.0,
|
| 690 |
+
):
|
| 691 |
+
device = next(model.parameters()).device
|
| 692 |
+
|
| 693 |
+
# initialize all‑pad sequence and trace
|
| 694 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 695 |
+
sampling_trace = []
|
| 696 |
+
|
| 697 |
+
# initialize log_rnd to accumulate log probability ratios
|
| 698 |
+
log_rnd = torch.zeros(batch_size, device=device)
|
| 699 |
+
|
| 700 |
+
dt = 1.0 / steps
|
| 701 |
+
t = torch.zeros(batch_size, device=device)
|
| 702 |
+
|
| 703 |
+
# precompute row indices for scatter
|
| 704 |
+
batch_idx_L = (
|
| 705 |
+
torch.arange(batch_size, device=device)
|
| 706 |
+
.view(batch_size, 1)
|
| 707 |
+
.expand(batch_size, max_length)
|
| 708 |
+
)
|
| 709 |
+
pos_idx_L = (
|
| 710 |
+
torch.arange(max_length, device=device)
|
| 711 |
+
.view(1, max_length)
|
| 712 |
+
.expand(batch_size, max_length)
|
| 713 |
+
)
|
| 714 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 715 |
+
|
| 716 |
+
for i in range(steps):
|
| 717 |
+
# ——— predict and convert rates ———
|
| 718 |
+
pred_rate = model(xt, t)
|
| 719 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 720 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 721 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 722 |
+
|
| 723 |
+
# ——— get pretrained model rates for log_rnd computation ———
|
| 724 |
+
pretrained_pred = pretrained(xt, t)
|
| 725 |
+
pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
|
| 726 |
+
pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
|
| 727 |
+
pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
|
| 728 |
+
|
| 729 |
+
# ——— unmask step (Euler) ———
|
| 730 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 731 |
+
unmask_rate[xt != mask] = 0
|
| 732 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 733 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 734 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 735 |
+
|
| 736 |
+
# Same for pretrained
|
| 737 |
+
pretrained_unmask_rate[xt != mask] = 0
|
| 738 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = 0
|
| 739 |
+
pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 740 |
+
pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
|
| 741 |
+
|
| 742 |
+
# add “stay” probability
|
| 743 |
+
_xt = xt.clone()
|
| 744 |
+
_xt[xt == pad] = mask
|
| 745 |
+
trans_prob.scatter_add_(
|
| 746 |
+
2,
|
| 747 |
+
_xt.unsqueeze(-1),
|
| 748 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 749 |
+
)
|
| 750 |
+
pretrained_trans_prob.scatter_add_(
|
| 751 |
+
2,
|
| 752 |
+
_xt.unsqueeze(-1),
|
| 753 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# Apply temperature scaling
|
| 757 |
+
if temperature != 1.0:
|
| 758 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 759 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 760 |
+
|
| 761 |
+
if i == steps - 1:
|
| 762 |
+
print("Final step, removing mask token from sampling")
|
| 763 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 764 |
+
|
| 765 |
+
# renormalize probabilities to ensure they sum to 1
|
| 766 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 767 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 768 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 769 |
+
if mask_has_zero_prob.any():
|
| 770 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 771 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 772 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 773 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 774 |
+
else:
|
| 775 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 776 |
+
|
| 777 |
+
new_xt = _sample_tokens(trans_prob)
|
| 778 |
+
new_xt[xt == pad] = pad
|
| 779 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 780 |
+
|
| 781 |
+
# ——— compute log probabilities for RND ———
|
| 782 |
+
lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 783 |
+
lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
|
| 784 |
+
|
| 785 |
+
changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
|
| 786 |
+
|
| 787 |
+
log_policy_step = (lp * changed_mask).sum(dim=1)
|
| 788 |
+
log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
|
| 789 |
+
|
| 790 |
+
log_rnd = log_pretrained_step - log_policy_step # (B,)
|
| 791 |
+
|
| 792 |
+
if i != steps - 1:
|
| 793 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 794 |
+
|
| 795 |
+
insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 796 |
+
pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
|
| 797 |
+
|
| 798 |
+
log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
|
| 799 |
+
log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
|
| 800 |
+
|
| 801 |
+
log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
|
| 802 |
+
log_rnd += log_insert_diff
|
| 803 |
+
|
| 804 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 805 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 806 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 807 |
+
total_ext = ext.sum(dim=1)
|
| 808 |
+
valid = xt_len + total_ext <= max_length
|
| 809 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 810 |
+
|
| 811 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 812 |
+
new_len = xt_len + total_ext # (B,)
|
| 813 |
+
|
| 814 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 815 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 816 |
+
xt_tmp[mask_fill] = mask
|
| 817 |
+
|
| 818 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 819 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 820 |
+
flat_b = batch_idx_L[orig_mask]
|
| 821 |
+
flat_p = new_pos_orig[orig_mask]
|
| 822 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 823 |
+
else:
|
| 824 |
+
xt_tmp = new_xt
|
| 825 |
+
|
| 826 |
+
if return_trace:
|
| 827 |
+
# check if the token was changed
|
| 828 |
+
for i in range(batch_size):
|
| 829 |
+
for j in range(max_length):
|
| 830 |
+
if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
|
| 831 |
+
sampling_trace[i].append(
|
| 832 |
+
SamplingTraceDatapoint(
|
| 833 |
+
t=t[i].item(),
|
| 834 |
+
event_type="change",
|
| 835 |
+
position=j,
|
| 836 |
+
token=new_xt[i, j].item(),
|
| 837 |
+
)
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
# check if a new token was inserted
|
| 841 |
+
for j in range(max_length):
|
| 842 |
+
id = max_length - j - 1
|
| 843 |
+
if ext[i, id]:
|
| 844 |
+
sampling_trace[i].append(
|
| 845 |
+
SamplingTraceDatapoint(
|
| 846 |
+
t=t[i].item(),
|
| 847 |
+
event_type="insertion",
|
| 848 |
+
position=id,
|
| 849 |
+
token=mask,
|
| 850 |
+
)
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
xt = xt_tmp
|
| 854 |
+
t = t + dt
|
| 855 |
+
|
| 856 |
+
# change rewards for peptides
|
| 857 |
+
samples = xt.to(device)
|
| 858 |
+
|
| 859 |
+
# store raw token IDs
|
| 860 |
+
# Decode and strip samples
|
| 861 |
+
decoded_samples = tokenizer.batch_decode(samples)
|
| 862 |
+
|
| 863 |
+
valid_x_final = []
|
| 864 |
+
validSequences = []
|
| 865 |
+
valid_log_rnd = []
|
| 866 |
+
|
| 867 |
+
for idx, seq in enumerate(decoded_samples):
|
| 868 |
+
# check if the peptide is valid
|
| 869 |
+
if analyzer.is_peptide(seq):
|
| 870 |
+
valid_x_final.append(xt[idx])
|
| 871 |
+
validSequences.append(seq)
|
| 872 |
+
valid_log_rnd.append(log_rnd[idx])
|
| 873 |
+
|
| 874 |
+
print("len valid sequences:", len(validSequences))
|
| 875 |
+
# compute multi-objective rewards
|
| 876 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 877 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 878 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
|
| 879 |
+
|
| 880 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 881 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 882 |
+
|
| 883 |
+
log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
|
| 884 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 885 |
+
|
| 886 |
+
return valid_x_final, log_rnd, scalar_rewards, sampling_trace
|
| 887 |
+
|
| 888 |
+
@torch.no_grad()
|
| 889 |
+
def any_order_finetuned_euler_sampler(
|
| 890 |
+
model, reward_model, analyzer,
|
| 891 |
+
tokenizer, steps,
|
| 892 |
+
mask,
|
| 893 |
+
pad,
|
| 894 |
+
batch_size,
|
| 895 |
+
max_length,
|
| 896 |
+
return_trace = False,
|
| 897 |
+
dataframe = False,
|
| 898 |
+
temperature: float = 1.0,
|
| 899 |
+
):
|
| 900 |
+
device = next(model.parameters()).device
|
| 901 |
+
|
| 902 |
+
# initialize all‑pad sequence and trace
|
| 903 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 904 |
+
sampling_trace = []
|
| 905 |
+
|
| 906 |
+
dt = 1.0 / steps
|
| 907 |
+
t = torch.zeros(batch_size, device=device)
|
| 908 |
+
|
| 909 |
+
# precompute row indices for scatter
|
| 910 |
+
batch_idx_L = (
|
| 911 |
+
torch.arange(batch_size, device=device)
|
| 912 |
+
.view(batch_size, 1)
|
| 913 |
+
.expand(batch_size, max_length)
|
| 914 |
+
)
|
| 915 |
+
pos_idx_L = (
|
| 916 |
+
torch.arange(max_length, device=device)
|
| 917 |
+
.view(1, max_length)
|
| 918 |
+
.expand(batch_size, max_length)
|
| 919 |
+
)
|
| 920 |
+
sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
|
| 921 |
+
|
| 922 |
+
for i in range(steps):
|
| 923 |
+
# ——— predict and convert rates ———
|
| 924 |
+
pred_rate = model(xt, t)
|
| 925 |
+
pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 926 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 927 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 928 |
+
|
| 929 |
+
# ——— unmask step (Euler) ———
|
| 930 |
+
mask_pos = (xt == mask).nonzero(as_tuple=True)
|
| 931 |
+
unmask_rate[xt != mask] = 0
|
| 932 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 933 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 934 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 935 |
+
|
| 936 |
+
# add “stay” probability
|
| 937 |
+
_xt = xt.clone()
|
| 938 |
+
_xt[xt == pad] = mask
|
| 939 |
+
trans_prob.scatter_add_(
|
| 940 |
+
2,
|
| 941 |
+
_xt.unsqueeze(-1),
|
| 942 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
# Apply temperature scaling
|
| 946 |
+
if temperature != 1.0:
|
| 947 |
+
logits = torch.log(trans_prob + 1e-10) / temperature
|
| 948 |
+
trans_prob = torch.softmax(logits, dim=-1)
|
| 949 |
+
|
| 950 |
+
if i == steps - 1:
|
| 951 |
+
print("Final step, removing mask token from sampling")
|
| 952 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 953 |
+
|
| 954 |
+
# renormalize probabilities to ensure they sum to 1
|
| 955 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 956 |
+
# avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 957 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 958 |
+
if mask_has_zero_prob.any():
|
| 959 |
+
# create uniform distribution over valid tokens (excluding mask and pad)
|
| 960 |
+
uniform_prob = torch.zeros_like(trans_prob[0])
|
| 961 |
+
uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 962 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 963 |
+
else:
|
| 964 |
+
# normalize to sum to 1
|
| 965 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 966 |
+
|
| 967 |
+
new_xt = _sample_tokens(trans_prob)
|
| 968 |
+
new_xt[xt == pad] = pad
|
| 969 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 970 |
+
|
| 971 |
+
if i != steps - 1:
|
| 972 |
+
# gap-wise insertion refactored — compute new length, fill masks, scatter tokens
|
| 973 |
+
ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
|
| 974 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 975 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 976 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 977 |
+
total_ext = ext.sum(dim=1)
|
| 978 |
+
valid = xt_len + total_ext <= max_length
|
| 979 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 980 |
+
|
| 981 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 982 |
+
new_len = xt_len + total_ext # (B,)
|
| 983 |
+
|
| 984 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 985 |
+
mask_fill = pos_idx_L < new_len.view(batch_size, 1)
|
| 986 |
+
xt_tmp[mask_fill] = mask
|
| 987 |
+
|
| 988 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 989 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 990 |
+
flat_b = batch_idx_L[orig_mask]
|
| 991 |
+
flat_p = new_pos_orig[orig_mask]
|
| 992 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 993 |
+
else:
|
| 994 |
+
xt_tmp = new_xt
|
| 995 |
+
|
| 996 |
+
if return_trace:
|
| 997 |
+
# check if the token was changed
|
| 998 |
+
for batch_idx in range(batch_size):
|
| 999 |
+
for j in range(max_length):
|
| 1000 |
+
if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
|
| 1001 |
+
sampling_trace[batch_idx].append(
|
| 1002 |
+
SamplingTraceDatapoint(
|
| 1003 |
+
t=t[batch_idx].item(),
|
| 1004 |
+
event_type="change",
|
| 1005 |
+
position=j,
|
| 1006 |
+
token=new_xt[batch_idx, j].item(),
|
| 1007 |
+
)
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
# check if a new token was inserted
|
| 1011 |
+
for j in range(max_length):
|
| 1012 |
+
id = max_length - j - 1
|
| 1013 |
+
if ext[batch_idx, id]:
|
| 1014 |
+
sampling_trace[batch_idx].append(
|
| 1015 |
+
SamplingTraceDatapoint(
|
| 1016 |
+
t=t[batch_idx].item(),
|
| 1017 |
+
event_type="insertion",
|
| 1018 |
+
position=id,
|
| 1019 |
+
token=mask,
|
| 1020 |
+
)
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
xt = xt_tmp
|
| 1024 |
+
t = t + dt
|
| 1025 |
+
|
| 1026 |
+
# start eval
|
| 1027 |
+
samples = xt.to(device)
|
| 1028 |
+
|
| 1029 |
+
decoded_samples = tokenizer.batch_decode(samples)
|
| 1030 |
+
|
| 1031 |
+
valid_x_final = []
|
| 1032 |
+
validSequences = []
|
| 1033 |
+
|
| 1034 |
+
for idx, seq in enumerate(decoded_samples):
|
| 1035 |
+
if analyzer.is_peptide(seq):
|
| 1036 |
+
valid_x_final.append(samples[idx])
|
| 1037 |
+
validSequences.append(seq)
|
| 1038 |
+
|
| 1039 |
+
print("len valid sequences:", len(validSequences))
|
| 1040 |
+
valid_fraction = len(validSequences) / batch_size
|
| 1041 |
+
|
| 1042 |
+
if (len(validSequences) != 0):
|
| 1043 |
+
# add scores to log
|
| 1044 |
+
score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
|
| 1045 |
+
average_scores = score_vectors.T
|
| 1046 |
+
|
| 1047 |
+
affinity = average_scores[0]
|
| 1048 |
+
sol = average_scores[1]
|
| 1049 |
+
hemo = average_scores[2]
|
| 1050 |
+
nf = average_scores[3]
|
| 1051 |
+
permeability = average_scores[4]
|
| 1052 |
+
|
| 1053 |
+
else:
|
| 1054 |
+
zeros = [0.0]
|
| 1055 |
+
|
| 1056 |
+
affinity = zeros
|
| 1057 |
+
sol = zeros
|
| 1058 |
+
hemo = zeros
|
| 1059 |
+
nf = zeros
|
| 1060 |
+
permeability = zeros
|
| 1061 |
+
|
| 1062 |
+
if dataframe:
|
| 1063 |
+
df = pd.DataFrame({
|
| 1064 |
+
"Peptide Sequence": validSequences,
|
| 1065 |
+
"Binding Affinity": affinity if len(validSequences) else [0.0],
|
| 1066 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 1067 |
+
"Hemolysis": hemo if len(validSequences) else [0.0],
|
| 1068 |
+
"Nonfouling": nf if len(validSequences) else [0.0],
|
| 1069 |
+
"Permeability": permeability if len(validSequences) else [0.0],
|
| 1070 |
+
})
|
| 1071 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 1072 |
+
|
| 1073 |
+
return samples, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 1074 |
+
|
| 1075 |
+
@torch.no_grad()
|
| 1076 |
+
def mdm_tau_leaping_sampling(
|
| 1077 |
+
model: MaskedDiffusionModule,
|
| 1078 |
+
steps: int,
|
| 1079 |
+
mask: int,
|
| 1080 |
+
pad: int,
|
| 1081 |
+
batch_size: int,
|
| 1082 |
+
max_length: int,
|
| 1083 |
+
return_trace: bool = False,
|
| 1084 |
+
temperature: float = 1.0,
|
| 1085 |
+
):
|
| 1086 |
+
assert not return_trace, "Trace is not yet supported"
|
| 1087 |
+
device = next(model.parameters()).device
|
| 1088 |
+
xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
|
| 1089 |
+
dt = 1.0 / steps
|
| 1090 |
+
t = torch.zeros(batch_size, device=device)
|
| 1091 |
+
|
| 1092 |
+
for i in range(steps):
|
| 1093 |
+
# ——— predict and convert rates ———
|
| 1094 |
+
pred = model(xt, t)
|
| 1095 |
+
pred = model.interpolant.to_actual_rate(xt, pred, t)
|
| 1096 |
+
unmask_rate = pred.unmask_rate # (B, L, V)
|
| 1097 |
+
|
| 1098 |
+
if i == steps - 1:
|
| 1099 |
+
# last step: deterministic unmask via argmax
|
| 1100 |
+
mask_pos = xt == mask # (B, L)
|
| 1101 |
+
new_token = unmask_rate.argmax(dim=2) # (B, L)
|
| 1102 |
+
new_xt = xt.clone()
|
| 1103 |
+
new_xt[mask_pos] = new_token[mask_pos]
|
| 1104 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 1105 |
+
xt = new_xt
|
| 1106 |
+
t = t + dt
|
| 1107 |
+
continue
|
| 1108 |
+
# tau-leaping via Poisson counts
|
| 1109 |
+
counts = torch.poisson(unmask_rate * dt).long()
|
| 1110 |
+
mask_pos = xt == mask # (B, L)
|
| 1111 |
+
# zero out non-mask positions and mask→mask
|
| 1112 |
+
counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
|
| 1113 |
+
counts[..., mask] = 0
|
| 1114 |
+
# only accept exactly one event
|
| 1115 |
+
sum_c = counts.sum(dim=2) # (B, L)
|
| 1116 |
+
one_event = sum_c == 1
|
| 1117 |
+
new_token = counts.argmax(dim=2) # (B, L)
|
| 1118 |
+
|
| 1119 |
+
# build new xt
|
| 1120 |
+
new_xt = xt.clone()
|
| 1121 |
+
new_xt[one_event] = new_token[one_event]
|
| 1122 |
+
# keep pads and already-unmasked tokens
|
| 1123 |
+
new_xt = torch.where(xt != mask, xt, new_xt)
|
| 1124 |
+
xt = new_xt
|
| 1125 |
+
t = t + dt
|
| 1126 |
+
|
| 1127 |
+
return xt, []
|
| 1128 |
+
|
| 1129 |
+
# Not used in production, for debugging purposes
|
| 1130 |
+
lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
|
| 1131 |
+
|
| 1132 |
+
def binomial_mass(k, n, p):
|
| 1133 |
+
"""
|
| 1134 |
+
Calculate the probability mass function (PMF) for a binomial distribution.
|
| 1135 |
+
|
| 1136 |
+
Args:
|
| 1137 |
+
k (int): Number of successes
|
| 1138 |
+
n (int): Number of trials
|
| 1139 |
+
p (float): Probability of success in a single trial
|
| 1140 |
+
|
| 1141 |
+
Returns:
|
| 1142 |
+
float: Probability mass P(X = k)
|
| 1143 |
+
"""
|
| 1144 |
+
import math
|
| 1145 |
+
|
| 1146 |
+
# Calculate binomial coefficient (n choose k)
|
| 1147 |
+
try:
|
| 1148 |
+
binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
|
| 1149 |
+
except ValueError:
|
| 1150 |
+
# Handle cases where k > n or negative values
|
| 1151 |
+
return 0.0
|
| 1152 |
+
|
| 1153 |
+
# Calculate probability mass
|
| 1154 |
+
return binom_coef * (p ** k) * ((1 - p) ** (n - k))
|
| 1155 |
+
|
| 1156 |
+
def calculate_rate_batch(alpha_t, len_t):
|
| 1157 |
+
"""
|
| 1158 |
+
Calculate rate for a batch of alpha_t and len_t values.
|
| 1159 |
+
|
| 1160 |
+
Args:
|
| 1161 |
+
alpha_t (torch.Tensor): Tensor of shape (batch_size,)
|
| 1162 |
+
len_t (torch.Tensor): Tensor of shape (batch_size,)
|
| 1163 |
+
|
| 1164 |
+
Returns:
|
| 1165 |
+
torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
|
| 1166 |
+
"""
|
| 1167 |
+
batch_size = alpha_t.shape[0]
|
| 1168 |
+
device = alpha_t.device
|
| 1169 |
+
|
| 1170 |
+
# Initialize tensors for numerator and denominator
|
| 1171 |
+
nom = torch.zeros(batch_size, device=device)
|
| 1172 |
+
denom = torch.zeros(batch_size, device=device)
|
| 1173 |
+
|
| 1174 |
+
for length, probability in lengths.items():
|
| 1175 |
+
# Create mask for valid entries where len_t <= length
|
| 1176 |
+
valid_mask = (len_t <= length) & (len_t >= 0)
|
| 1177 |
+
|
| 1178 |
+
if not valid_mask.any():
|
| 1179 |
+
continue
|
| 1180 |
+
|
| 1181 |
+
valid_indices = valid_mask.nonzero(as_tuple=True)[0]
|
| 1182 |
+
valid_len_t = len_t[valid_indices]
|
| 1183 |
+
valid_alpha_t = alpha_t[valid_indices]
|
| 1184 |
+
|
| 1185 |
+
# Calculate binomial probabilities efficiently using torch distribution
|
| 1186 |
+
binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
|
| 1187 |
+
binom_probs = binom_dist.log_prob(valid_len_t).exp()
|
| 1188 |
+
|
| 1189 |
+
# Update numerator and denominator for valid indices
|
| 1190 |
+
nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
|
| 1191 |
+
denom[valid_indices] += probability * binom_probs
|
| 1192 |
+
|
| 1193 |
+
# Handle division by zero in a vectorized way
|
| 1194 |
+
result = torch.zeros_like(nom)
|
| 1195 |
+
div_mask = denom > 0
|
| 1196 |
+
result[div_mask] = nom[div_mask] / (denom[div_mask])
|
| 1197 |
+
|
| 1198 |
+
return result
|
| 1199 |
+
|
| 1200 |
+
# Keep the original function for backward compatibility
|
| 1201 |
+
def calculate_rate(alpha_t, len_t):
|
| 1202 |
+
"""Legacy scalar version of calculate_rate"""
|
| 1203 |
+
if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
|
| 1204 |
+
return calculate_rate_batch(alpha_t, len_t)
|
| 1205 |
+
|
| 1206 |
+
nom, denom = 0, 0
|
| 1207 |
+
for length, probability in lengths.items():
|
| 1208 |
+
if length >= len_t:
|
| 1209 |
+
nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
|
| 1210 |
+
denom += probability * binomial_mass(len_t, length, alpha_t)
|
| 1211 |
+
|
| 1212 |
+
if denom == 0:
|
| 1213 |
+
return 0.0
|
| 1214 |
+
|
| 1215 |
+
return nom /denom
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
@torch.no_grad()
|
| 1219 |
+
def any_order_mask_insertion_tau_leaping_sampling(
|
| 1220 |
+
model: torch.nn.Module,
|
| 1221 |
+
steps: int,
|
| 1222 |
+
mask: int,
|
| 1223 |
+
pad: int,
|
| 1224 |
+
batch_size: int,
|
| 1225 |
+
max_length: int,
|
| 1226 |
+
return_trace: bool = False,
|
| 1227 |
+
confidence_based_sampling: bool = True, # whether to use confidence-based decoding
|
| 1228 |
+
alpha: float = 5.0, # hyperparameter for window size calculation
|
| 1229 |
+
max_window: int = 32, # Maximum window size for sliding window
|
| 1230 |
+
confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
|
| 1231 |
+
use_sliding_window: bool = False, # whether to use sliding window for position selection
|
| 1232 |
+
temperature: float = 1.0,
|
| 1233 |
+
) -> SamplingResult:
|
| 1234 |
+
|
| 1235 |
+
device = next(model.parameters()).device
|
| 1236 |
+
xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
|
| 1237 |
+
sampling_trace = []
|
| 1238 |
+
dt = 1.0 / steps
|
| 1239 |
+
t = torch.zeros(batch_size, device=device)
|
| 1240 |
+
|
| 1241 |
+
# Precompute row indices for scatter
|
| 1242 |
+
batch_idx_L = (
|
| 1243 |
+
torch.arange(batch_size, device=device)
|
| 1244 |
+
.view(batch_size, 1)
|
| 1245 |
+
.expand(batch_size, max_length)
|
| 1246 |
+
)
|
| 1247 |
+
pos_idx_L = (
|
| 1248 |
+
torch.arange(max_length, device=device)
|
| 1249 |
+
.view(1, max_length)
|
| 1250 |
+
.expand(batch_size, max_length)
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
for i in range(steps):
|
| 1254 |
+
# --- predict rates ---
|
| 1255 |
+
pred = model(xt, t)
|
| 1256 |
+
xt_len = (xt != pad).sum(dim=1)
|
| 1257 |
+
pred = model.interpolant.to_actual_rate(xt, pred, t)
|
| 1258 |
+
unmask_rate = pred.unmask_rate # (B, L, V)
|
| 1259 |
+
len_rate = pred.length_rate # (B, L+1)
|
| 1260 |
+
|
| 1261 |
+
if i == steps - 1:
|
| 1262 |
+
# last step: deterministic unmask via argmax
|
| 1263 |
+
mask_pos = xt == mask
|
| 1264 |
+
new_token = unmask_rate.argmax(dim=2)
|
| 1265 |
+
new_xt = xt.clone()
|
| 1266 |
+
new_xt[mask_pos] = new_token[mask_pos]
|
| 1267 |
+
new_xt = torch.where(xt == pad, pad, new_xt)
|
| 1268 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 1269 |
+
xt = new_xt
|
| 1270 |
+
t = t + dt
|
| 1271 |
+
continue
|
| 1272 |
+
|
| 1273 |
+
# --- confidence-based decoding ---
|
| 1274 |
+
if confidence_based_sampling > 0.0:
|
| 1275 |
+
# Confidence-based unmasking (vectorized)
|
| 1276 |
+
mask_positions = (xt == mask) # (B, L)
|
| 1277 |
+
num_mask_positions = mask_positions.sum(dim=1) # (B,)
|
| 1278 |
+
|
| 1279 |
+
# 1. Determine number of tokens to unmask using Poisson
|
| 1280 |
+
unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
|
| 1281 |
+
|
| 1282 |
+
# 2. Calculate confidence based on selected method
|
| 1283 |
+
if confidence_method == "position":
|
| 1284 |
+
# Position-based confidence: position i / len(xt)
|
| 1285 |
+
xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
|
| 1286 |
+
position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
|
| 1287 |
+
confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
|
| 1288 |
+
|
| 1289 |
+
elif confidence_method == "top_prob":
|
| 1290 |
+
# Top probability confidence
|
| 1291 |
+
import torch.nn.functional as F
|
| 1292 |
+
token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
|
| 1293 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1294 |
+
confidence = unmask_probs.max(dim=-1)[0] # (B, L)
|
| 1295 |
+
|
| 1296 |
+
elif confidence_method == "prob_diff":
|
| 1297 |
+
# Probability difference confidence (top - second top)
|
| 1298 |
+
import torch.nn.functional as F
|
| 1299 |
+
token_logits = unmask_rate # (B, L, V)
|
| 1300 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1301 |
+
top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
|
| 1302 |
+
confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
|
| 1303 |
+
|
| 1304 |
+
elif confidence_method == "entropy":
|
| 1305 |
+
# Entropy-based confidence (lower entropy = higher confidence)
|
| 1306 |
+
import torch.nn.functional as F
|
| 1307 |
+
token_logits = unmask_rate # (B, L, V)
|
| 1308 |
+
unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
|
| 1309 |
+
entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
|
| 1310 |
+
confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
|
| 1311 |
+
|
| 1312 |
+
else:
|
| 1313 |
+
raise ValueError(f"Unknown confidence_method: {confidence_method}")
|
| 1314 |
+
|
| 1315 |
+
# 3. Apply window constraint if enabled
|
| 1316 |
+
if use_sliding_window:
|
| 1317 |
+
# Calculate dynamic k for each batch
|
| 1318 |
+
k_values = torch.minimum(
|
| 1319 |
+
torch.minimum(
|
| 1320 |
+
(alpha * unmask_counts).long(),
|
| 1321 |
+
torch.tensor(max_window, device=device)
|
| 1322 |
+
), num_mask_positions) # (B,)
|
| 1323 |
+
|
| 1324 |
+
# Get cumulative count of mask positions
|
| 1325 |
+
mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
|
| 1326 |
+
|
| 1327 |
+
# Create window mask: position is eligible if it's a mask and within first k masks
|
| 1328 |
+
is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
|
| 1329 |
+
window_mask = mask_positions & is_within_window # (B, L)
|
| 1330 |
+
|
| 1331 |
+
# Set confidence to -inf for positions outside the window or non-mask positions
|
| 1332 |
+
confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
|
| 1333 |
+
else:
|
| 1334 |
+
# No window constraint - only mask positions are eligible
|
| 1335 |
+
confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
|
| 1336 |
+
|
| 1337 |
+
new_xt = xt.clone()
|
| 1338 |
+
|
| 1339 |
+
# vectorized unmasking
|
| 1340 |
+
max_unmask = unmask_counts.max().item()
|
| 1341 |
+
if max_unmask > 0:
|
| 1342 |
+
_, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
|
| 1343 |
+
|
| 1344 |
+
# create mask for valid unmask operations
|
| 1345 |
+
unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
|
| 1346 |
+
|
| 1347 |
+
most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
|
| 1348 |
+
|
| 1349 |
+
selected_positions = all_top_indices[unmask_mask]
|
| 1350 |
+
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
|
| 1351 |
+
|
| 1352 |
+
new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
|
| 1353 |
+
else:
|
| 1354 |
+
# --- tau-leaping unmask via Poisson ---
|
| 1355 |
+
counts = torch.poisson(unmask_rate * dt).long()
|
| 1356 |
+
mask_pos = xt == mask
|
| 1357 |
+
counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
|
| 1358 |
+
counts[..., mask] = 0
|
| 1359 |
+
sum_c = counts.sum(dim=2)
|
| 1360 |
+
one_event = sum_c == 1
|
| 1361 |
+
new_token = counts.argmax(dim=2)
|
| 1362 |
+
new_xt = xt.clone()
|
| 1363 |
+
new_xt[one_event] = new_token[one_event]
|
| 1364 |
+
new_xt = torch.where(xt == pad, pad, new_xt)
|
| 1365 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 1366 |
+
|
| 1367 |
+
# insertion only on non-last
|
| 1368 |
+
if i != steps - 1:
|
| 1369 |
+
# --- Poisson insertion, compute new lengths and fill masks ---
|
| 1370 |
+
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
|
| 1371 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 1372 |
+
gaps = torch.arange(max_length + 1, device=device).view(1, -1)
|
| 1373 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 1374 |
+
total_ext = ext.sum(dim=1)
|
| 1375 |
+
valid = xt_len + total_ext <= max_length
|
| 1376 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 1377 |
+
|
| 1378 |
+
# compute prefix sums of insertions
|
| 1379 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 1380 |
+
new_len = xt_len + total_ext # (B,)
|
| 1381 |
+
|
| 1382 |
+
# initialize with pads, then fill mask up to new_len
|
| 1383 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 1384 |
+
mask_pos = pos_idx_L < new_len.view(batch_size, 1)
|
| 1385 |
+
xt_tmp[mask_pos] = mask
|
| 1386 |
+
|
| 1387 |
+
# shift and scatter original tokens
|
| 1388 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
|
| 1389 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 1390 |
+
flat_b = batch_idx_L[orig_mask]
|
| 1391 |
+
flat_p = new_pos_orig[orig_mask]
|
| 1392 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 1393 |
+
else:
|
| 1394 |
+
xt_tmp = new_xt
|
| 1395 |
+
|
| 1396 |
+
xt = xt_tmp
|
| 1397 |
+
t = t + dt
|
| 1398 |
+
if return_trace:
|
| 1399 |
+
sampling_trace.append(xt)
|
| 1400 |
+
|
| 1401 |
+
return xt, sampling_trace
|
a2d2_pep/scripts/run_peptide_finetune.slurm
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# NOTE: --partition and --qos below are specific to our cluster. Change them
|
| 3 |
+
# (or remove them and pass `--partition` on the `sbatch` command line) to match
|
| 4 |
+
# the partitions/QOS available on yours.
|
| 5 |
+
#SBATCH --job-name=peptide-finetune-len256
|
| 6 |
+
#SBATCH --partition=b200-mig90
|
| 7 |
+
#SBATCH --qos=mig
|
| 8 |
+
#SBATCH --nodes=1
|
| 9 |
+
#SBATCH --gpus-per-node=1
|
| 10 |
+
#SBATCH --cpus-per-task=8
|
| 11 |
+
#SBATCH --ntasks-per-node=1
|
| 12 |
+
#SBATCH --mem=80GB
|
| 13 |
+
#SBATCH --time=02-00:00:00
|
| 14 |
+
#SBATCH --output=logs/peptide_finetune_%A.log
|
| 15 |
+
|
| 16 |
+
# =====================================================================
|
| 17 |
+
# run_peptide_finetune.slurm
|
| 18 |
+
#
|
| 19 |
+
# Single-mode job (1 MIG GPU) running ONE finetune_quality (peptide)
|
| 20 |
+
# experiment. Select which mode to run via the MODE_ID variable below
|
| 21 |
+
# (or override at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
|
| 22 |
+
# 0) A2D2 (Ours) – with full planner (alternating)
|
| 23 |
+
# 1) A2D2 w/o quality – --disable_planner
|
| 24 |
+
# 2) A2D2 w/o insertion planner – --disable_insertion_planner
|
| 25 |
+
# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
|
| 26 |
+
#
|
| 27 |
+
# The job trains the selected mode then evaluates the resulting
|
| 28 |
+
# checkpoint on the same GPU.
|
| 29 |
+
# =====================================================================
|
| 30 |
+
|
| 31 |
+
set -e
|
| 32 |
+
|
| 33 |
+
# --- Mode selection ---------------------------------------------------
|
| 34 |
+
# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
|
| 35 |
+
MODE_ID="${MODE_ID:-0}"
|
| 36 |
+
|
| 37 |
+
# Run prefix: YYYYMMDD + SLURM job ID
|
| 38 |
+
DATE_STAMP=$(date +%Y%m%d)
|
| 39 |
+
PREFIX="${DATE_STAMP}_job${SLURM_JOB_ID:-local$(date +%H%M%S)}"
|
| 40 |
+
|
| 41 |
+
# Default protein target (must be defined before path definitions below)
|
| 42 |
+
PROT_NAME=tfr
|
| 43 |
+
|
| 44 |
+
# --- Paths ------------------------------------------------------------
|
| 45 |
+
# Repo root is resolved at submit time so the script works from any clone:
|
| 46 |
+
# - set A2D2_ROOT explicitly, OR
|
| 47 |
+
# - run `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR), OR
|
| 48 |
+
# - fall back to this script's location (a2d2_pep/scripts/ -> two levels up).
|
| 49 |
+
if [ -n "${A2D2_ROOT:-}" ]; then
|
| 50 |
+
HOME_LOC="$A2D2_ROOT"
|
| 51 |
+
elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
|
| 52 |
+
HOME_LOC="$SLURM_SUBMIT_DIR"
|
| 53 |
+
else
|
| 54 |
+
HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 55 |
+
fi
|
| 56 |
+
SCRIPT_LOC="$HOME_LOC/a2d2_pep"
|
| 57 |
+
LOG_LOC="$HOME_LOC/logs"
|
| 58 |
+
SAVE_DIR="$HOME_LOC/checkpoints/finetune_test_peptides_${PROT_NAME}"
|
| 59 |
+
RESULTS_DIR="$HOME_LOC/results/peptide_test_ablation_${PROT_NAME}"
|
| 60 |
+
|
| 61 |
+
cd "$SCRIPT_LOC"
|
| 62 |
+
|
| 63 |
+
# BASE_PATH is passed as --base_path to finetune_quality.py: it's used
|
| 64 |
+
# to build the plot output path at $BASE_PATH/flexible/results/<run_name>
|
| 65 |
+
# (see finetune_quality.py:421). The pretrained checkpoint is now passed
|
| 66 |
+
# explicitly via --checkpoint_path below, so base_path no longer needs
|
| 67 |
+
# to follow the legacy /scratch layout.
|
| 68 |
+
BASE_PATH="${A2D2_BASE_PATH:-$HOME_LOC}"
|
| 69 |
+
|
| 70 |
+
mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
|
| 71 |
+
|
| 72 |
+
# --- Environment setup ------------------------------------------------
|
| 73 |
+
# Do NOT hardcode your W&B key. Either `wandb login` once on the cluster,
|
| 74 |
+
# export WANDB_API_KEY in your shell/SLURM environment before submitting,
|
| 75 |
+
# or set WANDB_MODE=offline to skip logging entirely.
|
| 76 |
+
export WANDB_DIR=$HOME_LOC/.wandb
|
| 77 |
+
export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
|
| 78 |
+
export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
|
| 79 |
+
# Stop wandb from hijacking stdout/stderr (its default fd-redirect mode sends
|
| 80 |
+
# all output to wandb/run-*/files/output.log and freezes the RUN_LOG below).
|
| 81 |
+
# With console off, everything flows to the `>> "$RUN_LOG" 2>&1` redirect.
|
| 82 |
+
export WANDB_CONSOLE=off
|
| 83 |
+
mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
|
| 84 |
+
|
| 85 |
+
export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
|
| 86 |
+
mkdir -p "$TRITON_CACHE_DIR"
|
| 87 |
+
|
| 88 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 89 |
+
|
| 90 |
+
# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
|
| 91 |
+
# install, or just have `conda` on PATH; override CONDA_ENV if your env name
|
| 92 |
+
# differs from the one created by environment.yml.
|
| 93 |
+
CONDA_ENV="${CONDA_ENV:-a2d2}"
|
| 94 |
+
if [ -n "${CONDA_ROOT:-}" ]; then
|
| 95 |
+
source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
|
| 96 |
+
elif command -v conda >/dev/null 2>&1; then
|
| 97 |
+
source "$(conda info --base)/bin/activate" "$CONDA_ENV"
|
| 98 |
+
else
|
| 99 |
+
echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
|
| 100 |
+
exit 1
|
| 101 |
+
fi
|
| 102 |
+
PYTHON_EXECUTABLE=$(which python)
|
| 103 |
+
|
| 104 |
+
# Pretrained base checkpoint
|
| 105 |
+
PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_pep.ckpt"
|
| 106 |
+
|
| 107 |
+
# --- Shared training hyperparameters ----------------------------------
|
| 108 |
+
COMMON_ARGS=(
|
| 109 |
+
--base_path "$BASE_PATH"
|
| 110 |
+
--checkpoint_path "$PRETRAINED_CKPT"
|
| 111 |
+
--prot_name "$PROT_NAME"
|
| 112 |
+
--noise_removal
|
| 113 |
+
--wdce_num_replicates 8
|
| 114 |
+
--pool_size 100
|
| 115 |
+
--pool_refresh_fraction 1.0
|
| 116 |
+
--buffer_size 50
|
| 117 |
+
--batch_size 200
|
| 118 |
+
--total_num_steps 256
|
| 119 |
+
--num_iter 20
|
| 120 |
+
--resample_every_n_step 10
|
| 121 |
+
--num_epochs 1000
|
| 122 |
+
--save_every_n_epochs 50
|
| 123 |
+
--reset_every_n_step 1
|
| 124 |
+
--alpha 0.1
|
| 125 |
+
--no_mcts
|
| 126 |
+
--schedule_warmup_epochs 20
|
| 127 |
+
--alternation_frequency 5
|
| 128 |
+
--num_remasking 3
|
| 129 |
+
--quality_threshold 0.2
|
| 130 |
+
--training_mini_batch_size 10
|
| 131 |
+
--max_length 256
|
| 132 |
+
--eval_every_n_epochs 50
|
| 133 |
+
--min_peptide_bonds 4
|
| 134 |
+
--grad_clip
|
| 135 |
+
--seed 42
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# --- Shared evaluation hyperparameters --------------------------------
|
| 139 |
+
EVAL_COMMON_ARGS=(
|
| 140 |
+
--pretrained_ckpt "$PRETRAINED_CKPT"
|
| 141 |
+
--num_samples 50
|
| 142 |
+
--batch_size 200
|
| 143 |
+
--max_length 256
|
| 144 |
+
--total_num_steps 256
|
| 145 |
+
--num_remasking 3
|
| 146 |
+
--quality_threshold 0.2
|
| 147 |
+
--prot_name "$PROT_NAME"
|
| 148 |
+
--seed 42
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# =====================================================================
|
| 152 |
+
# Pick experiment from $MODE_ID
|
| 153 |
+
# =====================================================================
|
| 154 |
+
case "$MODE_ID" in
|
| 155 |
+
0) MODE="with_planner"; EXTRA_ARGS=() ;;
|
| 156 |
+
1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
|
| 157 |
+
2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
|
| 158 |
+
3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
|
| 159 |
+
*) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
|
| 160 |
+
esac
|
| 161 |
+
|
| 162 |
+
RUN_NAME="${PREFIX}_peptide_${PROT_NAME}_${MODE}"
|
| 163 |
+
RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
|
| 164 |
+
RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
|
| 165 |
+
RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
|
| 166 |
+
mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
|
| 167 |
+
|
| 168 |
+
echo "=== Peptide finetune (MODE_ID=$MODE_ID) ==="
|
| 169 |
+
echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
|
| 170 |
+
echo "Mode: $MODE"
|
| 171 |
+
echo "Save dir: $RUN_SAVE_DIR"
|
| 172 |
+
echo "Results dir: $RESULTS_SUBDIR"
|
| 173 |
+
echo "Python: $PYTHON_EXECUTABLE"
|
| 174 |
+
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
|
| 175 |
+
|
| 176 |
+
# =====================================================================
|
| 177 |
+
# Train
|
| 178 |
+
# =====================================================================
|
| 179 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_quality.py \
|
| 180 |
+
"${COMMON_ARGS[@]}" \
|
| 181 |
+
--devices 1 \
|
| 182 |
+
"${EXTRA_ARGS[@]}" \
|
| 183 |
+
--save_path_dir "$RUN_SAVE_DIR" \
|
| 184 |
+
>> "$RUN_LOG" 2>&1
|
| 185 |
+
|
| 186 |
+
echo "Training finished for $MODE. Log: $RUN_LOG"
|
| 187 |
+
|
| 188 |
+
# =====================================================================
|
| 189 |
+
# Evaluate
|
| 190 |
+
# =====================================================================
|
| 191 |
+
# finetune_quality.py saves to $RUN_SAVE_DIR/<auto_run_name>/last.ckpt,
|
| 192 |
+
# so glob the run_name subdir.
|
| 193 |
+
RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt 2>/dev/null | head -1)
|
| 194 |
+
if [ -z "$RUN_CKPT" ]; then
|
| 195 |
+
echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
|
| 196 |
+
exit 1
|
| 197 |
+
fi
|
| 198 |
+
|
| 199 |
+
echo "Evaluating checkpoint: $RUN_CKPT"
|
| 200 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_peptide_table.py \
|
| 201 |
+
--checkpoint_path "$RUN_CKPT" \
|
| 202 |
+
"${EVAL_COMMON_ARGS[@]}" \
|
| 203 |
+
"${EXTRA_ARGS[@]}" \
|
| 204 |
+
--output_dir "$RESULTS_SUBDIR" \
|
| 205 |
+
--device cuda:0 \
|
| 206 |
+
>> "$RESULTS_SUBDIR/${RUN_NAME}_eval.log" 2>&1
|
| 207 |
+
|
| 208 |
+
echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}_${PROT_NAME}.csv"
|
| 209 |
+
|
| 210 |
+
conda deactivate
|
a2d2_pep/scripts/train_pep.sh
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=a2d2-pep-pretrain
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --gpus-per-node=4
|
| 6 |
+
#SBATCH --ntasks-per-node=4
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=512GB
|
| 9 |
+
#SBATCH --time=7-00:00:00
|
| 10 |
+
# SLURM's own catch-file (anything printed before the exec redirect below, plus
|
| 11 |
+
# slurm-infra messages). Relative to the submit dir, so submit this script from
|
| 12 |
+
# the a2d2_pep/ directory; the real run output is redirected via exec below.
|
| 13 |
+
#SBATCH --output=logs/slurm/%x_%j.out
|
| 14 |
+
#SBATCH --error=logs/slurm/%x_%j.err
|
| 15 |
+
#
|
| 16 |
+
# Pretrain the any-length insertion MDM on ~11M peptide SMILES on a dgx-b200 node.
|
| 17 |
+
# Submit with: sbatch scripts/train_pep.sh (from the a2d2_pep/ directory).
|
| 18 |
+
#
|
| 19 |
+
# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
|
| 20 |
+
# --ntasks-per-node must match; change both together (and they override the
|
| 21 |
+
# training.devices value baked into config_pep.yaml via the hydra override below).
|
| 22 |
+
|
| 23 |
+
DATE=$(date +%Y%m%d)
|
| 24 |
+
SPECIAL_PREFIX='a2d2-peptide'
|
| 25 |
+
|
| 26 |
+
# Resolve a2d2_pep/ (which holds train.py + config_pep.yaml) so paths are
|
| 27 |
+
# repo-relative. This script lives in a2d2_pep/scripts/, so the direct-run
|
| 28 |
+
# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
|
| 29 |
+
# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_pep/ directory).
|
| 30 |
+
if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
|
| 31 |
+
SCRIPT_DIR="$SLURM_SUBMIT_DIR"
|
| 32 |
+
else
|
| 33 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 34 |
+
fi
|
| 35 |
+
cd "$SCRIPT_DIR"
|
| 36 |
+
|
| 37 |
+
# Auto-detect GPUs from the SLURM allocation (falls back to 4 for `bash` runs).
|
| 38 |
+
DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-4}}
|
| 39 |
+
NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
|
| 40 |
+
NODES=${SLURM_NNODES:-1}
|
| 41 |
+
|
| 42 |
+
LOG_LOC="$SCRIPT_DIR/logs"
|
| 43 |
+
mkdir -p "$LOG_LOC/slurm"
|
| 44 |
+
exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Weights & Biases: log in once on your machine before running this script with
|
| 48 |
+
# `wandb login` (or `export WANDB_API_KEY=<your-key>`).
|
| 49 |
+
# Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
|
| 50 |
+
# export WANDB_MODE=disabled
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 54 |
+
|
| 55 |
+
# Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
|
| 56 |
+
# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
|
| 57 |
+
# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
|
| 58 |
+
# PATH; override CONDA_ENV if your env name differs from the one created by
|
| 59 |
+
# environment.yml.
|
| 60 |
+
CONDA_ENV="${CONDA_ENV:-a2d2}"
|
| 61 |
+
if [ -n "${CONDA_ROOT:-}" ]; then
|
| 62 |
+
source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
|
| 63 |
+
elif command -v conda >/dev/null 2>&1; then
|
| 64 |
+
source "$(conda info --base)/bin/activate" "$CONDA_ENV"
|
| 65 |
+
else
|
| 66 |
+
echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
|
| 67 |
+
exit 1
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
# --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
|
| 71 |
+
ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
|
| 72 |
+
if [ -z "$ETH_IFACE" ]; then
|
| 73 |
+
ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
|
| 74 |
+
fi
|
| 75 |
+
export NCCL_IB_DISABLE=1
|
| 76 |
+
export NCCL_SOCKET_FAMILY=AF_INET
|
| 77 |
+
export NCCL_SOCKET_IFNAME=$ETH_IFACE
|
| 78 |
+
export NCCL_P2P_LEVEL=NVL
|
| 79 |
+
|
| 80 |
+
export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
|
| 81 |
+
export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
|
| 82 |
+
export NODE_RANK=${SLURM_NODEID:-0}
|
| 83 |
+
|
| 84 |
+
echo "=== a2d2 peptide pretraining (dgx-b200) ==="
|
| 85 |
+
echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
|
| 86 |
+
|
| 87 |
+
# --task pep makes train.py load config_pep.yaml; the hydra overrides pin
|
| 88 |
+
# devices/nodes to the SLURM allocation so the two never drift apart.
|
| 89 |
+
srun --ntasks-per-node=$NTASKS python train.py --task pep \
|
| 90 |
+
training.devices=$DEVICES \
|
| 91 |
+
training.nodes=$NODES
|
| 92 |
+
|
| 93 |
+
conda deactivate
|
a2d2_pep/train.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import argparse
|
| 8 |
+
import hydra
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
# Directory containing this file and the config_*.yaml files (used by Hydra below).
|
| 12 |
+
CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
|
| 14 |
+
sys.path.insert(0, os.path.dirname(CONFIG_DIR))
|
| 15 |
+
|
| 16 |
+
import wandb
|
| 17 |
+
from lightning_modules import AnyOrderInsertionFlowModule
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
torch.set_printoptions(threshold=10_000)
|
| 21 |
+
torch.set_float32_matmul_precision("high")
|
| 22 |
+
|
| 23 |
+
# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
|
| 24 |
+
torch._dynamo.config.optimize_ddp = False
|
| 25 |
+
|
| 26 |
+
def train(config):
|
| 27 |
+
wandb_logger = None
|
| 28 |
+
|
| 29 |
+
# set the random seed
|
| 30 |
+
pl.seed_everything(42)
|
| 31 |
+
torch.manual_seed(42)
|
| 32 |
+
|
| 33 |
+
# Only initialize wandb on rank 0 to avoid multiple runs
|
| 34 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
| 35 |
+
wandb.init(
|
| 36 |
+
project=config.wandb.project,
|
| 37 |
+
name=config.wandb.name,
|
| 38 |
+
config=OmegaConf.to_container(config, resolve=True), # Convert to dict
|
| 39 |
+
dir=config.wandb.path
|
| 40 |
+
)
|
| 41 |
+
wandb_logger = WandbLogger(
|
| 42 |
+
project=wandb.run.project,
|
| 43 |
+
name=wandb.run.name,
|
| 44 |
+
log_model=False, # Disable checkpoint uploading to save disk space
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Modify config to add timestamp to checkpoint directory
|
| 48 |
+
OmegaConf.set_struct(config, False)
|
| 49 |
+
time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 50 |
+
config.training.checkpoint_dir = os.path.join(
|
| 51 |
+
config.training.checkpoint_dir, time_string
|
| 52 |
+
)
|
| 53 |
+
OmegaConf.set_struct(config, True)
|
| 54 |
+
|
| 55 |
+
# Create checkpoint directory
|
| 56 |
+
os.makedirs(config.training.checkpoint_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# Setup data module - check if using HuggingFace dataset
|
| 59 |
+
if hasattr(config, 'hf_dataset'):
|
| 60 |
+
# Imported lazily: the HF/SAFE path is only used by the molecule configs,
|
| 61 |
+
# which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
|
| 62 |
+
from mol_dataset import setup_hf_data_and_update_config
|
| 63 |
+
print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
|
| 64 |
+
data_module = setup_hf_data_and_update_config(
|
| 65 |
+
config,
|
| 66 |
+
dataset_name=config.hf_dataset.name,
|
| 67 |
+
smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
# Imported lazily: the local (arrow) path is used by the peptide config,
|
| 71 |
+
# which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
|
| 72 |
+
from data.dataloading_for_dynamic_batching import setup_data_and_update_config
|
| 73 |
+
print("Using local dataset")
|
| 74 |
+
data_module = setup_data_and_update_config(config)
|
| 75 |
+
|
| 76 |
+
module = AnyOrderInsertionFlowModule(config)
|
| 77 |
+
|
| 78 |
+
# Initialize trainer
|
| 79 |
+
|
| 80 |
+
# Configure trainer arguments
|
| 81 |
+
# Map torch_dtype to Lightning precision
|
| 82 |
+
dtype_str = config.model.get('torch_dtype', 'bfloat16')
|
| 83 |
+
precision_map = {
|
| 84 |
+
'float32': '32-true',
|
| 85 |
+
'float16': '16-mixed',
|
| 86 |
+
'bfloat16': 'bf16-mixed'
|
| 87 |
+
}
|
| 88 |
+
precision = precision_map.get(dtype_str, 'bf16-mixed')
|
| 89 |
+
|
| 90 |
+
trainer_kwargs = dict(
|
| 91 |
+
num_nodes=config.training.nodes,
|
| 92 |
+
accelerator="gpu",
|
| 93 |
+
devices=config.training.devices,
|
| 94 |
+
strategy="ddp",
|
| 95 |
+
precision=precision,
|
| 96 |
+
accumulate_grad_batches=(
|
| 97 |
+
config.training.batch_size
|
| 98 |
+
// (
|
| 99 |
+
config.training.per_gpu_batch_size
|
| 100 |
+
* config.training.nodes
|
| 101 |
+
* config.training.devices
|
| 102 |
+
)
|
| 103 |
+
),
|
| 104 |
+
log_every_n_steps=10,
|
| 105 |
+
enable_checkpointing=True,
|
| 106 |
+
default_root_dir=config.training.checkpoint_dir,
|
| 107 |
+
gradient_clip_val=1.0,
|
| 108 |
+
)
|
| 109 |
+
# Only one of max_steps or max_epochs will be used
|
| 110 |
+
if config.training.max_steps is not None:
|
| 111 |
+
trainer_kwargs["max_steps"] = config.training.max_steps
|
| 112 |
+
elif config.training.num_epochs is not None:
|
| 113 |
+
trainer_kwargs["max_epochs"] = config.training.num_epochs
|
| 114 |
+
config.training.max_steps = config.training.max_steps
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"Either max_steps or num_epochs must be specified in the config"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if config.training.warmup_steps is None:
|
| 121 |
+
config.training.warmup_steps = int(config.training.max_steps * 0.01)
|
| 122 |
+
|
| 123 |
+
# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
|
| 124 |
+
checkpoint_callback = ModelCheckpoint(
|
| 125 |
+
monitor="train/total_loss",
|
| 126 |
+
mode="min",
|
| 127 |
+
save_top_k=config.training.save_top_k,
|
| 128 |
+
save_last=True,
|
| 129 |
+
filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
|
| 130 |
+
dirpath=config.training.checkpoint_dir,
|
| 131 |
+
# Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
|
| 132 |
+
auto_insert_metric_name=False
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Add separate callback for periodic saves (no val_loss dependency). Use
|
| 136 |
+
# step-based saves for streaming datasets (save_every_n_steps) and epoch-based
|
| 137 |
+
# saves otherwise (save_every_n_epochs); whichever the config provides.
|
| 138 |
+
save_every_n_steps = config.training.get('save_every_n_steps', None)
|
| 139 |
+
save_every_n_epochs = config.training.get('save_every_n_epochs', None)
|
| 140 |
+
if save_every_n_steps is not None:
|
| 141 |
+
periodic_checkpoint_callback = ModelCheckpoint(
|
| 142 |
+
save_top_k=-1, # Save all periodic checkpoints
|
| 143 |
+
filename="step-{step:08d}",
|
| 144 |
+
dirpath=config.training.checkpoint_dir,
|
| 145 |
+
every_n_train_steps=save_every_n_steps,
|
| 146 |
+
auto_insert_metric_name=False
|
| 147 |
+
)
|
| 148 |
+
elif save_every_n_epochs is not None:
|
| 149 |
+
periodic_checkpoint_callback = ModelCheckpoint(
|
| 150 |
+
save_top_k=-1, # Save all periodic checkpoints
|
| 151 |
+
filename="epoch-{epoch:02d}",
|
| 152 |
+
dirpath=config.training.checkpoint_dir,
|
| 153 |
+
every_n_epochs=save_every_n_epochs,
|
| 154 |
+
auto_insert_metric_name=False
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"Either save_every_n_steps or save_every_n_epochs must be specified in the config"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
|
| 162 |
+
|
| 163 |
+
if wandb_logger is not None:
|
| 164 |
+
trainer_kwargs["logger"] = wandb_logger
|
| 165 |
+
|
| 166 |
+
trainer = pl.Trainer(**trainer_kwargs)
|
| 167 |
+
|
| 168 |
+
# Train the model
|
| 169 |
+
ckpt_path = None
|
| 170 |
+
if "resume_path" in config.training:
|
| 171 |
+
ckpt_path = config.training.resume_path
|
| 172 |
+
|
| 173 |
+
trainer.fit(module,
|
| 174 |
+
datamodule=data_module,
|
| 175 |
+
ckpt_path=ckpt_path)
|
| 176 |
+
|
| 177 |
+
# Only finish wandb on rank 0
|
| 178 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
| 179 |
+
wandb.finish()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == '__main__':
|
| 183 |
+
# Parse arguments to get config name
|
| 184 |
+
parser = argparse.ArgumentParser()
|
| 185 |
+
parser.add_argument('--config_name', type=str, default='config',
|
| 186 |
+
help='Name of the config file to use')
|
| 187 |
+
parser.add_argument('--task', type=str, default=None,
|
| 188 |
+
help='Task name (uses config_{task}.yaml)')
|
| 189 |
+
|
| 190 |
+
# Parse known args (hydra will handle the rest)
|
| 191 |
+
args, unknown = parser.parse_known_args()
|
| 192 |
+
|
| 193 |
+
# Determine config name from task or config_name
|
| 194 |
+
if args.task:
|
| 195 |
+
config_name = f'config_{args.task}'
|
| 196 |
+
else:
|
| 197 |
+
config_name = args.config_name
|
| 198 |
+
|
| 199 |
+
print(f"Using config: {config_name}.yaml")
|
| 200 |
+
|
| 201 |
+
# Add config name to Hydra overrides (this persists across DDP subprocesses)
|
| 202 |
+
if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
|
| 203 |
+
unknown.insert(0, f'--config-name={config_name}')
|
| 204 |
+
|
| 205 |
+
# Reconstruct sys.argv for hydra
|
| 206 |
+
sys.argv = [sys.argv[0]] + unknown
|
| 207 |
+
|
| 208 |
+
# Define main function with default config (will be overridden by command line)
|
| 209 |
+
@hydra.main(version_base=None,
|
| 210 |
+
config_path=CONFIG_DIR,
|
| 211 |
+
config_name='config')
|
| 212 |
+
def main(config):
|
| 213 |
+
"""Main entry point for training"""
|
| 214 |
+
train(config)
|
| 215 |
+
|
| 216 |
+
main()
|
assets/a2d2.gif
ADDED
|
Git LFS Details
|
demo/quality_inference_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
environment.yml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Conda environment shared across the molecule, peptide, and language experiments.
|
| 2 |
+
# Create with:
|
| 3 |
+
# conda env create -f environment.yml
|
| 4 |
+
# conda activate a2d2
|
| 5 |
+
#
|
| 6 |
+
# NOTE: flash-attn is hardware-specific and must be built against your installed torch
|
| 7 |
+
# and CUDA, so it is not listed below. It is imported by the shared transformer backbone
|
| 8 |
+
# (model/casual_transformer.py, model/rotary.py) and is required for all experiments.
|
| 9 |
+
# After creating the env, install it with:
|
| 10 |
+
# pip install flash-attn==2.8.3 --no-build-isolation
|
| 11 |
+
# Adjust pytorch-cuda below to match your CUDA toolkit / GPU.
|
| 12 |
+
name: a2d2
|
| 13 |
+
channels:
|
| 14 |
+
- pytorch
|
| 15 |
+
- nvidia
|
| 16 |
+
- conda-forge
|
| 17 |
+
dependencies:
|
| 18 |
+
- python=3.11
|
| 19 |
+
- pip
|
| 20 |
+
- pytorch
|
| 21 |
+
- pytorch-cuda=12.1
|
| 22 |
+
- rdkit=2023.9.6
|
| 23 |
+
- jupyterlab # for demo/quality_inference_demo.ipynb
|
| 24 |
+
- pip:
|
| 25 |
+
# --- core scientific / DL stack ---
|
| 26 |
+
- numpy==1.26.4
|
| 27 |
+
- scipy==1.17.1
|
| 28 |
+
- pandas==2.1.4
|
| 29 |
+
- scikit-learn==1.8.0
|
| 30 |
+
- pytorch-lightning==2.6.0
|
| 31 |
+
- lightning==2.6.1
|
| 32 |
+
- transformers==4.55.4
|
| 33 |
+
- tokenizers==0.21.4
|
| 34 |
+
- safetensors==0.7.0
|
| 35 |
+
- accelerate==0.33.0
|
| 36 |
+
- peft==0.15.1 # LoRA adapters (language experiment)
|
| 37 |
+
- datasets==2.19.2
|
| 38 |
+
- huggingface-hub==0.36.2
|
| 39 |
+
- einops==0.8.2
|
| 40 |
+
- timm==1.0.26
|
| 41 |
+
- omegaconf==2.3.0
|
| 42 |
+
- wandb==0.26.1
|
| 43 |
+
# --- molecule experiment ---
|
| 44 |
+
- safe-mol==0.1.14
|
| 45 |
+
- datamol==0.12.5
|
| 46 |
+
- PyTDC==1.1.15
|
| 47 |
+
# --- peptide experiment ---
|
| 48 |
+
- SmilesPE==0.0.3
|
| 49 |
+
- fair-esm==2.0.0
|
| 50 |
+
- xgboost==3.2.0
|
| 51 |
+
# --- plotting / utilities ---
|
| 52 |
+
- matplotlib==3.10.6
|
| 53 |
+
- seaborn==0.13.2
|
| 54 |
+
- tqdm==4.67.1
|
| 55 |
+
- joblib==1.5.3
|
| 56 |
+
- loguru==0.7.3
|
| 57 |
+
- fsspec==2024.3.1
|
lightning_modules/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mdm import MaskedDiffusionModule
|
| 2 |
+
from .any_order import AnyOrderInsertionFlowModule
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"MaskedDiffusionModule",
|
| 7 |
+
"AutoregressiveModule",
|
| 8 |
+
"AnyOrderInsertionFlowModule",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def __getattr__(name):
|
| 13 |
+
if name == "AutoregressiveModule":
|
| 14 |
+
from .autoregressive import AutoregressiveModule
|
| 15 |
+
return AutoregressiveModule
|
| 16 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
lightning_modules/any_length_remask.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from model.transformer import AnyOrderMaskInsertionFlow
|
| 8 |
+
from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
|
| 9 |
+
from .bregman import jump_kernel_elbo, mse
|
| 10 |
+
from .schedule import get_schedule_from_config
|
| 11 |
+
from lightning_modules.any_order import AnyOrderInsertionFlowModule
|
| 12 |
+
from model.model_wrapper import RemaskingAnyOrder
|
| 13 |
+
from sampling import _sample_tokens
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import Dict, Any
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Returns a new state_dict where any key containing '._orig_mod.' is replaced
|
| 22 |
+
by removing the '_orig_mod' segment, e.g.
|
| 23 |
+
'model._orig_mod.vocab_embed.embedding'
|
| 24 |
+
becomes
|
| 25 |
+
'model.vocab_embed.embedding'
|
| 26 |
+
"""
|
| 27 |
+
new_state_dict: Dict[str, Any] = {}
|
| 28 |
+
for key, value in state_dict.items():
|
| 29 |
+
# remove all occurrences of '._orig_mod.'
|
| 30 |
+
clean_key = re.sub(r"\._orig_mod\.", ".", key)
|
| 31 |
+
new_state_dict[clean_key] = value
|
| 32 |
+
return new_state_dict
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float:
|
| 37 |
+
"""Rank-based AUROC (Mann-Whitney U statistic).
|
| 38 |
+
|
| 39 |
+
AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN
|
| 40 |
+
when only one class is present (AUC undefined). Ties are not averaged, which
|
| 41 |
+
is fine for continuous logits used here.
|
| 42 |
+
"""
|
| 43 |
+
scores = scores.float().reshape(-1)
|
| 44 |
+
labels = labels.float().reshape(-1)
|
| 45 |
+
n_pos = labels.sum()
|
| 46 |
+
n_neg = labels.numel() - n_pos
|
| 47 |
+
if n_pos == 0 or n_neg == 0:
|
| 48 |
+
return float("nan")
|
| 49 |
+
order = torch.argsort(scores)
|
| 50 |
+
ranks = torch.empty_like(scores)
|
| 51 |
+
ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype)
|
| 52 |
+
auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
|
| 53 |
+
return auc.item()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule):
|
| 57 |
+
"""
|
| 58 |
+
Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model
|
| 59 |
+
for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint
|
| 60 |
+
and add the schedule model on top.
|
| 61 |
+
"""
|
| 62 |
+
def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False):
|
| 63 |
+
# Initialize parent class first
|
| 64 |
+
super().__init__(config)
|
| 65 |
+
|
| 66 |
+
self.args = args
|
| 67 |
+
self.insertion_planner = insertion_planner
|
| 68 |
+
|
| 69 |
+
# Save hyperparameters for this class (overrides parent's save)
|
| 70 |
+
self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args'])
|
| 71 |
+
|
| 72 |
+
# Load pretrained model weights BEFORE initializing planner to avoid circular reference
|
| 73 |
+
if pretrained_checkpoint is not None:
|
| 74 |
+
self.load_pretrained_model(pretrained_checkpoint)
|
| 75 |
+
|
| 76 |
+
# Initialize adaptive schedule model AFTER loading pretrained weights
|
| 77 |
+
self.planner = RemaskingAnyOrder(
|
| 78 |
+
backbone=self,
|
| 79 |
+
d_model=self.config.model.hidden_size,
|
| 80 |
+
insertion_planner=insertion_planner)
|
| 81 |
+
|
| 82 |
+
def load_pretrained_model(self, checkpoint_path: str):
|
| 83 |
+
"""
|
| 84 |
+
Load pretrained AnyOrderInsertionFlowModule weights.
|
| 85 |
+
Only loads the base model and interpolant, not the schedule model.
|
| 86 |
+
"""
|
| 87 |
+
print(f"Loading pretrained model from {checkpoint_path}")
|
| 88 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 89 |
+
|
| 90 |
+
# Extract state dict - handle different checkpoint formats
|
| 91 |
+
if 'state_dict' in checkpoint:
|
| 92 |
+
state_dict = checkpoint['state_dict']
|
| 93 |
+
else:
|
| 94 |
+
state_dict = checkpoint
|
| 95 |
+
|
| 96 |
+
# Strip _orig_mod keys if present
|
| 97 |
+
state_dict = strip_orig_mod_keys(state_dict)
|
| 98 |
+
|
| 99 |
+
# Filter out planner keys (if any exist from a previous FT checkpoint)
|
| 100 |
+
base_state_dict = {k: v for k, v in state_dict.items()
|
| 101 |
+
if not k.startswith('planner.')}
|
| 102 |
+
|
| 103 |
+
# Load the base model weights
|
| 104 |
+
# Use strict=False to ignore missing schedule_model keys
|
| 105 |
+
incompatible_keys = self.load_state_dict(base_state_dict, strict=False)
|
| 106 |
+
|
| 107 |
+
# Filter out expected missing planner keys for cleaner output
|
| 108 |
+
unexpected_missing = [k for k in incompatible_keys.missing_keys
|
| 109 |
+
if not k.startswith('planner.')]
|
| 110 |
+
planner_missing = [k for k in incompatible_keys.missing_keys
|
| 111 |
+
if k.startswith('planner.')]
|
| 112 |
+
|
| 113 |
+
if unexpected_missing:
|
| 114 |
+
print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}")
|
| 115 |
+
if planner_missing:
|
| 116 |
+
print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)")
|
| 117 |
+
if incompatible_keys.unexpected_keys:
|
| 118 |
+
print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}")
|
| 119 |
+
|
| 120 |
+
# Freeze base model if specified
|
| 121 |
+
if self.config.training.get('freeze_base_model', False):
|
| 122 |
+
print("Freezing base model parameters")
|
| 123 |
+
for name, param in self.named_parameters():
|
| 124 |
+
if not name.startswith('planner.'):
|
| 125 |
+
param.requires_grad = False
|
| 126 |
+
|
| 127 |
+
def forward(self, x, t, return_features=False):
|
| 128 |
+
# Use parent class forward method
|
| 129 |
+
return super().forward(x, t, return_features=return_features)
|
| 130 |
+
|
| 131 |
+
def training_loss(self, x1, t):
|
| 132 |
+
# Use parent class training_loss for base model loss
|
| 133 |
+
# Planner is trained separately via loss_planner_flexible with reward gradients
|
| 134 |
+
unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t)
|
| 135 |
+
return unmask_loss, insertion_loss, total_loss
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def training_step(self, batch, batch_idx):
|
| 139 |
+
# Extract input data
|
| 140 |
+
if isinstance(batch, dict):
|
| 141 |
+
batch = batch["input_ids"]
|
| 142 |
+
|
| 143 |
+
x1 = batch
|
| 144 |
+
t = self.sample_time(x1.shape[0], x1.device)
|
| 145 |
+
|
| 146 |
+
# Calculate the base model loss (planner trained separately, not here)
|
| 147 |
+
unmask_loss, len_loss, loss = self.training_loss(x1, t)
|
| 148 |
+
|
| 149 |
+
# Log component losses
|
| 150 |
+
self.log("train/unmask_loss", unmask_loss, prog_bar=True)
|
| 151 |
+
self.log("train/len_loss", len_loss, prog_bar=True)
|
| 152 |
+
self.log("train/total_loss", loss, prog_bar=True)
|
| 153 |
+
|
| 154 |
+
return loss
|
| 155 |
+
|
| 156 |
+
def validation_step(self, batch, batch_idx):
|
| 157 |
+
if isinstance(batch, dict):
|
| 158 |
+
batch = batch["input_ids"]
|
| 159 |
+
|
| 160 |
+
x1 = batch
|
| 161 |
+
t = self.sample_time(x1.shape[0], x1.device)
|
| 162 |
+
unmask_loss, len_loss, loss = self.training_loss(x1, t)
|
| 163 |
+
|
| 164 |
+
self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
|
| 165 |
+
self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
|
| 166 |
+
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
| 167 |
+
|
| 168 |
+
return loss
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs):
|
| 172 |
+
"""
|
| 173 |
+
Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner.
|
| 174 |
+
Extracts config from original pretrained checkpoint and loads finetuned weights.
|
| 175 |
+
"""
|
| 176 |
+
print(f"Loading finetuned checkpoint from {checkpoint_path}")
|
| 177 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False)
|
| 178 |
+
|
| 179 |
+
# Check if this is a wrapped checkpoint (from PeptideFinetuner)
|
| 180 |
+
hparams = checkpoint.get('hyper_parameters', {})
|
| 181 |
+
state_dict = checkpoint.get('state_dict', {})
|
| 182 |
+
|
| 183 |
+
# Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper)
|
| 184 |
+
has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys())
|
| 185 |
+
|
| 186 |
+
if has_policy_prefix:
|
| 187 |
+
# Detect model type (molecule vs peptide) based on vocab size in checkpoint
|
| 188 |
+
# Molecule models have vocab size ~1882, peptide models have ~587
|
| 189 |
+
vocab_size = None
|
| 190 |
+
for k, v in state_dict.items():
|
| 191 |
+
if 'vocab_embed.embedding' in k:
|
| 192 |
+
vocab_size = v.shape[0]
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
is_molecule_model = vocab_size is not None and vocab_size > 1000
|
| 196 |
+
model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner"
|
| 197 |
+
print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})")
|
| 198 |
+
|
| 199 |
+
# Extract args from hyperparameters
|
| 200 |
+
if 'args' not in hparams:
|
| 201 |
+
raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.")
|
| 202 |
+
|
| 203 |
+
args = hparams['args']
|
| 204 |
+
print(f"Found args in hyperparameters, type: {type(args)}")
|
| 205 |
+
|
| 206 |
+
# Get original checkpoint path from args
|
| 207 |
+
# Handle both Namespace (hasattr) and dict (get) access patterns
|
| 208 |
+
original_ckpt_path = None
|
| 209 |
+
if hasattr(args, 'checkpoint_path'):
|
| 210 |
+
original_ckpt_path = args.checkpoint_path
|
| 211 |
+
elif isinstance(args, dict) and 'checkpoint_path' in args:
|
| 212 |
+
original_ckpt_path = args['checkpoint_path']
|
| 213 |
+
|
| 214 |
+
# If checkpoint_path is not set or is None, use default pretrained checkpoint
|
| 215 |
+
# Select appropriate default based on detected model type
|
| 216 |
+
if original_ckpt_path is None:
|
| 217 |
+
_repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 218 |
+
if is_molecule_model:
|
| 219 |
+
original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt')
|
| 220 |
+
print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint")
|
| 221 |
+
else:
|
| 222 |
+
original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt')
|
| 223 |
+
print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint")
|
| 224 |
+
|
| 225 |
+
# Try to load config directly from checkpoint first (new checkpoints)
|
| 226 |
+
# Fall back to loading from original checkpoint (old checkpoints)
|
| 227 |
+
if 'config' in checkpoint:
|
| 228 |
+
print("Found config directly in checkpoint")
|
| 229 |
+
config = checkpoint['config']
|
| 230 |
+
else:
|
| 231 |
+
print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}")
|
| 232 |
+
|
| 233 |
+
# Load config from original pretrained checkpoint
|
| 234 |
+
orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False)
|
| 235 |
+
if 'config' not in orig_ckpt:
|
| 236 |
+
raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config")
|
| 237 |
+
|
| 238 |
+
config = orig_ckpt['config']
|
| 239 |
+
|
| 240 |
+
# Ensure adaptive schedule is enabled
|
| 241 |
+
# Need to disable struct mode to add new keys to OmegaConf config
|
| 242 |
+
from omegaconf import OmegaConf
|
| 243 |
+
if hasattr(config, 'training'):
|
| 244 |
+
OmegaConf.set_struct(config, False)
|
| 245 |
+
config.training.use_adaptive_schedule = True
|
| 246 |
+
OmegaConf.set_struct(config, True)
|
| 247 |
+
|
| 248 |
+
# Create args object if needed
|
| 249 |
+
if not hasattr(args, '__dict__'):
|
| 250 |
+
# Convert dict to object with attributes
|
| 251 |
+
class Args:
|
| 252 |
+
pass
|
| 253 |
+
args_obj = Args()
|
| 254 |
+
for k, v in args.items():
|
| 255 |
+
setattr(args_obj, k, v)
|
| 256 |
+
args = args_obj
|
| 257 |
+
|
| 258 |
+
# Initialize model with config and args
|
| 259 |
+
model = cls(
|
| 260 |
+
config=config,
|
| 261 |
+
args=args,
|
| 262 |
+
pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint
|
| 263 |
+
insertion_planner=getattr(args, 'insertion_planner', False)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Extract policy_model weights from state_dict
|
| 267 |
+
policy_state = {}
|
| 268 |
+
for k, v in state_dict.items():
|
| 269 |
+
if k.startswith('policy_model.'):
|
| 270 |
+
# Strip 'policy_model.' prefix
|
| 271 |
+
new_key = k[len('policy_model.'):]
|
| 272 |
+
policy_state[new_key] = v
|
| 273 |
+
|
| 274 |
+
# Load the finetuned weights
|
| 275 |
+
incompatible = model.load_state_dict(policy_state, strict=False)
|
| 276 |
+
if incompatible.missing_keys or incompatible.unexpected_keys:
|
| 277 |
+
print(f"Warning: Incompatible keys when loading finetuned weights:")
|
| 278 |
+
if incompatible.missing_keys:
|
| 279 |
+
print(f" Missing: {incompatible.missing_keys[:5]}...")
|
| 280 |
+
if incompatible.unexpected_keys:
|
| 281 |
+
print(f" Unexpected: {incompatible.unexpected_keys[:5]}...")
|
| 282 |
+
|
| 283 |
+
# Initialize or load EMA params
|
| 284 |
+
if model.use_ema:
|
| 285 |
+
if "ema_params" in checkpoint:
|
| 286 |
+
# Load EMA params from checkpoint
|
| 287 |
+
model.ema_params = checkpoint["ema_params"]
|
| 288 |
+
print("Loaded EMA params from checkpoint")
|
| 289 |
+
else:
|
| 290 |
+
# Initialize empty EMA params (will be populated if needed)
|
| 291 |
+
model.ema_params = {
|
| 292 |
+
name: param.clone().detach()
|
| 293 |
+
for name, param in model.named_parameters()
|
| 294 |
+
}
|
| 295 |
+
print("Initialized EMA params from current model state")
|
| 296 |
+
else:
|
| 297 |
+
model.ema_params = {}
|
| 298 |
+
|
| 299 |
+
# Load planner state if it exists
|
| 300 |
+
if "planner_state" in checkpoint and hasattr(model, 'planner'):
|
| 301 |
+
model.planner.load_state_dict(checkpoint["planner_state"], strict=False)
|
| 302 |
+
print("Loaded planner state from checkpoint")
|
| 303 |
+
|
| 304 |
+
return model
|
| 305 |
+
else:
|
| 306 |
+
# Not a wrapped checkpoint, use default Lightning loading
|
| 307 |
+
# But we still need to provide required __init__ arguments
|
| 308 |
+
raise NotImplementedError(
|
| 309 |
+
"Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. "
|
| 310 |
+
"Please provide config and args as kwargs."
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def on_save_checkpoint(self, checkpoint):
|
| 314 |
+
"""Save config and EMA params, including planner state."""
|
| 315 |
+
# Call parent to save config and base model EMA
|
| 316 |
+
super().on_save_checkpoint(checkpoint)
|
| 317 |
+
|
| 318 |
+
# Explicitly save planner state
|
| 319 |
+
if hasattr(self, 'planner'):
|
| 320 |
+
checkpoint["planner_state"] = self.planner.state_dict()
|
| 321 |
+
|
| 322 |
+
def on_load_checkpoint(self, checkpoint):
|
| 323 |
+
"""Load config and reinitialize interpolant, including planner."""
|
| 324 |
+
# For finetuned checkpoints loaded via custom load_from_checkpoint,
|
| 325 |
+
# config may not be in checkpoint (it's loaded from original checkpoint)
|
| 326 |
+
if "config" in checkpoint:
|
| 327 |
+
# Call parent to restore config and interpolant
|
| 328 |
+
super().on_load_checkpoint(checkpoint)
|
| 329 |
+
else:
|
| 330 |
+
# Config already set during __init__ via load_from_checkpoint
|
| 331 |
+
# Just restore EMA params if they exist
|
| 332 |
+
if self.use_ema and "ema_params" in checkpoint:
|
| 333 |
+
self.ema_params = checkpoint["ema_params"]
|
| 334 |
+
|
| 335 |
+
# Restore planner state if it exists in checkpoint
|
| 336 |
+
if hasattr(self, 'planner') and "planner_state" in checkpoint:
|
| 337 |
+
self.planner.load_state_dict(checkpoint["planner_state"])
|
| 338 |
+
print("Loaded planner from checkpoint")
|
| 339 |
+
|
| 340 |
+
def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
|
| 341 |
+
r"""
|
| 342 |
+
Weighted denoising cross entropy loss
|
| 343 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 344 |
+
|
| 345 |
+
log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
|
| 346 |
+
x: [B, L] (no mask)
|
| 347 |
+
num_replicates: R, number of replicates of each row in x
|
| 348 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 349 |
+
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
|
| 350 |
+
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 354 |
+
|
| 355 |
+
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
|
| 356 |
+
if centering:
|
| 357 |
+
batch_weights = batch_weights - centering_strength * batch_weights.mean()
|
| 358 |
+
|
| 359 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 360 |
+
|
| 361 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 362 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 363 |
+
|
| 364 |
+
t = lamda
|
| 365 |
+
|
| 366 |
+
# compute unmasking and insertion loss
|
| 367 |
+
interpolant_sample = self.interpolant.sample_interpolant(t, batch)
|
| 368 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
|
| 369 |
+
|
| 370 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 371 |
+
|
| 372 |
+
scale_factor = self.config.interpolant.max_length
|
| 373 |
+
|
| 374 |
+
match self.unmask_loss_fn:
|
| 375 |
+
case "elbo":
|
| 376 |
+
mask_indices = interpolant_sample.mask_indices
|
| 377 |
+
unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
|
| 378 |
+
unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
|
| 379 |
+
prediction.token_logits[mask_indices],
|
| 380 |
+
interpolant_sample.unmasked[mask_indices],
|
| 381 |
+
reduction="none",
|
| 382 |
+
)
|
| 383 |
+
unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 384 |
+
case _:
|
| 385 |
+
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
|
| 386 |
+
|
| 387 |
+
match self.insert_loss_fn:
|
| 388 |
+
case "expectation":
|
| 389 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 390 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
|
| 391 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
|
| 392 |
+
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
|
| 393 |
+
)
|
| 394 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 395 |
+
|
| 396 |
+
case "distribution":
|
| 397 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 398 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
|
| 399 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
|
| 400 |
+
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
|
| 401 |
+
)
|
| 402 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 403 |
+
|
| 404 |
+
total_loss = unmask_loss + insertion_loss # [B*R]
|
| 405 |
+
# end compute unmasking and insertion loss
|
| 406 |
+
|
| 407 |
+
weighted_loss = total_loss * batch_weights # [B*R]
|
| 408 |
+
return weighted_loss.mean()
|
| 409 |
+
|
| 410 |
+
def one_step_sampler(self, xt, t, pred_rate=None):
|
| 411 |
+
"""
|
| 412 |
+
Sample one step of unmasking using model predictions.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
xt: Current state [B, L]
|
| 416 |
+
t: Time [B]
|
| 417 |
+
pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
new_xt: Next state [B, L]
|
| 421 |
+
update_ids: Boolean mask of updated positions [B, L]
|
| 422 |
+
"""
|
| 423 |
+
mask = self.interpolant.mask_token
|
| 424 |
+
pad = self.interpolant.pad_token
|
| 425 |
+
batch_size, L = xt.shape
|
| 426 |
+
device = xt.device
|
| 427 |
+
steps = self.args.total_num_steps
|
| 428 |
+
dt = 1.0 / steps
|
| 429 |
+
max_length = self.interpolant.max_length
|
| 430 |
+
# Use actual tensor dimension L instead of max_length to handle replicated batches
|
| 431 |
+
batch_idx_L = (
|
| 432 |
+
torch.arange(batch_size, device=device)
|
| 433 |
+
.view(batch_size, 1)
|
| 434 |
+
.expand(batch_size, L)
|
| 435 |
+
)
|
| 436 |
+
pos_idx_L = (
|
| 437 |
+
torch.arange(L, device=device)
|
| 438 |
+
.view(1, L)
|
| 439 |
+
.expand(batch_size, L)
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# ——— predict and convert rates ———
|
| 443 |
+
if pred_rate is None:
|
| 444 |
+
pred_rate = self(xt, t)
|
| 445 |
+
pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t)
|
| 446 |
+
unmask_rate = pred_rate.unmask_rate # (B, L, V)
|
| 447 |
+
len_rate = pred_rate.length_rate # (B, L+1)
|
| 448 |
+
|
| 449 |
+
# ——— unmask step (Euler) ———
|
| 450 |
+
mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True)
|
| 451 |
+
unmask_rate[xt != mask] = 0
|
| 452 |
+
unmask_rate[mask_pos + (mask,)] = 0
|
| 453 |
+
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
|
| 454 |
+
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
|
| 455 |
+
|
| 456 |
+
# add "stay" probability
|
| 457 |
+
_xt = xt.clone()
|
| 458 |
+
_xt[xt == pad] = mask
|
| 459 |
+
trans_prob.scatter_add_(
|
| 460 |
+
2,
|
| 461 |
+
_xt.unsqueeze(-1),
|
| 462 |
+
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
|
| 466 |
+
|
| 467 |
+
# Renormalize probabilities to ensure they sum to 1
|
| 468 |
+
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
|
| 469 |
+
# Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
|
| 470 |
+
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
|
| 471 |
+
if mask_has_zero_prob.any():
|
| 472 |
+
# Create uniform distribution over valid tokens (excluding mask and pad)
|
| 473 |
+
num_zero_prob = mask_has_zero_prob.sum().item()
|
| 474 |
+
uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
|
| 475 |
+
uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
|
| 476 |
+
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
|
| 477 |
+
else:
|
| 478 |
+
# Normalize to sum to 1
|
| 479 |
+
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
|
| 480 |
+
|
| 481 |
+
new_xt = _sample_tokens(trans_prob)
|
| 482 |
+
new_xt[xt == pad] = pad
|
| 483 |
+
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
|
| 484 |
+
|
| 485 |
+
# update indices--boolean tensor of shape (B, max_length)
|
| 486 |
+
# A position is updated if:
|
| 487 |
+
# 1. The token changed (xt != new_xt)
|
| 488 |
+
# 2. It's not a pad position
|
| 489 |
+
# 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask)
|
| 490 |
+
|
| 491 |
+
# Debug before fix
|
| 492 |
+
old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask)
|
| 493 |
+
|
| 494 |
+
# Correct logic: updated positions are where mask tokens were changed
|
| 495 |
+
update_ids = (xt != new_xt) & (xt != pad)
|
| 496 |
+
|
| 497 |
+
if self.insertion_planner is False:
|
| 498 |
+
return new_xt, update_ids
|
| 499 |
+
|
| 500 |
+
# ——— Poisson insertion (tau-leaping) — can insert multiple masks per gap ———
|
| 501 |
+
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
|
| 502 |
+
xt_len = xt.ne(pad).sum(dim=1) # (B,)
|
| 503 |
+
# Use ext.shape[1] to get the actual max_length dimension from the data
|
| 504 |
+
actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1
|
| 505 |
+
gaps = torch.arange(ext.shape[1], device=device).view(1, -1)
|
| 506 |
+
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
|
| 507 |
+
total_ext = ext.sum(dim=1)
|
| 508 |
+
valid = xt_len + total_ext <= actual_max_length
|
| 509 |
+
ext = ext * valid.view(batch_size, 1).long()
|
| 510 |
+
|
| 511 |
+
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
|
| 512 |
+
new_len = xt_len + total_ext # (B,)
|
| 513 |
+
|
| 514 |
+
xt_tmp = torch.full_like(xt, pad)
|
| 515 |
+
# Create position indices that match xt_tmp's shape
|
| 516 |
+
pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1)
|
| 517 |
+
mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1)
|
| 518 |
+
xt_tmp[mask_fill] = mask
|
| 519 |
+
|
| 520 |
+
new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L)
|
| 521 |
+
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
|
| 522 |
+
flat_b = batch_idx_L[orig_mask]
|
| 523 |
+
flat_p = new_pos_orig[orig_mask]
|
| 524 |
+
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
|
| 525 |
+
|
| 526 |
+
new_ins_xt = xt_tmp
|
| 527 |
+
|
| 528 |
+
# Newly inserted masks: positions that are mask now but weren't before.
|
| 529 |
+
newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad)
|
| 530 |
+
|
| 531 |
+
update_ins_ids = newly_inserted_masks
|
| 532 |
+
|
| 533 |
+
return new_xt, update_ids, new_ins_xt, update_ins_ids
|
| 534 |
+
|
| 535 |
+
def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
|
| 536 |
+
r"""
|
| 537 |
+
Weighted denoising cross entropy loss
|
| 538 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 539 |
+
|
| 540 |
+
log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
|
| 541 |
+
x: [B, L] (no mask)
|
| 542 |
+
num_replicates: R, number of replicates of each row in x
|
| 543 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 544 |
+
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
|
| 545 |
+
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 549 |
+
batch_size = batch.shape[0]
|
| 550 |
+
|
| 551 |
+
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
|
| 552 |
+
if centering:
|
| 553 |
+
batch_weights = batch_weights - centering_strength * batch_weights.mean()
|
| 554 |
+
|
| 555 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 556 |
+
|
| 557 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 558 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 559 |
+
|
| 560 |
+
t = lamda
|
| 561 |
+
scale_factor = self.config.interpolant.max_length
|
| 562 |
+
|
| 563 |
+
# compute unmasking and insertion loss
|
| 564 |
+
interpolant_sample = self.interpolant.sample_interpolant(t, batch)
|
| 565 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
|
| 566 |
+
|
| 567 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 568 |
+
|
| 569 |
+
with torch.no_grad(): # no need to compute gradient in this step
|
| 570 |
+
sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction)
|
| 571 |
+
# one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids)
|
| 572 |
+
xs, update_ids = sampler_out[0], sampler_out[1]
|
| 573 |
+
|
| 574 |
+
# The remasking head scores the freshly-decoded tokens to decide which to
|
| 575 |
+
# remask, so it reads the POST-unmask state xs (matching inference, which
|
| 576 |
+
# calls the planner on the decoded new_xt).
|
| 577 |
+
planner = self.planner(xs, t)
|
| 578 |
+
remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
|
| 579 |
+
|
| 580 |
+
# Compute per-sample loss
|
| 581 |
+
# IMPORTANT: interpolant_sample.xt has been reordered via st permutation
|
| 582 |
+
# We need to map back to the original positions to compare with batch
|
| 583 |
+
st = interpolant_sample.st # [B*R, L] permutation indices
|
| 584 |
+
batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
|
| 585 |
+
|
| 586 |
+
binary_label = (xs == batch_reordered).float()
|
| 587 |
+
|
| 588 |
+
# Only compute loss on positions that were updated
|
| 589 |
+
per_token_loss = F.binary_cross_entropy_with_logits(
|
| 590 |
+
remasking_conf.squeeze(-1), # [B*R, L]
|
| 591 |
+
binary_label, # [B*R, L]
|
| 592 |
+
reduction="none" # [B*R, L]
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
per_token_loss = per_token_loss * update_ids.float() # [B*R, L]
|
| 596 |
+
|
| 597 |
+
# Mask out non-updated positions and average per sample
|
| 598 |
+
per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R]
|
| 599 |
+
|
| 600 |
+
# Weight by importance sampling weights
|
| 601 |
+
weighted_loss = per_sample_loss * batch_weights # [B*R]
|
| 602 |
+
|
| 603 |
+
# ——— AUC / label-balance diagnostics (see loss_insert_planner_flexible) ———
|
| 604 |
+
with torch.no_grad():
|
| 605 |
+
metrics = {}
|
| 606 |
+
sel_u = update_ids.bool()
|
| 607 |
+
if sel_u.any():
|
| 608 |
+
u_scores = remasking_conf.squeeze(-1)[sel_u]
|
| 609 |
+
u_labels = binary_label[sel_u]
|
| 610 |
+
metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
|
| 611 |
+
metrics["unmask_label_mean"] = u_labels.mean().item()
|
| 612 |
+
metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
|
| 613 |
+
metrics["unmask_n"] = float(sel_u.sum().item())
|
| 614 |
+
self._last_planner_metrics = metrics
|
| 615 |
+
|
| 616 |
+
return weighted_loss.mean()
|
| 617 |
+
|
| 618 |
+
def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
|
| 619 |
+
r"""
|
| 620 |
+
Weighted denoising cross entropy loss
|
| 621 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 622 |
+
|
| 623 |
+
log_rnd: [B] — pre-computed importance weights
|
| 624 |
+
x: [B, L] (no mask)
|
| 625 |
+
num_replicates: R, number of replicates of each row in x
|
| 626 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 627 |
+
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
|
| 628 |
+
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
|
| 629 |
+
"""
|
| 630 |
+
|
| 631 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 632 |
+
batch_size = batch.shape[0]
|
| 633 |
+
|
| 634 |
+
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
|
| 635 |
+
if centering:
|
| 636 |
+
batch_weights = batch_weights - centering_strength * batch_weights.mean()
|
| 637 |
+
|
| 638 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 639 |
+
|
| 640 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 641 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 642 |
+
|
| 643 |
+
t = lamda
|
| 644 |
+
scale_factor = self.config.interpolant.max_length
|
| 645 |
+
|
| 646 |
+
# compute unmasking and insertion loss
|
| 647 |
+
# deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted
|
| 648 |
+
# gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices
|
| 649 |
+
interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch)
|
| 650 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
|
| 651 |
+
|
| 652 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 653 |
+
|
| 654 |
+
with torch.no_grad(): # no need to compute gradient in this step
|
| 655 |
+
xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction)
|
| 656 |
+
|
| 657 |
+
# The remasking head scores the freshly-decoded tokens to decide which to
|
| 658 |
+
# remask, so it must see the POST-unmask state xs_unmask (matching
|
| 659 |
+
# inference in inference_quality.py, which calls the planner on the
|
| 660 |
+
# decoded new_xt). Grad stays on here since this head is what we train.
|
| 661 |
+
planner = self.planner(xs_unmask, t)
|
| 662 |
+
remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
|
| 663 |
+
|
| 664 |
+
# The insertion-quality head scores the freshly-inserted mask tokens, so
|
| 665 |
+
# it must see the POST-insertion state xs_insert (aligned with
|
| 666 |
+
# update_ins_ids / insertion_quality below, and matching inference in
|
| 667 |
+
# remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on
|
| 668 |
+
# here since this head is what we are training.
|
| 669 |
+
if self.planner.insertion_planner:
|
| 670 |
+
insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1]
|
| 671 |
+
else:
|
| 672 |
+
insertion_conf = None
|
| 673 |
+
|
| 674 |
+
# Compute per-sample loss
|
| 675 |
+
# IMPORTANT: interpolant_sample.xt has been reordered via st permutation
|
| 676 |
+
# We need to map back to the original positions to compare with batch
|
| 677 |
+
# Use the st (permutation) to get the ground truth in the reordered space
|
| 678 |
+
st = interpolant_sample.st # [B*R, L] permutation indices
|
| 679 |
+
batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
|
| 680 |
+
|
| 681 |
+
# Now compare in the reordered space
|
| 682 |
+
binary_label = (xs_unmask == batch_reordered).float()
|
| 683 |
+
|
| 684 |
+
# Only compute loss on positions that were updated
|
| 685 |
+
per_token_loss = F.binary_cross_entropy_with_logits(
|
| 686 |
+
remasking_conf.squeeze(-1), # [B*R, L]
|
| 687 |
+
binary_label, # [B*R, L]
|
| 688 |
+
reduction="none" # [B*R, L]
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L]
|
| 692 |
+
|
| 693 |
+
# Mask out non-updated positions and average per sample
|
| 694 |
+
unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R]
|
| 695 |
+
|
| 696 |
+
# compute insertion planner loss
|
| 697 |
+
# For positions where masks were inserted, we evaluate the quality of insertion
|
| 698 |
+
# by computing the probability that the ground truth token would be predicted at that position
|
| 699 |
+
|
| 700 |
+
# IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted
|
| 701 |
+
# The original prediction was computed from xt (before insertion)
|
| 702 |
+
with torch.no_grad():
|
| 703 |
+
prediction_after_insert: ModelPrediction = self(xs_insert, t)
|
| 704 |
+
|
| 705 |
+
# Get the token prediction probabilities at inserted mask positions
|
| 706 |
+
# prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert
|
| 707 |
+
token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V]
|
| 708 |
+
|
| 709 |
+
# For each gap where masks were inserted, compute the sum of probabilities
|
| 710 |
+
# of the ground truth tokens that were deleted in that specific gap
|
| 711 |
+
# gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices
|
| 712 |
+
# batch: [B*R, L] - ground truth tokens in original space (before permutation)
|
| 713 |
+
|
| 714 |
+
vocab_size = token_probs.shape[-1]
|
| 715 |
+
L = token_probs.shape[1]
|
| 716 |
+
max_gaps = gap_assignment.shape[1]
|
| 717 |
+
|
| 718 |
+
# For each gap, create a vocabulary mask of tokens that belong to that gap
|
| 719 |
+
# gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx
|
| 720 |
+
gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float)
|
| 721 |
+
|
| 722 |
+
# Vectorized: gather tokens from batch for all gaps at once
|
| 723 |
+
# tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions
|
| 724 |
+
tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L]
|
| 725 |
+
|
| 726 |
+
# valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad
|
| 727 |
+
valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L]
|
| 728 |
+
|
| 729 |
+
# Scatter tokens into vocabulary dimension: mark which tokens appear in each gap
|
| 730 |
+
gap_vocab_mask.scatter_add_(
|
| 731 |
+
2, # scatter along vocabulary dimension
|
| 732 |
+
tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L]
|
| 733 |
+
valid_mask.float() # values to add [B*R, max_gaps, L]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Binarize: a token either appears in the gap or not
|
| 737 |
+
gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V]
|
| 738 |
+
|
| 739 |
+
# For each insertion position in xs_insert, determine which gap it corresponds to
|
| 740 |
+
# Position p in xs_insert corresponds to gap p (insertions occur between existing tokens)
|
| 741 |
+
# Vectorized: compute for all positions at once
|
| 742 |
+
# token_probs: [B*R, L, V]
|
| 743 |
+
# gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1
|
| 744 |
+
insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L]
|
| 745 |
+
|
| 746 |
+
# Only consider quality at positions where masks were actually inserted
|
| 747 |
+
insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L]
|
| 748 |
+
|
| 749 |
+
# Compute insertion planner loss only if insertion_planner is enabled
|
| 750 |
+
if insertion_conf is not None:
|
| 751 |
+
# The planner predicts insertion confidence with insertion_conf
|
| 752 |
+
# We want to train it to predict high confidence when insertion_quality is high
|
| 753 |
+
# Use Bernoulli cross-entropy: treat insertion_quality as the "success probability"
|
| 754 |
+
|
| 755 |
+
# Binary cross-entropy with insertion_quality as continuous labels in [0,1]
|
| 756 |
+
ins_per_token_loss = F.binary_cross_entropy_with_logits(
|
| 757 |
+
insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits
|
| 758 |
+
insertion_quality, # [B*R, L] - ground truth token probability as quality metric
|
| 759 |
+
reduction="none"
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Only compute loss where masks were actually inserted
|
| 763 |
+
ins_per_token_loss = ins_per_token_loss * update_ins_ids.float()
|
| 764 |
+
|
| 765 |
+
# Average per sample
|
| 766 |
+
ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8)
|
| 767 |
+
else:
|
| 768 |
+
# No insertion planner - set loss to zero
|
| 769 |
+
ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss)
|
| 770 |
+
|
| 771 |
+
# Add to total loss
|
| 772 |
+
per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss
|
| 773 |
+
|
| 774 |
+
# Weight by importance sampling weights
|
| 775 |
+
weighted_loss = per_sample_loss * batch_weights # [B*R]
|
| 776 |
+
|
| 777 |
+
# ——— AUC / label-balance diagnostics (the loss alone hides degenerate
|
| 778 |
+
# targets; near-0 BCE can mean "all labels one class", not "learned") ———
|
| 779 |
+
with torch.no_grad():
|
| 780 |
+
metrics = {}
|
| 781 |
+
sel_u = update_unmask_ids.bool()
|
| 782 |
+
if sel_u.any():
|
| 783 |
+
u_scores = remasking_conf.squeeze(-1)[sel_u]
|
| 784 |
+
u_labels = binary_label[sel_u]
|
| 785 |
+
metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
|
| 786 |
+
metrics["unmask_label_mean"] = u_labels.mean().item()
|
| 787 |
+
metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
|
| 788 |
+
metrics["unmask_n"] = float(sel_u.sum().item())
|
| 789 |
+
if insertion_conf is not None:
|
| 790 |
+
sel_i = update_ins_ids.bool()
|
| 791 |
+
if sel_i.any():
|
| 792 |
+
i_scores = insertion_conf.squeeze(-1)[sel_i]
|
| 793 |
+
i_targets = insertion_quality[sel_i]
|
| 794 |
+
i_labels = (i_targets > 0.5).float()
|
| 795 |
+
metrics["insert_auc"] = _binary_auc(i_scores, i_labels)
|
| 796 |
+
metrics["insert_target_mean"] = i_targets.mean().item()
|
| 797 |
+
metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item()
|
| 798 |
+
metrics["insert_n"] = float(sel_i.sum().item())
|
| 799 |
+
self._last_planner_metrics = metrics
|
| 800 |
+
|
| 801 |
+
return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean()
|
lightning_modules/any_order.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
from omegaconf import DictConfig
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from model.transformer import AnyOrderMaskInsertionFlow
|
| 6 |
+
from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
|
| 7 |
+
from .bregman import jump_kernel_elbo, mse
|
| 8 |
+
from .schedule import get_schedule_from_config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
from typing import Dict, Any
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 16 |
+
"""
|
| 17 |
+
Returns a new state_dict where any key containing '._orig_mod.' is replaced
|
| 18 |
+
by removing the '_orig_mod' segment, e.g.
|
| 19 |
+
'model._orig_mod.vocab_embed.embedding'
|
| 20 |
+
becomes
|
| 21 |
+
'model.vocab_embed.embedding'
|
| 22 |
+
"""
|
| 23 |
+
new_state_dict: Dict[str, Any] = {}
|
| 24 |
+
for key, value in state_dict.items():
|
| 25 |
+
# remove all occurrences of '._orig_mod.'
|
| 26 |
+
clean_key = re.sub(r"\._orig_mod\.", ".", key)
|
| 27 |
+
new_state_dict[clean_key] = value
|
| 28 |
+
return new_state_dict
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AnyOrderInsertionFlowModule(pl.LightningModule):
|
| 32 |
+
def __init__(self, config: DictConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.config = config
|
| 35 |
+
self.model_type = config.interpolant.type
|
| 36 |
+
self.learning_rate = config.training.learning_rate
|
| 37 |
+
self.unmask_loss_fn = config.training.loss_fn.unmask
|
| 38 |
+
self.insert_loss_fn = config.training.loss_fn.insert
|
| 39 |
+
|
| 40 |
+
# Initialize model based on type
|
| 41 |
+
self.model = AnyOrderMaskInsertionFlow(config)
|
| 42 |
+
# self.model = torch.compile(self.model) # Disabled: incompatible with flex_attention nested functions
|
| 43 |
+
|
| 44 |
+
insert_schedule = get_schedule_from_config(config.interpolant.insert_schedule)
|
| 45 |
+
unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
|
| 46 |
+
|
| 47 |
+
# Initialize interpolant
|
| 48 |
+
self.interpolant = AnyOrderMaskInsertionInterpolant(
|
| 49 |
+
insertion_schedule=insert_schedule,
|
| 50 |
+
unmask_schedule=unmask_schedule,
|
| 51 |
+
vocab_size=config.interpolant.tokens,
|
| 52 |
+
mask_token=config.interpolant.mask_token,
|
| 53 |
+
pad_token=config.interpolant.pad_token,
|
| 54 |
+
max_length=config.interpolant.max_length,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Save hyperparameters
|
| 58 |
+
self.save_hyperparameters()
|
| 59 |
+
|
| 60 |
+
self.ema_decay = config.training.ema_decay or 0.0
|
| 61 |
+
self.use_ema = self.ema_decay > 0
|
| 62 |
+
self._orig_params = {}
|
| 63 |
+
|
| 64 |
+
def forward(self, x, t, return_features: bool = False):
|
| 65 |
+
if self.config.training.only_embed_insert:
|
| 66 |
+
result = self.model(x, self.interpolant.insertion_schedule.at(t), return_features=return_features)
|
| 67 |
+
else:
|
| 68 |
+
result = self.model(x, t, return_features=return_features)
|
| 69 |
+
return result
|
| 70 |
+
|
| 71 |
+
def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
|
| 72 |
+
"""Delegate to backbone transformer for RemaskingAnyOrder compatibility."""
|
| 73 |
+
return self.model.get_hidden_states(indices, t)
|
| 74 |
+
|
| 75 |
+
def training_loss(self, x1, t):
|
| 76 |
+
interpolant_sample = self.interpolant.sample_interpolant(t, x1)
|
| 77 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x1)
|
| 78 |
+
|
| 79 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 80 |
+
|
| 81 |
+
scale_factor = x1.shape[0] * self.config.interpolant.max_length
|
| 82 |
+
|
| 83 |
+
match self.unmask_loss_fn:
|
| 84 |
+
case "elbo":
|
| 85 |
+
mask_indices = interpolant_sample.mask_indices
|
| 86 |
+
unmask_loss = unmask_weight[mask_indices] * F.cross_entropy(
|
| 87 |
+
prediction.token_logits[mask_indices],
|
| 88 |
+
interpolant_sample.unmasked[mask_indices],
|
| 89 |
+
reduction="none",
|
| 90 |
+
)
|
| 91 |
+
unmask_loss = unmask_loss.sum() / scale_factor
|
| 92 |
+
case _:
|
| 93 |
+
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
|
| 94 |
+
|
| 95 |
+
match self.insert_loss_fn:
|
| 96 |
+
case "expectation":
|
| 97 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 98 |
+
insertion_loss = insert_weight[gaps_mask] * jump_kernel_elbo(
|
| 99 |
+
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
|
| 100 |
+
)
|
| 101 |
+
insertion_loss = insertion_loss.sum() / scale_factor
|
| 102 |
+
|
| 103 |
+
case "distribution":
|
| 104 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 105 |
+
insertion_loss = insert_weight[gaps_mask] * F.cross_entropy(
|
| 106 |
+
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
|
| 107 |
+
)
|
| 108 |
+
insertion_loss = insertion_loss.sum() / scale_factor
|
| 109 |
+
|
| 110 |
+
total_loss = unmask_loss + insertion_loss
|
| 111 |
+
return unmask_loss, insertion_loss, total_loss
|
| 112 |
+
|
| 113 |
+
def prepare_noised_sample(self, x, num_samples=1, t=None):
|
| 114 |
+
"""
|
| 115 |
+
Run the forward noising process on clean sequences x.
|
| 116 |
+
Replicates each sequence num_samples times with independent random times
|
| 117 |
+
so that both policy and pretrained can evaluate the same noised data.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
x: [B, L] clean token sequences (no mask tokens)
|
| 121 |
+
num_samples: K, number of noisy time samples per sequence
|
| 122 |
+
t: [B*K] optional time values. If None, sampled uniformly.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
dict with all artifacts needed by compute_loss_from_noised.
|
| 126 |
+
"""
|
| 127 |
+
B = x.shape[0]
|
| 128 |
+
x_rep = x.repeat_interleave(num_samples, dim=0) # [B*K, L]
|
| 129 |
+
if t is None:
|
| 130 |
+
t = torch.rand(B * num_samples, device=x.device)
|
| 131 |
+
|
| 132 |
+
interpolant_sample = self.interpolant.sample_interpolant(t, x_rep)
|
| 133 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x_rep)
|
| 134 |
+
scale_factor = self.config.interpolant.max_length
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"interpolant_sample": interpolant_sample,
|
| 138 |
+
"unmask_weight": unmask_weight,
|
| 139 |
+
"insert_weight": insert_weight,
|
| 140 |
+
"t": t,
|
| 141 |
+
"scale_factor": scale_factor,
|
| 142 |
+
"num_samples": num_samples,
|
| 143 |
+
"batch_size": B,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def compute_loss_from_noised(self, noised):
|
| 147 |
+
"""
|
| 148 |
+
Compute per-sample denoising loss given pre-noised data.
|
| 149 |
+
Each model runs its own forward pass on the shared noised xt.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
noised: dict from prepare_noised_sample()
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
total_loss: [B] per-sample loss averaged over K noisy samples
|
| 156 |
+
"""
|
| 157 |
+
interpolant_sample = noised["interpolant_sample"]
|
| 158 |
+
unmask_weight = noised["unmask_weight"]
|
| 159 |
+
insert_weight = noised["insert_weight"]
|
| 160 |
+
t = noised["t"]
|
| 161 |
+
scale_factor = noised["scale_factor"]
|
| 162 |
+
num_samples = noised["num_samples"]
|
| 163 |
+
B = noised["batch_size"]
|
| 164 |
+
|
| 165 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 166 |
+
|
| 167 |
+
match self.unmask_loss_fn:
|
| 168 |
+
case "elbo":
|
| 169 |
+
mask_indices = interpolant_sample.mask_indices
|
| 170 |
+
unmask_loss_all = torch.zeros_like(unmask_weight) # [B*K, L]
|
| 171 |
+
unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
|
| 172 |
+
prediction.token_logits[mask_indices],
|
| 173 |
+
interpolant_sample.unmasked[mask_indices],
|
| 174 |
+
reduction="none",
|
| 175 |
+
)
|
| 176 |
+
unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*K]
|
| 177 |
+
case _:
|
| 178 |
+
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
|
| 179 |
+
|
| 180 |
+
match self.insert_loss_fn:
|
| 181 |
+
case "expectation":
|
| 182 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 183 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
|
| 184 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
|
| 185 |
+
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
|
| 186 |
+
)
|
| 187 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
|
| 188 |
+
case "distribution":
|
| 189 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 190 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
|
| 191 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
|
| 192 |
+
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
|
| 193 |
+
)
|
| 194 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
|
| 195 |
+
|
| 196 |
+
per_replicate_loss = unmask_loss + insertion_loss # [B*K]
|
| 197 |
+
per_sample_loss = per_replicate_loss.view(B, num_samples).mean(dim=1) # [B]
|
| 198 |
+
return per_sample_loss
|
| 199 |
+
|
| 200 |
+
def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False):
|
| 201 |
+
r"""
|
| 202 |
+
Weighted denoising cross entropy loss
|
| 203 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 204 |
+
|
| 205 |
+
log_rnd: [B]; x: [B, L] (no mask)
|
| 206 |
+
num_replicates: R, number of replicates of each row in x
|
| 207 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
print("logrnd shape:", log_rnd.shape)
|
| 211 |
+
print("x shape:", x.shape)
|
| 212 |
+
|
| 213 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 214 |
+
|
| 215 |
+
batch_weights = log_rnd.detach().softmax(dim=-1) # [B*R]
|
| 216 |
+
if centering:
|
| 217 |
+
batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
|
| 218 |
+
|
| 219 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 220 |
+
|
| 221 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 222 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 223 |
+
|
| 224 |
+
t = lamda
|
| 225 |
+
|
| 226 |
+
# compute unmasking and insertion loss
|
| 227 |
+
interpolant_sample = self.interpolant.sample_interpolant(t, batch)
|
| 228 |
+
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
|
| 229 |
+
|
| 230 |
+
prediction: ModelPrediction = self(interpolant_sample.xt, t)
|
| 231 |
+
|
| 232 |
+
scale_factor = self.config.interpolant.max_length
|
| 233 |
+
|
| 234 |
+
match self.unmask_loss_fn:
|
| 235 |
+
case "elbo":
|
| 236 |
+
mask_indices = interpolant_sample.mask_indices
|
| 237 |
+
unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
|
| 238 |
+
unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
|
| 239 |
+
prediction.token_logits[mask_indices],
|
| 240 |
+
interpolant_sample.unmasked[mask_indices],
|
| 241 |
+
reduction="none",
|
| 242 |
+
)
|
| 243 |
+
unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 244 |
+
case _:
|
| 245 |
+
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
|
| 246 |
+
|
| 247 |
+
match self.insert_loss_fn:
|
| 248 |
+
case "expectation":
|
| 249 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 250 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
|
| 251 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
|
| 252 |
+
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
|
| 253 |
+
)
|
| 254 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 255 |
+
|
| 256 |
+
case "distribution":
|
| 257 |
+
gaps, gaps_mask = interpolant_sample.gaps_and_mask
|
| 258 |
+
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
|
| 259 |
+
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
|
| 260 |
+
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
|
| 261 |
+
)
|
| 262 |
+
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
|
| 263 |
+
|
| 264 |
+
total_loss = unmask_loss + insertion_loss # [B*R]
|
| 265 |
+
# end compute unmasking and insertion loss
|
| 266 |
+
|
| 267 |
+
weighted_loss = total_loss * batch_weights # [B*R]
|
| 268 |
+
return weighted_loss.mean()
|
| 269 |
+
|
| 270 |
+
def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 271 |
+
eps = 1e-6
|
| 272 |
+
interval = 1.0 - eps
|
| 273 |
+
interval_size = interval / batch_size
|
| 274 |
+
u = torch.rand(batch_size, device=device)
|
| 275 |
+
return (torch.arange(batch_size, device=device, dtype=u.dtype) + u) * interval_size
|
| 276 |
+
|
| 277 |
+
def training_step(self, batch, batch_idx):
|
| 278 |
+
# Extract input data
|
| 279 |
+
if isinstance(batch, dict):
|
| 280 |
+
batch = batch["input_ids"]
|
| 281 |
+
|
| 282 |
+
x1 = batch
|
| 283 |
+
t = self.sample_time(x1.shape[0], x1.device)
|
| 284 |
+
|
| 285 |
+
# Calculate the combined loss normally
|
| 286 |
+
unmask_loss, len_loss, loss = self.training_loss(x1, t)
|
| 287 |
+
|
| 288 |
+
# Log component losses
|
| 289 |
+
self.log("train/unmask_loss", unmask_loss, prog_bar=True)
|
| 290 |
+
self.log("train/len_loss", len_loss, prog_bar=True)
|
| 291 |
+
self.log("train/total_loss", loss, prog_bar=True)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
return loss
|
| 295 |
+
|
| 296 |
+
def validation_step(self, batch, batch_idx):
|
| 297 |
+
if isinstance(batch, dict):
|
| 298 |
+
batch = batch["input_ids"]
|
| 299 |
+
|
| 300 |
+
x1 = batch
|
| 301 |
+
t = self.sample_time(x1.shape[0], x1.device)
|
| 302 |
+
unmask_loss, len_loss, loss = self.training_loss(x1, t)
|
| 303 |
+
|
| 304 |
+
self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
|
| 305 |
+
self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
|
| 306 |
+
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
| 307 |
+
|
| 308 |
+
return loss
|
| 309 |
+
|
| 310 |
+
def configure_optimizers(self):
|
| 311 |
+
optimizer = torch.optim.AdamW(
|
| 312 |
+
self.parameters(),
|
| 313 |
+
lr=self.learning_rate,
|
| 314 |
+
weight_decay=self.config.training.weight_decay,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
warmup_steps = self.config.training.warmup_steps
|
| 318 |
+
max_steps = self.config.training.max_steps
|
| 319 |
+
|
| 320 |
+
# Always create a fresh schedule starting from step 0
|
| 321 |
+
# This allows extending training beyond original max_steps
|
| 322 |
+
linear_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 323 |
+
optimizer,
|
| 324 |
+
start_factor=1e-6,
|
| 325 |
+
end_factor=1.0,
|
| 326 |
+
total_iters=warmup_steps,
|
| 327 |
+
last_epoch=-1,
|
| 328 |
+
)
|
| 329 |
+
post_warmup = max_steps - warmup_steps
|
| 330 |
+
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 331 |
+
optimizer,
|
| 332 |
+
T_max=post_warmup,
|
| 333 |
+
eta_min=0.0,
|
| 334 |
+
last_epoch=-1,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 338 |
+
optimizer,
|
| 339 |
+
schedulers=[linear_scheduler, cosine_scheduler],
|
| 340 |
+
milestones=[warmup_steps],
|
| 341 |
+
last_epoch=-1,
|
| 342 |
+
)
|
| 343 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
| 344 |
+
|
| 345 |
+
def optimizer_step(
|
| 346 |
+
self,
|
| 347 |
+
epoch: int,
|
| 348 |
+
batch_idx: int,
|
| 349 |
+
optimizer,
|
| 350 |
+
optimizer_closure=None,
|
| 351 |
+
):
|
| 352 |
+
super().optimizer_step(
|
| 353 |
+
epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
|
| 354 |
+
)
|
| 355 |
+
# log learning rate and gradient norm
|
| 356 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 357 |
+
self.log("train/lr", lr, on_step=True, prog_bar=True)
|
| 358 |
+
grad_norm = torch.sqrt(
|
| 359 |
+
sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
|
| 360 |
+
)
|
| 361 |
+
self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
|
| 362 |
+
|
| 363 |
+
# update EMA
|
| 364 |
+
if self.use_ema:
|
| 365 |
+
for n, p in self.named_parameters():
|
| 366 |
+
self.ema_params[n].mul_(self.ema_decay).add_(
|
| 367 |
+
p.data.clone().detach(), alpha=1 - self.ema_decay
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def on_save_checkpoint(self, checkpoint):
|
| 371 |
+
checkpoint["config"] = self.config
|
| 372 |
+
# save EMA state
|
| 373 |
+
if self.use_ema:
|
| 374 |
+
checkpoint["ema_params"] = {
|
| 375 |
+
n: v.clone() for n, v in self.ema_params.items()
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def on_load_checkpoint(self, checkpoint):
|
| 379 |
+
self.config = checkpoint["config"]
|
| 380 |
+
|
| 381 |
+
insert_schedule = get_schedule_from_config(
|
| 382 |
+
self.config.interpolant.insert_schedule
|
| 383 |
+
)
|
| 384 |
+
unmask_schedule = get_schedule_from_config(
|
| 385 |
+
self.config.interpolant.unmask_schedule
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
self.interpolant = AnyOrderMaskInsertionInterpolant(
|
| 389 |
+
insertion_schedule=insert_schedule,
|
| 390 |
+
unmask_schedule=unmask_schedule,
|
| 391 |
+
vocab_size=self.config.interpolant.tokens,
|
| 392 |
+
mask_token=self.config.interpolant.mask_token,
|
| 393 |
+
pad_token=self.config.interpolant.pad_token,
|
| 394 |
+
max_length=self.config.interpolant.max_length,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
|
| 398 |
+
|
| 399 |
+
def swap_to_ema(self):
|
| 400 |
+
for name, p in self.named_parameters():
|
| 401 |
+
self._orig_params[name] = p.data.clone()
|
| 402 |
+
p.data.copy_(self.ema_params[name].to(p.device))
|
| 403 |
+
|
| 404 |
+
def restore_original(self):
|
| 405 |
+
for name, p in self.named_parameters():
|
| 406 |
+
p.data.copy_(self._orig_params[name])
|
| 407 |
+
self._orig_params.clear()
|
| 408 |
+
|
| 409 |
+
def on_train_start(self):
|
| 410 |
+
# initialize and move EMA buffers once model is on correct device
|
| 411 |
+
if self.use_ema:
|
| 412 |
+
self.ema_params = {
|
| 413 |
+
name: param.clone().detach().to(self.device)
|
| 414 |
+
for name, param in self.named_parameters()
|
| 415 |
+
}
|
| 416 |
+
for buf in self.ema_params.values():
|
| 417 |
+
buf.requires_grad = False
|