Sophia commited on
Commit
8019be0
·
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +16 -0
  3. LICENSE +21 -0
  4. README.md +62 -0
  5. a2d2_mol/README.md +132 -0
  6. a2d2_mol/config_mol.yaml +54 -0
  7. a2d2_mol/evaluate_mol_table.py +308 -0
  8. a2d2_mol/finetune_mol.py +747 -0
  9. a2d2_mol/inference_quality_mol.py +554 -0
  10. a2d2_mol/mol_dataset.py +379 -0
  11. a2d2_mol/mol_scoring/oracle/fpscores.pkl +3 -0
  12. a2d2_mol/mol_scoring/scoring_functions.py +68 -0
  13. a2d2_mol/mol_utils/bracket_safe_converter.py +159 -0
  14. a2d2_mol/mol_utils/utils.py +135 -0
  15. a2d2_mol/mol_utils/utils_chem.py +187 -0
  16. a2d2_mol/oracle/fpscores.pkl +3 -0
  17. a2d2_mol/remasking_scheduleaware.py +177 -0
  18. a2d2_mol/sampling.py +1401 -0
  19. a2d2_mol/scripts/run_mol_finetune.slurm +200 -0
  20. a2d2_mol/scripts/train_mol.sh +93 -0
  21. a2d2_mol/train.py +216 -0
  22. a2d2_pep/README.md +145 -0
  23. a2d2_pep/config_pep.yaml +50 -0
  24. a2d2_pep/data/dataloading_for_dynamic_batching.py +189 -0
  25. a2d2_pep/data/dataset.py +207 -0
  26. a2d2_pep/evaluate_peptide_table.py +326 -0
  27. a2d2_pep/finetune_quality.py +892 -0
  28. a2d2_pep/inference_quality.py +605 -0
  29. a2d2_pep/pep_scoring/functions/binding.py +178 -0
  30. a2d2_pep/pep_scoring/functions/binding_utils.py +290 -0
  31. a2d2_pep/pep_scoring/functions/hemolysis.py +63 -0
  32. a2d2_pep/pep_scoring/functions/nonfouling.py +66 -0
  33. a2d2_pep/pep_scoring/functions/permeability.py +170 -0
  34. a2d2_pep/pep_scoring/functions/scoring_utils.py +94 -0
  35. a2d2_pep/pep_scoring/functions/solubility.py +63 -0
  36. a2d2_pep/pep_scoring/scoring_functions.py +79 -0
  37. a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py +424 -0
  38. a2d2_pep/pep_utils/analyzer.py +1274 -0
  39. a2d2_pep/pep_utils/utils.py +135 -0
  40. a2d2_pep/remasking_scheduleaware.py +181 -0
  41. a2d2_pep/sampling.py +1401 -0
  42. a2d2_pep/scripts/run_peptide_finetune.slurm +210 -0
  43. a2d2_pep/scripts/train_pep.sh +93 -0
  44. a2d2_pep/train.py +216 -0
  45. assets/a2d2.gif +3 -0
  46. demo/quality_inference_demo.ipynb +0 -0
  47. environment.yml +57 -0
  48. lightning_modules/__init__.py +16 -0
  49. lightning_modules/any_length_remask.py +801 -0
  50. lightning_modules/any_order.py +417 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.gif filter=lfs diff=lfs merge=lfs -text
2
+ *.pkl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ pretrained/
3
+ __pycache__/
4
+ results/
5
+ a2d2_language/
6
+ a2d2_language/wandb/
7
+ a2d2_pep/wandb/
8
+ a2d2_mol/wandb/
9
+ logs/
10
+ *.pt
11
+ *.pyc
12
+ *.out
13
+ *.json
14
+ *.log
15
+ *.txt
16
+ *.wandb
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sophia Tang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding](https://arxiv.org/abs/2606.13565) 🃏🔮
2
+
3
+ [**Sophia Tang**](https://sophtang.github.io/), [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/), [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
4
+
5
+ <p>
6
+ <a href="https://arxiv.org/abs/2606.13565"><img src="https://img.shields.io/badge/arXiv-6B67EE?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a>
7
+ <a href="https://sophtang.github.io/a2d2/"><img src="https://img.shields.io/badge/Project_Page-6B67EE?style=for-the-badge&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHZpZXdCb3g9IjAgMCAyNCAyNCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiBmaWxsPSJ3aGl0ZSI+PHBhdGggZD0iTTEwLjUgMS41QzExIDcuMyAxMy4yIDkuNSAxOSAxMEMxMy4yIDEwLjUgMTEgMTIuNyAxMC41IDE4LjVDMTAgMTIuNyA3LjggMTAuNSAyIDEwQzcuOCA5LjUgMTAgNy4zIDEwLjUgMS41WiIvPjxwYXRoIGQ9Ik0xOC41IDEzLjVDMTguNyAxNS44IDE5LjcgMTYuOCAyMiAxN0MxOS43IDE3LjIgMTguNyAxOC4yIDE4LjUgMjAuNUMxOC4zIDE4LjIgMTcuMyAxNy4yIDE1IDE3QzE3LjMgMTYuOCAxOC4zIDE1LjggMTguNSAxMy41WiIvPjxwYXRoIGQ9Ik01IDE1LjVDNS4xMiAxNyA1LjUgMTcuMzggNyAxNy41QzUuNSAxNy42MiA1LjEyIDE4IDUgMTkuNUM0Ljg4IDE4IDQuNSAxNy42MiAzIDE3LjVDNC41IDE3LjM4IDQuODggMTcgNSAxNS41WiIvPjwvc3ZnPg==" alt="Project Page"></a>
8
+ </p>
9
+
10
+ ![A2D2](assets/a2d2.gif)
11
+
12
+ This is the repository for the paper [**A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**](https://arxiv.org/abs/2606.13565).
13
+
14
+ Masked discrete diffusion models (MDMs) offer a simple, stable likelihood-based framework for sequence generation, recently extended to **any-length** settings via token insertion. **A2D2** is a unified framework for reward-guided fine-tuning of any-length MDMs that **jointly optimizes the insertion and unmasking policies together with a quality-based inference schedule**, converging to the intractable reward-tilted distribution without requiring target samples.
15
+
16
+ 🃏 We derive the **Radon–Nikodym derivative** for the joint insertion–unmasking path measures, enabling theoretically guaranteed convergence to the reward-tilted sequence distribution.
17
+
18
+ 🃏 We establish **unmasking and insertion quality** as tractable approaches for minimizing decoding error (compounding parallelization error), and train lightweight quality predictors alongside the policy.
19
+
20
+ 🃏 We introduce the **Adaptive Joint Decoding (AJD)** loss, which provably yields the optimal path measure that generates the reward-tilted distribution while remasking low-quality tokens and dropping low-quality insertions at inference.
21
+
22
+ 🃏 Empirically, A2D2 improves reward optimization while enhancing generation **flexibility** and **accuracy** over prior fixed-length fine-tuning and inference-time guidance methods.
23
+
24
+ ## Drug-Like Small Molecule Design 🧪
25
+
26
+ We pre-train an any-length MDM on the **SAFE** dataset ([Noutahi et al. 2024](https://arxiv.org/abs/2310.10773), ~950M molecules from ZINC and Unichem in SAFE notation) and fine-tune it with **A2D2** to optimize **QED** (drug-likeness) and **synthetic accessibility (SA)**. A2D2 jointly raises QED and lowers SA over the pre-trained baseline while increasing the fraction of valid, unique, drug-like, and synthesizable molecules. Code and instructions are in [`/a2d2_mol`](a2d2_mol).
27
+
28
+ ## Multi-Objective Therapeutic Peptide Generation 💉
29
+
30
+ We pre-train an any-length **peptide SMILES** MDM on ~11M peptides (CycPeptMPDB, SmProt, CycloPs) and fine-tune with **A2D2** on five therapeutic properties simultaneously: **target-protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability**. A2D2 outperforms inference-time multi-objective guidance and fixed-length off-policy RL fine-tuning on almost all objectives, while improving the fraction of valid peptides. Code and instructions are in [`/a2d2_pep`](a2d2_pep).
31
+
32
+ ## Language Model Reasoning 🧠
33
+
34
+ We additionally apply **A2D2** to reward fine-tuning of any-length language MDMs (LLaDA / FlexMDM), optimizing math-reasoning correctness and format rewards (GSM8K / MATH), including infilling variants. Code is in [`/a2d2_language`](a2d2_language).
35
+
36
+ ## Repository Structure
37
+
38
+ | Directory | Experiment |
39
+ |-----------|------------|
40
+ | [`a2d2_mol`](a2d2_mol) | Drug-like small molecule design (QED, SA) |
41
+ | [`a2d2_pep`](a2d2_pep) | Multi-objective therapeutic peptide generation |
42
+ | [`a2d2_language`](a2d2_language) | Language model reasoning reward fine-tuning (code soon) |
43
+ | [`lightning_modules`](lightning_modules) | Any-length insertion MDM Lightning modules (policy + quality predictors) |
44
+ | [`model`](model) | Shared model architecture |
45
+ | [`demo`](demo) | Quality-guided inference demo notebook |
46
+
47
+ Each experiment directory contains its own `README.md` with environment setup, pretrained weight placement, fine-tuning commands, and evaluation instructions.
48
+
49
+ ## Citation
50
+
51
+ If you find this repository helpful for your publications, please consider citing our paper:
52
+
53
+ ```python
54
+ @article{tang2026a2d2,
55
+ title={A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding},
56
+ author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee},
57
+ journal={arXiv preprint arXiv:2606.13565},
58
+ year={2026}
59
+ }
60
+ ```
61
+
62
+ To use this repository, you agree to abide by the MIT License.
a2d2_mol/README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A2D2 for Molecule Generation 🧪
2
+
3
+ This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA).
4
+
5
+ A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
6
+
7
+ Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer.
8
+
9
+ The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
10
+
11
+ ## Environment Installation
12
+ ```
13
+ # from the repository root
14
+ conda env create -f environment.yml
15
+
16
+ conda activate a2d2
17
+ ```
18
+ The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
19
+
20
+ ## Model Pretrained Weights
21
+
22
+ A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at:
23
+ ```
24
+ A2D2/pretrained/anylength_mol.ckpt
25
+ ```
26
+ ```bash
27
+ # from the repository root
28
+ pip install gdown
29
+ mkdir -p pretrained
30
+ gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt
31
+ ```
32
+ (Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
33
+ This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout.
34
+
35
+ ## Pretraining the Any-Length Model
36
+
37
+ If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
38
+
39
+ ### 1. The pretraining dataset
40
+
41
+ The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run.
42
+
43
+ The dataset is configured in [`config_mol.yaml`](config_mol.yaml):
44
+
45
+ ```yaml
46
+ hf_dataset:
47
+ name: "datamol-io/safe-gpt"
48
+ smiles_column: "smiles"
49
+ ```
50
+
51
+ To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column).
52
+
53
+ ### 2. Configure
54
+
55
+ Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields:
56
+
57
+ | Field | Default | Notes |
58
+ |-------|---------|-------|
59
+ | `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). |
60
+ | `training.devices` | `2` | GPUs per node (DDP). |
61
+ | `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
62
+ | `training.max_steps` | `500000` | Total optimizer steps. |
63
+ | `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
64
+ | `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). |
65
+ | `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. |
66
+ | `interpolant.max_length` | `256` | Max token length. |
67
+
68
+ ### 3. Pre-training Any-Length Molecule Model
69
+
70
+ Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
71
+
72
+ ```bash
73
+ # from a2d2_mol/
74
+ sbatch train_mol.sh
75
+ ```
76
+
77
+ `train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
78
+
79
+ ```bash
80
+ python train.py --task mol
81
+ ```
82
+
83
+ It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`.
84
+
85
+ Checkpoints are written to `checkpoints/pretrain_mol/<timestamp>/` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/<date>_a2d2-mol_<jobid>.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
86
+
87
+ ## Fine-Tune with A2D2
88
+
89
+ The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running:
90
+
91
+ 1. Set `--base_path` to the location of `a2d2`. Results plots are written to `<base_path>/flexible/results/<run_name>/`.
92
+ 2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`.
93
+
94
+ ### Single run
95
+
96
+ [`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand.
97
+
98
+ The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`):
99
+ ```bash
100
+ export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
101
+ export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
102
+ sbatch scripts/run_mol_finetune.slurm
103
+ ```
104
+
105
+ Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
106
+ ```bash
107
+ sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm
108
+ ```
109
+ The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/<job>_mol_<mode>/` and `results/mol_ablation/<mode>/`.
110
+
111
+ ### Ablation flags
112
+ | Flag | Variant |
113
+ |------|---------|
114
+ | *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
115
+ | `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
116
+ | `--disable_insertion_planner` | A2D2 w/o insertion quality |
117
+ | `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
118
+ | `--joint_training` | train policy + quality heads jointly (no alternation) |
119
+
120
+ ## Evaluation
121
+
122
+ Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually:
123
+ ```
124
+ python evaluate_mol_table.py \
125
+ --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \
126
+ --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \
127
+ --output_dir /path/to/results \
128
+ --num_samples 1000 --batch_size 50 \
129
+ --max_length 256 --total_num_steps 256 \
130
+ --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0
131
+ ```
132
+ This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_<mode>.csv`.
a2d2_mol/config_mol.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: "any-order-flow"
2
+ dataset: "safe-drugs"
3
+
4
+ # HuggingFace dataset configuration
5
+ hf_dataset:
6
+ name: "datamol-io/safe-gpt"
7
+ smiles_column: "smiles" # Adjust based on actual column name in the dataset
8
+
9
+ model:
10
+ hidden_size: 768
11
+ n_heads: 12
12
+ cond_dim: 128
13
+ dropout: 0.05
14
+ n_blocks: 12
15
+ torch_dtype: 'float32' # Options: 'float32', 'float16', 'bfloat16'
16
+
17
+ interpolant:
18
+ type: "any-order"
19
+ tokens: null # filled in automatically
20
+ pad_token: null # filled in automatically
21
+ mask_token: null # filled in automatically
22
+ max_length: 256
23
+ insert_schedule:
24
+ type: "linear"
25
+ unmask_schedule:
26
+ type: "linear"
27
+
28
+ training:
29
+ only_embed_insert: true
30
+ batch_size: 2048
31
+ per_gpu_batch_size: 64 # Gradient accumulation happens automatically
32
+ cpus: 4
33
+ learning_rate: 3e-4
34
+ nodes: 1
35
+ devices: 2
36
+ max_steps: 500000
37
+ weight_decay: 0.03
38
+ checkpoint_dir: "checkpoints/pretrain_mol"
39
+ save_top_k: 3
40
+ save_every_n_steps: 1000 # Save checkpoint every 1k steps (for streaming datasets)
41
+ # save_every_n_epochs: 1 # Not used with streaming datasets
42
+ loss_fn:
43
+ unmask: "elbo"
44
+ insert: "expectation"
45
+ reset_lr: false
46
+ warmup_steps: 2000
47
+ ema_decay: 0.9999
48
+ filter_max_length: false
49
+
50
+ wandb:
51
+ entity: null # set to your W&B entity, or leave null to use the default
52
+ project: "a2d2-mol"
53
+ name: "a2d2-mol"
54
+ path: "./wandb"
a2d2_mol/evaluate_mol_table.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate a finetuned molecule model checkpoint by sampling sequences
3
+ and computing metrics for the De Novo Small Molecule Generation table:
4
+ Validity (%), Uniqueness (%), QED (↑), SA (↓), Quality (%), Diversity (↑), Sampling Time (↓)
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import time
11
+ import torch
12
+ import numpy as np
13
+ import pandas as pd
14
+ from tdc import Oracle, Evaluator
15
+
16
+ # add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
17
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
+ sys.path.insert(0, REPO_ROOT)
19
+
20
+ from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
21
+ from lightning_modules import AnyOrderInsertionFlowModule
22
+ from inference_quality_mol import sample_mol_eval
23
+ from mol_scoring.scoring_functions import MolScoringFunctions
24
+ from finetune_mol import MolFinetuner, get_tokenizer
25
+ from mol_utils.utils import str2bool, set_seed
26
+
27
+
28
+ def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
29
+ """Load a finetuned MolFinetuner from a Lightning checkpoint."""
30
+ # We need to reconstruct the model the same way main() does, then load state
31
+ # Load from Lightning checkpoint directly
32
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
33
+ hparams = ckpt.get('hyper_parameters', {})
34
+ args = hparams.get('args', None)
35
+
36
+ # Load pretrained base checkpoint to get config
37
+ base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
38
+ if 'hyper_parameters' in base_ckpt:
39
+ config = base_ckpt['hyper_parameters']['config']
40
+ elif 'config' in base_ckpt:
41
+ config = base_ckpt['config']
42
+ else:
43
+ raise ValueError("Cannot find config in base checkpoint")
44
+
45
+ from omegaconf import OmegaConf, DictConfig
46
+ if not OmegaConf.is_config(config):
47
+ config = DictConfig(config)
48
+ OmegaConf.set_struct(config, False)
49
+
50
+ # Set adaptive schedule config from args or defaults
51
+ config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
52
+ config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
53
+ config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
54
+ config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
55
+ config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
56
+ config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
57
+ config.training.use_bracket_safe = True
58
+ OmegaConf.set_struct(config, True)
59
+
60
+ # Determine if planner should be loaded based on disable_planner flag
61
+ disable_planner = getattr(args, 'disable_planner', False)
62
+
63
+ # Initialize policy model
64
+ policy_model = AnyOrderInsertionFlowModuleFT(
65
+ config=config,
66
+ args=args,
67
+ pretrained_checkpoint=pretrained_ckpt_path,
68
+ insertion_planner=not disable_planner,
69
+ )
70
+
71
+ # Load policy model weights from the finetuned checkpoint
72
+ state_dict = ckpt['state_dict']
73
+ # Lightning wraps the model: 'policy_model.xxx' -> remove prefix for the sub-module
74
+ policy_state = {}
75
+ for k, v in state_dict.items():
76
+ if k.startswith('policy_model.'):
77
+ policy_state[k[len('policy_model.'):]] = v
78
+ policy_model.load_state_dict(policy_state, strict=False)
79
+ policy_model = policy_model.to(device)
80
+ policy_model.eval()
81
+
82
+ return policy_model, args, config
83
+
84
+
85
+ @torch.no_grad()
86
+ def evaluate_checkpoint(policy_model, tokenizer, reward_model, evaluator,
87
+ num_samples=1000, batch_size=50, max_length=256,
88
+ total_num_steps=256, quality_mode="both", num_remasking=2,
89
+ quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
90
+ """
91
+ Sample `num_samples` molecules and compute all table metrics.
92
+ Returns a dict with: validity, uniqueness, qed, sa, quality, diversity, sampling_time
93
+ """
94
+ all_valid_seqs = []
95
+ all_smiles_generated = 0
96
+ total_time = 0.0
97
+
98
+ num_batches = (num_samples + batch_size - 1) // batch_size
99
+ remaining = num_samples
100
+
101
+ for b in range(num_batches):
102
+ bs = min(batch_size, remaining)
103
+ remaining -= bs
104
+
105
+ t_start = time.time()
106
+ result = sample_mol_eval(
107
+ model=policy_model,
108
+ reward_model=reward_model,
109
+ tokenizer=tokenizer,
110
+ steps=total_num_steps,
111
+ mask=policy_model.interpolant.mask_token,
112
+ pad=policy_model.interpolant.pad_token,
113
+ batch_size=bs,
114
+ max_length=max_length,
115
+ quality_mode=quality_mode,
116
+ num_remasking=num_remasking,
117
+ quality_threshold=quality_threshold,
118
+ unmask_quality_threshold=unmask_quality_threshold,
119
+ evaluator=evaluator,
120
+ dataframe=True,
121
+ )
122
+ t_end = time.time()
123
+
124
+ # Unpack: uniqueSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
125
+ unique_seqs, qed_scores, sa_scores, valid_frac, uniq, div, qual, df = result
126
+
127
+ all_valid_seqs.extend(list(unique_seqs) if not isinstance(unique_seqs, list) else unique_seqs)
128
+ all_smiles_generated += bs
129
+ total_time += (t_end - t_start)
130
+
131
+ print(f" Batch {b+1}/{num_batches}: {len(unique_seqs)} valid unique, "
132
+ f"time={t_end - t_start:.1f}s")
133
+
134
+ # --- Aggregate metrics over all samples ---
135
+ total_generated = num_samples
136
+
137
+ # Valid sequences (keeping duplicates for validity count)
138
+ # Re-evaluate from scratch on all collected valid sequences
139
+ all_unique = list(set(all_valid_seqs))
140
+ num_valid = len(all_valid_seqs) # total valid across batches (before dedup)
141
+ num_unique = len(all_unique)
142
+
143
+ validity = num_valid / total_generated * 100.0
144
+ uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
145
+
146
+ # Diversity on unique SMILES
147
+ diversity = evaluator(all_unique) if num_unique > 1 else 0.0
148
+
149
+ # QED and SA on unique sequences
150
+ if num_unique > 0:
151
+ oracle_qed = Oracle('qed')
152
+ oracle_sa = Oracle('sa')
153
+ qed_vals = oracle_qed(all_unique)
154
+ sa_vals = oracle_sa(all_unique)
155
+ mean_qed = np.mean(qed_vals)
156
+ mean_sa = np.mean(sa_vals)
157
+
158
+ # Quality: unique sequences with QED >= 0.6 AND SA <= 4
159
+ quality_mask = [(q >= 0.6 and s <= 4) for q, s in zip(qed_vals, sa_vals)]
160
+ quality = sum(quality_mask) / total_generated * 100.0
161
+ else:
162
+ mean_qed = 0.0
163
+ mean_sa = 0.0
164
+ quality = 0.0
165
+
166
+ sampling_time = total_time
167
+
168
+ metrics = {
169
+ 'Validity (%)': validity,
170
+ 'Uniqueness (%)': uniqueness,
171
+ 'QED': mean_qed,
172
+ 'Synthetic Accessibility': mean_sa,
173
+ 'Quality (%)': quality,
174
+ 'Diversity': diversity,
175
+ 'Sampling Time (s)': sampling_time,
176
+ 'Num Generated': total_generated,
177
+ 'Num Valid': num_valid,
178
+ 'Num Unique': num_unique,
179
+ }
180
+
181
+ return metrics, all_unique, qed_vals if num_unique > 0 else [], sa_vals if num_unique > 0 else []
182
+
183
+
184
+ def main():
185
+ parser = argparse.ArgumentParser(description="Evaluate a finetuned mol checkpoint")
186
+ parser.add_argument('--checkpoint_path', type=str, required=True,
187
+ help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
188
+ parser.add_argument('--pretrained_ckpt', type=str,
189
+ default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt'),
190
+ help='Path to the pretrained base model checkpoint '
191
+ '(defaults to <repo>/pretrained/anylength_mol.ckpt)')
192
+ parser.add_argument('--num_samples', type=int, default=1000,
193
+ help='Number of molecules to sample')
194
+ parser.add_argument('--batch_size', type=int, default=50,
195
+ help='Batch size for sampling')
196
+ parser.add_argument('--max_length', type=int, default=256)
197
+ parser.add_argument('--total_num_steps', type=int, default=256)
198
+ parser.add_argument('--num_remasking', type=int, default=2)
199
+ parser.add_argument('--disable_planner', action='store_true',
200
+ help='If set, disable remasking during evaluation (matches training mode)')
201
+ parser.add_argument('--disable_insertion_planner', action='store_true',
202
+ help='If set, disable insertion quality filtering during evaluation')
203
+ parser.add_argument('--disable_unmasking_planner', action='store_true',
204
+ help='If set, disable unmasking confidence planner during evaluation')
205
+ parser.add_argument('--quality_threshold', type=float, default=0.5,
206
+ help='Threshold for insertion quality filtering during sampling')
207
+ parser.add_argument('--unmask_quality_threshold', type=float, default=None,
208
+ help='If set, gate unmasking remasking on confidence: remask clean '
209
+ 'tokens whose remasking_conf < threshold (overrides the '
210
+ 'schedule-driven count). Default None = schedule-driven behavior.')
211
+ parser.add_argument('--output_dir', type=str, default=None,
212
+ help='Directory to save results CSV. Defaults to checkpoint directory.')
213
+ parser.add_argument('--device', type=str, default='cuda:0')
214
+ parser.add_argument('--seed', type=int, default=42)
215
+ args = parser.parse_args()
216
+
217
+ set_seed(args.seed, use_cuda=True)
218
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
219
+
220
+ print(f"Loading checkpoint: {args.checkpoint_path}")
221
+ print(f"Pretrained base: {args.pretrained_ckpt}")
222
+ print(f"Disable planner (no remasking): {args.disable_planner}")
223
+ print(f"Disable insertion planner: {args.disable_insertion_planner}")
224
+ print(f"Disable unmasking planner: {args.disable_unmasking_planner}")
225
+
226
+ policy_model, train_args, config = load_finetuned_model(
227
+ args.checkpoint_path, args.pretrained_ckpt, device=device
228
+ )
229
+
230
+ tokenizer = get_tokenizer()
231
+ score_func_names = ['qed', 'sa']
232
+ reward_model = MolScoringFunctions(score_func_names, device=device)
233
+ evaluator = Evaluator('diversity')
234
+
235
+ use_remasking = not args.disable_planner
236
+ disable_insertion_planner = args.disable_insertion_planner
237
+ disable_unmasking_planner = args.disable_unmasking_planner
238
+
239
+ # Map flags to quality_mode
240
+ if args.disable_planner:
241
+ quality_mode = "none"
242
+ elif args.disable_insertion_planner and args.disable_unmasking_planner:
243
+ quality_mode = "none"
244
+ elif args.disable_insertion_planner:
245
+ quality_mode = "unmasking_only"
246
+ elif args.disable_unmasking_planner:
247
+ quality_mode = "insertion_only"
248
+ else:
249
+ quality_mode = "both"
250
+
251
+ print(f"\nSampling {args.num_samples} molecules (quality_mode={quality_mode})...")
252
+
253
+ metrics, unique_smiles, qed_vals, sa_vals = evaluate_checkpoint(
254
+ policy_model=policy_model,
255
+ tokenizer=tokenizer,
256
+ reward_model=reward_model,
257
+ evaluator=evaluator,
258
+ num_samples=args.num_samples,
259
+ batch_size=args.batch_size,
260
+ max_length=args.max_length,
261
+ total_num_steps=args.total_num_steps,
262
+ quality_mode=quality_mode,
263
+ num_remasking=args.num_remasking,
264
+ quality_threshold=getattr(args, 'quality_threshold', 0.5),
265
+ unmask_quality_threshold=args.unmask_quality_threshold,
266
+ device=device,
267
+ )
268
+
269
+ # Print summary table
270
+ print("\n" + "=" * 60)
271
+ print(" De Novo Small Molecule Generation Results")
272
+ print("=" * 60)
273
+ for k, v in metrics.items():
274
+ if isinstance(v, float):
275
+ print(f" {k:<30s}: {v:.4f}")
276
+ else:
277
+ print(f" {k:<30s}: {v}")
278
+ print("=" * 60)
279
+
280
+ # Save results
281
+ output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
282
+ os.makedirs(output_dir, exist_ok=True)
283
+
284
+ if args.disable_planner:
285
+ tag = "no_planner"
286
+ elif args.disable_insertion_planner:
287
+ tag = "no_insertion_planner"
288
+ elif args.disable_unmasking_planner:
289
+ tag = "no_unmasking_planner"
290
+ else:
291
+ tag = "with_planner"
292
+ metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}.csv')
293
+ pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
294
+ print(f"Metrics saved to: {metrics_path}")
295
+
296
+ if unique_smiles:
297
+ smiles_path = os.path.join(output_dir, f'eval_smiles_{tag}.csv')
298
+ df = pd.DataFrame({
299
+ 'SMILES': unique_smiles,
300
+ 'QED': qed_vals,
301
+ 'SA': sa_vals,
302
+ })
303
+ df.to_csv(smiles_path, index=False)
304
+ print(f"SMILES saved to: {smiles_path}")
305
+
306
+
307
+ if __name__ == '__main__':
308
+ main()
a2d2_mol/finetune_mol.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import numpy as np
4
+ import torch
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning.strategies import DDPStrategy
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers import WandbLogger
9
+ import wandb
10
+ import os
11
+ import sys
12
+ from tqdm import tqdm
13
+ import pandas as pd
14
+
15
+ # add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ # imports
19
+ from inference_quality_mol import sample_mol_buffer, sample_mol_eval
20
+ from mol_utils.utils import str2bool, set_seed
21
+ from mol_scoring.scoring_functions import MolScoringFunctions
22
+ from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
23
+ from lightning_modules import AnyOrderInsertionFlowModule
24
+ from safe.tokenizer import SAFETokenizer
25
+ from tdc import Evaluator
26
+
27
+ # Repository root (two levels up from this file: A2D2/a2d2_mol/finetune_mol.py)
28
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
29
+
30
+
31
+ def get_tokenizer():
32
+ """Get SAFE tokenizer with added special tokens."""
33
+ tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
34
+ tk.add_tokens(['<', '>']) # for bracket_safe
35
+ return tk
36
+
37
+ class MolFinetuner(pl.LightningModule):
38
+ """Lightning module for distributed molecule finetuning."""
39
+
40
+ def __init__(
41
+ self,
42
+ args,
43
+ policy_model,
44
+ reward_model,
45
+ tokenizer,
46
+ pretrained=None,
47
+ mcts=None,
48
+ filename=None,
49
+ eps=1e-5
50
+ ):
51
+ super().__init__()
52
+ self.args = args
53
+ self.policy_model = policy_model
54
+ self.reward_model = reward_model
55
+ self.tokenizer = tokenizer
56
+ self.pretrained = pretrained
57
+ self.mcts = mcts
58
+ self.filename = filename
59
+ self.eps = eps
60
+
61
+ self.evaluator = Evaluator("diversity")
62
+
63
+ # Save hyperparameters
64
+ self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
65
+
66
+ # Buffer for sequences
67
+ self.x_saved = None
68
+ self.log_rnd_saved = None
69
+ self.final_rewards_saved = None
70
+
71
+ # initialize logs
72
+ self.valid_fraction_log = []
73
+ self.diversity_log = []
74
+ self.qed_log = []
75
+ self.sa_log = []
76
+ self.quality_log = []
77
+ self.uniqueness_log = []
78
+
79
+ # Alternating training between policy and planner
80
+ self.train_policy = True # Start by training policy
81
+ self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
82
+
83
+ def freeze_policy_model(self):
84
+ """Freeze policy model parameters (but not planner)."""
85
+ for name, param in self.policy_model.named_parameters():
86
+ if not name.startswith('planner.'):
87
+ param.requires_grad = False
88
+
89
+ def unfreeze_policy_model(self):
90
+ """Unfreeze policy model parameters (but not planner)."""
91
+ for name, param in self.policy_model.named_parameters():
92
+ if not name.startswith('planner.'):
93
+ param.requires_grad = True
94
+
95
+ def freeze_planner_model(self):
96
+ """Freeze planner parameters."""
97
+ if hasattr(self.policy_model, 'planner'):
98
+ for param in self.policy_model.planner.parameters():
99
+ param.requires_grad = False
100
+
101
+ def unfreeze_planner_model(self):
102
+ """Unfreeze planner parameters."""
103
+ if hasattr(self.policy_model, 'planner'):
104
+ for param in self.policy_model.planner.parameters():
105
+ param.requires_grad = True
106
+
107
+ def configure_optimizers(self):
108
+ # Separate parameter groups for policy backbone vs planner heads
109
+ planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
110
+ planner_params = []
111
+ policy_params = []
112
+ for name, param in self.policy_model.named_parameters():
113
+ if name.startswith('planner.'):
114
+ planner_params.append(param)
115
+ else:
116
+ policy_params.append(param)
117
+
118
+ param_groups = [
119
+ {'params': policy_params, 'lr': self.args.learning_rate},
120
+ {'params': planner_params, 'lr': planner_lr},
121
+ ]
122
+ optimizer = torch.optim.AdamW(param_groups)
123
+ return optimizer
124
+
125
+ def _get_quality_mode(self):
126
+ """Map ablation flags + warmup state to quality_mode string."""
127
+ if self.args.disable_planner:
128
+ return "none"
129
+ if self.current_epoch < self.args.schedule_warmup_epochs:
130
+ return "none"
131
+ di = getattr(self.args, 'disable_insertion_planner', False)
132
+ du = getattr(self.args, 'disable_unmasking_planner', False)
133
+ if di and du:
134
+ return "none"
135
+ if di:
136
+ return "unmasking_only"
137
+ if du:
138
+ return "insertion_only"
139
+ return "both"
140
+
141
+ def on_train_epoch_start(self):
142
+ """Called at the start of each training epoch."""
143
+
144
+ # If disable_planner mode, only train policy (no alternation)
145
+ if self.args.disable_planner:
146
+ self.train_policy = True
147
+ self.unfreeze_policy_model()
148
+ self.freeze_planner_model()
149
+ if self.global_rank == 0 and self.current_epoch == 0:
150
+ print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
151
+
152
+ elif getattr(self.args, 'joint_training', False):
153
+ # Joint mode: train policy + planner together every step (no alternation)
154
+ self.train_policy = True # marker; training_step adds planner loss when joint_training is set
155
+ self.unfreeze_policy_model()
156
+ self.unfreeze_planner_model()
157
+ if self.global_rank == 0 and self.current_epoch == 0:
158
+ print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
159
+
160
+ else:
161
+ # Alternate between training policy and planner from epoch 0
162
+ # Determine which model to train this epoch
163
+ cycle_position = (self.current_epoch // self.alternation_frequency) % 2
164
+ self.train_policy = (cycle_position == 0)
165
+
166
+ if self.train_policy:
167
+ # Train policy, freeze planner
168
+ self.unfreeze_policy_model()
169
+ self.freeze_planner_model()
170
+ if self.global_rank == 0:
171
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
172
+ else:
173
+ # Train planner, freeze policy
174
+ self.freeze_policy_model()
175
+ self.unfreeze_planner_model()
176
+ if self.global_rank == 0:
177
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
178
+
179
+ # Resample buffer if needed
180
+ if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
181
+ if self.global_rank == 0:
182
+ print(f"[BUFFER] Starting buffer generation for epoch {self.current_epoch}")
183
+ self._generate_buffer()
184
+ # Synchronize all ranks after buffer generation
185
+ if self.trainer and self.trainer.world_size > 1:
186
+ if self.global_rank == 0:
187
+ print(f"[BUFFER] All ranks completed buffer generation, synchronizing...")
188
+ torch.distributed.barrier()
189
+ if self.global_rank == 0:
190
+ print(f"[BUFFER] Synchronization complete!")
191
+
192
+ def _generate_buffer(self):
193
+ """Generate buffer of sequences for training.
194
+
195
+ When pool_size > 0, maintains a persistent pool and refreshes a fraction
196
+ each time instead of regenerating the entire buffer from scratch.
197
+ """
198
+ rank = self.global_rank if self.trainer else 0
199
+ world_size = self.trainer.world_size if self.trainer else 1
200
+
201
+ pool_size = getattr(self.args, 'pool_size', 0)
202
+ is_pool = pool_size > 0
203
+ is_init = self.x_saved is None
204
+
205
+ # Determine how many molecules to sample this call
206
+ if is_pool:
207
+ refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
208
+ if is_init:
209
+ samples_per_gpu = pool_size
210
+ else:
211
+ samples_per_gpu = max(1, int(pool_size * refresh_frac))
212
+ if rank == 0:
213
+ if is_init:
214
+ print(f"\n[POOL] Initializing pool with {pool_size} molecules at epoch {self.current_epoch}")
215
+ else:
216
+ print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} molecules ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
217
+ else:
218
+ samples_per_gpu = self.args.buffer_size // world_size
219
+ if rank == 0:
220
+ samples_per_gpu += self.args.buffer_size % world_size
221
+
222
+ if rank == 0:
223
+ print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
224
+
225
+ accumulated_x = []
226
+ accumulated_log_rnd = []
227
+ accumulated_rewards = []
228
+ total_accumulated = 0
229
+
230
+ max_attempts = 100 # Prevent infinite loop
231
+ attempts = 0
232
+
233
+ import time
234
+ while total_accumulated < samples_per_gpu and attempts < max_attempts:
235
+ attempts += 1
236
+ if rank == 0:
237
+ print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
238
+
239
+ start_time = time.time()
240
+
241
+ x_final, log_rnd, final_rewards, trace = \
242
+ sample_mol_buffer(
243
+ self.policy_model,
244
+ self.pretrained,
245
+ self.reward_model,
246
+ self.tokenizer,
247
+ steps=self.args.total_num_steps,
248
+ mask=self.policy_model.interpolant.mask_token,
249
+ pad=self.policy_model.interpolant.pad_token,
250
+ batch_size=self.args.batch_size,
251
+ max_length=self.args.max_length,
252
+ quality_mode=self._get_quality_mode(),
253
+ alpha=self.args.alpha,
254
+ num_remasking=self.args.num_remasking,
255
+ quality_threshold=self.args.quality_threshold,
256
+ use_quality_filter=self.args.use_quality_filter,
257
+ )
258
+ if self.args.elbo_rnd:
259
+ # Override trajectory log_rnd with forward ELBO estimate
260
+ if x_final.shape[0] > 0:
261
+ with torch.no_grad():
262
+ noised = self.policy_model.prepare_noised_sample(
263
+ x_final, num_samples=self.args.elbo_rnd_num_samples)
264
+ policy_loss = self.policy_model.compute_loss_from_noised(noised)
265
+ pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
266
+ log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
267
+
268
+ elapsed = time.time() - start_time
269
+ if rank == 0:
270
+ print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
271
+
272
+ n_valid = x_final.shape[0]
273
+ if n_valid > 0:
274
+ accumulated_x.append(x_final)
275
+ accumulated_log_rnd.append(log_rnd)
276
+ accumulated_rewards.append(final_rewards)
277
+ total_accumulated += n_valid
278
+
279
+ if rank == 0:
280
+ qm = self._get_quality_mode()
281
+ print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode={qm} accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
282
+
283
+ if total_accumulated == 0:
284
+ raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
285
+
286
+ if total_accumulated < samples_per_gpu:
287
+ print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
288
+
289
+ new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
290
+ new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
291
+ new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
292
+
293
+ del accumulated_x, accumulated_log_rnd, accumulated_rewards
294
+ torch.cuda.empty_cache()
295
+
296
+ # add to buffer: pool mode replaces a random subset, classic mode overwrites
297
+ if is_pool and not is_init:
298
+ actual_new = min(new_x.shape[0], self.x_saved.shape[0])
299
+ indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
300
+ self.x_saved[indices] = new_x[:actual_new]
301
+ self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
302
+ self.final_rewards_saved[indices] = new_rewards[:actual_new]
303
+ if rank == 0:
304
+ print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} molecules, reward mean={self.final_rewards_saved.mean():.4f}")
305
+ else:
306
+ self.x_saved = new_x
307
+ self.log_rnd_saved = new_log_rnd
308
+ self.final_rewards_saved = new_rewards
309
+
310
+ if rank == 0:
311
+ print(f"[BUFFER] After cleanup - GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
312
+
313
+ def training_step(self, batch, batch_idx):
314
+ """Training step - batch is ignored, we use saved buffer."""
315
+ # Process buffer in mini-batches to avoid OOM
316
+ mini_batch_size = getattr(self.args, 'training_mini_batch_size', 8)
317
+ buffer_size = self.x_saved.shape[0]
318
+
319
+ # Randomly sample a mini-batch from buffer
320
+ indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
321
+ x_final = self.x_saved[indices]
322
+
323
+ # get log_rnd values
324
+ log_rnd = self.log_rnd_saved[indices]
325
+
326
+ sm_temp = getattr(self.args, 'softmax_temperature', 1.0)
327
+
328
+ joint = getattr(self.args, 'joint_training', False)
329
+ policy_loss = None
330
+ planner_loss = None
331
+
332
+ if self.train_policy:
333
+ # Train policy with WDCE loss
334
+ policy_loss = self.policy_model.loss_wdce_flexible(
335
+ log_rnd,
336
+ x_final,
337
+ num_replicates=self.args.wdce_num_replicates,
338
+ centering=self.args.centering,
339
+ centering_strength=self.args.centering_strength,
340
+ softmax_temperature=sm_temp,
341
+ )
342
+ self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
343
+
344
+ if (not self.train_policy) or joint:
345
+ # Train planner with appropriate loss based on ablation flags
346
+ if self.args.disable_insertion_planner:
347
+ # Ablation: only train unmasking planner (no insertion head)
348
+ planner_loss = self.policy_model.loss_planner_flexible(
349
+ log_rnd,
350
+ x_final,
351
+ num_replicates=self.args.wdce_num_replicates,
352
+ centering=self.args.centering,
353
+ centering_strength=self.args.centering_strength,
354
+ softmax_temperature=sm_temp,
355
+ )
356
+ self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
357
+ self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
358
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
359
+ elif self.args.disable_unmasking_planner:
360
+ # only train insertion planner (no remasking head)
361
+ unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
362
+ log_rnd,
363
+ x_final,
364
+ num_replicates=self.args.wdce_num_replicates,
365
+ centering=self.args.centering,
366
+ centering_strength=self.args.centering_strength,
367
+ softmax_temperature=sm_temp,
368
+ )
369
+ # Zero out the unmasking component - only backprop insertion loss
370
+ planner_loss = insert_loss
371
+ self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
372
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
373
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
374
+ else:
375
+ # Full planner: train both remasking + insertion
376
+ unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
377
+ log_rnd,
378
+ x_final,
379
+ num_replicates=self.args.wdce_num_replicates,
380
+ centering=self.args.centering,
381
+ centering_strength=self.args.centering_strength,
382
+ softmax_temperature=sm_temp,
383
+ )
384
+ self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
385
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
386
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
387
+
388
+ # Combine losses depending on mode
389
+ if joint:
390
+ loss = policy_loss + planner_loss
391
+ mode_value = 0.5
392
+ elif self.train_policy:
393
+ loss = policy_loss
394
+ mode_value = 0.0
395
+ else:
396
+ loss = planner_loss
397
+ mode_value = 1.0
398
+
399
+ # Log overall loss and mode
400
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
401
+ self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
402
+
403
+ return loss
404
+
405
+ def on_train_epoch_end(self):
406
+ """Called at the end of each training epoch - only rank 0 evaluates."""
407
+ # Only evaluate every N epochs to save time
408
+ eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
409
+ is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
410
+ if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
411
+ # Sample eval batch with updated policy
412
+ x_eval, qed, sa, uniqueness, diversity, quality, valid_fraction = \
413
+ sample_mol_eval(
414
+ self.policy_model, self.reward_model,
415
+ self.tokenizer,
416
+ steps=self.args.total_num_steps,
417
+ mask=self.policy_model.interpolant.mask_token,
418
+ pad=self.policy_model.interpolant.pad_token,
419
+ batch_size=50,
420
+ max_length=self.args.max_length,
421
+ quality_mode=self._get_quality_mode(),
422
+ num_remasking=self.args.num_remasking,
423
+ evaluator=self.evaluator,
424
+ )
425
+
426
+ # Append to logs
427
+ self.valid_fraction_log.append(valid_fraction)
428
+ self.uniqueness_log.append(uniqueness)
429
+ self.diversity_log.append(diversity)
430
+ self.qed_log.append(qed)
431
+ self.sa_log.append(sa)
432
+ self.quality_log.append(quality)
433
+
434
+ # Compute reward stats
435
+ mean_reward = self.final_rewards_saved.mean().item()
436
+ min_reward = self.final_rewards_saved.min().item()
437
+ max_reward = self.final_rewards_saved.max().item()
438
+ median_reward = self.final_rewards_saved.median().item()
439
+
440
+ # Log metrics
441
+ self.log_dict({
442
+ "eval/valid_fraction": valid_fraction,
443
+ "eval/uniqueness": np.mean(uniqueness),
444
+ "eval/diversity": np.mean(diversity),
445
+ "eval/qed": np.mean(qed),
446
+ "eval/sa": np.mean(sa),
447
+ "eval/quality": np.mean(quality),
448
+ "eval/mean_reward_search": mean_reward,
449
+ "eval/min_reward_search": min_reward,
450
+ "eval/max_reward_search": max_reward,
451
+ "eval/median_reward_search": median_reward
452
+ })
453
+
454
+ print(f"epoch {self.current_epoch} | validity {valid_fraction:.4f} | uniqueness {np.mean(uniqueness):.4f} | diversity {np.mean(diversity):.4f} | "
455
+ f"QED {np.mean(qed):.4f} | SA {np.mean(sa):.4f} | quality {np.mean(quality):.4f} | ")
456
+
457
+ def on_fit_end(self):
458
+ """Called at the end of training - save results."""
459
+ if self.global_rank == 0:
460
+ # Save logs and plot
461
+ base_path = self.args.base_path
462
+ plot_path = f'{base_path}/results/{self.args.run_name}'
463
+ os.makedirs(plot_path, exist_ok=True)
464
+
465
+ output_log_path = f'{plot_path}/log_{self.filename}.csv'
466
+ save_logs_to_file(self.valid_fraction_log, self.uniqueness_log,
467
+ self.diversity_log, self.qed_log, self.sa_log,
468
+ self.quality_log, output_log_path)
469
+
470
+ # Final generation
471
+ x_eval, qed, sa, valid_fraction, uniqueness, diversity, quality, df = \
472
+ sample_mol_eval(
473
+ self.policy_model, self.reward_model,
474
+ self.tokenizer,
475
+ steps=self.args.total_num_steps,
476
+ mask=self.policy_model.interpolant.mask_token,
477
+ pad=self.policy_model.interpolant.pad_token,
478
+ batch_size=50,
479
+ max_length=self.args.max_length,
480
+ quality_mode=self._get_quality_mode(),
481
+ num_remasking=self.args.num_remasking,
482
+ evaluator=self.evaluator,
483
+ dataframe=True,
484
+ )
485
+ df.to_csv(f'{plot_path}/mol_generation_results.csv', index=False)
486
+
487
+
488
+ def save_logs_to_file(valid_fraction_log, uniqueness_log,
489
+ diversity_log, qed_log, sa_log,
490
+ quality_log, output_path):
491
+ """
492
+ Saves the logs to a CSV file.
493
+ """
494
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
495
+
496
+ log_data = {
497
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
498
+ "Valid Fraction": valid_fraction_log,
499
+ "Uniqueness": uniqueness_log,
500
+ "Diversity": diversity_log,
501
+ "QED": qed_log,
502
+ "Synthetic Accessibility": sa_log,
503
+ "Quality": quality_log
504
+ }
505
+
506
+ df = pd.DataFrame(log_data)
507
+ df.to_csv(output_path, index=False)
508
+
509
+
510
+ class DummyDataset(torch.utils.data.Dataset):
511
+ """Dummy dataset for Lightning trainer (we use buffer instead)."""
512
+ def __init__(self, size=100):
513
+ self.size = size
514
+
515
+ def __len__(self):
516
+ return self.size
517
+
518
+ def __getitem__(self, idx):
519
+ return torch.zeros(1) # Dummy data
520
+
521
+
522
+ def main():
523
+ """Main entry point for distributed training."""
524
+ # Disable DDP optimizer for higher-order ops like flex_attention
525
+ import torch._dynamo
526
+ torch._dynamo.config.optimize_ddp = False
527
+
528
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
529
+ argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
530
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
531
+ argparser.add_argument('--num_epochs', type=int, default=100)
532
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
533
+ argparser.add_argument('--truncate_steps', type=int, default=50)
534
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
535
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
536
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
537
+ argparser.add_argument('--batch_size', type=int, default=50)
538
+ argparser.add_argument('--name', type=str, default='debug')
539
+ argparser.add_argument('--total_num_steps', type=int, default=128)
540
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
541
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
542
+ argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
543
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
544
+ argparser.add_argument("--seed", type=int, default=0)
545
+ # new
546
+ argparser.add_argument('--run_name', type=str, default='mol')
547
+ argparser.add_argument("--save_path_dir", default="", type=str)
548
+ # mcts
549
+ argparser.add_argument('--num_sequences', type=int, default=10)
550
+ argparser.add_argument('--max_length', type=int, default=1024)
551
+ argparser.add_argument('--num_children', type=int, default=50)
552
+ argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
553
+ argparser.add_argument('--seq_length', type=int, default=1024)
554
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
555
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
556
+ argparser.add_argument('--buffer_size', type=int, default=100)
557
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
558
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
559
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
560
+ argparser.add_argument('--resample_every_n_step', type=int, default=3)
561
+ argparser.add_argument('--exploration', type=float, default=0.1)
562
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
563
+ argparser.add_argument('--alpha', type=float, default=0.01)
564
+ argparser.add_argument('--scalarization', type=str, default='sum')
565
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
566
+ argparser.add_argument("--centering", action='store_true', default=False)
567
+ argparser.add_argument("--centering_strength", type=float, default=1.0)
568
+
569
+ # adaptive schedule parameters
570
+ argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
571
+ argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
572
+ argparser.add_argument('--schedule_num_layers', type=int, default=2)
573
+ argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
574
+ argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
575
+ argparser.add_argument('--freeze_base_model', action='store_true', default=False)
576
+ argparser.add_argument('--schedule_warmup_epochs', type=int, default=20, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
577
+ argparser.add_argument('--alternation_frequency', type=int, default=5, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
578
+ argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
579
+
580
+ # objectives
581
+ argparser.add_argument('--num_obj', type=int, default=2)
582
+ argparser.add_argument('--devices', type=int, default=-1)
583
+ argparser.add_argument('--checkpoint_path', type=str, default=None)
584
+
585
+ # ELBO-based log_rnd estimation
586
+ argparser.add_argument('--elbo_rnd', action='store_true', default=False,
587
+ help='If set, compute log_rnd via forward ELBO instead of trajectory rollout')
588
+ argparser.add_argument('--elbo_rnd_num_samples', type=int, default=4,
589
+ help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
590
+
591
+ # remasking
592
+ argparser.add_argument('--num_remasking', type=int, default=5)
593
+ argparser.add_argument('--quality_threshold', type=float, default=1)
594
+ argparser.add_argument('--use_quality_filter', action='store_true', help='If set, filter buffer to only include molecules with QED>=0.6 and SA<=4')
595
+ argparser.add_argument('--training_mini_batch_size', type=int, default=8, help='Mini-batch size for training step to avoid OOM')
596
+ argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
597
+ argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
598
+ argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
599
+ argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
600
+ argparser.add_argument('--qed_only', action='store_true', help='If set, optimize only for QED score (no SA)')
601
+ argparser.add_argument('--softmax_temperature', type=float, default=1.0,
602
+ help='Temperature for softmax on importance weights (>1 smooths, prevents concentration)')
603
+ argparser.add_argument('--pool_size', type=int, default=0,
604
+ help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer)')
605
+ argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
606
+ help='Fraction of pool to replace each resample step (only used when pool_size>0)')
607
+ argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10,
608
+ help='Number of gradient updates per epoch (1=original, 10=recommended)')
609
+
610
+ args = argparser.parse_args()
611
+
612
+ # Default planner LR to policy LR if not specified
613
+ if args.planner_learning_rate is None:
614
+ args.planner_learning_rate = args.learning_rate
615
+
616
+ # Set seed
617
+ pl.seed_everything(args.seed)
618
+
619
+ # Load models
620
+ checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
621
+ os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt')
622
+
623
+ curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
624
+
625
+ if args.no_mcts:
626
+ args.run_name = f'mol_al_resample{args.resample_every_n_step}_no-mcts_{curr_time}'
627
+ else:
628
+ args.run_name = f'mol_al_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
629
+
630
+ # append ablation tags to run name for easy identification
631
+ if args.disable_planner:
632
+ args.run_name += '_no_planner'
633
+ if args.disable_insertion_planner:
634
+ args.run_name += '_no_insertion_planner'
635
+ if args.disable_unmasking_planner:
636
+ args.run_name += '_no_unmasking_planner'
637
+ if args.joint_training:
638
+ if args.disable_planner:
639
+ raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
640
+ args.run_name += '_joint_training'
641
+
642
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
643
+ os.makedirs(args.save_path, exist_ok=True)
644
+ set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
645
+
646
+ # Initialize the model
647
+ print("Loading models..")
648
+
649
+ # Load pretrained model for reference (frozen)
650
+ pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
651
+ map_location='cpu',
652
+ weights_only=False)
653
+ pretrained.eval()
654
+ for param in pretrained.parameters():
655
+ param.requires_grad = False
656
+
657
+ # Load checkpoint to extract config
658
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
659
+ if 'hyper_parameters' in checkpoint:
660
+ config = checkpoint['hyper_parameters']['config']
661
+ elif 'config' in checkpoint:
662
+ config = checkpoint['config']
663
+ else:
664
+ raise ValueError("Cannot find config in checkpoint")
665
+
666
+ # Update config for adaptive schedule
667
+ from omegaconf import OmegaConf
668
+ if not OmegaConf.is_config(config):
669
+ from omegaconf import DictConfig
670
+ config = DictConfig(config)
671
+
672
+ OmegaConf.set_struct(config, False)
673
+
674
+ config.training.use_adaptive_schedule = args.use_adaptive_schedule
675
+ config.training.schedule_hidden_dim = args.schedule_hidden_dim
676
+ config.training.schedule_num_layers = args.schedule_num_layers
677
+ config.training.schedule_loss_weight = args.schedule_loss_weight
678
+ config.training.freeze_base_model = args.freeze_base_model
679
+ config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
680
+ config.training.use_bracket_safe = True
681
+
682
+ OmegaConf.set_struct(config, True)
683
+
684
+ # initialize policy model with adaptive schedule
685
+ policy_model = AnyOrderInsertionFlowModuleFT(
686
+ config=config,
687
+ args=args,
688
+ pretrained_checkpoint=checkpoint_path,
689
+ insertion_planner=True,
690
+ )
691
+
692
+ # define mcts
693
+ if args.qed_only:
694
+ score_func_names = ['qed']
695
+ else:
696
+ score_func_names = ['qed', 'sa']
697
+
698
+ tokenizer = get_tokenizer()
699
+
700
+ filename = args.run_name
701
+
702
+ # Device will be set by Lightning automatically in DDP
703
+ reward_model = MolScoringFunctions(score_func_names, device='cpu')
704
+ model = MolFinetuner(
705
+ args=args,
706
+ policy_model=policy_model,
707
+ reward_model=reward_model,
708
+ tokenizer=tokenizer,
709
+ pretrained=pretrained,
710
+ mcts=None,
711
+ filename=filename,
712
+ )
713
+
714
+ checkpoint_callback = ModelCheckpoint(
715
+ dirpath=args.save_path,
716
+ filename='model-{epoch:02d}-{train_loss:.4f}',
717
+ every_n_epochs=args.save_every_n_epochs,
718
+ save_top_k=-1, # Save all checkpoints
719
+ save_last=True, # Also save last.ckpt
720
+ auto_insert_metric_name=False
721
+ )
722
+
723
+ # Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
724
+ wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-mol', name=args.run_name)
725
+
726
+ # create dummy dataloader
727
+ dataset = DummyDataset(size=args.num_training_steps_per_epoch)
728
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
729
+
730
+ # setup trainer with DDP
731
+ trainer = pl.Trainer(
732
+ max_epochs=args.num_epochs,
733
+ accelerator='gpu',
734
+ devices=args.devices,
735
+ strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
736
+ gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
737
+ logger=wandb_logger,
738
+ callbacks=[checkpoint_callback],
739
+ enable_progress_bar=True,
740
+ log_every_n_steps=1
741
+ )
742
+
743
+ # Train
744
+ trainer.fit(model, dataloader)
745
+
746
+ if __name__ == '__main__':
747
+ main()
a2d2_mol/inference_quality_mol.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified molecule sampling with quality-guided planning.
2
+
3
+ Supports 4 quality modes and optional RND (importance weight) computation.
4
+
5
+ Quality modes:
6
+ "none" - No planner, no remasking (policy-only)
7
+ "both" - Both unmasking + insertion planners active
8
+ "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
9
+ "insertion_only" - Only insertion planner (unmasking planner disabled)
10
+
11
+ RND toggle:
12
+ compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
13
+ compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
14
+ """
15
+
16
+ import torch
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch.nn.functional as F
20
+ from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
21
+ from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
22
+ from mol_utils.utils_chem import batch_safe_to_smiles, batch_validate_and_extract
23
+ from tdc import Evaluator, Oracle
24
+
25
+ QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
26
+
27
+
28
+ @torch.no_grad()
29
+ def _diffusion_loop(
30
+ model, steps, mask, pad, batch_size, max_length,
31
+ quality_mode="both",
32
+ compute_rnd=False,
33
+ pretrained=None,
34
+ remasking_mode="schedule_aware",
35
+ num_remasking=1,
36
+ quality_threshold=1,
37
+ temperature=1.0,
38
+ return_trace=False,
39
+ unmask_quality_threshold=None,
40
+ ):
41
+ """Core discrete diffusion sampling loop for molecule generation.
42
+
43
+ Args:
44
+ model: Finetuned policy model.
45
+ steps: Number of diffusion steps.
46
+ mask: Mask token ID.
47
+ pad: Pad token ID.
48
+ batch_size: Number of sequences to generate.
49
+ max_length: Maximum sequence length.
50
+ quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
51
+ compute_rnd: Whether to compute step-wise log importance weights.
52
+ pretrained: Frozen pretrained model (required if compute_rnd=True).
53
+ remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
54
+ num_remasking: Number of tokens to remask per step.
55
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
56
+ temperature: Sampling temperature (1.0 = no scaling).
57
+ return_trace: Whether to record sampling trace.
58
+
59
+ Returns:
60
+ (xt, log_rnd, sampling_trace)
61
+ log_rnd is None when compute_rnd=False.
62
+ """
63
+ assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
64
+ if compute_rnd:
65
+ assert pretrained is not None, "pretrained model required when compute_rnd=True"
66
+
67
+ # Derive flags from quality_mode
68
+ use_remasking = quality_mode != "none"
69
+ disable_unmasking_planner = quality_mode in ("none", "insertion_only")
70
+ disable_insertion_planner = quality_mode in ("none", "unmasking_only")
71
+
72
+ device = next(model.parameters()).device
73
+
74
+ # Initialize all-pad sequence
75
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
76
+
77
+ dt = 1.0 / steps
78
+ t = torch.zeros(batch_size, device=device)
79
+
80
+ # Precompute index tensors
81
+ batch_idx_L = (
82
+ torch.arange(batch_size, device=device)
83
+ .view(batch_size, 1)
84
+ .expand(batch_size, max_length)
85
+ )
86
+ pos_idx_L = (
87
+ torch.arange(max_length, device=device)
88
+ .view(1, max_length)
89
+ .expand(batch_size, max_length)
90
+ )
91
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
92
+
93
+ neg_inf = torch.tensor(-np.inf, device=device)
94
+
95
+ if use_remasking and remasking_mode == "remdm_conf":
96
+ remasking_score = torch.zeros((batch_size, max_length), device=device)
97
+
98
+ log_rnd = None
99
+
100
+ for i in range(steps):
101
+ # --- Policy model forward ---
102
+ pred_rate = model(xt, t)
103
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
104
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
105
+ len_rate = pred_rate.length_rate # (B, L+1)
106
+
107
+ # --- Pretrained model forward (for RND) ---
108
+ if compute_rnd:
109
+ pretrained_pred = pretrained(xt, t)
110
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
111
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
112
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
113
+
114
+ # --- Unmask step (Euler) ---
115
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
116
+ unmask_rate[xt != mask] = 0
117
+ unmask_rate[mask_pos + (mask,)] = 0
118
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
119
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
120
+
121
+ if compute_rnd:
122
+ pretrained_unmask_rate[xt != mask] = 0
123
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
124
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
125
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
126
+
127
+ # Add "stay" probability
128
+ _xt = xt.clone()
129
+ _xt[xt == pad] = mask
130
+ trans_prob.scatter_add_(
131
+ 2,
132
+ _xt.unsqueeze(-1),
133
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
134
+ )
135
+ if compute_rnd:
136
+ pretrained_trans_prob.scatter_add_(
137
+ 2,
138
+ _xt.unsqueeze(-1),
139
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
140
+ )
141
+
142
+ # Temperature scaling
143
+ if temperature != 1.0:
144
+ logits = torch.log(trans_prob + 1e-10) / temperature
145
+ trans_prob = torch.softmax(logits, dim=-1)
146
+
147
+ # Final step: remove mask token from sampling
148
+ if i == steps - 1:
149
+ print("Final step, removing mask token from sampling")
150
+ trans_prob[mask_pos + (mask,)] = 0.0
151
+
152
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
153
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
154
+ if mask_has_zero_prob.any():
155
+ num_zero_prob = mask_has_zero_prob.sum().item()
156
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
157
+ uniform_prob[:, :mask] = 1.0 / mask
158
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
159
+ else:
160
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
161
+
162
+ new_xt = _sample_tokens(trans_prob)
163
+ new_xt[xt == pad] = pad
164
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
165
+
166
+ # Update remasking_score buffer for remdm_conf mode
167
+ if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
168
+ token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
169
+ chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
170
+ changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
171
+ remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
172
+
173
+ # --- Remasking step ---
174
+ if use_remasking and i < steps - 1:
175
+ if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
176
+ remasking_conf = torch.zeros((batch_size, max_length), device=device)
177
+ else:
178
+ planner_out = model.planner(new_xt, t)
179
+ remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
180
+
181
+ clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
182
+
183
+ if remasking_mode == "schedule_aware":
184
+ new_xt = apply_schedule_aware_remasking(
185
+ model, new_xt, t, dt, remasking_conf, clean_index,
186
+ mask, neg_inf, batch_size,
187
+ unmask_quality_threshold=unmask_quality_threshold,
188
+ )
189
+ remasking_score_temp = None
190
+ else:
191
+ raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
192
+
193
+ if remasking_score_temp is not None:
194
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
195
+ for j in range(batch_size):
196
+ k = min(num_remasking, int(clean_index[j].sum().item()))
197
+ if k > 0:
198
+ _, select_indices = torch.topk(remasking_score_temp[j], k=k)
199
+ new_xt[j, select_indices] = mask
200
+
201
+ if return_trace:
202
+ for batch_idx in range(batch_size):
203
+ for pos in range(max_length):
204
+ if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
205
+ sampling_trace[batch_idx].append(
206
+ SamplingTraceDatapoint(
207
+ t=t[batch_idx].item(),
208
+ event_type="change",
209
+ position=pos,
210
+ token=mask,
211
+ )
212
+ )
213
+
214
+ # --- Compute log probabilities for RND ---
215
+ if compute_rnd:
216
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
217
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
218
+
219
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
220
+
221
+ log_policy_step = (lp * changed_mask).sum(dim=1)
222
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
223
+
224
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
225
+
226
+ # --- Insertion step ---
227
+ if i != steps - 1:
228
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
229
+
230
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
231
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
232
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
233
+ total_ext = ext.sum(dim=1)
234
+ valid = xt_len + total_ext <= max_length
235
+ ext = ext * valid.view(batch_size, 1).long()
236
+
237
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
238
+ new_len = xt_len + total_ext # (B,)
239
+
240
+ xt_tmp = torch.full_like(xt, pad)
241
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
242
+ xt_tmp[mask_fill] = mask
243
+
244
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
245
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
246
+ flat_b = batch_idx_L[orig_mask]
247
+ flat_p = new_pos_orig[orig_mask]
248
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
249
+
250
+ # Schedule-aware insertion quality filtering
251
+ if use_remasking and not disable_insertion_planner:
252
+ if compute_rnd:
253
+ xt_tmp_before = xt_tmp.clone()
254
+
255
+ xt_tmp = apply_schedule_aware_insertion(
256
+ model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
257
+ orig_mask, new_pos_orig, quality_threshold
258
+ )
259
+
260
+ if compute_rnd:
261
+ # Compute corrected ext based on what actually stayed
262
+ ext_corrected = torch.zeros_like(ext)
263
+ for b in range(batch_size):
264
+ after_len = xt_tmp[b].ne(pad).sum().item()
265
+ orig_len = xt_len[b].item()
266
+ surviving_insertions = after_len - orig_len
267
+ if total_ext[b] > 0:
268
+ ratio = surviving_insertions / total_ext[b].item()
269
+ ext_corrected[b] = (ext[b].float() * ratio).long()
270
+ else:
271
+ ext_corrected = ext
272
+ else:
273
+ ext_corrected = ext
274
+
275
+ # Compute insertion log_rnd
276
+ if compute_rnd:
277
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
278
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
279
+
280
+ log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
281
+ log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
282
+
283
+ log_insert_diff = log_pretrained_insert - log_policy_insert
284
+ log_rnd += log_insert_diff
285
+ else:
286
+ xt_tmp = new_xt
287
+
288
+ if return_trace:
289
+ for batch_idx in range(batch_size):
290
+ for j in range(max_length):
291
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
292
+ sampling_trace[batch_idx].append(
293
+ SamplingTraceDatapoint(
294
+ t=t[batch_idx].item(),
295
+ event_type="change",
296
+ position=j,
297
+ token=new_xt[batch_idx, j].item(),
298
+ )
299
+ )
300
+
301
+ if i != steps - 1:
302
+ for j in range(max_length):
303
+ id = max_length - j - 1
304
+ if ext[batch_idx, id]:
305
+ sampling_trace[batch_idx].append(
306
+ SamplingTraceDatapoint(
307
+ t=t[batch_idx].item(),
308
+ event_type="insertion",
309
+ position=id,
310
+ token=mask,
311
+ )
312
+ )
313
+
314
+ xt = xt_tmp
315
+ t = t + dt
316
+
317
+ return xt, log_rnd, sampling_trace
318
+
319
+
320
+ def _decode_and_validate(model, tokenizer, samples):
321
+ """Decode token IDs to SMILES and validate.
322
+
323
+ Returns:
324
+ (validSequences, valid_indices): list of valid SMILES, list of batch indices.
325
+ """
326
+ decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
327
+
328
+ use_bracket_safe = model.config.training.get('use_bracket_safe', False)
329
+ smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
330
+
331
+ # Extract valid sequences (take largest fragment)
332
+ validSequences = []
333
+ valid_indices = []
334
+ for idx, s in enumerate(smiles_samples):
335
+ if s:
336
+ largest_frag = sorted(s.split('.'), key=len)[-1]
337
+ validSequences.append(largest_frag)
338
+ valid_indices.append(idx)
339
+
340
+ return validSequences, valid_indices
341
+
342
+
343
+ @torch.no_grad()
344
+ def sample_mol_buffer(
345
+ model, pretrained, reward_model, tokenizer,
346
+ steps, mask, pad, batch_size, max_length,
347
+ quality_mode="both",
348
+ alpha=0.1,
349
+ remasking_mode="schedule_aware",
350
+ num_remasking=1,
351
+ quality_threshold=1,
352
+ temperature=1.0,
353
+ use_quality_filter=True,
354
+ ):
355
+ """Generate molecules for training buffer. Always computes step-wise RND.
356
+
357
+ Args:
358
+ model: Finetuned policy model.
359
+ pretrained: Frozen pretrained model.
360
+ reward_model: Molecule scoring function.
361
+ tokenizer: SAFE tokenizer for decoding.
362
+ steps: Number of diffusion steps.
363
+ mask: Mask token ID.
364
+ pad: Pad token ID.
365
+ batch_size: Number of sequences to generate.
366
+ max_length: Maximum sequence length.
367
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
368
+ alpha: RND scaling factor.
369
+ remasking_mode: Remasking strategy.
370
+ num_remasking: Number of tokens to remask per step.
371
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
372
+ temperature: Sampling temperature.
373
+ use_quality_filter: If True, filter to QED>=0.6 and SA<=4.
374
+
375
+ Returns:
376
+ (valid_x, log_rnd, scalar_rewards, sampling_trace)
377
+ """
378
+ xt, log_rnd, trace = _diffusion_loop(
379
+ model, steps, mask, pad, batch_size, max_length,
380
+ quality_mode=quality_mode,
381
+ compute_rnd=True,
382
+ pretrained=pretrained,
383
+ remasking_mode=remasking_mode,
384
+ num_remasking=num_remasking,
385
+ quality_threshold=quality_threshold,
386
+ temperature=temperature,
387
+ )
388
+
389
+ device = xt.device
390
+ samples = xt.to(device)
391
+
392
+ validSequences, valid_indices = _decode_and_validate(model, tokenizer, samples)
393
+
394
+ valid_x_final = [samples[idx] for idx in valid_indices]
395
+ valid_log_rnd = [log_rnd[idx] for idx in valid_indices]
396
+
397
+ print("len valid sequences:", len(validSequences))
398
+
399
+ if len(validSequences) == 0:
400
+ print("[WARNING] No valid molecules generated in this batch")
401
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
402
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
403
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
404
+ return empty_x, empty_log_rnd, empty_rewards, trace
405
+
406
+ # Compute multi-objective rewards
407
+ score_vectors = reward_model(input_seqs=validSequences)
408
+ scalar_rewards = np.sum(score_vectors, axis=-1)
409
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
410
+
411
+ print(f"scalar reward dim{len(scalar_rewards)}")
412
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
413
+
414
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha)
415
+ valid_x_final = torch.stack(valid_x_final, dim=0)
416
+
417
+ # Optionally filter to only keep quality sequences (QED >= 0.6 and SA <= 4)
418
+ if use_quality_filter:
419
+ qed_scores = score_vectors[:, 0]
420
+ if score_vectors.shape[1] > 1:
421
+ sa_scores = score_vectors[:, 1]
422
+ else:
423
+ _oracle_sa = Oracle('sa')
424
+ raw_sa = np.array(_oracle_sa(validSequences))
425
+ sa_scores = raw_sa
426
+ quality_mask = (qed_scores >= 0.6) & (sa_scores <= 4)
427
+
428
+ n_quality = quality_mask.sum()
429
+ print(f"Quality filtering: {n_quality}/{len(validSequences)} sequences pass (QED>=0.6, SA<=4)")
430
+
431
+ if n_quality == 0:
432
+ print("[WARNING] No quality molecules in this batch")
433
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
434
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
435
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
436
+ return empty_x, empty_log_rnd, empty_rewards, trace
437
+
438
+ quality_mask_torch = torch.as_tensor(quality_mask, dtype=torch.bool, device=device)
439
+
440
+ quality_x_final = valid_x_final[quality_mask_torch]
441
+ quality_log_rnd = log_rnd[quality_mask_torch]
442
+ quality_rewards = scalar_rewards[quality_mask_torch]
443
+ else:
444
+ print(f"No quality filtering applied - using all {len(validSequences)} valid molecules")
445
+ quality_x_final = valid_x_final
446
+ quality_log_rnd = log_rnd
447
+ quality_rewards = scalar_rewards
448
+
449
+ return quality_x_final, quality_log_rnd, quality_rewards, trace
450
+
451
+
452
+ @torch.no_grad()
453
+ def sample_mol_eval(
454
+ model, reward_model, tokenizer,
455
+ steps, mask, pad, batch_size, max_length,
456
+ quality_mode="both",
457
+ remasking_mode="schedule_aware",
458
+ num_remasking=1,
459
+ quality_threshold=1,
460
+ temperature=1.0,
461
+ evaluator=None,
462
+ dataframe=False,
463
+ unmask_quality_threshold=None,
464
+ ):
465
+ """Generate molecules for evaluation.
466
+
467
+ Args:
468
+ model: Finetuned policy model.
469
+ reward_model: Molecule scoring function.
470
+ tokenizer: SAFE tokenizer for decoding.
471
+ steps: Number of diffusion steps.
472
+ mask: Mask token ID.
473
+ pad: Pad token ID.
474
+ batch_size: Number of sequences to generate.
475
+ max_length: Maximum sequence length.
476
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
477
+ remasking_mode: Remasking strategy.
478
+ num_remasking: Number of tokens to remask per step.
479
+ quality_threshold: Threshold for insertion quality filtering. Pass None
480
+ to use schedule-driven deletion with no threshold gate
481
+ temperature: Sampling temperature.
482
+ evaluator: TDC Evaluator for diversity (created if None).
483
+ dataframe: If True, include a pandas DataFrame in the return.
484
+
485
+ Returns:
486
+ Without dataframe:
487
+ (validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction)
488
+ With dataframe:
489
+ (validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df)
490
+ validSequences is the raw list including duplicates; qed/sa are scored
491
+ on the unique set. Caller can dedup with set(validSequences). The
492
+ dataframe (when requested) has one row per unique molecule.
493
+ """
494
+ if evaluator is None:
495
+ evaluator = Evaluator('diversity')
496
+
497
+ xt, _, trace = _diffusion_loop(
498
+ model, steps, mask, pad, batch_size, max_length,
499
+ quality_mode=quality_mode,
500
+ compute_rnd=False,
501
+ remasking_mode=remasking_mode,
502
+ num_remasking=num_remasking,
503
+ quality_threshold=quality_threshold,
504
+ temperature=temperature,
505
+ unmask_quality_threshold=unmask_quality_threshold,
506
+ )
507
+
508
+ device = xt.device
509
+ samples = xt.to(device)
510
+
511
+ decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
512
+
513
+ use_bracket_safe = model.config.training.get('use_bracket_safe', False)
514
+ smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
515
+
516
+ # Extract valid sequences (take largest fragment)
517
+ validSequences = [sorted(s.split('.'), key=len)[-1] for s in smiles_samples if s]
518
+
519
+ print("len valid sequences:", len(validSequences))
520
+ valid_fraction = len(validSequences) / batch_size
521
+ uniqueSequences = list(set(validSequences))
522
+ uniqueness = len(uniqueSequences) / len(validSequences) if len(validSequences) > 0 else 0
523
+ diversity = evaluator(uniqueSequences) if len(uniqueSequences) > 0 else 0
524
+
525
+ # Calculate quality (unique sequences with QED >= 0.6 and SA <= 4)
526
+ if len(uniqueSequences) > 0:
527
+ score_vectors_temp = reward_model(input_seqs=list(uniqueSequences))
528
+ qed_scores = score_vectors_temp[:, 0] # Raw QED (0-1)
529
+
530
+ # Always use raw SA (1-10 scale) for quality filtering
531
+ _oracle_sa = Oracle('sa')
532
+ raw_sa_scores = np.array(_oracle_sa(list(uniqueSequences)))
533
+
534
+ quality_count = sum((qed_scores >= 0.6) & (raw_sa_scores <= 4))
535
+ quality = quality_count / batch_size
536
+ print(f'Quality:\t{quality}')
537
+
538
+ qed = qed_scores
539
+ sa = raw_sa_scores
540
+ else:
541
+ zeros = [0.0]
542
+ qed = zeros
543
+ sa = zeros
544
+ quality = 0.0
545
+
546
+ if dataframe:
547
+ df = pd.DataFrame({
548
+ "Mol Sequence": uniqueSequences,
549
+ "QED": qed if len(uniqueSequences) else [0.0],
550
+ "SA": sa if len(uniqueSequences) else [0.0],
551
+ })
552
+ return validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
553
+
554
+ return validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction
a2d2_mol/mol_dataset.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Adapter to use HuggingFace datasets with the any-length discrete diffusion model.
4
+ This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format
5
+ expected by the training pipeline.
6
+ """
7
+
8
+ import torch
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from datasets import load_dataset
11
+ import pytorch_lightning as pl
12
+ from safe.tokenizer import SAFETokenizer
13
+ from mol_utils.bracket_safe_converter import safe2bracketsafe
14
+ from typing import Optional, List
15
+ import re
16
+
17
+
18
+ def get_tokenizer():
19
+ """Get SAFE tokenizer with added special tokens."""
20
+ tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
21
+ tk.add_tokens(['<', '>']) # for bracket_safe
22
+ return tk
23
+
24
+
25
+ class Collator:
26
+ """Data collator for SAFE/bracket-SAFE format."""
27
+
28
+ def __init__(self, config, tokenizer=None):
29
+ self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer()
30
+ self.max_length = config.interpolant.max_length
31
+ self.use_bracket_safe = config.training.get('use_bracket_safe', False)
32
+
33
+ def __call__(self, examples):
34
+ # Handle both dict with 'labels' and direct string format
35
+ inputs = []
36
+ for example in examples:
37
+ if isinstance(example, dict):
38
+ # Try different key names: 'input', 'labels', 'smiles'
39
+ input_text = example.get('input', example.get('labels', example.get('smiles', '')))
40
+ else:
41
+ input_text = example
42
+
43
+ if self.use_bracket_safe:
44
+ input_text = safe2bracketsafe(input_text)
45
+
46
+ inputs.append(input_text)
47
+
48
+ batch = self.tokenizer(
49
+ inputs,
50
+ return_tensors='pt',
51
+ padding=True,
52
+ truncation=True,
53
+ max_length=self.max_length
54
+ )
55
+
56
+ # Convert BatchEncoding to plain dict with tensors
57
+ # Remove token_type_ids if present (not needed for diffusion models)
58
+ result = {
59
+ 'input_ids': batch['input_ids'],
60
+ 'attention_mask': batch['attention_mask']
61
+ }
62
+
63
+ return result
64
+
65
+
66
+ class HFDatasetAdapter(Dataset):
67
+ """Adapts HuggingFace datasets to the format expected by the diffusion model."""
68
+
69
+ def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False):
70
+ """
71
+ Args:
72
+ hf_dataset: HuggingFace dataset object (streaming or regular)
73
+ tokenizer: SMILES tokenizer instance
74
+ smiles_column: Name of the column containing SMILES strings
75
+ max_length: Maximum sequence length
76
+ convert_to_safe: Whether to convert SMILES to SAFE format
77
+ is_streaming: Whether dataset is in streaming mode
78
+ """
79
+ self.tokenizer = tokenizer
80
+ self.smiles_column = smiles_column
81
+ self.max_length = max_length
82
+ self.convert_to_safe = convert_to_safe
83
+ self.is_streaming = is_streaming
84
+
85
+ if is_streaming:
86
+ # For streaming datasets, we don't pre-load the data
87
+ self.data = hf_dataset
88
+ self._length = None # Unknown length for streaming
89
+ print(f'Initialized streaming dataset adapter')
90
+ else:
91
+ # Store raw data without pre-tokenization (tokenization will happen in collator)
92
+ print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...')
93
+ self.data = []
94
+ for item in hf_dataset:
95
+ smiles = item[smiles_column]
96
+ if smiles: # Skip empty SMILES
97
+ self.data.append({'input': smiles, 'labels': smiles})
98
+ print(f'Processed {len(self.data)} valid samples')
99
+
100
+ def __len__(self):
101
+ if self.is_streaming:
102
+ # Streaming datasets don't have a length
103
+ # Return a large number to prevent issues with samplers
104
+ return 10_000_000 if self._length is None else self._length
105
+ return len(self.data)
106
+
107
+ def __getitem__(self, idx):
108
+ if self.is_streaming:
109
+ # For streaming, iteration happens differently
110
+ raise NotImplementedError("Streaming datasets should be iterated, not indexed")
111
+ return self.data[idx]
112
+
113
+ def __iter__(self):
114
+ """Support iteration for streaming datasets."""
115
+ if self.is_streaming:
116
+ for item in self.data:
117
+ smiles = item[self.smiles_column]
118
+ if smiles: # Skip empty SMILES
119
+ yield {'input': smiles, 'labels': smiles}
120
+ else:
121
+ for item in self.data:
122
+ yield item
123
+
124
+
125
+ class HFDataModule(pl.LightningDataModule):
126
+ """PyTorch Lightning DataModule for HuggingFace datasets."""
127
+
128
+ def __init__(
129
+ self,
130
+ config,
131
+ dataset_name: str,
132
+ tokenizer: SAFETokenizer,
133
+ smiles_column: str = 'smiles',
134
+ val_split: float = 0.1,
135
+ test_split: Optional[float] = None,
136
+ streaming: bool = True,
137
+ max_train_samples: Optional[int] = None,
138
+ max_val_samples: Optional[int] = None,
139
+ ):
140
+ """
141
+ Args:
142
+ config: Configuration object containing training parameters
143
+ dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt")
144
+ tokenizer: SMILES tokenizer instance
145
+ smiles_column: Name of column containing SMILES strings
146
+ val_split: Fraction of data to use for validation
147
+ test_split: Optional fraction of data to use for testing
148
+ streaming: Whether to use streaming mode (recommended for large datasets)
149
+ max_train_samples: Maximum number of training samples to use (for non-streaming)
150
+ max_val_samples: Maximum number of validation samples to use (for non-streaming)
151
+ """
152
+ super().__init__()
153
+ self.config = config
154
+ self.dataset_name = dataset_name
155
+ self.tokenizer = tokenizer
156
+ self.smiles_column = smiles_column
157
+ self.max_length = config.interpolant.max_length
158
+ self.batch_size = config.training.per_gpu_batch_size
159
+ self.num_workers = config.training.get('cpus', 4)
160
+ self.val_split = val_split
161
+ self.test_split = test_split
162
+ self.streaming = streaming
163
+ self.max_train_samples = max_train_samples
164
+ self.max_val_samples = max_val_samples
165
+
166
+ self.train_dataset = None
167
+ self.val_dataset = None
168
+ self.test_dataset = None
169
+
170
+ # Initialize collator
171
+ self.collator = Collator(config, tokenizer)
172
+
173
+ def setup(self, stage: Optional[str] = None):
174
+ """Load and split the dataset."""
175
+ print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})')
176
+
177
+ if self.streaming:
178
+ # Load dataset in streaming mode
179
+ raw_dataset = load_dataset(self.dataset_name, streaming=True)
180
+
181
+ # Handle different dataset structures
182
+ if 'train' in raw_dataset:
183
+ train_stream = raw_dataset['train']
184
+ else:
185
+ # If no splits exist, use the entire dataset
186
+ train_stream = raw_dataset[list(raw_dataset.keys())[0]]
187
+
188
+ # For streaming, we need to manually split train/val
189
+ # Skip validation samples, then take training samples
190
+ val_size = int(100000 * self.val_split) # Assume ~100k samples for val split calculation
191
+ train_size = 100000 - val_size
192
+
193
+ # Create validation stream (take first val_size samples)
194
+ val_stream = train_stream.take(val_size)
195
+
196
+ # Create training stream (skip val_size samples, then iterate)
197
+ train_stream_shifted = train_stream.skip(val_size)
198
+
199
+ # Create adapted datasets
200
+ self.train_dataset = HFDatasetAdapter(
201
+ train_stream_shifted,
202
+ self.tokenizer,
203
+ self.smiles_column,
204
+ self.max_length,
205
+ is_streaming=True
206
+ )
207
+
208
+ self.val_dataset = HFDatasetAdapter(
209
+ val_stream,
210
+ self.tokenizer,
211
+ self.smiles_column,
212
+ self.max_length,
213
+ is_streaming=True
214
+ )
215
+
216
+ print(f'Streaming dataset initialized - samples will be loaded on-the-fly')
217
+
218
+ else:
219
+ # Traditional non-streaming mode with full dataset loading
220
+ raw_dataset = load_dataset(self.dataset_name)
221
+
222
+ # Handle different dataset structures
223
+ if 'train' in raw_dataset:
224
+ train_data = raw_dataset['train']
225
+ else:
226
+ # If no splits exist, use the entire dataset and split it
227
+ train_data = raw_dataset[list(raw_dataset.keys())[0]]
228
+
229
+ # Limit samples if specified
230
+ if self.max_train_samples:
231
+ train_data = train_data.select(range(min(self.max_train_samples, len(train_data))))
232
+
233
+ # Check if dataset already has validation split
234
+ if 'validation' in raw_dataset or 'val' in raw_dataset:
235
+ val_key = 'validation' if 'validation' in raw_dataset else 'val'
236
+ val_data = raw_dataset[val_key]
237
+ else:
238
+ # Create train/val split
239
+ split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42)
240
+ train_data = split_dataset['train']
241
+ val_data = split_dataset['test']
242
+
243
+ # Limit validation samples if specified
244
+ if self.max_val_samples:
245
+ val_data = val_data.select(range(min(self.max_val_samples, len(val_data))))
246
+
247
+ # Create test split if requested
248
+ if self.test_split and 'test' not in raw_dataset:
249
+ split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42)
250
+ train_data = split_dataset['train']
251
+ self.test_dataset = HFDatasetAdapter(
252
+ split_dataset['test'],
253
+ self.tokenizer,
254
+ self.smiles_column,
255
+ self.max_length,
256
+ is_streaming=False
257
+ )
258
+ elif 'test' in raw_dataset:
259
+ self.test_dataset = HFDatasetAdapter(
260
+ raw_dataset['test'],
261
+ self.tokenizer,
262
+ self.smiles_column,
263
+ self.max_length,
264
+ is_streaming=False
265
+ )
266
+
267
+ # Create adapted datasets
268
+ self.train_dataset = HFDatasetAdapter(
269
+ train_data,
270
+ self.tokenizer,
271
+ self.smiles_column,
272
+ self.max_length,
273
+ is_streaming=False
274
+ )
275
+
276
+ self.val_dataset = HFDatasetAdapter(
277
+ val_data,
278
+ self.tokenizer,
279
+ self.smiles_column,
280
+ self.max_length,
281
+ is_streaming=False
282
+ )
283
+
284
+ print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}')
285
+ if self.test_dataset:
286
+ print(f'Test: {len(self.test_dataset)}')
287
+
288
+ def train_dataloader(self):
289
+ if self.streaming:
290
+ # Pass streaming dataset directly to DataLoader (HF IterableDataset)
291
+ # Must use num_workers=0 when using .skip() or .take() operations
292
+ return DataLoader(
293
+ self.train_dataset.data, # Use the raw HF streaming dataset
294
+ batch_size=self.batch_size,
295
+ collate_fn=self.collator,
296
+ num_workers=0, # Required for streaming with skip/take operations
297
+ pin_memory=True,
298
+ shuffle=False, # Cannot shuffle streaming datasets
299
+ )
300
+ else:
301
+ return DataLoader(
302
+ self.train_dataset,
303
+ batch_size=self.batch_size,
304
+ collate_fn=self.collator,
305
+ shuffle=True,
306
+ num_workers=self.num_workers,
307
+ pin_memory=True,
308
+ persistent_workers=True if self.num_workers > 0 else False
309
+ )
310
+
311
+ def val_dataloader(self):
312
+ if self.streaming:
313
+ # Pass streaming dataset directly to DataLoader (HF IterableDataset)
314
+ # Must use num_workers=0 when using .skip() or .take() operations
315
+ return DataLoader(
316
+ self.val_dataset.data, # Use the raw HF streaming dataset
317
+ batch_size=self.batch_size,
318
+ collate_fn=self.collator,
319
+ num_workers=0, # Required for streaming with skip/take operations
320
+ pin_memory=True,
321
+ shuffle=False, # Cannot shuffle streaming datasets
322
+ )
323
+ else:
324
+ return DataLoader(
325
+ self.val_dataset,
326
+ batch_size=self.batch_size,
327
+ collate_fn=self.collator,
328
+ shuffle=False,
329
+ num_workers=self.num_workers,
330
+ pin_memory=True,
331
+ persistent_workers=True if self.num_workers > 0 else False
332
+ )
333
+
334
+ def test_dataloader(self):
335
+ if self.test_dataset:
336
+ return DataLoader(
337
+ self.test_dataset,
338
+ batch_size=self.batch_size,
339
+ collate_fn=self.collator,
340
+ shuffle=False,
341
+ num_workers=self.num_workers,
342
+ pin_memory=True,
343
+ persistent_workers=True if self.num_workers > 0 else False
344
+ )
345
+ return None
346
+
347
+
348
+ def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True):
349
+ """
350
+ Setup HuggingFace dataset and update config with token information.
351
+
352
+ Args:
353
+ config: Hydra config object
354
+ dataset_name: HuggingFace dataset identifier
355
+ smiles_column: Name of column containing SMILES strings
356
+ streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt)
357
+
358
+ Returns:
359
+ HFDataModule instance
360
+ """
361
+ # Initialize tokenizer
362
+ tokenizer = get_tokenizer()
363
+
364
+ # Update config with tokenizer info
365
+ config.interpolant.tokens = len(tokenizer)
366
+ config.interpolant.pad_token = tokenizer.pad_token_id
367
+ config.interpolant.mask_token = tokenizer.mask_token_id
368
+
369
+ # Create data module
370
+ data_module = HFDataModule(
371
+ config=config,
372
+ dataset_name=dataset_name,
373
+ tokenizer=tokenizer,
374
+ smiles_column=smiles_column,
375
+ val_split=0.1,
376
+ streaming=streaming,
377
+ )
378
+
379
+ return data_module
a2d2_mol/mol_scoring/oracle/fpscores.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
3
+ size 9048931
a2d2_mol/mol_scoring/scoring_functions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForMaskedLM
2
+ import numpy as np
3
+ from tdc import Oracle, Evaluator
4
+
5
+ class MolScoringFunctions:
6
+ def __init__(self, score_func_names=None, device=None, sa_transform='inverse'):
7
+ """
8
+ Class for generating score vectors given generated sequence
9
+
10
+ Args:
11
+ score_func_names: list of scoring function names to be evaluated
12
+ score_weights: weights to scale scores (default: 1)
13
+ sa_transform: how to transform SA scores to higher-is-better ~[0,1]:
14
+ 'inverse' (default): 1/(1+SA) — range ~0.09-0.5, weak gradient
15
+ 'linear': (10-SA)/9 — range ~0-1, stronger gradient
16
+ """
17
+ if score_func_names is None:
18
+ # just do unmasking based on validity of peptide bonds
19
+ self.score_func_names = []
20
+ else:
21
+ self.score_func_names = score_func_names
22
+
23
+ self.sa_transform = sa_transform
24
+
25
+ oracle_qed = Oracle('qed')
26
+ oracle_sa = Oracle('sa')
27
+
28
+ self.all_funcs = {'qed': oracle_qed,
29
+ 'sa': oracle_sa,
30
+ }
31
+
32
+ def forward(self, input_seqs):
33
+ scores = []
34
+
35
+ for i, score_func in enumerate(self.score_func_names):
36
+ score = self.all_funcs[score_func](input_seqs)
37
+
38
+ # Transform SA to be maximized and normalized (original SA: 1-10, lower is better)
39
+ # Convert to: higher is better, normalized to ~0-1 range like QED
40
+ if score_func == 'sa':
41
+ if self.sa_transform == 'linear':
42
+ score = (10.0 - np.array(score)) / 9.0 # range ~0-1, clipped at 0
43
+ score = np.maximum(score, 0.0)
44
+ else:
45
+ score = 1.0 / (1.0 + np.array(score)) # range ~0.09-0.5
46
+
47
+ scores.append(score)
48
+
49
+ # convert to numpy arrays with shape (num_sequences, num_functions)
50
+ scores = np.float32(scores).T
51
+
52
+ return scores
53
+
54
+ def __call__(self, input_seqs: list):
55
+ return self.forward(input_seqs)
56
+
57
+
58
+ def unittest():
59
+ scoring = MolScoringFunctions(score_func_names=['qed', 'sa'])
60
+
61
+ smiles = ['CCOc1cc(ccc1NC(=O)N[C@@H]2CCCC[C@@H]2O)F']
62
+
63
+ scores = scoring(input_seqs=smiles)
64
+ print(scores)
65
+ print(len(scores))
66
+
67
+ if __name__ == '__main__':
68
+ unittest()
a2d2_mol/mol_utils/bracket_safe_converter.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Patch: stub out `auto_docstring` if missing from transformers.utils
17
+ # (needed by safe.trainer.model in newer safe versions)
18
+ import transformers.utils as _tu
19
+ if not hasattr(_tu, 'auto_docstring'):
20
+ _tu.auto_docstring = lambda *a, **kw: (lambda fn: fn)
21
+
22
+ from safe.converter import *
23
+
24
+
25
+ class BracketSAFEConverter(SAFEConverter):
26
+ def encoder(
27
+ self,
28
+ inp: Union[str, dm.Mol],
29
+ canonical: bool = True,
30
+ randomize: Optional[bool] = False,
31
+ seed: Optional[int] = None,
32
+ constraints: Optional[List[dm.Mol]] = None,
33
+ allow_empty: bool = False,
34
+ rdkit_safe: bool = True,
35
+ ):
36
+ rng = None
37
+ if randomize:
38
+ rng = np.random.default_rng(seed)
39
+ if not canonical:
40
+ inp = dm.to_mol(inp, remove_hs=False)
41
+ inp = self.randomize(inp, rng)
42
+
43
+ if isinstance(inp, dm.Mol):
44
+ inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False)
45
+
46
+ branch_numbers = self._find_branch_number(inp)
47
+
48
+ mol = dm.to_mol(inp, remove_hs=False)
49
+ if self.ignore_stereo:
50
+ mol = dm.remove_stereochemistry(mol)
51
+
52
+ bond_map_id = 1
53
+ for atom in mol.GetAtoms():
54
+ if atom.GetAtomicNum() == 0:
55
+ atom.SetAtomMapNum(0)
56
+ atom.SetIsotope(bond_map_id)
57
+ bond_map_id += 1
58
+
59
+ if self.require_hs:
60
+ mol = dm.add_hs(mol)
61
+ matching_bonds = self._fragment(mol, allow_empty=allow_empty)
62
+ substructed_ignored = []
63
+ if constraints is not None:
64
+ substructed_ignored = list(
65
+ itertools.chain(
66
+ *[
67
+ mol.GetSubstructMatches(constraint, uniquify=True)
68
+ for constraint in constraints
69
+ ]
70
+ )
71
+ )
72
+
73
+ bonds = []
74
+ for i_a, i_b in matching_bonds:
75
+ # if both atoms of the bond are found in a disallowed substructure, we cannot consider them
76
+ # on the other end, a bond between two substructure to preserved independently is perfectly fine
77
+ if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored):
78
+ continue
79
+ obond = mol.GetBondBetweenAtoms(i_a, i_b)
80
+ bonds.append(obond.GetIdx())
81
+
82
+ if len(bonds) > 0:
83
+ mol = Chem.FragmentOnBonds(
84
+ mol,
85
+ bonds,
86
+ dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
87
+ )
88
+
89
+ frags = list(Chem.GetMolFrags(mol, asMols=True))
90
+ if randomize:
91
+ frags = rng.permutation(frags).tolist()
92
+ elif canonical:
93
+ frags = sorted(
94
+ frags,
95
+ key=lambda x: x.GetNumAtoms(),
96
+ reverse=True,
97
+ )
98
+
99
+ frags_str = []
100
+ for frag in frags:
101
+ non_map_atom_idxs = [
102
+ atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
103
+ ]
104
+ frags_str.append(
105
+ Chem.MolToSmiles(
106
+ frag,
107
+ isomericSmiles=True,
108
+ canonical=True, # needs to always be true
109
+ rootedAtAtom=non_map_atom_idxs[0],
110
+ )
111
+ )
112
+
113
+ scaffold_str = ".".join(frags_str)
114
+
115
+ # don't capture atom mapping in the scaffold
116
+ attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
117
+ if canonical:
118
+ attach_pos = sorted(attach_pos)
119
+ starting_num = 1
120
+ for attach in attach_pos:
121
+ val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
122
+ val = '<' + val + '>' # bracket added
123
+ # we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
124
+ attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
125
+ scaffold_str = attach_regexp.sub(val, scaffold_str)
126
+ starting_num += 1
127
+
128
+ # now we need to remove all the parenthesis around digit only number
129
+ wrong_attach = re.compile(r"\((<[\%\d+]*>)\)") # bracket added
130
+ scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
131
+ # furthermore, we autoapply rdkit-compatible digit standardization.
132
+ if rdkit_safe:
133
+ pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
134
+ replacement = r"\g<1>\g<2>"
135
+ scaffold_str = re.sub(pattern, replacement, scaffold_str)
136
+ return scaffold_str
137
+
138
+
139
+ def safe2bracketsafe(safe_str):
140
+ try:
141
+ return BracketSAFEConverter().encoder(Chem.MolFromSmiles(safe_str), allow_empty=True, canonical=False, randomize=True)
142
+ except:
143
+ return safe_str
144
+
145
+
146
+ def bracketsafe2safe(safe_str):
147
+ intrafrag_points = [m.group(0) for m in re.finditer(r'(?<!%)\d(?!>)', safe_str)] + \
148
+ [m.group(0).lstrip('%') for m in re.finditer(r'%\d+', safe_str)]
149
+ starting_num = max([int(i) for i in intrafrag_points]) + 1 if intrafrag_points else 0
150
+ interfrag_points = [(m.start(0), m.end(0)) for m in re.finditer(r'<\d+>', safe_str)]
151
+
152
+ safe_str = list(safe_str)
153
+ for start, end in interfrag_points:
154
+ safe_str[start] = safe_str[end-1] = ' ' # '<', '>' -> ''
155
+ num_to_replace = int(''.join(safe_str[start+1 : end-1])) + starting_num
156
+ num_to_replace = '%' + str(num_to_replace) if num_to_replace >= 10 else str(num_to_replace)
157
+ safe_str[start+1 : end-1] = [num_to_replace] + [' '] * (end - start - 3)
158
+ safe_str = re.sub(' ', '', ''.join(safe_str))
159
+ return safe_str
a2d2_mol/mol_utils/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import fsspec
9
+ import lightning
10
+ import torch
11
+ from timm.scheduler import CosineLRScheduler
12
+ import argparse
13
+ import numpy as np
14
+ import random
15
+ import os
16
+
17
+ def sample_categorical_logits(logits, dtype=torch.float64):
18
+ # do not require logits to be log-softmaxed
19
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
20
+ return (logits + gumbel_noise).argmax(dim=-1)
21
+
22
+ def fsspec_exists(filename):
23
+ """Check if a file exists using fsspec."""
24
+ fs, _ = fsspec.core.url_to_fs(filename)
25
+ return fs.exists(filename)
26
+
27
+
28
+ def fsspec_listdir(dirname):
29
+ """Listdir in manner compatible with fsspec."""
30
+ fs, _ = fsspec.core.url_to_fs(dirname)
31
+ return fs.ls(dirname)
32
+
33
+
34
+ def fsspec_mkdirs(dirname, exist_ok=True):
35
+ """Mkdirs in manner compatible with fsspec."""
36
+ fs, _ = fsspec.core.url_to_fs(dirname)
37
+ fs.makedirs(dirname, exist_ok=exist_ok)
38
+
39
+
40
+ def print_nans(tensor, name):
41
+ if torch.isnan(tensor).any():
42
+ print(name, tensor)
43
+
44
+
45
+ class CosineDecayWarmupLRScheduler(
46
+ CosineLRScheduler,
47
+ torch.optim.lr_scheduler._LRScheduler):
48
+
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self._last_epoch = -1
52
+ self.step(epoch=0)
53
+
54
+ def step(self, epoch=None):
55
+ if epoch is None:
56
+ self._last_epoch += 1
57
+ else:
58
+ self._last_epoch = epoch
59
+ # We call either step or step_update, depending on
60
+ # whether we're using the scheduler every epoch or every
61
+ # step.
62
+ # Otherwise, lightning will always call step (i.e.,
63
+ # meant for each epoch), and if we set scheduler
64
+ # interval to "step", then the learning rate update will
65
+ # be wrong.
66
+ if self.t_in_epochs:
67
+ super().step(epoch=self._last_epoch)
68
+ else:
69
+ super().step_update(num_updates=self._last_epoch)
70
+
71
+
72
+ class LoggingContext:
73
+ """Context manager for selective logging."""
74
+ def __init__(self, logger, level=None, handler=None, close=True):
75
+ self.logger = logger
76
+ self.level = level
77
+ self.handler = handler
78
+ self.close = close
79
+
80
+ def __enter__(self):
81
+ if self.level is not None:
82
+ self.old_level = self.logger.level
83
+ self.logger.setLevel(self.level)
84
+ if self.handler:
85
+ self.logger.addHandler(self.handler)
86
+
87
+ def __exit__(self, et, ev, tb):
88
+ if self.level is not None:
89
+ self.logger.setLevel(self.old_level)
90
+ if self.handler:
91
+ self.logger.removeHandler(self.handler)
92
+ if self.handler and self.close:
93
+ self.handler.close()
94
+
95
+
96
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
97
+ """Initializes multi-GPU-friendly python logger."""
98
+
99
+ logger = logging.getLogger(name)
100
+ logger.setLevel(level)
101
+
102
+ # this ensures all logging levels get marked with the rank zero decorator
103
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
104
+ for level in ('debug', 'info', 'warning', 'error',
105
+ 'exception', 'fatal', 'critical'):
106
+ setattr(logger,
107
+ level,
108
+ lightning.pytorch.utilities.rank_zero_only(
109
+ getattr(logger, level)))
110
+
111
+ return logger
112
+
113
+
114
+ def str2bool(v):
115
+ if isinstance(v, bool):
116
+ return v
117
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
118
+ return True
119
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
120
+ return False
121
+ else:
122
+ raise argparse.ArgumentTypeError('Boolean value expected.')
123
+
124
+
125
+ def set_seed(seed, use_cuda):
126
+ os.environ['PYTHONHASHSEED'] = str(seed)
127
+ np.random.seed(seed)
128
+ random.seed(seed)
129
+ torch.manual_seed(seed)
130
+ # torch.backends.cudnn.deterministic = True
131
+ if use_cuda:
132
+ torch.cuda.manual_seed(seed)
133
+ torch.cuda.manual_seed_all(seed)
134
+ print(f'=> Seed of the run set to {seed}')
135
+
a2d2_mol/mol_utils/utils_chem.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import random
18
+ import safe as sf
19
+ import datamol as dm
20
+ from contextlib import suppress
21
+ from rdkit import Chem, RDLogger
22
+ RDLogger.DisableLog('rdApp.*')
23
+
24
+ # https://github.com/datamol-io/safe/blob/main/safe/sample.py
25
+ # https://github.com/jensengroup/GB_GA/blob/master/crossover.py
26
+ def safe_to_smiles(safe_str, fix=True):
27
+ if fix:
28
+ safe_str = '.'.join([frag for frag in safe_str.split('.')
29
+ if sf.decode(frag, ignore_errors=True) is not None])
30
+ return sf.decode(safe_str, canonical=True, ignore_errors=True)
31
+
32
+
33
+ def _safe_to_smiles_worker(args):
34
+ """Worker function for parallel SAFE to SMILES conversion."""
35
+ safe_str, use_bracket_safe, fix = args
36
+ try:
37
+ from mol_utils.bracket_safe_converter import bracketsafe2safe
38
+ if use_bracket_safe:
39
+ safe_str = bracketsafe2safe(safe_str)
40
+ return safe_to_smiles(safe_str, fix=fix)
41
+ except Exception:
42
+ return None
43
+
44
+
45
+ def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
46
+ """
47
+ Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
48
+
49
+ Args:
50
+ safe_strings: List of SAFE format strings
51
+ use_bracket_safe: Whether to convert from bracket SAFE format first
52
+ fix: Whether to fix invalid fragments
53
+ num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
54
+
55
+ Returns:
56
+ List of SMILES strings (None for invalid molecules)
57
+ """
58
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
59
+ import os
60
+
61
+ n = len(safe_strings)
62
+ if n == 0:
63
+ return []
64
+
65
+ # For small batches, use sequential processing (overhead not worth it)
66
+ if n <= 4:
67
+ if use_bracket_safe:
68
+ from mol_utils.bracket_safe_converter import bracketsafe2safe
69
+ return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
70
+ else:
71
+ return [safe_to_smiles(s, fix=fix) for s in safe_strings]
72
+
73
+ # Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
74
+ # ProcessPoolExecutor has too much overhead for this use case
75
+ if num_workers is None:
76
+ num_workers = min(os.cpu_count() or 4, n, 8)
77
+
78
+ args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
79
+
80
+ # ThreadPoolExecutor is faster here because:
81
+ # 1. No pickle serialization overhead
82
+ # 2. RDKit releases the GIL during computation
83
+ # 3. Lower startup cost
84
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
85
+ results = list(executor.map(_safe_to_smiles_worker, args_list))
86
+
87
+ return results
88
+
89
+
90
+ def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
91
+ """
92
+ Batch validate SMILES and extract valid samples efficiently.
93
+
94
+ Args:
95
+ smiles_list: List of SMILES strings (may contain None for invalid)
96
+ samples_tensor: Tensor of token IDs (B, L)
97
+ log_rnd_tensor: Tensor of log random values (B,)
98
+
99
+ Returns:
100
+ valid_sequences: List of valid SMILES (largest fragment)
101
+ valid_indices: List of indices of valid samples
102
+ """
103
+ valid_sequences = []
104
+ valid_indices = []
105
+
106
+ for idx, smiles in enumerate(smiles_list):
107
+ if smiles: # Valid SMILES
108
+ # Take largest fragment if multiple
109
+ largest_fragment = sorted(smiles.split('.'), key=len)[-1]
110
+ valid_sequences.append(largest_fragment)
111
+ valid_indices.append(idx)
112
+
113
+ return valid_sequences, valid_indices
114
+
115
+
116
+ def filter_by_substructure(sequences, substruct):
117
+ substruct = sf.utils.standardize_attach(substruct)
118
+ substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
119
+ substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
120
+ return sf.utils.filter_by_substructure_constraints(sequences, substruct)
121
+
122
+
123
+ def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
124
+ mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)
125
+
126
+ prefix_linkers = []
127
+ suffix_linkers = []
128
+ prefix_query = dm.from_smarts(prefix)
129
+ suffix_query = dm.from_smarts(suffix)
130
+
131
+ for x in prefix_sequences:
132
+ with suppress(Exception):
133
+ x = dm.to_mol(x)
134
+ out = mol_linker_slicer(x, prefix_query)
135
+ prefix_linkers.append(out[1])
136
+
137
+ for x in suffix_sequences:
138
+ with suppress(Exception):
139
+ x = dm.to_mol(x)
140
+ out = mol_linker_slicer(x, suffix_query)
141
+ suffix_linkers.append(out[1])
142
+
143
+ n_linked = 0
144
+ linked = []
145
+ linkers = prefix_linkers + suffix_linkers
146
+ linkers = [x for x in linkers if x is not None]
147
+ for n_linked, linker in enumerate(linkers):
148
+ linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
149
+ if n_linked > num_samples:
150
+ break
151
+ linked = [x for x in linked if x]
152
+ return linked[:num_samples]
153
+
154
+
155
+ def cut(smiles):
156
+ def cut_nonring(mol):
157
+ if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
158
+ return None
159
+
160
+ bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring
161
+ bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
162
+ fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])
163
+
164
+ try:
165
+ return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
166
+ except ValueError:
167
+ return None
168
+
169
+ mol = Chem.MolFromSmiles(smiles)
170
+ frags = set()
171
+ # non-ring cut
172
+ for _ in range(3):
173
+ frags_nonring = cut_nonring(mol)
174
+ if frags_nonring is not None:
175
+ frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
176
+ return frags
177
+
178
+
179
+ class Slicer:
180
+ def __call__(self, mol):
181
+ if isinstance(mol, str):
182
+ mol = Chem.MolFromSmiles(mol)
183
+
184
+ # non-ring single bonds
185
+ bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
186
+ for bond in bonds:
187
+ yield bond
a2d2_mol/oracle/fpscores.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
3
+ size 9048931
a2d2_mol/remasking_scheduleaware.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schedule-aware remasking and insertion logic that ensures the number of masked tokens
3
+ follows the interpolant schedule.
4
+ """
5
+ import torch
6
+ import numpy as np
7
+
8
+ def apply_schedule_aware_insertion(
9
+ model,
10
+ xt_tmp,
11
+ new_xt,
12
+ t,
13
+ dt,
14
+ ext,
15
+ mask,
16
+ pad,
17
+ max_length,
18
+ orig_mask,
19
+ new_pos_orig,
20
+ quality_threshold=1,
21
+ ):
22
+ """
23
+ Remove low-quality insertions based on insertion confidence while respecting
24
+ the interpolant schedule for expected sequence length.
25
+
26
+ Args:
27
+ model: Model with planner and interpolant
28
+ xt_tmp: Sequence after insertion [B, L]
29
+ new_xt: Sequence before insertion [B, L]
30
+ t: Current time [B]
31
+ dt: Time step size
32
+ ext: Number of insertions per gap [B, L+1]
33
+ mask: Mask token ID
34
+ pad: Pad token ID
35
+ max_length: Maximum sequence length
36
+ orig_mask: Mask of original token positions [B, L]
37
+ new_pos_orig: New positions of original tokens [B, L]
38
+ quality_threshold: If a float, drop insertions with confidence below it
39
+
40
+ Returns:
41
+ xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
42
+ """
43
+ device = xt_tmp.device
44
+ batch_size, L = xt_tmp.shape
45
+ total_ext = ext.sum(dim=1)
46
+
47
+ # Only proceed if there were insertions
48
+ if total_ext.sum() == 0:
49
+ return xt_tmp
50
+
51
+ # Get planner predictions on inserted state. The insertion head is trained
52
+ # with the pre-step time t (see loss_insert_planner_flexible), so condition
53
+ # on t here too; t_next is still used below for the length schedule.
54
+ t_next = t + dt
55
+ planner_out = model.planner(xt_tmp, t)
56
+ insertion_conf = planner_out.get("insertion_conf", None)
57
+
58
+ if insertion_conf is None:
59
+ return xt_tmp
60
+
61
+ insertion_conf = insertion_conf.squeeze(-1) # (B, L)
62
+
63
+ # Expected sequence length at next timestep according to schedule
64
+ current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
65
+ expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
66
+ estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
67
+ expected_length = estimated_final_length * expected_progress # [B]
68
+
69
+ # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
70
+ # Fancy-indexing scatter avoids the per-batch python loop.
71
+ valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
72
+ valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
73
+ is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
74
+ is_original[valid_b, valid_p] = True
75
+ inserted_positions = (xt_tmp == mask) & ~is_original
76
+
77
+ # Two deletion modes, selected by `quality_threshold`:
78
+ # * float: drop insertions whose confidence is below the threshold, capped
79
+ # so the length never falls below the scheduled minimum.
80
+ candidates = inserted_positions & (insertion_conf < quality_threshold)
81
+ num_bad = candidates.sum(dim=1) # [B], long
82
+ min_length = expected_length.long().clamp(min=1) # [B]
83
+ max_removable = (current_length_after.long() - min_length).clamp(min=0)
84
+ length_after_removal = current_length_after.long() - num_bad
85
+ schedule_violates = length_after_removal < min_length
86
+ k_per_row = torch.where(schedule_violates, max_removable, num_bad)
87
+ k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
88
+
89
+ if not candidates.any():
90
+ return xt_tmp
91
+
92
+ # Select the lowest-confidence candidates per row via a sort.
93
+ neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
94
+ scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
95
+ _, sorted_indices = scores.sort(dim=1, descending=True)
96
+ positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
97
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
98
+ final_bad = torch.zeros_like(candidates)
99
+ final_bad.scatter_(1, sorted_indices, keep_in_topk)
100
+
101
+ if not final_bad.any():
102
+ return xt_tmp
103
+
104
+ # Compact each row to the left (keep good, drop bad), then pad the tail.
105
+ # Stable sort by the bad flag pushes bad positions to the right.
106
+ sort_key = final_bad.long()
107
+ _, perm = torch.sort(sort_key, dim=1, stable=True)
108
+ xt_tmp = torch.gather(xt_tmp, 1, perm)
109
+ num_keep = (~final_bad).sum(dim=1) # [B]
110
+ tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
111
+ xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
112
+
113
+ return xt_tmp
114
+
115
+
116
+ def apply_schedule_aware_remasking(
117
+ model,
118
+ new_xt,
119
+ t,
120
+ dt,
121
+ remasking_conf,
122
+ clean_index,
123
+ mask,
124
+ neg_inf,
125
+ batch_size,
126
+ unmask_quality_threshold=None,
127
+ ):
128
+ """
129
+ Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
130
+
131
+ Args:
132
+ model: Model with interpolant that has an unmask_schedule
133
+ new_xt: Current sequence [B, L]
134
+ t: Current time [B]
135
+ dt: Time step size
136
+ remasking_conf: Confidence scores for tokens [B, L]
137
+ clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
138
+ mask: Mask token ID
139
+ neg_inf: Negative infinity tensor
140
+ batch_size: Batch size
141
+
142
+ Returns:
143
+ new_xt: Modified sequence with schedule-aware remasking applied
144
+ """
145
+ # Optional AJD threshold gate (overrides the schedule-driven count when set):
146
+ # remask every clean token whose unmasking-quality confidence is below the
147
+ # threshold. Higher threshold => more aggressive remasking.
148
+ if unmask_quality_threshold is not None:
149
+ to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
150
+ return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
151
+
152
+ t_next = t + dt
153
+ num_clean = clean_index.sum(dim=1) # [B], long
154
+ current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
155
+ expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
156
+ expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
157
+ masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
158
+
159
+ # Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
160
+ k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
161
+
162
+ if k_per_row.sum() == 0:
163
+ return new_xt
164
+
165
+ # Use confidence to decide which clean tokens to remask: lowest conf first.
166
+ remasking_score_temp = -1.0 * remasking_conf # low conf = high score
167
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
168
+
169
+ _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
170
+ L = remasking_score_temp.shape[1]
171
+ positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
172
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
173
+ to_mask = torch.zeros_like(clean_index)
174
+ to_mask.scatter_(1, sorted_indices, keep_in_topk)
175
+ new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
176
+
177
+ return new_xt
a2d2_mol/sampling.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
4
+
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from typing import Any, Literal, Optional
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from lightning_modules.mdm import MaskedDiffusionModule
12
+
13
+
14
+ @dataclass
15
+ class SamplingTraceDatapoint:
16
+ t: float
17
+ event_type: Literal["insertion", "change"]
18
+ position: int
19
+ token: Any
20
+
21
+
22
+ @dataclass
23
+ class SamplingResult:
24
+ samples: torch.Tensor
25
+ # Trace is supposed to be processed sequentially as updates are not commutative
26
+ trace: Optional[list[SamplingTraceDatapoint]]
27
+
28
+ def __iter__(self):
29
+ yield from [self.samples, self.trace]
30
+
31
+
32
+ # Sample from categorical distribution for each position using the transition probabilities
33
+ def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
34
+ """Sample one token per position from probability distribution.
35
+ Args:
36
+ probs: [batch_size, seq_len, vocab_size] transition probabilities
37
+ Returns:
38
+ [batch_size, seq_len] sampled token indices
39
+ """
40
+ batch_size, seq_len, vocab_size = probs.shape
41
+ flat_probs = probs.view(-1, vocab_size)
42
+ samples = torch.multinomial(flat_probs, num_samples=1)
43
+ return samples.view(batch_size, seq_len)
44
+
45
+
46
+ def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
47
+
48
+ batch_size, seq_len, vocab_size = probs.shape
49
+
50
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
51
+ noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
52
+
53
+ # select the highest score (most likely category after Gumbel noise)
54
+ samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
55
+
56
+ return samples.view(batch_size, seq_len)
57
+
58
+ @torch.no_grad()
59
+ def mdm_euler_sampling(
60
+ model: MaskedDiffusionModule,
61
+ steps: int,
62
+ mask: int,
63
+ pad: int,
64
+ batch_size: int,
65
+ max_length: int,
66
+ return_trace: bool = False,
67
+ temperature: float = 1.0,
68
+ ):
69
+ assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
70
+ device = next(model.parameters()).device
71
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
72
+
73
+ dt = 1.0 / steps
74
+ t = torch.zeros(batch_size, device=device)
75
+
76
+ for i in range(steps):
77
+ print("i-th sampling step")
78
+ # ——— predict and convert rates ———
79
+ pred_rate = model(xt, t)
80
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
81
+ unmask_rate = pred_rate.unmask_rate
82
+
83
+ # ——— unmask step (Euler) ———
84
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
85
+ unmask_rate[xt != mask] = 0
86
+ unmask_rate[mask_pos + (mask,)] = 0
87
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
88
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
89
+
90
+ _xt = xt.clone()
91
+ trans_prob.scatter_add_(
92
+ 2,
93
+ _xt.unsqueeze(-1),
94
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
95
+ )
96
+
97
+ # Apply temperature scaling
98
+ if temperature != 1.0:
99
+ logits = torch.log(trans_prob + 1e-10) / temperature
100
+ trans_prob = torch.softmax(logits, dim=-1)
101
+
102
+ if i == steps - 1:
103
+ print("Final step, removing mask token from sampling")
104
+ trans_prob[mask_pos + (mask,)] = 0.0
105
+ print(trans_prob[mask_pos + (mask,)])
106
+
107
+ new_xt = _sample_tokens(trans_prob)
108
+ new_xt = torch.where(xt != mask, xt, new_xt)
109
+
110
+ xt = new_xt
111
+ t = t + dt
112
+
113
+ return xt, []
114
+
115
+
116
+ @torch.no_grad()
117
+ def any_order_mask_insertion_euler_sampling(
118
+ model: torch.nn.Module,
119
+ steps: int,
120
+ mask: int,
121
+ pad: int,
122
+ batch_size: int,
123
+ max_length: int,
124
+ return_trace: bool = False,
125
+ temperature: float = 1.0,
126
+ ) -> SamplingResult:
127
+ device = next(model.parameters()).device
128
+
129
+ # 1) Initialize all‑pad sequence and trace
130
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
131
+ sampling_trace = []
132
+
133
+ dt = 1.0 / steps
134
+ t = torch.zeros(batch_size, device=device)
135
+
136
+ # Precompute row indices for scatter
137
+ batch_idx_L = (
138
+ torch.arange(batch_size, device=device)
139
+ .view(batch_size, 1)
140
+ .expand(batch_size, max_length)
141
+ )
142
+ pos_idx_L = (
143
+ torch.arange(max_length, device=device)
144
+ .view(1, max_length)
145
+ .expand(batch_size, max_length)
146
+ )
147
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
148
+
149
+ for i in range(steps):
150
+ # ——— predict and convert rates ———
151
+ pred_rate = model(xt, t)
152
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
153
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
154
+ len_rate = pred_rate.length_rate # (B, L+1)
155
+
156
+ # ——— unmask step (Euler) ———
157
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
158
+ unmask_rate[xt != mask] = 0
159
+ unmask_rate[mask_pos + (mask,)] = 0
160
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
161
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
162
+
163
+ # add “stay” probability
164
+ _xt = xt.clone()
165
+ _xt[xt == pad] = mask
166
+ trans_prob.scatter_add_(
167
+ 2,
168
+ _xt.unsqueeze(-1),
169
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
170
+ )
171
+
172
+ if i == steps - 1:
173
+ print("Final step, removing mask token from sampling")
174
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
175
+
176
+ # renormalize probabilities to ensure they sum to 1
177
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
178
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
179
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
180
+ if mask_has_zero_prob.any():
181
+ # create uniform distribution over valid tokens (excluding mask and pad)
182
+ uniform_prob = torch.zeros_like(trans_prob[0])
183
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
184
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
185
+ else:
186
+ # normalize to sum to 1
187
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
188
+
189
+ new_xt = _sample_tokens(trans_prob)
190
+ new_xt[xt == pad] = pad
191
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
192
+
193
+ if i != steps - 1:
194
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
195
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
196
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
197
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
198
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
199
+ total_ext = ext.sum(dim=1)
200
+ valid = xt_len + total_ext <= max_length
201
+ ext = ext * valid.view(batch_size, 1).long()
202
+
203
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
204
+ new_len = xt_len + total_ext # (B,)
205
+
206
+ xt_tmp = torch.full_like(xt, pad)
207
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
208
+ xt_tmp[mask_fill] = mask
209
+
210
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
211
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
212
+ flat_b = batch_idx_L[orig_mask]
213
+ flat_p = new_pos_orig[orig_mask]
214
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
215
+ else:
216
+ xt_tmp = new_xt
217
+
218
+ if return_trace:
219
+ # Check if the token was changed
220
+ for batch_idx in range(batch_size):
221
+ for j in range(max_length):
222
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
223
+ sampling_trace[batch_idx].append(
224
+ SamplingTraceDatapoint(
225
+ t=t[batch_idx].item(),
226
+ event_type="change",
227
+ position=j,
228
+ token=new_xt[batch_idx, j].item(),
229
+ )
230
+ )
231
+
232
+ # Check if a new token was inserted
233
+ for j in range(max_length):
234
+ id = max_length - j - 1
235
+ if ext[batch_idx, id]:
236
+ sampling_trace[batch_idx].append(
237
+ SamplingTraceDatapoint(
238
+ t=t[batch_idx].item(),
239
+ event_type="insertion",
240
+ position=id,
241
+ token=mask,
242
+ )
243
+ )
244
+
245
+ xt = xt_tmp
246
+ t = t + dt
247
+
248
+ return xt, sampling_trace
249
+
250
+ @torch.no_grad()
251
+ def batch_mcts_reverse_step(
252
+ xt: torch.Tensor,
253
+ t: torch.Tensor,
254
+ dt: float,
255
+ model: torch.nn.Module,
256
+ pretrained: torch.nn.Module,
257
+ mask: int,
258
+ pad: int,
259
+ batch_size: int,
260
+ max_length: int,
261
+ last_step: bool = False,
262
+ temperature: float = 1.0,
263
+ ) -> SamplingResult:
264
+ device = next(model.parameters()).device
265
+
266
+ xt = xt.repeat(batch_size, 1)
267
+
268
+ # squeeze to remove extra dimensions, then expand to batch_size
269
+ t = t.squeeze().expand(batch_size)
270
+ # precompute row indices for scatter
271
+ batch_idx_L = (
272
+ torch.arange(batch_size, device=device)
273
+ .view(batch_size, 1)
274
+ .expand(batch_size, max_length)
275
+ )
276
+ pos_idx_L = (
277
+ torch.arange(max_length, device=device)
278
+ .view(1, max_length)
279
+ .expand(batch_size, max_length)
280
+ )
281
+
282
+ # ——— predict and convert rates ———
283
+ pred_rate = model(xt, t)
284
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
285
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
286
+ len_rate = pred_rate.length_rate # (B, L+1)
287
+
288
+ # ——— get pretrained model rates for log_rnd computation ———
289
+ pretrained_pred = pretrained(xt, t)
290
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
291
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
292
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
293
+
294
+ # ——— unmask step (Euler) ———
295
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
296
+ unmask_rate[xt != mask] = 0
297
+ unmask_rate[mask_pos + (mask,)] = 0
298
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
299
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
300
+
301
+ # Same for pretrained
302
+ pretrained_unmask_rate[xt != mask] = 0
303
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
304
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
305
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
306
+
307
+ # add “stay” probability
308
+ _xt = xt.clone()
309
+ _xt[xt == pad] = mask
310
+ trans_prob.scatter_add_(
311
+ 2,
312
+ _xt.unsqueeze(-1),
313
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
314
+ )
315
+ pretrained_trans_prob.scatter_add_(
316
+ 2,
317
+ _xt.unsqueeze(-1),
318
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
319
+ )
320
+
321
+ if last_step:
322
+ print("Final step, removing mask token from sampling")
323
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
324
+
325
+ # renormalize probabilities to ensure they sum to 1
326
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
327
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
328
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
329
+ if mask_has_zero_prob.any():
330
+ # create uniform distribution over valid tokens (excluding mask and pad)
331
+ uniform_prob = torch.zeros_like(trans_prob[0])
332
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
333
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
334
+ else:
335
+ # normalize to sum to 1
336
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
337
+
338
+ new_xt = _sample_tokens(trans_prob)
339
+ new_xt[xt == pad] = pad
340
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
341
+
342
+ # ——— compute log probabilities for RND ———
343
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
344
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
345
+
346
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
347
+
348
+ log_policy_step = (lp * changed_mask).sum(dim=1)
349
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
350
+
351
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
352
+
353
+ if not last_step:
354
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
355
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
356
+
357
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
358
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
359
+
360
+ # log P(ext; λ) = ext*log(λ) - λ
361
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
362
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
363
+
364
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
365
+ log_rnd += log_insert_diff
366
+ log_pretrained_step += log_pretrained_insert
367
+ log_policy_step += log_policy_insert
368
+
369
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
370
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
371
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
372
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
373
+ total_ext = ext.sum(dim=1)
374
+ valid = xt_len + total_ext <= max_length
375
+ ext = ext * valid.view(batch_size, 1).long()
376
+
377
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
378
+ new_len = xt_len + total_ext # (B,)
379
+
380
+ xt_tmp = torch.full_like(xt, pad)
381
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
382
+ xt_tmp[mask_fill] = mask
383
+
384
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
385
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
386
+ flat_b = batch_idx_L[orig_mask]
387
+ flat_p = new_pos_orig[orig_mask]
388
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
389
+ else:
390
+ xt_tmp = new_xt
391
+
392
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
393
+
394
+
395
+ @torch.no_grad()
396
+ def mcts_reverse_step(
397
+ xt: torch.Tensor,
398
+ t: torch.Tensor,
399
+ dt: float,
400
+ model: torch.nn.Module,
401
+ pretrained: torch.nn.Module,
402
+ mask: int,
403
+ pad: int,
404
+ max_length: int,
405
+ last_step: bool = False,
406
+ temperature: float = 1.0,
407
+ ) -> SamplingResult:
408
+ device = next(model.parameters()).device
409
+
410
+ batch_size = xt.size(0)
411
+
412
+ # precompute row indices for scatter
413
+ batch_idx_L = (
414
+ torch.arange(batch_size, device=device)
415
+ .view(batch_size, 1)
416
+ .expand(batch_size, max_length)
417
+ )
418
+ pos_idx_L = (
419
+ torch.arange(max_length, device=device)
420
+ .view(1, max_length)
421
+ .expand(batch_size, max_length)
422
+ )
423
+
424
+ # ——— predict and convert rates ———
425
+ pred_rate = model(xt, t)
426
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
427
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
428
+ len_rate = pred_rate.length_rate # (B, L+1)
429
+
430
+ # ——— get pretrained model rates for log_rnd computation ———
431
+ pretrained_pred = pretrained(xt, t)
432
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
433
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
434
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
435
+
436
+ # ——— unmask step (Euler) ———
437
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
438
+ unmask_rate[xt != mask] = 0
439
+ unmask_rate[mask_pos + (mask,)] = 0
440
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
441
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
442
+
443
+ # same for pretrained
444
+ pretrained_unmask_rate[xt != mask] = 0
445
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
446
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
447
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
448
+
449
+ # add “stay” probability
450
+ _xt = xt.clone()
451
+ _xt[xt == pad] = mask
452
+ trans_prob.scatter_add_(
453
+ 2,
454
+ _xt.unsqueeze(-1),
455
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
456
+ )
457
+ pretrained_trans_prob.scatter_add_(
458
+ 2,
459
+ _xt.unsqueeze(-1),
460
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
461
+ )
462
+
463
+ if last_step:
464
+ print("Final step, removing mask token from sampling")
465
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
466
+
467
+ # renormalize probabilities to ensure they sum to 1
468
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
469
+ # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
470
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
471
+ if mask_has_zero_prob.any():
472
+ # create uniform distribution over valid tokens (excluding mask and pad)
473
+ uniform_prob = torch.zeros_like(trans_prob[0])
474
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
475
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
476
+ else:
477
+ # normalize to sum to 1
478
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
479
+
480
+ new_xt = _sample_tokens(trans_prob)
481
+ new_xt[xt == pad] = pad
482
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
483
+
484
+ # ——— compute log probabilities for RND ———
485
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
486
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
487
+
488
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
489
+
490
+ log_policy_step = (lp * changed_mask).sum(dim=1)
491
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
492
+
493
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
494
+
495
+ if not last_step:
496
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
497
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
498
+
499
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
500
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
501
+
502
+ # log P(ext; λ) = ext*log(λ) - λ
503
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
504
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
505
+
506
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
507
+ log_rnd += log_insert_diff
508
+ log_pretrained_step += log_pretrained_insert
509
+ log_policy_step += log_policy_insert
510
+
511
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
512
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
513
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
514
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
515
+ total_ext = ext.sum(dim=1)
516
+ valid = xt_len + total_ext <= max_length
517
+ ext = ext * valid.view(batch_size, 1).long()
518
+
519
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
520
+ new_len = xt_len + total_ext # (B,)
521
+
522
+ xt_tmp = torch.full_like(xt, pad)
523
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
524
+ xt_tmp[mask_fill] = mask
525
+
526
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
527
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
528
+ flat_b = batch_idx_L[orig_mask]
529
+ flat_p = new_pos_orig[orig_mask]
530
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
531
+ else:
532
+ xt_tmp = new_xt
533
+
534
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
535
+
536
+ @torch.no_grad()
537
+ def any_order_euler_sampling_with_schedule(
538
+ model: torch.nn.Module,
539
+ time_schedule: torch.Tensor,
540
+ mask: int,
541
+ pad: int,
542
+ batch_size: int,
543
+ max_length: int,
544
+ return_trace: bool = False,
545
+ temperature: float = 1.0,
546
+ ) -> SamplingResult:
547
+ device = next(model.parameters()).device
548
+
549
+ time_schedule = time_schedule.to(device)
550
+ if time_schedule[0] < time_schedule[-1]:
551
+ time_schedule = torch.flip(time_schedule, [0]) # descending order
552
+
553
+ steps = len(time_schedule) - 1
554
+
555
+ # initialize all-pad sequence and trace
556
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
557
+
558
+ # precompute row indices for scatter
559
+ batch_idx_L = (
560
+ torch.arange(batch_size, device=device)
561
+ .view(batch_size, 1)
562
+ .expand(batch_size, max_length)
563
+ )
564
+ pos_idx_L = (
565
+ torch.arange(max_length, device=device)
566
+ .view(1, max_length)
567
+ .expand(batch_size, max_length)
568
+ )
569
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
570
+
571
+ for i in range(steps):
572
+ # use scheduled timesteps
573
+ t = time_schedule[i].repeat(batch_size)
574
+ t_next = time_schedule[i + 1]
575
+ dt = (t - t_next).abs() # timestep difference
576
+
577
+ # ——— predict and convert rates ———
578
+ pred_rate = model(xt, t)
579
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
580
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
581
+ len_rate = pred_rate.length_rate # (B, L+1)
582
+
583
+ # ——— unmask step (Euler) ———
584
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
585
+ unmask_rate[xt != mask] = 0
586
+ unmask_rate[mask_pos + (mask,)] = 0
587
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
588
+ trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
589
+
590
+ # add "stay" probability
591
+ _xt = xt.clone()
592
+ _xt[xt == pad] = mask
593
+ trans_prob.scatter_add_(
594
+ 2,
595
+ _xt.unsqueeze(-1),
596
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
597
+ )
598
+
599
+ # Apply temperature scaling
600
+ if temperature != 1.0:
601
+ logits = torch.log(trans_prob + 1e-10) / temperature
602
+ trans_prob = torch.softmax(logits, dim=-1)
603
+
604
+ if i == steps - 1:
605
+ print("Final step, removing mask token from sampling")
606
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
607
+
608
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
609
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
610
+
611
+ if mask_has_zero_prob.any():
612
+ uniform_prob = torch.zeros_like(trans_prob[0])
613
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
614
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
615
+ else:
616
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
617
+
618
+ new_xt = _sample_tokens(trans_prob)
619
+ new_xt[xt == pad] = pad
620
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
621
+
622
+ if i != steps - 1:
623
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
624
+ ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
625
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
626
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
627
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
628
+ total_ext = ext.sum(dim=1)
629
+ valid = xt_len + total_ext <= max_length
630
+ ext = ext * valid.view(batch_size, 1).long()
631
+
632
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
633
+ new_len = xt_len + total_ext # (B,)
634
+
635
+ xt_tmp = torch.full_like(xt, pad)
636
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
637
+ xt_tmp[mask_fill] = mask
638
+
639
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
640
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
641
+ flat_b = batch_idx_L[orig_mask]
642
+ flat_p = new_pos_orig[orig_mask]
643
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
644
+ else:
645
+ xt_tmp = new_xt
646
+
647
+ if return_trace:
648
+ # Check if the token was changed
649
+ for batch_idx in range(batch_size):
650
+ for j in range(max_length):
651
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
652
+ sampling_trace[batch_idx].append(
653
+ SamplingTraceDatapoint(
654
+ t=t[batch_idx].item(),
655
+ event_type="change",
656
+ position=j,
657
+ token=new_xt[batch_idx, j].item(),
658
+ )
659
+ )
660
+
661
+ # Check if a new token was inserted
662
+ for j in range(max_length):
663
+ id = max_length - j - 1
664
+ if ext[batch_idx, id]:
665
+ sampling_trace[batch_idx].append(
666
+ SamplingTraceDatapoint(
667
+ t=t[batch_idx].item(),
668
+ event_type="insertion",
669
+ position=id,
670
+ token=mask,
671
+ )
672
+ )
673
+
674
+ xt = xt_tmp
675
+
676
+ return xt, sampling_trace
677
+
678
+
679
+ @torch.no_grad()
680
+ def any_order_mask_insertion_euler_sampling_with_rnd(
681
+ model, pretrained, reward_model, analyzer,
682
+ tokenizer, steps,
683
+ mask,
684
+ pad,
685
+ batch_size,
686
+ max_length,
687
+ return_trace = False,
688
+ alpha = 0.1,
689
+ temperature: float = 1.0,
690
+ ):
691
+ device = next(model.parameters()).device
692
+
693
+ # initialize all‑pad sequence and trace
694
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
695
+ sampling_trace = []
696
+
697
+ # initialize log_rnd to accumulate log probability ratios
698
+ log_rnd = torch.zeros(batch_size, device=device)
699
+
700
+ dt = 1.0 / steps
701
+ t = torch.zeros(batch_size, device=device)
702
+
703
+ # precompute row indices for scatter
704
+ batch_idx_L = (
705
+ torch.arange(batch_size, device=device)
706
+ .view(batch_size, 1)
707
+ .expand(batch_size, max_length)
708
+ )
709
+ pos_idx_L = (
710
+ torch.arange(max_length, device=device)
711
+ .view(1, max_length)
712
+ .expand(batch_size, max_length)
713
+ )
714
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
715
+
716
+ for i in range(steps):
717
+ # ——— predict and convert rates ———
718
+ pred_rate = model(xt, t)
719
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
720
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
721
+ len_rate = pred_rate.length_rate # (B, L+1)
722
+
723
+ # ——— get pretrained model rates for log_rnd computation ———
724
+ pretrained_pred = pretrained(xt, t)
725
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
726
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
727
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
728
+
729
+ # ——— unmask step (Euler) ———
730
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
731
+ unmask_rate[xt != mask] = 0
732
+ unmask_rate[mask_pos + (mask,)] = 0
733
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
734
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
735
+
736
+ # Same for pretrained
737
+ pretrained_unmask_rate[xt != mask] = 0
738
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
739
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
740
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
741
+
742
+ # add “stay” probability
743
+ _xt = xt.clone()
744
+ _xt[xt == pad] = mask
745
+ trans_prob.scatter_add_(
746
+ 2,
747
+ _xt.unsqueeze(-1),
748
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
749
+ )
750
+ pretrained_trans_prob.scatter_add_(
751
+ 2,
752
+ _xt.unsqueeze(-1),
753
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
754
+ )
755
+
756
+ # Apply temperature scaling
757
+ if temperature != 1.0:
758
+ logits = torch.log(trans_prob + 1e-10) / temperature
759
+ trans_prob = torch.softmax(logits, dim=-1)
760
+
761
+ if i == steps - 1:
762
+ print("Final step, removing mask token from sampling")
763
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
764
+
765
+ # renormalize probabilities to ensure they sum to 1
766
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
767
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
768
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
769
+ if mask_has_zero_prob.any():
770
+ # create uniform distribution over valid tokens (excluding mask and pad)
771
+ uniform_prob = torch.zeros_like(trans_prob[0])
772
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
773
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
774
+ else:
775
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
776
+
777
+ new_xt = _sample_tokens(trans_prob)
778
+ new_xt[xt == pad] = pad
779
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
780
+
781
+ # ——— compute log probabilities for RND ———
782
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
783
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
784
+
785
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
786
+
787
+ log_policy_step = (lp * changed_mask).sum(dim=1)
788
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
789
+
790
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
791
+
792
+ if i != steps - 1:
793
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
794
+
795
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
796
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
797
+
798
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
799
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
800
+
801
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
802
+ log_rnd += log_insert_diff
803
+
804
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
805
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
806
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
807
+ total_ext = ext.sum(dim=1)
808
+ valid = xt_len + total_ext <= max_length
809
+ ext = ext * valid.view(batch_size, 1).long()
810
+
811
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
812
+ new_len = xt_len + total_ext # (B,)
813
+
814
+ xt_tmp = torch.full_like(xt, pad)
815
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
816
+ xt_tmp[mask_fill] = mask
817
+
818
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
819
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
820
+ flat_b = batch_idx_L[orig_mask]
821
+ flat_p = new_pos_orig[orig_mask]
822
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
823
+ else:
824
+ xt_tmp = new_xt
825
+
826
+ if return_trace:
827
+ # check if the token was changed
828
+ for i in range(batch_size):
829
+ for j in range(max_length):
830
+ if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
831
+ sampling_trace[i].append(
832
+ SamplingTraceDatapoint(
833
+ t=t[i].item(),
834
+ event_type="change",
835
+ position=j,
836
+ token=new_xt[i, j].item(),
837
+ )
838
+ )
839
+
840
+ # check if a new token was inserted
841
+ for j in range(max_length):
842
+ id = max_length - j - 1
843
+ if ext[i, id]:
844
+ sampling_trace[i].append(
845
+ SamplingTraceDatapoint(
846
+ t=t[i].item(),
847
+ event_type="insertion",
848
+ position=id,
849
+ token=mask,
850
+ )
851
+ )
852
+
853
+ xt = xt_tmp
854
+ t = t + dt
855
+
856
+ # change rewards for peptides
857
+ samples = xt.to(device)
858
+
859
+ # store raw token IDs
860
+ # Decode and strip samples
861
+ decoded_samples = tokenizer.batch_decode(samples)
862
+
863
+ valid_x_final = []
864
+ validSequences = []
865
+ valid_log_rnd = []
866
+
867
+ for idx, seq in enumerate(decoded_samples):
868
+ # check if the peptide is valid
869
+ if analyzer.is_peptide(seq):
870
+ valid_x_final.append(xt[idx])
871
+ validSequences.append(seq)
872
+ valid_log_rnd.append(log_rnd[idx])
873
+
874
+ print("len valid sequences:", len(validSequences))
875
+ # compute multi-objective rewards
876
+ score_vectors = reward_model(input_seqs=validSequences)
877
+ scalar_rewards = np.sum(score_vectors, axis=-1)
878
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
879
+
880
+ print(f"scalar reward dim{len(scalar_rewards)}")
881
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
882
+
883
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
884
+ valid_x_final = torch.stack(valid_x_final, dim=0)
885
+
886
+ return valid_x_final, log_rnd, scalar_rewards, sampling_trace
887
+
888
+ @torch.no_grad()
889
+ def any_order_finetuned_euler_sampler(
890
+ model, reward_model, analyzer,
891
+ tokenizer, steps,
892
+ mask,
893
+ pad,
894
+ batch_size,
895
+ max_length,
896
+ return_trace = False,
897
+ dataframe = False,
898
+ temperature: float = 1.0,
899
+ ):
900
+ device = next(model.parameters()).device
901
+
902
+ # initialize all‑pad sequence and trace
903
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
904
+ sampling_trace = []
905
+
906
+ dt = 1.0 / steps
907
+ t = torch.zeros(batch_size, device=device)
908
+
909
+ # precompute row indices for scatter
910
+ batch_idx_L = (
911
+ torch.arange(batch_size, device=device)
912
+ .view(batch_size, 1)
913
+ .expand(batch_size, max_length)
914
+ )
915
+ pos_idx_L = (
916
+ torch.arange(max_length, device=device)
917
+ .view(1, max_length)
918
+ .expand(batch_size, max_length)
919
+ )
920
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
921
+
922
+ for i in range(steps):
923
+ # ——— predict and convert rates ———
924
+ pred_rate = model(xt, t)
925
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
926
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
927
+ len_rate = pred_rate.length_rate # (B, L+1)
928
+
929
+ # ——— unmask step (Euler) ———
930
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
931
+ unmask_rate[xt != mask] = 0
932
+ unmask_rate[mask_pos + (mask,)] = 0
933
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
934
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
935
+
936
+ # add “stay” probability
937
+ _xt = xt.clone()
938
+ _xt[xt == pad] = mask
939
+ trans_prob.scatter_add_(
940
+ 2,
941
+ _xt.unsqueeze(-1),
942
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
943
+ )
944
+
945
+ # Apply temperature scaling
946
+ if temperature != 1.0:
947
+ logits = torch.log(trans_prob + 1e-10) / temperature
948
+ trans_prob = torch.softmax(logits, dim=-1)
949
+
950
+ if i == steps - 1:
951
+ print("Final step, removing mask token from sampling")
952
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
953
+
954
+ # renormalize probabilities to ensure they sum to 1
955
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
956
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
957
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
958
+ if mask_has_zero_prob.any():
959
+ # create uniform distribution over valid tokens (excluding mask and pad)
960
+ uniform_prob = torch.zeros_like(trans_prob[0])
961
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
962
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
963
+ else:
964
+ # normalize to sum to 1
965
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
966
+
967
+ new_xt = _sample_tokens(trans_prob)
968
+ new_xt[xt == pad] = pad
969
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
970
+
971
+ if i != steps - 1:
972
+ # gap-wise insertion refactored — compute new length, fill masks, scatter tokens
973
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
974
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
975
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
976
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
977
+ total_ext = ext.sum(dim=1)
978
+ valid = xt_len + total_ext <= max_length
979
+ ext = ext * valid.view(batch_size, 1).long()
980
+
981
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
982
+ new_len = xt_len + total_ext # (B,)
983
+
984
+ xt_tmp = torch.full_like(xt, pad)
985
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
986
+ xt_tmp[mask_fill] = mask
987
+
988
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
989
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
990
+ flat_b = batch_idx_L[orig_mask]
991
+ flat_p = new_pos_orig[orig_mask]
992
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
993
+ else:
994
+ xt_tmp = new_xt
995
+
996
+ if return_trace:
997
+ # check if the token was changed
998
+ for batch_idx in range(batch_size):
999
+ for j in range(max_length):
1000
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
1001
+ sampling_trace[batch_idx].append(
1002
+ SamplingTraceDatapoint(
1003
+ t=t[batch_idx].item(),
1004
+ event_type="change",
1005
+ position=j,
1006
+ token=new_xt[batch_idx, j].item(),
1007
+ )
1008
+ )
1009
+
1010
+ # check if a new token was inserted
1011
+ for j in range(max_length):
1012
+ id = max_length - j - 1
1013
+ if ext[batch_idx, id]:
1014
+ sampling_trace[batch_idx].append(
1015
+ SamplingTraceDatapoint(
1016
+ t=t[batch_idx].item(),
1017
+ event_type="insertion",
1018
+ position=id,
1019
+ token=mask,
1020
+ )
1021
+ )
1022
+
1023
+ xt = xt_tmp
1024
+ t = t + dt
1025
+
1026
+ # start eval
1027
+ samples = xt.to(device)
1028
+
1029
+ decoded_samples = tokenizer.batch_decode(samples)
1030
+
1031
+ valid_x_final = []
1032
+ validSequences = []
1033
+
1034
+ for idx, seq in enumerate(decoded_samples):
1035
+ if analyzer.is_peptide(seq):
1036
+ valid_x_final.append(samples[idx])
1037
+ validSequences.append(seq)
1038
+
1039
+ print("len valid sequences:", len(validSequences))
1040
+ valid_fraction = len(validSequences) / batch_size
1041
+
1042
+ if (len(validSequences) != 0):
1043
+ # add scores to log
1044
+ score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
1045
+ average_scores = score_vectors.T
1046
+
1047
+ affinity = average_scores[0]
1048
+ sol = average_scores[1]
1049
+ hemo = average_scores[2]
1050
+ nf = average_scores[3]
1051
+ permeability = average_scores[4]
1052
+
1053
+ else:
1054
+ zeros = [0.0]
1055
+
1056
+ affinity = zeros
1057
+ sol = zeros
1058
+ hemo = zeros
1059
+ nf = zeros
1060
+ permeability = zeros
1061
+
1062
+ if dataframe:
1063
+ df = pd.DataFrame({
1064
+ "Peptide Sequence": validSequences,
1065
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
1066
+ "Solubility": sol if len(validSequences) else [0.0],
1067
+ "Hemolysis": hemo if len(validSequences) else [0.0],
1068
+ "Nonfouling": nf if len(validSequences) else [0.0],
1069
+ "Permeability": permeability if len(validSequences) else [0.0],
1070
+ })
1071
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
1072
+
1073
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
1074
+
1075
+ @torch.no_grad()
1076
+ def mdm_tau_leaping_sampling(
1077
+ model: MaskedDiffusionModule,
1078
+ steps: int,
1079
+ mask: int,
1080
+ pad: int,
1081
+ batch_size: int,
1082
+ max_length: int,
1083
+ return_trace: bool = False,
1084
+ temperature: float = 1.0,
1085
+ ):
1086
+ assert not return_trace, "Trace is not yet supported"
1087
+ device = next(model.parameters()).device
1088
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
1089
+ dt = 1.0 / steps
1090
+ t = torch.zeros(batch_size, device=device)
1091
+
1092
+ for i in range(steps):
1093
+ # ——— predict and convert rates ———
1094
+ pred = model(xt, t)
1095
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
1096
+ unmask_rate = pred.unmask_rate # (B, L, V)
1097
+
1098
+ if i == steps - 1:
1099
+ # last step: deterministic unmask via argmax
1100
+ mask_pos = xt == mask # (B, L)
1101
+ new_token = unmask_rate.argmax(dim=2) # (B, L)
1102
+ new_xt = xt.clone()
1103
+ new_xt[mask_pos] = new_token[mask_pos]
1104
+ new_xt = torch.where(xt != mask, xt, new_xt)
1105
+ xt = new_xt
1106
+ t = t + dt
1107
+ continue
1108
+ # tau-leaping via Poisson counts
1109
+ counts = torch.poisson(unmask_rate * dt).long()
1110
+ mask_pos = xt == mask # (B, L)
1111
+ # zero out non-mask positions and mask→mask
1112
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
1113
+ counts[..., mask] = 0
1114
+ # only accept exactly one event
1115
+ sum_c = counts.sum(dim=2) # (B, L)
1116
+ one_event = sum_c == 1
1117
+ new_token = counts.argmax(dim=2) # (B, L)
1118
+
1119
+ # build new xt
1120
+ new_xt = xt.clone()
1121
+ new_xt[one_event] = new_token[one_event]
1122
+ # keep pads and already-unmasked tokens
1123
+ new_xt = torch.where(xt != mask, xt, new_xt)
1124
+ xt = new_xt
1125
+ t = t + dt
1126
+
1127
+ return xt, []
1128
+
1129
+ # Not used in production, for debugging purposes
1130
+ lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
1131
+
1132
+ def binomial_mass(k, n, p):
1133
+ """
1134
+ Calculate the probability mass function (PMF) for a binomial distribution.
1135
+
1136
+ Args:
1137
+ k (int): Number of successes
1138
+ n (int): Number of trials
1139
+ p (float): Probability of success in a single trial
1140
+
1141
+ Returns:
1142
+ float: Probability mass P(X = k)
1143
+ """
1144
+ import math
1145
+
1146
+ # Calculate binomial coefficient (n choose k)
1147
+ try:
1148
+ binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
1149
+ except ValueError:
1150
+ # Handle cases where k > n or negative values
1151
+ return 0.0
1152
+
1153
+ # Calculate probability mass
1154
+ return binom_coef * (p ** k) * ((1 - p) ** (n - k))
1155
+
1156
+ def calculate_rate_batch(alpha_t, len_t):
1157
+ """
1158
+ Calculate rate for a batch of alpha_t and len_t values.
1159
+
1160
+ Args:
1161
+ alpha_t (torch.Tensor): Tensor of shape (batch_size,)
1162
+ len_t (torch.Tensor): Tensor of shape (batch_size,)
1163
+
1164
+ Returns:
1165
+ torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
1166
+ """
1167
+ batch_size = alpha_t.shape[0]
1168
+ device = alpha_t.device
1169
+
1170
+ # Initialize tensors for numerator and denominator
1171
+ nom = torch.zeros(batch_size, device=device)
1172
+ denom = torch.zeros(batch_size, device=device)
1173
+
1174
+ for length, probability in lengths.items():
1175
+ # Create mask for valid entries where len_t <= length
1176
+ valid_mask = (len_t <= length) & (len_t >= 0)
1177
+
1178
+ if not valid_mask.any():
1179
+ continue
1180
+
1181
+ valid_indices = valid_mask.nonzero(as_tuple=True)[0]
1182
+ valid_len_t = len_t[valid_indices]
1183
+ valid_alpha_t = alpha_t[valid_indices]
1184
+
1185
+ # Calculate binomial probabilities efficiently using torch distribution
1186
+ binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
1187
+ binom_probs = binom_dist.log_prob(valid_len_t).exp()
1188
+
1189
+ # Update numerator and denominator for valid indices
1190
+ nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
1191
+ denom[valid_indices] += probability * binom_probs
1192
+
1193
+ # Handle division by zero in a vectorized way
1194
+ result = torch.zeros_like(nom)
1195
+ div_mask = denom > 0
1196
+ result[div_mask] = nom[div_mask] / (denom[div_mask])
1197
+
1198
+ return result
1199
+
1200
+ # Keep the original function for backward compatibility
1201
+ def calculate_rate(alpha_t, len_t):
1202
+ """Legacy scalar version of calculate_rate"""
1203
+ if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
1204
+ return calculate_rate_batch(alpha_t, len_t)
1205
+
1206
+ nom, denom = 0, 0
1207
+ for length, probability in lengths.items():
1208
+ if length >= len_t:
1209
+ nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
1210
+ denom += probability * binomial_mass(len_t, length, alpha_t)
1211
+
1212
+ if denom == 0:
1213
+ return 0.0
1214
+
1215
+ return nom /denom
1216
+
1217
+
1218
+ @torch.no_grad()
1219
+ def any_order_mask_insertion_tau_leaping_sampling(
1220
+ model: torch.nn.Module,
1221
+ steps: int,
1222
+ mask: int,
1223
+ pad: int,
1224
+ batch_size: int,
1225
+ max_length: int,
1226
+ return_trace: bool = False,
1227
+ confidence_based_sampling: bool = True, # whether to use confidence-based decoding
1228
+ alpha: float = 5.0, # hyperparameter for window size calculation
1229
+ max_window: int = 32, # Maximum window size for sliding window
1230
+ confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
1231
+ use_sliding_window: bool = False, # whether to use sliding window for position selection
1232
+ temperature: float = 1.0,
1233
+ ) -> SamplingResult:
1234
+
1235
+ device = next(model.parameters()).device
1236
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
1237
+ sampling_trace = []
1238
+ dt = 1.0 / steps
1239
+ t = torch.zeros(batch_size, device=device)
1240
+
1241
+ # Precompute row indices for scatter
1242
+ batch_idx_L = (
1243
+ torch.arange(batch_size, device=device)
1244
+ .view(batch_size, 1)
1245
+ .expand(batch_size, max_length)
1246
+ )
1247
+ pos_idx_L = (
1248
+ torch.arange(max_length, device=device)
1249
+ .view(1, max_length)
1250
+ .expand(batch_size, max_length)
1251
+ )
1252
+
1253
+ for i in range(steps):
1254
+ # --- predict rates ---
1255
+ pred = model(xt, t)
1256
+ xt_len = (xt != pad).sum(dim=1)
1257
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
1258
+ unmask_rate = pred.unmask_rate # (B, L, V)
1259
+ len_rate = pred.length_rate # (B, L+1)
1260
+
1261
+ if i == steps - 1:
1262
+ # last step: deterministic unmask via argmax
1263
+ mask_pos = xt == mask
1264
+ new_token = unmask_rate.argmax(dim=2)
1265
+ new_xt = xt.clone()
1266
+ new_xt[mask_pos] = new_token[mask_pos]
1267
+ new_xt = torch.where(xt == pad, pad, new_xt)
1268
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
1269
+ xt = new_xt
1270
+ t = t + dt
1271
+ continue
1272
+
1273
+ # --- confidence-based decoding ---
1274
+ if confidence_based_sampling > 0.0:
1275
+ # Confidence-based unmasking (vectorized)
1276
+ mask_positions = (xt == mask) # (B, L)
1277
+ num_mask_positions = mask_positions.sum(dim=1) # (B,)
1278
+
1279
+ # 1. Determine number of tokens to unmask using Poisson
1280
+ unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
1281
+
1282
+ # 2. Calculate confidence based on selected method
1283
+ if confidence_method == "position":
1284
+ # Position-based confidence: position i / len(xt)
1285
+ xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
1286
+ position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
1287
+ confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
1288
+
1289
+ elif confidence_method == "top_prob":
1290
+ # Top probability confidence
1291
+ import torch.nn.functional as F
1292
+ token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
1293
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1294
+ confidence = unmask_probs.max(dim=-1)[0] # (B, L)
1295
+
1296
+ elif confidence_method == "prob_diff":
1297
+ # Probability difference confidence (top - second top)
1298
+ import torch.nn.functional as F
1299
+ token_logits = unmask_rate # (B, L, V)
1300
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1301
+ top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
1302
+ confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
1303
+
1304
+ elif confidence_method == "entropy":
1305
+ # Entropy-based confidence (lower entropy = higher confidence)
1306
+ import torch.nn.functional as F
1307
+ token_logits = unmask_rate # (B, L, V)
1308
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1309
+ entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
1310
+ confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
1311
+
1312
+ else:
1313
+ raise ValueError(f"Unknown confidence_method: {confidence_method}")
1314
+
1315
+ # 3. Apply window constraint if enabled
1316
+ if use_sliding_window:
1317
+ # Calculate dynamic k for each batch
1318
+ k_values = torch.minimum(
1319
+ torch.minimum(
1320
+ (alpha * unmask_counts).long(),
1321
+ torch.tensor(max_window, device=device)
1322
+ ), num_mask_positions) # (B,)
1323
+
1324
+ # Get cumulative count of mask positions
1325
+ mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
1326
+
1327
+ # Create window mask: position is eligible if it's a mask and within first k masks
1328
+ is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
1329
+ window_mask = mask_positions & is_within_window # (B, L)
1330
+
1331
+ # Set confidence to -inf for positions outside the window or non-mask positions
1332
+ confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
1333
+ else:
1334
+ # No window constraint - only mask positions are eligible
1335
+ confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
1336
+
1337
+ new_xt = xt.clone()
1338
+
1339
+ # vectorized unmasking
1340
+ max_unmask = unmask_counts.max().item()
1341
+ if max_unmask > 0:
1342
+ _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
1343
+
1344
+ # create mask for valid unmask operations
1345
+ unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
1346
+
1347
+ most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
1348
+
1349
+ selected_positions = all_top_indices[unmask_mask]
1350
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
1351
+
1352
+ new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
1353
+ else:
1354
+ # --- tau-leaping unmask via Poisson ---
1355
+ counts = torch.poisson(unmask_rate * dt).long()
1356
+ mask_pos = xt == mask
1357
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
1358
+ counts[..., mask] = 0
1359
+ sum_c = counts.sum(dim=2)
1360
+ one_event = sum_c == 1
1361
+ new_token = counts.argmax(dim=2)
1362
+ new_xt = xt.clone()
1363
+ new_xt[one_event] = new_token[one_event]
1364
+ new_xt = torch.where(xt == pad, pad, new_xt)
1365
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
1366
+
1367
+ # insertion only on non-last
1368
+ if i != steps - 1:
1369
+ # --- Poisson insertion, compute new lengths and fill masks ---
1370
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
1371
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
1372
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
1373
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
1374
+ total_ext = ext.sum(dim=1)
1375
+ valid = xt_len + total_ext <= max_length
1376
+ ext = ext * valid.view(batch_size, 1).long()
1377
+
1378
+ # compute prefix sums of insertions
1379
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
1380
+ new_len = xt_len + total_ext # (B,)
1381
+
1382
+ # initialize with pads, then fill mask up to new_len
1383
+ xt_tmp = torch.full_like(xt, pad)
1384
+ mask_pos = pos_idx_L < new_len.view(batch_size, 1)
1385
+ xt_tmp[mask_pos] = mask
1386
+
1387
+ # shift and scatter original tokens
1388
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
1389
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
1390
+ flat_b = batch_idx_L[orig_mask]
1391
+ flat_p = new_pos_orig[orig_mask]
1392
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
1393
+ else:
1394
+ xt_tmp = new_xt
1395
+
1396
+ xt = xt_tmp
1397
+ t = t + dt
1398
+ if return_trace:
1399
+ sampling_trace.append(xt)
1400
+
1401
+ return xt, sampling_trace
a2d2_mol/scripts/run_mol_finetune.slurm ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # NOTE: --partition and --qos below are specific to our cluster. Change them
3
+ # (or remove them and pass `--partition` on the `sbatch` command line) to match
4
+ # the partitions/QOS available on yours.
5
+ #SBATCH --job-name=mol-finetune
6
+ #SBATCH --partition=dgx-b200
7
+ #SBATCH --nodes=1
8
+ #SBATCH --gpus-per-node=1
9
+ #SBATCH --cpus-per-task=8
10
+ #SBATCH --ntasks-per-node=1
11
+ #SBATCH --mem=80GB
12
+ #SBATCH --time=02-00:00:00
13
+ #SBATCH --output=logs/slurm-%A.%x.log
14
+
15
+ # =====================================================================
16
+ # run_mol_finetune.slurm
17
+ #
18
+ # Single-mode job (1 MIG GPU) running ONE finetune_mol experiment.
19
+ # Select which mode to run via the MODE_ID variable below (or override
20
+ # at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
21
+ # 0) A2D2 (Ours) – with full planner (alternating)
22
+ # 1) A2D2 w/o quality – --disable_planner
23
+ # 2) A2D2 w/o insertion planner – --disable_insertion_planner
24
+ # 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
25
+ #
26
+ # The job trains the selected mode then evaluates the resulting
27
+ # checkpoint on the same GPU.
28
+ # =====================================================================
29
+
30
+ set -e
31
+
32
+ # --- Mode selection ---------------------------------------------------
33
+ # Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
34
+ MODE_ID="${MODE_ID:-0}"
35
+
36
+ # Run prefix
37
+ PREFIX=${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}
38
+
39
+ # --- Paths ------------------------------------------------------------
40
+ # Repo root is resolved at submit time so the job runs from any clone:
41
+ # - set A2D2_ROOT explicitly, OR
42
+ # - submit with `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR;
43
+ # note sbatch copies the script to a spool dir, so we can't rely on the
44
+ # script's own path here), OR
45
+ # - run the script directly, falling back to its location on disk.
46
+ if [ -n "${A2D2_ROOT:-}" ]; then
47
+ HOME_LOC="$A2D2_ROOT"
48
+ elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
49
+ HOME_LOC="$SLURM_SUBMIT_DIR"
50
+ else
51
+ # This script lives in a2d2_mol/scripts/, so the repo root is two levels up.
52
+ HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
53
+ fi
54
+ SCRIPT_LOC="$HOME_LOC/a2d2_mol"
55
+ LOG_LOC=$HOME_LOC/logs
56
+ SAVE_DIR=$HOME_LOC/checkpoints/finetune_mol
57
+ RESULTS_DIR=$HOME_LOC/results/mol_ablation
58
+
59
+ mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
60
+
61
+ # --- Environment setup ------------------------------------------------
62
+ # Set WANDB_API_KEY in your shell/secret store before submitting (do NOT commit it):
63
+ # export WANDB_API_KEY=... or `wandb login`
64
+ export WANDB_DIR=$HOME_LOC/.wandb
65
+ export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
66
+ export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
67
+ mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
68
+
69
+ export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
70
+ mkdir -p "$TRITON_CACHE_DIR"
71
+
72
+ export TORCHINDUCTOR_CACHE_DIR=$HOME_LOC/.torchinductor/cache
73
+ mkdir -p "$TORCHINDUCTOR_CACHE_DIR"
74
+
75
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
76
+
77
+ # Force unbuffered stdout/stderr so live training output is flushed to the
78
+ # redirected RUN_LOG (Python block-buffers stdout when it's a file, not a TTY).
79
+ export PYTHONUNBUFFERED=1
80
+
81
+ # Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
82
+ # install, or just have `conda` on PATH; override CONDA_ENV if your env name
83
+ # differs from the one created by environment.yml.
84
+ CONDA_ENV="${CONDA_ENV:-a2d2}"
85
+ if [ -n "${CONDA_ROOT:-}" ]; then
86
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
87
+ elif command -v conda >/dev/null 2>&1; then
88
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
89
+ else
90
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
91
+ exit 1
92
+ fi
93
+ PYTHON_EXECUTABLE=$(which python)
94
+
95
+ cd "$SCRIPT_LOC"
96
+
97
+ # Pretrained base checkpoint
98
+ PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_mol.ckpt"
99
+
100
+ # --- Shared training hyperparameters ----------------------------------
101
+ COMMON_ARGS=(
102
+ --base_path "$HOME_LOC"
103
+ --use_quality_filter
104
+ --noise_removal
105
+ --wdce_num_replicates 16
106
+ --pool_size 1000
107
+ --pool_refresh_fraction 0.3
108
+ --buffer_size 100
109
+ --batch_size 200
110
+ --training_mini_batch_size 20
111
+ --max_length 256
112
+ --total_num_steps 256
113
+ --num_iter 20
114
+ --resample_every_n_step 10
115
+ --num_epochs 1000
116
+ --save_every_n_epochs 100
117
+ --reset_every_n_step 1
118
+ --alpha 0.01
119
+ --no_mcts
120
+ --schedule_warmup_epochs 20
121
+ --alternation_frequency 5
122
+ --num_remasking 3
123
+ --quality_threshold 0.3
124
+ --checkpoint_path "$PRETRAINED_CKPT"
125
+ --grad_clip
126
+ --qed_only
127
+ --seed 42
128
+ --num_training_steps_per_epoch 25
129
+ )
130
+
131
+ # --- Shared evaluation hyperparameters --------------------------------
132
+ EVAL_COMMON_ARGS=(
133
+ --pretrained_ckpt "$PRETRAINED_CKPT"
134
+ --num_samples 1000
135
+ --batch_size 50
136
+ --max_length 256
137
+ --total_num_steps 256
138
+ --num_remasking 2
139
+ --quality_threshold 0.3
140
+ --seed 42
141
+ )
142
+
143
+ # =====================================================================
144
+ # Pick experiment from $MODE_ID
145
+ # =====================================================================
146
+ case "$MODE_ID" in
147
+ 0) MODE="with_planner"; EXTRA_ARGS=() ;;
148
+ 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
149
+ 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
150
+ 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
151
+ *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
152
+ esac
153
+
154
+ RUN_NAME="${PREFIX}_mol_${MODE}"
155
+ RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
156
+ RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
157
+ RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
158
+ mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
159
+
160
+ echo "=== Mol finetune (MODE_ID=$MODE_ID) ==="
161
+ echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
162
+ echo "Mode: $MODE"
163
+ echo "Save dir: $RUN_SAVE_DIR"
164
+ echo "Results dir: $RESULTS_SUBDIR"
165
+ echo "Python: $PYTHON_EXECUTABLE"
166
+ echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
167
+
168
+ # =====================================================================
169
+ # Train
170
+ # =====================================================================
171
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_mol.py \
172
+ "${COMMON_ARGS[@]}" \
173
+ --devices 1 \
174
+ "${EXTRA_ARGS[@]}" \
175
+ --save_path_dir "$RUN_SAVE_DIR" \
176
+ >> "$RUN_LOG" 2>&1
177
+
178
+ echo "Training finished for $MODE. Log: $RUN_LOG"
179
+
180
+ # =====================================================================
181
+ # Evaluate
182
+ # =====================================================================
183
+ RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt "$RUN_SAVE_DIR"/last.ckpt 2>/dev/null | head -1)
184
+ if [ -z "$RUN_CKPT" ]; then
185
+ echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
186
+ exit 1
187
+ fi
188
+
189
+ echo "Evaluating checkpoint: $RUN_CKPT"
190
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_mol_table.py \
191
+ --checkpoint_path "$RUN_CKPT" \
192
+ "${EVAL_COMMON_ARGS[@]}" \
193
+ "${EXTRA_ARGS[@]}" \
194
+ --output_dir "$RESULTS_SUBDIR" \
195
+ --device cuda:0 \
196
+ >> "$RESULTS_SUBDIR/eval.log" 2>&1
197
+
198
+ echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}.csv"
199
+
200
+ conda deactivate
a2d2_mol/scripts/train_mol.sh ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=a2d2-mol-pretrain
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --nodes=1
5
+ #SBATCH --gpus-per-node=2
6
+ #SBATCH --ntasks-per-node=2
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=512GB
9
+ #SBATCH --time=7-00:00:00
10
+ # SLURM's own catch-file (anything printed before the exec redirect below, plus
11
+ # slurm-infra messages). Relative to the submit dir, so submit this script from
12
+ # the a2d2_mol/ directory; the real run output is redirected via exec below.
13
+ #SBATCH --output=logs/slurm/%x_%j.out
14
+ #SBATCH --error=logs/slurm/%x_%j.err
15
+ #
16
+ # Pretrain the any-length insertion MDM on drug-like SAFE molecules on a dgx-b200 node.
17
+ # Submit with: sbatch scripts/train_mol.sh (from the a2d2_mol/ directory).
18
+ #
19
+ # DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
20
+ # --ntasks-per-node must match; change both together (and they override the
21
+ # training.devices value baked into config_mol.yaml via the hydra override below).
22
+
23
+ DATE=$(date +%Y%m%d)
24
+ SPECIAL_PREFIX='a2d2-mol'
25
+
26
+ # Resolve a2d2_mol/ (which holds train.py + config_mol.yaml) so paths are
27
+ # repo-relative. This script lives in a2d2_mol/scripts/, so the direct-run
28
+ # fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
29
+ # copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_mol/ directory).
30
+ if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
31
+ SCRIPT_DIR="$SLURM_SUBMIT_DIR"
32
+ else
33
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
34
+ fi
35
+ cd "$SCRIPT_DIR"
36
+
37
+ # Auto-detect GPUs from the SLURM allocation (falls back to 2 for `bash` runs).
38
+ DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-2}}
39
+ NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
40
+ NODES=${SLURM_NNODES:-1}
41
+
42
+ LOG_LOC="$SCRIPT_DIR/logs"
43
+ mkdir -p "$LOG_LOC/slurm"
44
+ exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Weights & Biases: log in once on your machine before running this script with
48
+ # `wandb login` (or `export WANDB_API_KEY=<your-key>`).
49
+ # Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
50
+ # export WANDB_MODE=disabled
51
+ # ---------------------------------------------------------------------------
52
+
53
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
54
+
55
+ # Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
56
+ # The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
57
+ # CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
58
+ # PATH; override CONDA_ENV if your env name differs from the one created by
59
+ # environment.yml.
60
+ CONDA_ENV="${CONDA_ENV:-a2d2}"
61
+ if [ -n "${CONDA_ROOT:-}" ]; then
62
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
63
+ elif command -v conda >/dev/null 2>&1; then
64
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
65
+ else
66
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
67
+ exit 1
68
+ fi
69
+
70
+ # --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
71
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
72
+ if [ -z "$ETH_IFACE" ]; then
73
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
74
+ fi
75
+ export NCCL_IB_DISABLE=1
76
+ export NCCL_SOCKET_FAMILY=AF_INET
77
+ export NCCL_SOCKET_IFNAME=$ETH_IFACE
78
+ export NCCL_P2P_LEVEL=NVL
79
+
80
+ export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
81
+ export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
82
+ export NODE_RANK=${SLURM_NODEID:-0}
83
+
84
+ echo "=== a2d2 molecule pretraining (dgx-b200) ==="
85
+ echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
86
+
87
+ # --task mol makes train.py load config_mol.yaml; the hydra overrides pin
88
+ # devices/nodes to the SLURM allocation so the two never drift apart.
89
+ srun --ntasks-per-node=$NTASKS python train.py --task mol \
90
+ training.devices=$DEVICES \
91
+ training.nodes=$NODES
92
+
93
+ conda deactivate
a2d2_mol/train.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from pytorch_lightning.loggers import WandbLogger
4
+ from pytorch_lightning.callbacks import ModelCheckpoint
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import hydra
9
+ from omegaconf import OmegaConf
10
+ from datetime import datetime
11
+ # Directory containing this file and the config_*.yaml files (used by Hydra below).
12
+ CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ # Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
14
+ sys.path.insert(0, os.path.dirname(CONFIG_DIR))
15
+
16
+ import wandb
17
+ from lightning_modules import AnyOrderInsertionFlowModule
18
+
19
+
20
+ torch.set_printoptions(threshold=10_000)
21
+ torch.set_float32_matmul_precision("high")
22
+
23
+ # Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
24
+ torch._dynamo.config.optimize_ddp = False
25
+
26
+ def train(config):
27
+ wandb_logger = None
28
+
29
+ # set the random seed
30
+ pl.seed_everything(42)
31
+ torch.manual_seed(42)
32
+
33
+ # Only initialize wandb on rank 0 to avoid multiple runs
34
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
35
+ wandb.init(
36
+ project=config.wandb.project,
37
+ name=config.wandb.name,
38
+ config=OmegaConf.to_container(config, resolve=True), # Convert to dict
39
+ dir=config.wandb.path
40
+ )
41
+ wandb_logger = WandbLogger(
42
+ project=wandb.run.project,
43
+ name=wandb.run.name,
44
+ log_model=False, # Disable checkpoint uploading to save disk space
45
+ )
46
+
47
+ # Modify config to add timestamp to checkpoint directory
48
+ OmegaConf.set_struct(config, False)
49
+ time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
50
+ config.training.checkpoint_dir = os.path.join(
51
+ config.training.checkpoint_dir, time_string
52
+ )
53
+ OmegaConf.set_struct(config, True)
54
+
55
+ # Create checkpoint directory
56
+ os.makedirs(config.training.checkpoint_dir, exist_ok=True)
57
+
58
+ # Setup data module - check if using HuggingFace dataset
59
+ if hasattr(config, 'hf_dataset'):
60
+ # Imported lazily: the HF/SAFE path is only used by the molecule configs,
61
+ # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
62
+ from mol_dataset import setup_hf_data_and_update_config
63
+ print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
64
+ data_module = setup_hf_data_and_update_config(
65
+ config,
66
+ dataset_name=config.hf_dataset.name,
67
+ smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
68
+ )
69
+ else:
70
+ # Imported lazily: the local (arrow) path is used by the peptide config,
71
+ # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
72
+ from dataloading_for_dynamic_batching import setup_data_and_update_config
73
+ print("Using local dataset")
74
+ data_module = setup_data_and_update_config(config)
75
+
76
+ module = AnyOrderInsertionFlowModule(config)
77
+
78
+ # Initialize trainer
79
+
80
+ # Configure trainer arguments
81
+ # Map torch_dtype to Lightning precision
82
+ dtype_str = config.model.get('torch_dtype', 'bfloat16')
83
+ precision_map = {
84
+ 'float32': '32-true',
85
+ 'float16': '16-mixed',
86
+ 'bfloat16': 'bf16-mixed'
87
+ }
88
+ precision = precision_map.get(dtype_str, 'bf16-mixed')
89
+
90
+ trainer_kwargs = dict(
91
+ num_nodes=config.training.nodes,
92
+ accelerator="gpu",
93
+ devices=config.training.devices,
94
+ strategy="ddp",
95
+ precision=precision,
96
+ accumulate_grad_batches=(
97
+ config.training.batch_size
98
+ // (
99
+ config.training.per_gpu_batch_size
100
+ * config.training.nodes
101
+ * config.training.devices
102
+ )
103
+ ),
104
+ log_every_n_steps=10,
105
+ enable_checkpointing=True,
106
+ default_root_dir=config.training.checkpoint_dir,
107
+ gradient_clip_val=1.0,
108
+ )
109
+ # Only one of max_steps or max_epochs will be used
110
+ if config.training.max_steps is not None:
111
+ trainer_kwargs["max_steps"] = config.training.max_steps
112
+ elif config.training.num_epochs is not None:
113
+ trainer_kwargs["max_epochs"] = config.training.num_epochs
114
+ config.training.max_steps = config.training.max_steps
115
+ else:
116
+ raise ValueError(
117
+ "Either max_steps or num_epochs must be specified in the config"
118
+ )
119
+
120
+ if config.training.warmup_steps is None:
121
+ config.training.warmup_steps = int(config.training.max_steps * 0.01)
122
+
123
+ # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
124
+ checkpoint_callback = ModelCheckpoint(
125
+ monitor="train/total_loss",
126
+ mode="min",
127
+ save_top_k=config.training.save_top_k,
128
+ save_last=True,
129
+ filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
130
+ dirpath=config.training.checkpoint_dir,
131
+ # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
132
+ auto_insert_metric_name=False
133
+ )
134
+
135
+ # Add separate callback for periodic saves (no val_loss dependency). Use
136
+ # step-based saves for streaming datasets (save_every_n_steps) and epoch-based
137
+ # saves otherwise (save_every_n_epochs); whichever the config provides.
138
+ save_every_n_steps = config.training.get('save_every_n_steps', None)
139
+ save_every_n_epochs = config.training.get('save_every_n_epochs', None)
140
+ if save_every_n_steps is not None:
141
+ periodic_checkpoint_callback = ModelCheckpoint(
142
+ save_top_k=-1, # Save all periodic checkpoints
143
+ filename="step-{step:08d}",
144
+ dirpath=config.training.checkpoint_dir,
145
+ every_n_train_steps=save_every_n_steps,
146
+ auto_insert_metric_name=False
147
+ )
148
+ elif save_every_n_epochs is not None:
149
+ periodic_checkpoint_callback = ModelCheckpoint(
150
+ save_top_k=-1, # Save all periodic checkpoints
151
+ filename="epoch-{epoch:02d}",
152
+ dirpath=config.training.checkpoint_dir,
153
+ every_n_epochs=save_every_n_epochs,
154
+ auto_insert_metric_name=False
155
+ )
156
+ else:
157
+ raise ValueError(
158
+ "Either save_every_n_steps or save_every_n_epochs must be specified in the config"
159
+ )
160
+
161
+ trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
162
+
163
+ if wandb_logger is not None:
164
+ trainer_kwargs["logger"] = wandb_logger
165
+
166
+ trainer = pl.Trainer(**trainer_kwargs)
167
+
168
+ # Train the model
169
+ ckpt_path = None
170
+ if "resume_path" in config.training:
171
+ ckpt_path = config.training.resume_path
172
+
173
+ trainer.fit(module,
174
+ datamodule=data_module,
175
+ ckpt_path=ckpt_path)
176
+
177
+ # Only finish wandb on rank 0
178
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
179
+ wandb.finish()
180
+
181
+
182
+ if __name__ == '__main__':
183
+ # Parse arguments to get config name
184
+ parser = argparse.ArgumentParser()
185
+ parser.add_argument('--config_name', type=str, default='config',
186
+ help='Name of the config file to use')
187
+ parser.add_argument('--task', type=str, default=None,
188
+ help='Task name (uses config_{task}.yaml)')
189
+
190
+ # Parse known args (hydra will handle the rest)
191
+ args, unknown = parser.parse_known_args()
192
+
193
+ # Determine config name from task or config_name
194
+ if args.task:
195
+ config_name = f'config_{args.task}'
196
+ else:
197
+ config_name = args.config_name
198
+
199
+ print(f"Using config: {config_name}.yaml")
200
+
201
+ # Add config name to Hydra overrides (this persists across DDP subprocesses)
202
+ if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
203
+ unknown.insert(0, f'--config-name={config_name}')
204
+
205
+ # Reconstruct sys.argv for hydra
206
+ sys.argv = [sys.argv[0]] + unknown
207
+
208
+ # Define main function with default config (will be overridden by command line)
209
+ @hydra.main(version_base=None,
210
+ config_path=CONFIG_DIR,
211
+ config_name='config')
212
+ def main(config):
213
+ """Main entry point for training"""
214
+ train(config)
215
+
216
+ main()
a2d2_pep/README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A2D2 for Multi-Objective Therapeutic Peptide Generation 🧫
2
+
3
+ This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over peptide SMILES with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize **five therapeutic properties simultaneously**: binding affinity to a target protein, solubility, non-hemolysis, non-fouling, and cell-membrane permeability.
4
+
5
+ A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating peptides via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
6
+
7
+ Peptides are represented as **SMILES** strings and tokenized with the SMILES Pair Encoding tokenizer (vocabulary size `V = 587`) from [PeptideCLM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01443). Generated SMILES are decoded and validity-checked with the `SMILES2PEPTIDE` filter from [PepTune](https://arxiv.org/abs/2412.17780).
8
+
9
+ The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
10
+
11
+ ## Environment Installation
12
+ ```
13
+ # from the repository root
14
+ conda env create -f environment.yml
15
+
16
+ conda activate a2d2
17
+ ```
18
+ The peptide scripts share the `a2d2` environment with the molecule and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
19
+
20
+ ## Model Pretrained Weights
21
+
22
+ A2D2 fine-tunes a pretrained any-length insertion MDM trained on ~11M peptide SMILES (7,451 sequences from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs). Download the base checkpoint and place it at:
23
+ ```
24
+ A2D2/pretrained/anylength_pep.ckpt
25
+ ```
26
+ ```bash
27
+ # from the repository root
28
+ pip install gdown
29
+ mkdir -p pretrained
30
+ gdown 1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc -O pretrained/anylength_pep.ckpt
31
+ ```
32
+ (Or download manually from https://drive.google.com/file/d/1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
33
+ This is the default `--checkpoint_path`; pass `--checkpoint_path` to override it.
34
+
35
+ The reward classifiers (binding-affinity Transformer, plus XGBoost predictors for solubility, hemolysis, non-fouling, and permeability) and the SMILES PE tokenizer ship with the repo under [`pep_scoring/`](pep_scoring); no separate download is required. The PeptideCLM embedding model is fetched automatically from the Hugging Face Hub (`aaronfeller/PeptideCLM-23M-all`) on first run.
36
+
37
+ ## Pretraining the Any-Length Model
38
+
39
+ If you only want to fine-tune with A2D2, download the released `anylength_pep.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
40
+
41
+ ### 1. Download the pretraining dataset
42
+
43
+ The pretraining corpus is ~11M peptide SMILES (7,451 from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs), already tokenized with the in-repo SMILES PE tokenizer and saved as a Hugging Face `arrow` dataset (with `train`/`val` splits) via `save_to_disk`.
44
+
45
+ Download the archive and unpack it into [`data/`](data):
46
+
47
+ ```bash
48
+ # from a2d2_pep/
49
+ pip install gdown
50
+ gdown https://drive.google.com/uc?id=1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp -O 11M_peptide_smiles.zip
51
+ mkdir -p data
52
+ unzip 11M_peptide_smiles.zip -d data/
53
+ # result: a2d2_pep/data/11M_peptide_smiles/{train,val}/...
54
+ ```
55
+
56
+ This is the default `training.data_path` in [`config_pep.yaml`](config_pep.yaml). To store the dataset elsewhere, set `training.data_path` (absolute, or relative to `a2d2_pep/`).
57
+
58
+ ### 2. Configure
59
+
60
+ Pretraining is driven by [`config_pep.yaml`](config_pep.yaml). Key fields:
61
+
62
+ | Field | Default | Notes |
63
+ |-------|---------|-------|
64
+ | `training.data_path` | `data/11M_peptide_smiles` | Preprocessed arrow dataset from step 1. |
65
+ | `training.devices` | `4` | GPUs per node (DDP). |
66
+ | `training.batch_size` | `1024` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
67
+ | `training.max_steps` | `1000000` | Total optimizer steps. |
68
+ | `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
69
+ | `training.checkpoint_dir` | `checkpoints/peptides` | A timestamped subdirectory is created per run. |
70
+ | `interpolant.max_length` | `1024` | Max token length. |
71
+
72
+ ### 3. Pre-training Any-Length Peptide Model
73
+
74
+ Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
75
+
76
+ ```bash
77
+ # from a2d2_pep/
78
+ sbatch train_pep.sh
79
+ ```
80
+
81
+ `train_pep.sh` is a SLURM batch script that requests one `dgx-b200` node with 4 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
82
+
83
+ ```bash
84
+ python train.py --task pep
85
+ ```
86
+
87
+ It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task pep` makes `train.py` load `config_pep.yaml`.
88
+
89
+ Checkpoints are written to `checkpoints/peptides/<timestamp>/` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` for fine-tuning); the run log goes to `logs/<date>_a2d2-peptide_<jobid>.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
90
+
91
+ ## Fine-Tune with A2D2
92
+
93
+ All paths resolve relative to the repository, so the scripts run from any checkout. Before running, create the output directories `A2D2/checkpoints`, `A2D2/results`, and `A2D2/logs` (the script also creates them on demand). Fine-tuning curves and a `<prot_name>_generation_results.csv` are written to `<base_path>/results/<run_name>/`, and checkpoints to `--save_path_dir`.
94
+
95
+ Choose a target protein with `--prot_name` (looked up in the built-in `PROTEINS` table — e.g. `glp1` for GLP-1R or `glast` for GLAST), or supply an arbitrary target with `--prot_seq <amino acid sequence>`.
96
+
97
+ #### Available `--prot_name` targets
98
+
99
+ The named targets and their amino-acid sequences are defined in the `PROTEINS` dict in [`finetune_quality.py`](finetune_quality.py) (search for `PROTEINS = {`). The default is `glast`; passing a name not in the table raises an error listing the valid keys. To add a new target, add a `'<key>': '<sequence>'` entry there, or skip the table entirely with `--prot_seq`.
100
+
101
+ | `--prot_name` | Target |
102
+ |---------------|--------|
103
+ | `tfr` | Transferrin receptor (TfR) |
104
+ | `glp1` | GLP-1 receptor (GLP-1R) |
105
+
106
+ ### Single run
107
+
108
+ [`scripts/run_peptide_finetune.slurm`](scripts/run_peptide_finetune.slurm) runs a single `finetune_quality.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the hyperparameter set from the peptide column of the fine-tuning table in the paper — replicates `R = 8`, buffer size `B = 50`, resample interval `N_resample = 10`, gradient steps per iteration `N_update = 10`, alternation frequency `N_alt = 5`, warmup `N_warmup = 20`, sampling steps `N_steps = 256`, training mini-batch `10`, reward scaling `α = 0.1`, quality threshold `μ_min = 0.5`, and `--num_obj 5` — so you don't have to pass them by hand.
109
+
110
+ The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`) and `WANDB_ENTITY`:
111
+ ```bash
112
+ export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
113
+ export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
114
+ export WANDB_ENTITY=your_wandb_entity # optional
115
+ sbatch scripts/run_peptide_finetune.slurm
116
+ ```
117
+
118
+ Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
119
+ ```bash
120
+ sbatch --export=ALL,MODE_ID=2 scripts/run_peptide_finetune.slurm
121
+ ```
122
+ The target protein is set by the `PROT_NAME` variable near the top of the script (default `tfr`); edit it to one of the named targets above (or any key in the `PROTEINS` table). The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_pep.ckpt`. Outputs land in `checkpoints/finetune_test_peptides_<prot>/<job>_peptide_<prot>_<mode>/` and `results/peptide_test_ablation_<prot>/<mode>/`.
123
+
124
+ ### Key arguments
125
+ - `--prot_name` / `--prot_seq` — target protein (named lookup, or a raw amino-acid sequence).
126
+ - `--alternation_frequency` — epochs to train each of {policy, planner} before alternating.
127
+ - `--alpha` — reward-tilting temperature (smaller = stronger reward optimization).
128
+ - `--buffer_size`, `--resample_every_n_step` — replay-buffer size and how often it is regenerated.
129
+
130
+ ### Ablation flags
131
+ | Flag | Variant |
132
+ |------|---------|
133
+ | *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
134
+ | `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
135
+ | `--disable_insertion_planner` | A2D2 w/o insertion quality |
136
+ | `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
137
+ | `--joint_training` | train policy + quality heads jointly (no alternation) |
138
+
139
+ During buffer generation only sequences passing the `SMILES2PEPTIDE` validity filter are retained; the scalarized multi-objective reward is added to the log Radon–Nikodym derivative of each sequence. Fine-tuning runs on a single GPU (`--devices 1`).
140
+
141
+ ## Evaluation
142
+
143
+ Evaluation runs automatically every `--eval_every_n_epochs` epochs and at the end of training. It samples from the current model and reports the fraction of valid peptides along with the five therapeutic rewards (binding affinity, solubility, non-hemolysis, non-fouling, permeability), saving per-objective curves and `<prot_name>_generation_results.csv` under `<base_path>/results/<run_name>/`.
144
+
145
+ To resume a run, pass `--resume_ckpt /path/to/last.ckpt` (restores epoch, optimizer, and planner state; new checkpoints continue in the same directory).
a2d2_pep/config_pep.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: "any-order-flow"
2
+ dataset: "peptides"
3
+
4
+ model:
5
+ hidden_size: 768
6
+ n_heads: 12
7
+ cond_dim: 128
8
+ dropout: 0.05
9
+ n_blocks: 12
10
+
11
+ interpolant:
12
+ type: "any-order"
13
+ tokens: null # filled in automatically
14
+ pad_token: null # filled in automatically
15
+ mask_token: null # filled in automatically
16
+ max_length: 1024
17
+ insert_schedule:
18
+ type: "linear"
19
+ unmask_schedule:
20
+ type: "linear"
21
+
22
+ training:
23
+ only_embed_insert: true
24
+ batch_size: 1024
25
+ per_gpu_batch_size: 64 # Gradient accumulation happens automatically
26
+ cpus: 4
27
+ learning_rate: 3e-4
28
+ nodes: 1
29
+ devices: 4
30
+ max_steps: 1000000
31
+ weight_decay: 0.03
32
+ # Path to the preprocessed (arrow) pretraining dataset; see README for the download link.
33
+ # Relative paths resolve against a2d2_pep/. Defaults to a2d2_pep/data/11M_peptide_smiles.
34
+ data_path: "data/11M_peptide_smiles"
35
+ checkpoint_dir: "checkpoints/peptides"
36
+ save_top_k: 1
37
+ save_every_n_epochs: 1
38
+ loss_fn:
39
+ unmask: "elbo"
40
+ insert: "expectation"
41
+ reset_lr: false
42
+ warmup_steps: 2000
43
+ ema_decay: 0.9999
44
+ filter_max_length: false
45
+
46
+ wandb:
47
+ entity: null # set to your W&B entity, or leave null to use the default
48
+ project: "a2d2-pep"
49
+ name: "a2d2-pep"
50
+ path: "./wandb"
a2d2_pep/data/dataloading_for_dynamic_batching.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from datasets import Dataset,load_from_disk
6
+ import sys
7
+ import pytorch_lightning as pl
8
+ from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
9
+ from functools import partial
10
+ import re
11
+
12
+ # Directory containing this file; used to resolve the in-repo tokenizer files.
13
+ _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
14
+
15
+
16
+ class DynamicBatchingDataset(Dataset):
17
+ def __init__(self, dataset_dict, tokenizer):
18
+ print('Initializing dataset...')
19
+ self.dataset_dict = {
20
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
21
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
22
+ 'labels': dataset_dict['labels']
23
+ }
24
+ self.tokenizer = tokenizer
25
+
26
+ def __len__(self):
27
+ return len(self.dataset_dict['attention_mask'])
28
+
29
+ def __getitem__(self, idx):
30
+ if isinstance(idx, int):
31
+ return {
32
+ 'input_ids': self.dataset_dict['input_ids'][idx],
33
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
34
+ 'labels': self.dataset_dict['labels'][idx]
35
+ }
36
+ elif isinstance(idx, list):
37
+ return {
38
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
39
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
40
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
41
+ }
42
+ else:
43
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
44
+
45
+ class CustomDataModule(pl.LightningDataModule):
46
+ def __init__(self, dataset_path, tokenizer):
47
+ super().__init__()
48
+ self.dataset = load_from_disk(dataset_path)
49
+ self.tokenizer = tokenizer
50
+
51
+ def peptide_bond_mask(self, smiles_list):
52
+ """
53
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
54
+ of recognized bonds in the positions dictionary and 0 elsewhere.
55
+
56
+ Args:
57
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
58
+
59
+ Returns:
60
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
61
+ """
62
+ # Initialize the batch mask
63
+ batch_size = len(smiles_list)
64
+ max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
65
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
66
+
67
+ bond_patterns = [
68
+ (r'OC\(=O\)', 'ester'),
69
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
70
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
71
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
72
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
73
+ (r'C\(=O\)N[12]?', 'peptide')
74
+ ]
75
+
76
+ for batch_idx, smiles in enumerate(smiles_list):
77
+ positions = []
78
+ used = set()
79
+
80
+ # Identify bonds
81
+ for pattern, bond_type in bond_patterns:
82
+ for match in re.finditer(pattern, smiles):
83
+ if not any(p in range(match.start(), match.end()) for p in used):
84
+ positions.append({
85
+ 'start': match.start(),
86
+ 'end': match.end(),
87
+ 'type': bond_type,
88
+ 'pattern': match.group()
89
+ })
90
+ used.update(range(match.start(), match.end()))
91
+
92
+ # Update the mask for the current SMILES
93
+ for pos in positions:
94
+ mask[batch_idx, pos['start']:pos['end']] = 1
95
+
96
+ return mask
97
+
98
+ def peptide_token_mask(self, smiles_list, token_lists):
99
+ """
100
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
101
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
102
+
103
+ Args:
104
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
105
+ token_lists: List of tokenized SMILES strings (split into tokens).
106
+
107
+ Returns:
108
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
109
+ """
110
+ # Initialize the batch mask
111
+ batch_size = len(smiles_list)
112
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
113
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
114
+ atomwise_masks = self.peptide_bond_mask(smiles_list)
115
+
116
+
117
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
118
+ token_seq = token_lists[batch_idx]
119
+ atom_idx = 0
120
+
121
+ for token_idx, token in enumerate(token_seq):
122
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
123
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
124
+ tokenized_masks[batch_idx][token_idx] = 1
125
+ atom_idx += len(token)
126
+
127
+ return tokenized_masks
128
+
129
+ def collate_fn(self, batch):
130
+ item = batch[0]
131
+
132
+ token_array = self.tokenizer.get_token_split(item['input_ids'])
133
+ bond_mask = self.peptide_token_mask(item['labels'], token_array)
134
+
135
+ return {
136
+ 'input_ids': item['input_ids'],
137
+ 'attention_mask': item['attention_mask'],
138
+ 'bond_mask': bond_mask
139
+ }
140
+
141
+ def train_dataloader(self):
142
+ train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
143
+ return DataLoader(
144
+ train_dataset,
145
+ batch_size=1,
146
+ collate_fn=self.collate_fn, # Use the instance method
147
+ shuffle=True,
148
+ num_workers=12,
149
+ pin_memory=True
150
+ )
151
+
152
+ def val_dataloader(self):
153
+ val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
154
+ return DataLoader(
155
+ val_dataset,
156
+ batch_size=1,
157
+ collate_fn=self.collate_fn, # Use the instance method
158
+ num_workers=8,
159
+ pin_memory=True
160
+ )
161
+
162
+
163
+ def setup_data_and_update_config(config):
164
+ """
165
+ Get the dataset and update the config with token information for text datasets.
166
+ """
167
+ # SMILES Pair Encoding tokenizer ships with the repo under pep_scoring/tokenizer/.
168
+ tokenizer = SMILES_SPE_Tokenizer(
169
+ os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
170
+ os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_splits.txt'),
171
+ )
172
+
173
+ config.interpolant.tokens = len(tokenizer)
174
+ config.interpolant.pad_token = tokenizer.pad_token_id
175
+ config.interpolant.mask_token = tokenizer.mask_token_id
176
+
177
+ # Path to the preprocessed (arrow) pretraining dataset saved via `save_to_disk`.
178
+ # Download instructions are in the README; override with `training.data_path` in the config.
179
+ data_path = config.training.get('data_path', os.path.join('data', '11M_peptide_smiles'))
180
+ if not os.path.isabs(data_path):
181
+ data_path = os.path.join(_THIS_DIR, data_path)
182
+ if not os.path.exists(data_path):
183
+ raise FileNotFoundError(
184
+ f"Pretraining dataset not found at '{data_path}'. Download it (see a2d2_pep/README.md, "
185
+ "'Pretraining the Any-Length Model') and set `training.data_path` in config_pep.yaml."
186
+ )
187
+ data_module = CustomDataModule(data_path, tokenizer)
188
+
189
+ return data_module
a2d2_pep/data/dataset.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import torch
4
+
5
+ import utils
6
+
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import pytorch_lightning as pl
9
+ from functools import partial
10
+ import sys
11
+
12
+ class CustomDataset(Dataset):
13
+ def __init__(self, dataset, indices):
14
+ self.dataset = dataset
15
+ self.indices = indices
16
+
17
+ def __len__(self):
18
+ return len(self.indices)
19
+
20
+ def __getitem__(self, idx):
21
+ actual_idx = int(self.indices[idx])
22
+ item = self.dataset[actual_idx]
23
+ return item
24
+
25
+
26
+ # for weighting losses of peptide bonds
27
+ def peptide_bond_mask(smiles_list):
28
+ """
29
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
30
+ of recognized bonds in the positions dictionary and 0 elsewhere.
31
+
32
+ Args:
33
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
34
+
35
+ Returns:
36
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
37
+ """
38
+ # Initialize the batch mask
39
+ batch_size = len(smiles_list)
40
+ max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
41
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
42
+
43
+ bond_patterns = [
44
+ (r'OC\(=O\)', 'ester'),
45
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
46
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
47
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
48
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
49
+ (r'C\(=O\)N[12]?', 'peptide')
50
+ ]
51
+
52
+ for batch_idx, smiles in enumerate(smiles_list):
53
+ positions = []
54
+ used = set()
55
+
56
+ # Identify bonds
57
+ for pattern, bond_type in bond_patterns:
58
+ for match in re.finditer(pattern, smiles):
59
+ if not any(p in range(match.start(), match.end()) for p in used):
60
+ positions.append({
61
+ 'start': match.start(),
62
+ 'end': match.end(),
63
+ 'type': bond_type,
64
+ 'pattern': match.group()
65
+ })
66
+ used.update(range(match.start(), match.end()))
67
+
68
+ # Update the mask for the current SMILES
69
+ for pos in positions:
70
+ mask[batch_idx, pos['start']:pos['end']] = 1
71
+
72
+ return mask
73
+
74
+ def peptide_token_mask(smiles_list, token_lists):
75
+ """
76
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
77
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
78
+
79
+ Args:
80
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
81
+ token_lists: List of tokenized SMILES strings (split into tokens).
82
+
83
+ Returns:
84
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
85
+ """
86
+ # Initialize the batch mask
87
+ batch_size = len(smiles_list)
88
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
89
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
90
+ atomwise_masks = peptide_bond_mask(smiles_list)
91
+
92
+
93
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
94
+ token_seq = token_lists[batch_idx]
95
+ atom_idx = 0
96
+
97
+ for token_idx, token in enumerate(token_seq):
98
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
99
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
100
+ tokenized_masks[batch_idx][token_idx] = 1
101
+ atom_idx += len(token)
102
+
103
+ return tokenized_masks
104
+
105
+ def extract_amino_acid_sequence(helm_string):
106
+ """
107
+ Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
108
+ removing any brackets around each amino acid.
109
+
110
+ Args:
111
+ helm_string (str): The HELM notation string for a peptide.
112
+
113
+ Returns:
114
+ list: A list containing each amino acid in sequence without brackets.
115
+ """
116
+ # Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
117
+ matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
118
+
119
+ if matches:
120
+ # Join all matched sequences and split by dots to get individual amino acids
121
+ amino_acid_sequence = []
122
+ for match in matches:
123
+ sequence = match.replace('[', '').replace(']', '').split('.')
124
+ amino_acid_sequence.extend(sequence)
125
+ return amino_acid_sequence
126
+ else:
127
+ return "Invalid HELM notation or no peptide sequence found."
128
+
129
+ def helm_collate_fn(batch, tokenizer):
130
+ sequences = [item['HELM'] for item in batch]
131
+
132
+ max_len = 0
133
+ for sequence in sequences:
134
+ seq_len = len(extract_amino_acid_sequence(sequence))
135
+ if seq_len > max_len:
136
+ max_len = seq_len
137
+
138
+ tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
139
+
140
+ return {
141
+ 'input_ids': tokens['input_ids'],
142
+ 'attention_mask': tokens['attention_mask']
143
+ }
144
+
145
+
146
+ def collate_fn(batch, tokenizer):
147
+ """Standard data collator that truncates/pad sequences based on max_length"""
148
+ valid_sequences = []
149
+ valid_items = []
150
+
151
+ for item in batch:
152
+ try:
153
+ test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
154
+ valid_sequences.append(item['SMILES'])
155
+ valid_items.append(item)
156
+ except Exception as e:
157
+ print(f"Skipping sequence due to: {str(e)}")
158
+ continue
159
+
160
+ #sequences = [item['SMILES'] for item in batch]
161
+ #max_len = max([len(seq) for seq in sequences])
162
+ #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
163
+
164
+ tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
165
+
166
+ token_array = tokenizer.get_token_split(tokens['input_ids'])
167
+ bond_mask = peptide_token_mask(valid_sequences, token_array)
168
+ #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
169
+
170
+ return {
171
+ 'input_ids': tokens['input_ids'],
172
+ 'attention_mask': tokens['attention_mask'],
173
+ 'bond_mask': bond_mask
174
+ }
175
+
176
+
177
+ class CustomDataModule(pl.LightningDataModule):
178
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
179
+ super().__init__()
180
+ self.train_dataset = train_dataset
181
+ self.val_dataset = val_dataset
182
+ #self.test_dataset = test_dataset
183
+ self.batch_size = batch_size
184
+ self.tokenizer = tokenizer
185
+ self.collate_fn = collate_fn
186
+
187
+ def train_dataloader(self):
188
+ return DataLoader(self.train_dataset,
189
+ batch_size=self.batch_size,
190
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
191
+ num_workers=8,
192
+ pin_memory=True
193
+ )
194
+
195
+
196
+ def val_dataloader(self):
197
+ return DataLoader(self.val_dataset,
198
+ batch_size=self.batch_size,
199
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
200
+ num_workers=8,
201
+ pin_memory=True
202
+ )
203
+
204
+ """def test_dataloader(self):
205
+ return DataLoader(self.test_dataset, batch_size=self.batch_size,
206
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
207
+ num_workers=8, pin_memory=True)"""
a2d2_pep/evaluate_peptide_table.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate a finetuned peptide model checkpoint by sampling sequences
3
+ and computing metrics for the De Novo Peptide Generation table:
4
+ Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑),
5
+ Nonfouling (↑), Permeability (↑), Sampling Time (↓)
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import argparse
11
+ import time
12
+ import torch
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ # add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
17
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
+ sys.path.insert(0, REPO_ROOT)
19
+
20
+ from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
21
+ from lightning_modules import AnyOrderInsertionFlowModule
22
+ from inference_quality import sample_peptides_eval
23
+ from pep_scoring.scoring_functions import ScoringFunctions
24
+ from pep_utils.analyzer import PeptideAnalyzer
25
+ from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
26
+ from finetune_quality import PeptideFinetuner
27
+ from pep_utils.utils import str2bool, set_seed
28
+ from tdc import Evaluator
29
+
30
+
31
+ # Protein sequences
32
+ PROTEINS = {
33
+ 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
34
+ 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
35
+ 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
36
+ 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
37
+ 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
38
+ 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
39
+ 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
40
+ 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
41
+ 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
42
+ }
43
+
44
+
45
+ def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
46
+ """Load a finetuned PeptideFinetuner from a Lightning checkpoint."""
47
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
48
+ hparams = ckpt.get('hyper_parameters', {})
49
+ args = hparams.get('args', None)
50
+
51
+ # Load pretrained base checkpoint to get config
52
+ base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
53
+ if 'hyper_parameters' in base_ckpt:
54
+ config = base_ckpt['hyper_parameters']['config']
55
+ elif 'config' in base_ckpt:
56
+ config = base_ckpt['config']
57
+ else:
58
+ raise ValueError("Cannot find config in base checkpoint")
59
+
60
+ from omegaconf import OmegaConf, DictConfig
61
+ if not OmegaConf.is_config(config):
62
+ config = DictConfig(config)
63
+ OmegaConf.set_struct(config, False)
64
+
65
+ config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
66
+ config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
67
+ config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
68
+ config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
69
+ config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
70
+ config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
71
+ OmegaConf.set_struct(config, True)
72
+
73
+ disable_planner = getattr(args, 'disable_planner', False)
74
+
75
+ policy_model = AnyOrderInsertionFlowModuleFT(
76
+ config=config,
77
+ args=args,
78
+ pretrained_checkpoint=pretrained_ckpt_path,
79
+ insertion_planner=not disable_planner,
80
+ )
81
+
82
+ # Load finetuned weights
83
+ state_dict = ckpt['state_dict']
84
+ policy_state = {}
85
+ for k, v in state_dict.items():
86
+ if k.startswith('policy_model.'):
87
+ policy_state[k[len('policy_model.'):]] = v
88
+ policy_model.load_state_dict(policy_state, strict=False)
89
+ policy_model = policy_model.to(device)
90
+ policy_model.eval()
91
+
92
+ return policy_model, args, config
93
+
94
+
95
+ @torch.no_grad()
96
+ def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer,
97
+ num_samples=1000, batch_size=50, max_length=512,
98
+ total_num_steps=256, quality_mode="both", num_remasking=3,
99
+ quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
100
+ """
101
+ Sample `num_samples` peptides and compute all table metrics.
102
+ Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time
103
+ """
104
+ all_affinity = []
105
+ all_sol = []
106
+ all_hemo = []
107
+ all_nf = []
108
+ all_permeability = []
109
+ all_valid_seqs = []
110
+ total_valid = 0
111
+ total_generated = 0
112
+ total_time = 0.0
113
+
114
+ num_batches = (num_samples + batch_size - 1) // batch_size
115
+ remaining = num_samples
116
+
117
+ for b in range(num_batches):
118
+ bs = min(batch_size, remaining)
119
+ remaining -= bs
120
+
121
+ t_start = time.time()
122
+ result = sample_peptides_eval(
123
+ model=policy_model,
124
+ reward_model=reward_model,
125
+ analyzer=analyzer,
126
+ tokenizer=tokenizer,
127
+ steps=total_num_steps,
128
+ mask=policy_model.interpolant.mask_token,
129
+ pad=policy_model.interpolant.pad_token,
130
+ batch_size=bs,
131
+ max_length=max_length,
132
+ quality_mode=quality_mode,
133
+ num_remasking=num_remasking,
134
+ quality_threshold=quality_threshold,
135
+ unmask_quality_threshold=unmask_quality_threshold,
136
+ return_valid=True,
137
+ )
138
+ t_end = time.time()
139
+
140
+ # Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
141
+ valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result
142
+
143
+ batch_valid = len(valid_seqs)
144
+ total_valid += batch_valid
145
+ total_generated += bs
146
+ total_time += (t_end - t_start)
147
+ all_valid_seqs.extend(valid_seqs)
148
+
149
+ if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0:
150
+ all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist())
151
+ all_sol.extend(sol if isinstance(sol, list) else sol.tolist())
152
+ all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist())
153
+ all_nf.extend(nf if isinstance(nf, list) else nf.tolist())
154
+ all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist())
155
+
156
+ print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, "
157
+ f"time={t_end - t_start:.1f}s")
158
+
159
+ validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0
160
+
161
+ # Uniqueness (% of valid sequences that are unique) and
162
+ # Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences).
163
+ # Matches the convention used in evaluate_mol_table.py.
164
+ all_unique = list(set(all_valid_seqs))
165
+ num_unique = len(all_unique)
166
+ uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0
167
+ if num_unique > 1:
168
+ diversity = Evaluator('diversity')(all_unique)
169
+ else:
170
+ diversity = 0.0
171
+
172
+ metrics = {
173
+ 'Validity (%)': validity,
174
+ 'Uniqueness (%)': uniqueness,
175
+ 'Diversity': diversity,
176
+ 'Affinity': np.mean(all_affinity) if all_affinity else 0.0,
177
+ 'Affinity Std': np.std(all_affinity) if all_affinity else 0.0,
178
+ 'Solubility': np.mean(all_sol) if all_sol else 0.0,
179
+ 'Solubility Std': np.std(all_sol) if all_sol else 0.0,
180
+ 'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0,
181
+ 'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0,
182
+ 'Nonfouling': np.mean(all_nf) if all_nf else 0.0,
183
+ 'Nonfouling Std': np.std(all_nf) if all_nf else 0.0,
184
+ 'Permeability': np.mean(all_permeability) if all_permeability else 0.0,
185
+ 'Permeability Std': np.std(all_permeability) if all_permeability else 0.0,
186
+ 'Sampling Time (s)': total_time,
187
+ 'Num Generated': total_generated,
188
+ 'Num Valid': total_valid,
189
+ 'Num Unique': num_unique,
190
+ }
191
+
192
+ return metrics
193
+
194
+
195
+ def main():
196
+ parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint")
197
+ parser.add_argument('--checkpoint_path', type=str, required=True,
198
+ help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
199
+ parser.add_argument('--pretrained_ckpt', type=str,
200
+ default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'),
201
+ help='Path to the pretrained base model checkpoint')
202
+ parser.add_argument('--num_samples', type=int, default=500,
203
+ help='Number of peptides to sample')
204
+ parser.add_argument('--batch_size', type=int, default=50,
205
+ help='Batch size for sampling')
206
+ parser.add_argument('--max_length', type=int, default=512)
207
+ parser.add_argument('--total_num_steps', type=int, default=256)
208
+ parser.add_argument('--num_remasking', type=int, default=3)
209
+ parser.add_argument('--quality_threshold', type=float, default=0.5,
210
+ help='Threshold for insertion quality filtering during sampling')
211
+ parser.add_argument('--unmask_quality_threshold', type=float, default=None,
212
+ help='If set, gate unmasking/remasking by confidence: remask '
213
+ 'ALL clean tokens whose unmasking confidence is below this '
214
+ 'threshold, regardless of the schedule budget. If unset '
215
+ '(default), remasking is purely schedule-driven (count-based).')
216
+ parser.add_argument('--prot_name', type=str, default='glast',
217
+ help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')')
218
+ parser.add_argument('--prot_seq', type=str, default=None,
219
+ help='Custom protein sequence (overrides --prot_name)')
220
+ parser.add_argument('--disable_planner', action='store_true',
221
+ help='If set, disable remasking during evaluation')
222
+ parser.add_argument('--disable_insertion_planner', action='store_true',
223
+ help='If set, disable insertion quality filtering during evaluation')
224
+ parser.add_argument('--disable_unmasking_planner', action='store_true',
225
+ help='If set, disable unmasking confidence planner during evaluation')
226
+ parser.add_argument('--output_dir', type=str, default=None,
227
+ help='Directory to save results CSV. Defaults to checkpoint directory.')
228
+ parser.add_argument('--device', type=str, default='cuda:0')
229
+ parser.add_argument('--seed', type=int, default=42)
230
+ args = parser.parse_args()
231
+
232
+ set_seed(args.seed, use_cuda=True)
233
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
234
+
235
+ # Map flags to quality_mode
236
+ if args.disable_planner:
237
+ quality_mode = "none"
238
+ elif args.disable_insertion_planner and args.disable_unmasking_planner:
239
+ quality_mode = "none"
240
+ elif args.disable_insertion_planner:
241
+ quality_mode = "unmasking_only"
242
+ elif args.disable_unmasking_planner:
243
+ quality_mode = "insertion_only"
244
+ else:
245
+ quality_mode = "both"
246
+
247
+ print(f"Loading checkpoint: {args.checkpoint_path}")
248
+ print(f"Pretrained base: {args.pretrained_ckpt}")
249
+ print(f"Quality mode: {quality_mode}")
250
+
251
+ policy_model, train_args, config = load_finetuned_model(
252
+ args.checkpoint_path, args.pretrained_ckpt, device=device
253
+ )
254
+
255
+ # Setup tokenizer, reward model, analyzer
256
+ tokenizer = SMILES_SPE_Tokenizer(
257
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
258
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
259
+ )
260
+
261
+ if args.prot_seq is not None:
262
+ prot = args.prot_seq
263
+ prot_name = args.prot_name
264
+ else:
265
+ prot_name = args.prot_name
266
+ if prot_name not in PROTEINS:
267
+ raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
268
+ prot = PROTEINS[prot_name]
269
+
270
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
271
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device)
272
+ analyzer = PeptideAnalyzer()
273
+
274
+ print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...")
275
+
276
+ metrics = evaluate_checkpoint(
277
+ policy_model=policy_model,
278
+ tokenizer=tokenizer,
279
+ reward_model=reward_model,
280
+ analyzer=analyzer,
281
+ num_samples=args.num_samples,
282
+ batch_size=args.batch_size,
283
+ max_length=args.max_length,
284
+ total_num_steps=args.total_num_steps,
285
+ quality_mode=quality_mode,
286
+ num_remasking=args.num_remasking,
287
+ quality_threshold=args.quality_threshold,
288
+ unmask_quality_threshold=args.unmask_quality_threshold,
289
+ device=device,
290
+ )
291
+
292
+ # Print summary table
293
+ print("\n" + "=" * 60)
294
+ print(" De Novo Peptide Generation Results")
295
+ print("=" * 60)
296
+ for k, v in metrics.items():
297
+ if isinstance(v, float):
298
+ print(f" {k:<30s}: {v:.4f}")
299
+ else:
300
+ print(f" {k:<30s}: {v}")
301
+ print("=" * 60)
302
+
303
+ # Save results
304
+ output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
305
+ os.makedirs(output_dir, exist_ok=True)
306
+
307
+ if args.disable_planner:
308
+ tag = "no_planner"
309
+ elif args.disable_insertion_planner:
310
+ tag = "no_insertion_planner"
311
+ elif args.disable_unmasking_planner:
312
+ tag = "no_unmasking_planner"
313
+ else:
314
+ tag = "with_planner"
315
+ if args.unmask_quality_threshold is not None:
316
+ tag += f"_ut{args.unmask_quality_threshold:g}"
317
+ # Record the sweep parameter in the saved row for traceability.
318
+ metrics['unmask_quality_threshold'] = args.unmask_quality_threshold
319
+ metrics['quality_threshold'] = args.quality_threshold
320
+ metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv')
321
+ pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
322
+ print(f"Metrics saved to: {metrics_path}")
323
+
324
+
325
+ if __name__ == '__main__':
326
+ main()
a2d2_pep/finetune_quality.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Data Parallel (DDP) finetuning for peptide generation using PyTorch Lightning
2
+ import argparse
3
+ import math
4
+ from datetime import datetime
5
+ import numpy as np
6
+ import torch
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning.strategies import DDPStrategy
9
+ from pytorch_lightning.callbacks import ModelCheckpoint
10
+ from pytorch_lightning.loggers import WandbLogger
11
+ import wandb
12
+ import os
13
+ import sys
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+ # add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
18
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+ from inference_quality import sample_peptides_buffer, sample_peptides_eval
21
+ from pep_utils.analyzer import PeptideAnalyzer
22
+ from pep_utils.utils import str2bool, set_seed
23
+ from pep_scoring.scoring_functions import ScoringFunctions
24
+ from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
25
+ from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
26
+ from lightning_modules import AnyOrderInsertionFlowModule
27
+ from tdc import Evaluator
28
+
29
+ # Repository root (two levels up from this file: A2D2/a2d2_pep/finetune_quality.py)
30
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
31
+
32
+ class PeptideFinetuner(pl.LightningModule):
33
+ """Lightning module for distributed peptide finetuning."""
34
+
35
+ def __init__(
36
+ self,
37
+ args,
38
+ policy_model,
39
+ reward_model,
40
+ tokenizer,
41
+ pretrained=None,
42
+ mcts=None,
43
+ filename=None,
44
+ prot_name=None,
45
+ eps=1e-5
46
+ ):
47
+ super().__init__()
48
+ self.args = args
49
+ self.policy_model = policy_model
50
+ self.reward_model = reward_model
51
+ self.tokenizer = tokenizer
52
+ self.pretrained = pretrained
53
+ self.mcts = mcts
54
+ self.filename = filename
55
+ self.prot_name = prot_name
56
+ self.eps = eps
57
+
58
+ # Length cutoff is tunable from the CLI: --min_peptide_bonds N enforces
59
+ # >=N peptide bonds (filters degenerate short reward-hacked molecules);
60
+ # --min_peptide_bonds 0 disables the cutoff.
61
+ min_bonds = getattr(args, 'min_peptide_bonds', 4)
62
+ self.analyzer = PeptideAnalyzer(
63
+ min_peptide_bonds=max(0, min_bonds),
64
+ enforce_min_peptide_bonds=min_bonds > 0,
65
+ )
66
+
67
+ # Save hyperparameters
68
+ self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
69
+
70
+ # Buffer for sequences
71
+ self.x_saved = None
72
+ self.log_rnd_saved = None
73
+ self.final_rewards_saved = None
74
+
75
+ # Logs
76
+ self.valid_fraction_log = []
77
+ self.uniqueness_log = []
78
+ self.diversity_log = []
79
+ self.affinity_log = []
80
+ self.sol_log = []
81
+ self.hemo_log = []
82
+ self.nf_log = []
83
+ self.permeability_log = []
84
+ self._diversity_evaluator = Evaluator('diversity')
85
+
86
+ # Alternating training between policy and planner
87
+ self.train_policy = True # Start by training policy
88
+ self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
89
+
90
+ def freeze_policy_model(self):
91
+ """Freeze policy model parameters (but not planner)."""
92
+ for name, param in self.policy_model.named_parameters():
93
+ if not name.startswith('planner.'):
94
+ param.requires_grad = False
95
+
96
+ def unfreeze_policy_model(self):
97
+ """Unfreeze policy model parameters (but not planner)."""
98
+ for name, param in self.policy_model.named_parameters():
99
+ if not name.startswith('planner.'):
100
+ param.requires_grad = True
101
+
102
+ def freeze_planner_model(self):
103
+ """Freeze planner parameters."""
104
+ if hasattr(self.policy_model, 'planner'):
105
+ for param in self.policy_model.planner.parameters():
106
+ param.requires_grad = False
107
+
108
+ def unfreeze_planner_model(self):
109
+ """Unfreeze planner parameters."""
110
+ if hasattr(self.policy_model, 'planner'):
111
+ for param in self.policy_model.planner.parameters():
112
+ param.requires_grad = True
113
+
114
+ def configure_optimizers(self):
115
+ # Separate parameter groups for policy backbone vs planner heads
116
+ planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
117
+ planner_params = []
118
+ policy_params = []
119
+ for name, param in self.policy_model.named_parameters():
120
+ if name.startswith('planner.'):
121
+ planner_params.append(param)
122
+ else:
123
+ policy_params.append(param)
124
+
125
+ param_groups = [
126
+ {'params': policy_params, 'lr': self.args.learning_rate},
127
+ {'params': planner_params, 'lr': planner_lr},
128
+ ]
129
+ optimizer = torch.optim.AdamW(param_groups)
130
+ return optimizer
131
+
132
+ def _get_quality_mode(self):
133
+ """Map ablation flags + warmup state to quality_mode string."""
134
+ if self.args.disable_planner:
135
+ return "none"
136
+ if self.current_epoch < self.args.schedule_warmup_epochs:
137
+ return "none"
138
+ di = getattr(self.args, 'disable_insertion_planner', False)
139
+ du = getattr(self.args, 'disable_unmasking_planner', False)
140
+ if di and du:
141
+ return "none"
142
+ if di:
143
+ return "unmasking_only"
144
+ if du:
145
+ return "insertion_only"
146
+ return "both"
147
+
148
+ def on_save_checkpoint(self, checkpoint):
149
+ """
150
+ Save additional metadata to make loading easier.
151
+ Saves the config directly in the checkpoint so loading doesn't need to follow references.
152
+ """
153
+ # Save the config from the policy model directly in the checkpoint
154
+ if hasattr(self.policy_model, 'config'):
155
+ checkpoint['config'] = self.policy_model.config
156
+ print(f"Saved config to checkpoint for easier loading")
157
+
158
+ # Save EMA params if they exist in the policy model
159
+ if hasattr(self.policy_model, 'ema_params') and self.policy_model.ema_params:
160
+ checkpoint['ema_params'] = self.policy_model.ema_params
161
+ print(f"Saved EMA params to checkpoint")
162
+
163
+ # Save planner state if it exists
164
+ if hasattr(self.policy_model, 'planner'):
165
+ checkpoint['planner_state'] = self.policy_model.planner.state_dict()
166
+ print(f"Saved planner state to checkpoint")
167
+
168
+ def on_train_epoch_start(self):
169
+ """Called at the start of each training epoch."""
170
+ # If disable_planner mode, only train policy (no alternation)
171
+ if self.args.disable_planner:
172
+ self.train_policy = True
173
+ self.unfreeze_policy_model()
174
+ self.freeze_planner_model()
175
+ if self.global_rank == 0 and self.current_epoch == 0:
176
+ print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
177
+ elif getattr(self.args, 'joint_training', False):
178
+ # Joint mode: train policy + planner together every step (no alternation)
179
+ self.train_policy = True # marker; training_step adds planner loss when joint_training is set
180
+ self.unfreeze_policy_model()
181
+ self.unfreeze_planner_model()
182
+ if self.global_rank == 0 and self.current_epoch == 0:
183
+ print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
184
+ else:
185
+ # Alternate between training policy and planner from epoch 0
186
+ # Determine which model to train this epoch
187
+ cycle_position = (self.current_epoch // self.alternation_frequency) % 2
188
+ self.train_policy = (cycle_position == 0)
189
+
190
+ if self.train_policy:
191
+ # Train policy, freeze planner
192
+ self.unfreeze_policy_model()
193
+ self.freeze_planner_model()
194
+ if self.global_rank == 0:
195
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
196
+ else:
197
+ # Train planner, freeze policy
198
+ self.freeze_policy_model()
199
+ self.unfreeze_planner_model()
200
+ if self.global_rank == 0:
201
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
202
+
203
+ # Resample buffer if needed
204
+ if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
205
+ self._generate_buffer()
206
+ # Synchronize all ranks after buffer generation to prevent NCCL timeout
207
+ if self.trainer and self.trainer.world_size > 1:
208
+ torch.distributed.barrier()
209
+
210
+ def _generate_buffer(self):
211
+ """Generate buffer of sequences for training - all ranks generate in parallel.
212
+
213
+ When pool_size > 0, maintains a persistent pool and refreshes a fraction
214
+ each time instead of regenerating the entire buffer from scratch. This
215
+ preserves diversity/uniqueness across training by avoiding wholesale
216
+ replacement with samples from an increasingly mode-collapsed policy.
217
+ """
218
+ world_size = self.trainer.world_size if self.trainer else 1
219
+ rank = self.global_rank if self.trainer else 0
220
+
221
+ pool_size = getattr(self.args, 'pool_size', 0)
222
+ is_pool = pool_size > 0
223
+ is_init = self.x_saved is None
224
+
225
+ # Determine how many sequences to sample this call
226
+ if is_pool:
227
+ refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
228
+ if is_init:
229
+ samples_per_gpu = pool_size
230
+ else:
231
+ samples_per_gpu = max(1, int(pool_size * refresh_frac))
232
+ if rank == 0:
233
+ if is_init:
234
+ print(f"\n[POOL] Initializing pool with {pool_size} sequences at epoch {self.current_epoch}")
235
+ else:
236
+ print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} sequences ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
237
+ else:
238
+ samples_per_gpu = self.args.buffer_size // world_size
239
+ if rank == 0:
240
+ samples_per_gpu += self.args.buffer_size % world_size
241
+
242
+ accumulated_x = []
243
+ accumulated_log_rnd = []
244
+ accumulated_rewards = []
245
+ total_accumulated = 0
246
+
247
+ if rank == 0:
248
+ print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
249
+ print(f"[BUFFER] GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
250
+ print(f"[BUFFER] GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
251
+ if not is_pool:
252
+ print(f"[BUFFER] Each of {world_size} ranks will generate {samples_per_gpu} samples")
253
+
254
+ max_attempts = getattr(self.args, 'max_buffer_attempts', 100) # cap wasted GPU / infinite loop
255
+ starvation_patience = getattr(self.args, 'buffer_starvation_patience', 10)
256
+ attempts = 0
257
+
258
+ import time
259
+ while total_accumulated < samples_per_gpu and attempts < max_attempts:
260
+ attempts += 1
261
+ if rank == 0:
262
+ print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
263
+
264
+ start_time = time.time()
265
+
266
+ # new elbo loss
267
+ if self.args.elbo_rnd:
268
+ x_final, _, final_rewards, trace = \
269
+ sample_peptides_buffer(
270
+ self.policy_model,
271
+ self.reward_model, self.analyzer,
272
+ self.tokenizer,
273
+ steps=self.args.total_num_steps,
274
+ mask=self.policy_model.interpolant.mask_token,
275
+ pad=self.policy_model.interpolant.pad_token,
276
+ batch_size=self.args.batch_size,
277
+ max_length=self.args.max_length,
278
+ # Buffer generation never uses the quality heads (planner):
279
+ # the backbone must train on raw policy samples so that a
280
+ # poorly-trained planner can't corrupt the backbone's data.
281
+ quality_mode="none",
282
+ compute_rnd=False,
283
+ alpha=self.args.alpha,
284
+ num_remasking=self.args.num_remasking,
285
+ min_length=self.args.min_length,
286
+ )
287
+ if x_final.shape[0] > 0:
288
+ with torch.no_grad():
289
+ noised = self.policy_model.prepare_noised_sample(
290
+ x_final, num_samples=self.args.elbo_rnd_num_samples)
291
+ policy_loss = self.policy_model.compute_loss_from_noised(noised)
292
+ pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
293
+ log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
294
+ else:
295
+ log_rnd = torch.empty((0,), dtype=torch.float32, device=x_final.device)
296
+ else:
297
+ x_final, log_rnd, final_rewards, trace = \
298
+ sample_peptides_buffer(
299
+ self.policy_model,
300
+ self.reward_model, self.analyzer,
301
+ self.tokenizer,
302
+ steps=self.args.total_num_steps,
303
+ mask=self.policy_model.interpolant.mask_token,
304
+ pad=self.policy_model.interpolant.pad_token,
305
+ batch_size=self.args.batch_size,
306
+ max_length=self.args.max_length,
307
+ # Buffer generation never uses the quality heads (planner):
308
+ # the backbone must train on raw policy samples so that a
309
+ # poorly-trained planner can't corrupt the backbone's data.
310
+ quality_mode="none",
311
+ compute_rnd=True,
312
+ pretrained=self.pretrained,
313
+ alpha=self.args.alpha,
314
+ num_remasking=self.args.num_remasking,
315
+ min_length=self.args.min_length,
316
+ )
317
+ elapsed = time.time() - start_time
318
+ if rank == 0:
319
+ print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
320
+
321
+ n_valid = x_final.shape[0]
322
+ if n_valid > 0:
323
+ accumulated_x.append(x_final)
324
+ accumulated_log_rnd.append(log_rnd)
325
+ accumulated_rewards.append(final_rewards)
326
+ total_accumulated += n_valid
327
+
328
+ if rank == 0:
329
+ print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode=none (heads disabled for buffer gen) accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
330
+
331
+ # Starvation guard: if nothing valid comes through (e.g. the length
332
+ # cutoff is too aggressive for a collapsed policy), stop grinding GPU
333
+ # hours and fail fast with an actionable message.
334
+ if attempts >= starvation_patience and total_accumulated == 0:
335
+ if rank == 0:
336
+ print(f"[BUFFER STARVATION] 0 valid samples after {attempts} attempts "
337
+ f"(min_peptide_bonds={getattr(self.args, 'min_peptide_bonds', 4)}). "
338
+ f"Aborting refill early — lower --min_peptide_bonds or check the policy.")
339
+ break
340
+
341
+ if total_accumulated == 0:
342
+ raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
343
+
344
+ if total_accumulated < samples_per_gpu:
345
+ print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
346
+
347
+ new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
348
+ new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
349
+ new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
350
+
351
+ del accumulated_x, accumulated_log_rnd, accumulated_rewards
352
+ torch.cuda.empty_cache()
353
+
354
+ # Pool mode (after init): replace a random subset of the existing pool.
355
+ # Classic mode / pool init: overwrite the buffer.
356
+ if is_pool and not is_init:
357
+ actual_new = min(new_x.shape[0], self.x_saved.shape[0])
358
+ indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
359
+ self.x_saved[indices] = new_x[:actual_new]
360
+ self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
361
+ self.final_rewards_saved[indices] = new_rewards[:actual_new]
362
+ if rank == 0:
363
+ print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} sequences, reward mean={self.final_rewards_saved.mean():.4f}")
364
+ else:
365
+ self.x_saved = new_x
366
+ self.log_rnd_saved = new_log_rnd
367
+ self.final_rewards_saved = new_rewards
368
+
369
+ # Sanity check: median length (non-pad tokens) of buffered peptides.
370
+ if rank == 0:
371
+ pad = self.policy_model.interpolant.pad_token
372
+ token_lens = (self.x_saved != pad).sum(dim=1)
373
+ print(f"[BUFFER] peptide token length: median={token_lens.median().item()} "
374
+ f"min={token_lens.min().item()} max={token_lens.max().item()} "
375
+ f"(n={token_lens.shape[0]})")
376
+
377
+ def training_step(self, batch, batch_idx):
378
+ """Training step - batch is ignored, we use saved buffer."""
379
+ # Use mini-batch sampling from buffer to avoid OOM
380
+ buffer_size = self.x_saved.shape[0]
381
+ mini_batch_size = getattr(self.args, 'training_mini_batch_size', 6)
382
+
383
+ # Randomly sample mini_batch_size sequences from buffer
384
+ if buffer_size > mini_batch_size:
385
+ indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
386
+ x_final = self.x_saved[indices]
387
+ log_rnd = self.log_rnd_saved[indices]
388
+ else:
389
+ # If buffer is smaller than mini_batch_size, use all
390
+ x_final = self.x_saved
391
+ log_rnd = self.log_rnd_saved
392
+
393
+ joint = getattr(self.args, 'joint_training', False)
394
+ policy_loss = None
395
+ planner_loss = None
396
+
397
+ if self.train_policy:
398
+ # Train policy with WDCE loss
399
+ policy_loss = self.policy_model.loss_wdce_flexible(
400
+ log_rnd,
401
+ x_final,
402
+ num_replicates=self.args.wdce_num_replicates,
403
+ centering=self.args.centering,
404
+ centering_strength=self.args.centering_strength
405
+ )
406
+ self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
407
+
408
+ if (not self.train_policy) or joint:
409
+ # Train planner with appropriate loss based on ablation flags
410
+ if self.args.disable_insertion_planner:
411
+ # Ablation: only train unmasking/remasking planner (no insertion head)
412
+ planner_loss = self.policy_model.loss_planner_flexible(
413
+ log_rnd,
414
+ x_final,
415
+ num_replicates=self.args.wdce_num_replicates,
416
+ centering=self.args.centering,
417
+ centering_strength=self.args.centering_strength
418
+ )
419
+ self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
420
+ self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
421
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
422
+ elif self.args.disable_unmasking_planner:
423
+ # Ablation: only train insertion planner (no remasking head)
424
+ unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
425
+ log_rnd,
426
+ x_final,
427
+ num_replicates=self.args.wdce_num_replicates,
428
+ centering=self.args.centering,
429
+ centering_strength=self.args.centering_strength
430
+ )
431
+ # Zero out the unmasking component - only backprop insertion loss
432
+ planner_loss = insert_loss
433
+ self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
434
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
435
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
436
+ else:
437
+ # Full planner: train both remasking + insertion
438
+ unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
439
+ log_rnd,
440
+ x_final,
441
+ num_replicates=self.args.wdce_num_replicates,
442
+ centering=self.args.centering,
443
+ centering_strength=self.args.centering_strength
444
+ )
445
+ self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
446
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
447
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
448
+
449
+
450
+ # Combine losses depending on mode
451
+ if joint:
452
+ loss = policy_loss + planner_loss
453
+ mode_value = 0.5
454
+ elif self.train_policy:
455
+ loss = policy_loss
456
+ mode_value = 0.0
457
+ else:
458
+ loss = planner_loss
459
+ mode_value = 1.0
460
+
461
+ # Log overall loss and mode
462
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
463
+ self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
464
+
465
+ return loss
466
+
467
+ def on_train_epoch_end(self):
468
+ """Called at the end of each training epoch - only rank 0 evaluates."""
469
+ # Only evaluate every N epochs to save time
470
+ eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
471
+ is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
472
+ if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
473
+ # Sample eval batch with updated policy
474
+ valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = \
475
+ sample_peptides_eval(
476
+ self.policy_model, self.reward_model, self.analyzer,
477
+ self.tokenizer,
478
+ steps=self.args.total_num_steps,
479
+ mask=self.policy_model.interpolant.mask_token,
480
+ pad=self.policy_model.interpolant.pad_token,
481
+ batch_size=50,
482
+ max_length=self.args.max_length,
483
+ quality_mode=self._get_quality_mode(),
484
+ num_remasking=self.args.num_remasking,
485
+ return_valid=True,
486
+ )
487
+
488
+ # Uniqueness (% of valid that are unique) and Diversity
489
+ # (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences),
490
+ # matching evaluate_peptide_table.py / evaluate_mol_table.py.
491
+ num_valid = len(valid_seqs)
492
+ unique_seqs = list(set(valid_seqs))
493
+ num_unique = len(unique_seqs)
494
+ uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
495
+ diversity = self._diversity_evaluator(unique_seqs) if num_unique > 1 else 0.0
496
+
497
+ # Append to logs
498
+ self.affinity_log.append(affinity)
499
+ self.sol_log.append(sol)
500
+ self.hemo_log.append(hemo)
501
+ self.nf_log.append(nf)
502
+ self.permeability_log.append(permeability)
503
+ self.valid_fraction_log.append(valid_fraction)
504
+ self.uniqueness_log.append(uniqueness)
505
+ self.diversity_log.append(diversity)
506
+
507
+ # Compute reward stats
508
+ mean_reward = self.final_rewards_saved.mean().item()
509
+ min_reward = self.final_rewards_saved.min().item()
510
+ max_reward = self.final_rewards_saved.max().item()
511
+ median_reward = self.final_rewards_saved.median().item()
512
+
513
+ # Log metrics
514
+ self.log_dict({
515
+ "eval/affinity": np.mean(affinity),
516
+ "eval/sol": np.mean(sol),
517
+ "eval/hemo": np.mean(hemo),
518
+ "eval/nf": np.mean(nf),
519
+ "eval/permeability": np.mean(permeability),
520
+ "eval/valid_fraction": valid_fraction,
521
+ "eval/uniqueness": uniqueness,
522
+ "eval/diversity": diversity,
523
+ "eval/mean_reward_search": mean_reward,
524
+ "eval/min_reward_search": min_reward,
525
+ "eval/max_reward_search": max_reward,
526
+ "eval/median_reward_search": median_reward
527
+ })
528
+
529
+ print(f"epoch {self.current_epoch} | affinity {np.mean(affinity):.4f} | "
530
+ f"sol {np.mean(sol):.4f} | hemo {np.mean(hemo):.4f} | "
531
+ f"nf {np.mean(nf):.4f} | permeability {np.mean(permeability):.4f} | "
532
+ f"valid {valid_fraction:.4f} | uniq {uniqueness:.2f}% | div {diversity:.4f}")
533
+
534
+ def on_fit_end(self):
535
+ """Called at the end of training - save results."""
536
+ if self.global_rank == 0:
537
+ # Save logs and plot
538
+ base_path = self.args.base_path
539
+ plot_path = f'{base_path}/results/{self.args.run_name}'
540
+ os.makedirs(plot_path, exist_ok=True)
541
+
542
+ output_log_path = f'{plot_path}/log_{self.filename}.csv'
543
+ save_logs_to_file(self.valid_fraction_log, self.affinity_log,
544
+ self.sol_log, self.hemo_log, self.nf_log,
545
+ self.permeability_log, output_log_path,
546
+ uniqueness_log=self.uniqueness_log,
547
+ diversity_log=self.diversity_log)
548
+
549
+ # Final generation
550
+ x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = \
551
+ sample_peptides_eval(
552
+ self.policy_model, self.reward_model, self.analyzer,
553
+ self.tokenizer,
554
+ steps=self.args.total_num_steps,
555
+ mask=self.policy_model.interpolant.mask_token,
556
+ pad=self.policy_model.interpolant.pad_token,
557
+ batch_size=50,
558
+ max_length=self.args.max_length,
559
+ quality_mode=self._get_quality_mode(),
560
+ num_remasking=self.args.num_remasking,
561
+ dataframe=True,
562
+ )
563
+ df.to_csv(f'{plot_path}/{self.prot_name}_generation_results.csv', index=False)
564
+
565
+
566
+ def save_logs_to_file(valid_fraction_log, affinity_log,
567
+ sol_log, hemo_log, nf_log,
568
+ permeability_log, output_path,
569
+ uniqueness_log=None, diversity_log=None):
570
+ """
571
+ Saves the logs to a CSV file.
572
+ """
573
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
574
+
575
+ log_data = {
576
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
577
+ "Valid Fraction": valid_fraction_log,
578
+ "Binding Affinity": affinity_log,
579
+ "Solubility": sol_log,
580
+ "Hemolysis": hemo_log,
581
+ "Nonfouling": nf_log,
582
+ "Permeability": permeability_log,
583
+ }
584
+ if uniqueness_log is not None:
585
+ log_data["Uniqueness (%)"] = uniqueness_log
586
+ if diversity_log is not None:
587
+ log_data["Diversity"] = diversity_log
588
+
589
+ df = pd.DataFrame(log_data)
590
+ df.to_csv(output_path, index=False)
591
+
592
+
593
+ class DummyDataset(torch.utils.data.Dataset):
594
+ """Dummy dataset for Lightning trainer (we use buffer instead)."""
595
+ def __init__(self, size=10):
596
+ self.size = size
597
+
598
+ def __len__(self):
599
+ return self.size
600
+
601
+ def __getitem__(self, idx):
602
+ return torch.zeros(1) # Dummy data
603
+
604
+
605
+ def main():
606
+ """Main entry point for distributed training."""
607
+ # Disable DDP optimizer for higher-order ops like flex_attention
608
+ import torch._dynamo
609
+ torch._dynamo.config.optimize_ddp = False
610
+
611
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
612
+ argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
613
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
614
+ argparser.add_argument('--num_epochs', type=int, default=100)
615
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
616
+ argparser.add_argument('--truncate_steps', type=int, default=50)
617
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
618
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
619
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
620
+ argparser.add_argument('--batch_size', type=int, default=50)
621
+ argparser.add_argument('--name', type=str, default='debug')
622
+ argparser.add_argument('--total_num_steps', type=int, default=128)
623
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
624
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
625
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
626
+ argparser.add_argument("--seed", type=int, default=0)
627
+ # new
628
+ argparser.add_argument('--run_name', type=str, default='peptides')
629
+ argparser.add_argument("--save_path_dir", default=os.path.join(REPO_ROOT, "checkpoints", "finetune_peptides"), type=str)
630
+ # mcts
631
+ argparser.add_argument('--num_sequences', type=int, default=10)
632
+ argparser.add_argument('--max_length', type=int, default=1024)
633
+ argparser.add_argument('--min_length', type=int, default=0,
634
+ help='Minimum sequence length (in SMILES SPE tokens). '
635
+ 'Samples shorter than this are dropped from the buffer. 0 disables the filter.')
636
+ argparser.add_argument('--num_children', type=int, default=50)
637
+ argparser.add_argument('--num_iter', type=int, default=30)
638
+ argparser.add_argument('--seq_length', type=int, default=1024)
639
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
640
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
641
+ argparser.add_argument('--buffer_size', type=int, default=100)
642
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
643
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
644
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
645
+ argparser.add_argument('--resample_every_n_step', type=int, default=10)
646
+ argparser.add_argument('--exploration', type=float, default=0.1)
647
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
648
+ argparser.add_argument('--alpha', type=float, default=0.01)
649
+ argparser.add_argument('--scalarization', type=str, default='sum')
650
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
651
+ argparser.add_argument("--centering", action='store_true', default=False)
652
+ argparser.add_argument("--centering_strength", type=float, default=1.0)
653
+
654
+ # ELBO-based log_rnd estimation
655
+ argparser.add_argument('--elbo_rnd', action='store_true', default=False,
656
+ help='If set, compute log_rnd via ELBO instead of trajectory rollout')
657
+ argparser.add_argument('--elbo_rnd_num_samples', type=int, default=16,
658
+ help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
659
+
660
+ # adaptive schedule parameters
661
+ argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
662
+ argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
663
+ argparser.add_argument('--schedule_num_layers', type=int, default=2)
664
+ argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
665
+ argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
666
+ argparser.add_argument('--freeze_base_model', action='store_true', default=False)
667
+ argparser.add_argument('--schedule_warmup_epochs', type=int, default=0, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
668
+ argparser.add_argument('--alternation_frequency', type=int, default=20, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
669
+ argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
670
+
671
+ # objectives
672
+ argparser.add_argument('--num_obj', type=int, default=5)
673
+ argparser.add_argument('--prot_seq', type=str, default=None)
674
+ argparser.add_argument('--prot_name', type=str, default='glast',
675
+ help='Protein target name. Looked up in PROTEINS dict unless --prot_seq is given.')
676
+ argparser.add_argument('--devices', type=int, default=-1)
677
+ argparser.add_argument('--checkpoint_path', type=str, default=None)
678
+ argparser.add_argument('--resume_ckpt', type=str, default=None,
679
+ help='Path to a Lightning last.ckpt to resume training from (restores epoch/optimizer/planner state). '
680
+ 'New checkpoints continue in the same directory as this checkpoint.')
681
+
682
+ # remasking
683
+ argparser.add_argument('--num_remasking', type=int, default=5)
684
+ argparser.add_argument('--quality_threshold', type=float, default=1)
685
+
686
+ # length cutoff (peptide-bond filter) + buffer starvation guard
687
+ argparser.add_argument('--min_peptide_bonds', type=int, default=4,
688
+ help='Minimum backbone peptide bonds for a sample to count as valid. '
689
+ '0 disables the cutoff. Filters degenerate short reward-hacked molecules.')
690
+ argparser.add_argument('--max_buffer_attempts', type=int, default=100,
691
+ help='Max sampling rounds per buffer refill before giving up (caps wasted GPU when validity is low).')
692
+ argparser.add_argument('--buffer_starvation_patience', type=int, default=10,
693
+ help='If 0 valid samples after this many rounds, abort the refill early (starvation guard).')
694
+
695
+ # planner ablation flags
696
+ argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
697
+ argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
698
+ argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
699
+ argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
700
+
701
+ # performance optimization
702
+ argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
703
+ argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10, help='Number of gradient updates per epoch')
704
+ argparser.add_argument('--training_mini_batch_size', type=int, default=6, help='Mini-batch size for training from buffer to avoid OOM')
705
+ argparser.add_argument('--pool_size', type=int, default=0,
706
+ help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer). Helps preserve uniqueness/diversity over training.')
707
+ argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
708
+ help='Fraction of pool to replace each resample step (only used when pool_size>0)')
709
+
710
+ args = argparser.parse_args()
711
+
712
+ # Default planner LR to policy LR if not specified
713
+ if args.planner_learning_rate is None:
714
+ args.planner_learning_rate = args.learning_rate
715
+
716
+ # Set seed
717
+ pl.seed_everything(args.seed)
718
+
719
+ # Load models
720
+ checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
721
+ os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt')
722
+
723
+ # Update args.checkpoint_path to ensure it's saved in hyperparameters for later inference
724
+ args.checkpoint_path = checkpoint_path
725
+
726
+ PROTEINS = {
727
+ 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
728
+ 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
729
+ 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
730
+ 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
731
+ 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
732
+ 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
733
+ 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
734
+ 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
735
+ 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
736
+ }
737
+
738
+ if args.prot_seq is not None:
739
+ prot = args.prot_seq
740
+ prot_name = args.prot_name
741
+ else:
742
+ prot_name = args.prot_name
743
+ if prot_name not in PROTEINS:
744
+ raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
745
+ prot = PROTEINS[prot_name]
746
+ filename = prot_name
747
+
748
+ curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
749
+
750
+ if args.no_mcts:
751
+ args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_no-mcts'
752
+ else:
753
+ args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}'
754
+
755
+ # Append ablation tags to run name for easy identification
756
+ if args.disable_planner:
757
+ args.run_name += '_no_planner'
758
+ if args.disable_insertion_planner:
759
+ args.run_name += '_no_insertion_planner'
760
+ if args.disable_unmasking_planner:
761
+ args.run_name += '_no_unmasking_planner'
762
+ if args.joint_training:
763
+ if args.disable_planner:
764
+ raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
765
+ args.run_name += '_joint_training'
766
+
767
+ # When resuming, continue writing checkpoints into the SAME directory as the
768
+ # checkpoint we resume from (keeps model-{epoch}.ckpt contiguous) instead of
769
+ # spawning a fresh timestamped run directory.
770
+ if args.resume_ckpt:
771
+ args.save_path = os.path.dirname(os.path.abspath(args.resume_ckpt))
772
+ args.run_name = os.path.basename(args.save_path)
773
+ else:
774
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
775
+ os.makedirs(args.save_path, exist_ok=True)
776
+ set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
777
+
778
+ # Initialize the model
779
+ print("Loading models..")
780
+
781
+ # Load pretrained model for reference (frozen)
782
+ pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
783
+ map_location='cpu',
784
+ weights_only=False)
785
+ pretrained.eval()
786
+ for param in pretrained.parameters():
787
+ param.requires_grad = False
788
+
789
+ # Load checkpoint to extract config
790
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
791
+ if 'hyper_parameters' in checkpoint:
792
+ config = checkpoint['hyper_parameters']['config']
793
+ elif 'config' in checkpoint:
794
+ config = checkpoint['config']
795
+ else:
796
+ raise ValueError("Cannot find config in checkpoint")
797
+
798
+ # Update config for adaptive schedule
799
+ from omegaconf import OmegaConf
800
+ if not OmegaConf.is_config(config):
801
+ from omegaconf import DictConfig
802
+ config = DictConfig(config)
803
+
804
+ # Disable struct mode to allow adding new keys
805
+ OmegaConf.set_struct(config, False)
806
+
807
+ config.training.use_adaptive_schedule = args.use_adaptive_schedule
808
+ config.training.schedule_hidden_dim = args.schedule_hidden_dim
809
+ config.training.schedule_num_layers = args.schedule_num_layers
810
+ config.training.schedule_loss_weight = args.schedule_loss_weight
811
+ config.training.freeze_base_model = args.freeze_base_model
812
+ config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
813
+
814
+ # Re-enable struct mode
815
+ OmegaConf.set_struct(config, True)
816
+
817
+ # Initialize policy model with adaptive schedule
818
+ policy_model = AnyOrderInsertionFlowModuleFT(
819
+ config=config,
820
+ args=args,
821
+ pretrained_checkpoint=checkpoint_path,
822
+ insertion_planner=True,
823
+ )
824
+
825
+ # define mcts
826
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
827
+
828
+ tokenizer = SMILES_SPE_Tokenizer(
829
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
830
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
831
+ )
832
+
833
+ # Device will be set by Lightning automatically in DDP
834
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device='cpu')
835
+ model = PeptideFinetuner(
836
+ args=args,
837
+ policy_model=policy_model,
838
+ reward_model=reward_model,
839
+ tokenizer=tokenizer,
840
+ pretrained=pretrained,
841
+ mcts=None,
842
+ filename=filename,
843
+ prot_name=prot_name
844
+ )
845
+
846
+ # Setup checkpoint callback
847
+ checkpoint_callback = ModelCheckpoint(
848
+ dirpath=args.save_path,
849
+ filename='model-{epoch:02d}',
850
+ every_n_epochs=args.save_every_n_epochs,
851
+ save_top_k=-1,
852
+ save_last=True, # Also save last.ckpt
853
+ auto_insert_metric_name=False
854
+ )
855
+
856
+ # Setup wandb logger - only on rank 0 to avoid multiple runs
857
+ # Check if we're in a spawned DDP process
858
+ rank = int(os.environ.get('LOCAL_RANK', 0))
859
+ if rank == 0:
860
+ # Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
861
+ wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-pep', name=args.run_name)
862
+ else:
863
+ wandb_logger = None
864
+
865
+ # Create dummy dataloader
866
+ dataset = DummyDataset(size=args.num_training_steps_per_epoch)
867
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
868
+
869
+ # Setup trainer with DDP
870
+ trainer = pl.Trainer(
871
+ max_epochs=args.num_epochs,
872
+ accelerator='gpu',
873
+ devices=args.devices,
874
+ strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
875
+ gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
876
+ logger=wandb_logger,
877
+ callbacks=[checkpoint_callback],
878
+ enable_progress_bar=True,
879
+ log_every_n_steps=1
880
+ )
881
+
882
+ # Train (resume full training state from --resume_ckpt if provided).
883
+ # weights_only=False is required when resuming because these checkpoints
884
+ # store argparse.Namespace / OmegaConf objects in hyper_parameters, which
885
+ # PyTorch 2.6's default weights_only=True unpickler rejects.
886
+ if args.resume_ckpt:
887
+ trainer.fit(model, dataloader, ckpt_path=args.resume_ckpt, weights_only=False)
888
+ else:
889
+ trainer.fit(model, dataloader)
890
+
891
+ if __name__ == '__main__':
892
+ main()
a2d2_pep/inference_quality.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified peptide sampling with quality-guided planning.
2
+
3
+ Supports 4 quality modes and optional RND (importance weight) computation.
4
+
5
+ Quality modes:
6
+ "none" - No planner, no remasking (policy-only)
7
+ "both" - Both unmasking + insertion planners active
8
+ "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
9
+ "insertion_only" - Only insertion planner (unmasking planner disabled)
10
+
11
+ RND toggle:
12
+ compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
13
+ compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
14
+ """
15
+
16
+ import os
17
+ import torch
18
+ import numpy as np
19
+ import pandas as pd
20
+ import torch.nn.functional as F
21
+ from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
22
+ from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
23
+
24
+ QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
25
+
26
+ # When set (e.g. A2D2_QUALITY_DEBUG=1), the diffusion loop prints, per step, how
27
+ # many already-unmasked tokens get remasked and how many proposed insertions get
28
+ # filtered by the quality planner, plus a per-batch total. Off by default so it
29
+ # never spams training/eval runs.
30
+ _QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False")
31
+
32
+
33
+ @torch.no_grad()
34
+ def _diffusion_loop(
35
+ model, steps, mask, pad, batch_size, max_length,
36
+ quality_mode="both",
37
+ compute_rnd=False,
38
+ pretrained=None,
39
+ remasking_mode="schedule_aware",
40
+ num_remasking=1,
41
+ quality_threshold=1,
42
+ unmask_quality_threshold=None,
43
+ unmask_all=False,
44
+ freq_penalty=0.0,
45
+ return_trace=False,
46
+ ):
47
+ """Core discrete diffusion sampling loop for peptide generation.
48
+
49
+ Args:
50
+ model: Finetuned policy model.
51
+ steps: Number of diffusion steps.
52
+ mask: Mask token ID.
53
+ pad: Pad token ID.
54
+ batch_size: Number of sequences to generate.
55
+ max_length: Maximum sequence length.
56
+ quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
57
+ compute_rnd: Whether to compute step-wise log importance weights.
58
+ pretrained: Frozen pretrained model (required if compute_rnd=True).
59
+ remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
60
+ num_remasking: Number of tokens to remask per step.
61
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
62
+ return_trace: Whether to record sampling trace.
63
+
64
+ Returns:
65
+ (xt, log_rnd, sampling_trace)
66
+ log_rnd is None when compute_rnd=False.
67
+ """
68
+ assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
69
+ if compute_rnd:
70
+ assert pretrained is not None, "pretrained model required when compute_rnd=True"
71
+
72
+ # Derive flags from quality_mode
73
+ use_remasking = quality_mode != "none"
74
+ disable_unmasking_planner = quality_mode in ("none", "insertion_only")
75
+ disable_insertion_planner = quality_mode in ("none", "unmasking_only")
76
+
77
+ device = next(model.parameters()).device
78
+
79
+ # Initialize all-pad sequence
80
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
81
+
82
+ dt = 1.0 / steps
83
+ t = torch.zeros(batch_size, device=device)
84
+
85
+ # Precompute index tensors
86
+ batch_idx_L = (
87
+ torch.arange(batch_size, device=device)
88
+ .view(batch_size, 1)
89
+ .expand(batch_size, max_length)
90
+ )
91
+ pos_idx_L = (
92
+ torch.arange(max_length, device=device)
93
+ .view(1, max_length)
94
+ .expand(batch_size, max_length)
95
+ )
96
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
97
+
98
+ neg_inf = torch.tensor(-np.inf, device=device)
99
+
100
+ if use_remasking and remasking_mode == "remdm_conf":
101
+ remasking_score = torch.zeros((batch_size, max_length), device=device)
102
+
103
+ log_rnd = None
104
+
105
+ dbg_total_remasked = 0
106
+ dbg_total_proposed_ins = 0
107
+ dbg_total_filtered = 0
108
+
109
+ for i in range(steps):
110
+ step_remasked = 0
111
+ step_proposed_ins = 0
112
+ step_filtered = 0
113
+ # --- Policy model forward ---
114
+ pred_rate = model(xt, t)
115
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
116
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
117
+ len_rate = pred_rate.length_rate # (B, L+1)
118
+
119
+ # --- Pretrained model forward (for RND) ---
120
+ if compute_rnd:
121
+ pretrained_pred = pretrained(xt, t)
122
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
123
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
124
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
125
+
126
+ # --- Unmask step (Euler) ---
127
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
128
+ unmask_rate[xt != mask] = 0
129
+ unmask_rate[mask_pos + (mask,)] = 0
130
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
131
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
132
+
133
+ if compute_rnd:
134
+ pretrained_unmask_rate[xt != mask] = 0
135
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
136
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
137
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
138
+
139
+ # Add "stay" probability
140
+ _xt = xt.clone()
141
+ _xt[xt == pad] = mask
142
+ trans_prob.scatter_add_(
143
+ 2,
144
+ _xt.unsqueeze(-1),
145
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
146
+ )
147
+ if compute_rnd:
148
+ pretrained_trans_prob.scatter_add_(
149
+ 2,
150
+ _xt.unsqueeze(-1),
151
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
152
+ )
153
+
154
+ # Remove mask token from sampling so every masked position is decoded.
155
+ # The final step always does this; unmask_all does it every step, so the
156
+ # schedule-aware remasking below re-masks the lowest-quality tokens back
157
+ # down to the schedule's expected mask count.
158
+ if i == steps - 1 or unmask_all:
159
+ if i == steps - 1:
160
+ print("Final step, removing mask token from sampling")
161
+ trans_prob[mask_pos + (mask,)] = 0.0
162
+
163
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
164
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
165
+ if mask_has_zero_prob.any():
166
+ num_zero_prob = mask_has_zero_prob.sum().item()
167
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
168
+ uniform_prob[:, :mask] = 1.0 / mask
169
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
170
+ else:
171
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
172
+
173
+ # --- Frequency penalty: down-weight residues already abundant in the
174
+ # sequence so (re)decoded masked positions don't collapse onto the modal
175
+ # token (glycine). Only masked positions are sampled; clean positions are
176
+ # overwritten below, so penalizing the whole tensor is harmless. mask/pad
177
+ # never accumulate counts, so their entries stay untouched. Applied to a
178
+ # copy so trans_prob (used for RND log-probs) is unchanged.
179
+ sample_prob = trans_prob
180
+ if freq_penalty > 0.0:
181
+ V = trans_prob.shape[-1]
182
+ clean_tok = (xt != mask) & (xt != pad) # (B, L)
183
+ counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype)
184
+ counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)),
185
+ clean_tok.to(trans_prob.dtype))
186
+ sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1)
187
+
188
+ new_xt = _sample_tokens(sample_prob)
189
+ new_xt[xt == pad] = pad
190
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
191
+
192
+ # Update remasking_score buffer for remdm_conf mode
193
+ if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
194
+ token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
195
+ chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
196
+ changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
197
+ remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
198
+
199
+ # --- Remasking step ---
200
+ if use_remasking and i < steps - 1:
201
+ if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
202
+ remasking_conf = torch.zeros((batch_size, max_length), device=device)
203
+ else:
204
+ planner_out = model.planner(new_xt, t)
205
+ remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
206
+
207
+ clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
208
+
209
+ if remasking_mode == "remdm":
210
+ remasking_score_temp = torch.rand(remasking_conf.shape, device=device)
211
+ elif remasking_mode == "remdm_conf":
212
+ remasking_score_temp = -1.0 * remasking_conf
213
+ elif remasking_mode == "schedule_aware":
214
+ # Only remask when the unmasking planner is active. Otherwise
215
+ # (e.g. insertion_only / no_unmasking_planner) remasking_conf is
216
+ # all zeros, so this would remask schedule-excess tokens by
217
+ # position rather than by quality.
218
+ if not disable_unmasking_planner:
219
+ new_xt = apply_schedule_aware_remasking(
220
+ model, new_xt, t, dt, remasking_conf, clean_index,
221
+ mask, neg_inf, batch_size,
222
+ unmask_quality_threshold=unmask_quality_threshold,
223
+ )
224
+ remasking_score_temp = None
225
+ else:
226
+ raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
227
+
228
+ if remasking_score_temp is not None:
229
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
230
+ for j in range(batch_size):
231
+ k = min(num_remasking, int(clean_index[j].sum().item()))
232
+ if k > 0:
233
+ _, select_indices = torch.topk(remasking_score_temp[j], k=k)
234
+ new_xt[j, select_indices] = mask
235
+
236
+ if _QUALITY_DEBUG:
237
+ # Positions that were clean before this remasking block and are
238
+ # now mask are exactly the unmasked tokens that got remasked.
239
+ step_remasked = int((clean_index & (new_xt == mask)).sum().item())
240
+
241
+ if return_trace:
242
+ for batch_idx in range(batch_size):
243
+ for pos in range(max_length):
244
+ if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
245
+ sampling_trace[batch_idx].append(
246
+ SamplingTraceDatapoint(
247
+ t=t[batch_idx].item(),
248
+ event_type="change",
249
+ position=pos,
250
+ token=mask,
251
+ )
252
+ )
253
+
254
+ # --- Compute log probabilities for RND ---
255
+ if compute_rnd:
256
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
257
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
258
+
259
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
260
+
261
+ log_policy_step = (lp * changed_mask).sum(dim=1)
262
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
263
+
264
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
265
+
266
+ # --- Insertion step ---
267
+ if i != steps - 1:
268
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
269
+
270
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
271
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
272
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
273
+ total_ext = ext.sum(dim=1)
274
+ valid = xt_len + total_ext <= max_length
275
+ ext = ext * valid.view(batch_size, 1).long()
276
+
277
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
278
+ new_len = xt_len + total_ext # (B,)
279
+
280
+ xt_tmp = torch.full_like(xt, pad)
281
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
282
+ xt_tmp[mask_fill] = mask
283
+
284
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
285
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
286
+ flat_b = batch_idx_L[orig_mask]
287
+ flat_p = new_pos_orig[orig_mask]
288
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
289
+
290
+ if _QUALITY_DEBUG:
291
+ # ext has been masked by the max-length validity check above, so
292
+ # this is the number of fresh mask tokens actually inserted.
293
+ step_proposed_ins = int(ext.sum().item())
294
+
295
+ # Schedule-aware insertion quality filtering
296
+ if use_remasking and not disable_insertion_planner:
297
+ if compute_rnd:
298
+ xt_tmp_before = xt_tmp.clone()
299
+
300
+ dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0
301
+
302
+ xt_tmp = apply_schedule_aware_insertion(
303
+ model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
304
+ orig_mask, new_pos_orig, quality_threshold
305
+ )
306
+
307
+ if _QUALITY_DEBUG:
308
+ # Filtering only drops/compacts tokens, so the drop in
309
+ # non-pad count is the number of insertions filtered out.
310
+ step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item())
311
+
312
+ if compute_rnd:
313
+ # Compute corrected ext based on what actually stayed
314
+ ext_corrected = torch.zeros_like(ext)
315
+ for b in range(batch_size):
316
+ after_len = xt_tmp[b].ne(pad).sum().item()
317
+ orig_len = xt_len[b].item()
318
+ surviving_insertions = after_len - orig_len
319
+ if total_ext[b] > 0:
320
+ ratio = surviving_insertions / total_ext[b].item()
321
+ ext_corrected[b] = (ext[b].float() * ratio).long()
322
+ else:
323
+ ext_corrected = ext
324
+ else:
325
+ ext_corrected = ext
326
+
327
+ # Compute insertion log_rnd
328
+ if compute_rnd:
329
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
330
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
331
+
332
+ log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
333
+ log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
334
+
335
+ log_insert_diff = log_pretrained_insert - log_policy_insert
336
+ log_rnd += log_insert_diff
337
+ else:
338
+ xt_tmp = new_xt
339
+
340
+ if return_trace:
341
+ for batch_idx in range(batch_size):
342
+ for j in range(max_length):
343
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
344
+ sampling_trace[batch_idx].append(
345
+ SamplingTraceDatapoint(
346
+ t=t[batch_idx].item(),
347
+ event_type="change",
348
+ position=j,
349
+ token=new_xt[batch_idx, j].item(),
350
+ )
351
+ )
352
+
353
+ if i != steps - 1:
354
+ for j in range(max_length):
355
+ id = max_length - j - 1
356
+ if ext[batch_idx, id]:
357
+ sampling_trace[batch_idx].append(
358
+ SamplingTraceDatapoint(
359
+ t=t[batch_idx].item(),
360
+ event_type="insertion",
361
+ position=id,
362
+ token=mask,
363
+ )
364
+ )
365
+
366
+ if _QUALITY_DEBUG:
367
+ dbg_total_remasked += step_remasked
368
+ dbg_total_proposed_ins += step_proposed_ins
369
+ dbg_total_filtered += step_filtered
370
+ print(
371
+ f"[QUALITY {quality_mode}] step {i+1}/{steps}: "
372
+ f"remasked {step_remasked} unmasked tokens -> mask | "
373
+ f"insertions proposed {step_proposed_ins}, "
374
+ f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}"
375
+ )
376
+
377
+ xt = xt_tmp
378
+ t = t + dt
379
+
380
+ if _QUALITY_DEBUG:
381
+ print(
382
+ f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): "
383
+ f"remasked {dbg_total_remasked} unmasked tokens | "
384
+ f"insertions proposed {dbg_total_proposed_ins}, "
385
+ f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}"
386
+ )
387
+
388
+ return xt, log_rnd, sampling_trace
389
+
390
+
391
+ @torch.no_grad()
392
+ def sample_peptides_buffer(
393
+ model, reward_model, analyzer, tokenizer,
394
+ steps, mask, pad, batch_size, max_length,
395
+ quality_mode="both",
396
+ compute_rnd=False,
397
+ pretrained=None,
398
+ alpha=0.1,
399
+ remasking_mode="schedule_aware",
400
+ num_remasking=1,
401
+ quality_threshold=1,
402
+ min_length=0,
403
+ ):
404
+ """Generate peptides for training buffer.
405
+
406
+ Args:
407
+ model: Finetuned policy model.
408
+ reward_model: Multi-objective scoring function.
409
+ analyzer: PeptideAnalyzer for validation.
410
+ tokenizer: Tokenizer for decoding.
411
+ steps: Number of diffusion steps.
412
+ mask: Mask token ID.
413
+ pad: Pad token ID.
414
+ batch_size: Number of sequences to generate.
415
+ max_length: Maximum sequence length.
416
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
417
+ compute_rnd: If True, compute step-wise log importance weights (requires pretrained).
418
+ If False, returns placeholder zero log_rnd (for ELBO-based RND).
419
+ pretrained: Frozen pretrained model (required when compute_rnd=True).
420
+ alpha: RND scaling factor.
421
+ remasking_mode: Remasking strategy.
422
+ num_remasking: Number of tokens to remask per step.
423
+ quality_threshold: Threshold for insertion quality filtering.
424
+
425
+ Returns:
426
+ (valid_x, log_rnd, scalar_rewards, sampling_trace)
427
+ """
428
+ xt, log_rnd, trace = _diffusion_loop(
429
+ model, steps, mask, pad, batch_size, max_length,
430
+ quality_mode=quality_mode,
431
+ compute_rnd=compute_rnd,
432
+ pretrained=pretrained,
433
+ remasking_mode=remasking_mode,
434
+ num_remasking=num_remasking,
435
+ quality_threshold=quality_threshold,
436
+ )
437
+
438
+ device = xt.device
439
+ decoded_samples = tokenizer.batch_decode(xt)
440
+
441
+ valid_x_final = []
442
+ validSequences = []
443
+ valid_log_rnd = []
444
+
445
+ for idx, seq in enumerate(decoded_samples):
446
+ if not analyzer.is_peptide(seq):
447
+ continue
448
+ token_len = int((xt[idx] != pad).sum().item())
449
+ if min_length > 0 and token_len < min_length:
450
+ continue
451
+ valid_x_final.append(xt[idx])
452
+ validSequences.append(seq)
453
+ if compute_rnd:
454
+ valid_log_rnd.append(log_rnd[idx])
455
+
456
+ print("len valid sequences:", len(validSequences))
457
+
458
+ if len(validSequences) == 0:
459
+ print("[WARNING] No valid peptides generated in this batch")
460
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
461
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
462
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
463
+ return empty_x, empty_log_rnd, empty_rewards, trace
464
+
465
+ score_vectors = reward_model(input_seqs=validSequences)
466
+ scalar_rewards = np.sum(score_vectors, axis=-1)
467
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
468
+
469
+ print(f"scalar reward dim{len(scalar_rewards)}")
470
+ valid_x_final = torch.stack(valid_x_final, dim=0)
471
+
472
+ if compute_rnd:
473
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
474
+ log_rnd_out = valid_log_rnd + (scalar_rewards / alpha)
475
+ else:
476
+ log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device)
477
+
478
+ return valid_x_final, log_rnd_out, scalar_rewards, trace
479
+
480
+
481
+ @torch.no_grad()
482
+ def sample_peptides_eval(
483
+ model, reward_model, analyzer, tokenizer,
484
+ steps, mask, pad, batch_size, max_length,
485
+ quality_mode="both",
486
+ remasking_mode="schedule_aware",
487
+ num_remasking=1,
488
+ quality_threshold=1,
489
+ unmask_quality_threshold=None,
490
+ unmask_all=False,
491
+ freq_penalty=0.0,
492
+ dataframe=False,
493
+ return_valid=False,
494
+ ):
495
+ """Generate peptides for evaluation.
496
+
497
+ Args:
498
+ model: Finetuned policy model.
499
+ reward_model: Multi-objective scoring function.
500
+ analyzer: PeptideAnalyzer for validation.
501
+ tokenizer: Tokenizer for decoding.
502
+ steps: Number of diffusion steps.
503
+ mask: Mask token ID.
504
+ pad: Pad token ID.
505
+ batch_size: Number of sequences to generate.
506
+ max_length: Maximum sequence length.
507
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
508
+ remasking_mode: Remasking strategy.
509
+ num_remasking: Number of tokens to remask per step.
510
+ quality_threshold: Threshold for insertion quality filtering.
511
+ dataframe: If True, include a pandas DataFrame in the return.
512
+ return_valid: If True, return decoded valid sequences instead of raw token tensors.
513
+
514
+ Returns:
515
+ For multi-objective (5 objectives):
516
+ (samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df])
517
+ For single objective:
518
+ (samples, sol, valid_fraction[, df])
519
+ When return_valid=True, samples is replaced with validSequences list.
520
+ """
521
+ xt, _, trace = _diffusion_loop(
522
+ model, steps, mask, pad, batch_size, max_length,
523
+ quality_mode=quality_mode,
524
+ compute_rnd=False,
525
+ remasking_mode=remasking_mode,
526
+ num_remasking=num_remasking,
527
+ quality_threshold=quality_threshold,
528
+ unmask_quality_threshold=unmask_quality_threshold,
529
+ unmask_all=unmask_all,
530
+ freq_penalty=freq_penalty,
531
+ )
532
+
533
+ device = xt.device
534
+ samples = xt.to(device)
535
+ decoded_samples = tokenizer.batch_decode(samples)
536
+
537
+ valid_x_final = []
538
+ validSequences = []
539
+
540
+ for idx, seq in enumerate(decoded_samples):
541
+ if analyzer.is_peptide(seq):
542
+ valid_x_final.append(samples[idx])
543
+ validSequences.append(seq)
544
+
545
+ print("len valid sequences:", len(validSequences))
546
+
547
+ valid_fraction = len(validSequences) / batch_size
548
+
549
+ # Determine number of objectives from reward model
550
+ num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5
551
+
552
+ if len(validSequences) != 0:
553
+ score_vectors = reward_model(input_seqs=validSequences) # (N, num_objectives)
554
+ average_scores = score_vectors.T
555
+
556
+ if num_objectives == 1:
557
+ sol = average_scores[0]
558
+ else:
559
+ affinity = average_scores[0]
560
+ sol = average_scores[1]
561
+ hemo = average_scores[2]
562
+ nf = average_scores[3]
563
+ permeability = average_scores[4]
564
+ else:
565
+ zeros = [0.0]
566
+
567
+ if num_objectives == 1:
568
+ sol = zeros
569
+ else:
570
+ affinity = zeros
571
+ sol = zeros
572
+ hemo = zeros
573
+ nf = zeros
574
+ permeability = zeros
575
+
576
+ if num_objectives == 1:
577
+ if dataframe:
578
+ df = pd.DataFrame({
579
+ "Peptide Sequence": validSequences,
580
+ "Solubility": sol if len(validSequences) else [0.0],
581
+ })
582
+ if return_valid:
583
+ return validSequences, sol, valid_fraction, df
584
+ return samples, sol, valid_fraction, df
585
+
586
+ if return_valid:
587
+ return validSequences, sol, valid_fraction
588
+ return samples, sol, valid_fraction
589
+
590
+ if dataframe:
591
+ df = pd.DataFrame({
592
+ "Peptide Sequence": validSequences,
593
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
594
+ "Solubility": sol if len(validSequences) else [0.0],
595
+ "Hemolysis": hemo if len(validSequences) else [0.0],
596
+ "Nonfouling": nf if len(validSequences) else [0.0],
597
+ "Permeability": permeability if len(validSequences) else [0.0],
598
+ })
599
+ if return_valid:
600
+ return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df
601
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
602
+
603
+ if return_valid:
604
+ return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
605
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
a2d2_pep/pep_scoring/functions/binding.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os, torch
3
+ import numpy as np
4
+ import torch
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import esm
8
+ from transformers import AutoModelForMaskedLM
9
+
10
+ class ImprovedBindingPredictor(nn.Module):
11
+ def __init__(self,
12
+ esm_dim=1280,
13
+ smiles_dim=768,
14
+ hidden_dim=512,
15
+ n_heads=8,
16
+ n_layers=3,
17
+ dropout=0.1):
18
+ super().__init__()
19
+
20
+ # Define binding thresholds
21
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
22
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
23
+
24
+ # Project to same dimension
25
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
26
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
27
+ self.protein_norm = nn.LayerNorm(hidden_dim)
28
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
29
+
30
+ # Cross attention blocks with layer norm
31
+ self.cross_attention_layers = nn.ModuleList([
32
+ nn.ModuleDict({
33
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
34
+ 'norm1': nn.LayerNorm(hidden_dim),
35
+ 'ffn': nn.Sequential(
36
+ nn.Linear(hidden_dim, hidden_dim * 4),
37
+ nn.ReLU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(hidden_dim * 4, hidden_dim)
40
+ ),
41
+ 'norm2': nn.LayerNorm(hidden_dim)
42
+ }) for _ in range(n_layers)
43
+ ])
44
+
45
+ # Prediction heads
46
+ self.shared_head = nn.Sequential(
47
+ nn.Linear(hidden_dim * 2, hidden_dim),
48
+ nn.ReLU(),
49
+ nn.Dropout(dropout),
50
+ )
51
+
52
+ # Regression head
53
+ self.regression_head = nn.Linear(hidden_dim, 1)
54
+
55
+ # Classification head (3 classes: tight, medium, loose binding)
56
+ self.classification_head = nn.Linear(hidden_dim, 3)
57
+
58
+ def get_binding_class(self, affinity):
59
+ """Convert affinity values to class indices
60
+ 0: tight binding (>= 7.5)
61
+ 1: medium binding (6.0-7.5)
62
+ 2: weak binding (< 6.0)
63
+ """
64
+ if isinstance(affinity, torch.Tensor):
65
+ tight_mask = affinity >= self.tight_threshold
66
+ weak_mask = affinity < self.weak_threshold
67
+ medium_mask = ~(tight_mask | weak_mask)
68
+
69
+ classes = torch.zeros_like(affinity, dtype=torch.long)
70
+ classes[medium_mask] = 1
71
+ classes[weak_mask] = 2
72
+ return classes
73
+ else:
74
+ if affinity >= self.tight_threshold:
75
+ return 0 # tight binding
76
+ elif affinity < self.weak_threshold:
77
+ return 2 # weak binding
78
+ else:
79
+ return 1 # medium binding
80
+
81
+ def forward(self, protein_emb, smiles_emb):
82
+ protein = self.protein_norm(self.protein_projection(protein_emb))
83
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
84
+
85
+ #protein = protein.transpose(0, 1)
86
+ #smiles = smiles.transpose(0, 1)
87
+
88
+ # Cross attention layers
89
+ for layer in self.cross_attention_layers:
90
+ # Protein attending to SMILES
91
+ attended_protein = layer['attention'](
92
+ protein, smiles, smiles
93
+ )[0]
94
+ protein = layer['norm1'](protein + attended_protein)
95
+ protein = layer['norm2'](protein + layer['ffn'](protein))
96
+
97
+ # SMILES attending to protein
98
+ attended_smiles = layer['attention'](
99
+ smiles, protein, protein
100
+ )[0]
101
+ smiles = layer['norm1'](smiles + attended_smiles)
102
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
103
+
104
+ # Get sequence-level representations
105
+ protein_pool = torch.mean(protein, dim=0)
106
+ smiles_pool = torch.mean(smiles, dim=0)
107
+
108
+ # Concatenate both representations
109
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
110
+
111
+ # Shared features
112
+ shared_features = self.shared_head(combined)
113
+
114
+ regression_output = self.regression_head(shared_features)
115
+ classification_logits = self.classification_head(shared_features)
116
+
117
+ return regression_output, classification_logits
118
+
119
+ class BindingAffinity:
120
+ def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
121
+ super().__init__()
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
123
+
124
+ # peptide embeddings
125
+ if emb_model is not None:
126
+ self.pep_model = emb_model.to(self.device).eval()
127
+ else:
128
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
129
+
130
+ self.pep_tokenizer = tokenizer
131
+
132
+ self.model = ImprovedBindingPredictor().to(self.device)
133
+ checkpoint = torch.load(f'{base_path}/functions/classifiers/binding-affinity.pt',
134
+ map_location=self.device,
135
+ weights_only=False)
136
+ self.model.load_state_dict(checkpoint['model_state_dict'])
137
+
138
+ self.model.eval()
139
+
140
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
141
+ self.esm_model = self.esm_model.to(self.device).eval()
142
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
143
+
144
+ data = [("target", prot_seq)]
145
+ # get tokenized protein
146
+ _, _, prot_tokens = self.prot_tokenizer(data)
147
+ prot_tokens = prot_tokens.to(self.device)
148
+ with torch.no_grad():
149
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
150
+ prot_emb = results["representations"][33]
151
+
152
+ self.prot_emb = prot_emb[0].to(self.device)
153
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
154
+
155
+
156
+ def forward(self, input_seqs):
157
+ with torch.no_grad():
158
+ scores = []
159
+ for seq in input_seqs:
160
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
161
+
162
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
163
+
164
+ with torch.no_grad():
165
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
166
+ attention_mask=pep_tokens['attention_mask'],
167
+ output_hidden_states=True)
168
+
169
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
170
+ pep_emb = emb.last_hidden_state.squeeze(0)
171
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
172
+
173
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
174
+ scores.append(score.item())
175
+ return scores
176
+
177
+ def __call__(self, input_seqs: list):
178
+ return self.forward(input_seqs)
a2d2_pep/pep_scoring/functions/binding_utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ def to_var(x):
6
+ if torch.cuda.is_available():
7
+ x = x.cuda()
8
+ return x
9
+
10
+ class MultiHeadAttentionSequence(nn.Module):
11
+
12
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13
+
14
+ super().__init__()
15
+
16
+ self.n_head = n_head
17
+ self.d_model = d_model
18
+ self.d_k = d_k
19
+ self.d_v = d_v
20
+
21
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
22
+ self.W_K = nn.Linear(d_model, n_head*d_k)
23
+ self.W_V = nn.Linear(d_model, n_head*d_v)
24
+ self.W_O = nn.Linear(n_head*d_v, d_model)
25
+
26
+ self.layer_norm = nn.LayerNorm(d_model)
27
+
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ def forward(self, q, k, v):
31
+
32
+ batch, len_q, _ = q.size()
33
+ batch, len_k, _ = k.size()
34
+ batch, len_v, _ = v.size()
35
+
36
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
37
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
38
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
39
+
40
+ Q = Q.transpose(1, 2)
41
+ K = K.transpose(1, 2).transpose(2, 3)
42
+ V = V.transpose(1, 2)
43
+
44
+ attention = torch.matmul(Q, K)
45
+
46
+ attention = attention / np.sqrt(self.d_k)
47
+
48
+ attention = F.softmax(attention, dim=-1)
49
+
50
+ output = torch.matmul(attention, V)
51
+
52
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
53
+
54
+ output = self.W_O(output)
55
+
56
+ output = self.dropout(output)
57
+
58
+ output = self.layer_norm(output + q)
59
+
60
+ return output, attention
61
+
62
+ class MultiHeadAttentionReciprocal(nn.Module):
63
+
64
+
65
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
66
+
67
+ super().__init__()
68
+
69
+ self.n_head = n_head
70
+ self.d_model = d_model
71
+ self.d_k = d_k
72
+ self.d_v = d_v
73
+
74
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
75
+ self.W_K = nn.Linear(d_model, n_head*d_k)
76
+ self.W_V = nn.Linear(d_model, n_head*d_v)
77
+ self.W_O = nn.Linear(n_head*d_v, d_model)
78
+ self.W_V_2 = nn.Linear(d_model, n_head*d_v)
79
+ self.W_O_2 = nn.Linear(n_head*d_v, d_model)
80
+
81
+ self.layer_norm = nn.LayerNorm(d_model)
82
+
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ self.layer_norm_2 = nn.LayerNorm(d_model)
86
+
87
+ self.dropout_2 = nn.Dropout(dropout)
88
+
89
+ def forward(self, q, k, v, v_2):
90
+
91
+ batch, len_q, _ = q.size()
92
+ batch, len_k, _ = k.size()
93
+ batch, len_v, _ = v.size()
94
+ batch, len_v_2, _ = v_2.size()
95
+
96
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
97
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
98
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
99
+ V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
100
+
101
+ Q = Q.transpose(1, 2)
102
+ K = K.transpose(1, 2).transpose(2, 3)
103
+ V = V.transpose(1, 2)
104
+ V_2 = V_2.transpose(1,2)
105
+
106
+ attention = torch.matmul(Q, K)
107
+
108
+
109
+ attention = attention /np.sqrt(self.d_k)
110
+
111
+ attention_2 = attention.transpose(-2, -1)
112
+
113
+
114
+
115
+ attention = F.softmax(attention, dim=-1)
116
+
117
+ attention_2 = F.softmax(attention_2, dim=-1)
118
+
119
+
120
+ output = torch.matmul(attention, V)
121
+
122
+ output_2 = torch.matmul(attention_2, V_2)
123
+
124
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
125
+
126
+ output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
127
+
128
+ output = self.W_O(output)
129
+
130
+ output_2 = self.W_O_2(output_2)
131
+
132
+ output = self.dropout(output)
133
+
134
+ output = self.layer_norm(output + q)
135
+
136
+ output_2 = self.dropout(output_2)
137
+
138
+ output_2 = self.layer_norm(output_2 + k)
139
+
140
+
141
+ return output, output_2, attention, attention_2
142
+
143
+
144
+ class FFN(nn.Module):
145
+
146
+ def __init__(self, d_in, d_hid, dropout=0.1):
147
+ super().__init__()
148
+
149
+ self.layer_1 = nn.Conv1d(d_in, d_hid,1)
150
+ self.layer_2 = nn.Conv1d(d_hid, d_in,1)
151
+ self.relu = nn.ReLU()
152
+ self.layer_norm = nn.LayerNorm(d_in)
153
+
154
+ self.dropout = nn.Dropout(dropout)
155
+
156
+ def forward(self, x):
157
+
158
+ residual = x
159
+ output = self.layer_1(x.transpose(1, 2))
160
+
161
+ output = self.relu(output)
162
+
163
+ output = self.layer_2(output)
164
+
165
+ output = self.dropout(output)
166
+
167
+ output = self.layer_norm(output.transpose(1, 2)+residual)
168
+
169
+ return output
170
+
171
+ class ConvLayer(nn.Module):
172
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
173
+ super(ConvLayer, self).__init__()
174
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
175
+ self.relu = nn.ReLU()
176
+
177
+ def forward(self, x):
178
+ out = self.conv(x)
179
+ out = self.relu(out)
180
+ return out
181
+
182
+
183
+ class DilatedCNN(nn.Module):
184
+ def __init__(self, d_model, d_hidden):
185
+ super(DilatedCNN, self).__init__()
186
+ self.first_ = nn.ModuleList()
187
+ self.second_ = nn.ModuleList()
188
+ self.third_ = nn.ModuleList()
189
+
190
+ dilation_tuple = (1, 2, 3)
191
+ dim_in_tuple = (d_model, d_hidden, d_hidden)
192
+ dim_out_tuple = (d_hidden, d_hidden, d_hidden)
193
+
194
+ for i, dilation_rate in enumerate(dilation_tuple):
195
+ self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
196
+ dilation=dilation_rate))
197
+
198
+ for i, dilation_rate in enumerate(dilation_tuple):
199
+ self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
200
+ dilation=dilation_rate))
201
+
202
+ for i, dilation_rate in enumerate(dilation_tuple):
203
+ self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
204
+ dilation=dilation_rate))
205
+
206
+ def forward(self, protein_seq_enc):
207
+ # pdb.set_trace()
208
+ protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
209
+
210
+ first_embedding = protein_seq_enc
211
+ second_embedding = protein_seq_enc
212
+ third_embedding = protein_seq_enc
213
+
214
+ for i in range(len(self.first_)):
215
+ first_embedding = self.first_[i](first_embedding)
216
+
217
+ for i in range(len(self.second_)):
218
+ second_embedding = self.second_[i](second_embedding)
219
+
220
+ for i in range(len(self.third_)):
221
+ third_embedding = self.third_[i](third_embedding)
222
+
223
+ # pdb.set_trace()
224
+
225
+ protein_seq_enc = first_embedding + second_embedding + third_embedding
226
+
227
+ return protein_seq_enc.transpose(1, 2)
228
+
229
+
230
+ class ReciprocalLayerwithCNN(nn.Module):
231
+
232
+ def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
233
+ super().__init__()
234
+
235
+ self.cnn = DilatedCNN(d_model, d_hidden)
236
+
237
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
238
+
239
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
240
+
241
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
242
+
243
+ self.ffn_seq = FFN(d_hidden, d_inner)
244
+
245
+ self.ffn_protein = FFN(d_hidden, d_inner)
246
+
247
+ def forward(self, sequence_enc, protein_seq_enc):
248
+ # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
249
+ protein_seq_enc = self.cnn(protein_seq_enc)
250
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
251
+
252
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
253
+
254
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
255
+
256
+ prot_enc = self.ffn_protein(prot_enc)
257
+
258
+ seq_enc = self.ffn_seq(seq_enc)
259
+
260
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
261
+
262
+
263
+ class ReciprocalLayer(nn.Module):
264
+
265
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v):
266
+
267
+ super().__init__()
268
+
269
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
270
+
271
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
272
+
273
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
274
+
275
+ self.ffn_seq = FFN(d_model, d_inner)
276
+
277
+ self.ffn_protein = FFN(d_model, d_inner)
278
+
279
+ def forward(self, sequence_enc, protein_seq_enc):
280
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
281
+
282
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
283
+
284
+
285
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
286
+ prot_enc = self.ffn_protein(prot_enc)
287
+
288
+ seq_enc = self.ffn_seq(seq_enc)
289
+
290
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
a2d2_pep/pep_scoring/functions/hemolysis.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Hemolysis:
15
+
16
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
18
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/hemolysis-xgboost.json')
19
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
20
+ self.tokenizer = tokenizer
21
+
22
+ def generate_embeddings(self, sequences):
23
+ embeddings = []
24
+ for sequence in sequences:
25
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
26
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
27
+ with torch.no_grad():
28
+ output = self.emb_model(**tokenized)
29
+ # Mean pooling across sequence length
30
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
31
+ embeddings.append(embedding)
32
+ return np.array(embeddings)
33
+
34
+ def get_scores(self, input_seqs: list):
35
+ scores = np.ones(len(input_seqs))
36
+ features = self.generate_embeddings(input_seqs)
37
+
38
+ if len(features) == 0:
39
+ return scores
40
+
41
+ features = np.nan_to_num(features, nan=0.)
42
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
43
+
44
+ features = xgb.DMatrix(features)
45
+
46
+ probs = self.predictor.predict(features)
47
+ # return the probability of it being not hemolytic
48
+ return scores - probs
49
+
50
+ def __call__(self, input_seqs: list):
51
+ scores = self.get_scores(input_seqs)
52
+ return scores
53
+
54
+ def unittest():
55
+ hemo = Hemolysis()
56
+ seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
57
+ print(hemo.tokenizer.vocab_size)
58
+ scores = hemo(input_seqs=seq)
59
+ print(scores)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
a2d2_pep/pep_scoring/functions/nonfouling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit import Chem, rdBase, DataStructs
10
+
11
+
12
+ rdBase.DisableLog('rdApp.error')
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ class Nonfouling:
18
+
19
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
21
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/nonfouling-xgboost.json')
22
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ # return the probability of it being not hemolytic
51
+ return scores
52
+
53
+ def __call__(self, input_seqs: list):
54
+ scores = self.get_scores(input_seqs)
55
+ return scores
56
+
57
+ def unittest():
58
+ nf = Nonfouling()
59
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
60
+
61
+ scores = nf(input_seqs=seq)
62
+ print(scores)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest()
a2d2_pep/pep_scoring/functions/permeability.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit.Chem import Descriptors, rdMolDescriptors
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from rdkit.Chem import AllChem
12
+ from typing import List
13
+
14
+
15
+ rdBase.DisableLog('rdApp.error')
16
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
17
+ warnings.filterwarnings("ignore", category=UserWarning)
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+
20
+ def fingerprints_from_smiles(smiles: List, size=2048):
21
+ """ Create ECFP fingerprints of smiles, with validity check """
22
+ fps = []
23
+ valid_mask = []
24
+ for i, smile in enumerate(smiles):
25
+ mol = Chem.MolFromSmiles(smile)
26
+ valid_mask.append(int(mol is not None))
27
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
28
+ fps.append(fp)
29
+
30
+ fps = np.concatenate(fps, axis=0)
31
+ return fps, valid_mask
32
+
33
+
34
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
35
+ """ Create ECFP fingerprint of a molecule """
36
+ if hashed:
37
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
38
+ else:
39
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
40
+ fp_np = np.zeros((1,))
41
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
42
+ return fp_np.reshape(1, -1)
43
+
44
+ def getMolDescriptors(mol, missingVal=0):
45
+ """ calculate the full list of descriptors for a molecule """
46
+
47
+ values, names = [], []
48
+ for nm, fn in Descriptors._descList:
49
+ try:
50
+ val = fn(mol)
51
+ except:
52
+ val = missingVal
53
+ values.append(val)
54
+ names.append(nm)
55
+
56
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
57
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
58
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
59
+
60
+ for nm, fn in custom_descriptors.items():
61
+ try:
62
+ val = fn(mol)
63
+ except:
64
+ val = missingVal
65
+ values.append(val)
66
+ names.append(nm)
67
+ return values, names
68
+
69
+ def get_pep_dps_from_smi(smi):
70
+ try:
71
+ mol = Chem.MolFromSmiles(smi)
72
+ except:
73
+ print(f"convert smi {smi} to molecule failed!")
74
+ mol = None
75
+
76
+ dps, _ = getMolDescriptors(mol)
77
+ return np.array(dps)
78
+
79
+
80
+ def get_pep_dps(smi_list):
81
+ if len(smi_list) == 0:
82
+ return np.zeros((0, 213))
83
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
84
+
85
+ def check_smi_validity(smiles: list):
86
+ valid_smi, valid_idx = [], []
87
+ for idx, smi in enumerate(smiles):
88
+ try:
89
+ mol = Chem.MolFromSmiles(smi) if smi else None
90
+ if mol:
91
+ valid_smi.append(smi)
92
+ valid_idx.append(idx)
93
+ except Exception as e:
94
+ # logger.debug(f'Error: {e} in smiles {smi}')
95
+ pass
96
+ return valid_smi, valid_idx
97
+
98
+ class Permeability:
99
+
100
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
101
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
102
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/permeability-xgboost.json')
103
+ if emb_model is not None:
104
+ self.emb_model = emb_model.to(self.device).eval()
105
+ else:
106
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
107
+
108
+ self.tokenizer = tokenizer
109
+
110
+ def generate_embeddings(self, sequences):
111
+ embeddings = []
112
+ for sequence in sequences:
113
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
114
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
115
+ with torch.no_grad():
116
+ output = self.emb_model(**tokenized)
117
+ # Mean pooling across sequence length
118
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
119
+ embeddings.append(embedding)
120
+ return np.array(embeddings)
121
+
122
+ def get_features(self, input_seqs: list, dps=False, fps=False):
123
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
124
+
125
+
126
+ if fps:
127
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
128
+ else:
129
+ fingerprints = torch.empty((len(input_seqs), 0))
130
+
131
+ if dps:
132
+ descriptors = get_pep_dps(input_seqs)
133
+ else:
134
+ descriptors = torch.empty((len(input_seqs), 0))
135
+
136
+ embeddings = self.generate_embeddings(input_seqs)
137
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
138
+
139
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
140
+
141
+ return features
142
+
143
+ def get_scores(self, input_seqs: list):
144
+ scores = -10 * np.ones(len(input_seqs))
145
+ features = self.get_features(input_seqs)
146
+
147
+ if len(features) == 0:
148
+ return scores
149
+
150
+ features = np.nan_to_num(features, nan=0.)
151
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
152
+
153
+ features = xgb.DMatrix(features)
154
+
155
+ scores = self.predictor.predict(features)
156
+ return scores
157
+
158
+ def __call__(self, input_seqs: list):
159
+ scores = self.get_scores(input_seqs)
160
+ return scores
161
+
162
+ def unittest():
163
+ permeability = Permeability()
164
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
165
+ scores = permeability(input_seqs=seq)
166
+ print(scores)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ unittest()
a2d2_pep/pep_scoring/functions/scoring_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from loguru import logger
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ from rdkit.Chem import Descriptors, rdMolDescriptors
6
+ import joblib
7
+ from rdkit import Chem, rdBase, DataStructs
8
+ from rdkit.Chem import AllChem
9
+ from typing import List
10
+
11
+
12
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
13
+ """
14
+ Create ECFP fingerprint of a molecule
15
+ """
16
+ if hashed:
17
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
18
+ else:
19
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
20
+ fp_np = np.zeros((1,))
21
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
22
+ return fp_np.reshape(1, -1)
23
+
24
+
25
+ def fingerprints_from_smiles(smiles: List, size=2048):
26
+ """ Create ECFP fingerprints of smiles, with validity check """
27
+ fps = []
28
+ valid_mask = []
29
+ for i, smile in enumerate(smiles):
30
+ mol = Chem.MolFromSmiles(smile)
31
+ valid_mask.append(int(mol is not None))
32
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
33
+ fps.append(fp)
34
+
35
+ fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
36
+ return fps, valid_mask
37
+
38
+
39
+ def getMolDescriptors(mol, missingVal=0):
40
+ """ calculate the full list of descriptors for a molecule """
41
+
42
+ values, names = [], []
43
+ for nm, fn in Descriptors._descList:
44
+ try:
45
+ val = fn(mol)
46
+ except:
47
+ val = missingVal
48
+ values.append(val)
49
+ names.append(nm)
50
+
51
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
52
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
53
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
54
+
55
+ for nm, fn in custom_descriptors.items():
56
+ try:
57
+ val = fn(mol)
58
+ except:
59
+ val = missingVal
60
+ values.append(val)
61
+ names.append(nm)
62
+ return values, names
63
+
64
+
65
+ def get_pep_dps_from_smi(smi):
66
+ try:
67
+ mol = Chem.MolFromSmiles(smi)
68
+ except:
69
+ print(f"convert smi {smi} to molecule failed!")
70
+ mol = None
71
+
72
+ dps, _ = getMolDescriptors(mol)
73
+ return np.array(dps)
74
+
75
+
76
+ def get_pep_dps(smi_list):
77
+ if len(smi_list) == 0:
78
+ return np.zeros((0, 211))
79
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
80
+
81
+
82
+
83
+ def check_smi_validity(smiles: list):
84
+ valid_smi, valid_idx = [], []
85
+ for idx, smi in enumerate(smiles):
86
+ try:
87
+ mol = Chem.MolFromSmiles(smi) if smi else None
88
+ if mol:
89
+ valid_smi.append(smi)
90
+ valid_idx.append(idx)
91
+ except Exception as e:
92
+ # logger.debug(f'Error: {e} in smiles {smi}')
93
+ pass
94
+ return valid_smi, valid_idx
a2d2_pep/pep_scoring/functions/solubility.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Solubility:
15
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
17
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/solubility-xgboost.json')
18
+ if emb_model is not None:
19
+ self.emb_model = emb_model.to(self.device).eval()
20
+ else:
21
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
22
+
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ return scores
51
+
52
+ def __call__(self, input_seqs: list):
53
+ scores = self.get_scores(input_seqs)
54
+ return scores
55
+
56
+ def unittest():
57
+ solubility = Solubility()
58
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
59
+ scores = solubility(input_seqs=seq)
60
+ print(scores)
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
a2d2_pep/pep_scoring/scoring_functions.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
3
+ from transformers import AutoModelForMaskedLM
4
+ import numpy as np
5
+ from .functions.binding import BindingAffinity
6
+ from .functions.permeability import Permeability
7
+ from .functions.solubility import Solubility
8
+ from .functions.hemolysis import Hemolysis
9
+ from .functions.nonfouling import Nonfouling
10
+
11
+ # base path: this package directory (holds tokenizer/ and functions/classifiers/)
12
+ base_path = os.path.dirname(os.path.abspath(__file__))
13
+
14
+ class ScoringFunctions:
15
+ def __init__(self, score_func_names=None, prot_seqs=None, device=None):
16
+ """
17
+ Class for generating score vectors given generated sequence
18
+
19
+ Args:
20
+ score_func_names: list of scoring function names to be evaluated
21
+ score_weights: weights to scale scores (default: 1)
22
+ target_protein: sequence of target protein binder
23
+ """
24
+ emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
25
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tokenizer/new_vocab.txt',
26
+ f'{base_path}/tokenizer/new_splits.txt')
27
+ prot_seqs = prot_seqs if prot_seqs is not None else []
28
+
29
+ if score_func_names is None:
30
+ # just do unmasking based on validity of peptide bonds
31
+ self.score_func_names = []
32
+ else:
33
+ self.score_func_names = score_func_names
34
+
35
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
36
+
37
+ # binding affinities
38
+ self.target_protein = prot_seqs
39
+ print(len(prot_seqs))
40
+
41
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
42
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
43
+ binding_affinity2 = None
44
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
45
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
46
+ binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
47
+ else:
48
+ print("here")
49
+ binding_affinity1 = None
50
+ binding_affinity2 = None
51
+
52
+ permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
53
+ sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
54
+ nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
55
+ hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
56
+
57
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
58
+ 'binding_affinity2': binding_affinity2,
59
+ 'permeability': permeability,
60
+ 'nonfouling': nonfouling,
61
+ 'solubility': sol,
62
+ 'hemolysis': hemo
63
+ }
64
+
65
+ def forward(self, input_seqs):
66
+ scores = []
67
+
68
+ for i, score_func in enumerate(self.score_func_names):
69
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
70
+
71
+ scores.append(score)
72
+
73
+ # convert to numpy arrays with shape (num_sequences, num_functions)
74
+ scores = np.float32(scores).T
75
+
76
+ return scores
77
+
78
+ def __call__(self, input_seqs: list):
79
+ return self.forward(input_seqs)
a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
a2d2_pep/pep_utils/analyzer.py ADDED
@@ -0,0 +1,1274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from io import StringIO
5
+ import rdkit
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, Draw
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from io import BytesIO
13
+ import tempfile
14
+ from rdkit import Chem
15
+
16
+ class PeptideAnalyzer:
17
+ def __init__(self, min_peptide_bonds=2, enforce_min_peptide_bonds=True):
18
+ # length cutoff: minimum number of backbone residues (N-Cα-C(=O) units)
19
+
20
+ self.min_peptide_bonds = min_peptide_bonds
21
+ self.enforce_min_peptide_bonds = enforce_min_peptide_bonds
22
+ self.bond_patterns = [
23
+ (r'OC\(=O\)', 'ester'), # Ester bond
24
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
25
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
26
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
27
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
28
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
29
+ ]
30
+ # Three to one letter code mapping
31
+ self.three_to_one = {
32
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
33
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
34
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
35
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
36
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
37
+ }
38
+
39
+ def count_peptide_bonds(self, smiles):
40
+ """Count backbone peptide residues via N-Cα-C(=O) units.
41
+
42
+ Matches the backbone pattern [NX3][CX4][CX3](=O): an amide nitrogen
43
+ bonded to an sp3 alpha-carbon bonded to a carbonyl. Requiring the sp3
44
+ Cα excludes non-backbone amides — ureas/biurets (N-C(=O)-N, no Cα),
45
+ sulfonamides, and side-chain amides (Asn/Gln) — and uniquify=True
46
+ avoids the multiple-mapping over-count of symmetric N-methyl groups.
47
+ Each match corresponds to one backbone residue.
48
+ """
49
+ mol = Chem.MolFromSmiles(smiles)
50
+ if mol is None:
51
+ return 0
52
+ backbone_pattern = Chem.MolFromSmarts('[NX3][CX4][CX3](=O)')
53
+ return len(mol.GetSubstructMatches(backbone_pattern, uniquify=True))
54
+
55
+ def is_peptide(self, smiles):
56
+ """Check if the SMILES represents a peptide structure"""
57
+ mol = Chem.MolFromSmiles(smiles)
58
+ if mol is None:
59
+ return False
60
+
61
+ # Count backbone residues (N-Cα-C(=O) units). Requiring a real backbone
62
+ # unit rejects ureas/biurets and side-chain-only amides outright.
63
+ n_residues = self.count_peptide_bonds(smiles)
64
+ if n_residues == 0:
65
+ return False
66
+
67
+ # length cutoff: reject molecules with too few backbone residues
68
+ if self.enforce_min_peptide_bonds and n_residues < self.min_peptide_bonds:
69
+ return False
70
+
71
+ return True
72
+
73
+ def is_cyclic(self, smiles):
74
+ """Improved cyclic peptide detection"""
75
+ # Check for C-terminal carboxyl
76
+ if smiles.endswith('C(=O)O'):
77
+ return False, [], []
78
+
79
+ # Find all numbers used in ring closures
80
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
81
+
82
+ # Find aromatic ring numbers
83
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
84
+ aromatic_cycles = []
85
+ for match in aromatic_matches:
86
+ numbers = re.findall(r'[0-9]', match)
87
+ aromatic_cycles.extend(numbers)
88
+
89
+ # Numbers that aren't part of aromatic rings are peptide cycles
90
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
91
+
92
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
93
+ return is_cyclic, peptide_cycles, aromatic_cycles
94
+
95
+ def split_on_bonds(self, smiles):
96
+ """Split SMILES into segments with simplified Pro handling"""
97
+ positions = []
98
+ used = set()
99
+
100
+ # Find Gly pattern first
101
+ gly_pattern = r'NCC\(=O\)'
102
+ for match in re.finditer(gly_pattern, smiles):
103
+ if not any(p in range(match.start(), match.end()) for p in used):
104
+ positions.append({
105
+ 'start': match.start(),
106
+ 'end': match.end(),
107
+ 'type': 'gly',
108
+ 'pattern': match.group()
109
+ })
110
+ used.update(range(match.start(), match.end()))
111
+
112
+ for pattern, bond_type in self.bond_patterns:
113
+ for match in re.finditer(pattern, smiles):
114
+ if not any(p in range(match.start(), match.end()) for p in used):
115
+ positions.append({
116
+ 'start': match.start(),
117
+ 'end': match.end(),
118
+ 'type': bond_type,
119
+ 'pattern': match.group()
120
+ })
121
+ used.update(range(match.start(), match.end()))
122
+
123
+ # Sort by position
124
+ positions.sort(key=lambda x: x['start'])
125
+
126
+ # Create segments
127
+ segments = []
128
+
129
+ if positions:
130
+ # First segment
131
+ if positions[0]['start'] > 0:
132
+ segments.append({
133
+ 'content': smiles[0:positions[0]['start']],
134
+ 'bond_after': positions[0]['pattern']
135
+ })
136
+
137
+ # Process segments
138
+ for i in range(len(positions)-1):
139
+ current = positions[i]
140
+ next_pos = positions[i+1]
141
+
142
+ if current['type'] == 'gly':
143
+ segments.append({
144
+ 'content': 'NCC(=O)',
145
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
146
+ 'bond_after': next_pos['pattern']
147
+ })
148
+ else:
149
+ content = smiles[current['end']:next_pos['start']]
150
+ if content:
151
+ segments.append({
152
+ 'content': content,
153
+ 'bond_before': current['pattern'],
154
+ 'bond_after': next_pos['pattern']
155
+ })
156
+
157
+ # Last segment
158
+ if positions[-1]['end'] < len(smiles):
159
+ segments.append({
160
+ 'content': smiles[positions[-1]['end']:],
161
+ 'bond_before': positions[-1]['pattern']
162
+ })
163
+
164
+ return segments
165
+
166
+ def clean_terminal_carboxyl(self, segment):
167
+ """Remove C-terminal carboxyl only if it's the true terminus"""
168
+ content = segment['content']
169
+
170
+ # Only clean if:
171
+ # 1. Contains C(=O)O
172
+ # 2. No bond_after exists (meaning it's the last segment)
173
+ # 3. C(=O)O is at the end of the content
174
+ if 'C(=O)O' in content and not segment.get('bond_after'):
175
+ print('recognized?')
176
+ # Remove C(=O)O pattern regardless of position
177
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
178
+ # Remove any leftover empty parentheses
179
+ cleaned = re.sub(r'\(\)', '', cleaned)
180
+ print(cleaned)
181
+ return cleaned
182
+ return content
183
+
184
+ def identify_residue(self, segment):
185
+ """Identify residue with Pro reconstruction"""
186
+ # Only clean terminal carboxyl if this is the last segment
187
+ content = self.clean_terminal_carboxyl(segment)
188
+ mods = self.get_modifications(segment)
189
+
190
+ # UAA pattern matching section - before regular residues
191
+ # Phenylglycine and derivatives
192
+ if 'c1ccccc1' in content:
193
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
194
+ return '4', mods # Base phenylglycine
195
+
196
+ # 4-substituted phenylalanines
197
+ if 'Cc1ccc' in content:
198
+ if 'OMe' in content or 'OCc1ccc' in content:
199
+ return '0A1', mods # 4-methoxy-Phenylalanine
200
+ elif 'Clc1ccc' in content:
201
+ return '200', mods # 4-chloro-Phenylalanine
202
+ elif 'Brc1ccc' in content:
203
+ return '4BF', mods # 4-Bromo-phenylalanine
204
+ elif 'C#Nc1ccc' in content:
205
+ return '4CF', mods # 4-cyano-phenylalanine
206
+ elif 'Ic1ccc' in content:
207
+ return 'PHI', mods # 4-Iodo-phenylalanine
208
+ elif 'Fc1ccc' in content:
209
+ return 'PFF', mods # 4-Fluoro-phenylalanine
210
+
211
+ # Modified tryptophans
212
+ if 'c[nH]c2' in content:
213
+ if 'Oc2cccc2' in content:
214
+ return '0AF', mods # 7-hydroxy-tryptophan
215
+ elif 'Fc2cccc2' in content:
216
+ return '4FW', mods # 4-fluoro-tryptophan
217
+ elif 'Clc2cccc2' in content:
218
+ return '6CW', mods # 6-chloro-tryptophan
219
+ elif 'Brc2cccc2' in content:
220
+ return 'BTR', mods # 6-bromo-tryptophan
221
+ elif 'COc2cccc2' in content:
222
+ return 'MOT5', mods # 5-Methoxy-tryptophan
223
+ elif 'Cc2cccc2' in content:
224
+ return 'MTR5', mods # 5-Methyl-tryptophan
225
+
226
+ # Special amino acids
227
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
228
+ return 'BUG', mods # Tertleucine
229
+
230
+ if 'CCCNC(=N)N' in content:
231
+ return 'CIR', mods # Citrulline
232
+
233
+ if '[SeH]' in content:
234
+ return 'CSE', mods # Selenocysteine
235
+
236
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
237
+ return 'DAB', mods # Diaminobutyric acid
238
+
239
+ if 'C1CCCCC1' in content:
240
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
241
+ return 'CHG', mods # Cyclohexylglycine
242
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
243
+ return 'ALC', mods # 3-cyclohexyl-alanine
244
+
245
+ # Naphthalene derivatives
246
+ if 'c1cccc2c1cccc2' in content:
247
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
248
+ return 'NAL', mods # 2-Naphthyl-alanine
249
+
250
+ # Heteroaromatic derivatives
251
+ if 'c1cncc' in content:
252
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
253
+ if 'c1cscc' in content:
254
+ return 'THA3', mods # 3-(3-thienyl)-alanine
255
+ if 'c1nnc' in content:
256
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
257
+
258
+ # Modified serines and threonines
259
+ if 'OP(O)(O)O' in content:
260
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
261
+ return 'SEP', mods # phosphoserine
262
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
263
+ return 'TPO', mods # phosphothreonine
264
+
265
+ # Specialized ring systems
266
+ if 'c1c2ccccc2cc2c1cccc2' in content:
267
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
268
+ if 'c1csc2c1cccc2' in content:
269
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
270
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
271
+ return 'ADAM', mods # Adamanthane
272
+
273
+ # Fluorinated derivatives
274
+ if 'FC(F)(F)' in content:
275
+ if 'CC(F)(F)F' in content:
276
+ return 'FLA', mods # Trifluoro-alanine
277
+ if 'C(F)(F)F)c1' in content:
278
+ if 'c1ccccc1C(F)(F)F' in content:
279
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
280
+ if 'c1cccc(c1)C(F)(F)F' in content:
281
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
282
+ if 'c1ccc(cc1)C(F)(F)F' in content:
283
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
284
+
285
+ # Multiple halogen patterns
286
+ if 'F' in content and 'c1' in content:
287
+ if 'c1ccc(c(c1)F)F' in content:
288
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
289
+ if 'cc(F)cc(c1)F' in content:
290
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
291
+ if 'Cl' in content and 'c1' in content:
292
+ if 'c1ccc(cc1Cl)Cl' in content:
293
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
294
+ if 'c1ccc(c(c1)Cl)Cl' in content:
295
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
296
+
297
+ # Hydroxy and amino derivatives
298
+ if 'O' in content and 'c1' in content:
299
+ if 'c1cc(O)cc(c1)O' in content:
300
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
301
+ if 'c1ccc(c(c1)O)O' in content:
302
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
303
+
304
+ # Cyclic amino acids
305
+ if 'C1CCCC1' in content:
306
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
307
+ if 'C1CCCCC1' in content:
308
+ if 'CC1CCCCC1' in content:
309
+ return 'ALC', mods # 3-cyclohexyl-alanine
310
+ else:
311
+ return 'CHG', mods # Cyclohexylglycine
312
+
313
+ # Chain-length variants
314
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
315
+ return 'NLE', mods # Norleucine
316
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
317
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
318
+ return 'ABA', mods # 2-Aminobutyric acid
319
+
320
+ # Modified histidines
321
+ if 'c1cnc' in content:
322
+ if '[C@@H]1CN[C@@H](N1)F' in content:
323
+ return '2HF', mods # 2-fluoro-l-histidine
324
+ if 'c1cnc([nH]1)F' in content:
325
+ return '2HF1', mods # 2-fluoro-l-histidine variant
326
+ if 'c1c[nH]c(n1)F' in content:
327
+ return '2HF2', mods # 2-fluoro-l-histidine variant
328
+
329
+ # Sulfur and selenium containing
330
+ if '[SeH]' in content:
331
+ return 'CSE', mods # Selenocysteine
332
+ if 'S' in content:
333
+ if 'CSCc1ccccc1' in content:
334
+ return 'BCS', mods # benzylcysteine
335
+ if 'CCSC' in content:
336
+ return 'ESC', mods # Ethionine
337
+ if 'CCS' in content:
338
+ return 'HCS', mods # homocysteine
339
+
340
+ # Additional modifications
341
+ if 'CN=[N]=N' in content:
342
+ return 'AZDA', mods # azido-alanine
343
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
344
+ if 'CCC[NH]=' in content:
345
+ return 'AGM', mods # 5-methyl-arginine
346
+ if 'CC[NH]=' in content:
347
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
348
+
349
+ if 'CCON' in content:
350
+ return 'CAN', mods # canaline
351
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
352
+ return 'ACZ', mods # cis-amiclenomycin
353
+ if 'CCC(=O)[NH3]' in content:
354
+ return 'ONL', mods # 5-oxo-l-norleucine
355
+ if 'c1ccncc1' in content:
356
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
357
+ if 'c1ccco1' in content:
358
+ return 'FUA2', mods # (2-furyl)-alanine
359
+
360
+ if 'c1ccc' in content:
361
+ if 'c1ccc(cc1)c1ccccc1' in content:
362
+ return 'BIF', mods # 4,4-biphenylalanine
363
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
364
+ return 'PBF', mods # 4-benzoyl-phenylalanine
365
+ if 'c1ccc(cc1)C(C)(C)C' in content:
366
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
367
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
368
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
369
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
370
+ return 'APM', mods # m-amidinophenyl-3-alanine
371
+
372
+ # Multiple hydroxy patterns
373
+ if 'O' in content:
374
+ if '[C@H]([C@H](C)O)O' in content:
375
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
376
+ if '[C@H]([C@@H](C)O)O' in content:
377
+ return 'ALO', mods # Allo-threonine
378
+ if '[C@H](COP(O)(O)O)' in content:
379
+ return 'SEP', mods # phosphoserine
380
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
381
+ return 'TPO', mods # phosphothreonine
382
+ if '[C@H](c1ccc(O)cc1)O' in content:
383
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
384
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
385
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
386
+
387
+ # Heterocyclic patterns
388
+ if 'n1' in content:
389
+ if 'n1cccn1' in content:
390
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
391
+ if 'n1nncn1' in content:
392
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
393
+ if 'c2c(n1)cccc2' in content:
394
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
395
+ if 'c1cnc2c(c1)cccc2' in content:
396
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
397
+ if 'c1ccnc2c1cccc2' in content:
398
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
399
+ if 'c1ccc2c(c1)nccc2' in content:
400
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
401
+ if 'c1ccc2c(c1)cncc2' in content:
402
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
403
+ if 'c1cnc2c(n1)cccc2' in content:
404
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
405
+
406
+ # Multiple nitrogen patterns
407
+ if 'N' in content:
408
+ if '[NH3]CC[C@@H]' in content:
409
+ return 'DAB', mods # Diaminobutyric acid
410
+ if '[NH3]C[C@@H]' in content:
411
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
412
+ if '[NH3]CCCCCC[C@@H]' in content:
413
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
414
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
415
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
416
+ if '[NH]=[C](=S)=[NH2]' in content:
417
+ return 'THIC', mods # Thio-citrulline
418
+
419
+ # Chain modified amino acids
420
+ if 'CC' in content:
421
+ if 'CCCC[C@@H]' in content:
422
+ return 'AHP', mods # 2-Aminoheptanoic acid
423
+ if 'CCC([C@@H])(C)C' in content:
424
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
425
+ if 'CC[C@H]([C@@H])C' in content:
426
+ return 'IIL', mods # Allo-Isoleucine
427
+ if '[C@H](CCC(C)C)' in content:
428
+ return 'HLEU', mods # Homoleucine
429
+ if '[C@@H]([C@@H](C)O)C' in content:
430
+ return 'HLU', mods # beta-hydroxyleucine
431
+
432
+ # Modified glutamate/aspartate patterns
433
+ if '[C@@H]' in content:
434
+ if '[C@@H](C[C@@H](F))' in content:
435
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
436
+ if '[C@@H](C[C@@H](O))' in content:
437
+ return '3GL', mods # 4-hydroxy-glutamic-acid
438
+ if '[C@@H](C[C@H](C))' in content:
439
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
440
+ if '[C@@H](CC[C@H](C))' in content:
441
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
442
+
443
+ # Sulfur and selenium modifications
444
+ if 'S' in content:
445
+ if 'SCC[C@@H]' in content:
446
+ return 'HSER', mods # homoserine
447
+ if 'SCCN' in content:
448
+ return 'SLZ', mods # thialysine
449
+ if 'SC(=O)' in content:
450
+ return 'CSA', mods # s-acetonylcysteine
451
+ if '[S@@](=O)' in content:
452
+ return 'SME', mods # Methionine sulfoxide
453
+ if 'S(=O)(=O)' in content:
454
+ return 'OMT', mods # Methionine sulfone
455
+
456
+ # Double bond containing
457
+ if 'C=' in content:
458
+ if 'C=C[C@@H]' in content:
459
+ return '2AG', mods # 2-Allyl-glycine
460
+ if 'C=C[C@@H]' in content:
461
+ return 'LVG', mods # vinylglycine
462
+ if 'C=Cc1ccccc1' in content:
463
+ return 'STYA', mods # Styrylalanine
464
+
465
+ # Special cases
466
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
467
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
468
+ if '[C](=[C](=O)=O)=O' in content:
469
+ return '26P', mods # 2-amino-6-oxopimelic acid
470
+ if '[C](=[C](=O)=O)=C' in content:
471
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
472
+ if 'c2cnc[nH]2' in content:
473
+ return 'HIS', mods # histidine core
474
+ if 'c1cccc2c1cc(O)cc2' in content:
475
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
476
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
477
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
478
+
479
+ # Proline (P) - flexible ring numbers
480
+ if any([
481
+ # Check for any ring number in bond patterns
482
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
483
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
484
+ for n in '123456789'
485
+ ]) or any([
486
+ # Check ending patterns with any ring number
487
+ (f'CCCN{n}' in content and content.endswith('=O') and
488
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
489
+ for n in '123456789'
490
+ ]) or any([
491
+ # Handle CCC[C@H]n patterns
492
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
493
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
494
+ # N-terminal Pro with any ring number
495
+ (f'N{n}CCC[C@H]{n}' in content) or
496
+ (f'N{n}CCC[C@@H]{n}' in content)
497
+ for n in '123456789'
498
+ ]):
499
+ return 'Pro', mods
500
+
501
+ # Tryptophan (W) - more specific indole pattern
502
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
503
+ 'c[nH]c' in content.replace(' ', ''):
504
+ return 'Trp', mods
505
+
506
+ # Lysine (K) - both patterns
507
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
508
+ return 'Lys', mods
509
+
510
+ # Arginine (R) - both patterns
511
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
512
+ return 'Arg', mods
513
+
514
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
515
+ return 'Nle', mods
516
+
517
+ # Ornithine (Orn) - 3-carbon chain with NH2
518
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
519
+ return 'Orn', mods
520
+
521
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
522
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
523
+ return '2Nal', mods
524
+
525
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
526
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
527
+ return 'Cha', mods
528
+
529
+ # Aminobutyric acid (Abu) - 2-carbon chain
530
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
531
+ return 'Abu', mods
532
+
533
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
534
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
535
+ return 'Pip', mods
536
+
537
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
538
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
539
+ return 'Chg', mods
540
+
541
+ # 4-Fluorophenylalanine (4F-Phe)
542
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
543
+ return '4F-Phe', mods
544
+
545
+ # Regular residue identification
546
+ if ('NCC(=O)' in content) or (content == 'C'):
547
+ # Middle case - between bonds
548
+ if segment.get('bond_before') and segment.get('bond_after'):
549
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
550
+ return 'Gly', mods
551
+ # Terminal case - at the end
552
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
553
+ return 'Gly', mods
554
+
555
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
556
+ return 'Leu', mods
557
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
558
+ return 'Leu', mods
559
+
560
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
561
+ return 'Thr', mods
562
+
563
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
564
+ return 'Phe', mods
565
+
566
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
567
+ '[C@@H](C(C)C)' in content or # With outer parentheses
568
+ '[C@H]C(C)C' in content or # Without outer parentheses
569
+ '[C@@H]C(C)C' in content): # Without outer parentheses
570
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
571
+ return 'Val', mods
572
+
573
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
574
+ return 'O-tBu', mods
575
+
576
+ if any([
577
+ 'CC[C@H](C)' in content,
578
+ 'CC[C@@H](C)' in content,
579
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
580
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
581
+ ]):
582
+ return 'Ile', mods
583
+
584
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
585
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
586
+ return 'Ala', mods
587
+
588
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
589
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
590
+ return 'Tyr', mods
591
+
592
+
593
+ # Serine (Ser) - Hydroxymethyl side chain
594
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
595
+ if not ('C(C)O' in content or 'COC' in content):
596
+ return 'Ser', mods
597
+
598
+ # Threonine (Thr) - 1-hydroxyethyl side chain
599
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
600
+ return 'Thr', mods
601
+
602
+ # Cysteine (Cys) - Thiol side chain
603
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
604
+ return 'Cys', mods
605
+
606
+ # Methionine (Met) - Methylthioethyl side chain
607
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
608
+ return 'Met', mods
609
+
610
+ # Asparagine (Asn) - Carbamoylmethyl side chain
611
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
612
+ return 'Asn', mods
613
+
614
+ # Glutamine (Gln) - Carbamoylethyl side chain
615
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
616
+ return 'Gln', mods
617
+
618
+ # Aspartic acid (Asp) - Carboxymethyl side chain
619
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
620
+ return 'Asp', mods
621
+
622
+ # Glutamic acid (Glu) - Carboxyethyl side chain
623
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
624
+ return 'Glu', mods
625
+
626
+ # Arginine (Arg) - 3-guanidinopropyl side chain
627
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
628
+ return 'Arg', mods
629
+
630
+ # Histidine (His) - Imidazole side chain
631
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
632
+ return 'His', mods
633
+
634
+ return None, mods
635
+
636
+ def get_modifications(self, segment):
637
+ """Get modifications based on bond types"""
638
+ mods = []
639
+ if segment.get('bond_after'):
640
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
641
+ mods.append('N-Me')
642
+ if 'OC(=O)' in segment['bond_after']:
643
+ mods.append('O-linked')
644
+ return mods
645
+
646
+ def analyze_structure(self, smiles):
647
+ """Main analysis function with debug output"""
648
+ print("\nAnalyzing structure:", smiles)
649
+
650
+ # Split into segments
651
+ segments = self.split_on_bonds(smiles)
652
+
653
+ print("\nSegment Analysis:")
654
+ sequence = []
655
+ for i, segment in enumerate(segments):
656
+ print(f"\nSegment {i}:")
657
+ print(f"Content: {segment['content']}")
658
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
659
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
660
+
661
+ residue, mods = self.identify_residue(segment)
662
+ if residue:
663
+ if mods:
664
+ sequence.append(f"{residue}({','.join(mods)})")
665
+ else:
666
+ sequence.append(residue)
667
+ print(f"Identified as: {residue}")
668
+ print(f"Modifications: {mods}")
669
+ else:
670
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
671
+
672
+ # Check if cyclic
673
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
674
+ three_letter = '-'.join(sequence)
675
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
676
+
677
+ if is_cyclic:
678
+ three_letter = f"cyclo({three_letter})"
679
+ one_letter = f"cyclo({one_letter})"
680
+
681
+ print(f"\nFinal sequence: {three_letter}")
682
+ print(f"One-letter code: {one_letter}")
683
+ print(f"Is cyclic: {is_cyclic}")
684
+ #print(f"Peptide cycles: {peptide_cycles}")
685
+ #print(f"Aromatic cycles: {aromatic_cycles}")
686
+
687
+ return three_letter, len(segments)
688
+ """return {
689
+ 'three_letter': three_letter,
690
+ #'one_letter': one_letter,
691
+ 'is_cyclic': is_cyclic
692
+ }"""
693
+
694
+ def return_sequence(self, smiles):
695
+ """Main analysis function with debug output"""
696
+ print("\nAnalyzing structure:", smiles)
697
+
698
+ # Split into segments
699
+ segments = self.split_on_bonds(smiles)
700
+
701
+ print("\nSegment Analysis:")
702
+ sequence = []
703
+ for i, segment in enumerate(segments):
704
+ print(f"\nSegment {i}:")
705
+ print(f"Content: {segment['content']}")
706
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
707
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
708
+
709
+ residue, mods = self.identify_residue(segment)
710
+ if residue:
711
+ if mods:
712
+ sequence.append(f"{residue}({','.join(mods)})")
713
+ else:
714
+ sequence.append(residue)
715
+ print(f"Identified as: {residue}")
716
+ print(f"Modifications: {mods}")
717
+ else:
718
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
719
+
720
+ return sequence
721
+
722
+ """
723
+ def annotate_cyclic_structure(mol, sequence):
724
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
725
+ # Generate 2D coordinates
726
+ # Generate 2D coordinates
727
+ AllChem.Compute2DCoords(mol)
728
+
729
+ # Create drawer with larger size for annotations
730
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
731
+
732
+ # Get residue list and reverse it to match structural representation
733
+ if sequence.startswith('cyclo('):
734
+ residues = sequence[6:-1].split('-')
735
+ else:
736
+ residues = sequence.split('-')
737
+ residues = list(reversed(residues)) # Reverse the sequence
738
+
739
+ # Draw molecule first to get its bounds
740
+ drawer.drawOptions().addAtomIndices = False
741
+ drawer.DrawMolecule(mol)
742
+ drawer.FinishDrawing()
743
+
744
+ # Convert to PIL Image
745
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
746
+ draw = ImageDraw.Draw(img)
747
+
748
+ try:
749
+ # Try to use DejaVuSans as it's commonly available on Linux systems
750
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
751
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
752
+ except OSError:
753
+ try:
754
+ # Fallback to Arial if available (common on Windows)
755
+ font = ImageFont.truetype("arial.ttf", 60)
756
+ small_font = ImageFont.truetype("arial.ttf", 60)
757
+ except OSError:
758
+ # If no TrueType fonts are available, fall back to default
759
+ print("Warning: TrueType fonts not available, using default font")
760
+ font = ImageFont.load_default()
761
+ small_font = ImageFont.load_default()
762
+ # Get molecule bounds
763
+ conf = mol.GetConformer()
764
+ positions = []
765
+ for i in range(mol.GetNumAtoms()):
766
+ pos = conf.GetAtomPosition(i)
767
+ positions.append((pos.x, pos.y))
768
+
769
+ x_coords = [p[0] for p in positions]
770
+ y_coords = [p[1] for p in positions]
771
+ min_x, max_x = min(x_coords), max(x_coords)
772
+ min_y, max_y = min(y_coords), max(y_coords)
773
+
774
+ # Calculate scaling factors
775
+ scale = 150 # Increased scale factor
776
+ center_x = 1000 # Image center
777
+ center_y = 1000
778
+
779
+ # Add residue labels in a circular arrangement around the structure
780
+ n_residues = len(residues)
781
+ radius = 700 # Distance of labels from center
782
+
783
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
784
+ # Offset by -3 positions to align with structure
785
+ offset = 0 # Adjust this value to match the structure alignment
786
+ for i, residue in enumerate(residues):
787
+ # Calculate position in a circle around the structure
788
+ # Start from 0 (3 o'clock) and go counterclockwise
789
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
790
+
791
+ # Calculate label position
792
+ label_x = center_x + radius * np.cos(angle)
793
+ label_y = center_y + radius * np.sin(angle)
794
+
795
+ # Draw residue label
796
+ text = f"{i+1}. {residue}"
797
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
798
+ padding = 10
799
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
800
+ bbox[2]+padding, bbox[3]+padding],
801
+ fill='white', outline='white')
802
+ draw.text((label_x, label_y), text,
803
+ font=font, fill='black', anchor="mm")
804
+
805
+ # Add sequence at the top with white background
806
+ seq_text = f"Sequence: {sequence}"
807
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
808
+ padding = 10
809
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
810
+ bbox[2]+padding, bbox[3]+padding],
811
+ fill='white', outline='white')
812
+ draw.text((center_x, 100), seq_text,
813
+ font=small_font, fill='black', anchor="mm")
814
+
815
+ return img
816
+
817
+ """
818
+ def annotate_cyclic_structure(mol, sequence):
819
+ """Create structure visualization with just the sequence header"""
820
+ # Generate 2D coordinates
821
+ AllChem.Compute2DCoords(mol)
822
+
823
+ # Create drawer with larger size for annotations
824
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
825
+
826
+ # Draw molecule first
827
+ drawer.drawOptions().addAtomIndices = False
828
+ drawer.DrawMolecule(mol)
829
+ drawer.FinishDrawing()
830
+
831
+ # Convert to PIL Image
832
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
833
+ draw = ImageDraw.Draw(img)
834
+ try:
835
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
836
+ except OSError:
837
+ try:
838
+ small_font = ImageFont.truetype("arial.ttf", 60)
839
+ except OSError:
840
+ print("Warning: TrueType fonts not available, using default font")
841
+ small_font = ImageFont.load_default()
842
+
843
+ # Add just the sequence header at the top
844
+ seq_text = f"Sequence: {sequence}"
845
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
846
+ padding = 10
847
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
848
+ bbox[2]+padding, bbox[3]+padding],
849
+ fill='white', outline='white')
850
+ draw.text((1000, 100), seq_text,
851
+ font=small_font, fill='black', anchor="mm")
852
+
853
+ return img
854
+
855
+ def create_enhanced_linear_viz(sequence, smiles):
856
+ """Create an enhanced linear representation using PeptideAnalyzer"""
857
+ analyzer = PeptideAnalyzer() # Create analyzer instance
858
+
859
+ # Create figure with two subplots
860
+ fig = plt.figure(figsize=(15, 10))
861
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
862
+ ax_struct = fig.add_subplot(gs[0])
863
+ ax_detail = fig.add_subplot(gs[1])
864
+
865
+ # Parse sequence and get residues
866
+ if sequence.startswith('cyclo('):
867
+ residues = sequence[6:-1].split('-')
868
+ else:
869
+ residues = sequence.split('-')
870
+
871
+ # Get segments using analyzer
872
+ segments = analyzer.split_on_bonds(smiles)
873
+
874
+ # Debug print
875
+ print(f"Number of residues: {len(residues)}")
876
+ print(f"Number of segments: {len(segments)}")
877
+
878
+ # Top subplot - Basic structure
879
+ ax_struct.set_xlim(0, 10)
880
+ ax_struct.set_ylim(0, 2)
881
+
882
+ num_residues = len(residues)
883
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
884
+
885
+ # Draw basic structure
886
+ y_pos = 1.5
887
+ for i in range(num_residues):
888
+ x_pos = 0.5 + i * spacing
889
+
890
+ # Draw amino acid box
891
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
892
+ facecolor='lightblue', edgecolor='black')
893
+ ax_struct.add_patch(rect)
894
+
895
+ # Draw connecting bonds if not the last residue
896
+ if i < num_residues - 1:
897
+ segment = segments[i] if i < len(segments) else None
898
+ if segment:
899
+ # Determine bond type from segment info
900
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
901
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
902
+
903
+ bond_color = 'red' if bond_type == 'ester' else 'black'
904
+ linestyle = '--' if bond_type == 'ester' else '-'
905
+
906
+ # Draw bond line
907
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
908
+ color=bond_color, linestyle=linestyle, linewidth=2)
909
+
910
+ # Add bond type label
911
+ mid_x = x_pos + spacing/2
912
+ bond_label = f"{bond_type}"
913
+ if is_n_methylated:
914
+ bond_label += "\n(N-Me)"
915
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
916
+ ha='center', va='bottom', fontsize=10,
917
+ color=bond_color)
918
+
919
+ # Add residue label
920
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
921
+ ha='center', va='top', fontsize=14)
922
+
923
+ # Bottom subplot - Detailed breakdown
924
+ ax_detail.set_ylim(0, len(segments)+1)
925
+ ax_detail.set_xlim(0, 1)
926
+
927
+ # Create detailed breakdown
928
+ segment_y = len(segments) # Start from top
929
+ for i, segment in enumerate(segments):
930
+ y = segment_y - i
931
+
932
+ # Check if this is a bond or residue
933
+ residue, mods = analyzer.identify_residue(segment)
934
+ if residue:
935
+ text = f"Residue {i+1}: {residue}"
936
+ if mods:
937
+ text += f" ({', '.join(mods)})"
938
+ color = 'blue'
939
+ else:
940
+ # Must be a bond
941
+ text = f"Bond {i}: "
942
+ if 'O-linked' in segment.get('bond_after', ''):
943
+ text += "ester"
944
+ elif 'N-Me' in segment.get('bond_after', ''):
945
+ text += "peptide (N-methylated)"
946
+ else:
947
+ text += "peptide"
948
+ color = 'red'
949
+
950
+ # Add segment analysis
951
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
952
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
953
+
954
+ # If cyclic, add connection indicator
955
+ if sequence.startswith('cyclo('):
956
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
957
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
958
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
959
+ ha='center', color='red', fontsize=14)
960
+
961
+ # Add titles and adjust layout
962
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
963
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
964
+
965
+ # Remove axes
966
+ for ax in [ax_struct, ax_detail]:
967
+ ax.set_xticks([])
968
+ ax.set_yticks([])
969
+ ax.axis('off')
970
+
971
+ plt.tight_layout()
972
+ return fig
973
+
974
+ class PeptideStructureGenerator:
975
+ """A class to generate 3D structures of peptides using different embedding methods"""
976
+
977
+ @staticmethod
978
+ def prepare_molecule(smiles):
979
+ """Prepare molecule with proper hydrogen handling"""
980
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
981
+ if mol is None:
982
+ raise ValueError("Failed to create molecule from SMILES")
983
+
984
+ # Calculate valence for each atom
985
+ for atom in mol.GetAtoms():
986
+ atom.UpdatePropertyCache(strict=False)
987
+
988
+ # Sanitize with reduced requirements
989
+ Chem.SanitizeMol(mol,
990
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
991
+ Chem.SANITIZE_KEKULIZE|
992
+ Chem.SANITIZE_SETAROMATICITY|
993
+ Chem.SANITIZE_SETCONJUGATION|
994
+ Chem.SANITIZE_SETHYBRIDIZATION|
995
+ Chem.SANITIZE_CLEANUPCHIRALITY)
996
+
997
+ mol = Chem.AddHs(mol)
998
+ return mol
999
+
1000
+ @staticmethod
1001
+ def get_etkdg_params(attempt=0):
1002
+ """Get ETKDG parameters with optional modifications based on attempt number"""
1003
+ params = AllChem.ETKDGv3()
1004
+ params.randomSeed = -1
1005
+ params.maxIterations = 200
1006
+ params.numThreads = 4 # Reduced for web interface
1007
+ params.useBasicKnowledge = True
1008
+ params.enforceChirality = True
1009
+ params.useExpTorsionAnglePrefs = True
1010
+ params.useSmallRingTorsions = True
1011
+ params.useMacrocycleTorsions = True
1012
+ params.ETversion = 2
1013
+ params.pruneRmsThresh = -1
1014
+ params.embedRmsThresh = 0.5
1015
+
1016
+ if attempt > 10:
1017
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
1018
+ params.useExpTorsionAnglePrefs = False
1019
+
1020
+ return params
1021
+
1022
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
1023
+ """Generate 3D structure using ETKDG without UFF optimization"""
1024
+ success = False
1025
+ mol = None
1026
+
1027
+ for attempt in range(max_attempts):
1028
+ try:
1029
+ mol = self.prepare_molecule(smiles)
1030
+ params = self.get_etkdg_params(attempt)
1031
+
1032
+ if AllChem.EmbedMolecule(mol, params) == 0:
1033
+ success = True
1034
+ break
1035
+ except Exception as e:
1036
+ continue
1037
+
1038
+ if not success:
1039
+ raise ValueError("Failed to generate structure with ETKDG")
1040
+
1041
+ return mol
1042
+
1043
+ def generate_structure_uff(self, smiles, max_attempts=20):
1044
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
1045
+ best_mol = None
1046
+ lowest_energy = float('inf')
1047
+
1048
+ for attempt in range(max_attempts):
1049
+ try:
1050
+ test_mol = self.prepare_molecule(smiles)
1051
+ params = self.get_etkdg_params(attempt)
1052
+
1053
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
1054
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
1055
+ vdwThresh=10.0, confId=0,
1056
+ ignoreInterfragInteractions=True)
1057
+
1058
+ if res == 0:
1059
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
1060
+ if ff:
1061
+ current_energy = ff.CalcEnergy()
1062
+ if current_energy < lowest_energy:
1063
+ lowest_energy = current_energy
1064
+ best_mol = Chem.Mol(test_mol)
1065
+ except Exception:
1066
+ continue
1067
+
1068
+ if best_mol is None:
1069
+ raise ValueError("Failed to generate optimized structure")
1070
+
1071
+ return best_mol
1072
+
1073
+ @staticmethod
1074
+ def mol_to_sdf_bytes(mol):
1075
+ """Convert RDKit molecule to SDF file bytes"""
1076
+ # First write to StringIO in text mode
1077
+ sio = StringIO()
1078
+ writer = Chem.SDWriter(sio)
1079
+ writer.write(mol)
1080
+ writer.close()
1081
+
1082
+ # Convert the string to bytes
1083
+ return sio.getvalue().encode('utf-8')
1084
+
1085
+ def process_input(smiles_input=None, file_obj=None, show_linear=False,
1086
+ show_segment_details=False, generate_3d=False, use_uff=False):
1087
+ """Process input and create visualizations using PeptideAnalyzer"""
1088
+ analyzer = PeptideAnalyzer()
1089
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
1090
+ structure_files = []
1091
+
1092
+ # Handle direct SMILES input
1093
+ if smiles_input:
1094
+ smiles = smiles_input.strip()
1095
+
1096
+ # First check if it's a peptide using analyzer's method
1097
+ if not analyzer.is_peptide(smiles):
1098
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
1099
+
1100
+ try:
1101
+ # Create molecule
1102
+ mol = Chem.MolFromSmiles(smiles)
1103
+ if mol is None:
1104
+ return "Error: Invalid SMILES notation.", None, None
1105
+
1106
+ # Generate 3D structures if requested
1107
+ if generate_3d:
1108
+ generator = PeptideStructureGenerator()
1109
+
1110
+ try:
1111
+ # Generate ETKDG structure
1112
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
1113
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
1114
+ writer = Chem.SDWriter(etkdg_path)
1115
+ writer.write(mol_etkdg)
1116
+ writer.close()
1117
+ structure_files.append(etkdg_path)
1118
+
1119
+ # Generate UFF structure if requested
1120
+ if use_uff:
1121
+ mol_uff = generator.generate_structure_uff(smiles)
1122
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
1123
+ writer = Chem.SDWriter(uff_path)
1124
+ writer.write(mol_uff)
1125
+ writer.close()
1126
+ structure_files.append(uff_path)
1127
+
1128
+ except Exception as e:
1129
+ return f"Error generating 3D structures: {str(e)}", None, None, None
1130
+
1131
+ # Use analyzer to get sequence
1132
+ segments = analyzer.split_on_bonds(smiles)
1133
+
1134
+ # Process segments and build sequence
1135
+ sequence_parts = []
1136
+ output_text = ""
1137
+
1138
+ # Only include segment analysis in output if requested
1139
+ if show_segment_details:
1140
+ output_text += "Segment Analysis:\n"
1141
+ for i, segment in enumerate(segments):
1142
+ output_text += f"\nSegment {i}:\n"
1143
+ output_text += f"Content: {segment['content']}\n"
1144
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1145
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1146
+
1147
+ residue, mods = analyzer.identify_residue(segment)
1148
+ if residue:
1149
+ if mods:
1150
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1151
+ else:
1152
+ sequence_parts.append(residue)
1153
+ output_text += f"Identified as: {residue}\n"
1154
+ output_text += f"Modifications: {mods}\n"
1155
+ else:
1156
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
1157
+ output_text += "\n"
1158
+ else:
1159
+ # Just build sequence without detailed analysis in output
1160
+ for segment in segments:
1161
+ residue, mods = analyzer.identify_residue(segment)
1162
+ if residue:
1163
+ if mods:
1164
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1165
+ else:
1166
+ sequence_parts.append(residue)
1167
+
1168
+ # Check if cyclic using analyzer's method
1169
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1170
+ three_letter = '-'.join(sequence_parts)
1171
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
1172
+
1173
+ if is_cyclic:
1174
+ three_letter = f"cyclo({three_letter})"
1175
+ one_letter = f"cyclo({one_letter})"
1176
+
1177
+ # Create cyclic structure visualization
1178
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
1179
+
1180
+ # Create linear representation if requested
1181
+ img_linear = None
1182
+ if show_linear:
1183
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
1184
+ buf = BytesIO()
1185
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
1186
+ buf.seek(0)
1187
+ img_linear = Image.open(buf)
1188
+ plt.close(fig_linear)
1189
+
1190
+ # Add summary to output
1191
+ summary = "Summary:\n"
1192
+ summary += f"Sequence: {three_letter}\n"
1193
+ summary += f"One-letter code: {one_letter}\n"
1194
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1195
+ #if is_cyclic:
1196
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1197
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1198
+
1199
+ if structure_files:
1200
+ summary += "\n3D Structures Generated:\n"
1201
+ for filepath in structure_files:
1202
+ summary += f"- {os.path.basename(filepath)}\n"
1203
+
1204
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
1205
+
1206
+ except Exception as e:
1207
+ return f"Error processing SMILES: {str(e)}", None, None, None
1208
+
1209
+ # Handle file input
1210
+ if file_obj is not None:
1211
+ try:
1212
+ # Handle file content
1213
+ if hasattr(file_obj, 'name'):
1214
+ with open(file_obj.name, 'r') as f:
1215
+ content = f.read()
1216
+ else:
1217
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
1218
+
1219
+ output_text = ""
1220
+ for line in content.splitlines():
1221
+ smiles = line.strip()
1222
+ if smiles:
1223
+ # Check if it's a peptide
1224
+ if not analyzer.is_peptide(smiles):
1225
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
1226
+ continue
1227
+
1228
+ # Process this SMILES
1229
+ segments = analyzer.split_on_bonds(smiles)
1230
+ sequence_parts = []
1231
+
1232
+ # Add segment details if requested
1233
+ if show_segment_details:
1234
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
1235
+ for i, segment in enumerate(segments):
1236
+ output_text += f"\nSegment {i}:\n"
1237
+ output_text += f"Content: {segment['content']}\n"
1238
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1239
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1240
+ residue, mods = analyzer.identify_residue(segment)
1241
+ if residue:
1242
+ if mods:
1243
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1244
+ else:
1245
+ sequence_parts.append(residue)
1246
+ output_text += f"Identified as: {residue}\n"
1247
+ output_text += f"Modifications: {mods}\n"
1248
+ else:
1249
+ for segment in segments:
1250
+ residue, mods = analyzer.identify_residue(segment)
1251
+ if residue:
1252
+ if mods:
1253
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1254
+ else:
1255
+ sequence_parts.append(residue)
1256
+
1257
+ # Get cyclicity and create sequence
1258
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1259
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
1260
+
1261
+ output_text += f"\nSummary for SMILES: {smiles}\n"
1262
+ output_text += f"Sequence: {sequence}\n"
1263
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1264
+ if is_cyclic:
1265
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1266
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1267
+ output_text += "-" * 50 + "\n"
1268
+
1269
+ return output_text, None, None
1270
+
1271
+ except Exception as e:
1272
+ return f"Error processing file: {str(e)}", None, None
1273
+
1274
+ return "No input provided.", None, None
a2d2_pep/pep_utils/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import fsspec
9
+ import lightning
10
+ import torch
11
+ from timm.scheduler import CosineLRScheduler
12
+ import argparse
13
+ import numpy as np
14
+ import random
15
+ import os
16
+
17
+ def sample_categorical_logits(logits, dtype=torch.float64):
18
+ # do not require logits to be log-softmaxed
19
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
20
+ return (logits + gumbel_noise).argmax(dim=-1)
21
+
22
+ def fsspec_exists(filename):
23
+ """Check if a file exists using fsspec."""
24
+ fs, _ = fsspec.core.url_to_fs(filename)
25
+ return fs.exists(filename)
26
+
27
+
28
+ def fsspec_listdir(dirname):
29
+ """Listdir in manner compatible with fsspec."""
30
+ fs, _ = fsspec.core.url_to_fs(dirname)
31
+ return fs.ls(dirname)
32
+
33
+
34
+ def fsspec_mkdirs(dirname, exist_ok=True):
35
+ """Mkdirs in manner compatible with fsspec."""
36
+ fs, _ = fsspec.core.url_to_fs(dirname)
37
+ fs.makedirs(dirname, exist_ok=exist_ok)
38
+
39
+
40
+ def print_nans(tensor, name):
41
+ if torch.isnan(tensor).any():
42
+ print(name, tensor)
43
+
44
+
45
+ class CosineDecayWarmupLRScheduler(
46
+ CosineLRScheduler,
47
+ torch.optim.lr_scheduler._LRScheduler):
48
+
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self._last_epoch = -1
52
+ self.step(epoch=0)
53
+
54
+ def step(self, epoch=None):
55
+ if epoch is None:
56
+ self._last_epoch += 1
57
+ else:
58
+ self._last_epoch = epoch
59
+ # We call either step or step_update, depending on
60
+ # whether we're using the scheduler every epoch or every
61
+ # step.
62
+ # Otherwise, lightning will always call step (i.e.,
63
+ # meant for each epoch), and if we set scheduler
64
+ # interval to "step", then the learning rate update will
65
+ # be wrong.
66
+ if self.t_in_epochs:
67
+ super().step(epoch=self._last_epoch)
68
+ else:
69
+ super().step_update(num_updates=self._last_epoch)
70
+
71
+
72
+ class LoggingContext:
73
+ """Context manager for selective logging."""
74
+ def __init__(self, logger, level=None, handler=None, close=True):
75
+ self.logger = logger
76
+ self.level = level
77
+ self.handler = handler
78
+ self.close = close
79
+
80
+ def __enter__(self):
81
+ if self.level is not None:
82
+ self.old_level = self.logger.level
83
+ self.logger.setLevel(self.level)
84
+ if self.handler:
85
+ self.logger.addHandler(self.handler)
86
+
87
+ def __exit__(self, et, ev, tb):
88
+ if self.level is not None:
89
+ self.logger.setLevel(self.old_level)
90
+ if self.handler:
91
+ self.logger.removeHandler(self.handler)
92
+ if self.handler and self.close:
93
+ self.handler.close()
94
+
95
+
96
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
97
+ """Initializes multi-GPU-friendly python logger."""
98
+
99
+ logger = logging.getLogger(name)
100
+ logger.setLevel(level)
101
+
102
+ # this ensures all logging levels get marked with the rank zero decorator
103
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
104
+ for level in ('debug', 'info', 'warning', 'error',
105
+ 'exception', 'fatal', 'critical'):
106
+ setattr(logger,
107
+ level,
108
+ lightning.pytorch.utilities.rank_zero_only(
109
+ getattr(logger, level)))
110
+
111
+ return logger
112
+
113
+
114
+ def str2bool(v):
115
+ if isinstance(v, bool):
116
+ return v
117
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
118
+ return True
119
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
120
+ return False
121
+ else:
122
+ raise argparse.ArgumentTypeError('Boolean value expected.')
123
+
124
+
125
+ def set_seed(seed, use_cuda):
126
+ os.environ['PYTHONHASHSEED'] = str(seed)
127
+ np.random.seed(seed)
128
+ random.seed(seed)
129
+ torch.manual_seed(seed)
130
+ # torch.backends.cudnn.deterministic = True
131
+ if use_cuda:
132
+ torch.cuda.manual_seed(seed)
133
+ torch.cuda.manual_seed_all(seed)
134
+ print(f'=> Seed of the run set to {seed}')
135
+
a2d2_pep/remasking_scheduleaware.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schedule-aware remasking and insertion logic that ensures the number of masked tokens
3
+ follows the interpolant schedule.
4
+ """
5
+ import torch
6
+ import numpy as np
7
+
8
+ def apply_schedule_aware_insertion(
9
+ model,
10
+ xt_tmp,
11
+ new_xt,
12
+ t,
13
+ dt,
14
+ ext,
15
+ mask,
16
+ pad,
17
+ max_length,
18
+ orig_mask,
19
+ new_pos_orig,
20
+ quality_threshold=1,
21
+ ):
22
+ """
23
+ Remove low-quality insertions based on insertion confidence while respecting
24
+ the interpolant schedule for expected sequence length.
25
+
26
+ Args:
27
+ model: Model with planner and interpolant
28
+ xt_tmp: Sequence after insertion [B, L]
29
+ new_xt: Sequence before insertion [B, L]
30
+ t: Current time [B]
31
+ dt: Time step size
32
+ ext: Number of insertions per gap [B, L+1]
33
+ mask: Mask token ID
34
+ pad: Pad token ID
35
+ max_length: Maximum sequence length
36
+ orig_mask: Mask of original token positions [B, L]
37
+ new_pos_orig: New positions of original tokens [B, L]
38
+ quality_threshold: If a float, drop insertions with confidence below it; if None, use schedule-driven deletion
39
+
40
+ Returns:
41
+ xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
42
+ """
43
+ device = xt_tmp.device
44
+ batch_size, L = xt_tmp.shape
45
+ total_ext = ext.sum(dim=1)
46
+
47
+ # Only proceed if there were insertions
48
+ if total_ext.sum() == 0:
49
+ return xt_tmp
50
+
51
+ # Get planner predictions on inserted state. The insertion head is trained
52
+ # with the pre-step time t (see loss_insert_planner_flexible), so condition
53
+ # on t here too; t_next is still used below for the length schedule.
54
+ t_next = t + dt
55
+ planner_out = model.planner(xt_tmp, t)
56
+ insertion_conf = planner_out.get("insertion_conf", None)
57
+
58
+ if insertion_conf is None:
59
+ return xt_tmp
60
+
61
+ insertion_conf = insertion_conf.squeeze(-1) # (B, L)
62
+
63
+ # Expected sequence length at next timestep according to schedule
64
+ current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
65
+ expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
66
+ estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
67
+ expected_length = estimated_final_length * expected_progress # [B]
68
+
69
+ # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
70
+ # Fancy-indexing scatter avoids the per-batch python loop.
71
+ valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
72
+ valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
73
+ is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
74
+ is_original[valid_b, valid_p] = True
75
+ inserted_positions = (xt_tmp == mask) & ~is_original
76
+
77
+ # Two deletion modes, selected by `quality_threshold`:
78
+ # * float: drop insertions whose confidence is below the threshold, capped
79
+ # so the length never falls below the scheduled minimum.
80
+ candidates = inserted_positions & (insertion_conf < quality_threshold)
81
+ num_bad = candidates.sum(dim=1) # [B], long
82
+ min_length = expected_length.long().clamp(min=1) # [B]
83
+ max_removable = (current_length_after.long() - min_length).clamp(min=0)
84
+ length_after_removal = current_length_after.long() - num_bad
85
+ schedule_violates = length_after_removal < min_length
86
+ k_per_row = torch.where(schedule_violates, max_removable, num_bad)
87
+ k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
88
+
89
+ if not candidates.any():
90
+ return xt_tmp
91
+
92
+ # Select the lowest-confidence candidates per row via a sort.
93
+ neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
94
+ scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
95
+ _, sorted_indices = scores.sort(dim=1, descending=True)
96
+ positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
97
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
98
+ final_bad = torch.zeros_like(candidates)
99
+ final_bad.scatter_(1, sorted_indices, keep_in_topk)
100
+
101
+ if not final_bad.any():
102
+ return xt_tmp
103
+
104
+ # Compact each row to the left (keep good, drop bad), then pad the tail.
105
+ # Stable sort by the bad flag pushes bad positions to the right.
106
+ sort_key = final_bad.long()
107
+ _, perm = torch.sort(sort_key, dim=1, stable=True)
108
+ xt_tmp = torch.gather(xt_tmp, 1, perm)
109
+ num_keep = (~final_bad).sum(dim=1) # [B]
110
+ tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
111
+ xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
112
+
113
+ return xt_tmp
114
+
115
+
116
+ def apply_schedule_aware_remasking(
117
+ model,
118
+ new_xt,
119
+ t,
120
+ dt,
121
+ remasking_conf,
122
+ clean_index,
123
+ mask,
124
+ neg_inf,
125
+ batch_size,
126
+ unmask_quality_threshold=None,
127
+ ):
128
+ """
129
+ Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
130
+
131
+ Args:
132
+ model: Model with interpolant that has an unmask_schedule
133
+ new_xt: Current sequence [B, L]
134
+ t: Current time [B]
135
+ dt: Time step size
136
+ remasking_conf: Confidence scores for tokens [B, L]
137
+ clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
138
+ mask: Mask token ID
139
+ neg_inf: Negative infinity tensor
140
+ batch_size: Batch size
141
+ unmask_quality_threshold: If None (default), remask exactly the schedule
142
+ excess (count-based). If a float, ignore the schedule budget entirely
143
+ and remask EVERY clean token whose unmasking-quality confidence is
144
+ below the threshold. Higher threshold => more aggressive remasking.
145
+
146
+ Returns:
147
+ new_xt: Modified sequence with schedule-aware remasking applied
148
+ """
149
+ # Threshold gate (overrides the schedule-driven count when set): remask every
150
+ # clean token whose unmasking-quality confidence is below the threshold,
151
+ # regardless of the schedule budget. Higher threshold => more remasking.
152
+ if unmask_quality_threshold is not None:
153
+ to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
154
+ return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
155
+
156
+ t_next = t + dt
157
+ num_clean = clean_index.sum(dim=1) # [B], long
158
+ current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
159
+ expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
160
+ expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
161
+ masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
162
+
163
+ # Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
164
+ k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
165
+
166
+ if k_per_row.sum() == 0:
167
+ return new_xt
168
+
169
+ # Use confidence to decide which clean tokens to remask: lowest conf first.
170
+ remasking_score_temp = -1.0 * remasking_conf # low conf = high score
171
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
172
+
173
+ _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
174
+ L = remasking_score_temp.shape[1]
175
+ positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
176
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
177
+ to_mask = torch.zeros_like(clean_index)
178
+ to_mask.scatter_(1, sorted_indices, keep_in_topk)
179
+ new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
180
+
181
+ return new_xt
a2d2_pep/sampling.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
4
+
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from typing import Any, Literal, Optional
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from lightning_modules.mdm import MaskedDiffusionModule
12
+
13
+
14
+ @dataclass
15
+ class SamplingTraceDatapoint:
16
+ t: float
17
+ event_type: Literal["insertion", "change"]
18
+ position: int
19
+ token: Any
20
+
21
+
22
+ @dataclass
23
+ class SamplingResult:
24
+ samples: torch.Tensor
25
+ # Trace is supposed to be processed sequentially as updates are not commutative
26
+ trace: Optional[list[SamplingTraceDatapoint]]
27
+
28
+ def __iter__(self):
29
+ yield from [self.samples, self.trace]
30
+
31
+
32
+ # Sample from categorical distribution for each position using the transition probabilities
33
+ def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
34
+ """Sample one token per position from probability distribution.
35
+ Args:
36
+ probs: [batch_size, seq_len, vocab_size] transition probabilities
37
+ Returns:
38
+ [batch_size, seq_len] sampled token indices
39
+ """
40
+ batch_size, seq_len, vocab_size = probs.shape
41
+ flat_probs = probs.view(-1, vocab_size)
42
+ samples = torch.multinomial(flat_probs, num_samples=1)
43
+ return samples.view(batch_size, seq_len)
44
+
45
+
46
+ def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
47
+
48
+ batch_size, seq_len, vocab_size = probs.shape
49
+
50
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
51
+ noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
52
+
53
+ # select the highest score (most likely category after Gumbel noise)
54
+ samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
55
+
56
+ return samples.view(batch_size, seq_len)
57
+
58
+ @torch.no_grad()
59
+ def mdm_euler_sampling(
60
+ model: MaskedDiffusionModule,
61
+ steps: int,
62
+ mask: int,
63
+ pad: int,
64
+ batch_size: int,
65
+ max_length: int,
66
+ return_trace: bool = False,
67
+ temperature: float = 1.0,
68
+ ):
69
+ assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
70
+ device = next(model.parameters()).device
71
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
72
+
73
+ dt = 1.0 / steps
74
+ t = torch.zeros(batch_size, device=device)
75
+
76
+ for i in range(steps):
77
+ print("i-th sampling step")
78
+ # ——— predict and convert rates ———
79
+ pred_rate = model(xt, t)
80
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
81
+ unmask_rate = pred_rate.unmask_rate
82
+
83
+ # ——— unmask step (Euler) ———
84
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
85
+ unmask_rate[xt != mask] = 0
86
+ unmask_rate[mask_pos + (mask,)] = 0
87
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
88
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
89
+
90
+ _xt = xt.clone()
91
+ trans_prob.scatter_add_(
92
+ 2,
93
+ _xt.unsqueeze(-1),
94
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
95
+ )
96
+
97
+ # Apply temperature scaling
98
+ if temperature != 1.0:
99
+ logits = torch.log(trans_prob + 1e-10) / temperature
100
+ trans_prob = torch.softmax(logits, dim=-1)
101
+
102
+ if i == steps - 1:
103
+ print("Final step, removing mask token from sampling")
104
+ trans_prob[mask_pos + (mask,)] = 0.0
105
+ print(trans_prob[mask_pos + (mask,)])
106
+
107
+ new_xt = _sample_tokens(trans_prob)
108
+ new_xt = torch.where(xt != mask, xt, new_xt)
109
+
110
+ xt = new_xt
111
+ t = t + dt
112
+
113
+ return xt, []
114
+
115
+
116
+ @torch.no_grad()
117
+ def any_order_mask_insertion_euler_sampling(
118
+ model: torch.nn.Module,
119
+ steps: int,
120
+ mask: int,
121
+ pad: int,
122
+ batch_size: int,
123
+ max_length: int,
124
+ return_trace: bool = False,
125
+ temperature: float = 1.0,
126
+ ) -> SamplingResult:
127
+ device = next(model.parameters()).device
128
+
129
+ # 1) Initialize all‑pad sequence and trace
130
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
131
+ sampling_trace = []
132
+
133
+ dt = 1.0 / steps
134
+ t = torch.zeros(batch_size, device=device)
135
+
136
+ # Precompute row indices for scatter
137
+ batch_idx_L = (
138
+ torch.arange(batch_size, device=device)
139
+ .view(batch_size, 1)
140
+ .expand(batch_size, max_length)
141
+ )
142
+ pos_idx_L = (
143
+ torch.arange(max_length, device=device)
144
+ .view(1, max_length)
145
+ .expand(batch_size, max_length)
146
+ )
147
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
148
+
149
+ for i in range(steps):
150
+ # ——— predict and convert rates ———
151
+ pred_rate = model(xt, t)
152
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
153
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
154
+ len_rate = pred_rate.length_rate # (B, L+1)
155
+
156
+ # ——— unmask step (Euler) ———
157
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
158
+ unmask_rate[xt != mask] = 0
159
+ unmask_rate[mask_pos + (mask,)] = 0
160
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
161
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
162
+
163
+ # add “stay” probability
164
+ _xt = xt.clone()
165
+ _xt[xt == pad] = mask
166
+ trans_prob.scatter_add_(
167
+ 2,
168
+ _xt.unsqueeze(-1),
169
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
170
+ )
171
+
172
+ if i == steps - 1:
173
+ print("Final step, removing mask token from sampling")
174
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
175
+
176
+ # renormalize probabilities to ensure they sum to 1
177
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
178
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
179
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
180
+ if mask_has_zero_prob.any():
181
+ # create uniform distribution over valid tokens (excluding mask and pad)
182
+ uniform_prob = torch.zeros_like(trans_prob[0])
183
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
184
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
185
+ else:
186
+ # normalize to sum to 1
187
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
188
+
189
+ new_xt = _sample_tokens(trans_prob)
190
+ new_xt[xt == pad] = pad
191
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
192
+
193
+ if i != steps - 1:
194
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
195
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
196
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
197
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
198
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
199
+ total_ext = ext.sum(dim=1)
200
+ valid = xt_len + total_ext <= max_length
201
+ ext = ext * valid.view(batch_size, 1).long()
202
+
203
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
204
+ new_len = xt_len + total_ext # (B,)
205
+
206
+ xt_tmp = torch.full_like(xt, pad)
207
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
208
+ xt_tmp[mask_fill] = mask
209
+
210
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
211
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
212
+ flat_b = batch_idx_L[orig_mask]
213
+ flat_p = new_pos_orig[orig_mask]
214
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
215
+ else:
216
+ xt_tmp = new_xt
217
+
218
+ if return_trace:
219
+ # Check if the token was changed
220
+ for batch_idx in range(batch_size):
221
+ for j in range(max_length):
222
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
223
+ sampling_trace[batch_idx].append(
224
+ SamplingTraceDatapoint(
225
+ t=t[batch_idx].item(),
226
+ event_type="change",
227
+ position=j,
228
+ token=new_xt[batch_idx, j].item(),
229
+ )
230
+ )
231
+
232
+ # Check if a new token was inserted
233
+ for j in range(max_length):
234
+ id = max_length - j - 1
235
+ if ext[batch_idx, id]:
236
+ sampling_trace[batch_idx].append(
237
+ SamplingTraceDatapoint(
238
+ t=t[batch_idx].item(),
239
+ event_type="insertion",
240
+ position=id,
241
+ token=mask,
242
+ )
243
+ )
244
+
245
+ xt = xt_tmp
246
+ t = t + dt
247
+
248
+ return xt, sampling_trace
249
+
250
+ @torch.no_grad()
251
+ def batch_mcts_reverse_step(
252
+ xt: torch.Tensor,
253
+ t: torch.Tensor,
254
+ dt: float,
255
+ model: torch.nn.Module,
256
+ pretrained: torch.nn.Module,
257
+ mask: int,
258
+ pad: int,
259
+ batch_size: int,
260
+ max_length: int,
261
+ last_step: bool = False,
262
+ temperature: float = 1.0,
263
+ ) -> SamplingResult:
264
+ device = next(model.parameters()).device
265
+
266
+ xt = xt.repeat(batch_size, 1)
267
+
268
+ # squeeze to remove extra dimensions, then expand to batch_size
269
+ t = t.squeeze().expand(batch_size)
270
+ # precompute row indices for scatter
271
+ batch_idx_L = (
272
+ torch.arange(batch_size, device=device)
273
+ .view(batch_size, 1)
274
+ .expand(batch_size, max_length)
275
+ )
276
+ pos_idx_L = (
277
+ torch.arange(max_length, device=device)
278
+ .view(1, max_length)
279
+ .expand(batch_size, max_length)
280
+ )
281
+
282
+ # ——— predict and convert rates ———
283
+ pred_rate = model(xt, t)
284
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
285
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
286
+ len_rate = pred_rate.length_rate # (B, L+1)
287
+
288
+ # ——— get pretrained model rates for log_rnd computation ———
289
+ pretrained_pred = pretrained(xt, t)
290
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
291
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
292
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
293
+
294
+ # ——— unmask step (Euler) ———
295
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
296
+ unmask_rate[xt != mask] = 0
297
+ unmask_rate[mask_pos + (mask,)] = 0
298
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
299
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
300
+
301
+ # Same for pretrained
302
+ pretrained_unmask_rate[xt != mask] = 0
303
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
304
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
305
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
306
+
307
+ # add “stay” probability
308
+ _xt = xt.clone()
309
+ _xt[xt == pad] = mask
310
+ trans_prob.scatter_add_(
311
+ 2,
312
+ _xt.unsqueeze(-1),
313
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
314
+ )
315
+ pretrained_trans_prob.scatter_add_(
316
+ 2,
317
+ _xt.unsqueeze(-1),
318
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
319
+ )
320
+
321
+ if last_step:
322
+ print("Final step, removing mask token from sampling")
323
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
324
+
325
+ # renormalize probabilities to ensure they sum to 1
326
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
327
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
328
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
329
+ if mask_has_zero_prob.any():
330
+ # create uniform distribution over valid tokens (excluding mask and pad)
331
+ uniform_prob = torch.zeros_like(trans_prob[0])
332
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
333
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
334
+ else:
335
+ # normalize to sum to 1
336
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
337
+
338
+ new_xt = _sample_tokens(trans_prob)
339
+ new_xt[xt == pad] = pad
340
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
341
+
342
+ # ——— compute log probabilities for RND ———
343
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
344
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
345
+
346
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
347
+
348
+ log_policy_step = (lp * changed_mask).sum(dim=1)
349
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
350
+
351
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
352
+
353
+ if not last_step:
354
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
355
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
356
+
357
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
358
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
359
+
360
+ # log P(ext; λ) = ext*log(λ) - λ
361
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
362
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
363
+
364
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
365
+ log_rnd += log_insert_diff
366
+ log_pretrained_step += log_pretrained_insert
367
+ log_policy_step += log_policy_insert
368
+
369
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
370
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
371
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
372
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
373
+ total_ext = ext.sum(dim=1)
374
+ valid = xt_len + total_ext <= max_length
375
+ ext = ext * valid.view(batch_size, 1).long()
376
+
377
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
378
+ new_len = xt_len + total_ext # (B,)
379
+
380
+ xt_tmp = torch.full_like(xt, pad)
381
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
382
+ xt_tmp[mask_fill] = mask
383
+
384
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
385
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
386
+ flat_b = batch_idx_L[orig_mask]
387
+ flat_p = new_pos_orig[orig_mask]
388
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
389
+ else:
390
+ xt_tmp = new_xt
391
+
392
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
393
+
394
+
395
+ @torch.no_grad()
396
+ def mcts_reverse_step(
397
+ xt: torch.Tensor,
398
+ t: torch.Tensor,
399
+ dt: float,
400
+ model: torch.nn.Module,
401
+ pretrained: torch.nn.Module,
402
+ mask: int,
403
+ pad: int,
404
+ max_length: int,
405
+ last_step: bool = False,
406
+ temperature: float = 1.0,
407
+ ) -> SamplingResult:
408
+ device = next(model.parameters()).device
409
+
410
+ batch_size = xt.size(0)
411
+
412
+ # precompute row indices for scatter
413
+ batch_idx_L = (
414
+ torch.arange(batch_size, device=device)
415
+ .view(batch_size, 1)
416
+ .expand(batch_size, max_length)
417
+ )
418
+ pos_idx_L = (
419
+ torch.arange(max_length, device=device)
420
+ .view(1, max_length)
421
+ .expand(batch_size, max_length)
422
+ )
423
+
424
+ # ——— predict and convert rates ———
425
+ pred_rate = model(xt, t)
426
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
427
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
428
+ len_rate = pred_rate.length_rate # (B, L+1)
429
+
430
+ # ——— get pretrained model rates for log_rnd computation ———
431
+ pretrained_pred = pretrained(xt, t)
432
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
433
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
434
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
435
+
436
+ # ——— unmask step (Euler) ———
437
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
438
+ unmask_rate[xt != mask] = 0
439
+ unmask_rate[mask_pos + (mask,)] = 0
440
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
441
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
442
+
443
+ # same for pretrained
444
+ pretrained_unmask_rate[xt != mask] = 0
445
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
446
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
447
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
448
+
449
+ # add “stay” probability
450
+ _xt = xt.clone()
451
+ _xt[xt == pad] = mask
452
+ trans_prob.scatter_add_(
453
+ 2,
454
+ _xt.unsqueeze(-1),
455
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
456
+ )
457
+ pretrained_trans_prob.scatter_add_(
458
+ 2,
459
+ _xt.unsqueeze(-1),
460
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
461
+ )
462
+
463
+ if last_step:
464
+ print("Final step, removing mask token from sampling")
465
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
466
+
467
+ # renormalize probabilities to ensure they sum to 1
468
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
469
+ # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
470
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
471
+ if mask_has_zero_prob.any():
472
+ # create uniform distribution over valid tokens (excluding mask and pad)
473
+ uniform_prob = torch.zeros_like(trans_prob[0])
474
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
475
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
476
+ else:
477
+ # normalize to sum to 1
478
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
479
+
480
+ new_xt = _sample_tokens(trans_prob)
481
+ new_xt[xt == pad] = pad
482
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
483
+
484
+ # ——— compute log probabilities for RND ———
485
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
486
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
487
+
488
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
489
+
490
+ log_policy_step = (lp * changed_mask).sum(dim=1)
491
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
492
+
493
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
494
+
495
+ if not last_step:
496
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
497
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
498
+
499
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
500
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
501
+
502
+ # log P(ext; λ) = ext*log(λ) - λ
503
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
504
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
505
+
506
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
507
+ log_rnd += log_insert_diff
508
+ log_pretrained_step += log_pretrained_insert
509
+ log_policy_step += log_policy_insert
510
+
511
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
512
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
513
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
514
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
515
+ total_ext = ext.sum(dim=1)
516
+ valid = xt_len + total_ext <= max_length
517
+ ext = ext * valid.view(batch_size, 1).long()
518
+
519
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
520
+ new_len = xt_len + total_ext # (B,)
521
+
522
+ xt_tmp = torch.full_like(xt, pad)
523
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
524
+ xt_tmp[mask_fill] = mask
525
+
526
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
527
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
528
+ flat_b = batch_idx_L[orig_mask]
529
+ flat_p = new_pos_orig[orig_mask]
530
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
531
+ else:
532
+ xt_tmp = new_xt
533
+
534
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
535
+
536
+ @torch.no_grad()
537
+ def any_order_euler_sampling_with_schedule(
538
+ model: torch.nn.Module,
539
+ time_schedule: torch.Tensor,
540
+ mask: int,
541
+ pad: int,
542
+ batch_size: int,
543
+ max_length: int,
544
+ return_trace: bool = False,
545
+ temperature: float = 1.0,
546
+ ) -> SamplingResult:
547
+ device = next(model.parameters()).device
548
+
549
+ time_schedule = time_schedule.to(device)
550
+ if time_schedule[0] < time_schedule[-1]:
551
+ time_schedule = torch.flip(time_schedule, [0]) # descending order
552
+
553
+ steps = len(time_schedule) - 1
554
+
555
+ # initialize all-pad sequence and trace
556
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
557
+
558
+ # precompute row indices for scatter
559
+ batch_idx_L = (
560
+ torch.arange(batch_size, device=device)
561
+ .view(batch_size, 1)
562
+ .expand(batch_size, max_length)
563
+ )
564
+ pos_idx_L = (
565
+ torch.arange(max_length, device=device)
566
+ .view(1, max_length)
567
+ .expand(batch_size, max_length)
568
+ )
569
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
570
+
571
+ for i in range(steps):
572
+ # use scheduled timesteps
573
+ t = time_schedule[i].repeat(batch_size)
574
+ t_next = time_schedule[i + 1]
575
+ dt = (t - t_next).abs() # timestep difference
576
+
577
+ # ——— predict and convert rates ———
578
+ pred_rate = model(xt, t)
579
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
580
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
581
+ len_rate = pred_rate.length_rate # (B, L+1)
582
+
583
+ # ——— unmask step (Euler) ———
584
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
585
+ unmask_rate[xt != mask] = 0
586
+ unmask_rate[mask_pos + (mask,)] = 0
587
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
588
+ trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
589
+
590
+ # add "stay" probability
591
+ _xt = xt.clone()
592
+ _xt[xt == pad] = mask
593
+ trans_prob.scatter_add_(
594
+ 2,
595
+ _xt.unsqueeze(-1),
596
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
597
+ )
598
+
599
+ # Apply temperature scaling
600
+ if temperature != 1.0:
601
+ logits = torch.log(trans_prob + 1e-10) / temperature
602
+ trans_prob = torch.softmax(logits, dim=-1)
603
+
604
+ if i == steps - 1:
605
+ print("Final step, removing mask token from sampling")
606
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
607
+
608
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
609
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
610
+
611
+ if mask_has_zero_prob.any():
612
+ uniform_prob = torch.zeros_like(trans_prob[0])
613
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
614
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
615
+ else:
616
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
617
+
618
+ new_xt = _sample_tokens(trans_prob)
619
+ new_xt[xt == pad] = pad
620
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
621
+
622
+ if i != steps - 1:
623
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
624
+ ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
625
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
626
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
627
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
628
+ total_ext = ext.sum(dim=1)
629
+ valid = xt_len + total_ext <= max_length
630
+ ext = ext * valid.view(batch_size, 1).long()
631
+
632
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
633
+ new_len = xt_len + total_ext # (B,)
634
+
635
+ xt_tmp = torch.full_like(xt, pad)
636
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
637
+ xt_tmp[mask_fill] = mask
638
+
639
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
640
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
641
+ flat_b = batch_idx_L[orig_mask]
642
+ flat_p = new_pos_orig[orig_mask]
643
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
644
+ else:
645
+ xt_tmp = new_xt
646
+
647
+ if return_trace:
648
+ # Check if the token was changed
649
+ for batch_idx in range(batch_size):
650
+ for j in range(max_length):
651
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
652
+ sampling_trace[batch_idx].append(
653
+ SamplingTraceDatapoint(
654
+ t=t[batch_idx].item(),
655
+ event_type="change",
656
+ position=j,
657
+ token=new_xt[batch_idx, j].item(),
658
+ )
659
+ )
660
+
661
+ # Check if a new token was inserted
662
+ for j in range(max_length):
663
+ id = max_length - j - 1
664
+ if ext[batch_idx, id]:
665
+ sampling_trace[batch_idx].append(
666
+ SamplingTraceDatapoint(
667
+ t=t[batch_idx].item(),
668
+ event_type="insertion",
669
+ position=id,
670
+ token=mask,
671
+ )
672
+ )
673
+
674
+ xt = xt_tmp
675
+
676
+ return xt, sampling_trace
677
+
678
+
679
+ @torch.no_grad()
680
+ def any_order_mask_insertion_euler_sampling_with_rnd(
681
+ model, pretrained, reward_model, analyzer,
682
+ tokenizer, steps,
683
+ mask,
684
+ pad,
685
+ batch_size,
686
+ max_length,
687
+ return_trace = False,
688
+ alpha = 0.1,
689
+ temperature: float = 1.0,
690
+ ):
691
+ device = next(model.parameters()).device
692
+
693
+ # initialize all‑pad sequence and trace
694
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
695
+ sampling_trace = []
696
+
697
+ # initialize log_rnd to accumulate log probability ratios
698
+ log_rnd = torch.zeros(batch_size, device=device)
699
+
700
+ dt = 1.0 / steps
701
+ t = torch.zeros(batch_size, device=device)
702
+
703
+ # precompute row indices for scatter
704
+ batch_idx_L = (
705
+ torch.arange(batch_size, device=device)
706
+ .view(batch_size, 1)
707
+ .expand(batch_size, max_length)
708
+ )
709
+ pos_idx_L = (
710
+ torch.arange(max_length, device=device)
711
+ .view(1, max_length)
712
+ .expand(batch_size, max_length)
713
+ )
714
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
715
+
716
+ for i in range(steps):
717
+ # ——— predict and convert rates ———
718
+ pred_rate = model(xt, t)
719
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
720
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
721
+ len_rate = pred_rate.length_rate # (B, L+1)
722
+
723
+ # ——— get pretrained model rates for log_rnd computation ———
724
+ pretrained_pred = pretrained(xt, t)
725
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
726
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
727
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
728
+
729
+ # ——— unmask step (Euler) ———
730
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
731
+ unmask_rate[xt != mask] = 0
732
+ unmask_rate[mask_pos + (mask,)] = 0
733
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
734
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
735
+
736
+ # Same for pretrained
737
+ pretrained_unmask_rate[xt != mask] = 0
738
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
739
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
740
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
741
+
742
+ # add “stay” probability
743
+ _xt = xt.clone()
744
+ _xt[xt == pad] = mask
745
+ trans_prob.scatter_add_(
746
+ 2,
747
+ _xt.unsqueeze(-1),
748
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
749
+ )
750
+ pretrained_trans_prob.scatter_add_(
751
+ 2,
752
+ _xt.unsqueeze(-1),
753
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
754
+ )
755
+
756
+ # Apply temperature scaling
757
+ if temperature != 1.0:
758
+ logits = torch.log(trans_prob + 1e-10) / temperature
759
+ trans_prob = torch.softmax(logits, dim=-1)
760
+
761
+ if i == steps - 1:
762
+ print("Final step, removing mask token from sampling")
763
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
764
+
765
+ # renormalize probabilities to ensure they sum to 1
766
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
767
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
768
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
769
+ if mask_has_zero_prob.any():
770
+ # create uniform distribution over valid tokens (excluding mask and pad)
771
+ uniform_prob = torch.zeros_like(trans_prob[0])
772
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
773
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
774
+ else:
775
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
776
+
777
+ new_xt = _sample_tokens(trans_prob)
778
+ new_xt[xt == pad] = pad
779
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
780
+
781
+ # ——— compute log probabilities for RND ———
782
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
783
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
784
+
785
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
786
+
787
+ log_policy_step = (lp * changed_mask).sum(dim=1)
788
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
789
+
790
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
791
+
792
+ if i != steps - 1:
793
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
794
+
795
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
796
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
797
+
798
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
799
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
800
+
801
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
802
+ log_rnd += log_insert_diff
803
+
804
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
805
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
806
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
807
+ total_ext = ext.sum(dim=1)
808
+ valid = xt_len + total_ext <= max_length
809
+ ext = ext * valid.view(batch_size, 1).long()
810
+
811
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
812
+ new_len = xt_len + total_ext # (B,)
813
+
814
+ xt_tmp = torch.full_like(xt, pad)
815
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
816
+ xt_tmp[mask_fill] = mask
817
+
818
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
819
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
820
+ flat_b = batch_idx_L[orig_mask]
821
+ flat_p = new_pos_orig[orig_mask]
822
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
823
+ else:
824
+ xt_tmp = new_xt
825
+
826
+ if return_trace:
827
+ # check if the token was changed
828
+ for i in range(batch_size):
829
+ for j in range(max_length):
830
+ if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
831
+ sampling_trace[i].append(
832
+ SamplingTraceDatapoint(
833
+ t=t[i].item(),
834
+ event_type="change",
835
+ position=j,
836
+ token=new_xt[i, j].item(),
837
+ )
838
+ )
839
+
840
+ # check if a new token was inserted
841
+ for j in range(max_length):
842
+ id = max_length - j - 1
843
+ if ext[i, id]:
844
+ sampling_trace[i].append(
845
+ SamplingTraceDatapoint(
846
+ t=t[i].item(),
847
+ event_type="insertion",
848
+ position=id,
849
+ token=mask,
850
+ )
851
+ )
852
+
853
+ xt = xt_tmp
854
+ t = t + dt
855
+
856
+ # change rewards for peptides
857
+ samples = xt.to(device)
858
+
859
+ # store raw token IDs
860
+ # Decode and strip samples
861
+ decoded_samples = tokenizer.batch_decode(samples)
862
+
863
+ valid_x_final = []
864
+ validSequences = []
865
+ valid_log_rnd = []
866
+
867
+ for idx, seq in enumerate(decoded_samples):
868
+ # check if the peptide is valid
869
+ if analyzer.is_peptide(seq):
870
+ valid_x_final.append(xt[idx])
871
+ validSequences.append(seq)
872
+ valid_log_rnd.append(log_rnd[idx])
873
+
874
+ print("len valid sequences:", len(validSequences))
875
+ # compute multi-objective rewards
876
+ score_vectors = reward_model(input_seqs=validSequences)
877
+ scalar_rewards = np.sum(score_vectors, axis=-1)
878
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
879
+
880
+ print(f"scalar reward dim{len(scalar_rewards)}")
881
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
882
+
883
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
884
+ valid_x_final = torch.stack(valid_x_final, dim=0)
885
+
886
+ return valid_x_final, log_rnd, scalar_rewards, sampling_trace
887
+
888
+ @torch.no_grad()
889
+ def any_order_finetuned_euler_sampler(
890
+ model, reward_model, analyzer,
891
+ tokenizer, steps,
892
+ mask,
893
+ pad,
894
+ batch_size,
895
+ max_length,
896
+ return_trace = False,
897
+ dataframe = False,
898
+ temperature: float = 1.0,
899
+ ):
900
+ device = next(model.parameters()).device
901
+
902
+ # initialize all‑pad sequence and trace
903
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
904
+ sampling_trace = []
905
+
906
+ dt = 1.0 / steps
907
+ t = torch.zeros(batch_size, device=device)
908
+
909
+ # precompute row indices for scatter
910
+ batch_idx_L = (
911
+ torch.arange(batch_size, device=device)
912
+ .view(batch_size, 1)
913
+ .expand(batch_size, max_length)
914
+ )
915
+ pos_idx_L = (
916
+ torch.arange(max_length, device=device)
917
+ .view(1, max_length)
918
+ .expand(batch_size, max_length)
919
+ )
920
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
921
+
922
+ for i in range(steps):
923
+ # ——— predict and convert rates ———
924
+ pred_rate = model(xt, t)
925
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
926
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
927
+ len_rate = pred_rate.length_rate # (B, L+1)
928
+
929
+ # ——— unmask step (Euler) ———
930
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
931
+ unmask_rate[xt != mask] = 0
932
+ unmask_rate[mask_pos + (mask,)] = 0
933
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
934
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
935
+
936
+ # add “stay” probability
937
+ _xt = xt.clone()
938
+ _xt[xt == pad] = mask
939
+ trans_prob.scatter_add_(
940
+ 2,
941
+ _xt.unsqueeze(-1),
942
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
943
+ )
944
+
945
+ # Apply temperature scaling
946
+ if temperature != 1.0:
947
+ logits = torch.log(trans_prob + 1e-10) / temperature
948
+ trans_prob = torch.softmax(logits, dim=-1)
949
+
950
+ if i == steps - 1:
951
+ print("Final step, removing mask token from sampling")
952
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
953
+
954
+ # renormalize probabilities to ensure they sum to 1
955
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
956
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
957
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
958
+ if mask_has_zero_prob.any():
959
+ # create uniform distribution over valid tokens (excluding mask and pad)
960
+ uniform_prob = torch.zeros_like(trans_prob[0])
961
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
962
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
963
+ else:
964
+ # normalize to sum to 1
965
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
966
+
967
+ new_xt = _sample_tokens(trans_prob)
968
+ new_xt[xt == pad] = pad
969
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
970
+
971
+ if i != steps - 1:
972
+ # gap-wise insertion refactored — compute new length, fill masks, scatter tokens
973
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
974
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
975
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
976
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
977
+ total_ext = ext.sum(dim=1)
978
+ valid = xt_len + total_ext <= max_length
979
+ ext = ext * valid.view(batch_size, 1).long()
980
+
981
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
982
+ new_len = xt_len + total_ext # (B,)
983
+
984
+ xt_tmp = torch.full_like(xt, pad)
985
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
986
+ xt_tmp[mask_fill] = mask
987
+
988
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
989
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
990
+ flat_b = batch_idx_L[orig_mask]
991
+ flat_p = new_pos_orig[orig_mask]
992
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
993
+ else:
994
+ xt_tmp = new_xt
995
+
996
+ if return_trace:
997
+ # check if the token was changed
998
+ for batch_idx in range(batch_size):
999
+ for j in range(max_length):
1000
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
1001
+ sampling_trace[batch_idx].append(
1002
+ SamplingTraceDatapoint(
1003
+ t=t[batch_idx].item(),
1004
+ event_type="change",
1005
+ position=j,
1006
+ token=new_xt[batch_idx, j].item(),
1007
+ )
1008
+ )
1009
+
1010
+ # check if a new token was inserted
1011
+ for j in range(max_length):
1012
+ id = max_length - j - 1
1013
+ if ext[batch_idx, id]:
1014
+ sampling_trace[batch_idx].append(
1015
+ SamplingTraceDatapoint(
1016
+ t=t[batch_idx].item(),
1017
+ event_type="insertion",
1018
+ position=id,
1019
+ token=mask,
1020
+ )
1021
+ )
1022
+
1023
+ xt = xt_tmp
1024
+ t = t + dt
1025
+
1026
+ # start eval
1027
+ samples = xt.to(device)
1028
+
1029
+ decoded_samples = tokenizer.batch_decode(samples)
1030
+
1031
+ valid_x_final = []
1032
+ validSequences = []
1033
+
1034
+ for idx, seq in enumerate(decoded_samples):
1035
+ if analyzer.is_peptide(seq):
1036
+ valid_x_final.append(samples[idx])
1037
+ validSequences.append(seq)
1038
+
1039
+ print("len valid sequences:", len(validSequences))
1040
+ valid_fraction = len(validSequences) / batch_size
1041
+
1042
+ if (len(validSequences) != 0):
1043
+ # add scores to log
1044
+ score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
1045
+ average_scores = score_vectors.T
1046
+
1047
+ affinity = average_scores[0]
1048
+ sol = average_scores[1]
1049
+ hemo = average_scores[2]
1050
+ nf = average_scores[3]
1051
+ permeability = average_scores[4]
1052
+
1053
+ else:
1054
+ zeros = [0.0]
1055
+
1056
+ affinity = zeros
1057
+ sol = zeros
1058
+ hemo = zeros
1059
+ nf = zeros
1060
+ permeability = zeros
1061
+
1062
+ if dataframe:
1063
+ df = pd.DataFrame({
1064
+ "Peptide Sequence": validSequences,
1065
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
1066
+ "Solubility": sol if len(validSequences) else [0.0],
1067
+ "Hemolysis": hemo if len(validSequences) else [0.0],
1068
+ "Nonfouling": nf if len(validSequences) else [0.0],
1069
+ "Permeability": permeability if len(validSequences) else [0.0],
1070
+ })
1071
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
1072
+
1073
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
1074
+
1075
+ @torch.no_grad()
1076
+ def mdm_tau_leaping_sampling(
1077
+ model: MaskedDiffusionModule,
1078
+ steps: int,
1079
+ mask: int,
1080
+ pad: int,
1081
+ batch_size: int,
1082
+ max_length: int,
1083
+ return_trace: bool = False,
1084
+ temperature: float = 1.0,
1085
+ ):
1086
+ assert not return_trace, "Trace is not yet supported"
1087
+ device = next(model.parameters()).device
1088
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
1089
+ dt = 1.0 / steps
1090
+ t = torch.zeros(batch_size, device=device)
1091
+
1092
+ for i in range(steps):
1093
+ # ——— predict and convert rates ———
1094
+ pred = model(xt, t)
1095
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
1096
+ unmask_rate = pred.unmask_rate # (B, L, V)
1097
+
1098
+ if i == steps - 1:
1099
+ # last step: deterministic unmask via argmax
1100
+ mask_pos = xt == mask # (B, L)
1101
+ new_token = unmask_rate.argmax(dim=2) # (B, L)
1102
+ new_xt = xt.clone()
1103
+ new_xt[mask_pos] = new_token[mask_pos]
1104
+ new_xt = torch.where(xt != mask, xt, new_xt)
1105
+ xt = new_xt
1106
+ t = t + dt
1107
+ continue
1108
+ # tau-leaping via Poisson counts
1109
+ counts = torch.poisson(unmask_rate * dt).long()
1110
+ mask_pos = xt == mask # (B, L)
1111
+ # zero out non-mask positions and mask→mask
1112
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
1113
+ counts[..., mask] = 0
1114
+ # only accept exactly one event
1115
+ sum_c = counts.sum(dim=2) # (B, L)
1116
+ one_event = sum_c == 1
1117
+ new_token = counts.argmax(dim=2) # (B, L)
1118
+
1119
+ # build new xt
1120
+ new_xt = xt.clone()
1121
+ new_xt[one_event] = new_token[one_event]
1122
+ # keep pads and already-unmasked tokens
1123
+ new_xt = torch.where(xt != mask, xt, new_xt)
1124
+ xt = new_xt
1125
+ t = t + dt
1126
+
1127
+ return xt, []
1128
+
1129
+ # Not used in production, for debugging purposes
1130
+ lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
1131
+
1132
+ def binomial_mass(k, n, p):
1133
+ """
1134
+ Calculate the probability mass function (PMF) for a binomial distribution.
1135
+
1136
+ Args:
1137
+ k (int): Number of successes
1138
+ n (int): Number of trials
1139
+ p (float): Probability of success in a single trial
1140
+
1141
+ Returns:
1142
+ float: Probability mass P(X = k)
1143
+ """
1144
+ import math
1145
+
1146
+ # Calculate binomial coefficient (n choose k)
1147
+ try:
1148
+ binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
1149
+ except ValueError:
1150
+ # Handle cases where k > n or negative values
1151
+ return 0.0
1152
+
1153
+ # Calculate probability mass
1154
+ return binom_coef * (p ** k) * ((1 - p) ** (n - k))
1155
+
1156
+ def calculate_rate_batch(alpha_t, len_t):
1157
+ """
1158
+ Calculate rate for a batch of alpha_t and len_t values.
1159
+
1160
+ Args:
1161
+ alpha_t (torch.Tensor): Tensor of shape (batch_size,)
1162
+ len_t (torch.Tensor): Tensor of shape (batch_size,)
1163
+
1164
+ Returns:
1165
+ torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
1166
+ """
1167
+ batch_size = alpha_t.shape[0]
1168
+ device = alpha_t.device
1169
+
1170
+ # Initialize tensors for numerator and denominator
1171
+ nom = torch.zeros(batch_size, device=device)
1172
+ denom = torch.zeros(batch_size, device=device)
1173
+
1174
+ for length, probability in lengths.items():
1175
+ # Create mask for valid entries where len_t <= length
1176
+ valid_mask = (len_t <= length) & (len_t >= 0)
1177
+
1178
+ if not valid_mask.any():
1179
+ continue
1180
+
1181
+ valid_indices = valid_mask.nonzero(as_tuple=True)[0]
1182
+ valid_len_t = len_t[valid_indices]
1183
+ valid_alpha_t = alpha_t[valid_indices]
1184
+
1185
+ # Calculate binomial probabilities efficiently using torch distribution
1186
+ binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
1187
+ binom_probs = binom_dist.log_prob(valid_len_t).exp()
1188
+
1189
+ # Update numerator and denominator for valid indices
1190
+ nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
1191
+ denom[valid_indices] += probability * binom_probs
1192
+
1193
+ # Handle division by zero in a vectorized way
1194
+ result = torch.zeros_like(nom)
1195
+ div_mask = denom > 0
1196
+ result[div_mask] = nom[div_mask] / (denom[div_mask])
1197
+
1198
+ return result
1199
+
1200
+ # Keep the original function for backward compatibility
1201
+ def calculate_rate(alpha_t, len_t):
1202
+ """Legacy scalar version of calculate_rate"""
1203
+ if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
1204
+ return calculate_rate_batch(alpha_t, len_t)
1205
+
1206
+ nom, denom = 0, 0
1207
+ for length, probability in lengths.items():
1208
+ if length >= len_t:
1209
+ nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
1210
+ denom += probability * binomial_mass(len_t, length, alpha_t)
1211
+
1212
+ if denom == 0:
1213
+ return 0.0
1214
+
1215
+ return nom /denom
1216
+
1217
+
1218
+ @torch.no_grad()
1219
+ def any_order_mask_insertion_tau_leaping_sampling(
1220
+ model: torch.nn.Module,
1221
+ steps: int,
1222
+ mask: int,
1223
+ pad: int,
1224
+ batch_size: int,
1225
+ max_length: int,
1226
+ return_trace: bool = False,
1227
+ confidence_based_sampling: bool = True, # whether to use confidence-based decoding
1228
+ alpha: float = 5.0, # hyperparameter for window size calculation
1229
+ max_window: int = 32, # Maximum window size for sliding window
1230
+ confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
1231
+ use_sliding_window: bool = False, # whether to use sliding window for position selection
1232
+ temperature: float = 1.0,
1233
+ ) -> SamplingResult:
1234
+
1235
+ device = next(model.parameters()).device
1236
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
1237
+ sampling_trace = []
1238
+ dt = 1.0 / steps
1239
+ t = torch.zeros(batch_size, device=device)
1240
+
1241
+ # Precompute row indices for scatter
1242
+ batch_idx_L = (
1243
+ torch.arange(batch_size, device=device)
1244
+ .view(batch_size, 1)
1245
+ .expand(batch_size, max_length)
1246
+ )
1247
+ pos_idx_L = (
1248
+ torch.arange(max_length, device=device)
1249
+ .view(1, max_length)
1250
+ .expand(batch_size, max_length)
1251
+ )
1252
+
1253
+ for i in range(steps):
1254
+ # --- predict rates ---
1255
+ pred = model(xt, t)
1256
+ xt_len = (xt != pad).sum(dim=1)
1257
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
1258
+ unmask_rate = pred.unmask_rate # (B, L, V)
1259
+ len_rate = pred.length_rate # (B, L+1)
1260
+
1261
+ if i == steps - 1:
1262
+ # last step: deterministic unmask via argmax
1263
+ mask_pos = xt == mask
1264
+ new_token = unmask_rate.argmax(dim=2)
1265
+ new_xt = xt.clone()
1266
+ new_xt[mask_pos] = new_token[mask_pos]
1267
+ new_xt = torch.where(xt == pad, pad, new_xt)
1268
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
1269
+ xt = new_xt
1270
+ t = t + dt
1271
+ continue
1272
+
1273
+ # --- confidence-based decoding ---
1274
+ if confidence_based_sampling > 0.0:
1275
+ # Confidence-based unmasking (vectorized)
1276
+ mask_positions = (xt == mask) # (B, L)
1277
+ num_mask_positions = mask_positions.sum(dim=1) # (B,)
1278
+
1279
+ # 1. Determine number of tokens to unmask using Poisson
1280
+ unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
1281
+
1282
+ # 2. Calculate confidence based on selected method
1283
+ if confidence_method == "position":
1284
+ # Position-based confidence: position i / len(xt)
1285
+ xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
1286
+ position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
1287
+ confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
1288
+
1289
+ elif confidence_method == "top_prob":
1290
+ # Top probability confidence
1291
+ import torch.nn.functional as F
1292
+ token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
1293
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1294
+ confidence = unmask_probs.max(dim=-1)[0] # (B, L)
1295
+
1296
+ elif confidence_method == "prob_diff":
1297
+ # Probability difference confidence (top - second top)
1298
+ import torch.nn.functional as F
1299
+ token_logits = unmask_rate # (B, L, V)
1300
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1301
+ top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
1302
+ confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
1303
+
1304
+ elif confidence_method == "entropy":
1305
+ # Entropy-based confidence (lower entropy = higher confidence)
1306
+ import torch.nn.functional as F
1307
+ token_logits = unmask_rate # (B, L, V)
1308
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
1309
+ entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
1310
+ confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
1311
+
1312
+ else:
1313
+ raise ValueError(f"Unknown confidence_method: {confidence_method}")
1314
+
1315
+ # 3. Apply window constraint if enabled
1316
+ if use_sliding_window:
1317
+ # Calculate dynamic k for each batch
1318
+ k_values = torch.minimum(
1319
+ torch.minimum(
1320
+ (alpha * unmask_counts).long(),
1321
+ torch.tensor(max_window, device=device)
1322
+ ), num_mask_positions) # (B,)
1323
+
1324
+ # Get cumulative count of mask positions
1325
+ mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
1326
+
1327
+ # Create window mask: position is eligible if it's a mask and within first k masks
1328
+ is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
1329
+ window_mask = mask_positions & is_within_window # (B, L)
1330
+
1331
+ # Set confidence to -inf for positions outside the window or non-mask positions
1332
+ confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
1333
+ else:
1334
+ # No window constraint - only mask positions are eligible
1335
+ confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
1336
+
1337
+ new_xt = xt.clone()
1338
+
1339
+ # vectorized unmasking
1340
+ max_unmask = unmask_counts.max().item()
1341
+ if max_unmask > 0:
1342
+ _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
1343
+
1344
+ # create mask for valid unmask operations
1345
+ unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
1346
+
1347
+ most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
1348
+
1349
+ selected_positions = all_top_indices[unmask_mask]
1350
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
1351
+
1352
+ new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
1353
+ else:
1354
+ # --- tau-leaping unmask via Poisson ---
1355
+ counts = torch.poisson(unmask_rate * dt).long()
1356
+ mask_pos = xt == mask
1357
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
1358
+ counts[..., mask] = 0
1359
+ sum_c = counts.sum(dim=2)
1360
+ one_event = sum_c == 1
1361
+ new_token = counts.argmax(dim=2)
1362
+ new_xt = xt.clone()
1363
+ new_xt[one_event] = new_token[one_event]
1364
+ new_xt = torch.where(xt == pad, pad, new_xt)
1365
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
1366
+
1367
+ # insertion only on non-last
1368
+ if i != steps - 1:
1369
+ # --- Poisson insertion, compute new lengths and fill masks ---
1370
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
1371
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
1372
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
1373
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
1374
+ total_ext = ext.sum(dim=1)
1375
+ valid = xt_len + total_ext <= max_length
1376
+ ext = ext * valid.view(batch_size, 1).long()
1377
+
1378
+ # compute prefix sums of insertions
1379
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
1380
+ new_len = xt_len + total_ext # (B,)
1381
+
1382
+ # initialize with pads, then fill mask up to new_len
1383
+ xt_tmp = torch.full_like(xt, pad)
1384
+ mask_pos = pos_idx_L < new_len.view(batch_size, 1)
1385
+ xt_tmp[mask_pos] = mask
1386
+
1387
+ # shift and scatter original tokens
1388
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
1389
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
1390
+ flat_b = batch_idx_L[orig_mask]
1391
+ flat_p = new_pos_orig[orig_mask]
1392
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
1393
+ else:
1394
+ xt_tmp = new_xt
1395
+
1396
+ xt = xt_tmp
1397
+ t = t + dt
1398
+ if return_trace:
1399
+ sampling_trace.append(xt)
1400
+
1401
+ return xt, sampling_trace
a2d2_pep/scripts/run_peptide_finetune.slurm ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # NOTE: --partition and --qos below are specific to our cluster. Change them
3
+ # (or remove them and pass `--partition` on the `sbatch` command line) to match
4
+ # the partitions/QOS available on yours.
5
+ #SBATCH --job-name=peptide-finetune-len256
6
+ #SBATCH --partition=b200-mig90
7
+ #SBATCH --qos=mig
8
+ #SBATCH --nodes=1
9
+ #SBATCH --gpus-per-node=1
10
+ #SBATCH --cpus-per-task=8
11
+ #SBATCH --ntasks-per-node=1
12
+ #SBATCH --mem=80GB
13
+ #SBATCH --time=02-00:00:00
14
+ #SBATCH --output=logs/peptide_finetune_%A.log
15
+
16
+ # =====================================================================
17
+ # run_peptide_finetune.slurm
18
+ #
19
+ # Single-mode job (1 MIG GPU) running ONE finetune_quality (peptide)
20
+ # experiment. Select which mode to run via the MODE_ID variable below
21
+ # (or override at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
22
+ # 0) A2D2 (Ours) – with full planner (alternating)
23
+ # 1) A2D2 w/o quality – --disable_planner
24
+ # 2) A2D2 w/o insertion planner – --disable_insertion_planner
25
+ # 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
26
+ #
27
+ # The job trains the selected mode then evaluates the resulting
28
+ # checkpoint on the same GPU.
29
+ # =====================================================================
30
+
31
+ set -e
32
+
33
+ # --- Mode selection ---------------------------------------------------
34
+ # Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
35
+ MODE_ID="${MODE_ID:-0}"
36
+
37
+ # Run prefix: YYYYMMDD + SLURM job ID
38
+ DATE_STAMP=$(date +%Y%m%d)
39
+ PREFIX="${DATE_STAMP}_job${SLURM_JOB_ID:-local$(date +%H%M%S)}"
40
+
41
+ # Default protein target (must be defined before path definitions below)
42
+ PROT_NAME=tfr
43
+
44
+ # --- Paths ------------------------------------------------------------
45
+ # Repo root is resolved at submit time so the script works from any clone:
46
+ # - set A2D2_ROOT explicitly, OR
47
+ # - run `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR), OR
48
+ # - fall back to this script's location (a2d2_pep/scripts/ -> two levels up).
49
+ if [ -n "${A2D2_ROOT:-}" ]; then
50
+ HOME_LOC="$A2D2_ROOT"
51
+ elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
52
+ HOME_LOC="$SLURM_SUBMIT_DIR"
53
+ else
54
+ HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
55
+ fi
56
+ SCRIPT_LOC="$HOME_LOC/a2d2_pep"
57
+ LOG_LOC="$HOME_LOC/logs"
58
+ SAVE_DIR="$HOME_LOC/checkpoints/finetune_test_peptides_${PROT_NAME}"
59
+ RESULTS_DIR="$HOME_LOC/results/peptide_test_ablation_${PROT_NAME}"
60
+
61
+ cd "$SCRIPT_LOC"
62
+
63
+ # BASE_PATH is passed as --base_path to finetune_quality.py: it's used
64
+ # to build the plot output path at $BASE_PATH/flexible/results/<run_name>
65
+ # (see finetune_quality.py:421). The pretrained checkpoint is now passed
66
+ # explicitly via --checkpoint_path below, so base_path no longer needs
67
+ # to follow the legacy /scratch layout.
68
+ BASE_PATH="${A2D2_BASE_PATH:-$HOME_LOC}"
69
+
70
+ mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
71
+
72
+ # --- Environment setup ------------------------------------------------
73
+ # Do NOT hardcode your W&B key. Either `wandb login` once on the cluster,
74
+ # export WANDB_API_KEY in your shell/SLURM environment before submitting,
75
+ # or set WANDB_MODE=offline to skip logging entirely.
76
+ export WANDB_DIR=$HOME_LOC/.wandb
77
+ export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
78
+ export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
79
+ # Stop wandb from hijacking stdout/stderr (its default fd-redirect mode sends
80
+ # all output to wandb/run-*/files/output.log and freezes the RUN_LOG below).
81
+ # With console off, everything flows to the `>> "$RUN_LOG" 2>&1` redirect.
82
+ export WANDB_CONSOLE=off
83
+ mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
84
+
85
+ export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
86
+ mkdir -p "$TRITON_CACHE_DIR"
87
+
88
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
89
+
90
+ # Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
91
+ # install, or just have `conda` on PATH; override CONDA_ENV if your env name
92
+ # differs from the one created by environment.yml.
93
+ CONDA_ENV="${CONDA_ENV:-a2d2}"
94
+ if [ -n "${CONDA_ROOT:-}" ]; then
95
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
96
+ elif command -v conda >/dev/null 2>&1; then
97
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
98
+ else
99
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
100
+ exit 1
101
+ fi
102
+ PYTHON_EXECUTABLE=$(which python)
103
+
104
+ # Pretrained base checkpoint
105
+ PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_pep.ckpt"
106
+
107
+ # --- Shared training hyperparameters ----------------------------------
108
+ COMMON_ARGS=(
109
+ --base_path "$BASE_PATH"
110
+ --checkpoint_path "$PRETRAINED_CKPT"
111
+ --prot_name "$PROT_NAME"
112
+ --noise_removal
113
+ --wdce_num_replicates 8
114
+ --pool_size 100
115
+ --pool_refresh_fraction 1.0
116
+ --buffer_size 50
117
+ --batch_size 200
118
+ --total_num_steps 256
119
+ --num_iter 20
120
+ --resample_every_n_step 10
121
+ --num_epochs 1000
122
+ --save_every_n_epochs 50
123
+ --reset_every_n_step 1
124
+ --alpha 0.1
125
+ --no_mcts
126
+ --schedule_warmup_epochs 20
127
+ --alternation_frequency 5
128
+ --num_remasking 3
129
+ --quality_threshold 0.2
130
+ --training_mini_batch_size 10
131
+ --max_length 256
132
+ --eval_every_n_epochs 50
133
+ --min_peptide_bonds 4
134
+ --grad_clip
135
+ --seed 42
136
+ )
137
+
138
+ # --- Shared evaluation hyperparameters --------------------------------
139
+ EVAL_COMMON_ARGS=(
140
+ --pretrained_ckpt "$PRETRAINED_CKPT"
141
+ --num_samples 50
142
+ --batch_size 200
143
+ --max_length 256
144
+ --total_num_steps 256
145
+ --num_remasking 3
146
+ --quality_threshold 0.2
147
+ --prot_name "$PROT_NAME"
148
+ --seed 42
149
+ )
150
+
151
+ # =====================================================================
152
+ # Pick experiment from $MODE_ID
153
+ # =====================================================================
154
+ case "$MODE_ID" in
155
+ 0) MODE="with_planner"; EXTRA_ARGS=() ;;
156
+ 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
157
+ 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
158
+ 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
159
+ *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
160
+ esac
161
+
162
+ RUN_NAME="${PREFIX}_peptide_${PROT_NAME}_${MODE}"
163
+ RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
164
+ RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
165
+ RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
166
+ mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
167
+
168
+ echo "=== Peptide finetune (MODE_ID=$MODE_ID) ==="
169
+ echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
170
+ echo "Mode: $MODE"
171
+ echo "Save dir: $RUN_SAVE_DIR"
172
+ echo "Results dir: $RESULTS_SUBDIR"
173
+ echo "Python: $PYTHON_EXECUTABLE"
174
+ echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
175
+
176
+ # =====================================================================
177
+ # Train
178
+ # =====================================================================
179
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_quality.py \
180
+ "${COMMON_ARGS[@]}" \
181
+ --devices 1 \
182
+ "${EXTRA_ARGS[@]}" \
183
+ --save_path_dir "$RUN_SAVE_DIR" \
184
+ >> "$RUN_LOG" 2>&1
185
+
186
+ echo "Training finished for $MODE. Log: $RUN_LOG"
187
+
188
+ # =====================================================================
189
+ # Evaluate
190
+ # =====================================================================
191
+ # finetune_quality.py saves to $RUN_SAVE_DIR/<auto_run_name>/last.ckpt,
192
+ # so glob the run_name subdir.
193
+ RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt 2>/dev/null | head -1)
194
+ if [ -z "$RUN_CKPT" ]; then
195
+ echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
196
+ exit 1
197
+ fi
198
+
199
+ echo "Evaluating checkpoint: $RUN_CKPT"
200
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_peptide_table.py \
201
+ --checkpoint_path "$RUN_CKPT" \
202
+ "${EVAL_COMMON_ARGS[@]}" \
203
+ "${EXTRA_ARGS[@]}" \
204
+ --output_dir "$RESULTS_SUBDIR" \
205
+ --device cuda:0 \
206
+ >> "$RESULTS_SUBDIR/${RUN_NAME}_eval.log" 2>&1
207
+
208
+ echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}_${PROT_NAME}.csv"
209
+
210
+ conda deactivate
a2d2_pep/scripts/train_pep.sh ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=a2d2-pep-pretrain
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --nodes=1
5
+ #SBATCH --gpus-per-node=4
6
+ #SBATCH --ntasks-per-node=4
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=512GB
9
+ #SBATCH --time=7-00:00:00
10
+ # SLURM's own catch-file (anything printed before the exec redirect below, plus
11
+ # slurm-infra messages). Relative to the submit dir, so submit this script from
12
+ # the a2d2_pep/ directory; the real run output is redirected via exec below.
13
+ #SBATCH --output=logs/slurm/%x_%j.out
14
+ #SBATCH --error=logs/slurm/%x_%j.err
15
+ #
16
+ # Pretrain the any-length insertion MDM on ~11M peptide SMILES on a dgx-b200 node.
17
+ # Submit with: sbatch scripts/train_pep.sh (from the a2d2_pep/ directory).
18
+ #
19
+ # DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
20
+ # --ntasks-per-node must match; change both together (and they override the
21
+ # training.devices value baked into config_pep.yaml via the hydra override below).
22
+
23
+ DATE=$(date +%Y%m%d)
24
+ SPECIAL_PREFIX='a2d2-peptide'
25
+
26
+ # Resolve a2d2_pep/ (which holds train.py + config_pep.yaml) so paths are
27
+ # repo-relative. This script lives in a2d2_pep/scripts/, so the direct-run
28
+ # fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
29
+ # copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_pep/ directory).
30
+ if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
31
+ SCRIPT_DIR="$SLURM_SUBMIT_DIR"
32
+ else
33
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
34
+ fi
35
+ cd "$SCRIPT_DIR"
36
+
37
+ # Auto-detect GPUs from the SLURM allocation (falls back to 4 for `bash` runs).
38
+ DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-4}}
39
+ NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
40
+ NODES=${SLURM_NNODES:-1}
41
+
42
+ LOG_LOC="$SCRIPT_DIR/logs"
43
+ mkdir -p "$LOG_LOC/slurm"
44
+ exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Weights & Biases: log in once on your machine before running this script with
48
+ # `wandb login` (or `export WANDB_API_KEY=<your-key>`).
49
+ # Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
50
+ # export WANDB_MODE=disabled
51
+ # ---------------------------------------------------------------------------
52
+
53
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
54
+
55
+ # Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
56
+ # The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
57
+ # CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
58
+ # PATH; override CONDA_ENV if your env name differs from the one created by
59
+ # environment.yml.
60
+ CONDA_ENV="${CONDA_ENV:-a2d2}"
61
+ if [ -n "${CONDA_ROOT:-}" ]; then
62
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
63
+ elif command -v conda >/dev/null 2>&1; then
64
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
65
+ else
66
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
67
+ exit 1
68
+ fi
69
+
70
+ # --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
71
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
72
+ if [ -z "$ETH_IFACE" ]; then
73
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
74
+ fi
75
+ export NCCL_IB_DISABLE=1
76
+ export NCCL_SOCKET_FAMILY=AF_INET
77
+ export NCCL_SOCKET_IFNAME=$ETH_IFACE
78
+ export NCCL_P2P_LEVEL=NVL
79
+
80
+ export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
81
+ export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
82
+ export NODE_RANK=${SLURM_NODEID:-0}
83
+
84
+ echo "=== a2d2 peptide pretraining (dgx-b200) ==="
85
+ echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
86
+
87
+ # --task pep makes train.py load config_pep.yaml; the hydra overrides pin
88
+ # devices/nodes to the SLURM allocation so the two never drift apart.
89
+ srun --ntasks-per-node=$NTASKS python train.py --task pep \
90
+ training.devices=$DEVICES \
91
+ training.nodes=$NODES
92
+
93
+ conda deactivate
a2d2_pep/train.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from pytorch_lightning.loggers import WandbLogger
4
+ from pytorch_lightning.callbacks import ModelCheckpoint
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import hydra
9
+ from omegaconf import OmegaConf
10
+ from datetime import datetime
11
+ # Directory containing this file and the config_*.yaml files (used by Hydra below).
12
+ CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ # Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
14
+ sys.path.insert(0, os.path.dirname(CONFIG_DIR))
15
+
16
+ import wandb
17
+ from lightning_modules import AnyOrderInsertionFlowModule
18
+
19
+
20
+ torch.set_printoptions(threshold=10_000)
21
+ torch.set_float32_matmul_precision("high")
22
+
23
+ # Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
24
+ torch._dynamo.config.optimize_ddp = False
25
+
26
+ def train(config):
27
+ wandb_logger = None
28
+
29
+ # set the random seed
30
+ pl.seed_everything(42)
31
+ torch.manual_seed(42)
32
+
33
+ # Only initialize wandb on rank 0 to avoid multiple runs
34
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
35
+ wandb.init(
36
+ project=config.wandb.project,
37
+ name=config.wandb.name,
38
+ config=OmegaConf.to_container(config, resolve=True), # Convert to dict
39
+ dir=config.wandb.path
40
+ )
41
+ wandb_logger = WandbLogger(
42
+ project=wandb.run.project,
43
+ name=wandb.run.name,
44
+ log_model=False, # Disable checkpoint uploading to save disk space
45
+ )
46
+
47
+ # Modify config to add timestamp to checkpoint directory
48
+ OmegaConf.set_struct(config, False)
49
+ time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
50
+ config.training.checkpoint_dir = os.path.join(
51
+ config.training.checkpoint_dir, time_string
52
+ )
53
+ OmegaConf.set_struct(config, True)
54
+
55
+ # Create checkpoint directory
56
+ os.makedirs(config.training.checkpoint_dir, exist_ok=True)
57
+
58
+ # Setup data module - check if using HuggingFace dataset
59
+ if hasattr(config, 'hf_dataset'):
60
+ # Imported lazily: the HF/SAFE path is only used by the molecule configs,
61
+ # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
62
+ from mol_dataset import setup_hf_data_and_update_config
63
+ print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
64
+ data_module = setup_hf_data_and_update_config(
65
+ config,
66
+ dataset_name=config.hf_dataset.name,
67
+ smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
68
+ )
69
+ else:
70
+ # Imported lazily: the local (arrow) path is used by the peptide config,
71
+ # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
72
+ from data.dataloading_for_dynamic_batching import setup_data_and_update_config
73
+ print("Using local dataset")
74
+ data_module = setup_data_and_update_config(config)
75
+
76
+ module = AnyOrderInsertionFlowModule(config)
77
+
78
+ # Initialize trainer
79
+
80
+ # Configure trainer arguments
81
+ # Map torch_dtype to Lightning precision
82
+ dtype_str = config.model.get('torch_dtype', 'bfloat16')
83
+ precision_map = {
84
+ 'float32': '32-true',
85
+ 'float16': '16-mixed',
86
+ 'bfloat16': 'bf16-mixed'
87
+ }
88
+ precision = precision_map.get(dtype_str, 'bf16-mixed')
89
+
90
+ trainer_kwargs = dict(
91
+ num_nodes=config.training.nodes,
92
+ accelerator="gpu",
93
+ devices=config.training.devices,
94
+ strategy="ddp",
95
+ precision=precision,
96
+ accumulate_grad_batches=(
97
+ config.training.batch_size
98
+ // (
99
+ config.training.per_gpu_batch_size
100
+ * config.training.nodes
101
+ * config.training.devices
102
+ )
103
+ ),
104
+ log_every_n_steps=10,
105
+ enable_checkpointing=True,
106
+ default_root_dir=config.training.checkpoint_dir,
107
+ gradient_clip_val=1.0,
108
+ )
109
+ # Only one of max_steps or max_epochs will be used
110
+ if config.training.max_steps is not None:
111
+ trainer_kwargs["max_steps"] = config.training.max_steps
112
+ elif config.training.num_epochs is not None:
113
+ trainer_kwargs["max_epochs"] = config.training.num_epochs
114
+ config.training.max_steps = config.training.max_steps
115
+ else:
116
+ raise ValueError(
117
+ "Either max_steps or num_epochs must be specified in the config"
118
+ )
119
+
120
+ if config.training.warmup_steps is None:
121
+ config.training.warmup_steps = int(config.training.max_steps * 0.01)
122
+
123
+ # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
124
+ checkpoint_callback = ModelCheckpoint(
125
+ monitor="train/total_loss",
126
+ mode="min",
127
+ save_top_k=config.training.save_top_k,
128
+ save_last=True,
129
+ filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
130
+ dirpath=config.training.checkpoint_dir,
131
+ # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
132
+ auto_insert_metric_name=False
133
+ )
134
+
135
+ # Add separate callback for periodic saves (no val_loss dependency). Use
136
+ # step-based saves for streaming datasets (save_every_n_steps) and epoch-based
137
+ # saves otherwise (save_every_n_epochs); whichever the config provides.
138
+ save_every_n_steps = config.training.get('save_every_n_steps', None)
139
+ save_every_n_epochs = config.training.get('save_every_n_epochs', None)
140
+ if save_every_n_steps is not None:
141
+ periodic_checkpoint_callback = ModelCheckpoint(
142
+ save_top_k=-1, # Save all periodic checkpoints
143
+ filename="step-{step:08d}",
144
+ dirpath=config.training.checkpoint_dir,
145
+ every_n_train_steps=save_every_n_steps,
146
+ auto_insert_metric_name=False
147
+ )
148
+ elif save_every_n_epochs is not None:
149
+ periodic_checkpoint_callback = ModelCheckpoint(
150
+ save_top_k=-1, # Save all periodic checkpoints
151
+ filename="epoch-{epoch:02d}",
152
+ dirpath=config.training.checkpoint_dir,
153
+ every_n_epochs=save_every_n_epochs,
154
+ auto_insert_metric_name=False
155
+ )
156
+ else:
157
+ raise ValueError(
158
+ "Either save_every_n_steps or save_every_n_epochs must be specified in the config"
159
+ )
160
+
161
+ trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
162
+
163
+ if wandb_logger is not None:
164
+ trainer_kwargs["logger"] = wandb_logger
165
+
166
+ trainer = pl.Trainer(**trainer_kwargs)
167
+
168
+ # Train the model
169
+ ckpt_path = None
170
+ if "resume_path" in config.training:
171
+ ckpt_path = config.training.resume_path
172
+
173
+ trainer.fit(module,
174
+ datamodule=data_module,
175
+ ckpt_path=ckpt_path)
176
+
177
+ # Only finish wandb on rank 0
178
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
179
+ wandb.finish()
180
+
181
+
182
+ if __name__ == '__main__':
183
+ # Parse arguments to get config name
184
+ parser = argparse.ArgumentParser()
185
+ parser.add_argument('--config_name', type=str, default='config',
186
+ help='Name of the config file to use')
187
+ parser.add_argument('--task', type=str, default=None,
188
+ help='Task name (uses config_{task}.yaml)')
189
+
190
+ # Parse known args (hydra will handle the rest)
191
+ args, unknown = parser.parse_known_args()
192
+
193
+ # Determine config name from task or config_name
194
+ if args.task:
195
+ config_name = f'config_{args.task}'
196
+ else:
197
+ config_name = args.config_name
198
+
199
+ print(f"Using config: {config_name}.yaml")
200
+
201
+ # Add config name to Hydra overrides (this persists across DDP subprocesses)
202
+ if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
203
+ unknown.insert(0, f'--config-name={config_name}')
204
+
205
+ # Reconstruct sys.argv for hydra
206
+ sys.argv = [sys.argv[0]] + unknown
207
+
208
+ # Define main function with default config (will be overridden by command line)
209
+ @hydra.main(version_base=None,
210
+ config_path=CONFIG_DIR,
211
+ config_name='config')
212
+ def main(config):
213
+ """Main entry point for training"""
214
+ train(config)
215
+
216
+ main()
assets/a2d2.gif ADDED

Git LFS Details

  • SHA256: 178ca7850ca39365492fea70cfc5e4f2e8653ceeda9a13dcd0438af61e1a83bb
  • Pointer size: 132 Bytes
  • Size of remote file: 7.83 MB
demo/quality_inference_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
environment.yml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conda environment shared across the molecule, peptide, and language experiments.
2
+ # Create with:
3
+ # conda env create -f environment.yml
4
+ # conda activate a2d2
5
+ #
6
+ # NOTE: flash-attn is hardware-specific and must be built against your installed torch
7
+ # and CUDA, so it is not listed below. It is imported by the shared transformer backbone
8
+ # (model/casual_transformer.py, model/rotary.py) and is required for all experiments.
9
+ # After creating the env, install it with:
10
+ # pip install flash-attn==2.8.3 --no-build-isolation
11
+ # Adjust pytorch-cuda below to match your CUDA toolkit / GPU.
12
+ name: a2d2
13
+ channels:
14
+ - pytorch
15
+ - nvidia
16
+ - conda-forge
17
+ dependencies:
18
+ - python=3.11
19
+ - pip
20
+ - pytorch
21
+ - pytorch-cuda=12.1
22
+ - rdkit=2023.9.6
23
+ - jupyterlab # for demo/quality_inference_demo.ipynb
24
+ - pip:
25
+ # --- core scientific / DL stack ---
26
+ - numpy==1.26.4
27
+ - scipy==1.17.1
28
+ - pandas==2.1.4
29
+ - scikit-learn==1.8.0
30
+ - pytorch-lightning==2.6.0
31
+ - lightning==2.6.1
32
+ - transformers==4.55.4
33
+ - tokenizers==0.21.4
34
+ - safetensors==0.7.0
35
+ - accelerate==0.33.0
36
+ - peft==0.15.1 # LoRA adapters (language experiment)
37
+ - datasets==2.19.2
38
+ - huggingface-hub==0.36.2
39
+ - einops==0.8.2
40
+ - timm==1.0.26
41
+ - omegaconf==2.3.0
42
+ - wandb==0.26.1
43
+ # --- molecule experiment ---
44
+ - safe-mol==0.1.14
45
+ - datamol==0.12.5
46
+ - PyTDC==1.1.15
47
+ # --- peptide experiment ---
48
+ - SmilesPE==0.0.3
49
+ - fair-esm==2.0.0
50
+ - xgboost==3.2.0
51
+ # --- plotting / utilities ---
52
+ - matplotlib==3.10.6
53
+ - seaborn==0.13.2
54
+ - tqdm==4.67.1
55
+ - joblib==1.5.3
56
+ - loguru==0.7.3
57
+ - fsspec==2024.3.1
lightning_modules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mdm import MaskedDiffusionModule
2
+ from .any_order import AnyOrderInsertionFlowModule
3
+
4
+
5
+ __all__ = [
6
+ "MaskedDiffusionModule",
7
+ "AutoregressiveModule",
8
+ "AnyOrderInsertionFlowModule",
9
+ ]
10
+
11
+
12
+ def __getattr__(name):
13
+ if name == "AutoregressiveModule":
14
+ from .autoregressive import AutoregressiveModule
15
+ return AutoregressiveModule
16
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
lightning_modules/any_length_remask.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ from omegaconf import DictConfig
6
+ import torch.nn.functional as F
7
+ from model.transformer import AnyOrderMaskInsertionFlow
8
+ from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
9
+ from .bregman import jump_kernel_elbo, mse
10
+ from .schedule import get_schedule_from_config
11
+ from lightning_modules.any_order import AnyOrderInsertionFlowModule
12
+ from model.model_wrapper import RemaskingAnyOrder
13
+ from sampling import _sample_tokens
14
+
15
+ import re
16
+ from typing import Dict, Any
17
+ from dataclasses import dataclass
18
+
19
+ def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
20
+ """
21
+ Returns a new state_dict where any key containing '._orig_mod.' is replaced
22
+ by removing the '_orig_mod' segment, e.g.
23
+ 'model._orig_mod.vocab_embed.embedding'
24
+ becomes
25
+ 'model.vocab_embed.embedding'
26
+ """
27
+ new_state_dict: Dict[str, Any] = {}
28
+ for key, value in state_dict.items():
29
+ # remove all occurrences of '._orig_mod.'
30
+ clean_key = re.sub(r"\._orig_mod\.", ".", key)
31
+ new_state_dict[clean_key] = value
32
+ return new_state_dict
33
+
34
+
35
+ @torch.no_grad()
36
+ def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float:
37
+ """Rank-based AUROC (Mann-Whitney U statistic).
38
+
39
+ AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN
40
+ when only one class is present (AUC undefined). Ties are not averaged, which
41
+ is fine for continuous logits used here.
42
+ """
43
+ scores = scores.float().reshape(-1)
44
+ labels = labels.float().reshape(-1)
45
+ n_pos = labels.sum()
46
+ n_neg = labels.numel() - n_pos
47
+ if n_pos == 0 or n_neg == 0:
48
+ return float("nan")
49
+ order = torch.argsort(scores)
50
+ ranks = torch.empty_like(scores)
51
+ ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype)
52
+ auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
53
+ return auc.item()
54
+
55
+
56
+ class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule):
57
+ """
58
+ Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model
59
+ for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint
60
+ and add the schedule model on top.
61
+ """
62
+ def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False):
63
+ # Initialize parent class first
64
+ super().__init__(config)
65
+
66
+ self.args = args
67
+ self.insertion_planner = insertion_planner
68
+
69
+ # Save hyperparameters for this class (overrides parent's save)
70
+ self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args'])
71
+
72
+ # Load pretrained model weights BEFORE initializing planner to avoid circular reference
73
+ if pretrained_checkpoint is not None:
74
+ self.load_pretrained_model(pretrained_checkpoint)
75
+
76
+ # Initialize adaptive schedule model AFTER loading pretrained weights
77
+ self.planner = RemaskingAnyOrder(
78
+ backbone=self,
79
+ d_model=self.config.model.hidden_size,
80
+ insertion_planner=insertion_planner)
81
+
82
+ def load_pretrained_model(self, checkpoint_path: str):
83
+ """
84
+ Load pretrained AnyOrderInsertionFlowModule weights.
85
+ Only loads the base model and interpolant, not the schedule model.
86
+ """
87
+ print(f"Loading pretrained model from {checkpoint_path}")
88
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
89
+
90
+ # Extract state dict - handle different checkpoint formats
91
+ if 'state_dict' in checkpoint:
92
+ state_dict = checkpoint['state_dict']
93
+ else:
94
+ state_dict = checkpoint
95
+
96
+ # Strip _orig_mod keys if present
97
+ state_dict = strip_orig_mod_keys(state_dict)
98
+
99
+ # Filter out planner keys (if any exist from a previous FT checkpoint)
100
+ base_state_dict = {k: v for k, v in state_dict.items()
101
+ if not k.startswith('planner.')}
102
+
103
+ # Load the base model weights
104
+ # Use strict=False to ignore missing schedule_model keys
105
+ incompatible_keys = self.load_state_dict(base_state_dict, strict=False)
106
+
107
+ # Filter out expected missing planner keys for cleaner output
108
+ unexpected_missing = [k for k in incompatible_keys.missing_keys
109
+ if not k.startswith('planner.')]
110
+ planner_missing = [k for k in incompatible_keys.missing_keys
111
+ if k.startswith('planner.')]
112
+
113
+ if unexpected_missing:
114
+ print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}")
115
+ if planner_missing:
116
+ print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)")
117
+ if incompatible_keys.unexpected_keys:
118
+ print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}")
119
+
120
+ # Freeze base model if specified
121
+ if self.config.training.get('freeze_base_model', False):
122
+ print("Freezing base model parameters")
123
+ for name, param in self.named_parameters():
124
+ if not name.startswith('planner.'):
125
+ param.requires_grad = False
126
+
127
+ def forward(self, x, t, return_features=False):
128
+ # Use parent class forward method
129
+ return super().forward(x, t, return_features=return_features)
130
+
131
+ def training_loss(self, x1, t):
132
+ # Use parent class training_loss for base model loss
133
+ # Planner is trained separately via loss_planner_flexible with reward gradients
134
+ unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t)
135
+ return unmask_loss, insertion_loss, total_loss
136
+
137
+
138
+ def training_step(self, batch, batch_idx):
139
+ # Extract input data
140
+ if isinstance(batch, dict):
141
+ batch = batch["input_ids"]
142
+
143
+ x1 = batch
144
+ t = self.sample_time(x1.shape[0], x1.device)
145
+
146
+ # Calculate the base model loss (planner trained separately, not here)
147
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
148
+
149
+ # Log component losses
150
+ self.log("train/unmask_loss", unmask_loss, prog_bar=True)
151
+ self.log("train/len_loss", len_loss, prog_bar=True)
152
+ self.log("train/total_loss", loss, prog_bar=True)
153
+
154
+ return loss
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ if isinstance(batch, dict):
158
+ batch = batch["input_ids"]
159
+
160
+ x1 = batch
161
+ t = self.sample_time(x1.shape[0], x1.device)
162
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
163
+
164
+ self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
165
+ self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
166
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
167
+
168
+ return loss
169
+
170
+ @classmethod
171
+ def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs):
172
+ """
173
+ Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner.
174
+ Extracts config from original pretrained checkpoint and loads finetuned weights.
175
+ """
176
+ print(f"Loading finetuned checkpoint from {checkpoint_path}")
177
+ checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False)
178
+
179
+ # Check if this is a wrapped checkpoint (from PeptideFinetuner)
180
+ hparams = checkpoint.get('hyper_parameters', {})
181
+ state_dict = checkpoint.get('state_dict', {})
182
+
183
+ # Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper)
184
+ has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys())
185
+
186
+ if has_policy_prefix:
187
+ # Detect model type (molecule vs peptide) based on vocab size in checkpoint
188
+ # Molecule models have vocab size ~1882, peptide models have ~587
189
+ vocab_size = None
190
+ for k, v in state_dict.items():
191
+ if 'vocab_embed.embedding' in k:
192
+ vocab_size = v.shape[0]
193
+ break
194
+
195
+ is_molecule_model = vocab_size is not None and vocab_size > 1000
196
+ model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner"
197
+ print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})")
198
+
199
+ # Extract args from hyperparameters
200
+ if 'args' not in hparams:
201
+ raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.")
202
+
203
+ args = hparams['args']
204
+ print(f"Found args in hyperparameters, type: {type(args)}")
205
+
206
+ # Get original checkpoint path from args
207
+ # Handle both Namespace (hasattr) and dict (get) access patterns
208
+ original_ckpt_path = None
209
+ if hasattr(args, 'checkpoint_path'):
210
+ original_ckpt_path = args.checkpoint_path
211
+ elif isinstance(args, dict) and 'checkpoint_path' in args:
212
+ original_ckpt_path = args['checkpoint_path']
213
+
214
+ # If checkpoint_path is not set or is None, use default pretrained checkpoint
215
+ # Select appropriate default based on detected model type
216
+ if original_ckpt_path is None:
217
+ _repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
218
+ if is_molecule_model:
219
+ original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt')
220
+ print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint")
221
+ else:
222
+ original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt')
223
+ print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint")
224
+
225
+ # Try to load config directly from checkpoint first (new checkpoints)
226
+ # Fall back to loading from original checkpoint (old checkpoints)
227
+ if 'config' in checkpoint:
228
+ print("Found config directly in checkpoint")
229
+ config = checkpoint['config']
230
+ else:
231
+ print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}")
232
+
233
+ # Load config from original pretrained checkpoint
234
+ orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False)
235
+ if 'config' not in orig_ckpt:
236
+ raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config")
237
+
238
+ config = orig_ckpt['config']
239
+
240
+ # Ensure adaptive schedule is enabled
241
+ # Need to disable struct mode to add new keys to OmegaConf config
242
+ from omegaconf import OmegaConf
243
+ if hasattr(config, 'training'):
244
+ OmegaConf.set_struct(config, False)
245
+ config.training.use_adaptive_schedule = True
246
+ OmegaConf.set_struct(config, True)
247
+
248
+ # Create args object if needed
249
+ if not hasattr(args, '__dict__'):
250
+ # Convert dict to object with attributes
251
+ class Args:
252
+ pass
253
+ args_obj = Args()
254
+ for k, v in args.items():
255
+ setattr(args_obj, k, v)
256
+ args = args_obj
257
+
258
+ # Initialize model with config and args
259
+ model = cls(
260
+ config=config,
261
+ args=args,
262
+ pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint
263
+ insertion_planner=getattr(args, 'insertion_planner', False)
264
+ )
265
+
266
+ # Extract policy_model weights from state_dict
267
+ policy_state = {}
268
+ for k, v in state_dict.items():
269
+ if k.startswith('policy_model.'):
270
+ # Strip 'policy_model.' prefix
271
+ new_key = k[len('policy_model.'):]
272
+ policy_state[new_key] = v
273
+
274
+ # Load the finetuned weights
275
+ incompatible = model.load_state_dict(policy_state, strict=False)
276
+ if incompatible.missing_keys or incompatible.unexpected_keys:
277
+ print(f"Warning: Incompatible keys when loading finetuned weights:")
278
+ if incompatible.missing_keys:
279
+ print(f" Missing: {incompatible.missing_keys[:5]}...")
280
+ if incompatible.unexpected_keys:
281
+ print(f" Unexpected: {incompatible.unexpected_keys[:5]}...")
282
+
283
+ # Initialize or load EMA params
284
+ if model.use_ema:
285
+ if "ema_params" in checkpoint:
286
+ # Load EMA params from checkpoint
287
+ model.ema_params = checkpoint["ema_params"]
288
+ print("Loaded EMA params from checkpoint")
289
+ else:
290
+ # Initialize empty EMA params (will be populated if needed)
291
+ model.ema_params = {
292
+ name: param.clone().detach()
293
+ for name, param in model.named_parameters()
294
+ }
295
+ print("Initialized EMA params from current model state")
296
+ else:
297
+ model.ema_params = {}
298
+
299
+ # Load planner state if it exists
300
+ if "planner_state" in checkpoint and hasattr(model, 'planner'):
301
+ model.planner.load_state_dict(checkpoint["planner_state"], strict=False)
302
+ print("Loaded planner state from checkpoint")
303
+
304
+ return model
305
+ else:
306
+ # Not a wrapped checkpoint, use default Lightning loading
307
+ # But we still need to provide required __init__ arguments
308
+ raise NotImplementedError(
309
+ "Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. "
310
+ "Please provide config and args as kwargs."
311
+ )
312
+
313
+ def on_save_checkpoint(self, checkpoint):
314
+ """Save config and EMA params, including planner state."""
315
+ # Call parent to save config and base model EMA
316
+ super().on_save_checkpoint(checkpoint)
317
+
318
+ # Explicitly save planner state
319
+ if hasattr(self, 'planner'):
320
+ checkpoint["planner_state"] = self.planner.state_dict()
321
+
322
+ def on_load_checkpoint(self, checkpoint):
323
+ """Load config and reinitialize interpolant, including planner."""
324
+ # For finetuned checkpoints loaded via custom load_from_checkpoint,
325
+ # config may not be in checkpoint (it's loaded from original checkpoint)
326
+ if "config" in checkpoint:
327
+ # Call parent to restore config and interpolant
328
+ super().on_load_checkpoint(checkpoint)
329
+ else:
330
+ # Config already set during __init__ via load_from_checkpoint
331
+ # Just restore EMA params if they exist
332
+ if self.use_ema and "ema_params" in checkpoint:
333
+ self.ema_params = checkpoint["ema_params"]
334
+
335
+ # Restore planner state if it exists in checkpoint
336
+ if hasattr(self, 'planner') and "planner_state" in checkpoint:
337
+ self.planner.load_state_dict(checkpoint["planner_state"])
338
+ print("Loaded planner from checkpoint")
339
+
340
+ def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
341
+ r"""
342
+ Weighted denoising cross entropy loss
343
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
344
+
345
+ log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
346
+ x: [B, L] (no mask)
347
+ num_replicates: R, number of replicates of each row in x
348
+ weight_func: w(lambda) for each sample, 1/lambda by default
349
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
350
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
351
+ """
352
+
353
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
354
+
355
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
356
+ if centering:
357
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
358
+
359
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
360
+
361
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
362
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
363
+
364
+ t = lamda
365
+
366
+ # compute unmasking and insertion loss
367
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
368
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
369
+
370
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
371
+
372
+ scale_factor = self.config.interpolant.max_length
373
+
374
+ match self.unmask_loss_fn:
375
+ case "elbo":
376
+ mask_indices = interpolant_sample.mask_indices
377
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
378
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
379
+ prediction.token_logits[mask_indices],
380
+ interpolant_sample.unmasked[mask_indices],
381
+ reduction="none",
382
+ )
383
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
384
+ case _:
385
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
386
+
387
+ match self.insert_loss_fn:
388
+ case "expectation":
389
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
390
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
391
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
392
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
393
+ )
394
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
395
+
396
+ case "distribution":
397
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
398
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
399
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
400
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
401
+ )
402
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
403
+
404
+ total_loss = unmask_loss + insertion_loss # [B*R]
405
+ # end compute unmasking and insertion loss
406
+
407
+ weighted_loss = total_loss * batch_weights # [B*R]
408
+ return weighted_loss.mean()
409
+
410
+ def one_step_sampler(self, xt, t, pred_rate=None):
411
+ """
412
+ Sample one step of unmasking using model predictions.
413
+
414
+ Args:
415
+ xt: Current state [B, L]
416
+ t: Time [B]
417
+ pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model.
418
+
419
+ Returns:
420
+ new_xt: Next state [B, L]
421
+ update_ids: Boolean mask of updated positions [B, L]
422
+ """
423
+ mask = self.interpolant.mask_token
424
+ pad = self.interpolant.pad_token
425
+ batch_size, L = xt.shape
426
+ device = xt.device
427
+ steps = self.args.total_num_steps
428
+ dt = 1.0 / steps
429
+ max_length = self.interpolant.max_length
430
+ # Use actual tensor dimension L instead of max_length to handle replicated batches
431
+ batch_idx_L = (
432
+ torch.arange(batch_size, device=device)
433
+ .view(batch_size, 1)
434
+ .expand(batch_size, L)
435
+ )
436
+ pos_idx_L = (
437
+ torch.arange(L, device=device)
438
+ .view(1, L)
439
+ .expand(batch_size, L)
440
+ )
441
+
442
+ # ——— predict and convert rates ———
443
+ if pred_rate is None:
444
+ pred_rate = self(xt, t)
445
+ pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t)
446
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
447
+ len_rate = pred_rate.length_rate # (B, L+1)
448
+
449
+ # ——— unmask step (Euler) ———
450
+ mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True)
451
+ unmask_rate[xt != mask] = 0
452
+ unmask_rate[mask_pos + (mask,)] = 0
453
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
454
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
455
+
456
+ # add "stay" probability
457
+ _xt = xt.clone()
458
+ _xt[xt == pad] = mask
459
+ trans_prob.scatter_add_(
460
+ 2,
461
+ _xt.unsqueeze(-1),
462
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
463
+ )
464
+
465
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
466
+
467
+ # Renormalize probabilities to ensure they sum to 1
468
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
469
+ # Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
470
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
471
+ if mask_has_zero_prob.any():
472
+ # Create uniform distribution over valid tokens (excluding mask and pad)
473
+ num_zero_prob = mask_has_zero_prob.sum().item()
474
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
475
+ uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
476
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
477
+ else:
478
+ # Normalize to sum to 1
479
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
480
+
481
+ new_xt = _sample_tokens(trans_prob)
482
+ new_xt[xt == pad] = pad
483
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
484
+
485
+ # update indices--boolean tensor of shape (B, max_length)
486
+ # A position is updated if:
487
+ # 1. The token changed (xt != new_xt)
488
+ # 2. It's not a pad position
489
+ # 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask)
490
+
491
+ # Debug before fix
492
+ old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask)
493
+
494
+ # Correct logic: updated positions are where mask tokens were changed
495
+ update_ids = (xt != new_xt) & (xt != pad)
496
+
497
+ if self.insertion_planner is False:
498
+ return new_xt, update_ids
499
+
500
+ # ——— Poisson insertion (tau-leaping) — can insert multiple masks per gap ———
501
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
502
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
503
+ # Use ext.shape[1] to get the actual max_length dimension from the data
504
+ actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1
505
+ gaps = torch.arange(ext.shape[1], device=device).view(1, -1)
506
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
507
+ total_ext = ext.sum(dim=1)
508
+ valid = xt_len + total_ext <= actual_max_length
509
+ ext = ext * valid.view(batch_size, 1).long()
510
+
511
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
512
+ new_len = xt_len + total_ext # (B,)
513
+
514
+ xt_tmp = torch.full_like(xt, pad)
515
+ # Create position indices that match xt_tmp's shape
516
+ pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1)
517
+ mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1)
518
+ xt_tmp[mask_fill] = mask
519
+
520
+ new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L)
521
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
522
+ flat_b = batch_idx_L[orig_mask]
523
+ flat_p = new_pos_orig[orig_mask]
524
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
525
+
526
+ new_ins_xt = xt_tmp
527
+
528
+ # Newly inserted masks: positions that are mask now but weren't before.
529
+ newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad)
530
+
531
+ update_ins_ids = newly_inserted_masks
532
+
533
+ return new_xt, update_ids, new_ins_xt, update_ins_ids
534
+
535
+ def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
536
+ r"""
537
+ Weighted denoising cross entropy loss
538
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
539
+
540
+ log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
541
+ x: [B, L] (no mask)
542
+ num_replicates: R, number of replicates of each row in x
543
+ weight_func: w(lambda) for each sample, 1/lambda by default
544
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
545
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
546
+ """
547
+
548
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
549
+ batch_size = batch.shape[0]
550
+
551
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
552
+ if centering:
553
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
554
+
555
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
556
+
557
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
558
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
559
+
560
+ t = lamda
561
+ scale_factor = self.config.interpolant.max_length
562
+
563
+ # compute unmasking and insertion loss
564
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
565
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
566
+
567
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
568
+
569
+ with torch.no_grad(): # no need to compute gradient in this step
570
+ sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction)
571
+ # one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids)
572
+ xs, update_ids = sampler_out[0], sampler_out[1]
573
+
574
+ # The remasking head scores the freshly-decoded tokens to decide which to
575
+ # remask, so it reads the POST-unmask state xs (matching inference, which
576
+ # calls the planner on the decoded new_xt).
577
+ planner = self.planner(xs, t)
578
+ remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
579
+
580
+ # Compute per-sample loss
581
+ # IMPORTANT: interpolant_sample.xt has been reordered via st permutation
582
+ # We need to map back to the original positions to compare with batch
583
+ st = interpolant_sample.st # [B*R, L] permutation indices
584
+ batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
585
+
586
+ binary_label = (xs == batch_reordered).float()
587
+
588
+ # Only compute loss on positions that were updated
589
+ per_token_loss = F.binary_cross_entropy_with_logits(
590
+ remasking_conf.squeeze(-1), # [B*R, L]
591
+ binary_label, # [B*R, L]
592
+ reduction="none" # [B*R, L]
593
+ )
594
+
595
+ per_token_loss = per_token_loss * update_ids.float() # [B*R, L]
596
+
597
+ # Mask out non-updated positions and average per sample
598
+ per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R]
599
+
600
+ # Weight by importance sampling weights
601
+ weighted_loss = per_sample_loss * batch_weights # [B*R]
602
+
603
+ # ——— AUC / label-balance diagnostics (see loss_insert_planner_flexible) ———
604
+ with torch.no_grad():
605
+ metrics = {}
606
+ sel_u = update_ids.bool()
607
+ if sel_u.any():
608
+ u_scores = remasking_conf.squeeze(-1)[sel_u]
609
+ u_labels = binary_label[sel_u]
610
+ metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
611
+ metrics["unmask_label_mean"] = u_labels.mean().item()
612
+ metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
613
+ metrics["unmask_n"] = float(sel_u.sum().item())
614
+ self._last_planner_metrics = metrics
615
+
616
+ return weighted_loss.mean()
617
+
618
+ def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
619
+ r"""
620
+ Weighted denoising cross entropy loss
621
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
622
+
623
+ log_rnd: [B] — pre-computed importance weights
624
+ x: [B, L] (no mask)
625
+ num_replicates: R, number of replicates of each row in x
626
+ weight_func: w(lambda) for each sample, 1/lambda by default
627
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
628
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
629
+ """
630
+
631
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
632
+ batch_size = batch.shape[0]
633
+
634
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
635
+ if centering:
636
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
637
+
638
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
639
+
640
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
641
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
642
+
643
+ t = lamda
644
+ scale_factor = self.config.interpolant.max_length
645
+
646
+ # compute unmasking and insertion loss
647
+ # deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted
648
+ # gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices
649
+ interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch)
650
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
651
+
652
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
653
+
654
+ with torch.no_grad(): # no need to compute gradient in this step
655
+ xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction)
656
+
657
+ # The remasking head scores the freshly-decoded tokens to decide which to
658
+ # remask, so it must see the POST-unmask state xs_unmask (matching
659
+ # inference in inference_quality.py, which calls the planner on the
660
+ # decoded new_xt). Grad stays on here since this head is what we train.
661
+ planner = self.planner(xs_unmask, t)
662
+ remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
663
+
664
+ # The insertion-quality head scores the freshly-inserted mask tokens, so
665
+ # it must see the POST-insertion state xs_insert (aligned with
666
+ # update_ins_ids / insertion_quality below, and matching inference in
667
+ # remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on
668
+ # here since this head is what we are training.
669
+ if self.planner.insertion_planner:
670
+ insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1]
671
+ else:
672
+ insertion_conf = None
673
+
674
+ # Compute per-sample loss
675
+ # IMPORTANT: interpolant_sample.xt has been reordered via st permutation
676
+ # We need to map back to the original positions to compare with batch
677
+ # Use the st (permutation) to get the ground truth in the reordered space
678
+ st = interpolant_sample.st # [B*R, L] permutation indices
679
+ batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
680
+
681
+ # Now compare in the reordered space
682
+ binary_label = (xs_unmask == batch_reordered).float()
683
+
684
+ # Only compute loss on positions that were updated
685
+ per_token_loss = F.binary_cross_entropy_with_logits(
686
+ remasking_conf.squeeze(-1), # [B*R, L]
687
+ binary_label, # [B*R, L]
688
+ reduction="none" # [B*R, L]
689
+ )
690
+
691
+ per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L]
692
+
693
+ # Mask out non-updated positions and average per sample
694
+ unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R]
695
+
696
+ # compute insertion planner loss
697
+ # For positions where masks were inserted, we evaluate the quality of insertion
698
+ # by computing the probability that the ground truth token would be predicted at that position
699
+
700
+ # IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted
701
+ # The original prediction was computed from xt (before insertion)
702
+ with torch.no_grad():
703
+ prediction_after_insert: ModelPrediction = self(xs_insert, t)
704
+
705
+ # Get the token prediction probabilities at inserted mask positions
706
+ # prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert
707
+ token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V]
708
+
709
+ # For each gap where masks were inserted, compute the sum of probabilities
710
+ # of the ground truth tokens that were deleted in that specific gap
711
+ # gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices
712
+ # batch: [B*R, L] - ground truth tokens in original space (before permutation)
713
+
714
+ vocab_size = token_probs.shape[-1]
715
+ L = token_probs.shape[1]
716
+ max_gaps = gap_assignment.shape[1]
717
+
718
+ # For each gap, create a vocabulary mask of tokens that belong to that gap
719
+ # gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx
720
+ gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float)
721
+
722
+ # Vectorized: gather tokens from batch for all gaps at once
723
+ # tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions
724
+ tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L]
725
+
726
+ # valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad
727
+ valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L]
728
+
729
+ # Scatter tokens into vocabulary dimension: mark which tokens appear in each gap
730
+ gap_vocab_mask.scatter_add_(
731
+ 2, # scatter along vocabulary dimension
732
+ tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L]
733
+ valid_mask.float() # values to add [B*R, max_gaps, L]
734
+ )
735
+
736
+ # Binarize: a token either appears in the gap or not
737
+ gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V]
738
+
739
+ # For each insertion position in xs_insert, determine which gap it corresponds to
740
+ # Position p in xs_insert corresponds to gap p (insertions occur between existing tokens)
741
+ # Vectorized: compute for all positions at once
742
+ # token_probs: [B*R, L, V]
743
+ # gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1
744
+ insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L]
745
+
746
+ # Only consider quality at positions where masks were actually inserted
747
+ insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L]
748
+
749
+ # Compute insertion planner loss only if insertion_planner is enabled
750
+ if insertion_conf is not None:
751
+ # The planner predicts insertion confidence with insertion_conf
752
+ # We want to train it to predict high confidence when insertion_quality is high
753
+ # Use Bernoulli cross-entropy: treat insertion_quality as the "success probability"
754
+
755
+ # Binary cross-entropy with insertion_quality as continuous labels in [0,1]
756
+ ins_per_token_loss = F.binary_cross_entropy_with_logits(
757
+ insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits
758
+ insertion_quality, # [B*R, L] - ground truth token probability as quality metric
759
+ reduction="none"
760
+ )
761
+
762
+ # Only compute loss where masks were actually inserted
763
+ ins_per_token_loss = ins_per_token_loss * update_ins_ids.float()
764
+
765
+ # Average per sample
766
+ ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8)
767
+ else:
768
+ # No insertion planner - set loss to zero
769
+ ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss)
770
+
771
+ # Add to total loss
772
+ per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss
773
+
774
+ # Weight by importance sampling weights
775
+ weighted_loss = per_sample_loss * batch_weights # [B*R]
776
+
777
+ # ——— AUC / label-balance diagnostics (the loss alone hides degenerate
778
+ # targets; near-0 BCE can mean "all labels one class", not "learned") ———
779
+ with torch.no_grad():
780
+ metrics = {}
781
+ sel_u = update_unmask_ids.bool()
782
+ if sel_u.any():
783
+ u_scores = remasking_conf.squeeze(-1)[sel_u]
784
+ u_labels = binary_label[sel_u]
785
+ metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
786
+ metrics["unmask_label_mean"] = u_labels.mean().item()
787
+ metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
788
+ metrics["unmask_n"] = float(sel_u.sum().item())
789
+ if insertion_conf is not None:
790
+ sel_i = update_ins_ids.bool()
791
+ if sel_i.any():
792
+ i_scores = insertion_conf.squeeze(-1)[sel_i]
793
+ i_targets = insertion_quality[sel_i]
794
+ i_labels = (i_targets > 0.5).float()
795
+ metrics["insert_auc"] = _binary_auc(i_scores, i_labels)
796
+ metrics["insert_target_mean"] = i_targets.mean().item()
797
+ metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item()
798
+ metrics["insert_n"] = float(sel_i.sum().item())
799
+ self._last_planner_metrics = metrics
800
+
801
+ return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean()
lightning_modules/any_order.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from omegaconf import DictConfig
4
+ import torch.nn.functional as F
5
+ from model.transformer import AnyOrderMaskInsertionFlow
6
+ from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
7
+ from .bregman import jump_kernel_elbo, mse
8
+ from .schedule import get_schedule_from_config
9
+
10
+
11
+ import re
12
+ from typing import Dict, Any
13
+
14
+
15
+ def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
16
+ """
17
+ Returns a new state_dict where any key containing '._orig_mod.' is replaced
18
+ by removing the '_orig_mod' segment, e.g.
19
+ 'model._orig_mod.vocab_embed.embedding'
20
+ becomes
21
+ 'model.vocab_embed.embedding'
22
+ """
23
+ new_state_dict: Dict[str, Any] = {}
24
+ for key, value in state_dict.items():
25
+ # remove all occurrences of '._orig_mod.'
26
+ clean_key = re.sub(r"\._orig_mod\.", ".", key)
27
+ new_state_dict[clean_key] = value
28
+ return new_state_dict
29
+
30
+
31
+ class AnyOrderInsertionFlowModule(pl.LightningModule):
32
+ def __init__(self, config: DictConfig):
33
+ super().__init__()
34
+ self.config = config
35
+ self.model_type = config.interpolant.type
36
+ self.learning_rate = config.training.learning_rate
37
+ self.unmask_loss_fn = config.training.loss_fn.unmask
38
+ self.insert_loss_fn = config.training.loss_fn.insert
39
+
40
+ # Initialize model based on type
41
+ self.model = AnyOrderMaskInsertionFlow(config)
42
+ # self.model = torch.compile(self.model) # Disabled: incompatible with flex_attention nested functions
43
+
44
+ insert_schedule = get_schedule_from_config(config.interpolant.insert_schedule)
45
+ unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
46
+
47
+ # Initialize interpolant
48
+ self.interpolant = AnyOrderMaskInsertionInterpolant(
49
+ insertion_schedule=insert_schedule,
50
+ unmask_schedule=unmask_schedule,
51
+ vocab_size=config.interpolant.tokens,
52
+ mask_token=config.interpolant.mask_token,
53
+ pad_token=config.interpolant.pad_token,
54
+ max_length=config.interpolant.max_length,
55
+ )
56
+
57
+ # Save hyperparameters
58
+ self.save_hyperparameters()
59
+
60
+ self.ema_decay = config.training.ema_decay or 0.0
61
+ self.use_ema = self.ema_decay > 0
62
+ self._orig_params = {}
63
+
64
+ def forward(self, x, t, return_features: bool = False):
65
+ if self.config.training.only_embed_insert:
66
+ result = self.model(x, self.interpolant.insertion_schedule.at(t), return_features=return_features)
67
+ else:
68
+ result = self.model(x, t, return_features=return_features)
69
+ return result
70
+
71
+ def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
72
+ """Delegate to backbone transformer for RemaskingAnyOrder compatibility."""
73
+ return self.model.get_hidden_states(indices, t)
74
+
75
+ def training_loss(self, x1, t):
76
+ interpolant_sample = self.interpolant.sample_interpolant(t, x1)
77
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x1)
78
+
79
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
80
+
81
+ scale_factor = x1.shape[0] * self.config.interpolant.max_length
82
+
83
+ match self.unmask_loss_fn:
84
+ case "elbo":
85
+ mask_indices = interpolant_sample.mask_indices
86
+ unmask_loss = unmask_weight[mask_indices] * F.cross_entropy(
87
+ prediction.token_logits[mask_indices],
88
+ interpolant_sample.unmasked[mask_indices],
89
+ reduction="none",
90
+ )
91
+ unmask_loss = unmask_loss.sum() / scale_factor
92
+ case _:
93
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
94
+
95
+ match self.insert_loss_fn:
96
+ case "expectation":
97
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
98
+ insertion_loss = insert_weight[gaps_mask] * jump_kernel_elbo(
99
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
100
+ )
101
+ insertion_loss = insertion_loss.sum() / scale_factor
102
+
103
+ case "distribution":
104
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
105
+ insertion_loss = insert_weight[gaps_mask] * F.cross_entropy(
106
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
107
+ )
108
+ insertion_loss = insertion_loss.sum() / scale_factor
109
+
110
+ total_loss = unmask_loss + insertion_loss
111
+ return unmask_loss, insertion_loss, total_loss
112
+
113
+ def prepare_noised_sample(self, x, num_samples=1, t=None):
114
+ """
115
+ Run the forward noising process on clean sequences x.
116
+ Replicates each sequence num_samples times with independent random times
117
+ so that both policy and pretrained can evaluate the same noised data.
118
+
119
+ Args:
120
+ x: [B, L] clean token sequences (no mask tokens)
121
+ num_samples: K, number of noisy time samples per sequence
122
+ t: [B*K] optional time values. If None, sampled uniformly.
123
+
124
+ Returns:
125
+ dict with all artifacts needed by compute_loss_from_noised.
126
+ """
127
+ B = x.shape[0]
128
+ x_rep = x.repeat_interleave(num_samples, dim=0) # [B*K, L]
129
+ if t is None:
130
+ t = torch.rand(B * num_samples, device=x.device)
131
+
132
+ interpolant_sample = self.interpolant.sample_interpolant(t, x_rep)
133
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x_rep)
134
+ scale_factor = self.config.interpolant.max_length
135
+
136
+ return {
137
+ "interpolant_sample": interpolant_sample,
138
+ "unmask_weight": unmask_weight,
139
+ "insert_weight": insert_weight,
140
+ "t": t,
141
+ "scale_factor": scale_factor,
142
+ "num_samples": num_samples,
143
+ "batch_size": B,
144
+ }
145
+
146
+ def compute_loss_from_noised(self, noised):
147
+ """
148
+ Compute per-sample denoising loss given pre-noised data.
149
+ Each model runs its own forward pass on the shared noised xt.
150
+
151
+ Args:
152
+ noised: dict from prepare_noised_sample()
153
+
154
+ Returns:
155
+ total_loss: [B] per-sample loss averaged over K noisy samples
156
+ """
157
+ interpolant_sample = noised["interpolant_sample"]
158
+ unmask_weight = noised["unmask_weight"]
159
+ insert_weight = noised["insert_weight"]
160
+ t = noised["t"]
161
+ scale_factor = noised["scale_factor"]
162
+ num_samples = noised["num_samples"]
163
+ B = noised["batch_size"]
164
+
165
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
166
+
167
+ match self.unmask_loss_fn:
168
+ case "elbo":
169
+ mask_indices = interpolant_sample.mask_indices
170
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*K, L]
171
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
172
+ prediction.token_logits[mask_indices],
173
+ interpolant_sample.unmasked[mask_indices],
174
+ reduction="none",
175
+ )
176
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*K]
177
+ case _:
178
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
179
+
180
+ match self.insert_loss_fn:
181
+ case "expectation":
182
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
183
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
184
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
185
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
186
+ )
187
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
188
+ case "distribution":
189
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
190
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
191
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
192
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
193
+ )
194
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
195
+
196
+ per_replicate_loss = unmask_loss + insertion_loss # [B*K]
197
+ per_sample_loss = per_replicate_loss.view(B, num_samples).mean(dim=1) # [B]
198
+ return per_sample_loss
199
+
200
+ def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False):
201
+ r"""
202
+ Weighted denoising cross entropy loss
203
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
204
+
205
+ log_rnd: [B]; x: [B, L] (no mask)
206
+ num_replicates: R, number of replicates of each row in x
207
+ weight_func: w(lambda) for each sample, 1/lambda by default
208
+ """
209
+
210
+ print("logrnd shape:", log_rnd.shape)
211
+ print("x shape:", x.shape)
212
+
213
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
214
+
215
+ batch_weights = log_rnd.detach().softmax(dim=-1) # [B*R]
216
+ if centering:
217
+ batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
218
+
219
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
220
+
221
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
222
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
223
+
224
+ t = lamda
225
+
226
+ # compute unmasking and insertion loss
227
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
228
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
229
+
230
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
231
+
232
+ scale_factor = self.config.interpolant.max_length
233
+
234
+ match self.unmask_loss_fn:
235
+ case "elbo":
236
+ mask_indices = interpolant_sample.mask_indices
237
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
238
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
239
+ prediction.token_logits[mask_indices],
240
+ interpolant_sample.unmasked[mask_indices],
241
+ reduction="none",
242
+ )
243
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
244
+ case _:
245
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
246
+
247
+ match self.insert_loss_fn:
248
+ case "expectation":
249
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
250
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
251
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
252
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
253
+ )
254
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
255
+
256
+ case "distribution":
257
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
258
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
259
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
260
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
261
+ )
262
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
263
+
264
+ total_loss = unmask_loss + insertion_loss # [B*R]
265
+ # end compute unmasking and insertion loss
266
+
267
+ weighted_loss = total_loss * batch_weights # [B*R]
268
+ return weighted_loss.mean()
269
+
270
+ def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
271
+ eps = 1e-6
272
+ interval = 1.0 - eps
273
+ interval_size = interval / batch_size
274
+ u = torch.rand(batch_size, device=device)
275
+ return (torch.arange(batch_size, device=device, dtype=u.dtype) + u) * interval_size
276
+
277
+ def training_step(self, batch, batch_idx):
278
+ # Extract input data
279
+ if isinstance(batch, dict):
280
+ batch = batch["input_ids"]
281
+
282
+ x1 = batch
283
+ t = self.sample_time(x1.shape[0], x1.device)
284
+
285
+ # Calculate the combined loss normally
286
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
287
+
288
+ # Log component losses
289
+ self.log("train/unmask_loss", unmask_loss, prog_bar=True)
290
+ self.log("train/len_loss", len_loss, prog_bar=True)
291
+ self.log("train/total_loss", loss, prog_bar=True)
292
+
293
+
294
+ return loss
295
+
296
+ def validation_step(self, batch, batch_idx):
297
+ if isinstance(batch, dict):
298
+ batch = batch["input_ids"]
299
+
300
+ x1 = batch
301
+ t = self.sample_time(x1.shape[0], x1.device)
302
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
303
+
304
+ self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
305
+ self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
306
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
307
+
308
+ return loss
309
+
310
+ def configure_optimizers(self):
311
+ optimizer = torch.optim.AdamW(
312
+ self.parameters(),
313
+ lr=self.learning_rate,
314
+ weight_decay=self.config.training.weight_decay,
315
+ )
316
+
317
+ warmup_steps = self.config.training.warmup_steps
318
+ max_steps = self.config.training.max_steps
319
+
320
+ # Always create a fresh schedule starting from step 0
321
+ # This allows extending training beyond original max_steps
322
+ linear_scheduler = torch.optim.lr_scheduler.LinearLR(
323
+ optimizer,
324
+ start_factor=1e-6,
325
+ end_factor=1.0,
326
+ total_iters=warmup_steps,
327
+ last_epoch=-1,
328
+ )
329
+ post_warmup = max_steps - warmup_steps
330
+ cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
331
+ optimizer,
332
+ T_max=post_warmup,
333
+ eta_min=0.0,
334
+ last_epoch=-1,
335
+ )
336
+
337
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
338
+ optimizer,
339
+ schedulers=[linear_scheduler, cosine_scheduler],
340
+ milestones=[warmup_steps],
341
+ last_epoch=-1,
342
+ )
343
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
344
+
345
+ def optimizer_step(
346
+ self,
347
+ epoch: int,
348
+ batch_idx: int,
349
+ optimizer,
350
+ optimizer_closure=None,
351
+ ):
352
+ super().optimizer_step(
353
+ epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
354
+ )
355
+ # log learning rate and gradient norm
356
+ lr = optimizer.param_groups[0]["lr"]
357
+ self.log("train/lr", lr, on_step=True, prog_bar=True)
358
+ grad_norm = torch.sqrt(
359
+ sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
360
+ )
361
+ self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
362
+
363
+ # update EMA
364
+ if self.use_ema:
365
+ for n, p in self.named_parameters():
366
+ self.ema_params[n].mul_(self.ema_decay).add_(
367
+ p.data.clone().detach(), alpha=1 - self.ema_decay
368
+ )
369
+
370
+ def on_save_checkpoint(self, checkpoint):
371
+ checkpoint["config"] = self.config
372
+ # save EMA state
373
+ if self.use_ema:
374
+ checkpoint["ema_params"] = {
375
+ n: v.clone() for n, v in self.ema_params.items()
376
+ }
377
+
378
+ def on_load_checkpoint(self, checkpoint):
379
+ self.config = checkpoint["config"]
380
+
381
+ insert_schedule = get_schedule_from_config(
382
+ self.config.interpolant.insert_schedule
383
+ )
384
+ unmask_schedule = get_schedule_from_config(
385
+ self.config.interpolant.unmask_schedule
386
+ )
387
+
388
+ self.interpolant = AnyOrderMaskInsertionInterpolant(
389
+ insertion_schedule=insert_schedule,
390
+ unmask_schedule=unmask_schedule,
391
+ vocab_size=self.config.interpolant.tokens,
392
+ mask_token=self.config.interpolant.mask_token,
393
+ pad_token=self.config.interpolant.pad_token,
394
+ max_length=self.config.interpolant.max_length,
395
+ )
396
+
397
+ self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
398
+
399
+ def swap_to_ema(self):
400
+ for name, p in self.named_parameters():
401
+ self._orig_params[name] = p.data.clone()
402
+ p.data.copy_(self.ema_params[name].to(p.device))
403
+
404
+ def restore_original(self):
405
+ for name, p in self.named_parameters():
406
+ p.data.copy_(self._orig_params[name])
407
+ self._orig_params.clear()
408
+
409
+ def on_train_start(self):
410
+ # initialize and move EMA buffers once model is on correct device
411
+ if self.use_ema:
412
+ self.ema_params = {
413
+ name: param.clone().detach().to(self.device)
414
+ for name, param in self.named_parameters()
415
+ }
416
+ for buf in self.ema_params.values():
417
+ buf.requires_grad = False