diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..6e8075b1ea694b70e967356b554bddda7b77f13d
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..01080488b905578eb4f1196ac2213c50a5682c4b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+checkpoints/
+pretrained/
+__pycache__/
+results/
+a2d2_language/
+a2d2_language/wandb/
+a2d2_pep/wandb/
+a2d2_mol/wandb/
+logs/
+*.pt
+*.pyc
+*.out
+*.json
+*.log
+*.txt
+*.wandb
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..839c375c02da75f7ae46b3c50f8689a788fcdcee
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2026 Sophia Tang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fef6ef0d7a5f8213300b5372dfe9d0a198de0db8
--- /dev/null
+++ b/README.md
@@ -0,0 +1,62 @@
+# [A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding](https://arxiv.org/abs/2606.13565) 🃏🔮
+
+[**Sophia Tang**](https://sophtang.github.io/), [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/), [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
+
+
+
+
+
+
+
+
+This is the repository for the paper [**A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**](https://arxiv.org/abs/2606.13565).
+
+Masked discrete diffusion models (MDMs) offer a simple, stable likelihood-based framework for sequence generation, recently extended to **any-length** settings via token insertion. **A2D2** is a unified framework for reward-guided fine-tuning of any-length MDMs that **jointly optimizes the insertion and unmasking policies together with a quality-based inference schedule**, converging to the intractable reward-tilted distribution without requiring target samples.
+
+🃏 We derive the **Radon–Nikodym derivative** for the joint insertion–unmasking path measures, enabling theoretically guaranteed convergence to the reward-tilted sequence distribution.
+
+🃏 We establish **unmasking and insertion quality** as tractable approaches for minimizing decoding error (compounding parallelization error), and train lightweight quality predictors alongside the policy.
+
+🃏 We introduce the **Adaptive Joint Decoding (AJD)** loss, which provably yields the optimal path measure that generates the reward-tilted distribution while remasking low-quality tokens and dropping low-quality insertions at inference.
+
+🃏 Empirically, A2D2 improves reward optimization while enhancing generation **flexibility** and **accuracy** over prior fixed-length fine-tuning and inference-time guidance methods.
+
+## Drug-Like Small Molecule Design 🧪
+
+We pre-train an any-length MDM on the **SAFE** dataset ([Noutahi et al. 2024](https://arxiv.org/abs/2310.10773), ~950M molecules from ZINC and Unichem in SAFE notation) and fine-tune it with **A2D2** to optimize **QED** (drug-likeness) and **synthetic accessibility (SA)**. A2D2 jointly raises QED and lowers SA over the pre-trained baseline while increasing the fraction of valid, unique, drug-like, and synthesizable molecules. Code and instructions are in [`/a2d2_mol`](a2d2_mol).
+
+## Multi-Objective Therapeutic Peptide Generation 💉
+
+We pre-train an any-length **peptide SMILES** MDM on ~11M peptides (CycPeptMPDB, SmProt, CycloPs) and fine-tune with **A2D2** on five therapeutic properties simultaneously: **target-protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability**. A2D2 outperforms inference-time multi-objective guidance and fixed-length off-policy RL fine-tuning on almost all objectives, while improving the fraction of valid peptides. Code and instructions are in [`/a2d2_pep`](a2d2_pep).
+
+## Language Model Reasoning 🧠
+
+We additionally apply **A2D2** to reward fine-tuning of any-length language MDMs (LLaDA / FlexMDM), optimizing math-reasoning correctness and format rewards (GSM8K / MATH), including infilling variants. Code is in [`/a2d2_language`](a2d2_language).
+
+## Repository Structure
+
+| Directory | Experiment |
+|-----------|------------|
+| [`a2d2_mol`](a2d2_mol) | Drug-like small molecule design (QED, SA) |
+| [`a2d2_pep`](a2d2_pep) | Multi-objective therapeutic peptide generation |
+| [`a2d2_language`](a2d2_language) | Language model reasoning reward fine-tuning (code soon) |
+| [`lightning_modules`](lightning_modules) | Any-length insertion MDM Lightning modules (policy + quality predictors) |
+| [`model`](model) | Shared model architecture |
+| [`demo`](demo) | Quality-guided inference demo notebook |
+
+Each experiment directory contains its own `README.md` with environment setup, pretrained weight placement, fine-tuning commands, and evaluation instructions.
+
+## Citation
+
+If you find this repository helpful for your publications, please consider citing our paper:
+
+```python
+@article{tang2026a2d2,
+ title={A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding},
+ author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee},
+ journal={arXiv preprint arXiv:2606.13565},
+ year={2026}
+}
+```
+
+To use this repository, you agree to abide by the MIT License.
\ No newline at end of file
diff --git a/a2d2_mol/README.md b/a2d2_mol/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..61db0ba8e8ffbe7b929b1dd17480e9248fbe81ca
--- /dev/null
+++ b/a2d2_mol/README.md
@@ -0,0 +1,132 @@
+# A2D2 for Molecule Generation 🧪
+
+This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA).
+
+A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
+
+Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer.
+
+The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
+
+## Environment Installation
+```
+# from the repository root
+conda env create -f environment.yml
+
+conda activate a2d2
+```
+The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
+
+## Model Pretrained Weights
+
+A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at:
+```
+A2D2/pretrained/anylength_mol.ckpt
+```
+```bash
+# from the repository root
+pip install gdown
+mkdir -p pretrained
+gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt
+```
+(Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
+This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout.
+
+## Pretraining the Any-Length Model
+
+If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
+
+### 1. The pretraining dataset
+
+The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run.
+
+The dataset is configured in [`config_mol.yaml`](config_mol.yaml):
+
+```yaml
+hf_dataset:
+ name: "datamol-io/safe-gpt"
+ smiles_column: "smiles"
+```
+
+To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column).
+
+### 2. Configure
+
+Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields:
+
+| Field | Default | Notes |
+|-------|---------|-------|
+| `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). |
+| `training.devices` | `2` | GPUs per node (DDP). |
+| `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
+| `training.max_steps` | `500000` | Total optimizer steps. |
+| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
+| `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). |
+| `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. |
+| `interpolant.max_length` | `256` | Max token length. |
+
+### 3. Pre-training Any-Length Molecule Model
+
+Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
+
+```bash
+# from a2d2_mol/
+sbatch train_mol.sh
+```
+
+`train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
+
+```bash
+python train.py --task mol
+```
+
+It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`.
+
+Checkpoints are written to `checkpoints/pretrain_mol//` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/_a2d2-mol_.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
+
+## Fine-Tune with A2D2
+
+The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running:
+
+1. Set `--base_path` to the location of `a2d2`. Results plots are written to `/flexible/results//`.
+2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`.
+
+### Single run
+
+[`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand.
+
+The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`):
+```bash
+export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
+export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
+sbatch scripts/run_mol_finetune.slurm
+```
+
+Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
+```bash
+sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm
+```
+The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/_mol_/` and `results/mol_ablation//`.
+
+### Ablation flags
+| Flag | Variant |
+|------|---------|
+| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
+| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
+| `--disable_insertion_planner` | A2D2 w/o insertion quality |
+| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
+| `--joint_training` | train policy + quality heads jointly (no alternation) |
+
+## Evaluation
+
+Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually:
+```
+python evaluate_mol_table.py \
+ --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \
+ --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \
+ --output_dir /path/to/results \
+ --num_samples 1000 --batch_size 50 \
+ --max_length 256 --total_num_steps 256 \
+ --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0
+```
+This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_.csv`.
diff --git a/a2d2_mol/config_mol.yaml b/a2d2_mol/config_mol.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a2a659d8417aba012326c69218188b80aa9e4a4
--- /dev/null
+++ b/a2d2_mol/config_mol.yaml
@@ -0,0 +1,54 @@
+trainer: "any-order-flow"
+dataset: "safe-drugs"
+
+# HuggingFace dataset configuration
+hf_dataset:
+ name: "datamol-io/safe-gpt"
+ smiles_column: "smiles" # Adjust based on actual column name in the dataset
+
+model:
+ hidden_size: 768
+ n_heads: 12
+ cond_dim: 128
+ dropout: 0.05
+ n_blocks: 12
+ torch_dtype: 'float32' # Options: 'float32', 'float16', 'bfloat16'
+
+interpolant:
+ type: "any-order"
+ tokens: null # filled in automatically
+ pad_token: null # filled in automatically
+ mask_token: null # filled in automatically
+ max_length: 256
+ insert_schedule:
+ type: "linear"
+ unmask_schedule:
+ type: "linear"
+
+training:
+ only_embed_insert: true
+ batch_size: 2048
+ per_gpu_batch_size: 64 # Gradient accumulation happens automatically
+ cpus: 4
+ learning_rate: 3e-4
+ nodes: 1
+ devices: 2
+ max_steps: 500000
+ weight_decay: 0.03
+ checkpoint_dir: "checkpoints/pretrain_mol"
+ save_top_k: 3
+ save_every_n_steps: 1000 # Save checkpoint every 1k steps (for streaming datasets)
+ # save_every_n_epochs: 1 # Not used with streaming datasets
+ loss_fn:
+ unmask: "elbo"
+ insert: "expectation"
+ reset_lr: false
+ warmup_steps: 2000
+ ema_decay: 0.9999
+ filter_max_length: false
+
+wandb:
+ entity: null # set to your W&B entity, or leave null to use the default
+ project: "a2d2-mol"
+ name: "a2d2-mol"
+ path: "./wandb"
diff --git a/a2d2_mol/evaluate_mol_table.py b/a2d2_mol/evaluate_mol_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..103b277a1d0fa634be4e4b8e465bfe524e242327
--- /dev/null
+++ b/a2d2_mol/evaluate_mol_table.py
@@ -0,0 +1,308 @@
+"""
+Evaluate a finetuned molecule model checkpoint by sampling sequences
+and computing metrics for the De Novo Small Molecule Generation table:
+ Validity (%), Uniqueness (%), QED (↑), SA (↓), Quality (%), Diversity (↑), Sampling Time (↓)
+"""
+
+import os
+import sys
+import argparse
+import time
+import torch
+import numpy as np
+import pandas as pd
+from tdc import Oracle, Evaluator
+
+# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, REPO_ROOT)
+
+from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
+from lightning_modules import AnyOrderInsertionFlowModule
+from inference_quality_mol import sample_mol_eval
+from mol_scoring.scoring_functions import MolScoringFunctions
+from finetune_mol import MolFinetuner, get_tokenizer
+from mol_utils.utils import str2bool, set_seed
+
+
+def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
+ """Load a finetuned MolFinetuner from a Lightning checkpoint."""
+ # We need to reconstruct the model the same way main() does, then load state
+ # Load from Lightning checkpoint directly
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
+ hparams = ckpt.get('hyper_parameters', {})
+ args = hparams.get('args', None)
+
+ # Load pretrained base checkpoint to get config
+ base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
+ if 'hyper_parameters' in base_ckpt:
+ config = base_ckpt['hyper_parameters']['config']
+ elif 'config' in base_ckpt:
+ config = base_ckpt['config']
+ else:
+ raise ValueError("Cannot find config in base checkpoint")
+
+ from omegaconf import OmegaConf, DictConfig
+ if not OmegaConf.is_config(config):
+ config = DictConfig(config)
+ OmegaConf.set_struct(config, False)
+
+ # Set adaptive schedule config from args or defaults
+ config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
+ config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
+ config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
+ config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
+ config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
+ config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
+ config.training.use_bracket_safe = True
+ OmegaConf.set_struct(config, True)
+
+ # Determine if planner should be loaded based on disable_planner flag
+ disable_planner = getattr(args, 'disable_planner', False)
+
+ # Initialize policy model
+ policy_model = AnyOrderInsertionFlowModuleFT(
+ config=config,
+ args=args,
+ pretrained_checkpoint=pretrained_ckpt_path,
+ insertion_planner=not disable_planner,
+ )
+
+ # Load policy model weights from the finetuned checkpoint
+ state_dict = ckpt['state_dict']
+ # Lightning wraps the model: 'policy_model.xxx' -> remove prefix for the sub-module
+ policy_state = {}
+ for k, v in state_dict.items():
+ if k.startswith('policy_model.'):
+ policy_state[k[len('policy_model.'):]] = v
+ policy_model.load_state_dict(policy_state, strict=False)
+ policy_model = policy_model.to(device)
+ policy_model.eval()
+
+ return policy_model, args, config
+
+
+@torch.no_grad()
+def evaluate_checkpoint(policy_model, tokenizer, reward_model, evaluator,
+ num_samples=1000, batch_size=50, max_length=256,
+ total_num_steps=256, quality_mode="both", num_remasking=2,
+ quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
+ """
+ Sample `num_samples` molecules and compute all table metrics.
+ Returns a dict with: validity, uniqueness, qed, sa, quality, diversity, sampling_time
+ """
+ all_valid_seqs = []
+ all_smiles_generated = 0
+ total_time = 0.0
+
+ num_batches = (num_samples + batch_size - 1) // batch_size
+ remaining = num_samples
+
+ for b in range(num_batches):
+ bs = min(batch_size, remaining)
+ remaining -= bs
+
+ t_start = time.time()
+ result = sample_mol_eval(
+ model=policy_model,
+ reward_model=reward_model,
+ tokenizer=tokenizer,
+ steps=total_num_steps,
+ mask=policy_model.interpolant.mask_token,
+ pad=policy_model.interpolant.pad_token,
+ batch_size=bs,
+ max_length=max_length,
+ quality_mode=quality_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ unmask_quality_threshold=unmask_quality_threshold,
+ evaluator=evaluator,
+ dataframe=True,
+ )
+ t_end = time.time()
+
+ # Unpack: uniqueSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
+ unique_seqs, qed_scores, sa_scores, valid_frac, uniq, div, qual, df = result
+
+ all_valid_seqs.extend(list(unique_seqs) if not isinstance(unique_seqs, list) else unique_seqs)
+ all_smiles_generated += bs
+ total_time += (t_end - t_start)
+
+ print(f" Batch {b+1}/{num_batches}: {len(unique_seqs)} valid unique, "
+ f"time={t_end - t_start:.1f}s")
+
+ # --- Aggregate metrics over all samples ---
+ total_generated = num_samples
+
+ # Valid sequences (keeping duplicates for validity count)
+ # Re-evaluate from scratch on all collected valid sequences
+ all_unique = list(set(all_valid_seqs))
+ num_valid = len(all_valid_seqs) # total valid across batches (before dedup)
+ num_unique = len(all_unique)
+
+ validity = num_valid / total_generated * 100.0
+ uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
+
+ # Diversity on unique SMILES
+ diversity = evaluator(all_unique) if num_unique > 1 else 0.0
+
+ # QED and SA on unique sequences
+ if num_unique > 0:
+ oracle_qed = Oracle('qed')
+ oracle_sa = Oracle('sa')
+ qed_vals = oracle_qed(all_unique)
+ sa_vals = oracle_sa(all_unique)
+ mean_qed = np.mean(qed_vals)
+ mean_sa = np.mean(sa_vals)
+
+ # Quality: unique sequences with QED >= 0.6 AND SA <= 4
+ quality_mask = [(q >= 0.6 and s <= 4) for q, s in zip(qed_vals, sa_vals)]
+ quality = sum(quality_mask) / total_generated * 100.0
+ else:
+ mean_qed = 0.0
+ mean_sa = 0.0
+ quality = 0.0
+
+ sampling_time = total_time
+
+ metrics = {
+ 'Validity (%)': validity,
+ 'Uniqueness (%)': uniqueness,
+ 'QED': mean_qed,
+ 'Synthetic Accessibility': mean_sa,
+ 'Quality (%)': quality,
+ 'Diversity': diversity,
+ 'Sampling Time (s)': sampling_time,
+ 'Num Generated': total_generated,
+ 'Num Valid': num_valid,
+ 'Num Unique': num_unique,
+ }
+
+ return metrics, all_unique, qed_vals if num_unique > 0 else [], sa_vals if num_unique > 0 else []
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Evaluate a finetuned mol checkpoint")
+ parser.add_argument('--checkpoint_path', type=str, required=True,
+ help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
+ parser.add_argument('--pretrained_ckpt', type=str,
+ default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt'),
+ help='Path to the pretrained base model checkpoint '
+ '(defaults to /pretrained/anylength_mol.ckpt)')
+ parser.add_argument('--num_samples', type=int, default=1000,
+ help='Number of molecules to sample')
+ parser.add_argument('--batch_size', type=int, default=50,
+ help='Batch size for sampling')
+ parser.add_argument('--max_length', type=int, default=256)
+ parser.add_argument('--total_num_steps', type=int, default=256)
+ parser.add_argument('--num_remasking', type=int, default=2)
+ parser.add_argument('--disable_planner', action='store_true',
+ help='If set, disable remasking during evaluation (matches training mode)')
+ parser.add_argument('--disable_insertion_planner', action='store_true',
+ help='If set, disable insertion quality filtering during evaluation')
+ parser.add_argument('--disable_unmasking_planner', action='store_true',
+ help='If set, disable unmasking confidence planner during evaluation')
+ parser.add_argument('--quality_threshold', type=float, default=0.5,
+ help='Threshold for insertion quality filtering during sampling')
+ parser.add_argument('--unmask_quality_threshold', type=float, default=None,
+ help='If set, gate unmasking remasking on confidence: remask clean '
+ 'tokens whose remasking_conf < threshold (overrides the '
+ 'schedule-driven count). Default None = schedule-driven behavior.')
+ parser.add_argument('--output_dir', type=str, default=None,
+ help='Directory to save results CSV. Defaults to checkpoint directory.')
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--seed', type=int, default=42)
+ args = parser.parse_args()
+
+ set_seed(args.seed, use_cuda=True)
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
+
+ print(f"Loading checkpoint: {args.checkpoint_path}")
+ print(f"Pretrained base: {args.pretrained_ckpt}")
+ print(f"Disable planner (no remasking): {args.disable_planner}")
+ print(f"Disable insertion planner: {args.disable_insertion_planner}")
+ print(f"Disable unmasking planner: {args.disable_unmasking_planner}")
+
+ policy_model, train_args, config = load_finetuned_model(
+ args.checkpoint_path, args.pretrained_ckpt, device=device
+ )
+
+ tokenizer = get_tokenizer()
+ score_func_names = ['qed', 'sa']
+ reward_model = MolScoringFunctions(score_func_names, device=device)
+ evaluator = Evaluator('diversity')
+
+ use_remasking = not args.disable_planner
+ disable_insertion_planner = args.disable_insertion_planner
+ disable_unmasking_planner = args.disable_unmasking_planner
+
+ # Map flags to quality_mode
+ if args.disable_planner:
+ quality_mode = "none"
+ elif args.disable_insertion_planner and args.disable_unmasking_planner:
+ quality_mode = "none"
+ elif args.disable_insertion_planner:
+ quality_mode = "unmasking_only"
+ elif args.disable_unmasking_planner:
+ quality_mode = "insertion_only"
+ else:
+ quality_mode = "both"
+
+ print(f"\nSampling {args.num_samples} molecules (quality_mode={quality_mode})...")
+
+ metrics, unique_smiles, qed_vals, sa_vals = evaluate_checkpoint(
+ policy_model=policy_model,
+ tokenizer=tokenizer,
+ reward_model=reward_model,
+ evaluator=evaluator,
+ num_samples=args.num_samples,
+ batch_size=args.batch_size,
+ max_length=args.max_length,
+ total_num_steps=args.total_num_steps,
+ quality_mode=quality_mode,
+ num_remasking=args.num_remasking,
+ quality_threshold=getattr(args, 'quality_threshold', 0.5),
+ unmask_quality_threshold=args.unmask_quality_threshold,
+ device=device,
+ )
+
+ # Print summary table
+ print("\n" + "=" * 60)
+ print(" De Novo Small Molecule Generation Results")
+ print("=" * 60)
+ for k, v in metrics.items():
+ if isinstance(v, float):
+ print(f" {k:<30s}: {v:.4f}")
+ else:
+ print(f" {k:<30s}: {v}")
+ print("=" * 60)
+
+ # Save results
+ output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ if args.disable_planner:
+ tag = "no_planner"
+ elif args.disable_insertion_planner:
+ tag = "no_insertion_planner"
+ elif args.disable_unmasking_planner:
+ tag = "no_unmasking_planner"
+ else:
+ tag = "with_planner"
+ metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}.csv')
+ pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
+ print(f"Metrics saved to: {metrics_path}")
+
+ if unique_smiles:
+ smiles_path = os.path.join(output_dir, f'eval_smiles_{tag}.csv')
+ df = pd.DataFrame({
+ 'SMILES': unique_smiles,
+ 'QED': qed_vals,
+ 'SA': sa_vals,
+ })
+ df.to_csv(smiles_path, index=False)
+ print(f"SMILES saved to: {smiles_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/a2d2_mol/finetune_mol.py b/a2d2_mol/finetune_mol.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0d84ebba28cd76e46ef5545662f59a47206d089
--- /dev/null
+++ b/a2d2_mol/finetune_mol.py
@@ -0,0 +1,747 @@
+import argparse
+from datetime import datetime
+import numpy as np
+import torch
+import pytorch_lightning as pl
+from pytorch_lightning.strategies import DDPStrategy
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.loggers import WandbLogger
+import wandb
+import os
+import sys
+from tqdm import tqdm
+import pandas as pd
+
+# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+# imports
+from inference_quality_mol import sample_mol_buffer, sample_mol_eval
+from mol_utils.utils import str2bool, set_seed
+from mol_scoring.scoring_functions import MolScoringFunctions
+from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
+from lightning_modules import AnyOrderInsertionFlowModule
+from safe.tokenizer import SAFETokenizer
+from tdc import Evaluator
+
+# Repository root (two levels up from this file: A2D2/a2d2_mol/finetune_mol.py)
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+def get_tokenizer():
+ """Get SAFE tokenizer with added special tokens."""
+ tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
+ tk.add_tokens(['<', '>']) # for bracket_safe
+ return tk
+
+class MolFinetuner(pl.LightningModule):
+ """Lightning module for distributed molecule finetuning."""
+
+ def __init__(
+ self,
+ args,
+ policy_model,
+ reward_model,
+ tokenizer,
+ pretrained=None,
+ mcts=None,
+ filename=None,
+ eps=1e-5
+ ):
+ super().__init__()
+ self.args = args
+ self.policy_model = policy_model
+ self.reward_model = reward_model
+ self.tokenizer = tokenizer
+ self.pretrained = pretrained
+ self.mcts = mcts
+ self.filename = filename
+ self.eps = eps
+
+ self.evaluator = Evaluator("diversity")
+
+ # Save hyperparameters
+ self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
+
+ # Buffer for sequences
+ self.x_saved = None
+ self.log_rnd_saved = None
+ self.final_rewards_saved = None
+
+ # initialize logs
+ self.valid_fraction_log = []
+ self.diversity_log = []
+ self.qed_log = []
+ self.sa_log = []
+ self.quality_log = []
+ self.uniqueness_log = []
+
+ # Alternating training between policy and planner
+ self.train_policy = True # Start by training policy
+ self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
+
+ def freeze_policy_model(self):
+ """Freeze policy model parameters (but not planner)."""
+ for name, param in self.policy_model.named_parameters():
+ if not name.startswith('planner.'):
+ param.requires_grad = False
+
+ def unfreeze_policy_model(self):
+ """Unfreeze policy model parameters (but not planner)."""
+ for name, param in self.policy_model.named_parameters():
+ if not name.startswith('planner.'):
+ param.requires_grad = True
+
+ def freeze_planner_model(self):
+ """Freeze planner parameters."""
+ if hasattr(self.policy_model, 'planner'):
+ for param in self.policy_model.planner.parameters():
+ param.requires_grad = False
+
+ def unfreeze_planner_model(self):
+ """Unfreeze planner parameters."""
+ if hasattr(self.policy_model, 'planner'):
+ for param in self.policy_model.planner.parameters():
+ param.requires_grad = True
+
+ def configure_optimizers(self):
+ # Separate parameter groups for policy backbone vs planner heads
+ planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
+ planner_params = []
+ policy_params = []
+ for name, param in self.policy_model.named_parameters():
+ if name.startswith('planner.'):
+ planner_params.append(param)
+ else:
+ policy_params.append(param)
+
+ param_groups = [
+ {'params': policy_params, 'lr': self.args.learning_rate},
+ {'params': planner_params, 'lr': planner_lr},
+ ]
+ optimizer = torch.optim.AdamW(param_groups)
+ return optimizer
+
+ def _get_quality_mode(self):
+ """Map ablation flags + warmup state to quality_mode string."""
+ if self.args.disable_planner:
+ return "none"
+ if self.current_epoch < self.args.schedule_warmup_epochs:
+ return "none"
+ di = getattr(self.args, 'disable_insertion_planner', False)
+ du = getattr(self.args, 'disable_unmasking_planner', False)
+ if di and du:
+ return "none"
+ if di:
+ return "unmasking_only"
+ if du:
+ return "insertion_only"
+ return "both"
+
+ def on_train_epoch_start(self):
+ """Called at the start of each training epoch."""
+
+ # If disable_planner mode, only train policy (no alternation)
+ if self.args.disable_planner:
+ self.train_policy = True
+ self.unfreeze_policy_model()
+ self.freeze_planner_model()
+ if self.global_rank == 0 and self.current_epoch == 0:
+ print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
+
+ elif getattr(self.args, 'joint_training', False):
+ # Joint mode: train policy + planner together every step (no alternation)
+ self.train_policy = True # marker; training_step adds planner loss when joint_training is set
+ self.unfreeze_policy_model()
+ self.unfreeze_planner_model()
+ if self.global_rank == 0 and self.current_epoch == 0:
+ print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
+
+ else:
+ # Alternate between training policy and planner from epoch 0
+ # Determine which model to train this epoch
+ cycle_position = (self.current_epoch // self.alternation_frequency) % 2
+ self.train_policy = (cycle_position == 0)
+
+ if self.train_policy:
+ # Train policy, freeze planner
+ self.unfreeze_policy_model()
+ self.freeze_planner_model()
+ if self.global_rank == 0:
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
+ else:
+ # Train planner, freeze policy
+ self.freeze_policy_model()
+ self.unfreeze_planner_model()
+ if self.global_rank == 0:
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
+
+ # Resample buffer if needed
+ if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
+ if self.global_rank == 0:
+ print(f"[BUFFER] Starting buffer generation for epoch {self.current_epoch}")
+ self._generate_buffer()
+ # Synchronize all ranks after buffer generation
+ if self.trainer and self.trainer.world_size > 1:
+ if self.global_rank == 0:
+ print(f"[BUFFER] All ranks completed buffer generation, synchronizing...")
+ torch.distributed.barrier()
+ if self.global_rank == 0:
+ print(f"[BUFFER] Synchronization complete!")
+
+ def _generate_buffer(self):
+ """Generate buffer of sequences for training.
+
+ When pool_size > 0, maintains a persistent pool and refreshes a fraction
+ each time instead of regenerating the entire buffer from scratch.
+ """
+ rank = self.global_rank if self.trainer else 0
+ world_size = self.trainer.world_size if self.trainer else 1
+
+ pool_size = getattr(self.args, 'pool_size', 0)
+ is_pool = pool_size > 0
+ is_init = self.x_saved is None
+
+ # Determine how many molecules to sample this call
+ if is_pool:
+ refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
+ if is_init:
+ samples_per_gpu = pool_size
+ else:
+ samples_per_gpu = max(1, int(pool_size * refresh_frac))
+ if rank == 0:
+ if is_init:
+ print(f"\n[POOL] Initializing pool with {pool_size} molecules at epoch {self.current_epoch}")
+ else:
+ print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} molecules ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
+ else:
+ samples_per_gpu = self.args.buffer_size // world_size
+ if rank == 0:
+ samples_per_gpu += self.args.buffer_size % world_size
+
+ if rank == 0:
+ print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
+
+ accumulated_x = []
+ accumulated_log_rnd = []
+ accumulated_rewards = []
+ total_accumulated = 0
+
+ max_attempts = 100 # Prevent infinite loop
+ attempts = 0
+
+ import time
+ while total_accumulated < samples_per_gpu and attempts < max_attempts:
+ attempts += 1
+ if rank == 0:
+ print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
+
+ start_time = time.time()
+
+ x_final, log_rnd, final_rewards, trace = \
+ sample_mol_buffer(
+ self.policy_model,
+ self.pretrained,
+ self.reward_model,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=self.args.batch_size,
+ max_length=self.args.max_length,
+ quality_mode=self._get_quality_mode(),
+ alpha=self.args.alpha,
+ num_remasking=self.args.num_remasking,
+ quality_threshold=self.args.quality_threshold,
+ use_quality_filter=self.args.use_quality_filter,
+ )
+ if self.args.elbo_rnd:
+ # Override trajectory log_rnd with forward ELBO estimate
+ if x_final.shape[0] > 0:
+ with torch.no_grad():
+ noised = self.policy_model.prepare_noised_sample(
+ x_final, num_samples=self.args.elbo_rnd_num_samples)
+ policy_loss = self.policy_model.compute_loss_from_noised(noised)
+ pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
+ log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
+
+ elapsed = time.time() - start_time
+ if rank == 0:
+ print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
+
+ n_valid = x_final.shape[0]
+ if n_valid > 0:
+ accumulated_x.append(x_final)
+ accumulated_log_rnd.append(log_rnd)
+ accumulated_rewards.append(final_rewards)
+ total_accumulated += n_valid
+
+ if rank == 0:
+ qm = self._get_quality_mode()
+ print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode={qm} accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
+
+ if total_accumulated == 0:
+ raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
+
+ if total_accumulated < samples_per_gpu:
+ print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
+
+ new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
+ new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
+ new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
+
+ del accumulated_x, accumulated_log_rnd, accumulated_rewards
+ torch.cuda.empty_cache()
+
+ # add to buffer: pool mode replaces a random subset, classic mode overwrites
+ if is_pool and not is_init:
+ actual_new = min(new_x.shape[0], self.x_saved.shape[0])
+ indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
+ self.x_saved[indices] = new_x[:actual_new]
+ self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
+ self.final_rewards_saved[indices] = new_rewards[:actual_new]
+ if rank == 0:
+ print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} molecules, reward mean={self.final_rewards_saved.mean():.4f}")
+ else:
+ self.x_saved = new_x
+ self.log_rnd_saved = new_log_rnd
+ self.final_rewards_saved = new_rewards
+
+ if rank == 0:
+ print(f"[BUFFER] After cleanup - GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
+
+ def training_step(self, batch, batch_idx):
+ """Training step - batch is ignored, we use saved buffer."""
+ # Process buffer in mini-batches to avoid OOM
+ mini_batch_size = getattr(self.args, 'training_mini_batch_size', 8)
+ buffer_size = self.x_saved.shape[0]
+
+ # Randomly sample a mini-batch from buffer
+ indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
+ x_final = self.x_saved[indices]
+
+ # get log_rnd values
+ log_rnd = self.log_rnd_saved[indices]
+
+ sm_temp = getattr(self.args, 'softmax_temperature', 1.0)
+
+ joint = getattr(self.args, 'joint_training', False)
+ policy_loss = None
+ planner_loss = None
+
+ if self.train_policy:
+ # Train policy with WDCE loss
+ policy_loss = self.policy_model.loss_wdce_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength,
+ softmax_temperature=sm_temp,
+ )
+ self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ if (not self.train_policy) or joint:
+ # Train planner with appropriate loss based on ablation flags
+ if self.args.disable_insertion_planner:
+ # Ablation: only train unmasking planner (no insertion head)
+ planner_loss = self.policy_model.loss_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength,
+ softmax_temperature=sm_temp,
+ )
+ self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ elif self.args.disable_unmasking_planner:
+ # only train insertion planner (no remasking head)
+ unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength,
+ softmax_temperature=sm_temp,
+ )
+ # Zero out the unmasking component - only backprop insertion loss
+ planner_loss = insert_loss
+ self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ else:
+ # Full planner: train both remasking + insertion
+ unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength,
+ softmax_temperature=sm_temp,
+ )
+ self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ # Combine losses depending on mode
+ if joint:
+ loss = policy_loss + planner_loss
+ mode_value = 0.5
+ elif self.train_policy:
+ loss = policy_loss
+ mode_value = 0.0
+ else:
+ loss = planner_loss
+ mode_value = 1.0
+
+ # Log overall loss and mode
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
+
+ return loss
+
+ def on_train_epoch_end(self):
+ """Called at the end of each training epoch - only rank 0 evaluates."""
+ # Only evaluate every N epochs to save time
+ eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
+ is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
+ if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
+ # Sample eval batch with updated policy
+ x_eval, qed, sa, uniqueness, diversity, quality, valid_fraction = \
+ sample_mol_eval(
+ self.policy_model, self.reward_model,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=50,
+ max_length=self.args.max_length,
+ quality_mode=self._get_quality_mode(),
+ num_remasking=self.args.num_remasking,
+ evaluator=self.evaluator,
+ )
+
+ # Append to logs
+ self.valid_fraction_log.append(valid_fraction)
+ self.uniqueness_log.append(uniqueness)
+ self.diversity_log.append(diversity)
+ self.qed_log.append(qed)
+ self.sa_log.append(sa)
+ self.quality_log.append(quality)
+
+ # Compute reward stats
+ mean_reward = self.final_rewards_saved.mean().item()
+ min_reward = self.final_rewards_saved.min().item()
+ max_reward = self.final_rewards_saved.max().item()
+ median_reward = self.final_rewards_saved.median().item()
+
+ # Log metrics
+ self.log_dict({
+ "eval/valid_fraction": valid_fraction,
+ "eval/uniqueness": np.mean(uniqueness),
+ "eval/diversity": np.mean(diversity),
+ "eval/qed": np.mean(qed),
+ "eval/sa": np.mean(sa),
+ "eval/quality": np.mean(quality),
+ "eval/mean_reward_search": mean_reward,
+ "eval/min_reward_search": min_reward,
+ "eval/max_reward_search": max_reward,
+ "eval/median_reward_search": median_reward
+ })
+
+ print(f"epoch {self.current_epoch} | validity {valid_fraction:.4f} | uniqueness {np.mean(uniqueness):.4f} | diversity {np.mean(diversity):.4f} | "
+ f"QED {np.mean(qed):.4f} | SA {np.mean(sa):.4f} | quality {np.mean(quality):.4f} | ")
+
+ def on_fit_end(self):
+ """Called at the end of training - save results."""
+ if self.global_rank == 0:
+ # Save logs and plot
+ base_path = self.args.base_path
+ plot_path = f'{base_path}/results/{self.args.run_name}'
+ os.makedirs(plot_path, exist_ok=True)
+
+ output_log_path = f'{plot_path}/log_{self.filename}.csv'
+ save_logs_to_file(self.valid_fraction_log, self.uniqueness_log,
+ self.diversity_log, self.qed_log, self.sa_log,
+ self.quality_log, output_log_path)
+
+ # Final generation
+ x_eval, qed, sa, valid_fraction, uniqueness, diversity, quality, df = \
+ sample_mol_eval(
+ self.policy_model, self.reward_model,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=50,
+ max_length=self.args.max_length,
+ quality_mode=self._get_quality_mode(),
+ num_remasking=self.args.num_remasking,
+ evaluator=self.evaluator,
+ dataframe=True,
+ )
+ df.to_csv(f'{plot_path}/mol_generation_results.csv', index=False)
+
+
+def save_logs_to_file(valid_fraction_log, uniqueness_log,
+ diversity_log, qed_log, sa_log,
+ quality_log, output_path):
+ """
+ Saves the logs to a CSV file.
+ """
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ log_data = {
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
+ "Valid Fraction": valid_fraction_log,
+ "Uniqueness": uniqueness_log,
+ "Diversity": diversity_log,
+ "QED": qed_log,
+ "Synthetic Accessibility": sa_log,
+ "Quality": quality_log
+ }
+
+ df = pd.DataFrame(log_data)
+ df.to_csv(output_path, index=False)
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ """Dummy dataset for Lightning trainer (we use buffer instead)."""
+ def __init__(self, size=100):
+ self.size = size
+
+ def __len__(self):
+ return self.size
+
+ def __getitem__(self, idx):
+ return torch.zeros(1) # Dummy data
+
+
+def main():
+ """Main entry point for distributed training."""
+ # Disable DDP optimizer for higher-order ops like flex_attention
+ import torch._dynamo
+ torch._dynamo.config.optimize_ddp = False
+
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
+ argparser.add_argument('--num_epochs', type=int, default=100)
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
+ argparser.add_argument('--truncate_steps', type=int, default=50)
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
+ argparser.add_argument('--batch_size', type=int, default=50)
+ argparser.add_argument('--name', type=str, default='debug')
+ argparser.add_argument('--total_num_steps', type=int, default=128)
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
+ argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
+ argparser.add_argument("--seed", type=int, default=0)
+ # new
+ argparser.add_argument('--run_name', type=str, default='mol')
+ argparser.add_argument("--save_path_dir", default="", type=str)
+ # mcts
+ argparser.add_argument('--num_sequences', type=int, default=10)
+ argparser.add_argument('--max_length', type=int, default=1024)
+ argparser.add_argument('--num_children', type=int, default=50)
+ argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
+ argparser.add_argument('--seq_length', type=int, default=1024)
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
+ argparser.add_argument('--buffer_size', type=int, default=100)
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
+ argparser.add_argument('--resample_every_n_step', type=int, default=3)
+ argparser.add_argument('--exploration', type=float, default=0.1)
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
+ argparser.add_argument('--alpha', type=float, default=0.01)
+ argparser.add_argument('--scalarization', type=str, default='sum')
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
+ argparser.add_argument("--centering", action='store_true', default=False)
+ argparser.add_argument("--centering_strength", type=float, default=1.0)
+
+ # adaptive schedule parameters
+ argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
+ argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
+ argparser.add_argument('--schedule_num_layers', type=int, default=2)
+ argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
+ argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
+ argparser.add_argument('--freeze_base_model', action='store_true', default=False)
+ argparser.add_argument('--schedule_warmup_epochs', type=int, default=20, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
+ argparser.add_argument('--alternation_frequency', type=int, default=5, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
+ argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
+
+ # objectives
+ argparser.add_argument('--num_obj', type=int, default=2)
+ argparser.add_argument('--devices', type=int, default=-1)
+ argparser.add_argument('--checkpoint_path', type=str, default=None)
+
+ # ELBO-based log_rnd estimation
+ argparser.add_argument('--elbo_rnd', action='store_true', default=False,
+ help='If set, compute log_rnd via forward ELBO instead of trajectory rollout')
+ argparser.add_argument('--elbo_rnd_num_samples', type=int, default=4,
+ help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
+
+ # remasking
+ argparser.add_argument('--num_remasking', type=int, default=5)
+ argparser.add_argument('--quality_threshold', type=float, default=1)
+ argparser.add_argument('--use_quality_filter', action='store_true', help='If set, filter buffer to only include molecules with QED>=0.6 and SA<=4')
+ argparser.add_argument('--training_mini_batch_size', type=int, default=8, help='Mini-batch size for training step to avoid OOM')
+ argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
+ argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
+ argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
+ argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
+ argparser.add_argument('--qed_only', action='store_true', help='If set, optimize only for QED score (no SA)')
+ argparser.add_argument('--softmax_temperature', type=float, default=1.0,
+ help='Temperature for softmax on importance weights (>1 smooths, prevents concentration)')
+ argparser.add_argument('--pool_size', type=int, default=0,
+ help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer)')
+ argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
+ help='Fraction of pool to replace each resample step (only used when pool_size>0)')
+ argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10,
+ help='Number of gradient updates per epoch (1=original, 10=recommended)')
+
+ args = argparser.parse_args()
+
+ # Default planner LR to policy LR if not specified
+ if args.planner_learning_rate is None:
+ args.planner_learning_rate = args.learning_rate
+
+ # Set seed
+ pl.seed_everything(args.seed)
+
+ # Load models
+ checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
+ os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt')
+
+ curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+ if args.no_mcts:
+ args.run_name = f'mol_al_resample{args.resample_every_n_step}_no-mcts_{curr_time}'
+ else:
+ args.run_name = f'mol_al_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
+
+ # append ablation tags to run name for easy identification
+ if args.disable_planner:
+ args.run_name += '_no_planner'
+ if args.disable_insertion_planner:
+ args.run_name += '_no_insertion_planner'
+ if args.disable_unmasking_planner:
+ args.run_name += '_no_unmasking_planner'
+ if args.joint_training:
+ if args.disable_planner:
+ raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
+ args.run_name += '_joint_training'
+
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
+ os.makedirs(args.save_path, exist_ok=True)
+ set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
+
+ # Initialize the model
+ print("Loading models..")
+
+ # Load pretrained model for reference (frozen)
+ pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
+ map_location='cpu',
+ weights_only=False)
+ pretrained.eval()
+ for param in pretrained.parameters():
+ param.requires_grad = False
+
+ # Load checkpoint to extract config
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
+ if 'hyper_parameters' in checkpoint:
+ config = checkpoint['hyper_parameters']['config']
+ elif 'config' in checkpoint:
+ config = checkpoint['config']
+ else:
+ raise ValueError("Cannot find config in checkpoint")
+
+ # Update config for adaptive schedule
+ from omegaconf import OmegaConf
+ if not OmegaConf.is_config(config):
+ from omegaconf import DictConfig
+ config = DictConfig(config)
+
+ OmegaConf.set_struct(config, False)
+
+ config.training.use_adaptive_schedule = args.use_adaptive_schedule
+ config.training.schedule_hidden_dim = args.schedule_hidden_dim
+ config.training.schedule_num_layers = args.schedule_num_layers
+ config.training.schedule_loss_weight = args.schedule_loss_weight
+ config.training.freeze_base_model = args.freeze_base_model
+ config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
+ config.training.use_bracket_safe = True
+
+ OmegaConf.set_struct(config, True)
+
+ # initialize policy model with adaptive schedule
+ policy_model = AnyOrderInsertionFlowModuleFT(
+ config=config,
+ args=args,
+ pretrained_checkpoint=checkpoint_path,
+ insertion_planner=True,
+ )
+
+ # define mcts
+ if args.qed_only:
+ score_func_names = ['qed']
+ else:
+ score_func_names = ['qed', 'sa']
+
+ tokenizer = get_tokenizer()
+
+ filename = args.run_name
+
+ # Device will be set by Lightning automatically in DDP
+ reward_model = MolScoringFunctions(score_func_names, device='cpu')
+ model = MolFinetuner(
+ args=args,
+ policy_model=policy_model,
+ reward_model=reward_model,
+ tokenizer=tokenizer,
+ pretrained=pretrained,
+ mcts=None,
+ filename=filename,
+ )
+
+ checkpoint_callback = ModelCheckpoint(
+ dirpath=args.save_path,
+ filename='model-{epoch:02d}-{train_loss:.4f}',
+ every_n_epochs=args.save_every_n_epochs,
+ save_top_k=-1, # Save all checkpoints
+ save_last=True, # Also save last.ckpt
+ auto_insert_metric_name=False
+ )
+
+ # Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
+ wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-mol', name=args.run_name)
+
+ # create dummy dataloader
+ dataset = DummyDataset(size=args.num_training_steps_per_epoch)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
+
+ # setup trainer with DDP
+ trainer = pl.Trainer(
+ max_epochs=args.num_epochs,
+ accelerator='gpu',
+ devices=args.devices,
+ strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
+ gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
+ logger=wandb_logger,
+ callbacks=[checkpoint_callback],
+ enable_progress_bar=True,
+ log_every_n_steps=1
+ )
+
+ # Train
+ trainer.fit(model, dataloader)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/a2d2_mol/inference_quality_mol.py b/a2d2_mol/inference_quality_mol.py
new file mode 100644
index 0000000000000000000000000000000000000000..57e564f1a8ed4aa3570375097bd37c458abb8123
--- /dev/null
+++ b/a2d2_mol/inference_quality_mol.py
@@ -0,0 +1,554 @@
+"""Unified molecule sampling with quality-guided planning.
+
+Supports 4 quality modes and optional RND (importance weight) computation.
+
+Quality modes:
+ "none" - No planner, no remasking (policy-only)
+ "both" - Both unmasking + insertion planners active
+ "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
+ "insertion_only" - Only insertion planner (unmasking planner disabled)
+
+RND toggle:
+ compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
+ compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
+"""
+
+import torch
+import numpy as np
+import pandas as pd
+import torch.nn.functional as F
+from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
+from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
+from mol_utils.utils_chem import batch_safe_to_smiles, batch_validate_and_extract
+from tdc import Evaluator, Oracle
+
+QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
+
+
+@torch.no_grad()
+def _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ compute_rnd=False,
+ pretrained=None,
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ temperature=1.0,
+ return_trace=False,
+ unmask_quality_threshold=None,
+):
+ """Core discrete diffusion sampling loop for molecule generation.
+
+ Args:
+ model: Finetuned policy model.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
+ compute_rnd: Whether to compute step-wise log importance weights.
+ pretrained: Frozen pretrained model (required if compute_rnd=True).
+ remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
+ temperature: Sampling temperature (1.0 = no scaling).
+ return_trace: Whether to record sampling trace.
+
+ Returns:
+ (xt, log_rnd, sampling_trace)
+ log_rnd is None when compute_rnd=False.
+ """
+ assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
+ if compute_rnd:
+ assert pretrained is not None, "pretrained model required when compute_rnd=True"
+
+ # Derive flags from quality_mode
+ use_remasking = quality_mode != "none"
+ disable_unmasking_planner = quality_mode in ("none", "insertion_only")
+ disable_insertion_planner = quality_mode in ("none", "unmasking_only")
+
+ device = next(model.parameters()).device
+
+ # Initialize all-pad sequence
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute index tensors
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ neg_inf = torch.tensor(-np.inf, device=device)
+
+ if use_remasking and remasking_mode == "remdm_conf":
+ remasking_score = torch.zeros((batch_size, max_length), device=device)
+
+ log_rnd = None
+
+ for i in range(steps):
+ # --- Policy model forward ---
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # --- Pretrained model forward (for RND) ---
+ if compute_rnd:
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # --- Unmask step (Euler) ---
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ if compute_rnd:
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Add "stay" probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ if compute_rnd:
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ # Temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ # Final step: remove mask token from sampling
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0
+
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ num_zero_prob = mask_has_zero_prob.sum().item()
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
+ uniform_prob[:, :mask] = 1.0 / mask
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # Update remasking_score buffer for remdm_conf mode
+ if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
+ token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
+ chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
+ changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+ remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
+
+ # --- Remasking step ---
+ if use_remasking and i < steps - 1:
+ if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
+ remasking_conf = torch.zeros((batch_size, max_length), device=device)
+ else:
+ planner_out = model.planner(new_xt, t)
+ remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
+
+ clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
+
+ if remasking_mode == "schedule_aware":
+ new_xt = apply_schedule_aware_remasking(
+ model, new_xt, t, dt, remasking_conf, clean_index,
+ mask, neg_inf, batch_size,
+ unmask_quality_threshold=unmask_quality_threshold,
+ )
+ remasking_score_temp = None
+ else:
+ raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
+
+ if remasking_score_temp is not None:
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
+ for j in range(batch_size):
+ k = min(num_remasking, int(clean_index[j].sum().item()))
+ if k > 0:
+ _, select_indices = torch.topk(remasking_score_temp[j], k=k)
+ new_xt[j, select_indices] = mask
+
+ if return_trace:
+ for batch_idx in range(batch_size):
+ for pos in range(max_length):
+ if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=pos,
+ token=mask,
+ )
+ )
+
+ # --- Compute log probabilities for RND ---
+ if compute_rnd:
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ # --- Insertion step ---
+ if i != steps - 1:
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+
+ # Schedule-aware insertion quality filtering
+ if use_remasking and not disable_insertion_planner:
+ if compute_rnd:
+ xt_tmp_before = xt_tmp.clone()
+
+ xt_tmp = apply_schedule_aware_insertion(
+ model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
+ orig_mask, new_pos_orig, quality_threshold
+ )
+
+ if compute_rnd:
+ # Compute corrected ext based on what actually stayed
+ ext_corrected = torch.zeros_like(ext)
+ for b in range(batch_size):
+ after_len = xt_tmp[b].ne(pad).sum().item()
+ orig_len = xt_len[b].item()
+ surviving_insertions = after_len - orig_len
+ if total_ext[b] > 0:
+ ratio = surviving_insertions / total_ext[b].item()
+ ext_corrected[b] = (ext[b].float() * ratio).long()
+ else:
+ ext_corrected = ext
+ else:
+ ext_corrected = ext
+
+ # Compute insertion log_rnd
+ if compute_rnd:
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
+ log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert
+ log_rnd += log_insert_diff
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ if i != steps - 1:
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ return xt, log_rnd, sampling_trace
+
+
+def _decode_and_validate(model, tokenizer, samples):
+ """Decode token IDs to SMILES and validate.
+
+ Returns:
+ (validSequences, valid_indices): list of valid SMILES, list of batch indices.
+ """
+ decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
+
+ use_bracket_safe = model.config.training.get('use_bracket_safe', False)
+ smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
+
+ # Extract valid sequences (take largest fragment)
+ validSequences = []
+ valid_indices = []
+ for idx, s in enumerate(smiles_samples):
+ if s:
+ largest_frag = sorted(s.split('.'), key=len)[-1]
+ validSequences.append(largest_frag)
+ valid_indices.append(idx)
+
+ return validSequences, valid_indices
+
+
+@torch.no_grad()
+def sample_mol_buffer(
+ model, pretrained, reward_model, tokenizer,
+ steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ alpha=0.1,
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ temperature=1.0,
+ use_quality_filter=True,
+):
+ """Generate molecules for training buffer. Always computes step-wise RND.
+
+ Args:
+ model: Finetuned policy model.
+ pretrained: Frozen pretrained model.
+ reward_model: Molecule scoring function.
+ tokenizer: SAFE tokenizer for decoding.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
+ alpha: RND scaling factor.
+ remasking_mode: Remasking strategy.
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
+ temperature: Sampling temperature.
+ use_quality_filter: If True, filter to QED>=0.6 and SA<=4.
+
+ Returns:
+ (valid_x, log_rnd, scalar_rewards, sampling_trace)
+ """
+ xt, log_rnd, trace = _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode=quality_mode,
+ compute_rnd=True,
+ pretrained=pretrained,
+ remasking_mode=remasking_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ temperature=temperature,
+ )
+
+ device = xt.device
+ samples = xt.to(device)
+
+ validSequences, valid_indices = _decode_and_validate(model, tokenizer, samples)
+
+ valid_x_final = [samples[idx] for idx in valid_indices]
+ valid_log_rnd = [log_rnd[idx] for idx in valid_indices]
+
+ print("len valid sequences:", len(validSequences))
+
+ if len(validSequences) == 0:
+ print("[WARNING] No valid molecules generated in this batch")
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
+ return empty_x, empty_log_rnd, empty_rewards, trace
+
+ # Compute multi-objective rewards
+ score_vectors = reward_model(input_seqs=validSequences)
+ scalar_rewards = np.sum(score_vectors, axis=-1)
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
+
+ print(f"scalar reward dim{len(scalar_rewards)}")
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
+
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha)
+ valid_x_final = torch.stack(valid_x_final, dim=0)
+
+ # Optionally filter to only keep quality sequences (QED >= 0.6 and SA <= 4)
+ if use_quality_filter:
+ qed_scores = score_vectors[:, 0]
+ if score_vectors.shape[1] > 1:
+ sa_scores = score_vectors[:, 1]
+ else:
+ _oracle_sa = Oracle('sa')
+ raw_sa = np.array(_oracle_sa(validSequences))
+ sa_scores = raw_sa
+ quality_mask = (qed_scores >= 0.6) & (sa_scores <= 4)
+
+ n_quality = quality_mask.sum()
+ print(f"Quality filtering: {n_quality}/{len(validSequences)} sequences pass (QED>=0.6, SA<=4)")
+
+ if n_quality == 0:
+ print("[WARNING] No quality molecules in this batch")
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
+ return empty_x, empty_log_rnd, empty_rewards, trace
+
+ quality_mask_torch = torch.as_tensor(quality_mask, dtype=torch.bool, device=device)
+
+ quality_x_final = valid_x_final[quality_mask_torch]
+ quality_log_rnd = log_rnd[quality_mask_torch]
+ quality_rewards = scalar_rewards[quality_mask_torch]
+ else:
+ print(f"No quality filtering applied - using all {len(validSequences)} valid molecules")
+ quality_x_final = valid_x_final
+ quality_log_rnd = log_rnd
+ quality_rewards = scalar_rewards
+
+ return quality_x_final, quality_log_rnd, quality_rewards, trace
+
+
+@torch.no_grad()
+def sample_mol_eval(
+ model, reward_model, tokenizer,
+ steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ temperature=1.0,
+ evaluator=None,
+ dataframe=False,
+ unmask_quality_threshold=None,
+):
+ """Generate molecules for evaluation.
+
+ Args:
+ model: Finetuned policy model.
+ reward_model: Molecule scoring function.
+ tokenizer: SAFE tokenizer for decoding.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
+ remasking_mode: Remasking strategy.
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering. Pass None
+ to use schedule-driven deletion with no threshold gate
+ temperature: Sampling temperature.
+ evaluator: TDC Evaluator for diversity (created if None).
+ dataframe: If True, include a pandas DataFrame in the return.
+
+ Returns:
+ Without dataframe:
+ (validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction)
+ With dataframe:
+ (validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df)
+ validSequences is the raw list including duplicates; qed/sa are scored
+ on the unique set. Caller can dedup with set(validSequences). The
+ dataframe (when requested) has one row per unique molecule.
+ """
+ if evaluator is None:
+ evaluator = Evaluator('diversity')
+
+ xt, _, trace = _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode=quality_mode,
+ compute_rnd=False,
+ remasking_mode=remasking_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ temperature=temperature,
+ unmask_quality_threshold=unmask_quality_threshold,
+ )
+
+ device = xt.device
+ samples = xt.to(device)
+
+ decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
+
+ use_bracket_safe = model.config.training.get('use_bracket_safe', False)
+ smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True)
+
+ # Extract valid sequences (take largest fragment)
+ validSequences = [sorted(s.split('.'), key=len)[-1] for s in smiles_samples if s]
+
+ print("len valid sequences:", len(validSequences))
+ valid_fraction = len(validSequences) / batch_size
+ uniqueSequences = list(set(validSequences))
+ uniqueness = len(uniqueSequences) / len(validSequences) if len(validSequences) > 0 else 0
+ diversity = evaluator(uniqueSequences) if len(uniqueSequences) > 0 else 0
+
+ # Calculate quality (unique sequences with QED >= 0.6 and SA <= 4)
+ if len(uniqueSequences) > 0:
+ score_vectors_temp = reward_model(input_seqs=list(uniqueSequences))
+ qed_scores = score_vectors_temp[:, 0] # Raw QED (0-1)
+
+ # Always use raw SA (1-10 scale) for quality filtering
+ _oracle_sa = Oracle('sa')
+ raw_sa_scores = np.array(_oracle_sa(list(uniqueSequences)))
+
+ quality_count = sum((qed_scores >= 0.6) & (raw_sa_scores <= 4))
+ quality = quality_count / batch_size
+ print(f'Quality:\t{quality}')
+
+ qed = qed_scores
+ sa = raw_sa_scores
+ else:
+ zeros = [0.0]
+ qed = zeros
+ sa = zeros
+ quality = 0.0
+
+ if dataframe:
+ df = pd.DataFrame({
+ "Mol Sequence": uniqueSequences,
+ "QED": qed if len(uniqueSequences) else [0.0],
+ "SA": sa if len(uniqueSequences) else [0.0],
+ })
+ return validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
+
+ return validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction
diff --git a/a2d2_mol/mol_dataset.py b/a2d2_mol/mol_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb4da02094d9865a89bf8f98e7c50dcd8a122c3
--- /dev/null
+++ b/a2d2_mol/mol_dataset.py
@@ -0,0 +1,379 @@
+#!/usr/bin/env python
+"""
+Adapter to use HuggingFace datasets with the any-length discrete diffusion model.
+This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format
+expected by the training pipeline.
+"""
+
+import torch
+from torch.utils.data import Dataset, DataLoader
+from datasets import load_dataset
+import pytorch_lightning as pl
+from safe.tokenizer import SAFETokenizer
+from mol_utils.bracket_safe_converter import safe2bracketsafe
+from typing import Optional, List
+import re
+
+
+def get_tokenizer():
+ """Get SAFE tokenizer with added special tokens."""
+ tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
+ tk.add_tokens(['<', '>']) # for bracket_safe
+ return tk
+
+
+class Collator:
+ """Data collator for SAFE/bracket-SAFE format."""
+
+ def __init__(self, config, tokenizer=None):
+ self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer()
+ self.max_length = config.interpolant.max_length
+ self.use_bracket_safe = config.training.get('use_bracket_safe', False)
+
+ def __call__(self, examples):
+ # Handle both dict with 'labels' and direct string format
+ inputs = []
+ for example in examples:
+ if isinstance(example, dict):
+ # Try different key names: 'input', 'labels', 'smiles'
+ input_text = example.get('input', example.get('labels', example.get('smiles', '')))
+ else:
+ input_text = example
+
+ if self.use_bracket_safe:
+ input_text = safe2bracketsafe(input_text)
+
+ inputs.append(input_text)
+
+ batch = self.tokenizer(
+ inputs,
+ return_tensors='pt',
+ padding=True,
+ truncation=True,
+ max_length=self.max_length
+ )
+
+ # Convert BatchEncoding to plain dict with tensors
+ # Remove token_type_ids if present (not needed for diffusion models)
+ result = {
+ 'input_ids': batch['input_ids'],
+ 'attention_mask': batch['attention_mask']
+ }
+
+ return result
+
+
+class HFDatasetAdapter(Dataset):
+ """Adapts HuggingFace datasets to the format expected by the diffusion model."""
+
+ def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False):
+ """
+ Args:
+ hf_dataset: HuggingFace dataset object (streaming or regular)
+ tokenizer: SMILES tokenizer instance
+ smiles_column: Name of the column containing SMILES strings
+ max_length: Maximum sequence length
+ convert_to_safe: Whether to convert SMILES to SAFE format
+ is_streaming: Whether dataset is in streaming mode
+ """
+ self.tokenizer = tokenizer
+ self.smiles_column = smiles_column
+ self.max_length = max_length
+ self.convert_to_safe = convert_to_safe
+ self.is_streaming = is_streaming
+
+ if is_streaming:
+ # For streaming datasets, we don't pre-load the data
+ self.data = hf_dataset
+ self._length = None # Unknown length for streaming
+ print(f'Initialized streaming dataset adapter')
+ else:
+ # Store raw data without pre-tokenization (tokenization will happen in collator)
+ print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...')
+ self.data = []
+ for item in hf_dataset:
+ smiles = item[smiles_column]
+ if smiles: # Skip empty SMILES
+ self.data.append({'input': smiles, 'labels': smiles})
+ print(f'Processed {len(self.data)} valid samples')
+
+ def __len__(self):
+ if self.is_streaming:
+ # Streaming datasets don't have a length
+ # Return a large number to prevent issues with samplers
+ return 10_000_000 if self._length is None else self._length
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ if self.is_streaming:
+ # For streaming, iteration happens differently
+ raise NotImplementedError("Streaming datasets should be iterated, not indexed")
+ return self.data[idx]
+
+ def __iter__(self):
+ """Support iteration for streaming datasets."""
+ if self.is_streaming:
+ for item in self.data:
+ smiles = item[self.smiles_column]
+ if smiles: # Skip empty SMILES
+ yield {'input': smiles, 'labels': smiles}
+ else:
+ for item in self.data:
+ yield item
+
+
+class HFDataModule(pl.LightningDataModule):
+ """PyTorch Lightning DataModule for HuggingFace datasets."""
+
+ def __init__(
+ self,
+ config,
+ dataset_name: str,
+ tokenizer: SAFETokenizer,
+ smiles_column: str = 'smiles',
+ val_split: float = 0.1,
+ test_split: Optional[float] = None,
+ streaming: bool = True,
+ max_train_samples: Optional[int] = None,
+ max_val_samples: Optional[int] = None,
+ ):
+ """
+ Args:
+ config: Configuration object containing training parameters
+ dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt")
+ tokenizer: SMILES tokenizer instance
+ smiles_column: Name of column containing SMILES strings
+ val_split: Fraction of data to use for validation
+ test_split: Optional fraction of data to use for testing
+ streaming: Whether to use streaming mode (recommended for large datasets)
+ max_train_samples: Maximum number of training samples to use (for non-streaming)
+ max_val_samples: Maximum number of validation samples to use (for non-streaming)
+ """
+ super().__init__()
+ self.config = config
+ self.dataset_name = dataset_name
+ self.tokenizer = tokenizer
+ self.smiles_column = smiles_column
+ self.max_length = config.interpolant.max_length
+ self.batch_size = config.training.per_gpu_batch_size
+ self.num_workers = config.training.get('cpus', 4)
+ self.val_split = val_split
+ self.test_split = test_split
+ self.streaming = streaming
+ self.max_train_samples = max_train_samples
+ self.max_val_samples = max_val_samples
+
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ # Initialize collator
+ self.collator = Collator(config, tokenizer)
+
+ def setup(self, stage: Optional[str] = None):
+ """Load and split the dataset."""
+ print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})')
+
+ if self.streaming:
+ # Load dataset in streaming mode
+ raw_dataset = load_dataset(self.dataset_name, streaming=True)
+
+ # Handle different dataset structures
+ if 'train' in raw_dataset:
+ train_stream = raw_dataset['train']
+ else:
+ # If no splits exist, use the entire dataset
+ train_stream = raw_dataset[list(raw_dataset.keys())[0]]
+
+ # For streaming, we need to manually split train/val
+ # Skip validation samples, then take training samples
+ val_size = int(100000 * self.val_split) # Assume ~100k samples for val split calculation
+ train_size = 100000 - val_size
+
+ # Create validation stream (take first val_size samples)
+ val_stream = train_stream.take(val_size)
+
+ # Create training stream (skip val_size samples, then iterate)
+ train_stream_shifted = train_stream.skip(val_size)
+
+ # Create adapted datasets
+ self.train_dataset = HFDatasetAdapter(
+ train_stream_shifted,
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=True
+ )
+
+ self.val_dataset = HFDatasetAdapter(
+ val_stream,
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=True
+ )
+
+ print(f'Streaming dataset initialized - samples will be loaded on-the-fly')
+
+ else:
+ # Traditional non-streaming mode with full dataset loading
+ raw_dataset = load_dataset(self.dataset_name)
+
+ # Handle different dataset structures
+ if 'train' in raw_dataset:
+ train_data = raw_dataset['train']
+ else:
+ # If no splits exist, use the entire dataset and split it
+ train_data = raw_dataset[list(raw_dataset.keys())[0]]
+
+ # Limit samples if specified
+ if self.max_train_samples:
+ train_data = train_data.select(range(min(self.max_train_samples, len(train_data))))
+
+ # Check if dataset already has validation split
+ if 'validation' in raw_dataset or 'val' in raw_dataset:
+ val_key = 'validation' if 'validation' in raw_dataset else 'val'
+ val_data = raw_dataset[val_key]
+ else:
+ # Create train/val split
+ split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42)
+ train_data = split_dataset['train']
+ val_data = split_dataset['test']
+
+ # Limit validation samples if specified
+ if self.max_val_samples:
+ val_data = val_data.select(range(min(self.max_val_samples, len(val_data))))
+
+ # Create test split if requested
+ if self.test_split and 'test' not in raw_dataset:
+ split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42)
+ train_data = split_dataset['train']
+ self.test_dataset = HFDatasetAdapter(
+ split_dataset['test'],
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=False
+ )
+ elif 'test' in raw_dataset:
+ self.test_dataset = HFDatasetAdapter(
+ raw_dataset['test'],
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=False
+ )
+
+ # Create adapted datasets
+ self.train_dataset = HFDatasetAdapter(
+ train_data,
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=False
+ )
+
+ self.val_dataset = HFDatasetAdapter(
+ val_data,
+ self.tokenizer,
+ self.smiles_column,
+ self.max_length,
+ is_streaming=False
+ )
+
+ print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}')
+ if self.test_dataset:
+ print(f'Test: {len(self.test_dataset)}')
+
+ def train_dataloader(self):
+ if self.streaming:
+ # Pass streaming dataset directly to DataLoader (HF IterableDataset)
+ # Must use num_workers=0 when using .skip() or .take() operations
+ return DataLoader(
+ self.train_dataset.data, # Use the raw HF streaming dataset
+ batch_size=self.batch_size,
+ collate_fn=self.collator,
+ num_workers=0, # Required for streaming with skip/take operations
+ pin_memory=True,
+ shuffle=False, # Cannot shuffle streaming datasets
+ )
+ else:
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=self.collator,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True if self.num_workers > 0 else False
+ )
+
+ def val_dataloader(self):
+ if self.streaming:
+ # Pass streaming dataset directly to DataLoader (HF IterableDataset)
+ # Must use num_workers=0 when using .skip() or .take() operations
+ return DataLoader(
+ self.val_dataset.data, # Use the raw HF streaming dataset
+ batch_size=self.batch_size,
+ collate_fn=self.collator,
+ num_workers=0, # Required for streaming with skip/take operations
+ pin_memory=True,
+ shuffle=False, # Cannot shuffle streaming datasets
+ )
+ else:
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=self.collator,
+ shuffle=False,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True if self.num_workers > 0 else False
+ )
+
+ def test_dataloader(self):
+ if self.test_dataset:
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ collate_fn=self.collator,
+ shuffle=False,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True if self.num_workers > 0 else False
+ )
+ return None
+
+
+def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True):
+ """
+ Setup HuggingFace dataset and update config with token information.
+
+ Args:
+ config: Hydra config object
+ dataset_name: HuggingFace dataset identifier
+ smiles_column: Name of column containing SMILES strings
+ streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt)
+
+ Returns:
+ HFDataModule instance
+ """
+ # Initialize tokenizer
+ tokenizer = get_tokenizer()
+
+ # Update config with tokenizer info
+ config.interpolant.tokens = len(tokenizer)
+ config.interpolant.pad_token = tokenizer.pad_token_id
+ config.interpolant.mask_token = tokenizer.mask_token_id
+
+ # Create data module
+ data_module = HFDataModule(
+ config=config,
+ dataset_name=dataset_name,
+ tokenizer=tokenizer,
+ smiles_column=smiles_column,
+ val_split=0.1,
+ streaming=streaming,
+ )
+
+ return data_module
diff --git a/a2d2_mol/mol_scoring/oracle/fpscores.pkl b/a2d2_mol/mol_scoring/oracle/fpscores.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..24e7f60a5b12606184f8642bb5ac9a84c291484b
--- /dev/null
+++ b/a2d2_mol/mol_scoring/oracle/fpscores.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
+size 9048931
diff --git a/a2d2_mol/mol_scoring/scoring_functions.py b/a2d2_mol/mol_scoring/scoring_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f18620debb54206868ad9a07e2769274d08bb40
--- /dev/null
+++ b/a2d2_mol/mol_scoring/scoring_functions.py
@@ -0,0 +1,68 @@
+from transformers import AutoModelForMaskedLM
+import numpy as np
+from tdc import Oracle, Evaluator
+
+class MolScoringFunctions:
+ def __init__(self, score_func_names=None, device=None, sa_transform='inverse'):
+ """
+ Class for generating score vectors given generated sequence
+
+ Args:
+ score_func_names: list of scoring function names to be evaluated
+ score_weights: weights to scale scores (default: 1)
+ sa_transform: how to transform SA scores to higher-is-better ~[0,1]:
+ 'inverse' (default): 1/(1+SA) — range ~0.09-0.5, weak gradient
+ 'linear': (10-SA)/9 — range ~0-1, stronger gradient
+ """
+ if score_func_names is None:
+ # just do unmasking based on validity of peptide bonds
+ self.score_func_names = []
+ else:
+ self.score_func_names = score_func_names
+
+ self.sa_transform = sa_transform
+
+ oracle_qed = Oracle('qed')
+ oracle_sa = Oracle('sa')
+
+ self.all_funcs = {'qed': oracle_qed,
+ 'sa': oracle_sa,
+ }
+
+ def forward(self, input_seqs):
+ scores = []
+
+ for i, score_func in enumerate(self.score_func_names):
+ score = self.all_funcs[score_func](input_seqs)
+
+ # Transform SA to be maximized and normalized (original SA: 1-10, lower is better)
+ # Convert to: higher is better, normalized to ~0-1 range like QED
+ if score_func == 'sa':
+ if self.sa_transform == 'linear':
+ score = (10.0 - np.array(score)) / 9.0 # range ~0-1, clipped at 0
+ score = np.maximum(score, 0.0)
+ else:
+ score = 1.0 / (1.0 + np.array(score)) # range ~0.09-0.5
+
+ scores.append(score)
+
+ # convert to numpy arrays with shape (num_sequences, num_functions)
+ scores = np.float32(scores).T
+
+ return scores
+
+ def __call__(self, input_seqs: list):
+ return self.forward(input_seqs)
+
+
+def unittest():
+ scoring = MolScoringFunctions(score_func_names=['qed', 'sa'])
+
+ smiles = ['CCOc1cc(ccc1NC(=O)N[C@@H]2CCCC[C@@H]2O)F']
+
+ scores = scoring(input_seqs=smiles)
+ print(scores)
+ print(len(scores))
+
+if __name__ == '__main__':
+ unittest()
\ No newline at end of file
diff --git a/a2d2_mol/mol_utils/bracket_safe_converter.py b/a2d2_mol/mol_utils/bracket_safe_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ccc9c5d0fdf8ba3ddd55dbb60838e74acacb452
--- /dev/null
+++ b/a2d2_mol/mol_utils/bracket_safe_converter.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Patch: stub out `auto_docstring` if missing from transformers.utils
+# (needed by safe.trainer.model in newer safe versions)
+import transformers.utils as _tu
+if not hasattr(_tu, 'auto_docstring'):
+ _tu.auto_docstring = lambda *a, **kw: (lambda fn: fn)
+
+from safe.converter import *
+
+
+class BracketSAFEConverter(SAFEConverter):
+ def encoder(
+ self,
+ inp: Union[str, dm.Mol],
+ canonical: bool = True,
+ randomize: Optional[bool] = False,
+ seed: Optional[int] = None,
+ constraints: Optional[List[dm.Mol]] = None,
+ allow_empty: bool = False,
+ rdkit_safe: bool = True,
+ ):
+ rng = None
+ if randomize:
+ rng = np.random.default_rng(seed)
+ if not canonical:
+ inp = dm.to_mol(inp, remove_hs=False)
+ inp = self.randomize(inp, rng)
+
+ if isinstance(inp, dm.Mol):
+ inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False)
+
+ branch_numbers = self._find_branch_number(inp)
+
+ mol = dm.to_mol(inp, remove_hs=False)
+ if self.ignore_stereo:
+ mol = dm.remove_stereochemistry(mol)
+
+ bond_map_id = 1
+ for atom in mol.GetAtoms():
+ if atom.GetAtomicNum() == 0:
+ atom.SetAtomMapNum(0)
+ atom.SetIsotope(bond_map_id)
+ bond_map_id += 1
+
+ if self.require_hs:
+ mol = dm.add_hs(mol)
+ matching_bonds = self._fragment(mol, allow_empty=allow_empty)
+ substructed_ignored = []
+ if constraints is not None:
+ substructed_ignored = list(
+ itertools.chain(
+ *[
+ mol.GetSubstructMatches(constraint, uniquify=True)
+ for constraint in constraints
+ ]
+ )
+ )
+
+ bonds = []
+ for i_a, i_b in matching_bonds:
+ # if both atoms of the bond are found in a disallowed substructure, we cannot consider them
+ # on the other end, a bond between two substructure to preserved independently is perfectly fine
+ if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored):
+ continue
+ obond = mol.GetBondBetweenAtoms(i_a, i_b)
+ bonds.append(obond.GetIdx())
+
+ if len(bonds) > 0:
+ mol = Chem.FragmentOnBonds(
+ mol,
+ bonds,
+ dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
+ )
+
+ frags = list(Chem.GetMolFrags(mol, asMols=True))
+ if randomize:
+ frags = rng.permutation(frags).tolist()
+ elif canonical:
+ frags = sorted(
+ frags,
+ key=lambda x: x.GetNumAtoms(),
+ reverse=True,
+ )
+
+ frags_str = []
+ for frag in frags:
+ non_map_atom_idxs = [
+ atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
+ ]
+ frags_str.append(
+ Chem.MolToSmiles(
+ frag,
+ isomericSmiles=True,
+ canonical=True, # needs to always be true
+ rootedAtAtom=non_map_atom_idxs[0],
+ )
+ )
+
+ scaffold_str = ".".join(frags_str)
+
+ # don't capture atom mapping in the scaffold
+ attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
+ if canonical:
+ attach_pos = sorted(attach_pos)
+ starting_num = 1
+ for attach in attach_pos:
+ val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
+ val = '<' + val + '>' # bracket added
+ # we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
+ attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
+ scaffold_str = attach_regexp.sub(val, scaffold_str)
+ starting_num += 1
+
+ # now we need to remove all the parenthesis around digit only number
+ wrong_attach = re.compile(r"\((<[\%\d+]*>)\)") # bracket added
+ scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
+ # furthermore, we autoapply rdkit-compatible digit standardization.
+ if rdkit_safe:
+ pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
+ replacement = r"\g<1>\g<2>"
+ scaffold_str = re.sub(pattern, replacement, scaffold_str)
+ return scaffold_str
+
+
+def safe2bracketsafe(safe_str):
+ try:
+ return BracketSAFEConverter().encoder(Chem.MolFromSmiles(safe_str), allow_empty=True, canonical=False, randomize=True)
+ except:
+ return safe_str
+
+
+def bracketsafe2safe(safe_str):
+ intrafrag_points = [m.group(0) for m in re.finditer(r'(?)', safe_str)] + \
+ [m.group(0).lstrip('%') for m in re.finditer(r'%\d+', safe_str)]
+ starting_num = max([int(i) for i in intrafrag_points]) + 1 if intrafrag_points else 0
+ interfrag_points = [(m.start(0), m.end(0)) for m in re.finditer(r'<\d+>', safe_str)]
+
+ safe_str = list(safe_str)
+ for start, end in interfrag_points:
+ safe_str[start] = safe_str[end-1] = ' ' # '<', '>' -> ''
+ num_to_replace = int(''.join(safe_str[start+1 : end-1])) + starting_num
+ num_to_replace = '%' + str(num_to_replace) if num_to_replace >= 10 else str(num_to_replace)
+ safe_str[start+1 : end-1] = [num_to_replace] + [' '] * (end - start - 3)
+ safe_str = re.sub(' ', '', ''.join(safe_str))
+ return safe_str
\ No newline at end of file
diff --git a/a2d2_mol/mol_utils/utils.py b/a2d2_mol/mol_utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..0795a73ff2ed657e5136ab99f83e109f5ea9174e
--- /dev/null
+++ b/a2d2_mol/mol_utils/utils.py
@@ -0,0 +1,135 @@
+"""Console logger utilities.
+
+Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
+Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
+"""
+
+import logging
+import fsspec
+import lightning
+import torch
+from timm.scheduler import CosineLRScheduler
+import argparse
+import numpy as np
+import random
+import os
+
+def sample_categorical_logits(logits, dtype=torch.float64):
+ # do not require logits to be log-softmaxed
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
+ return (logits + gumbel_noise).argmax(dim=-1)
+
+def fsspec_exists(filename):
+ """Check if a file exists using fsspec."""
+ fs, _ = fsspec.core.url_to_fs(filename)
+ return fs.exists(filename)
+
+
+def fsspec_listdir(dirname):
+ """Listdir in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ return fs.ls(dirname)
+
+
+def fsspec_mkdirs(dirname, exist_ok=True):
+ """Mkdirs in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ fs.makedirs(dirname, exist_ok=exist_ok)
+
+
+def print_nans(tensor, name):
+ if torch.isnan(tensor).any():
+ print(name, tensor)
+
+
+class CosineDecayWarmupLRScheduler(
+ CosineLRScheduler,
+ torch.optim.lr_scheduler._LRScheduler):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._last_epoch = -1
+ self.step(epoch=0)
+
+ def step(self, epoch=None):
+ if epoch is None:
+ self._last_epoch += 1
+ else:
+ self._last_epoch = epoch
+ # We call either step or step_update, depending on
+ # whether we're using the scheduler every epoch or every
+ # step.
+ # Otherwise, lightning will always call step (i.e.,
+ # meant for each epoch), and if we set scheduler
+ # interval to "step", then the learning rate update will
+ # be wrong.
+ if self.t_in_epochs:
+ super().step(epoch=self._last_epoch)
+ else:
+ super().step_update(num_updates=self._last_epoch)
+
+
+class LoggingContext:
+ """Context manager for selective logging."""
+ def __init__(self, logger, level=None, handler=None, close=True):
+ self.logger = logger
+ self.level = level
+ self.handler = handler
+ self.close = close
+
+ def __enter__(self):
+ if self.level is not None:
+ self.old_level = self.logger.level
+ self.logger.setLevel(self.level)
+ if self.handler:
+ self.logger.addHandler(self.handler)
+
+ def __exit__(self, et, ev, tb):
+ if self.level is not None:
+ self.logger.setLevel(self.old_level)
+ if self.handler:
+ self.logger.removeHandler(self.handler)
+ if self.handler and self.close:
+ self.handler.close()
+
+
+def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
+ """Initializes multi-GPU-friendly python logger."""
+
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ for level in ('debug', 'info', 'warning', 'error',
+ 'exception', 'fatal', 'critical'):
+ setattr(logger,
+ level,
+ lightning.pytorch.utilities.rank_zero_only(
+ getattr(logger, level)))
+
+ return logger
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def set_seed(seed, use_cuda):
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ # torch.backends.cudnn.deterministic = True
+ if use_cuda:
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ print(f'=> Seed of the run set to {seed}')
+
diff --git a/a2d2_mol/mol_utils/utils_chem.py b/a2d2_mol/mol_utils/utils_chem.py
new file mode 100644
index 0000000000000000000000000000000000000000..8682e03d49ca11ee0b28950f3b8cab471d86cc61
--- /dev/null
+++ b/a2d2_mol/mol_utils/utils_chem.py
@@ -0,0 +1,187 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import random
+import safe as sf
+import datamol as dm
+from contextlib import suppress
+from rdkit import Chem, RDLogger
+RDLogger.DisableLog('rdApp.*')
+
+# https://github.com/datamol-io/safe/blob/main/safe/sample.py
+# https://github.com/jensengroup/GB_GA/blob/master/crossover.py
+def safe_to_smiles(safe_str, fix=True):
+ if fix:
+ safe_str = '.'.join([frag for frag in safe_str.split('.')
+ if sf.decode(frag, ignore_errors=True) is not None])
+ return sf.decode(safe_str, canonical=True, ignore_errors=True)
+
+
+def _safe_to_smiles_worker(args):
+ """Worker function for parallel SAFE to SMILES conversion."""
+ safe_str, use_bracket_safe, fix = args
+ try:
+ from mol_utils.bracket_safe_converter import bracketsafe2safe
+ if use_bracket_safe:
+ safe_str = bracketsafe2safe(safe_str)
+ return safe_to_smiles(safe_str, fix=fix)
+ except Exception:
+ return None
+
+
+def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None):
+ """
+ Convert a batch of SAFE strings to SMILES in parallel using multiprocessing.
+
+ Args:
+ safe_strings: List of SAFE format strings
+ use_bracket_safe: Whether to convert from bracket SAFE format first
+ fix: Whether to fix invalid fragments
+ num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8))
+
+ Returns:
+ List of SMILES strings (None for invalid molecules)
+ """
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+ import os
+
+ n = len(safe_strings)
+ if n == 0:
+ return []
+
+ # For small batches, use sequential processing (overhead not worth it)
+ if n <= 4:
+ if use_bracket_safe:
+ from mol_utils.bracket_safe_converter import bracketsafe2safe
+ return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings]
+ else:
+ return [safe_to_smiles(s, fix=fix) for s in safe_strings]
+
+ # Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL)
+ # ProcessPoolExecutor has too much overhead for this use case
+ if num_workers is None:
+ num_workers = min(os.cpu_count() or 4, n, 8)
+
+ args_list = [(s, use_bracket_safe, fix) for s in safe_strings]
+
+ # ThreadPoolExecutor is faster here because:
+ # 1. No pickle serialization overhead
+ # 2. RDKit releases the GIL during computation
+ # 3. Lower startup cost
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ results = list(executor.map(_safe_to_smiles_worker, args_list))
+
+ return results
+
+
+def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor):
+ """
+ Batch validate SMILES and extract valid samples efficiently.
+
+ Args:
+ smiles_list: List of SMILES strings (may contain None for invalid)
+ samples_tensor: Tensor of token IDs (B, L)
+ log_rnd_tensor: Tensor of log random values (B,)
+
+ Returns:
+ valid_sequences: List of valid SMILES (largest fragment)
+ valid_indices: List of indices of valid samples
+ """
+ valid_sequences = []
+ valid_indices = []
+
+ for idx, smiles in enumerate(smiles_list):
+ if smiles: # Valid SMILES
+ # Take largest fragment if multiple
+ largest_fragment = sorted(smiles.split('.'), key=len)[-1]
+ valid_sequences.append(largest_fragment)
+ valid_indices.append(idx)
+
+ return valid_sequences, valid_indices
+
+
+def filter_by_substructure(sequences, substruct):
+ substruct = sf.utils.standardize_attach(substruct)
+ substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*'))
+ substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct))
+ return sf.utils.filter_by_substructure_constraints(sequences, substruct)
+
+
+def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1):
+ mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False)
+
+ prefix_linkers = []
+ suffix_linkers = []
+ prefix_query = dm.from_smarts(prefix)
+ suffix_query = dm.from_smarts(suffix)
+
+ for x in prefix_sequences:
+ with suppress(Exception):
+ x = dm.to_mol(x)
+ out = mol_linker_slicer(x, prefix_query)
+ prefix_linkers.append(out[1])
+
+ for x in suffix_sequences:
+ with suppress(Exception):
+ x = dm.to_mol(x)
+ out = mol_linker_slicer(x, suffix_query)
+ suffix_linkers.append(out[1])
+
+ n_linked = 0
+ linked = []
+ linkers = prefix_linkers + suffix_linkers
+ linkers = [x for x in linkers if x is not None]
+ for n_linked, linker in enumerate(linkers):
+ linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
+ if n_linked > num_samples:
+ break
+ linked = [x for x in linked if x]
+ return linked[:num_samples]
+
+
+def cut(smiles):
+ def cut_nonring(mol):
+ if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')):
+ return None
+
+ bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring
+ bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()]
+ fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)])
+
+ try:
+ return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True)
+ except ValueError:
+ return None
+
+ mol = Chem.MolFromSmiles(smiles)
+ frags = set()
+ # non-ring cut
+ for _ in range(3):
+ frags_nonring = cut_nonring(mol)
+ if frags_nonring is not None:
+ frags |= set([Chem.MolToSmiles(f) for f in frags_nonring])
+ return frags
+
+
+class Slicer:
+ def __call__(self, mol):
+ if isinstance(mol, str):
+ mol = Chem.MolFromSmiles(mol)
+
+ # non-ring single bonds
+ bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))
+ for bond in bonds:
+ yield bond
\ No newline at end of file
diff --git a/a2d2_mol/oracle/fpscores.pkl b/a2d2_mol/oracle/fpscores.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..24e7f60a5b12606184f8642bb5ac9a84c291484b
--- /dev/null
+++ b/a2d2_mol/oracle/fpscores.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf
+size 9048931
diff --git a/a2d2_mol/remasking_scheduleaware.py b/a2d2_mol/remasking_scheduleaware.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8637a514828dd6336caa996c133b2250748064b
--- /dev/null
+++ b/a2d2_mol/remasking_scheduleaware.py
@@ -0,0 +1,177 @@
+"""
+Schedule-aware remasking and insertion logic that ensures the number of masked tokens
+follows the interpolant schedule.
+"""
+import torch
+import numpy as np
+
+def apply_schedule_aware_insertion(
+ model,
+ xt_tmp,
+ new_xt,
+ t,
+ dt,
+ ext,
+ mask,
+ pad,
+ max_length,
+ orig_mask,
+ new_pos_orig,
+ quality_threshold=1,
+):
+ """
+ Remove low-quality insertions based on insertion confidence while respecting
+ the interpolant schedule for expected sequence length.
+
+ Args:
+ model: Model with planner and interpolant
+ xt_tmp: Sequence after insertion [B, L]
+ new_xt: Sequence before insertion [B, L]
+ t: Current time [B]
+ dt: Time step size
+ ext: Number of insertions per gap [B, L+1]
+ mask: Mask token ID
+ pad: Pad token ID
+ max_length: Maximum sequence length
+ orig_mask: Mask of original token positions [B, L]
+ new_pos_orig: New positions of original tokens [B, L]
+ quality_threshold: If a float, drop insertions with confidence below it
+
+ Returns:
+ xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
+ """
+ device = xt_tmp.device
+ batch_size, L = xt_tmp.shape
+ total_ext = ext.sum(dim=1)
+
+ # Only proceed if there were insertions
+ if total_ext.sum() == 0:
+ return xt_tmp
+
+ # Get planner predictions on inserted state. The insertion head is trained
+ # with the pre-step time t (see loss_insert_planner_flexible), so condition
+ # on t here too; t_next is still used below for the length schedule.
+ t_next = t + dt
+ planner_out = model.planner(xt_tmp, t)
+ insertion_conf = planner_out.get("insertion_conf", None)
+
+ if insertion_conf is None:
+ return xt_tmp
+
+ insertion_conf = insertion_conf.squeeze(-1) # (B, L)
+
+ # Expected sequence length at next timestep according to schedule
+ current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
+ expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
+ estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
+ expected_length = estimated_final_length * expected_progress # [B]
+
+ # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
+ # Fancy-indexing scatter avoids the per-batch python loop.
+ valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
+ valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
+ is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
+ is_original[valid_b, valid_p] = True
+ inserted_positions = (xt_tmp == mask) & ~is_original
+
+ # Two deletion modes, selected by `quality_threshold`:
+ # * float: drop insertions whose confidence is below the threshold, capped
+ # so the length never falls below the scheduled minimum.
+ candidates = inserted_positions & (insertion_conf < quality_threshold)
+ num_bad = candidates.sum(dim=1) # [B], long
+ min_length = expected_length.long().clamp(min=1) # [B]
+ max_removable = (current_length_after.long() - min_length).clamp(min=0)
+ length_after_removal = current_length_after.long() - num_bad
+ schedule_violates = length_after_removal < min_length
+ k_per_row = torch.where(schedule_violates, max_removable, num_bad)
+ k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
+
+ if not candidates.any():
+ return xt_tmp
+
+ # Select the lowest-confidence candidates per row via a sort.
+ neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
+ scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
+ _, sorted_indices = scores.sort(dim=1, descending=True)
+ positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
+ final_bad = torch.zeros_like(candidates)
+ final_bad.scatter_(1, sorted_indices, keep_in_topk)
+
+ if not final_bad.any():
+ return xt_tmp
+
+ # Compact each row to the left (keep good, drop bad), then pad the tail.
+ # Stable sort by the bad flag pushes bad positions to the right.
+ sort_key = final_bad.long()
+ _, perm = torch.sort(sort_key, dim=1, stable=True)
+ xt_tmp = torch.gather(xt_tmp, 1, perm)
+ num_keep = (~final_bad).sum(dim=1) # [B]
+ tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
+ xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
+
+ return xt_tmp
+
+
+def apply_schedule_aware_remasking(
+ model,
+ new_xt,
+ t,
+ dt,
+ remasking_conf,
+ clean_index,
+ mask,
+ neg_inf,
+ batch_size,
+ unmask_quality_threshold=None,
+):
+ """
+ Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
+
+ Args:
+ model: Model with interpolant that has an unmask_schedule
+ new_xt: Current sequence [B, L]
+ t: Current time [B]
+ dt: Time step size
+ remasking_conf: Confidence scores for tokens [B, L]
+ clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
+ mask: Mask token ID
+ neg_inf: Negative infinity tensor
+ batch_size: Batch size
+
+ Returns:
+ new_xt: Modified sequence with schedule-aware remasking applied
+ """
+ # Optional AJD threshold gate (overrides the schedule-driven count when set):
+ # remask every clean token whose unmasking-quality confidence is below the
+ # threshold. Higher threshold => more aggressive remasking.
+ if unmask_quality_threshold is not None:
+ to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
+ return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
+
+ t_next = t + dt
+ num_clean = clean_index.sum(dim=1) # [B], long
+ current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
+ expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
+ expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
+ masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
+
+ # Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
+ k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
+
+ if k_per_row.sum() == 0:
+ return new_xt
+
+ # Use confidence to decide which clean tokens to remask: lowest conf first.
+ remasking_score_temp = -1.0 * remasking_conf # low conf = high score
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
+
+ _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
+ L = remasking_score_temp.shape[1]
+ positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
+ to_mask = torch.zeros_like(clean_index)
+ to_mask.scatter_(1, sorted_indices, keep_in_topk)
+ new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
+
+ return new_xt
diff --git a/a2d2_mol/sampling.py b/a2d2_mol/sampling.py
new file mode 100755
index 0000000000000000000000000000000000000000..2bdd2eee730434849c946259f801ad527fd15723
--- /dev/null
+++ b/a2d2_mol/sampling.py
@@ -0,0 +1,1401 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
+
+import torch
+from dataclasses import dataclass
+from typing import Any, Literal, Optional
+import numpy as np
+import pandas as pd
+
+from lightning_modules.mdm import MaskedDiffusionModule
+
+
+@dataclass
+class SamplingTraceDatapoint:
+ t: float
+ event_type: Literal["insertion", "change"]
+ position: int
+ token: Any
+
+
+@dataclass
+class SamplingResult:
+ samples: torch.Tensor
+ # Trace is supposed to be processed sequentially as updates are not commutative
+ trace: Optional[list[SamplingTraceDatapoint]]
+
+ def __iter__(self):
+ yield from [self.samples, self.trace]
+
+
+# Sample from categorical distribution for each position using the transition probabilities
+def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
+ """Sample one token per position from probability distribution.
+ Args:
+ probs: [batch_size, seq_len, vocab_size] transition probabilities
+ Returns:
+ [batch_size, seq_len] sampled token indices
+ """
+ batch_size, seq_len, vocab_size = probs.shape
+ flat_probs = probs.view(-1, vocab_size)
+ samples = torch.multinomial(flat_probs, num_samples=1)
+ return samples.view(batch_size, seq_len)
+
+
+def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
+
+ batch_size, seq_len, vocab_size = probs.shape
+
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
+ noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
+
+ # select the highest score (most likely category after Gumbel noise)
+ samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
+
+ return samples.view(batch_size, seq_len)
+
+@torch.no_grad()
+def mdm_euler_sampling(
+ model: MaskedDiffusionModule,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+):
+ assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ for i in range(steps):
+ print("i-th sampling step")
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ _xt = xt.clone()
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0
+ print(trans_prob[mask_pos + (mask,)])
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt = torch.where(xt != mask, xt, new_xt)
+
+ xt = new_xt
+ t = t + dt
+
+ return xt, []
+
+
+@torch.no_grad()
+def any_order_mask_insertion_euler_sampling(
+ model: torch.nn.Module,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ # 1) Initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # Check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # Check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ return xt, sampling_trace
+
+@torch.no_grad()
+def batch_mcts_reverse_step(
+ xt: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ model: torch.nn.Module,
+ pretrained: torch.nn.Module,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ last_step: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ xt = xt.repeat(batch_size, 1)
+
+ # squeeze to remove extra dimensions, then expand to batch_size
+ t = t.squeeze().expand(batch_size)
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ if last_step:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if not last_step:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ # log P(ext; λ) = ext*log(λ) - λ
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+ log_pretrained_step += log_pretrained_insert
+ log_policy_step += log_policy_insert
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
+
+
+@torch.no_grad()
+def mcts_reverse_step(
+ xt: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ model: torch.nn.Module,
+ pretrained: torch.nn.Module,
+ mask: int,
+ pad: int,
+ max_length: int,
+ last_step: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ batch_size = xt.size(0)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ if last_step:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if not last_step:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ # log P(ext; λ) = ext*log(λ) - λ
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+ log_pretrained_step += log_pretrained_insert
+ log_policy_step += log_policy_insert
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
+
+@torch.no_grad()
+def any_order_euler_sampling_with_schedule(
+ model: torch.nn.Module,
+ time_schedule: torch.Tensor,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ time_schedule = time_schedule.to(device)
+ if time_schedule[0] < time_schedule[-1]:
+ time_schedule = torch.flip(time_schedule, [0]) # descending order
+
+ steps = len(time_schedule) - 1
+
+ # initialize all-pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # use scheduled timesteps
+ t = time_schedule[i].repeat(batch_size)
+ t_next = time_schedule[i + 1]
+ dt = (t - t_next).abs() # timestep difference
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
+
+ # add "stay" probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+
+ if mask_has_zero_prob.any():
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # Check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # Check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+
+ return xt, sampling_trace
+
+
+@torch.no_grad()
+def any_order_mask_insertion_euler_sampling_with_rnd(
+ model, pretrained, reward_model, analyzer,
+ tokenizer, steps,
+ mask,
+ pad,
+ batch_size,
+ max_length,
+ return_trace = False,
+ alpha = 0.1,
+ temperature: float = 1.0,
+):
+ device = next(model.parameters()).device
+
+ # initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ # initialize log_rnd to accumulate log probability ratios
+ log_rnd = torch.zeros(batch_size, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if i != steps - 1:
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # check if the token was changed
+ for i in range(batch_size):
+ for j in range(max_length):
+ if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
+ sampling_trace[i].append(
+ SamplingTraceDatapoint(
+ t=t[i].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[i, j].item(),
+ )
+ )
+
+ # check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[i, id]:
+ sampling_trace[i].append(
+ SamplingTraceDatapoint(
+ t=t[i].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ # change rewards for peptides
+ samples = xt.to(device)
+
+ # store raw token IDs
+ # Decode and strip samples
+ decoded_samples = tokenizer.batch_decode(samples)
+
+ valid_x_final = []
+ validSequences = []
+ valid_log_rnd = []
+
+ for idx, seq in enumerate(decoded_samples):
+ # check if the peptide is valid
+ if analyzer.is_peptide(seq):
+ valid_x_final.append(xt[idx])
+ validSequences.append(seq)
+ valid_log_rnd.append(log_rnd[idx])
+
+ print("len valid sequences:", len(validSequences))
+ # compute multi-objective rewards
+ score_vectors = reward_model(input_seqs=validSequences)
+ scalar_rewards = np.sum(score_vectors, axis=-1)
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
+
+ print(f"scalar reward dim{len(scalar_rewards)}")
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
+
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
+ valid_x_final = torch.stack(valid_x_final, dim=0)
+
+ return valid_x_final, log_rnd, scalar_rewards, sampling_trace
+
+@torch.no_grad()
+def any_order_finetuned_euler_sampler(
+ model, reward_model, analyzer,
+ tokenizer, steps,
+ mask,
+ pad,
+ batch_size,
+ max_length,
+ return_trace = False,
+ dataframe = False,
+ temperature: float = 1.0,
+ ):
+ device = next(model.parameters()).device
+
+ # initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # gap-wise insertion refactored — compute new length, fill masks, scatter tokens
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ # start eval
+ samples = xt.to(device)
+
+ decoded_samples = tokenizer.batch_decode(samples)
+
+ valid_x_final = []
+ validSequences = []
+
+ for idx, seq in enumerate(decoded_samples):
+ if analyzer.is_peptide(seq):
+ valid_x_final.append(samples[idx])
+ validSequences.append(seq)
+
+ print("len valid sequences:", len(validSequences))
+ valid_fraction = len(validSequences) / batch_size
+
+ if (len(validSequences) != 0):
+ # add scores to log
+ score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
+ average_scores = score_vectors.T
+
+ affinity = average_scores[0]
+ sol = average_scores[1]
+ hemo = average_scores[2]
+ nf = average_scores[3]
+ permeability = average_scores[4]
+
+ else:
+ zeros = [0.0]
+
+ affinity = zeros
+ sol = zeros
+ hemo = zeros
+ nf = zeros
+ permeability = zeros
+
+ if dataframe:
+ df = pd.DataFrame({
+ "Peptide Sequence": validSequences,
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
+ "Solubility": sol if len(validSequences) else [0.0],
+ "Hemolysis": hemo if len(validSequences) else [0.0],
+ "Nonfouling": nf if len(validSequences) else [0.0],
+ "Permeability": permeability if len(validSequences) else [0.0],
+ })
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
+
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
+
+@torch.no_grad()
+def mdm_tau_leaping_sampling(
+ model: MaskedDiffusionModule,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+):
+ assert not return_trace, "Trace is not yet supported"
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred = model(xt, t)
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
+ unmask_rate = pred.unmask_rate # (B, L, V)
+
+ if i == steps - 1:
+ # last step: deterministic unmask via argmax
+ mask_pos = xt == mask # (B, L)
+ new_token = unmask_rate.argmax(dim=2) # (B, L)
+ new_xt = xt.clone()
+ new_xt[mask_pos] = new_token[mask_pos]
+ new_xt = torch.where(xt != mask, xt, new_xt)
+ xt = new_xt
+ t = t + dt
+ continue
+ # tau-leaping via Poisson counts
+ counts = torch.poisson(unmask_rate * dt).long()
+ mask_pos = xt == mask # (B, L)
+ # zero out non-mask positions and mask→mask
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
+ counts[..., mask] = 0
+ # only accept exactly one event
+ sum_c = counts.sum(dim=2) # (B, L)
+ one_event = sum_c == 1
+ new_token = counts.argmax(dim=2) # (B, L)
+
+ # build new xt
+ new_xt = xt.clone()
+ new_xt[one_event] = new_token[one_event]
+ # keep pads and already-unmasked tokens
+ new_xt = torch.where(xt != mask, xt, new_xt)
+ xt = new_xt
+ t = t + dt
+
+ return xt, []
+
+# Not used in production, for debugging purposes
+lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
+
+def binomial_mass(k, n, p):
+ """
+ Calculate the probability mass function (PMF) for a binomial distribution.
+
+ Args:
+ k (int): Number of successes
+ n (int): Number of trials
+ p (float): Probability of success in a single trial
+
+ Returns:
+ float: Probability mass P(X = k)
+ """
+ import math
+
+ # Calculate binomial coefficient (n choose k)
+ try:
+ binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
+ except ValueError:
+ # Handle cases where k > n or negative values
+ return 0.0
+
+ # Calculate probability mass
+ return binom_coef * (p ** k) * ((1 - p) ** (n - k))
+
+def calculate_rate_batch(alpha_t, len_t):
+ """
+ Calculate rate for a batch of alpha_t and len_t values.
+
+ Args:
+ alpha_t (torch.Tensor): Tensor of shape (batch_size,)
+ len_t (torch.Tensor): Tensor of shape (batch_size,)
+
+ Returns:
+ torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
+ """
+ batch_size = alpha_t.shape[0]
+ device = alpha_t.device
+
+ # Initialize tensors for numerator and denominator
+ nom = torch.zeros(batch_size, device=device)
+ denom = torch.zeros(batch_size, device=device)
+
+ for length, probability in lengths.items():
+ # Create mask for valid entries where len_t <= length
+ valid_mask = (len_t <= length) & (len_t >= 0)
+
+ if not valid_mask.any():
+ continue
+
+ valid_indices = valid_mask.nonzero(as_tuple=True)[0]
+ valid_len_t = len_t[valid_indices]
+ valid_alpha_t = alpha_t[valid_indices]
+
+ # Calculate binomial probabilities efficiently using torch distribution
+ binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
+ binom_probs = binom_dist.log_prob(valid_len_t).exp()
+
+ # Update numerator and denominator for valid indices
+ nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
+ denom[valid_indices] += probability * binom_probs
+
+ # Handle division by zero in a vectorized way
+ result = torch.zeros_like(nom)
+ div_mask = denom > 0
+ result[div_mask] = nom[div_mask] / (denom[div_mask])
+
+ return result
+
+# Keep the original function for backward compatibility
+def calculate_rate(alpha_t, len_t):
+ """Legacy scalar version of calculate_rate"""
+ if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
+ return calculate_rate_batch(alpha_t, len_t)
+
+ nom, denom = 0, 0
+ for length, probability in lengths.items():
+ if length >= len_t:
+ nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
+ denom += probability * binomial_mass(len_t, length, alpha_t)
+
+ if denom == 0:
+ return 0.0
+
+ return nom /denom
+
+
+@torch.no_grad()
+def any_order_mask_insertion_tau_leaping_sampling(
+ model: torch.nn.Module,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ confidence_based_sampling: bool = True, # whether to use confidence-based decoding
+ alpha: float = 5.0, # hyperparameter for window size calculation
+ max_window: int = 32, # Maximum window size for sliding window
+ confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
+ use_sliding_window: bool = False, # whether to use sliding window for position selection
+ temperature: float = 1.0,
+) -> SamplingResult:
+
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ for i in range(steps):
+ # --- predict rates ---
+ pred = model(xt, t)
+ xt_len = (xt != pad).sum(dim=1)
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
+ unmask_rate = pred.unmask_rate # (B, L, V)
+ len_rate = pred.length_rate # (B, L+1)
+
+ if i == steps - 1:
+ # last step: deterministic unmask via argmax
+ mask_pos = xt == mask
+ new_token = unmask_rate.argmax(dim=2)
+ new_xt = xt.clone()
+ new_xt[mask_pos] = new_token[mask_pos]
+ new_xt = torch.where(xt == pad, pad, new_xt)
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+ xt = new_xt
+ t = t + dt
+ continue
+
+ # --- confidence-based decoding ---
+ if confidence_based_sampling > 0.0:
+ # Confidence-based unmasking (vectorized)
+ mask_positions = (xt == mask) # (B, L)
+ num_mask_positions = mask_positions.sum(dim=1) # (B,)
+
+ # 1. Determine number of tokens to unmask using Poisson
+ unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
+
+ # 2. Calculate confidence based on selected method
+ if confidence_method == "position":
+ # Position-based confidence: position i / len(xt)
+ xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
+ position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
+ confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
+
+ elif confidence_method == "top_prob":
+ # Top probability confidence
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ confidence = unmask_probs.max(dim=-1)[0] # (B, L)
+
+ elif confidence_method == "prob_diff":
+ # Probability difference confidence (top - second top)
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V)
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
+ confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
+
+ elif confidence_method == "entropy":
+ # Entropy-based confidence (lower entropy = higher confidence)
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V)
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
+ confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
+
+ else:
+ raise ValueError(f"Unknown confidence_method: {confidence_method}")
+
+ # 3. Apply window constraint if enabled
+ if use_sliding_window:
+ # Calculate dynamic k for each batch
+ k_values = torch.minimum(
+ torch.minimum(
+ (alpha * unmask_counts).long(),
+ torch.tensor(max_window, device=device)
+ ), num_mask_positions) # (B,)
+
+ # Get cumulative count of mask positions
+ mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
+
+ # Create window mask: position is eligible if it's a mask and within first k masks
+ is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
+ window_mask = mask_positions & is_within_window # (B, L)
+
+ # Set confidence to -inf for positions outside the window or non-mask positions
+ confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
+ else:
+ # No window constraint - only mask positions are eligible
+ confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
+
+ new_xt = xt.clone()
+
+ # vectorized unmasking
+ max_unmask = unmask_counts.max().item()
+ if max_unmask > 0:
+ _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
+
+ # create mask for valid unmask operations
+ unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
+
+ most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
+
+ selected_positions = all_top_indices[unmask_mask]
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
+
+ new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
+ else:
+ # --- tau-leaping unmask via Poisson ---
+ counts = torch.poisson(unmask_rate * dt).long()
+ mask_pos = xt == mask
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
+ counts[..., mask] = 0
+ sum_c = counts.sum(dim=2)
+ one_event = sum_c == 1
+ new_token = counts.argmax(dim=2)
+ new_xt = xt.clone()
+ new_xt[one_event] = new_token[one_event]
+ new_xt = torch.where(xt == pad, pad, new_xt)
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # insertion only on non-last
+ if i != steps - 1:
+ # --- Poisson insertion, compute new lengths and fill masks ---
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ # compute prefix sums of insertions
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ # initialize with pads, then fill mask up to new_len
+ xt_tmp = torch.full_like(xt, pad)
+ mask_pos = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_pos] = mask
+
+ # shift and scatter original tokens
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ xt = xt_tmp
+ t = t + dt
+ if return_trace:
+ sampling_trace.append(xt)
+
+ return xt, sampling_trace
diff --git a/a2d2_mol/scripts/run_mol_finetune.slurm b/a2d2_mol/scripts/run_mol_finetune.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..f629a011f69d7ee13fe38d0b50fd054d06ac69c6
--- /dev/null
+++ b/a2d2_mol/scripts/run_mol_finetune.slurm
@@ -0,0 +1,200 @@
+#!/bin/bash
+# NOTE: --partition and --qos below are specific to our cluster. Change them
+# (or remove them and pass `--partition` on the `sbatch` command line) to match
+# the partitions/QOS available on yours.
+#SBATCH --job-name=mol-finetune
+#SBATCH --partition=dgx-b200
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=8
+#SBATCH --ntasks-per-node=1
+#SBATCH --mem=80GB
+#SBATCH --time=02-00:00:00
+#SBATCH --output=logs/slurm-%A.%x.log
+
+# =====================================================================
+# run_mol_finetune.slurm
+#
+# Single-mode job (1 MIG GPU) running ONE finetune_mol experiment.
+# Select which mode to run via the MODE_ID variable below (or override
+# at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
+# 0) A2D2 (Ours) – with full planner (alternating)
+# 1) A2D2 w/o quality – --disable_planner
+# 2) A2D2 w/o insertion planner – --disable_insertion_planner
+# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
+#
+# The job trains the selected mode then evaluates the resulting
+# checkpoint on the same GPU.
+# =====================================================================
+
+set -e
+
+# --- Mode selection ---------------------------------------------------
+# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
+MODE_ID="${MODE_ID:-0}"
+
+# Run prefix
+PREFIX=${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}
+
+# --- Paths ------------------------------------------------------------
+# Repo root is resolved at submit time so the job runs from any clone:
+# - set A2D2_ROOT explicitly, OR
+# - submit with `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR;
+# note sbatch copies the script to a spool dir, so we can't rely on the
+# script's own path here), OR
+# - run the script directly, falling back to its location on disk.
+if [ -n "${A2D2_ROOT:-}" ]; then
+ HOME_LOC="$A2D2_ROOT"
+elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
+ HOME_LOC="$SLURM_SUBMIT_DIR"
+else
+ # This script lives in a2d2_mol/scripts/, so the repo root is two levels up.
+ HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
+fi
+SCRIPT_LOC="$HOME_LOC/a2d2_mol"
+LOG_LOC=$HOME_LOC/logs
+SAVE_DIR=$HOME_LOC/checkpoints/finetune_mol
+RESULTS_DIR=$HOME_LOC/results/mol_ablation
+
+mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
+
+# --- Environment setup ------------------------------------------------
+# Set WANDB_API_KEY in your shell/secret store before submitting (do NOT commit it):
+# export WANDB_API_KEY=... or `wandb login`
+export WANDB_DIR=$HOME_LOC/.wandb
+export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
+export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
+mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
+
+export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
+mkdir -p "$TRITON_CACHE_DIR"
+
+export TORCHINDUCTOR_CACHE_DIR=$HOME_LOC/.torchinductor/cache
+mkdir -p "$TORCHINDUCTOR_CACHE_DIR"
+
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+# Force unbuffered stdout/stderr so live training output is flushed to the
+# redirected RUN_LOG (Python block-buffers stdout when it's a file, not a TTY).
+export PYTHONUNBUFFERED=1
+
+# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
+# install, or just have `conda` on PATH; override CONDA_ENV if your env name
+# differs from the one created by environment.yml.
+CONDA_ENV="${CONDA_ENV:-a2d2}"
+if [ -n "${CONDA_ROOT:-}" ]; then
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
+elif command -v conda >/dev/null 2>&1; then
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
+else
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
+ exit 1
+fi
+PYTHON_EXECUTABLE=$(which python)
+
+cd "$SCRIPT_LOC"
+
+# Pretrained base checkpoint
+PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_mol.ckpt"
+
+# --- Shared training hyperparameters ----------------------------------
+COMMON_ARGS=(
+ --base_path "$HOME_LOC"
+ --use_quality_filter
+ --noise_removal
+ --wdce_num_replicates 16
+ --pool_size 1000
+ --pool_refresh_fraction 0.3
+ --buffer_size 100
+ --batch_size 200
+ --training_mini_batch_size 20
+ --max_length 256
+ --total_num_steps 256
+ --num_iter 20
+ --resample_every_n_step 10
+ --num_epochs 1000
+ --save_every_n_epochs 100
+ --reset_every_n_step 1
+ --alpha 0.01
+ --no_mcts
+ --schedule_warmup_epochs 20
+ --alternation_frequency 5
+ --num_remasking 3
+ --quality_threshold 0.3
+ --checkpoint_path "$PRETRAINED_CKPT"
+ --grad_clip
+ --qed_only
+ --seed 42
+ --num_training_steps_per_epoch 25
+)
+
+# --- Shared evaluation hyperparameters --------------------------------
+EVAL_COMMON_ARGS=(
+ --pretrained_ckpt "$PRETRAINED_CKPT"
+ --num_samples 1000
+ --batch_size 50
+ --max_length 256
+ --total_num_steps 256
+ --num_remasking 2
+ --quality_threshold 0.3
+ --seed 42
+)
+
+# =====================================================================
+# Pick experiment from $MODE_ID
+# =====================================================================
+case "$MODE_ID" in
+ 0) MODE="with_planner"; EXTRA_ARGS=() ;;
+ 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
+ 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
+ 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
+ *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
+esac
+
+RUN_NAME="${PREFIX}_mol_${MODE}"
+RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
+RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
+RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
+mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
+
+echo "=== Mol finetune (MODE_ID=$MODE_ID) ==="
+echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
+echo "Mode: $MODE"
+echo "Save dir: $RUN_SAVE_DIR"
+echo "Results dir: $RESULTS_SUBDIR"
+echo "Python: $PYTHON_EXECUTABLE"
+echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
+
+# =====================================================================
+# Train
+# =====================================================================
+$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_mol.py \
+ "${COMMON_ARGS[@]}" \
+ --devices 1 \
+ "${EXTRA_ARGS[@]}" \
+ --save_path_dir "$RUN_SAVE_DIR" \
+ >> "$RUN_LOG" 2>&1
+
+echo "Training finished for $MODE. Log: $RUN_LOG"
+
+# =====================================================================
+# Evaluate
+# =====================================================================
+RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt "$RUN_SAVE_DIR"/last.ckpt 2>/dev/null | head -1)
+if [ -z "$RUN_CKPT" ]; then
+ echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
+ exit 1
+fi
+
+echo "Evaluating checkpoint: $RUN_CKPT"
+$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_mol_table.py \
+ --checkpoint_path "$RUN_CKPT" \
+ "${EVAL_COMMON_ARGS[@]}" \
+ "${EXTRA_ARGS[@]}" \
+ --output_dir "$RESULTS_SUBDIR" \
+ --device cuda:0 \
+ >> "$RESULTS_SUBDIR/eval.log" 2>&1
+
+echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}.csv"
+
+conda deactivate
diff --git a/a2d2_mol/scripts/train_mol.sh b/a2d2_mol/scripts/train_mol.sh
new file mode 100755
index 0000000000000000000000000000000000000000..33bc79330da861c95e033e8b615a1f667a2b1725
--- /dev/null
+++ b/a2d2_mol/scripts/train_mol.sh
@@ -0,0 +1,93 @@
+#!/bin/bash
+#SBATCH --job-name=a2d2-mol-pretrain
+#SBATCH --partition=dgx-b200
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=2
+#SBATCH --ntasks-per-node=2
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=512GB
+#SBATCH --time=7-00:00:00
+# SLURM's own catch-file (anything printed before the exec redirect below, plus
+# slurm-infra messages). Relative to the submit dir, so submit this script from
+# the a2d2_mol/ directory; the real run output is redirected via exec below.
+#SBATCH --output=logs/slurm/%x_%j.out
+#SBATCH --error=logs/slurm/%x_%j.err
+#
+# Pretrain the any-length insertion MDM on drug-like SAFE molecules on a dgx-b200 node.
+# Submit with: sbatch scripts/train_mol.sh (from the a2d2_mol/ directory).
+#
+# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
+# --ntasks-per-node must match; change both together (and they override the
+# training.devices value baked into config_mol.yaml via the hydra override below).
+
+DATE=$(date +%Y%m%d)
+SPECIAL_PREFIX='a2d2-mol'
+
+# Resolve a2d2_mol/ (which holds train.py + config_mol.yaml) so paths are
+# repo-relative. This script lives in a2d2_mol/scripts/, so the direct-run
+# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
+# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_mol/ directory).
+if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
+ SCRIPT_DIR="$SLURM_SUBMIT_DIR"
+else
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+fi
+cd "$SCRIPT_DIR"
+
+# Auto-detect GPUs from the SLURM allocation (falls back to 2 for `bash` runs).
+DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-2}}
+NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
+NODES=${SLURM_NNODES:-1}
+
+LOG_LOC="$SCRIPT_DIR/logs"
+mkdir -p "$LOG_LOC/slurm"
+exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
+
+# ---------------------------------------------------------------------------
+# Weights & Biases: log in once on your machine before running this script with
+# `wandb login` (or `export WANDB_API_KEY=`).
+# Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
+# export WANDB_MODE=disabled
+# ---------------------------------------------------------------------------
+
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+
+# Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
+# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
+# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
+# PATH; override CONDA_ENV if your env name differs from the one created by
+# environment.yml.
+CONDA_ENV="${CONDA_ENV:-a2d2}"
+if [ -n "${CONDA_ROOT:-}" ]; then
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
+elif command -v conda >/dev/null 2>&1; then
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
+else
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
+ exit 1
+fi
+
+# --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
+ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
+if [ -z "$ETH_IFACE" ]; then
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
+fi
+export NCCL_IB_DISABLE=1
+export NCCL_SOCKET_FAMILY=AF_INET
+export NCCL_SOCKET_IFNAME=$ETH_IFACE
+export NCCL_P2P_LEVEL=NVL
+
+export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
+export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
+export NODE_RANK=${SLURM_NODEID:-0}
+
+echo "=== a2d2 molecule pretraining (dgx-b200) ==="
+echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
+
+# --task mol makes train.py load config_mol.yaml; the hydra overrides pin
+# devices/nodes to the SLURM allocation so the two never drift apart.
+srun --ntasks-per-node=$NTASKS python train.py --task mol \
+ training.devices=$DEVICES \
+ training.nodes=$NODES
+
+conda deactivate
diff --git a/a2d2_mol/train.py b/a2d2_mol/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..84e8e0196fc50b1fa496a1827394a7518bd25fdf
--- /dev/null
+++ b/a2d2_mol/train.py
@@ -0,0 +1,216 @@
+import torch
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+import os
+import sys
+import argparse
+import hydra
+from omegaconf import OmegaConf
+from datetime import datetime
+# Directory containing this file and the config_*.yaml files (used by Hydra below).
+CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
+# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
+sys.path.insert(0, os.path.dirname(CONFIG_DIR))
+
+import wandb
+from lightning_modules import AnyOrderInsertionFlowModule
+
+
+torch.set_printoptions(threshold=10_000)
+torch.set_float32_matmul_precision("high")
+
+# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
+torch._dynamo.config.optimize_ddp = False
+
+def train(config):
+ wandb_logger = None
+
+ # set the random seed
+ pl.seed_everything(42)
+ torch.manual_seed(42)
+
+ # Only initialize wandb on rank 0 to avoid multiple runs
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+ wandb.init(
+ project=config.wandb.project,
+ name=config.wandb.name,
+ config=OmegaConf.to_container(config, resolve=True), # Convert to dict
+ dir=config.wandb.path
+ )
+ wandb_logger = WandbLogger(
+ project=wandb.run.project,
+ name=wandb.run.name,
+ log_model=False, # Disable checkpoint uploading to save disk space
+ )
+
+ # Modify config to add timestamp to checkpoint directory
+ OmegaConf.set_struct(config, False)
+ time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
+ config.training.checkpoint_dir = os.path.join(
+ config.training.checkpoint_dir, time_string
+ )
+ OmegaConf.set_struct(config, True)
+
+ # Create checkpoint directory
+ os.makedirs(config.training.checkpoint_dir, exist_ok=True)
+
+ # Setup data module - check if using HuggingFace dataset
+ if hasattr(config, 'hf_dataset'):
+ # Imported lazily: the HF/SAFE path is only used by the molecule configs,
+ # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
+ from mol_dataset import setup_hf_data_and_update_config
+ print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
+ data_module = setup_hf_data_and_update_config(
+ config,
+ dataset_name=config.hf_dataset.name,
+ smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
+ )
+ else:
+ # Imported lazily: the local (arrow) path is used by the peptide config,
+ # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
+ from dataloading_for_dynamic_batching import setup_data_and_update_config
+ print("Using local dataset")
+ data_module = setup_data_and_update_config(config)
+
+ module = AnyOrderInsertionFlowModule(config)
+
+ # Initialize trainer
+
+ # Configure trainer arguments
+ # Map torch_dtype to Lightning precision
+ dtype_str = config.model.get('torch_dtype', 'bfloat16')
+ precision_map = {
+ 'float32': '32-true',
+ 'float16': '16-mixed',
+ 'bfloat16': 'bf16-mixed'
+ }
+ precision = precision_map.get(dtype_str, 'bf16-mixed')
+
+ trainer_kwargs = dict(
+ num_nodes=config.training.nodes,
+ accelerator="gpu",
+ devices=config.training.devices,
+ strategy="ddp",
+ precision=precision,
+ accumulate_grad_batches=(
+ config.training.batch_size
+ // (
+ config.training.per_gpu_batch_size
+ * config.training.nodes
+ * config.training.devices
+ )
+ ),
+ log_every_n_steps=10,
+ enable_checkpointing=True,
+ default_root_dir=config.training.checkpoint_dir,
+ gradient_clip_val=1.0,
+ )
+ # Only one of max_steps or max_epochs will be used
+ if config.training.max_steps is not None:
+ trainer_kwargs["max_steps"] = config.training.max_steps
+ elif config.training.num_epochs is not None:
+ trainer_kwargs["max_epochs"] = config.training.num_epochs
+ config.training.max_steps = config.training.max_steps
+ else:
+ raise ValueError(
+ "Either max_steps or num_epochs must be specified in the config"
+ )
+
+ if config.training.warmup_steps is None:
+ config.training.warmup_steps = int(config.training.max_steps * 0.01)
+
+ # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
+ checkpoint_callback = ModelCheckpoint(
+ monitor="train/total_loss",
+ mode="min",
+ save_top_k=config.training.save_top_k,
+ save_last=True,
+ filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
+ dirpath=config.training.checkpoint_dir,
+ # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
+ auto_insert_metric_name=False
+ )
+
+ # Add separate callback for periodic saves (no val_loss dependency). Use
+ # step-based saves for streaming datasets (save_every_n_steps) and epoch-based
+ # saves otherwise (save_every_n_epochs); whichever the config provides.
+ save_every_n_steps = config.training.get('save_every_n_steps', None)
+ save_every_n_epochs = config.training.get('save_every_n_epochs', None)
+ if save_every_n_steps is not None:
+ periodic_checkpoint_callback = ModelCheckpoint(
+ save_top_k=-1, # Save all periodic checkpoints
+ filename="step-{step:08d}",
+ dirpath=config.training.checkpoint_dir,
+ every_n_train_steps=save_every_n_steps,
+ auto_insert_metric_name=False
+ )
+ elif save_every_n_epochs is not None:
+ periodic_checkpoint_callback = ModelCheckpoint(
+ save_top_k=-1, # Save all periodic checkpoints
+ filename="epoch-{epoch:02d}",
+ dirpath=config.training.checkpoint_dir,
+ every_n_epochs=save_every_n_epochs,
+ auto_insert_metric_name=False
+ )
+ else:
+ raise ValueError(
+ "Either save_every_n_steps or save_every_n_epochs must be specified in the config"
+ )
+
+ trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
+
+ if wandb_logger is not None:
+ trainer_kwargs["logger"] = wandb_logger
+
+ trainer = pl.Trainer(**trainer_kwargs)
+
+ # Train the model
+ ckpt_path = None
+ if "resume_path" in config.training:
+ ckpt_path = config.training.resume_path
+
+ trainer.fit(module,
+ datamodule=data_module,
+ ckpt_path=ckpt_path)
+
+ # Only finish wandb on rank 0
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+ wandb.finish()
+
+
+if __name__ == '__main__':
+ # Parse arguments to get config name
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config_name', type=str, default='config',
+ help='Name of the config file to use')
+ parser.add_argument('--task', type=str, default=None,
+ help='Task name (uses config_{task}.yaml)')
+
+ # Parse known args (hydra will handle the rest)
+ args, unknown = parser.parse_known_args()
+
+ # Determine config name from task or config_name
+ if args.task:
+ config_name = f'config_{args.task}'
+ else:
+ config_name = args.config_name
+
+ print(f"Using config: {config_name}.yaml")
+
+ # Add config name to Hydra overrides (this persists across DDP subprocesses)
+ if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
+ unknown.insert(0, f'--config-name={config_name}')
+
+ # Reconstruct sys.argv for hydra
+ sys.argv = [sys.argv[0]] + unknown
+
+ # Define main function with default config (will be overridden by command line)
+ @hydra.main(version_base=None,
+ config_path=CONFIG_DIR,
+ config_name='config')
+ def main(config):
+ """Main entry point for training"""
+ train(config)
+
+ main()
\ No newline at end of file
diff --git a/a2d2_pep/README.md b/a2d2_pep/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..28f6d169706537f88ec3836ed9d406bc1d087854
--- /dev/null
+++ b/a2d2_pep/README.md
@@ -0,0 +1,145 @@
+# A2D2 for Multi-Objective Therapeutic Peptide Generation 🧫
+
+This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over peptide SMILES with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize **five therapeutic properties simultaneously**: binding affinity to a target protein, solubility, non-hemolysis, non-fouling, and cell-membrane permeability.
+
+A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating peptides via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.
+
+Peptides are represented as **SMILES** strings and tokenized with the SMILES Pair Encoding tokenizer (vocabulary size `V = 587`) from [PeptideCLM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01443). Generated SMILES are decoded and validity-checked with the `SMILES2PEPTIDE` filter from [PepTune](https://arxiv.org/abs/2412.17780).
+
+The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).
+
+## Environment Installation
+```
+# from the repository root
+conda env create -f environment.yml
+
+conda activate a2d2
+```
+The peptide scripts share the `a2d2` environment with the molecule and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.
+
+## Model Pretrained Weights
+
+A2D2 fine-tunes a pretrained any-length insertion MDM trained on ~11M peptide SMILES (7,451 sequences from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs). Download the base checkpoint and place it at:
+```
+A2D2/pretrained/anylength_pep.ckpt
+```
+```bash
+# from the repository root
+pip install gdown
+mkdir -p pretrained
+gdown 1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc -O pretrained/anylength_pep.ckpt
+```
+(Or download manually from https://drive.google.com/file/d/1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
+This is the default `--checkpoint_path`; pass `--checkpoint_path` to override it.
+
+The reward classifiers (binding-affinity Transformer, plus XGBoost predictors for solubility, hemolysis, non-fouling, and permeability) and the SMILES PE tokenizer ship with the repo under [`pep_scoring/`](pep_scoring); no separate download is required. The PeptideCLM embedding model is fetched automatically from the Hugging Face Hub (`aaronfeller/PeptideCLM-23M-all`) on first run.
+
+## Pretraining the Any-Length Model
+
+If you only want to fine-tune with A2D2, download the released `anylength_pep.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.
+
+### 1. Download the pretraining dataset
+
+The pretraining corpus is ~11M peptide SMILES (7,451 from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs), already tokenized with the in-repo SMILES PE tokenizer and saved as a Hugging Face `arrow` dataset (with `train`/`val` splits) via `save_to_disk`.
+
+Download the archive and unpack it into [`data/`](data):
+
+```bash
+# from a2d2_pep/
+pip install gdown
+gdown https://drive.google.com/uc?id=1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp -O 11M_peptide_smiles.zip
+mkdir -p data
+unzip 11M_peptide_smiles.zip -d data/
+# result: a2d2_pep/data/11M_peptide_smiles/{train,val}/...
+```
+
+This is the default `training.data_path` in [`config_pep.yaml`](config_pep.yaml). To store the dataset elsewhere, set `training.data_path` (absolute, or relative to `a2d2_pep/`).
+
+### 2. Configure
+
+Pretraining is driven by [`config_pep.yaml`](config_pep.yaml). Key fields:
+
+| Field | Default | Notes |
+|-------|---------|-------|
+| `training.data_path` | `data/11M_peptide_smiles` | Preprocessed arrow dataset from step 1. |
+| `training.devices` | `4` | GPUs per node (DDP). |
+| `training.batch_size` | `1024` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
+| `training.max_steps` | `1000000` | Total optimizer steps. |
+| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
+| `training.checkpoint_dir` | `checkpoints/peptides` | A timestamped subdirectory is created per run. |
+| `interpolant.max_length` | `1024` | Max token length. |
+
+### 3. Pre-training Any-Length Peptide Model
+
+Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:
+
+```bash
+# from a2d2_pep/
+sbatch train_pep.sh
+```
+
+`train_pep.sh` is a SLURM batch script that requests one `dgx-b200` node with 4 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:
+
+```bash
+python train.py --task pep
+```
+
+It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task pep` makes `train.py` load `config_pep.yaml`.
+
+Checkpoints are written to `checkpoints/peptides//` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` for fine-tuning); the run log goes to `logs/_a2d2-peptide_.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.
+
+## Fine-Tune with A2D2
+
+All paths resolve relative to the repository, so the scripts run from any checkout. Before running, create the output directories `A2D2/checkpoints`, `A2D2/results`, and `A2D2/logs` (the script also creates them on demand). Fine-tuning curves and a `_generation_results.csv` are written to `/results//`, and checkpoints to `--save_path_dir`.
+
+Choose a target protein with `--prot_name` (looked up in the built-in `PROTEINS` table — e.g. `glp1` for GLP-1R or `glast` for GLAST), or supply an arbitrary target with `--prot_seq `.
+
+#### Available `--prot_name` targets
+
+The named targets and their amino-acid sequences are defined in the `PROTEINS` dict in [`finetune_quality.py`](finetune_quality.py) (search for `PROTEINS = {`). The default is `glast`; passing a name not in the table raises an error listing the valid keys. To add a new target, add a `'': ''` entry there, or skip the table entirely with `--prot_seq`.
+
+| `--prot_name` | Target |
+|---------------|--------|
+| `tfr` | Transferrin receptor (TfR) |
+| `glp1` | GLP-1 receptor (GLP-1R) |
+
+### Single run
+
+[`scripts/run_peptide_finetune.slurm`](scripts/run_peptide_finetune.slurm) runs a single `finetune_quality.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the hyperparameter set from the peptide column of the fine-tuning table in the paper — replicates `R = 8`, buffer size `B = 50`, resample interval `N_resample = 10`, gradient steps per iteration `N_update = 10`, alternation frequency `N_alt = 5`, warmup `N_warmup = 20`, sampling steps `N_steps = 256`, training mini-batch `10`, reward scaling `α = 0.1`, quality threshold `μ_min = 0.5`, and `--num_obj 5` — so you don't have to pass them by hand.
+
+The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`) and `WANDB_ENTITY`:
+```bash
+export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone
+export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH
+export WANDB_ENTITY=your_wandb_entity # optional
+sbatch scripts/run_peptide_finetune.slurm
+```
+
+Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
+```bash
+sbatch --export=ALL,MODE_ID=2 scripts/run_peptide_finetune.slurm
+```
+The target protein is set by the `PROT_NAME` variable near the top of the script (default `tfr`); edit it to one of the named targets above (or any key in the `PROTEINS` table). The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_pep.ckpt`. Outputs land in `checkpoints/finetune_test_peptides_/_peptide__/` and `results/peptide_test_ablation_//`.
+
+### Key arguments
+- `--prot_name` / `--prot_seq` — target protein (named lookup, or a raw amino-acid sequence).
+- `--alternation_frequency` — epochs to train each of {policy, planner} before alternating.
+- `--alpha` — reward-tilting temperature (smaller = stronger reward optimization).
+- `--buffer_size`, `--resample_every_n_step` — replay-buffer size and how often it is regenerated.
+
+### Ablation flags
+| Flag | Variant |
+|------|---------|
+| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
+| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
+| `--disable_insertion_planner` | A2D2 w/o insertion quality |
+| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
+| `--joint_training` | train policy + quality heads jointly (no alternation) |
+
+During buffer generation only sequences passing the `SMILES2PEPTIDE` validity filter are retained; the scalarized multi-objective reward is added to the log Radon–Nikodym derivative of each sequence. Fine-tuning runs on a single GPU (`--devices 1`).
+
+## Evaluation
+
+Evaluation runs automatically every `--eval_every_n_epochs` epochs and at the end of training. It samples from the current model and reports the fraction of valid peptides along with the five therapeutic rewards (binding affinity, solubility, non-hemolysis, non-fouling, permeability), saving per-objective curves and `_generation_results.csv` under `/results//`.
+
+To resume a run, pass `--resume_ckpt /path/to/last.ckpt` (restores epoch, optimizer, and planner state; new checkpoints continue in the same directory).
diff --git a/a2d2_pep/config_pep.yaml b/a2d2_pep/config_pep.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..b7877e1a8fd72fd8bde7d391e5f6dda073df060f
--- /dev/null
+++ b/a2d2_pep/config_pep.yaml
@@ -0,0 +1,50 @@
+trainer: "any-order-flow"
+dataset: "peptides"
+
+model:
+ hidden_size: 768
+ n_heads: 12
+ cond_dim: 128
+ dropout: 0.05
+ n_blocks: 12
+
+interpolant:
+ type: "any-order"
+ tokens: null # filled in automatically
+ pad_token: null # filled in automatically
+ mask_token: null # filled in automatically
+ max_length: 1024
+ insert_schedule:
+ type: "linear"
+ unmask_schedule:
+ type: "linear"
+
+training:
+ only_embed_insert: true
+ batch_size: 1024
+ per_gpu_batch_size: 64 # Gradient accumulation happens automatically
+ cpus: 4
+ learning_rate: 3e-4
+ nodes: 1
+ devices: 4
+ max_steps: 1000000
+ weight_decay: 0.03
+ # Path to the preprocessed (arrow) pretraining dataset; see README for the download link.
+ # Relative paths resolve against a2d2_pep/. Defaults to a2d2_pep/data/11M_peptide_smiles.
+ data_path: "data/11M_peptide_smiles"
+ checkpoint_dir: "checkpoints/peptides"
+ save_top_k: 1
+ save_every_n_epochs: 1
+ loss_fn:
+ unmask: "elbo"
+ insert: "expectation"
+ reset_lr: false
+ warmup_steps: 2000
+ ema_decay: 0.9999
+ filter_max_length: false
+
+wandb:
+ entity: null # set to your W&B entity, or leave null to use the default
+ project: "a2d2-pep"
+ name: "a2d2-pep"
+ path: "./wandb"
diff --git a/a2d2_pep/data/dataloading_for_dynamic_batching.py b/a2d2_pep/data/dataloading_for_dynamic_batching.py
new file mode 100755
index 0000000000000000000000000000000000000000..ceffa4370e08296941116dc57932dec358eaec66
--- /dev/null
+++ b/a2d2_pep/data/dataloading_for_dynamic_batching.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env
+import os
+import torch
+from torch.utils.data import Dataset, DataLoader
+from datasets import Dataset,load_from_disk
+import sys
+import pytorch_lightning as pl
+from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
+from functools import partial
+import re
+
+# Directory containing this file; used to resolve the in-repo tokenizer files.
+_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class DynamicBatchingDataset(Dataset):
+ def __init__(self, dataset_dict, tokenizer):
+ print('Initializing dataset...')
+ self.dataset_dict = {
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
+ 'labels': dataset_dict['labels']
+ }
+ self.tokenizer = tokenizer
+
+ def __len__(self):
+ return len(self.dataset_dict['attention_mask'])
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ return {
+ 'input_ids': self.dataset_dict['input_ids'][idx],
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
+ 'labels': self.dataset_dict['labels'][idx]
+ }
+ elif isinstance(idx, list):
+ return {
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
+ }
+ else:
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
+
+class CustomDataModule(pl.LightningDataModule):
+ def __init__(self, dataset_path, tokenizer):
+ super().__init__()
+ self.dataset = load_from_disk(dataset_path)
+ self.tokenizer = tokenizer
+
+ def peptide_bond_mask(self, smiles_list):
+ """
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
+ of recognized bonds in the positions dictionary and 0 elsewhere.
+
+ Args:
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
+
+ Returns:
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
+ """
+ # Initialize the batch mask
+ batch_size = len(smiles_list)
+ max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
+
+ bond_patterns = [
+ (r'OC\(=O\)', 'ester'),
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
+ (r'C\(=O\)N[12]?', 'peptide')
+ ]
+
+ for batch_idx, smiles in enumerate(smiles_list):
+ positions = []
+ used = set()
+
+ # Identify bonds
+ for pattern, bond_type in bond_patterns:
+ for match in re.finditer(pattern, smiles):
+ if not any(p in range(match.start(), match.end()) for p in used):
+ positions.append({
+ 'start': match.start(),
+ 'end': match.end(),
+ 'type': bond_type,
+ 'pattern': match.group()
+ })
+ used.update(range(match.start(), match.end()))
+
+ # Update the mask for the current SMILES
+ for pos in positions:
+ mask[batch_idx, pos['start']:pos['end']] = 1
+
+ return mask
+
+ def peptide_token_mask(self, smiles_list, token_lists):
+ """
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
+
+ Args:
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
+ token_lists: List of tokenized SMILES strings (split into tokens).
+
+ Returns:
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
+ """
+ # Initialize the batch mask
+ batch_size = len(smiles_list)
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
+ atomwise_masks = self.peptide_bond_mask(smiles_list)
+
+
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
+ token_seq = token_lists[batch_idx]
+ atom_idx = 0
+
+ for token_idx, token in enumerate(token_seq):
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
+ tokenized_masks[batch_idx][token_idx] = 1
+ atom_idx += len(token)
+
+ return tokenized_masks
+
+ def collate_fn(self, batch):
+ item = batch[0]
+
+ token_array = self.tokenizer.get_token_split(item['input_ids'])
+ bond_mask = self.peptide_token_mask(item['labels'], token_array)
+
+ return {
+ 'input_ids': item['input_ids'],
+ 'attention_mask': item['attention_mask'],
+ 'bond_mask': bond_mask
+ }
+
+ def train_dataloader(self):
+ train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
+ return DataLoader(
+ train_dataset,
+ batch_size=1,
+ collate_fn=self.collate_fn, # Use the instance method
+ shuffle=True,
+ num_workers=12,
+ pin_memory=True
+ )
+
+ def val_dataloader(self):
+ val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
+ return DataLoader(
+ val_dataset,
+ batch_size=1,
+ collate_fn=self.collate_fn, # Use the instance method
+ num_workers=8,
+ pin_memory=True
+ )
+
+
+def setup_data_and_update_config(config):
+ """
+ Get the dataset and update the config with token information for text datasets.
+ """
+ # SMILES Pair Encoding tokenizer ships with the repo under pep_scoring/tokenizer/.
+ tokenizer = SMILES_SPE_Tokenizer(
+ os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
+ os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_splits.txt'),
+ )
+
+ config.interpolant.tokens = len(tokenizer)
+ config.interpolant.pad_token = tokenizer.pad_token_id
+ config.interpolant.mask_token = tokenizer.mask_token_id
+
+ # Path to the preprocessed (arrow) pretraining dataset saved via `save_to_disk`.
+ # Download instructions are in the README; override with `training.data_path` in the config.
+ data_path = config.training.get('data_path', os.path.join('data', '11M_peptide_smiles'))
+ if not os.path.isabs(data_path):
+ data_path = os.path.join(_THIS_DIR, data_path)
+ if not os.path.exists(data_path):
+ raise FileNotFoundError(
+ f"Pretraining dataset not found at '{data_path}'. Download it (see a2d2_pep/README.md, "
+ "'Pretraining the Any-Length Model') and set `training.data_path` in config_pep.yaml."
+ )
+ data_module = CustomDataModule(data_path, tokenizer)
+
+ return data_module
diff --git a/a2d2_pep/data/dataset.py b/a2d2_pep/data/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..a6cdd558951a4faaaaf53e109fca5a2bd71b4011
--- /dev/null
+++ b/a2d2_pep/data/dataset.py
@@ -0,0 +1,207 @@
+
+import re
+import torch
+
+import utils
+
+from torch.utils.data import Dataset, DataLoader
+import pytorch_lightning as pl
+from functools import partial
+import sys
+
+class CustomDataset(Dataset):
+ def __init__(self, dataset, indices):
+ self.dataset = dataset
+ self.indices = indices
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, idx):
+ actual_idx = int(self.indices[idx])
+ item = self.dataset[actual_idx]
+ return item
+
+
+# for weighting losses of peptide bonds
+def peptide_bond_mask(smiles_list):
+ """
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
+ of recognized bonds in the positions dictionary and 0 elsewhere.
+
+ Args:
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
+
+ Returns:
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
+ """
+ # Initialize the batch mask
+ batch_size = len(smiles_list)
+ max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
+
+ bond_patterns = [
+ (r'OC\(=O\)', 'ester'),
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
+ (r'C\(=O\)N[12]?', 'peptide')
+ ]
+
+ for batch_idx, smiles in enumerate(smiles_list):
+ positions = []
+ used = set()
+
+ # Identify bonds
+ for pattern, bond_type in bond_patterns:
+ for match in re.finditer(pattern, smiles):
+ if not any(p in range(match.start(), match.end()) for p in used):
+ positions.append({
+ 'start': match.start(),
+ 'end': match.end(),
+ 'type': bond_type,
+ 'pattern': match.group()
+ })
+ used.update(range(match.start(), match.end()))
+
+ # Update the mask for the current SMILES
+ for pos in positions:
+ mask[batch_idx, pos['start']:pos['end']] = 1
+
+ return mask
+
+def peptide_token_mask(smiles_list, token_lists):
+ """
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
+
+ Args:
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
+ token_lists: List of tokenized SMILES strings (split into tokens).
+
+ Returns:
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
+ """
+ # Initialize the batch mask
+ batch_size = len(smiles_list)
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
+ atomwise_masks = peptide_bond_mask(smiles_list)
+
+
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
+ token_seq = token_lists[batch_idx]
+ atom_idx = 0
+
+ for token_idx, token in enumerate(token_seq):
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
+ tokenized_masks[batch_idx][token_idx] = 1
+ atom_idx += len(token)
+
+ return tokenized_masks
+
+def extract_amino_acid_sequence(helm_string):
+ """
+ Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
+ removing any brackets around each amino acid.
+
+ Args:
+ helm_string (str): The HELM notation string for a peptide.
+
+ Returns:
+ list: A list containing each amino acid in sequence without brackets.
+ """
+ # Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
+ matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
+
+ if matches:
+ # Join all matched sequences and split by dots to get individual amino acids
+ amino_acid_sequence = []
+ for match in matches:
+ sequence = match.replace('[', '').replace(']', '').split('.')
+ amino_acid_sequence.extend(sequence)
+ return amino_acid_sequence
+ else:
+ return "Invalid HELM notation or no peptide sequence found."
+
+def helm_collate_fn(batch, tokenizer):
+ sequences = [item['HELM'] for item in batch]
+
+ max_len = 0
+ for sequence in sequences:
+ seq_len = len(extract_amino_acid_sequence(sequence))
+ if seq_len > max_len:
+ max_len = seq_len
+
+ tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
+
+ return {
+ 'input_ids': tokens['input_ids'],
+ 'attention_mask': tokens['attention_mask']
+ }
+
+
+def collate_fn(batch, tokenizer):
+ """Standard data collator that truncates/pad sequences based on max_length"""
+ valid_sequences = []
+ valid_items = []
+
+ for item in batch:
+ try:
+ test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
+ valid_sequences.append(item['SMILES'])
+ valid_items.append(item)
+ except Exception as e:
+ print(f"Skipping sequence due to: {str(e)}")
+ continue
+
+ #sequences = [item['SMILES'] for item in batch]
+ #max_len = max([len(seq) for seq in sequences])
+ #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
+
+ tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
+
+ token_array = tokenizer.get_token_split(tokens['input_ids'])
+ bond_mask = peptide_token_mask(valid_sequences, token_array)
+ #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
+
+ return {
+ 'input_ids': tokens['input_ids'],
+ 'attention_mask': tokens['attention_mask'],
+ 'bond_mask': bond_mask
+ }
+
+
+class CustomDataModule(pl.LightningDataModule):
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
+ super().__init__()
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ #self.test_dataset = test_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.collate_fn = collate_fn
+
+ def train_dataloader(self):
+ return DataLoader(self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
+ num_workers=8,
+ pin_memory=True
+ )
+
+
+ def val_dataloader(self):
+ return DataLoader(self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
+ num_workers=8,
+ pin_memory=True
+ )
+
+ """def test_dataloader(self):
+ return DataLoader(self.test_dataset, batch_size=self.batch_size,
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
+ num_workers=8, pin_memory=True)"""
\ No newline at end of file
diff --git a/a2d2_pep/evaluate_peptide_table.py b/a2d2_pep/evaluate_peptide_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce15303670f8a38e59e0f3055bd589f26c4ae075
--- /dev/null
+++ b/a2d2_pep/evaluate_peptide_table.py
@@ -0,0 +1,326 @@
+"""
+Evaluate a finetuned peptide model checkpoint by sampling sequences
+and computing metrics for the De Novo Peptide Generation table:
+ Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑),
+ Nonfouling (↑), Permeability (↑), Sampling Time (↓)
+"""
+
+import os
+import sys
+import argparse
+import time
+import torch
+import numpy as np
+import pandas as pd
+
+# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, REPO_ROOT)
+
+from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
+from lightning_modules import AnyOrderInsertionFlowModule
+from inference_quality import sample_peptides_eval
+from pep_scoring.scoring_functions import ScoringFunctions
+from pep_utils.analyzer import PeptideAnalyzer
+from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
+from finetune_quality import PeptideFinetuner
+from pep_utils.utils import str2bool, set_seed
+from tdc import Evaluator
+
+
+# Protein sequences
+PROTEINS = {
+ 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
+ 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
+ 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
+ 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
+ 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
+ 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
+ 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
+ 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
+ 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
+}
+
+
+def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
+ """Load a finetuned PeptideFinetuner from a Lightning checkpoint."""
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
+ hparams = ckpt.get('hyper_parameters', {})
+ args = hparams.get('args', None)
+
+ # Load pretrained base checkpoint to get config
+ base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
+ if 'hyper_parameters' in base_ckpt:
+ config = base_ckpt['hyper_parameters']['config']
+ elif 'config' in base_ckpt:
+ config = base_ckpt['config']
+ else:
+ raise ValueError("Cannot find config in base checkpoint")
+
+ from omegaconf import OmegaConf, DictConfig
+ if not OmegaConf.is_config(config):
+ config = DictConfig(config)
+ OmegaConf.set_struct(config, False)
+
+ config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
+ config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
+ config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
+ config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
+ config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
+ config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
+ OmegaConf.set_struct(config, True)
+
+ disable_planner = getattr(args, 'disable_planner', False)
+
+ policy_model = AnyOrderInsertionFlowModuleFT(
+ config=config,
+ args=args,
+ pretrained_checkpoint=pretrained_ckpt_path,
+ insertion_planner=not disable_planner,
+ )
+
+ # Load finetuned weights
+ state_dict = ckpt['state_dict']
+ policy_state = {}
+ for k, v in state_dict.items():
+ if k.startswith('policy_model.'):
+ policy_state[k[len('policy_model.'):]] = v
+ policy_model.load_state_dict(policy_state, strict=False)
+ policy_model = policy_model.to(device)
+ policy_model.eval()
+
+ return policy_model, args, config
+
+
+@torch.no_grad()
+def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer,
+ num_samples=1000, batch_size=50, max_length=512,
+ total_num_steps=256, quality_mode="both", num_remasking=3,
+ quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
+ """
+ Sample `num_samples` peptides and compute all table metrics.
+ Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time
+ """
+ all_affinity = []
+ all_sol = []
+ all_hemo = []
+ all_nf = []
+ all_permeability = []
+ all_valid_seqs = []
+ total_valid = 0
+ total_generated = 0
+ total_time = 0.0
+
+ num_batches = (num_samples + batch_size - 1) // batch_size
+ remaining = num_samples
+
+ for b in range(num_batches):
+ bs = min(batch_size, remaining)
+ remaining -= bs
+
+ t_start = time.time()
+ result = sample_peptides_eval(
+ model=policy_model,
+ reward_model=reward_model,
+ analyzer=analyzer,
+ tokenizer=tokenizer,
+ steps=total_num_steps,
+ mask=policy_model.interpolant.mask_token,
+ pad=policy_model.interpolant.pad_token,
+ batch_size=bs,
+ max_length=max_length,
+ quality_mode=quality_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ unmask_quality_threshold=unmask_quality_threshold,
+ return_valid=True,
+ )
+ t_end = time.time()
+
+ # Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
+ valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result
+
+ batch_valid = len(valid_seqs)
+ total_valid += batch_valid
+ total_generated += bs
+ total_time += (t_end - t_start)
+ all_valid_seqs.extend(valid_seqs)
+
+ if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0:
+ all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist())
+ all_sol.extend(sol if isinstance(sol, list) else sol.tolist())
+ all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist())
+ all_nf.extend(nf if isinstance(nf, list) else nf.tolist())
+ all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist())
+
+ print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, "
+ f"time={t_end - t_start:.1f}s")
+
+ validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0
+
+ # Uniqueness (% of valid sequences that are unique) and
+ # Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences).
+ # Matches the convention used in evaluate_mol_table.py.
+ all_unique = list(set(all_valid_seqs))
+ num_unique = len(all_unique)
+ uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0
+ if num_unique > 1:
+ diversity = Evaluator('diversity')(all_unique)
+ else:
+ diversity = 0.0
+
+ metrics = {
+ 'Validity (%)': validity,
+ 'Uniqueness (%)': uniqueness,
+ 'Diversity': diversity,
+ 'Affinity': np.mean(all_affinity) if all_affinity else 0.0,
+ 'Affinity Std': np.std(all_affinity) if all_affinity else 0.0,
+ 'Solubility': np.mean(all_sol) if all_sol else 0.0,
+ 'Solubility Std': np.std(all_sol) if all_sol else 0.0,
+ 'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0,
+ 'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0,
+ 'Nonfouling': np.mean(all_nf) if all_nf else 0.0,
+ 'Nonfouling Std': np.std(all_nf) if all_nf else 0.0,
+ 'Permeability': np.mean(all_permeability) if all_permeability else 0.0,
+ 'Permeability Std': np.std(all_permeability) if all_permeability else 0.0,
+ 'Sampling Time (s)': total_time,
+ 'Num Generated': total_generated,
+ 'Num Valid': total_valid,
+ 'Num Unique': num_unique,
+ }
+
+ return metrics
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint")
+ parser.add_argument('--checkpoint_path', type=str, required=True,
+ help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
+ parser.add_argument('--pretrained_ckpt', type=str,
+ default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'),
+ help='Path to the pretrained base model checkpoint')
+ parser.add_argument('--num_samples', type=int, default=500,
+ help='Number of peptides to sample')
+ parser.add_argument('--batch_size', type=int, default=50,
+ help='Batch size for sampling')
+ parser.add_argument('--max_length', type=int, default=512)
+ parser.add_argument('--total_num_steps', type=int, default=256)
+ parser.add_argument('--num_remasking', type=int, default=3)
+ parser.add_argument('--quality_threshold', type=float, default=0.5,
+ help='Threshold for insertion quality filtering during sampling')
+ parser.add_argument('--unmask_quality_threshold', type=float, default=None,
+ help='If set, gate unmasking/remasking by confidence: remask '
+ 'ALL clean tokens whose unmasking confidence is below this '
+ 'threshold, regardless of the schedule budget. If unset '
+ '(default), remasking is purely schedule-driven (count-based).')
+ parser.add_argument('--prot_name', type=str, default='glast',
+ help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')')
+ parser.add_argument('--prot_seq', type=str, default=None,
+ help='Custom protein sequence (overrides --prot_name)')
+ parser.add_argument('--disable_planner', action='store_true',
+ help='If set, disable remasking during evaluation')
+ parser.add_argument('--disable_insertion_planner', action='store_true',
+ help='If set, disable insertion quality filtering during evaluation')
+ parser.add_argument('--disable_unmasking_planner', action='store_true',
+ help='If set, disable unmasking confidence planner during evaluation')
+ parser.add_argument('--output_dir', type=str, default=None,
+ help='Directory to save results CSV. Defaults to checkpoint directory.')
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--seed', type=int, default=42)
+ args = parser.parse_args()
+
+ set_seed(args.seed, use_cuda=True)
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
+
+ # Map flags to quality_mode
+ if args.disable_planner:
+ quality_mode = "none"
+ elif args.disable_insertion_planner and args.disable_unmasking_planner:
+ quality_mode = "none"
+ elif args.disable_insertion_planner:
+ quality_mode = "unmasking_only"
+ elif args.disable_unmasking_planner:
+ quality_mode = "insertion_only"
+ else:
+ quality_mode = "both"
+
+ print(f"Loading checkpoint: {args.checkpoint_path}")
+ print(f"Pretrained base: {args.pretrained_ckpt}")
+ print(f"Quality mode: {quality_mode}")
+
+ policy_model, train_args, config = load_finetuned_model(
+ args.checkpoint_path, args.pretrained_ckpt, device=device
+ )
+
+ # Setup tokenizer, reward model, analyzer
+ tokenizer = SMILES_SPE_Tokenizer(
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
+ )
+
+ if args.prot_seq is not None:
+ prot = args.prot_seq
+ prot_name = args.prot_name
+ else:
+ prot_name = args.prot_name
+ if prot_name not in PROTEINS:
+ raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
+ prot = PROTEINS[prot_name]
+
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device)
+ analyzer = PeptideAnalyzer()
+
+ print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...")
+
+ metrics = evaluate_checkpoint(
+ policy_model=policy_model,
+ tokenizer=tokenizer,
+ reward_model=reward_model,
+ analyzer=analyzer,
+ num_samples=args.num_samples,
+ batch_size=args.batch_size,
+ max_length=args.max_length,
+ total_num_steps=args.total_num_steps,
+ quality_mode=quality_mode,
+ num_remasking=args.num_remasking,
+ quality_threshold=args.quality_threshold,
+ unmask_quality_threshold=args.unmask_quality_threshold,
+ device=device,
+ )
+
+ # Print summary table
+ print("\n" + "=" * 60)
+ print(" De Novo Peptide Generation Results")
+ print("=" * 60)
+ for k, v in metrics.items():
+ if isinstance(v, float):
+ print(f" {k:<30s}: {v:.4f}")
+ else:
+ print(f" {k:<30s}: {v}")
+ print("=" * 60)
+
+ # Save results
+ output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ if args.disable_planner:
+ tag = "no_planner"
+ elif args.disable_insertion_planner:
+ tag = "no_insertion_planner"
+ elif args.disable_unmasking_planner:
+ tag = "no_unmasking_planner"
+ else:
+ tag = "with_planner"
+ if args.unmask_quality_threshold is not None:
+ tag += f"_ut{args.unmask_quality_threshold:g}"
+ # Record the sweep parameter in the saved row for traceability.
+ metrics['unmask_quality_threshold'] = args.unmask_quality_threshold
+ metrics['quality_threshold'] = args.quality_threshold
+ metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv')
+ pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
+ print(f"Metrics saved to: {metrics_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/a2d2_pep/finetune_quality.py b/a2d2_pep/finetune_quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..89673c782a1e0db1c07653d56e1096ab73cbced5
--- /dev/null
+++ b/a2d2_pep/finetune_quality.py
@@ -0,0 +1,892 @@
+# Distributed Data Parallel (DDP) finetuning for peptide generation using PyTorch Lightning
+import argparse
+import math
+from datetime import datetime
+import numpy as np
+import torch
+import pytorch_lightning as pl
+from pytorch_lightning.strategies import DDPStrategy
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.loggers import WandbLogger
+import wandb
+import os
+import sys
+from tqdm import tqdm
+import pandas as pd
+
+# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from inference_quality import sample_peptides_buffer, sample_peptides_eval
+from pep_utils.analyzer import PeptideAnalyzer
+from pep_utils.utils import str2bool, set_seed
+from pep_scoring.scoring_functions import ScoringFunctions
+from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
+from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
+from lightning_modules import AnyOrderInsertionFlowModule
+from tdc import Evaluator
+
+# Repository root (two levels up from this file: A2D2/a2d2_pep/finetune_quality.py)
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+class PeptideFinetuner(pl.LightningModule):
+ """Lightning module for distributed peptide finetuning."""
+
+ def __init__(
+ self,
+ args,
+ policy_model,
+ reward_model,
+ tokenizer,
+ pretrained=None,
+ mcts=None,
+ filename=None,
+ prot_name=None,
+ eps=1e-5
+ ):
+ super().__init__()
+ self.args = args
+ self.policy_model = policy_model
+ self.reward_model = reward_model
+ self.tokenizer = tokenizer
+ self.pretrained = pretrained
+ self.mcts = mcts
+ self.filename = filename
+ self.prot_name = prot_name
+ self.eps = eps
+
+ # Length cutoff is tunable from the CLI: --min_peptide_bonds N enforces
+ # >=N peptide bonds (filters degenerate short reward-hacked molecules);
+ # --min_peptide_bonds 0 disables the cutoff.
+ min_bonds = getattr(args, 'min_peptide_bonds', 4)
+ self.analyzer = PeptideAnalyzer(
+ min_peptide_bonds=max(0, min_bonds),
+ enforce_min_peptide_bonds=min_bonds > 0,
+ )
+
+ # Save hyperparameters
+ self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
+
+ # Buffer for sequences
+ self.x_saved = None
+ self.log_rnd_saved = None
+ self.final_rewards_saved = None
+
+ # Logs
+ self.valid_fraction_log = []
+ self.uniqueness_log = []
+ self.diversity_log = []
+ self.affinity_log = []
+ self.sol_log = []
+ self.hemo_log = []
+ self.nf_log = []
+ self.permeability_log = []
+ self._diversity_evaluator = Evaluator('diversity')
+
+ # Alternating training between policy and planner
+ self.train_policy = True # Start by training policy
+ self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
+
+ def freeze_policy_model(self):
+ """Freeze policy model parameters (but not planner)."""
+ for name, param in self.policy_model.named_parameters():
+ if not name.startswith('planner.'):
+ param.requires_grad = False
+
+ def unfreeze_policy_model(self):
+ """Unfreeze policy model parameters (but not planner)."""
+ for name, param in self.policy_model.named_parameters():
+ if not name.startswith('planner.'):
+ param.requires_grad = True
+
+ def freeze_planner_model(self):
+ """Freeze planner parameters."""
+ if hasattr(self.policy_model, 'planner'):
+ for param in self.policy_model.planner.parameters():
+ param.requires_grad = False
+
+ def unfreeze_planner_model(self):
+ """Unfreeze planner parameters."""
+ if hasattr(self.policy_model, 'planner'):
+ for param in self.policy_model.planner.parameters():
+ param.requires_grad = True
+
+ def configure_optimizers(self):
+ # Separate parameter groups for policy backbone vs planner heads
+ planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
+ planner_params = []
+ policy_params = []
+ for name, param in self.policy_model.named_parameters():
+ if name.startswith('planner.'):
+ planner_params.append(param)
+ else:
+ policy_params.append(param)
+
+ param_groups = [
+ {'params': policy_params, 'lr': self.args.learning_rate},
+ {'params': planner_params, 'lr': planner_lr},
+ ]
+ optimizer = torch.optim.AdamW(param_groups)
+ return optimizer
+
+ def _get_quality_mode(self):
+ """Map ablation flags + warmup state to quality_mode string."""
+ if self.args.disable_planner:
+ return "none"
+ if self.current_epoch < self.args.schedule_warmup_epochs:
+ return "none"
+ di = getattr(self.args, 'disable_insertion_planner', False)
+ du = getattr(self.args, 'disable_unmasking_planner', False)
+ if di and du:
+ return "none"
+ if di:
+ return "unmasking_only"
+ if du:
+ return "insertion_only"
+ return "both"
+
+ def on_save_checkpoint(self, checkpoint):
+ """
+ Save additional metadata to make loading easier.
+ Saves the config directly in the checkpoint so loading doesn't need to follow references.
+ """
+ # Save the config from the policy model directly in the checkpoint
+ if hasattr(self.policy_model, 'config'):
+ checkpoint['config'] = self.policy_model.config
+ print(f"Saved config to checkpoint for easier loading")
+
+ # Save EMA params if they exist in the policy model
+ if hasattr(self.policy_model, 'ema_params') and self.policy_model.ema_params:
+ checkpoint['ema_params'] = self.policy_model.ema_params
+ print(f"Saved EMA params to checkpoint")
+
+ # Save planner state if it exists
+ if hasattr(self.policy_model, 'planner'):
+ checkpoint['planner_state'] = self.policy_model.planner.state_dict()
+ print(f"Saved planner state to checkpoint")
+
+ def on_train_epoch_start(self):
+ """Called at the start of each training epoch."""
+ # If disable_planner mode, only train policy (no alternation)
+ if self.args.disable_planner:
+ self.train_policy = True
+ self.unfreeze_policy_model()
+ self.freeze_planner_model()
+ if self.global_rank == 0 and self.current_epoch == 0:
+ print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
+ elif getattr(self.args, 'joint_training', False):
+ # Joint mode: train policy + planner together every step (no alternation)
+ self.train_policy = True # marker; training_step adds planner loss when joint_training is set
+ self.unfreeze_policy_model()
+ self.unfreeze_planner_model()
+ if self.global_rank == 0 and self.current_epoch == 0:
+ print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
+ else:
+ # Alternate between training policy and planner from epoch 0
+ # Determine which model to train this epoch
+ cycle_position = (self.current_epoch // self.alternation_frequency) % 2
+ self.train_policy = (cycle_position == 0)
+
+ if self.train_policy:
+ # Train policy, freeze planner
+ self.unfreeze_policy_model()
+ self.freeze_planner_model()
+ if self.global_rank == 0:
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
+ else:
+ # Train planner, freeze policy
+ self.freeze_policy_model()
+ self.unfreeze_planner_model()
+ if self.global_rank == 0:
+ print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
+
+ # Resample buffer if needed
+ if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
+ self._generate_buffer()
+ # Synchronize all ranks after buffer generation to prevent NCCL timeout
+ if self.trainer and self.trainer.world_size > 1:
+ torch.distributed.barrier()
+
+ def _generate_buffer(self):
+ """Generate buffer of sequences for training - all ranks generate in parallel.
+
+ When pool_size > 0, maintains a persistent pool and refreshes a fraction
+ each time instead of regenerating the entire buffer from scratch. This
+ preserves diversity/uniqueness across training by avoiding wholesale
+ replacement with samples from an increasingly mode-collapsed policy.
+ """
+ world_size = self.trainer.world_size if self.trainer else 1
+ rank = self.global_rank if self.trainer else 0
+
+ pool_size = getattr(self.args, 'pool_size', 0)
+ is_pool = pool_size > 0
+ is_init = self.x_saved is None
+
+ # Determine how many sequences to sample this call
+ if is_pool:
+ refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
+ if is_init:
+ samples_per_gpu = pool_size
+ else:
+ samples_per_gpu = max(1, int(pool_size * refresh_frac))
+ if rank == 0:
+ if is_init:
+ print(f"\n[POOL] Initializing pool with {pool_size} sequences at epoch {self.current_epoch}")
+ else:
+ print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} sequences ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
+ else:
+ samples_per_gpu = self.args.buffer_size // world_size
+ if rank == 0:
+ samples_per_gpu += self.args.buffer_size % world_size
+
+ accumulated_x = []
+ accumulated_log_rnd = []
+ accumulated_rewards = []
+ total_accumulated = 0
+
+ if rank == 0:
+ print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
+ print(f"[BUFFER] GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
+ print(f"[BUFFER] GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
+ if not is_pool:
+ print(f"[BUFFER] Each of {world_size} ranks will generate {samples_per_gpu} samples")
+
+ max_attempts = getattr(self.args, 'max_buffer_attempts', 100) # cap wasted GPU / infinite loop
+ starvation_patience = getattr(self.args, 'buffer_starvation_patience', 10)
+ attempts = 0
+
+ import time
+ while total_accumulated < samples_per_gpu and attempts < max_attempts:
+ attempts += 1
+ if rank == 0:
+ print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
+
+ start_time = time.time()
+
+ # new elbo loss
+ if self.args.elbo_rnd:
+ x_final, _, final_rewards, trace = \
+ sample_peptides_buffer(
+ self.policy_model,
+ self.reward_model, self.analyzer,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=self.args.batch_size,
+ max_length=self.args.max_length,
+ # Buffer generation never uses the quality heads (planner):
+ # the backbone must train on raw policy samples so that a
+ # poorly-trained planner can't corrupt the backbone's data.
+ quality_mode="none",
+ compute_rnd=False,
+ alpha=self.args.alpha,
+ num_remasking=self.args.num_remasking,
+ min_length=self.args.min_length,
+ )
+ if x_final.shape[0] > 0:
+ with torch.no_grad():
+ noised = self.policy_model.prepare_noised_sample(
+ x_final, num_samples=self.args.elbo_rnd_num_samples)
+ policy_loss = self.policy_model.compute_loss_from_noised(noised)
+ pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
+ log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
+ else:
+ log_rnd = torch.empty((0,), dtype=torch.float32, device=x_final.device)
+ else:
+ x_final, log_rnd, final_rewards, trace = \
+ sample_peptides_buffer(
+ self.policy_model,
+ self.reward_model, self.analyzer,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=self.args.batch_size,
+ max_length=self.args.max_length,
+ # Buffer generation never uses the quality heads (planner):
+ # the backbone must train on raw policy samples so that a
+ # poorly-trained planner can't corrupt the backbone's data.
+ quality_mode="none",
+ compute_rnd=True,
+ pretrained=self.pretrained,
+ alpha=self.args.alpha,
+ num_remasking=self.args.num_remasking,
+ min_length=self.args.min_length,
+ )
+ elapsed = time.time() - start_time
+ if rank == 0:
+ print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
+
+ n_valid = x_final.shape[0]
+ if n_valid > 0:
+ accumulated_x.append(x_final)
+ accumulated_log_rnd.append(log_rnd)
+ accumulated_rewards.append(final_rewards)
+ total_accumulated += n_valid
+
+ if rank == 0:
+ print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode=none (heads disabled for buffer gen) accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
+
+ # Starvation guard: if nothing valid comes through (e.g. the length
+ # cutoff is too aggressive for a collapsed policy), stop grinding GPU
+ # hours and fail fast with an actionable message.
+ if attempts >= starvation_patience and total_accumulated == 0:
+ if rank == 0:
+ print(f"[BUFFER STARVATION] 0 valid samples after {attempts} attempts "
+ f"(min_peptide_bonds={getattr(self.args, 'min_peptide_bonds', 4)}). "
+ f"Aborting refill early — lower --min_peptide_bonds or check the policy.")
+ break
+
+ if total_accumulated == 0:
+ raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
+
+ if total_accumulated < samples_per_gpu:
+ print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
+
+ new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
+ new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
+ new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
+
+ del accumulated_x, accumulated_log_rnd, accumulated_rewards
+ torch.cuda.empty_cache()
+
+ # Pool mode (after init): replace a random subset of the existing pool.
+ # Classic mode / pool init: overwrite the buffer.
+ if is_pool and not is_init:
+ actual_new = min(new_x.shape[0], self.x_saved.shape[0])
+ indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
+ self.x_saved[indices] = new_x[:actual_new]
+ self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
+ self.final_rewards_saved[indices] = new_rewards[:actual_new]
+ if rank == 0:
+ print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} sequences, reward mean={self.final_rewards_saved.mean():.4f}")
+ else:
+ self.x_saved = new_x
+ self.log_rnd_saved = new_log_rnd
+ self.final_rewards_saved = new_rewards
+
+ # Sanity check: median length (non-pad tokens) of buffered peptides.
+ if rank == 0:
+ pad = self.policy_model.interpolant.pad_token
+ token_lens = (self.x_saved != pad).sum(dim=1)
+ print(f"[BUFFER] peptide token length: median={token_lens.median().item()} "
+ f"min={token_lens.min().item()} max={token_lens.max().item()} "
+ f"(n={token_lens.shape[0]})")
+
+ def training_step(self, batch, batch_idx):
+ """Training step - batch is ignored, we use saved buffer."""
+ # Use mini-batch sampling from buffer to avoid OOM
+ buffer_size = self.x_saved.shape[0]
+ mini_batch_size = getattr(self.args, 'training_mini_batch_size', 6)
+
+ # Randomly sample mini_batch_size sequences from buffer
+ if buffer_size > mini_batch_size:
+ indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
+ x_final = self.x_saved[indices]
+ log_rnd = self.log_rnd_saved[indices]
+ else:
+ # If buffer is smaller than mini_batch_size, use all
+ x_final = self.x_saved
+ log_rnd = self.log_rnd_saved
+
+ joint = getattr(self.args, 'joint_training', False)
+ policy_loss = None
+ planner_loss = None
+
+ if self.train_policy:
+ # Train policy with WDCE loss
+ policy_loss = self.policy_model.loss_wdce_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength
+ )
+ self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+
+ if (not self.train_policy) or joint:
+ # Train planner with appropriate loss based on ablation flags
+ if self.args.disable_insertion_planner:
+ # Ablation: only train unmasking/remasking planner (no insertion head)
+ planner_loss = self.policy_model.loss_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength
+ )
+ self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ elif self.args.disable_unmasking_planner:
+ # Ablation: only train insertion planner (no remasking head)
+ unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength
+ )
+ # Zero out the unmasking component - only backprop insertion loss
+ planner_loss = insert_loss
+ self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ else:
+ # Full planner: train both remasking + insertion
+ unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
+ log_rnd,
+ x_final,
+ num_replicates=self.args.wdce_num_replicates,
+ centering=self.args.centering,
+ centering_strength=self.args.centering_strength
+ )
+ self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+
+
+ # Combine losses depending on mode
+ if joint:
+ loss = policy_loss + planner_loss
+ mode_value = 0.5
+ elif self.train_policy:
+ loss = policy_loss
+ mode_value = 0.0
+ else:
+ loss = planner_loss
+ mode_value = 1.0
+
+ # Log overall loss and mode
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
+ self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
+
+ return loss
+
+ def on_train_epoch_end(self):
+ """Called at the end of each training epoch - only rank 0 evaluates."""
+ # Only evaluate every N epochs to save time
+ eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
+ is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
+ if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
+ # Sample eval batch with updated policy
+ valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = \
+ sample_peptides_eval(
+ self.policy_model, self.reward_model, self.analyzer,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=50,
+ max_length=self.args.max_length,
+ quality_mode=self._get_quality_mode(),
+ num_remasking=self.args.num_remasking,
+ return_valid=True,
+ )
+
+ # Uniqueness (% of valid that are unique) and Diversity
+ # (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences),
+ # matching evaluate_peptide_table.py / evaluate_mol_table.py.
+ num_valid = len(valid_seqs)
+ unique_seqs = list(set(valid_seqs))
+ num_unique = len(unique_seqs)
+ uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
+ diversity = self._diversity_evaluator(unique_seqs) if num_unique > 1 else 0.0
+
+ # Append to logs
+ self.affinity_log.append(affinity)
+ self.sol_log.append(sol)
+ self.hemo_log.append(hemo)
+ self.nf_log.append(nf)
+ self.permeability_log.append(permeability)
+ self.valid_fraction_log.append(valid_fraction)
+ self.uniqueness_log.append(uniqueness)
+ self.diversity_log.append(diversity)
+
+ # Compute reward stats
+ mean_reward = self.final_rewards_saved.mean().item()
+ min_reward = self.final_rewards_saved.min().item()
+ max_reward = self.final_rewards_saved.max().item()
+ median_reward = self.final_rewards_saved.median().item()
+
+ # Log metrics
+ self.log_dict({
+ "eval/affinity": np.mean(affinity),
+ "eval/sol": np.mean(sol),
+ "eval/hemo": np.mean(hemo),
+ "eval/nf": np.mean(nf),
+ "eval/permeability": np.mean(permeability),
+ "eval/valid_fraction": valid_fraction,
+ "eval/uniqueness": uniqueness,
+ "eval/diversity": diversity,
+ "eval/mean_reward_search": mean_reward,
+ "eval/min_reward_search": min_reward,
+ "eval/max_reward_search": max_reward,
+ "eval/median_reward_search": median_reward
+ })
+
+ print(f"epoch {self.current_epoch} | affinity {np.mean(affinity):.4f} | "
+ f"sol {np.mean(sol):.4f} | hemo {np.mean(hemo):.4f} | "
+ f"nf {np.mean(nf):.4f} | permeability {np.mean(permeability):.4f} | "
+ f"valid {valid_fraction:.4f} | uniq {uniqueness:.2f}% | div {diversity:.4f}")
+
+ def on_fit_end(self):
+ """Called at the end of training - save results."""
+ if self.global_rank == 0:
+ # Save logs and plot
+ base_path = self.args.base_path
+ plot_path = f'{base_path}/results/{self.args.run_name}'
+ os.makedirs(plot_path, exist_ok=True)
+
+ output_log_path = f'{plot_path}/log_{self.filename}.csv'
+ save_logs_to_file(self.valid_fraction_log, self.affinity_log,
+ self.sol_log, self.hemo_log, self.nf_log,
+ self.permeability_log, output_log_path,
+ uniqueness_log=self.uniqueness_log,
+ diversity_log=self.diversity_log)
+
+ # Final generation
+ x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = \
+ sample_peptides_eval(
+ self.policy_model, self.reward_model, self.analyzer,
+ self.tokenizer,
+ steps=self.args.total_num_steps,
+ mask=self.policy_model.interpolant.mask_token,
+ pad=self.policy_model.interpolant.pad_token,
+ batch_size=50,
+ max_length=self.args.max_length,
+ quality_mode=self._get_quality_mode(),
+ num_remasking=self.args.num_remasking,
+ dataframe=True,
+ )
+ df.to_csv(f'{plot_path}/{self.prot_name}_generation_results.csv', index=False)
+
+
+def save_logs_to_file(valid_fraction_log, affinity_log,
+ sol_log, hemo_log, nf_log,
+ permeability_log, output_path,
+ uniqueness_log=None, diversity_log=None):
+ """
+ Saves the logs to a CSV file.
+ """
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ log_data = {
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
+ "Valid Fraction": valid_fraction_log,
+ "Binding Affinity": affinity_log,
+ "Solubility": sol_log,
+ "Hemolysis": hemo_log,
+ "Nonfouling": nf_log,
+ "Permeability": permeability_log,
+ }
+ if uniqueness_log is not None:
+ log_data["Uniqueness (%)"] = uniqueness_log
+ if diversity_log is not None:
+ log_data["Diversity"] = diversity_log
+
+ df = pd.DataFrame(log_data)
+ df.to_csv(output_path, index=False)
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ """Dummy dataset for Lightning trainer (we use buffer instead)."""
+ def __init__(self, size=10):
+ self.size = size
+
+ def __len__(self):
+ return self.size
+
+ def __getitem__(self, idx):
+ return torch.zeros(1) # Dummy data
+
+
+def main():
+ """Main entry point for distributed training."""
+ # Disable DDP optimizer for higher-order ops like flex_attention
+ import torch._dynamo
+ torch._dynamo.config.optimize_ddp = False
+
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
+ argparser.add_argument('--num_epochs', type=int, default=100)
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
+ argparser.add_argument('--truncate_steps', type=int, default=50)
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
+ argparser.add_argument('--batch_size', type=int, default=50)
+ argparser.add_argument('--name', type=str, default='debug')
+ argparser.add_argument('--total_num_steps', type=int, default=128)
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
+ argparser.add_argument("--seed", type=int, default=0)
+ # new
+ argparser.add_argument('--run_name', type=str, default='peptides')
+ argparser.add_argument("--save_path_dir", default=os.path.join(REPO_ROOT, "checkpoints", "finetune_peptides"), type=str)
+ # mcts
+ argparser.add_argument('--num_sequences', type=int, default=10)
+ argparser.add_argument('--max_length', type=int, default=1024)
+ argparser.add_argument('--min_length', type=int, default=0,
+ help='Minimum sequence length (in SMILES SPE tokens). '
+ 'Samples shorter than this are dropped from the buffer. 0 disables the filter.')
+ argparser.add_argument('--num_children', type=int, default=50)
+ argparser.add_argument('--num_iter', type=int, default=30)
+ argparser.add_argument('--seq_length', type=int, default=1024)
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
+ argparser.add_argument('--buffer_size', type=int, default=100)
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
+ argparser.add_argument('--resample_every_n_step', type=int, default=10)
+ argparser.add_argument('--exploration', type=float, default=0.1)
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
+ argparser.add_argument('--alpha', type=float, default=0.01)
+ argparser.add_argument('--scalarization', type=str, default='sum')
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
+ argparser.add_argument("--centering", action='store_true', default=False)
+ argparser.add_argument("--centering_strength", type=float, default=1.0)
+
+ # ELBO-based log_rnd estimation
+ argparser.add_argument('--elbo_rnd', action='store_true', default=False,
+ help='If set, compute log_rnd via ELBO instead of trajectory rollout')
+ argparser.add_argument('--elbo_rnd_num_samples', type=int, default=16,
+ help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
+
+ # adaptive schedule parameters
+ argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
+ argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
+ argparser.add_argument('--schedule_num_layers', type=int, default=2)
+ argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
+ argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
+ argparser.add_argument('--freeze_base_model', action='store_true', default=False)
+ argparser.add_argument('--schedule_warmup_epochs', type=int, default=0, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
+ argparser.add_argument('--alternation_frequency', type=int, default=20, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
+ argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
+
+ # objectives
+ argparser.add_argument('--num_obj', type=int, default=5)
+ argparser.add_argument('--prot_seq', type=str, default=None)
+ argparser.add_argument('--prot_name', type=str, default='glast',
+ help='Protein target name. Looked up in PROTEINS dict unless --prot_seq is given.')
+ argparser.add_argument('--devices', type=int, default=-1)
+ argparser.add_argument('--checkpoint_path', type=str, default=None)
+ argparser.add_argument('--resume_ckpt', type=str, default=None,
+ help='Path to a Lightning last.ckpt to resume training from (restores epoch/optimizer/planner state). '
+ 'New checkpoints continue in the same directory as this checkpoint.')
+
+ # remasking
+ argparser.add_argument('--num_remasking', type=int, default=5)
+ argparser.add_argument('--quality_threshold', type=float, default=1)
+
+ # length cutoff (peptide-bond filter) + buffer starvation guard
+ argparser.add_argument('--min_peptide_bonds', type=int, default=4,
+ help='Minimum backbone peptide bonds for a sample to count as valid. '
+ '0 disables the cutoff. Filters degenerate short reward-hacked molecules.')
+ argparser.add_argument('--max_buffer_attempts', type=int, default=100,
+ help='Max sampling rounds per buffer refill before giving up (caps wasted GPU when validity is low).')
+ argparser.add_argument('--buffer_starvation_patience', type=int, default=10,
+ help='If 0 valid samples after this many rounds, abort the refill early (starvation guard).')
+
+ # planner ablation flags
+ argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
+ argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
+ argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
+ argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
+
+ # performance optimization
+ argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
+ argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10, help='Number of gradient updates per epoch')
+ argparser.add_argument('--training_mini_batch_size', type=int, default=6, help='Mini-batch size for training from buffer to avoid OOM')
+ argparser.add_argument('--pool_size', type=int, default=0,
+ help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer). Helps preserve uniqueness/diversity over training.')
+ argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
+ help='Fraction of pool to replace each resample step (only used when pool_size>0)')
+
+ args = argparser.parse_args()
+
+ # Default planner LR to policy LR if not specified
+ if args.planner_learning_rate is None:
+ args.planner_learning_rate = args.learning_rate
+
+ # Set seed
+ pl.seed_everything(args.seed)
+
+ # Load models
+ checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
+ os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt')
+
+ # Update args.checkpoint_path to ensure it's saved in hyperparameters for later inference
+ args.checkpoint_path = checkpoint_path
+
+ PROTEINS = {
+ 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
+ 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
+ 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
+ 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
+ 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
+ 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
+ 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
+ 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
+ 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
+ }
+
+ if args.prot_seq is not None:
+ prot = args.prot_seq
+ prot_name = args.prot_name
+ else:
+ prot_name = args.prot_name
+ if prot_name not in PROTEINS:
+ raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
+ prot = PROTEINS[prot_name]
+ filename = prot_name
+
+ curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+ if args.no_mcts:
+ args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_no-mcts'
+ else:
+ args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}'
+
+ # Append ablation tags to run name for easy identification
+ if args.disable_planner:
+ args.run_name += '_no_planner'
+ if args.disable_insertion_planner:
+ args.run_name += '_no_insertion_planner'
+ if args.disable_unmasking_planner:
+ args.run_name += '_no_unmasking_planner'
+ if args.joint_training:
+ if args.disable_planner:
+ raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
+ args.run_name += '_joint_training'
+
+ # When resuming, continue writing checkpoints into the SAME directory as the
+ # checkpoint we resume from (keeps model-{epoch}.ckpt contiguous) instead of
+ # spawning a fresh timestamped run directory.
+ if args.resume_ckpt:
+ args.save_path = os.path.dirname(os.path.abspath(args.resume_ckpt))
+ args.run_name = os.path.basename(args.save_path)
+ else:
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
+ os.makedirs(args.save_path, exist_ok=True)
+ set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
+
+ # Initialize the model
+ print("Loading models..")
+
+ # Load pretrained model for reference (frozen)
+ pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
+ map_location='cpu',
+ weights_only=False)
+ pretrained.eval()
+ for param in pretrained.parameters():
+ param.requires_grad = False
+
+ # Load checkpoint to extract config
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
+ if 'hyper_parameters' in checkpoint:
+ config = checkpoint['hyper_parameters']['config']
+ elif 'config' in checkpoint:
+ config = checkpoint['config']
+ else:
+ raise ValueError("Cannot find config in checkpoint")
+
+ # Update config for adaptive schedule
+ from omegaconf import OmegaConf
+ if not OmegaConf.is_config(config):
+ from omegaconf import DictConfig
+ config = DictConfig(config)
+
+ # Disable struct mode to allow adding new keys
+ OmegaConf.set_struct(config, False)
+
+ config.training.use_adaptive_schedule = args.use_adaptive_schedule
+ config.training.schedule_hidden_dim = args.schedule_hidden_dim
+ config.training.schedule_num_layers = args.schedule_num_layers
+ config.training.schedule_loss_weight = args.schedule_loss_weight
+ config.training.freeze_base_model = args.freeze_base_model
+ config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
+
+ # Re-enable struct mode
+ OmegaConf.set_struct(config, True)
+
+ # Initialize policy model with adaptive schedule
+ policy_model = AnyOrderInsertionFlowModuleFT(
+ config=config,
+ args=args,
+ pretrained_checkpoint=checkpoint_path,
+ insertion_planner=True,
+ )
+
+ # define mcts
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
+
+ tokenizer = SMILES_SPE_Tokenizer(
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
+ os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
+ )
+
+ # Device will be set by Lightning automatically in DDP
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device='cpu')
+ model = PeptideFinetuner(
+ args=args,
+ policy_model=policy_model,
+ reward_model=reward_model,
+ tokenizer=tokenizer,
+ pretrained=pretrained,
+ mcts=None,
+ filename=filename,
+ prot_name=prot_name
+ )
+
+ # Setup checkpoint callback
+ checkpoint_callback = ModelCheckpoint(
+ dirpath=args.save_path,
+ filename='model-{epoch:02d}',
+ every_n_epochs=args.save_every_n_epochs,
+ save_top_k=-1,
+ save_last=True, # Also save last.ckpt
+ auto_insert_metric_name=False
+ )
+
+ # Setup wandb logger - only on rank 0 to avoid multiple runs
+ # Check if we're in a spawned DDP process
+ rank = int(os.environ.get('LOCAL_RANK', 0))
+ if rank == 0:
+ # Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
+ wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-pep', name=args.run_name)
+ else:
+ wandb_logger = None
+
+ # Create dummy dataloader
+ dataset = DummyDataset(size=args.num_training_steps_per_epoch)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
+
+ # Setup trainer with DDP
+ trainer = pl.Trainer(
+ max_epochs=args.num_epochs,
+ accelerator='gpu',
+ devices=args.devices,
+ strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
+ gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
+ logger=wandb_logger,
+ callbacks=[checkpoint_callback],
+ enable_progress_bar=True,
+ log_every_n_steps=1
+ )
+
+ # Train (resume full training state from --resume_ckpt if provided).
+ # weights_only=False is required when resuming because these checkpoints
+ # store argparse.Namespace / OmegaConf objects in hyper_parameters, which
+ # PyTorch 2.6's default weights_only=True unpickler rejects.
+ if args.resume_ckpt:
+ trainer.fit(model, dataloader, ckpt_path=args.resume_ckpt, weights_only=False)
+ else:
+ trainer.fit(model, dataloader)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/a2d2_pep/inference_quality.py b/a2d2_pep/inference_quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9ddc5cabc3bedc557e2c4a045362c8af2e4efa
--- /dev/null
+++ b/a2d2_pep/inference_quality.py
@@ -0,0 +1,605 @@
+"""Unified peptide sampling with quality-guided planning.
+
+Supports 4 quality modes and optional RND (importance weight) computation.
+
+Quality modes:
+ "none" - No planner, no remasking (policy-only)
+ "both" - Both unmasking + insertion planners active
+ "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled)
+ "insertion_only" - Only insertion planner (unmasking planner disabled)
+
+RND toggle:
+ compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights
+ compute_rnd=False - Run policy model only (use with ELBO-based RND or eval)
+"""
+
+import os
+import torch
+import numpy as np
+import pandas as pd
+import torch.nn.functional as F
+from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens
+from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion
+
+QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"}
+
+# When set (e.g. A2D2_QUALITY_DEBUG=1), the diffusion loop prints, per step, how
+# many already-unmasked tokens get remasked and how many proposed insertions get
+# filtered by the quality planner, plus a per-batch total. Off by default so it
+# never spams training/eval runs.
+_QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False")
+
+
+@torch.no_grad()
+def _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ compute_rnd=False,
+ pretrained=None,
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ unmask_quality_threshold=None,
+ unmask_all=False,
+ freq_penalty=0.0,
+ return_trace=False,
+):
+ """Core discrete diffusion sampling loop for peptide generation.
+
+ Args:
+ model: Finetuned policy model.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: One of "none", "both", "unmasking_only", "insertion_only".
+ compute_rnd: Whether to compute step-wise log importance weights.
+ pretrained: Frozen pretrained model (required if compute_rnd=True).
+ remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf").
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering. None if schedule-driven.
+ return_trace: Whether to record sampling trace.
+
+ Returns:
+ (xt, log_rnd, sampling_trace)
+ log_rnd is None when compute_rnd=False.
+ """
+ assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}"
+ if compute_rnd:
+ assert pretrained is not None, "pretrained model required when compute_rnd=True"
+
+ # Derive flags from quality_mode
+ use_remasking = quality_mode != "none"
+ disable_unmasking_planner = quality_mode in ("none", "insertion_only")
+ disable_insertion_planner = quality_mode in ("none", "unmasking_only")
+
+ device = next(model.parameters()).device
+
+ # Initialize all-pad sequence
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute index tensors
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ neg_inf = torch.tensor(-np.inf, device=device)
+
+ if use_remasking and remasking_mode == "remdm_conf":
+ remasking_score = torch.zeros((batch_size, max_length), device=device)
+
+ log_rnd = None
+
+ dbg_total_remasked = 0
+ dbg_total_proposed_ins = 0
+ dbg_total_filtered = 0
+
+ for i in range(steps):
+ step_remasked = 0
+ step_proposed_ins = 0
+ step_filtered = 0
+ # --- Policy model forward ---
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # --- Pretrained model forward (for RND) ---
+ if compute_rnd:
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # --- Unmask step (Euler) ---
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ if compute_rnd:
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Add "stay" probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ if compute_rnd:
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ # Remove mask token from sampling so every masked position is decoded.
+ # The final step always does this; unmask_all does it every step, so the
+ # schedule-aware remasking below re-masks the lowest-quality tokens back
+ # down to the schedule's expected mask count.
+ if i == steps - 1 or unmask_all:
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0
+
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ num_zero_prob = mask_has_zero_prob.sum().item()
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
+ uniform_prob[:, :mask] = 1.0 / mask
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ # --- Frequency penalty: down-weight residues already abundant in the
+ # sequence so (re)decoded masked positions don't collapse onto the modal
+ # token (glycine). Only masked positions are sampled; clean positions are
+ # overwritten below, so penalizing the whole tensor is harmless. mask/pad
+ # never accumulate counts, so their entries stay untouched. Applied to a
+ # copy so trans_prob (used for RND log-probs) is unchanged.
+ sample_prob = trans_prob
+ if freq_penalty > 0.0:
+ V = trans_prob.shape[-1]
+ clean_tok = (xt != mask) & (xt != pad) # (B, L)
+ counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype)
+ counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)),
+ clean_tok.to(trans_prob.dtype))
+ sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1)
+
+ new_xt = _sample_tokens(sample_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # Update remasking_score buffer for remdm_conf mode
+ if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1:
+ token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V)
+ chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L)
+ changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+ remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score)
+
+ # --- Remasking step ---
+ if use_remasking and i < steps - 1:
+ if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None):
+ remasking_conf = torch.zeros((batch_size, max_length), device=device)
+ else:
+ planner_out = model.planner(new_xt, t)
+ remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L)
+
+ clean_index = (new_xt != mask) & (new_xt != pad) # (B, L)
+
+ if remasking_mode == "remdm":
+ remasking_score_temp = torch.rand(remasking_conf.shape, device=device)
+ elif remasking_mode == "remdm_conf":
+ remasking_score_temp = -1.0 * remasking_conf
+ elif remasking_mode == "schedule_aware":
+ # Only remask when the unmasking planner is active. Otherwise
+ # (e.g. insertion_only / no_unmasking_planner) remasking_conf is
+ # all zeros, so this would remask schedule-excess tokens by
+ # position rather than by quality.
+ if not disable_unmasking_planner:
+ new_xt = apply_schedule_aware_remasking(
+ model, new_xt, t, dt, remasking_conf, clean_index,
+ mask, neg_inf, batch_size,
+ unmask_quality_threshold=unmask_quality_threshold,
+ )
+ remasking_score_temp = None
+ else:
+ raise ValueError(f"Unknown remasking_mode: {remasking_mode}")
+
+ if remasking_score_temp is not None:
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
+ for j in range(batch_size):
+ k = min(num_remasking, int(clean_index[j].sum().item()))
+ if k > 0:
+ _, select_indices = torch.topk(remasking_score_temp[j], k=k)
+ new_xt[j, select_indices] = mask
+
+ if _QUALITY_DEBUG:
+ # Positions that were clean before this remasking block and are
+ # now mask are exactly the unmasked tokens that got remasked.
+ step_remasked = int((clean_index & (new_xt == mask)).sum().item())
+
+ if return_trace:
+ for batch_idx in range(batch_size):
+ for pos in range(max_length):
+ if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=pos,
+ token=mask,
+ )
+ )
+
+ # --- Compute log probabilities for RND ---
+ if compute_rnd:
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ # --- Insertion step ---
+ if i != steps - 1:
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+
+ if _QUALITY_DEBUG:
+ # ext has been masked by the max-length validity check above, so
+ # this is the number of fresh mask tokens actually inserted.
+ step_proposed_ins = int(ext.sum().item())
+
+ # Schedule-aware insertion quality filtering
+ if use_remasking and not disable_insertion_planner:
+ if compute_rnd:
+ xt_tmp_before = xt_tmp.clone()
+
+ dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0
+
+ xt_tmp = apply_schedule_aware_insertion(
+ model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length,
+ orig_mask, new_pos_orig, quality_threshold
+ )
+
+ if _QUALITY_DEBUG:
+ # Filtering only drops/compacts tokens, so the drop in
+ # non-pad count is the number of insertions filtered out.
+ step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item())
+
+ if compute_rnd:
+ # Compute corrected ext based on what actually stayed
+ ext_corrected = torch.zeros_like(ext)
+ for b in range(batch_size):
+ after_len = xt_tmp[b].ne(pad).sum().item()
+ orig_len = xt_len[b].item()
+ surviving_insertions = after_len - orig_len
+ if total_ext[b] > 0:
+ ratio = surviving_insertions / total_ext[b].item()
+ ext_corrected[b] = (ext[b].float() * ratio).long()
+ else:
+ ext_corrected = ext
+ else:
+ ext_corrected = ext
+
+ # Compute insertion log_rnd
+ if compute_rnd:
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1)
+ log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert
+ log_rnd += log_insert_diff
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ if i != steps - 1:
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ if _QUALITY_DEBUG:
+ dbg_total_remasked += step_remasked
+ dbg_total_proposed_ins += step_proposed_ins
+ dbg_total_filtered += step_filtered
+ print(
+ f"[QUALITY {quality_mode}] step {i+1}/{steps}: "
+ f"remasked {step_remasked} unmasked tokens -> mask | "
+ f"insertions proposed {step_proposed_ins}, "
+ f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}"
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ if _QUALITY_DEBUG:
+ print(
+ f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): "
+ f"remasked {dbg_total_remasked} unmasked tokens | "
+ f"insertions proposed {dbg_total_proposed_ins}, "
+ f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}"
+ )
+
+ return xt, log_rnd, sampling_trace
+
+
+@torch.no_grad()
+def sample_peptides_buffer(
+ model, reward_model, analyzer, tokenizer,
+ steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ compute_rnd=False,
+ pretrained=None,
+ alpha=0.1,
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ min_length=0,
+):
+ """Generate peptides for training buffer.
+
+ Args:
+ model: Finetuned policy model.
+ reward_model: Multi-objective scoring function.
+ analyzer: PeptideAnalyzer for validation.
+ tokenizer: Tokenizer for decoding.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
+ compute_rnd: If True, compute step-wise log importance weights (requires pretrained).
+ If False, returns placeholder zero log_rnd (for ELBO-based RND).
+ pretrained: Frozen pretrained model (required when compute_rnd=True).
+ alpha: RND scaling factor.
+ remasking_mode: Remasking strategy.
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering.
+
+ Returns:
+ (valid_x, log_rnd, scalar_rewards, sampling_trace)
+ """
+ xt, log_rnd, trace = _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode=quality_mode,
+ compute_rnd=compute_rnd,
+ pretrained=pretrained,
+ remasking_mode=remasking_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ )
+
+ device = xt.device
+ decoded_samples = tokenizer.batch_decode(xt)
+
+ valid_x_final = []
+ validSequences = []
+ valid_log_rnd = []
+
+ for idx, seq in enumerate(decoded_samples):
+ if not analyzer.is_peptide(seq):
+ continue
+ token_len = int((xt[idx] != pad).sum().item())
+ if min_length > 0 and token_len < min_length:
+ continue
+ valid_x_final.append(xt[idx])
+ validSequences.append(seq)
+ if compute_rnd:
+ valid_log_rnd.append(log_rnd[idx])
+
+ print("len valid sequences:", len(validSequences))
+
+ if len(validSequences) == 0:
+ print("[WARNING] No valid peptides generated in this batch")
+ empty_x = torch.empty((0, max_length), dtype=torch.long, device=device)
+ empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device)
+ empty_rewards = torch.empty((0,), dtype=torch.float32, device=device)
+ return empty_x, empty_log_rnd, empty_rewards, trace
+
+ score_vectors = reward_model(input_seqs=validSequences)
+ scalar_rewards = np.sum(score_vectors, axis=-1)
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
+
+ print(f"scalar reward dim{len(scalar_rewards)}")
+ valid_x_final = torch.stack(valid_x_final, dim=0)
+
+ if compute_rnd:
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
+ log_rnd_out = valid_log_rnd + (scalar_rewards / alpha)
+ else:
+ log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device)
+
+ return valid_x_final, log_rnd_out, scalar_rewards, trace
+
+
+@torch.no_grad()
+def sample_peptides_eval(
+ model, reward_model, analyzer, tokenizer,
+ steps, mask, pad, batch_size, max_length,
+ quality_mode="both",
+ remasking_mode="schedule_aware",
+ num_remasking=1,
+ quality_threshold=1,
+ unmask_quality_threshold=None,
+ unmask_all=False,
+ freq_penalty=0.0,
+ dataframe=False,
+ return_valid=False,
+):
+ """Generate peptides for evaluation.
+
+ Args:
+ model: Finetuned policy model.
+ reward_model: Multi-objective scoring function.
+ analyzer: PeptideAnalyzer for validation.
+ tokenizer: Tokenizer for decoding.
+ steps: Number of diffusion steps.
+ mask: Mask token ID.
+ pad: Pad token ID.
+ batch_size: Number of sequences to generate.
+ max_length: Maximum sequence length.
+ quality_mode: "none", "both", "unmasking_only", or "insertion_only".
+ remasking_mode: Remasking strategy.
+ num_remasking: Number of tokens to remask per step.
+ quality_threshold: Threshold for insertion quality filtering.
+ dataframe: If True, include a pandas DataFrame in the return.
+ return_valid: If True, return decoded valid sequences instead of raw token tensors.
+
+ Returns:
+ For multi-objective (5 objectives):
+ (samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df])
+ For single objective:
+ (samples, sol, valid_fraction[, df])
+ When return_valid=True, samples is replaced with validSequences list.
+ """
+ xt, _, trace = _diffusion_loop(
+ model, steps, mask, pad, batch_size, max_length,
+ quality_mode=quality_mode,
+ compute_rnd=False,
+ remasking_mode=remasking_mode,
+ num_remasking=num_remasking,
+ quality_threshold=quality_threshold,
+ unmask_quality_threshold=unmask_quality_threshold,
+ unmask_all=unmask_all,
+ freq_penalty=freq_penalty,
+ )
+
+ device = xt.device
+ samples = xt.to(device)
+ decoded_samples = tokenizer.batch_decode(samples)
+
+ valid_x_final = []
+ validSequences = []
+
+ for idx, seq in enumerate(decoded_samples):
+ if analyzer.is_peptide(seq):
+ valid_x_final.append(samples[idx])
+ validSequences.append(seq)
+
+ print("len valid sequences:", len(validSequences))
+
+ valid_fraction = len(validSequences) / batch_size
+
+ # Determine number of objectives from reward model
+ num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5
+
+ if len(validSequences) != 0:
+ score_vectors = reward_model(input_seqs=validSequences) # (N, num_objectives)
+ average_scores = score_vectors.T
+
+ if num_objectives == 1:
+ sol = average_scores[0]
+ else:
+ affinity = average_scores[0]
+ sol = average_scores[1]
+ hemo = average_scores[2]
+ nf = average_scores[3]
+ permeability = average_scores[4]
+ else:
+ zeros = [0.0]
+
+ if num_objectives == 1:
+ sol = zeros
+ else:
+ affinity = zeros
+ sol = zeros
+ hemo = zeros
+ nf = zeros
+ permeability = zeros
+
+ if num_objectives == 1:
+ if dataframe:
+ df = pd.DataFrame({
+ "Peptide Sequence": validSequences,
+ "Solubility": sol if len(validSequences) else [0.0],
+ })
+ if return_valid:
+ return validSequences, sol, valid_fraction, df
+ return samples, sol, valid_fraction, df
+
+ if return_valid:
+ return validSequences, sol, valid_fraction
+ return samples, sol, valid_fraction
+
+ if dataframe:
+ df = pd.DataFrame({
+ "Peptide Sequence": validSequences,
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
+ "Solubility": sol if len(validSequences) else [0.0],
+ "Hemolysis": hemo if len(validSequences) else [0.0],
+ "Nonfouling": nf if len(validSequences) else [0.0],
+ "Permeability": permeability if len(validSequences) else [0.0],
+ })
+ if return_valid:
+ return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
+
+ if return_valid:
+ return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
diff --git a/a2d2_pep/pep_scoring/functions/binding.py b/a2d2_pep/pep_scoring/functions/binding.py
new file mode 100755
index 0000000000000000000000000000000000000000..c192b79f6387c8847341362f56a253264b61fb6e
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/binding.py
@@ -0,0 +1,178 @@
+import sys
+import os, torch
+import numpy as np
+import torch
+import pandas as pd
+import torch.nn as nn
+import esm
+from transformers import AutoModelForMaskedLM
+
+class ImprovedBindingPredictor(nn.Module):
+ def __init__(self,
+ esm_dim=1280,
+ smiles_dim=768,
+ hidden_dim=512,
+ n_heads=8,
+ n_layers=3,
+ dropout=0.1):
+ super().__init__()
+
+ # Define binding thresholds
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
+
+ # Project to same dimension
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
+ self.protein_norm = nn.LayerNorm(hidden_dim)
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
+
+ # Cross attention blocks with layer norm
+ self.cross_attention_layers = nn.ModuleList([
+ nn.ModuleDict({
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
+ 'norm1': nn.LayerNorm(hidden_dim),
+ 'ffn': nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim * 4),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim * 4, hidden_dim)
+ ),
+ 'norm2': nn.LayerNorm(hidden_dim)
+ }) for _ in range(n_layers)
+ ])
+
+ # Prediction heads
+ self.shared_head = nn.Sequential(
+ nn.Linear(hidden_dim * 2, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ )
+
+ # Regression head
+ self.regression_head = nn.Linear(hidden_dim, 1)
+
+ # Classification head (3 classes: tight, medium, loose binding)
+ self.classification_head = nn.Linear(hidden_dim, 3)
+
+ def get_binding_class(self, affinity):
+ """Convert affinity values to class indices
+ 0: tight binding (>= 7.5)
+ 1: medium binding (6.0-7.5)
+ 2: weak binding (< 6.0)
+ """
+ if isinstance(affinity, torch.Tensor):
+ tight_mask = affinity >= self.tight_threshold
+ weak_mask = affinity < self.weak_threshold
+ medium_mask = ~(tight_mask | weak_mask)
+
+ classes = torch.zeros_like(affinity, dtype=torch.long)
+ classes[medium_mask] = 1
+ classes[weak_mask] = 2
+ return classes
+ else:
+ if affinity >= self.tight_threshold:
+ return 0 # tight binding
+ elif affinity < self.weak_threshold:
+ return 2 # weak binding
+ else:
+ return 1 # medium binding
+
+ def forward(self, protein_emb, smiles_emb):
+ protein = self.protein_norm(self.protein_projection(protein_emb))
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
+
+ #protein = protein.transpose(0, 1)
+ #smiles = smiles.transpose(0, 1)
+
+ # Cross attention layers
+ for layer in self.cross_attention_layers:
+ # Protein attending to SMILES
+ attended_protein = layer['attention'](
+ protein, smiles, smiles
+ )[0]
+ protein = layer['norm1'](protein + attended_protein)
+ protein = layer['norm2'](protein + layer['ffn'](protein))
+
+ # SMILES attending to protein
+ attended_smiles = layer['attention'](
+ smiles, protein, protein
+ )[0]
+ smiles = layer['norm1'](smiles + attended_smiles)
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
+
+ # Get sequence-level representations
+ protein_pool = torch.mean(protein, dim=0)
+ smiles_pool = torch.mean(smiles, dim=0)
+
+ # Concatenate both representations
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
+
+ # Shared features
+ shared_features = self.shared_head(combined)
+
+ regression_output = self.regression_head(shared_features)
+ classification_logits = self.classification_head(shared_features)
+
+ return regression_output, classification_logits
+
+class BindingAffinity:
+ def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
+ super().__init__()
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+
+ # peptide embeddings
+ if emb_model is not None:
+ self.pep_model = emb_model.to(self.device).eval()
+ else:
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
+
+ self.pep_tokenizer = tokenizer
+
+ self.model = ImprovedBindingPredictor().to(self.device)
+ checkpoint = torch.load(f'{base_path}/functions/classifiers/binding-affinity.pt',
+ map_location=self.device,
+ weights_only=False)
+ self.model.load_state_dict(checkpoint['model_state_dict'])
+
+ self.model.eval()
+
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
+ self.esm_model = self.esm_model.to(self.device).eval()
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
+
+ data = [("target", prot_seq)]
+ # get tokenized protein
+ _, _, prot_tokens = self.prot_tokenizer(data)
+ prot_tokens = prot_tokens.to(self.device)
+ with torch.no_grad():
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
+ prot_emb = results["representations"][33]
+
+ self.prot_emb = prot_emb[0].to(self.device)
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
+
+
+ def forward(self, input_seqs):
+ with torch.no_grad():
+ scores = []
+ for seq in input_seqs:
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
+
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
+
+ with torch.no_grad():
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
+ attention_mask=pep_tokens['attention_mask'],
+ output_hidden_states=True)
+
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
+ pep_emb = emb.last_hidden_state.squeeze(0)
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
+
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
+ scores.append(score.item())
+ return scores
+
+ def __call__(self, input_seqs: list):
+ return self.forward(input_seqs)
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/binding_utils.py b/a2d2_pep/pep_scoring/functions/binding_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..5f5b7fe1d0ab2c2be70a11a27d3f1f4bddcd8ff7
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/binding_utils.py
@@ -0,0 +1,290 @@
+from torch import nn
+import torch
+import numpy as np
+
+def to_var(x):
+ if torch.cuda.is_available():
+ x = x.cuda()
+ return x
+
+class MultiHeadAttentionSequence(nn.Module):
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
+ self.W_K = nn.Linear(d_model, n_head*d_k)
+ self.W_V = nn.Linear(d_model, n_head*d_v)
+ self.W_O = nn.Linear(n_head*d_v, d_model)
+
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v):
+
+ batch, len_q, _ = q.size()
+ batch, len_k, _ = k.size()
+ batch, len_v, _ = v.size()
+
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
+
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2).transpose(2, 3)
+ V = V.transpose(1, 2)
+
+ attention = torch.matmul(Q, K)
+
+ attention = attention / np.sqrt(self.d_k)
+
+ attention = F.softmax(attention, dim=-1)
+
+ output = torch.matmul(attention, V)
+
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
+
+ output = self.W_O(output)
+
+ output = self.dropout(output)
+
+ output = self.layer_norm(output + q)
+
+ return output, attention
+
+class MultiHeadAttentionReciprocal(nn.Module):
+
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
+ self.W_K = nn.Linear(d_model, n_head*d_k)
+ self.W_V = nn.Linear(d_model, n_head*d_v)
+ self.W_O = nn.Linear(n_head*d_v, d_model)
+ self.W_V_2 = nn.Linear(d_model, n_head*d_v)
+ self.W_O_2 = nn.Linear(n_head*d_v, d_model)
+
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.layer_norm_2 = nn.LayerNorm(d_model)
+
+ self.dropout_2 = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, v_2):
+
+ batch, len_q, _ = q.size()
+ batch, len_k, _ = k.size()
+ batch, len_v, _ = v.size()
+ batch, len_v_2, _ = v_2.size()
+
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
+ V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
+
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2).transpose(2, 3)
+ V = V.transpose(1, 2)
+ V_2 = V_2.transpose(1,2)
+
+ attention = torch.matmul(Q, K)
+
+
+ attention = attention /np.sqrt(self.d_k)
+
+ attention_2 = attention.transpose(-2, -1)
+
+
+
+ attention = F.softmax(attention, dim=-1)
+
+ attention_2 = F.softmax(attention_2, dim=-1)
+
+
+ output = torch.matmul(attention, V)
+
+ output_2 = torch.matmul(attention_2, V_2)
+
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
+
+ output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
+
+ output = self.W_O(output)
+
+ output_2 = self.W_O_2(output_2)
+
+ output = self.dropout(output)
+
+ output = self.layer_norm(output + q)
+
+ output_2 = self.dropout(output_2)
+
+ output_2 = self.layer_norm(output_2 + k)
+
+
+ return output, output_2, attention, attention_2
+
+
+class FFN(nn.Module):
+
+ def __init__(self, d_in, d_hid, dropout=0.1):
+ super().__init__()
+
+ self.layer_1 = nn.Conv1d(d_in, d_hid,1)
+ self.layer_2 = nn.Conv1d(d_hid, d_in,1)
+ self.relu = nn.ReLU()
+ self.layer_norm = nn.LayerNorm(d_in)
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+
+ residual = x
+ output = self.layer_1(x.transpose(1, 2))
+
+ output = self.relu(output)
+
+ output = self.layer_2(output)
+
+ output = self.dropout(output)
+
+ output = self.layer_norm(output.transpose(1, 2)+residual)
+
+ return output
+
+class ConvLayer(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
+ super(ConvLayer, self).__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.relu(out)
+ return out
+
+
+class DilatedCNN(nn.Module):
+ def __init__(self, d_model, d_hidden):
+ super(DilatedCNN, self).__init__()
+ self.first_ = nn.ModuleList()
+ self.second_ = nn.ModuleList()
+ self.third_ = nn.ModuleList()
+
+ dilation_tuple = (1, 2, 3)
+ dim_in_tuple = (d_model, d_hidden, d_hidden)
+ dim_out_tuple = (d_hidden, d_hidden, d_hidden)
+
+ for i, dilation_rate in enumerate(dilation_tuple):
+ self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
+ dilation=dilation_rate))
+
+ for i, dilation_rate in enumerate(dilation_tuple):
+ self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
+ dilation=dilation_rate))
+
+ for i, dilation_rate in enumerate(dilation_tuple):
+ self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
+ dilation=dilation_rate))
+
+ def forward(self, protein_seq_enc):
+ # pdb.set_trace()
+ protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
+
+ first_embedding = protein_seq_enc
+ second_embedding = protein_seq_enc
+ third_embedding = protein_seq_enc
+
+ for i in range(len(self.first_)):
+ first_embedding = self.first_[i](first_embedding)
+
+ for i in range(len(self.second_)):
+ second_embedding = self.second_[i](second_embedding)
+
+ for i in range(len(self.third_)):
+ third_embedding = self.third_[i](third_embedding)
+
+ # pdb.set_trace()
+
+ protein_seq_enc = first_embedding + second_embedding + third_embedding
+
+ return protein_seq_enc.transpose(1, 2)
+
+
+class ReciprocalLayerwithCNN(nn.Module):
+
+ def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
+ super().__init__()
+
+ self.cnn = DilatedCNN(d_model, d_hidden)
+
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
+
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
+
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
+
+ self.ffn_seq = FFN(d_hidden, d_inner)
+
+ self.ffn_protein = FFN(d_hidden, d_inner)
+
+ def forward(self, sequence_enc, protein_seq_enc):
+ # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
+ protein_seq_enc = self.cnn(protein_seq_enc)
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
+
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
+
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
+
+ prot_enc = self.ffn_protein(prot_enc)
+
+ seq_enc = self.ffn_seq(seq_enc)
+
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
+
+
+class ReciprocalLayer(nn.Module):
+
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v):
+
+ super().__init__()
+
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
+
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
+
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
+
+ self.ffn_seq = FFN(d_model, d_inner)
+
+ self.ffn_protein = FFN(d_model, d_inner)
+
+ def forward(self, sequence_enc, protein_seq_enc):
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
+
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
+
+
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
+ prot_enc = self.ffn_protein(prot_enc)
+
+ seq_enc = self.ffn_seq(seq_enc)
+
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/hemolysis.py b/a2d2_pep/pep_scoring/functions/hemolysis.py
new file mode 100755
index 0000000000000000000000000000000000000000..04615bc0339f0ac675a135a51bfe88ff475e4e62
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/hemolysis.py
@@ -0,0 +1,63 @@
+import xgboost as xgb
+import torch
+import numpy as np
+from transformers import AutoModelForMaskedLM
+import warnings
+import numpy as np
+from rdkit import rdBase
+
+rdBase.DisableLog('rdApp.error')
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+class Hemolysis:
+
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/hemolysis-xgboost.json')
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
+ self.tokenizer = tokenizer
+
+ def generate_embeddings(self, sequences):
+ embeddings = []
+ for sequence in sequences:
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
+ with torch.no_grad():
+ output = self.emb_model(**tokenized)
+ # Mean pooling across sequence length
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
+ embeddings.append(embedding)
+ return np.array(embeddings)
+
+ def get_scores(self, input_seqs: list):
+ scores = np.ones(len(input_seqs))
+ features = self.generate_embeddings(input_seqs)
+
+ if len(features) == 0:
+ return scores
+
+ features = np.nan_to_num(features, nan=0.)
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
+
+ features = xgb.DMatrix(features)
+
+ probs = self.predictor.predict(features)
+ # return the probability of it being not hemolytic
+ return scores - probs
+
+ def __call__(self, input_seqs: list):
+ scores = self.get_scores(input_seqs)
+ return scores
+
+def unittest():
+ hemo = Hemolysis()
+ seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
+ print(hemo.tokenizer.vocab_size)
+ scores = hemo(input_seqs=seq)
+ print(scores)
+
+
+if __name__ == '__main__':
+ unittest()
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/nonfouling.py b/a2d2_pep/pep_scoring/functions/nonfouling.py
new file mode 100755
index 0000000000000000000000000000000000000000..4e76d8e99ca2e02d0bcac9f3eb00393e9f74e11c
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/nonfouling.py
@@ -0,0 +1,66 @@
+import sys
+import os
+import xgboost as xgb
+import torch
+import numpy as np
+from transformers import AutoModelForMaskedLM
+import warnings
+import numpy as np
+from rdkit import Chem, rdBase, DataStructs
+
+
+rdBase.DisableLog('rdApp.error')
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+class Nonfouling:
+
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/nonfouling-xgboost.json')
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
+ self.tokenizer = tokenizer
+
+ def generate_embeddings(self, sequences):
+ embeddings = []
+ for sequence in sequences:
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
+ with torch.no_grad():
+ output = self.emb_model(**tokenized)
+ # Mean pooling across sequence length
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
+ embeddings.append(embedding)
+ return np.array(embeddings)
+
+ def get_scores(self, input_seqs: list):
+ scores = np.zeros(len(input_seqs))
+ features = self.generate_embeddings(input_seqs)
+
+ if len(features) == 0:
+ return scores
+
+ features = np.nan_to_num(features, nan=0.)
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
+
+ features = xgb.DMatrix(features)
+
+ scores = self.predictor.predict(features)
+ # return the probability of it being not hemolytic
+ return scores
+
+ def __call__(self, input_seqs: list):
+ scores = self.get_scores(input_seqs)
+ return scores
+
+def unittest():
+ nf = Nonfouling()
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
+
+ scores = nf(input_seqs=seq)
+ print(scores)
+
+
+if __name__ == '__main__':
+ unittest()
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/permeability.py b/a2d2_pep/pep_scoring/functions/permeability.py
new file mode 100755
index 0000000000000000000000000000000000000000..1a9909a42706078679c9cb1c3bc9db26eea60525
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/permeability.py
@@ -0,0 +1,170 @@
+import sys
+import os
+import xgboost as xgb
+import torch
+import numpy as np
+from transformers import AutoModelForMaskedLM
+import warnings
+import numpy as np
+from rdkit.Chem import Descriptors, rdMolDescriptors
+from rdkit import Chem, rdBase, DataStructs
+from rdkit.Chem import AllChem
+from typing import List
+
+
+rdBase.DisableLog('rdApp.error')
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+def fingerprints_from_smiles(smiles: List, size=2048):
+ """ Create ECFP fingerprints of smiles, with validity check """
+ fps = []
+ valid_mask = []
+ for i, smile in enumerate(smiles):
+ mol = Chem.MolFromSmiles(smile)
+ valid_mask.append(int(mol is not None))
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
+ fps.append(fp)
+
+ fps = np.concatenate(fps, axis=0)
+ return fps, valid_mask
+
+
+def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
+ """ Create ECFP fingerprint of a molecule """
+ if hashed:
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
+ else:
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
+ fp_np = np.zeros((1,))
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
+ return fp_np.reshape(1, -1)
+
+def getMolDescriptors(mol, missingVal=0):
+ """ calculate the full list of descriptors for a molecule """
+
+ values, names = [], []
+ for nm, fn in Descriptors._descList:
+ try:
+ val = fn(mol)
+ except:
+ val = missingVal
+ values.append(val)
+ names.append(nm)
+
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
+
+ for nm, fn in custom_descriptors.items():
+ try:
+ val = fn(mol)
+ except:
+ val = missingVal
+ values.append(val)
+ names.append(nm)
+ return values, names
+
+def get_pep_dps_from_smi(smi):
+ try:
+ mol = Chem.MolFromSmiles(smi)
+ except:
+ print(f"convert smi {smi} to molecule failed!")
+ mol = None
+
+ dps, _ = getMolDescriptors(mol)
+ return np.array(dps)
+
+
+def get_pep_dps(smi_list):
+ if len(smi_list) == 0:
+ return np.zeros((0, 213))
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
+
+def check_smi_validity(smiles: list):
+ valid_smi, valid_idx = [], []
+ for idx, smi in enumerate(smiles):
+ try:
+ mol = Chem.MolFromSmiles(smi) if smi else None
+ if mol:
+ valid_smi.append(smi)
+ valid_idx.append(idx)
+ except Exception as e:
+ # logger.debug(f'Error: {e} in smiles {smi}')
+ pass
+ return valid_smi, valid_idx
+
+class Permeability:
+
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/permeability-xgboost.json')
+ if emb_model is not None:
+ self.emb_model = emb_model.to(self.device).eval()
+ else:
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
+
+ self.tokenizer = tokenizer
+
+ def generate_embeddings(self, sequences):
+ embeddings = []
+ for sequence in sequences:
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
+ with torch.no_grad():
+ output = self.emb_model(**tokenized)
+ # Mean pooling across sequence length
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
+ embeddings.append(embedding)
+ return np.array(embeddings)
+
+ def get_features(self, input_seqs: list, dps=False, fps=False):
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
+
+
+ if fps:
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
+ else:
+ fingerprints = torch.empty((len(input_seqs), 0))
+
+ if dps:
+ descriptors = get_pep_dps(input_seqs)
+ else:
+ descriptors = torch.empty((len(input_seqs), 0))
+
+ embeddings = self.generate_embeddings(input_seqs)
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
+
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
+
+ return features
+
+ def get_scores(self, input_seqs: list):
+ scores = -10 * np.ones(len(input_seqs))
+ features = self.get_features(input_seqs)
+
+ if len(features) == 0:
+ return scores
+
+ features = np.nan_to_num(features, nan=0.)
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
+
+ features = xgb.DMatrix(features)
+
+ scores = self.predictor.predict(features)
+ return scores
+
+ def __call__(self, input_seqs: list):
+ scores = self.get_scores(input_seqs)
+ return scores
+
+def unittest():
+ permeability = Permeability()
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
+ scores = permeability(input_seqs=seq)
+ print(scores)
+
+
+if __name__ == '__main__':
+ unittest()
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/scoring_utils.py b/a2d2_pep/pep_scoring/functions/scoring_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..911c29708d61288723865f71f2915990ca2e165f
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/scoring_utils.py
@@ -0,0 +1,94 @@
+import warnings
+import numpy as np
+from loguru import logger
+from sklearn.ensemble import RandomForestRegressor
+from rdkit.Chem import Descriptors, rdMolDescriptors
+import joblib
+from rdkit import Chem, rdBase, DataStructs
+from rdkit.Chem import AllChem
+from typing import List
+
+
+def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
+ """
+ Create ECFP fingerprint of a molecule
+ """
+ if hashed:
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
+ else:
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
+ fp_np = np.zeros((1,))
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
+ return fp_np.reshape(1, -1)
+
+
+def fingerprints_from_smiles(smiles: List, size=2048):
+ """ Create ECFP fingerprints of smiles, with validity check """
+ fps = []
+ valid_mask = []
+ for i, smile in enumerate(smiles):
+ mol = Chem.MolFromSmiles(smile)
+ valid_mask.append(int(mol is not None))
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
+ fps.append(fp)
+
+ fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
+ return fps, valid_mask
+
+
+def getMolDescriptors(mol, missingVal=0):
+ """ calculate the full list of descriptors for a molecule """
+
+ values, names = [], []
+ for nm, fn in Descriptors._descList:
+ try:
+ val = fn(mol)
+ except:
+ val = missingVal
+ values.append(val)
+ names.append(nm)
+
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
+
+ for nm, fn in custom_descriptors.items():
+ try:
+ val = fn(mol)
+ except:
+ val = missingVal
+ values.append(val)
+ names.append(nm)
+ return values, names
+
+
+def get_pep_dps_from_smi(smi):
+ try:
+ mol = Chem.MolFromSmiles(smi)
+ except:
+ print(f"convert smi {smi} to molecule failed!")
+ mol = None
+
+ dps, _ = getMolDescriptors(mol)
+ return np.array(dps)
+
+
+def get_pep_dps(smi_list):
+ if len(smi_list) == 0:
+ return np.zeros((0, 211))
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
+
+
+
+def check_smi_validity(smiles: list):
+ valid_smi, valid_idx = [], []
+ for idx, smi in enumerate(smiles):
+ try:
+ mol = Chem.MolFromSmiles(smi) if smi else None
+ if mol:
+ valid_smi.append(smi)
+ valid_idx.append(idx)
+ except Exception as e:
+ # logger.debug(f'Error: {e} in smiles {smi}')
+ pass
+ return valid_smi, valid_idx
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/functions/solubility.py b/a2d2_pep/pep_scoring/functions/solubility.py
new file mode 100755
index 0000000000000000000000000000000000000000..80a71d0cb7ee6fc5a3ad4a2f04dc631932bb8ff0
--- /dev/null
+++ b/a2d2_pep/pep_scoring/functions/solubility.py
@@ -0,0 +1,63 @@
+import xgboost as xgb
+import torch
+import numpy as np
+from transformers import AutoModelForMaskedLM
+import warnings
+import numpy as np
+from rdkit import rdBase
+
+rdBase.DisableLog('rdApp.error')
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+class Solubility:
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+ self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/solubility-xgboost.json')
+ if emb_model is not None:
+ self.emb_model = emb_model.to(self.device).eval()
+ else:
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
+
+ self.tokenizer = tokenizer
+
+ def generate_embeddings(self, sequences):
+ embeddings = []
+ for sequence in sequences:
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
+ with torch.no_grad():
+ output = self.emb_model(**tokenized)
+ # Mean pooling across sequence length
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
+ embeddings.append(embedding)
+ return np.array(embeddings)
+
+ def get_scores(self, input_seqs: list):
+ scores = np.zeros(len(input_seqs))
+ features = self.generate_embeddings(input_seqs)
+
+ if len(features) == 0:
+ return scores
+
+ features = np.nan_to_num(features, nan=0.)
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
+
+ features = xgb.DMatrix(features)
+
+ scores = self.predictor.predict(features)
+ return scores
+
+ def __call__(self, input_seqs: list):
+ scores = self.get_scores(input_seqs)
+ return scores
+
+def unittest():
+ solubility = Solubility()
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
+ scores = solubility(input_seqs=seq)
+ print(scores)
+
+if __name__ == '__main__':
+ unittest()
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/scoring_functions.py b/a2d2_pep/pep_scoring/scoring_functions.py
new file mode 100755
index 0000000000000000000000000000000000000000..45f3cdc14245a05d29bcd0177fe39a3e6c819903
--- /dev/null
+++ b/a2d2_pep/pep_scoring/scoring_functions.py
@@ -0,0 +1,79 @@
+import os
+from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
+from transformers import AutoModelForMaskedLM
+import numpy as np
+from .functions.binding import BindingAffinity
+from .functions.permeability import Permeability
+from .functions.solubility import Solubility
+from .functions.hemolysis import Hemolysis
+from .functions.nonfouling import Nonfouling
+
+# base path: this package directory (holds tokenizer/ and functions/classifiers/)
+base_path = os.path.dirname(os.path.abspath(__file__))
+
+class ScoringFunctions:
+ def __init__(self, score_func_names=None, prot_seqs=None, device=None):
+ """
+ Class for generating score vectors given generated sequence
+
+ Args:
+ score_func_names: list of scoring function names to be evaluated
+ score_weights: weights to scale scores (default: 1)
+ target_protein: sequence of target protein binder
+ """
+ emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tokenizer/new_vocab.txt',
+ f'{base_path}/tokenizer/new_splits.txt')
+ prot_seqs = prot_seqs if prot_seqs is not None else []
+
+ if score_func_names is None:
+ # just do unmasking based on validity of peptide bonds
+ self.score_func_names = []
+ else:
+ self.score_func_names = score_func_names
+
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
+
+ # binding affinities
+ self.target_protein = prot_seqs
+ print(len(prot_seqs))
+
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
+ binding_affinity2 = None
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
+ binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
+ else:
+ print("here")
+ binding_affinity1 = None
+ binding_affinity2 = None
+
+ permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
+ sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
+ nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
+ hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
+
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
+ 'binding_affinity2': binding_affinity2,
+ 'permeability': permeability,
+ 'nonfouling': nonfouling,
+ 'solubility': sol,
+ 'hemolysis': hemo
+ }
+
+ def forward(self, input_seqs):
+ scores = []
+
+ for i, score_func in enumerate(self.score_func_names):
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
+
+ scores.append(score)
+
+ # convert to numpy arrays with shape (num_sequences, num_functions)
+ scores = np.float32(scores).T
+
+ return scores
+
+ def __call__(self, input_seqs: list):
+ return self.forward(input_seqs)
\ No newline at end of file
diff --git a/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py b/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py
new file mode 100755
index 0000000000000000000000000000000000000000..0ada3e41d16041e8e22d37b4e5fd8303b9f00491
--- /dev/null
+++ b/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py
@@ -0,0 +1,424 @@
+import collections
+import os
+import re
+from typing import List, Optional
+from transformers import PreTrainedTokenizer
+from SmilesPE.tokenizer import SPE_Tokenizer
+import torch
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+class Atomwise_Tokenizer(object):
+ """Run atom-level SMILES tokenization"""
+
+ def __init__(self):
+ """ Constructs a atom-level Tokenizer.
+ """
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
+
+ self.regex = re.compile(self.regex_pattern)
+
+ def tokenize(self, text):
+ """ Basic Tokenization of a SMILES.
+ """
+ tokens = [token for token in self.regex.findall(text)]
+ return tokens
+
+class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
+ r"""
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
+ should refer to the superclass for more information regarding methods.
+ Args:
+ vocab_file (:obj:`string`):
+ File containing the vocabulary.
+ spe_file (:obj:`string`):
+ File containing the trained SMILES Pair Encoding vocabulary.
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ """
+
+ def __init__(self, vocab_file, spe_file,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ **kwargs):
+ if not os.path.isfile(vocab_file):
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
+ if not os.path.isfile(spe_file):
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
+
+ self.vocab = load_vocab(vocab_file)
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
+
+ super().__init__(
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ return self.spe_tokenizer.tokenize(text).split(' ')
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ # changed encode and decode functions
+ def encode(self, token_array):
+ token_ids = []
+ token_ids.append(2)
+ for token in token_array:
+ id = self._convert_token_to_id(token)
+ token_ids.append(id)
+ token_ids.append(3)
+ token_ids = torch.tensor([token_ids])
+ attn_mask = torch.ones_like(token_ids)
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
+
+ def decode(self, token_ids, skip_special_tokens=True):
+ token_ids = token_ids.squeeze(0).cpu().tolist()
+ token_array = []
+ for idx in token_ids:
+ if idx == 3: # Stop decoding when token ID 3 is encountered
+ break
+ if skip_special_tokens and idx in self.all_special_ids:
+ continue
+ token = self._convert_id_to_token(idx)
+ token_array.append(token)
+ sequence = "".join(token_array)
+ return sequence
+
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
+ sequences = []
+ for token_ids in batch_token_ids:
+ sequences.append(self.decode(token_ids))
+ return sequences
+
+ def get_token_split(self, token_ids):
+ if isinstance(token_ids, torch.Tensor):
+ token_ids = token_ids.cpu().tolist()
+
+ token_array = []
+ for seq_ids in token_ids:
+ seq_array = []
+ for id in seq_ids:
+ token = self._convert_id_to_token(id)
+ seq_array.append(token)
+ token_array.append(seq_array)
+
+ return token_array
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A BERT sequence has the following format:
+ - single sequence: ``[CLS] X [SEP]``
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True if the token list is already formatted with special tokens for the model
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
+ A BERT sequence pair mask has the following format:
+ ::
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, vocab_path):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+ Args:
+ vocab_path (:obj:`str`):
+ The directory in which to save the vocabulary.
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ index = 0
+ vocab_file = vocab_path
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
+ r"""
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
+ should refer to the superclass for more information regarding methods.
+ Args:
+ vocab_file (:obj:`string`):
+ File containing the vocabulary.
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ """
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ **kwargs
+ ):
+ super().__init__(
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.tokenizer = Atomwise_Tokenizer()
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+
+ def _tokenize(self, text):
+ return self.tokenizer.tokenize(text)
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A BERT sequence has the following format:
+ - single sequence: ``[CLS] X [SEP]``
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True if the token list is already formatted with special tokens for the model
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
+ A BERT sequence pair mask has the following format:
+ ::
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, vocab_path):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+ Args:
+ vocab_path (:obj:`str`):
+ The directory in which to save the vocabulary.
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ index = 0
+ vocab_file = vocab_path
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
diff --git a/a2d2_pep/pep_utils/analyzer.py b/a2d2_pep/pep_utils/analyzer.py
new file mode 100755
index 0000000000000000000000000000000000000000..1ba3d3c6fe69ea73b933225148cf9f2c8844a129
--- /dev/null
+++ b/a2d2_pep/pep_utils/analyzer.py
@@ -0,0 +1,1274 @@
+import os
+import re
+import pandas as pd
+from io import StringIO
+import rdkit
+from rdkit import Chem
+from rdkit.Chem import AllChem, Draw
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+from io import BytesIO
+import tempfile
+from rdkit import Chem
+
+class PeptideAnalyzer:
+ def __init__(self, min_peptide_bonds=2, enforce_min_peptide_bonds=True):
+ # length cutoff: minimum number of backbone residues (N-Cα-C(=O) units)
+
+ self.min_peptide_bonds = min_peptide_bonds
+ self.enforce_min_peptide_bonds = enforce_min_peptide_bonds
+ self.bond_patterns = [
+ (r'OC\(=O\)', 'ester'), # Ester bond
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
+ ]
+ # Three to one letter code mapping
+ self.three_to_one = {
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
+ }
+
+ def count_peptide_bonds(self, smiles):
+ """Count backbone peptide residues via N-Cα-C(=O) units.
+
+ Matches the backbone pattern [NX3][CX4][CX3](=O): an amide nitrogen
+ bonded to an sp3 alpha-carbon bonded to a carbonyl. Requiring the sp3
+ Cα excludes non-backbone amides — ureas/biurets (N-C(=O)-N, no Cα),
+ sulfonamides, and side-chain amides (Asn/Gln) — and uniquify=True
+ avoids the multiple-mapping over-count of symmetric N-methyl groups.
+ Each match corresponds to one backbone residue.
+ """
+ mol = Chem.MolFromSmiles(smiles)
+ if mol is None:
+ return 0
+ backbone_pattern = Chem.MolFromSmarts('[NX3][CX4][CX3](=O)')
+ return len(mol.GetSubstructMatches(backbone_pattern, uniquify=True))
+
+ def is_peptide(self, smiles):
+ """Check if the SMILES represents a peptide structure"""
+ mol = Chem.MolFromSmiles(smiles)
+ if mol is None:
+ return False
+
+ # Count backbone residues (N-Cα-C(=O) units). Requiring a real backbone
+ # unit rejects ureas/biurets and side-chain-only amides outright.
+ n_residues = self.count_peptide_bonds(smiles)
+ if n_residues == 0:
+ return False
+
+ # length cutoff: reject molecules with too few backbone residues
+ if self.enforce_min_peptide_bonds and n_residues < self.min_peptide_bonds:
+ return False
+
+ return True
+
+ def is_cyclic(self, smiles):
+ """Improved cyclic peptide detection"""
+ # Check for C-terminal carboxyl
+ if smiles.endswith('C(=O)O'):
+ return False, [], []
+
+ # Find all numbers used in ring closures
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
+
+ # Find aromatic ring numbers
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
+ aromatic_cycles = []
+ for match in aromatic_matches:
+ numbers = re.findall(r'[0-9]', match)
+ aromatic_cycles.extend(numbers)
+
+ # Numbers that aren't part of aromatic rings are peptide cycles
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
+
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
+ return is_cyclic, peptide_cycles, aromatic_cycles
+
+ def split_on_bonds(self, smiles):
+ """Split SMILES into segments with simplified Pro handling"""
+ positions = []
+ used = set()
+
+ # Find Gly pattern first
+ gly_pattern = r'NCC\(=O\)'
+ for match in re.finditer(gly_pattern, smiles):
+ if not any(p in range(match.start(), match.end()) for p in used):
+ positions.append({
+ 'start': match.start(),
+ 'end': match.end(),
+ 'type': 'gly',
+ 'pattern': match.group()
+ })
+ used.update(range(match.start(), match.end()))
+
+ for pattern, bond_type in self.bond_patterns:
+ for match in re.finditer(pattern, smiles):
+ if not any(p in range(match.start(), match.end()) for p in used):
+ positions.append({
+ 'start': match.start(),
+ 'end': match.end(),
+ 'type': bond_type,
+ 'pattern': match.group()
+ })
+ used.update(range(match.start(), match.end()))
+
+ # Sort by position
+ positions.sort(key=lambda x: x['start'])
+
+ # Create segments
+ segments = []
+
+ if positions:
+ # First segment
+ if positions[0]['start'] > 0:
+ segments.append({
+ 'content': smiles[0:positions[0]['start']],
+ 'bond_after': positions[0]['pattern']
+ })
+
+ # Process segments
+ for i in range(len(positions)-1):
+ current = positions[i]
+ next_pos = positions[i+1]
+
+ if current['type'] == 'gly':
+ segments.append({
+ 'content': 'NCC(=O)',
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
+ 'bond_after': next_pos['pattern']
+ })
+ else:
+ content = smiles[current['end']:next_pos['start']]
+ if content:
+ segments.append({
+ 'content': content,
+ 'bond_before': current['pattern'],
+ 'bond_after': next_pos['pattern']
+ })
+
+ # Last segment
+ if positions[-1]['end'] < len(smiles):
+ segments.append({
+ 'content': smiles[positions[-1]['end']:],
+ 'bond_before': positions[-1]['pattern']
+ })
+
+ return segments
+
+ def clean_terminal_carboxyl(self, segment):
+ """Remove C-terminal carboxyl only if it's the true terminus"""
+ content = segment['content']
+
+ # Only clean if:
+ # 1. Contains C(=O)O
+ # 2. No bond_after exists (meaning it's the last segment)
+ # 3. C(=O)O is at the end of the content
+ if 'C(=O)O' in content and not segment.get('bond_after'):
+ print('recognized?')
+ # Remove C(=O)O pattern regardless of position
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
+ # Remove any leftover empty parentheses
+ cleaned = re.sub(r'\(\)', '', cleaned)
+ print(cleaned)
+ return cleaned
+ return content
+
+ def identify_residue(self, segment):
+ """Identify residue with Pro reconstruction"""
+ # Only clean terminal carboxyl if this is the last segment
+ content = self.clean_terminal_carboxyl(segment)
+ mods = self.get_modifications(segment)
+
+ # UAA pattern matching section - before regular residues
+ # Phenylglycine and derivatives
+ if 'c1ccccc1' in content:
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
+ return '4', mods # Base phenylglycine
+
+ # 4-substituted phenylalanines
+ if 'Cc1ccc' in content:
+ if 'OMe' in content or 'OCc1ccc' in content:
+ return '0A1', mods # 4-methoxy-Phenylalanine
+ elif 'Clc1ccc' in content:
+ return '200', mods # 4-chloro-Phenylalanine
+ elif 'Brc1ccc' in content:
+ return '4BF', mods # 4-Bromo-phenylalanine
+ elif 'C#Nc1ccc' in content:
+ return '4CF', mods # 4-cyano-phenylalanine
+ elif 'Ic1ccc' in content:
+ return 'PHI', mods # 4-Iodo-phenylalanine
+ elif 'Fc1ccc' in content:
+ return 'PFF', mods # 4-Fluoro-phenylalanine
+
+ # Modified tryptophans
+ if 'c[nH]c2' in content:
+ if 'Oc2cccc2' in content:
+ return '0AF', mods # 7-hydroxy-tryptophan
+ elif 'Fc2cccc2' in content:
+ return '4FW', mods # 4-fluoro-tryptophan
+ elif 'Clc2cccc2' in content:
+ return '6CW', mods # 6-chloro-tryptophan
+ elif 'Brc2cccc2' in content:
+ return 'BTR', mods # 6-bromo-tryptophan
+ elif 'COc2cccc2' in content:
+ return 'MOT5', mods # 5-Methoxy-tryptophan
+ elif 'Cc2cccc2' in content:
+ return 'MTR5', mods # 5-Methyl-tryptophan
+
+ # Special amino acids
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
+ return 'BUG', mods # Tertleucine
+
+ if 'CCCNC(=N)N' in content:
+ return 'CIR', mods # Citrulline
+
+ if '[SeH]' in content:
+ return 'CSE', mods # Selenocysteine
+
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
+ return 'DAB', mods # Diaminobutyric acid
+
+ if 'C1CCCCC1' in content:
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
+ return 'CHG', mods # Cyclohexylglycine
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
+ return 'ALC', mods # 3-cyclohexyl-alanine
+
+ # Naphthalene derivatives
+ if 'c1cccc2c1cccc2' in content:
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
+ return 'NAL', mods # 2-Naphthyl-alanine
+
+ # Heteroaromatic derivatives
+ if 'c1cncc' in content:
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
+ if 'c1cscc' in content:
+ return 'THA3', mods # 3-(3-thienyl)-alanine
+ if 'c1nnc' in content:
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
+
+ # Modified serines and threonines
+ if 'OP(O)(O)O' in content:
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
+ return 'SEP', mods # phosphoserine
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
+ return 'TPO', mods # phosphothreonine
+
+ # Specialized ring systems
+ if 'c1c2ccccc2cc2c1cccc2' in content:
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
+ if 'c1csc2c1cccc2' in content:
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
+ return 'ADAM', mods # Adamanthane
+
+ # Fluorinated derivatives
+ if 'FC(F)(F)' in content:
+ if 'CC(F)(F)F' in content:
+ return 'FLA', mods # Trifluoro-alanine
+ if 'C(F)(F)F)c1' in content:
+ if 'c1ccccc1C(F)(F)F' in content:
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
+ if 'c1cccc(c1)C(F)(F)F' in content:
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
+ if 'c1ccc(cc1)C(F)(F)F' in content:
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
+
+ # Multiple halogen patterns
+ if 'F' in content and 'c1' in content:
+ if 'c1ccc(c(c1)F)F' in content:
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
+ if 'cc(F)cc(c1)F' in content:
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
+ if 'Cl' in content and 'c1' in content:
+ if 'c1ccc(cc1Cl)Cl' in content:
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
+ if 'c1ccc(c(c1)Cl)Cl' in content:
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
+
+ # Hydroxy and amino derivatives
+ if 'O' in content and 'c1' in content:
+ if 'c1cc(O)cc(c1)O' in content:
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
+ if 'c1ccc(c(c1)O)O' in content:
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
+
+ # Cyclic amino acids
+ if 'C1CCCC1' in content:
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
+ if 'C1CCCCC1' in content:
+ if 'CC1CCCCC1' in content:
+ return 'ALC', mods # 3-cyclohexyl-alanine
+ else:
+ return 'CHG', mods # Cyclohexylglycine
+
+ # Chain-length variants
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
+ return 'NLE', mods # Norleucine
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
+ return 'ABA', mods # 2-Aminobutyric acid
+
+ # Modified histidines
+ if 'c1cnc' in content:
+ if '[C@@H]1CN[C@@H](N1)F' in content:
+ return '2HF', mods # 2-fluoro-l-histidine
+ if 'c1cnc([nH]1)F' in content:
+ return '2HF1', mods # 2-fluoro-l-histidine variant
+ if 'c1c[nH]c(n1)F' in content:
+ return '2HF2', mods # 2-fluoro-l-histidine variant
+
+ # Sulfur and selenium containing
+ if '[SeH]' in content:
+ return 'CSE', mods # Selenocysteine
+ if 'S' in content:
+ if 'CSCc1ccccc1' in content:
+ return 'BCS', mods # benzylcysteine
+ if 'CCSC' in content:
+ return 'ESC', mods # Ethionine
+ if 'CCS' in content:
+ return 'HCS', mods # homocysteine
+
+ # Additional modifications
+ if 'CN=[N]=N' in content:
+ return 'AZDA', mods # azido-alanine
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
+ if 'CCC[NH]=' in content:
+ return 'AGM', mods # 5-methyl-arginine
+ if 'CC[NH]=' in content:
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
+
+ if 'CCON' in content:
+ return 'CAN', mods # canaline
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
+ return 'ACZ', mods # cis-amiclenomycin
+ if 'CCC(=O)[NH3]' in content:
+ return 'ONL', mods # 5-oxo-l-norleucine
+ if 'c1ccncc1' in content:
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
+ if 'c1ccco1' in content:
+ return 'FUA2', mods # (2-furyl)-alanine
+
+ if 'c1ccc' in content:
+ if 'c1ccc(cc1)c1ccccc1' in content:
+ return 'BIF', mods # 4,4-biphenylalanine
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
+ return 'PBF', mods # 4-benzoyl-phenylalanine
+ if 'c1ccc(cc1)C(C)(C)C' in content:
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
+ return 'APM', mods # m-amidinophenyl-3-alanine
+
+ # Multiple hydroxy patterns
+ if 'O' in content:
+ if '[C@H]([C@H](C)O)O' in content:
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
+ if '[C@H]([C@@H](C)O)O' in content:
+ return 'ALO', mods # Allo-threonine
+ if '[C@H](COP(O)(O)O)' in content:
+ return 'SEP', mods # phosphoserine
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
+ return 'TPO', mods # phosphothreonine
+ if '[C@H](c1ccc(O)cc1)O' in content:
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
+
+ # Heterocyclic patterns
+ if 'n1' in content:
+ if 'n1cccn1' in content:
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
+ if 'n1nncn1' in content:
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
+ if 'c2c(n1)cccc2' in content:
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
+ if 'c1cnc2c(c1)cccc2' in content:
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
+ if 'c1ccnc2c1cccc2' in content:
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
+ if 'c1ccc2c(c1)nccc2' in content:
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
+ if 'c1ccc2c(c1)cncc2' in content:
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
+ if 'c1cnc2c(n1)cccc2' in content:
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
+
+ # Multiple nitrogen patterns
+ if 'N' in content:
+ if '[NH3]CC[C@@H]' in content:
+ return 'DAB', mods # Diaminobutyric acid
+ if '[NH3]C[C@@H]' in content:
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
+ if '[NH3]CCCCCC[C@@H]' in content:
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
+ if '[NH]=[C](=S)=[NH2]' in content:
+ return 'THIC', mods # Thio-citrulline
+
+ # Chain modified amino acids
+ if 'CC' in content:
+ if 'CCCC[C@@H]' in content:
+ return 'AHP', mods # 2-Aminoheptanoic acid
+ if 'CCC([C@@H])(C)C' in content:
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
+ if 'CC[C@H]([C@@H])C' in content:
+ return 'IIL', mods # Allo-Isoleucine
+ if '[C@H](CCC(C)C)' in content:
+ return 'HLEU', mods # Homoleucine
+ if '[C@@H]([C@@H](C)O)C' in content:
+ return 'HLU', mods # beta-hydroxyleucine
+
+ # Modified glutamate/aspartate patterns
+ if '[C@@H]' in content:
+ if '[C@@H](C[C@@H](F))' in content:
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
+ if '[C@@H](C[C@@H](O))' in content:
+ return '3GL', mods # 4-hydroxy-glutamic-acid
+ if '[C@@H](C[C@H](C))' in content:
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
+ if '[C@@H](CC[C@H](C))' in content:
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
+
+ # Sulfur and selenium modifications
+ if 'S' in content:
+ if 'SCC[C@@H]' in content:
+ return 'HSER', mods # homoserine
+ if 'SCCN' in content:
+ return 'SLZ', mods # thialysine
+ if 'SC(=O)' in content:
+ return 'CSA', mods # s-acetonylcysteine
+ if '[S@@](=O)' in content:
+ return 'SME', mods # Methionine sulfoxide
+ if 'S(=O)(=O)' in content:
+ return 'OMT', mods # Methionine sulfone
+
+ # Double bond containing
+ if 'C=' in content:
+ if 'C=C[C@@H]' in content:
+ return '2AG', mods # 2-Allyl-glycine
+ if 'C=C[C@@H]' in content:
+ return 'LVG', mods # vinylglycine
+ if 'C=Cc1ccccc1' in content:
+ return 'STYA', mods # Styrylalanine
+
+ # Special cases
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
+ if '[C](=[C](=O)=O)=O' in content:
+ return '26P', mods # 2-amino-6-oxopimelic acid
+ if '[C](=[C](=O)=O)=C' in content:
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
+ if 'c2cnc[nH]2' in content:
+ return 'HIS', mods # histidine core
+ if 'c1cccc2c1cc(O)cc2' in content:
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
+
+ # Proline (P) - flexible ring numbers
+ if any([
+ # Check for any ring number in bond patterns
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
+ for n in '123456789'
+ ]) or any([
+ # Check ending patterns with any ring number
+ (f'CCCN{n}' in content and content.endswith('=O') and
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
+ for n in '123456789'
+ ]) or any([
+ # Handle CCC[C@H]n patterns
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
+ # N-terminal Pro with any ring number
+ (f'N{n}CCC[C@H]{n}' in content) or
+ (f'N{n}CCC[C@@H]{n}' in content)
+ for n in '123456789'
+ ]):
+ return 'Pro', mods
+
+ # Tryptophan (W) - more specific indole pattern
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
+ 'c[nH]c' in content.replace(' ', ''):
+ return 'Trp', mods
+
+ # Lysine (K) - both patterns
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
+ return 'Lys', mods
+
+ # Arginine (R) - both patterns
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
+ return 'Arg', mods
+
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
+ return 'Nle', mods
+
+ # Ornithine (Orn) - 3-carbon chain with NH2
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
+ return 'Orn', mods
+
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return '2Nal', mods
+
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
+ return 'Cha', mods
+
+ # Aminobutyric acid (Abu) - 2-carbon chain
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
+ return 'Abu', mods
+
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Pip', mods
+
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
+ return 'Chg', mods
+
+ # 4-Fluorophenylalanine (4F-Phe)
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return '4F-Phe', mods
+
+ # Regular residue identification
+ if ('NCC(=O)' in content) or (content == 'C'):
+ # Middle case - between bonds
+ if segment.get('bond_before') and segment.get('bond_after'):
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
+ return 'Gly', mods
+ # Terminal case - at the end
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
+ return 'Gly', mods
+
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
+ return 'Leu', mods
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
+ return 'Leu', mods
+
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
+ return 'Thr', mods
+
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
+ return 'Phe', mods
+
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
+ '[C@@H](C(C)C)' in content or # With outer parentheses
+ '[C@H]C(C)C' in content or # Without outer parentheses
+ '[C@@H]C(C)C' in content): # Without outer parentheses
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
+ return 'Val', mods
+
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
+ return 'O-tBu', mods
+
+ if any([
+ 'CC[C@H](C)' in content,
+ 'CC[C@@H](C)' in content,
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
+ ]):
+ return 'Ile', mods
+
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
+ return 'Ala', mods
+
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
+ return 'Tyr', mods
+
+
+ # Serine (Ser) - Hydroxymethyl side chain
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
+ if not ('C(C)O' in content or 'COC' in content):
+ return 'Ser', mods
+
+ # Threonine (Thr) - 1-hydroxyethyl side chain
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
+ return 'Thr', mods
+
+ # Cysteine (Cys) - Thiol side chain
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
+ return 'Cys', mods
+
+ # Methionine (Met) - Methylthioethyl side chain
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
+ return 'Met', mods
+
+ # Asparagine (Asn) - Carbamoylmethyl side chain
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Asn', mods
+
+ # Glutamine (Gln) - Carbamoylethyl side chain
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Gln', mods
+
+ # Aspartic acid (Asp) - Carboxymethyl side chain
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Asp', mods
+
+ # Glutamic acid (Glu) - Carboxyethyl side chain
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Glu', mods
+
+ # Arginine (Arg) - 3-guanidinopropyl side chain
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'Arg', mods
+
+ # Histidine (His) - Imidazole side chain
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
+ return 'His', mods
+
+ return None, mods
+
+ def get_modifications(self, segment):
+ """Get modifications based on bond types"""
+ mods = []
+ if segment.get('bond_after'):
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
+ mods.append('N-Me')
+ if 'OC(=O)' in segment['bond_after']:
+ mods.append('O-linked')
+ return mods
+
+ def analyze_structure(self, smiles):
+ """Main analysis function with debug output"""
+ print("\nAnalyzing structure:", smiles)
+
+ # Split into segments
+ segments = self.split_on_bonds(smiles)
+
+ print("\nSegment Analysis:")
+ sequence = []
+ for i, segment in enumerate(segments):
+ print(f"\nSegment {i}:")
+ print(f"Content: {segment['content']}")
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
+
+ residue, mods = self.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence.append(residue)
+ print(f"Identified as: {residue}")
+ print(f"Modifications: {mods}")
+ else:
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
+
+ # Check if cyclic
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
+ three_letter = '-'.join(sequence)
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
+
+ if is_cyclic:
+ three_letter = f"cyclo({three_letter})"
+ one_letter = f"cyclo({one_letter})"
+
+ print(f"\nFinal sequence: {three_letter}")
+ print(f"One-letter code: {one_letter}")
+ print(f"Is cyclic: {is_cyclic}")
+ #print(f"Peptide cycles: {peptide_cycles}")
+ #print(f"Aromatic cycles: {aromatic_cycles}")
+
+ return three_letter, len(segments)
+ """return {
+ 'three_letter': three_letter,
+ #'one_letter': one_letter,
+ 'is_cyclic': is_cyclic
+ }"""
+
+ def return_sequence(self, smiles):
+ """Main analysis function with debug output"""
+ print("\nAnalyzing structure:", smiles)
+
+ # Split into segments
+ segments = self.split_on_bonds(smiles)
+
+ print("\nSegment Analysis:")
+ sequence = []
+ for i, segment in enumerate(segments):
+ print(f"\nSegment {i}:")
+ print(f"Content: {segment['content']}")
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
+
+ residue, mods = self.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence.append(residue)
+ print(f"Identified as: {residue}")
+ print(f"Modifications: {mods}")
+ else:
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
+
+ return sequence
+
+"""
+def annotate_cyclic_structure(mol, sequence):
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
+ # Generate 2D coordinates
+ # Generate 2D coordinates
+ AllChem.Compute2DCoords(mol)
+
+ # Create drawer with larger size for annotations
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
+
+ # Get residue list and reverse it to match structural representation
+ if sequence.startswith('cyclo('):
+ residues = sequence[6:-1].split('-')
+ else:
+ residues = sequence.split('-')
+ residues = list(reversed(residues)) # Reverse the sequence
+
+ # Draw molecule first to get its bounds
+ drawer.drawOptions().addAtomIndices = False
+ drawer.DrawMolecule(mol)
+ drawer.FinishDrawing()
+
+ # Convert to PIL Image
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
+ draw = ImageDraw.Draw(img)
+
+ try:
+ # Try to use DejaVuSans as it's commonly available on Linux systems
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
+ except OSError:
+ try:
+ # Fallback to Arial if available (common on Windows)
+ font = ImageFont.truetype("arial.ttf", 60)
+ small_font = ImageFont.truetype("arial.ttf", 60)
+ except OSError:
+ # If no TrueType fonts are available, fall back to default
+ print("Warning: TrueType fonts not available, using default font")
+ font = ImageFont.load_default()
+ small_font = ImageFont.load_default()
+ # Get molecule bounds
+ conf = mol.GetConformer()
+ positions = []
+ for i in range(mol.GetNumAtoms()):
+ pos = conf.GetAtomPosition(i)
+ positions.append((pos.x, pos.y))
+
+ x_coords = [p[0] for p in positions]
+ y_coords = [p[1] for p in positions]
+ min_x, max_x = min(x_coords), max(x_coords)
+ min_y, max_y = min(y_coords), max(y_coords)
+
+ # Calculate scaling factors
+ scale = 150 # Increased scale factor
+ center_x = 1000 # Image center
+ center_y = 1000
+
+ # Add residue labels in a circular arrangement around the structure
+ n_residues = len(residues)
+ radius = 700 # Distance of labels from center
+
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
+ # Offset by -3 positions to align with structure
+ offset = 0 # Adjust this value to match the structure alignment
+ for i, residue in enumerate(residues):
+ # Calculate position in a circle around the structure
+ # Start from 0 (3 o'clock) and go counterclockwise
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
+
+ # Calculate label position
+ label_x = center_x + radius * np.cos(angle)
+ label_y = center_y + radius * np.sin(angle)
+
+ # Draw residue label
+ text = f"{i+1}. {residue}"
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
+ padding = 10
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
+ bbox[2]+padding, bbox[3]+padding],
+ fill='white', outline='white')
+ draw.text((label_x, label_y), text,
+ font=font, fill='black', anchor="mm")
+
+ # Add sequence at the top with white background
+ seq_text = f"Sequence: {sequence}"
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
+ padding = 10
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
+ bbox[2]+padding, bbox[3]+padding],
+ fill='white', outline='white')
+ draw.text((center_x, 100), seq_text,
+ font=small_font, fill='black', anchor="mm")
+
+ return img
+
+"""
+def annotate_cyclic_structure(mol, sequence):
+ """Create structure visualization with just the sequence header"""
+ # Generate 2D coordinates
+ AllChem.Compute2DCoords(mol)
+
+ # Create drawer with larger size for annotations
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
+
+ # Draw molecule first
+ drawer.drawOptions().addAtomIndices = False
+ drawer.DrawMolecule(mol)
+ drawer.FinishDrawing()
+
+ # Convert to PIL Image
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
+ draw = ImageDraw.Draw(img)
+ try:
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
+ except OSError:
+ try:
+ small_font = ImageFont.truetype("arial.ttf", 60)
+ except OSError:
+ print("Warning: TrueType fonts not available, using default font")
+ small_font = ImageFont.load_default()
+
+ # Add just the sequence header at the top
+ seq_text = f"Sequence: {sequence}"
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
+ padding = 10
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
+ bbox[2]+padding, bbox[3]+padding],
+ fill='white', outline='white')
+ draw.text((1000, 100), seq_text,
+ font=small_font, fill='black', anchor="mm")
+
+ return img
+
+def create_enhanced_linear_viz(sequence, smiles):
+ """Create an enhanced linear representation using PeptideAnalyzer"""
+ analyzer = PeptideAnalyzer() # Create analyzer instance
+
+ # Create figure with two subplots
+ fig = plt.figure(figsize=(15, 10))
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
+ ax_struct = fig.add_subplot(gs[0])
+ ax_detail = fig.add_subplot(gs[1])
+
+ # Parse sequence and get residues
+ if sequence.startswith('cyclo('):
+ residues = sequence[6:-1].split('-')
+ else:
+ residues = sequence.split('-')
+
+ # Get segments using analyzer
+ segments = analyzer.split_on_bonds(smiles)
+
+ # Debug print
+ print(f"Number of residues: {len(residues)}")
+ print(f"Number of segments: {len(segments)}")
+
+ # Top subplot - Basic structure
+ ax_struct.set_xlim(0, 10)
+ ax_struct.set_ylim(0, 2)
+
+ num_residues = len(residues)
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
+
+ # Draw basic structure
+ y_pos = 1.5
+ for i in range(num_residues):
+ x_pos = 0.5 + i * spacing
+
+ # Draw amino acid box
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
+ facecolor='lightblue', edgecolor='black')
+ ax_struct.add_patch(rect)
+
+ # Draw connecting bonds if not the last residue
+ if i < num_residues - 1:
+ segment = segments[i] if i < len(segments) else None
+ if segment:
+ # Determine bond type from segment info
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
+
+ bond_color = 'red' if bond_type == 'ester' else 'black'
+ linestyle = '--' if bond_type == 'ester' else '-'
+
+ # Draw bond line
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
+ color=bond_color, linestyle=linestyle, linewidth=2)
+
+ # Add bond type label
+ mid_x = x_pos + spacing/2
+ bond_label = f"{bond_type}"
+ if is_n_methylated:
+ bond_label += "\n(N-Me)"
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
+ ha='center', va='bottom', fontsize=10,
+ color=bond_color)
+
+ # Add residue label
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
+ ha='center', va='top', fontsize=14)
+
+ # Bottom subplot - Detailed breakdown
+ ax_detail.set_ylim(0, len(segments)+1)
+ ax_detail.set_xlim(0, 1)
+
+ # Create detailed breakdown
+ segment_y = len(segments) # Start from top
+ for i, segment in enumerate(segments):
+ y = segment_y - i
+
+ # Check if this is a bond or residue
+ residue, mods = analyzer.identify_residue(segment)
+ if residue:
+ text = f"Residue {i+1}: {residue}"
+ if mods:
+ text += f" ({', '.join(mods)})"
+ color = 'blue'
+ else:
+ # Must be a bond
+ text = f"Bond {i}: "
+ if 'O-linked' in segment.get('bond_after', ''):
+ text += "ester"
+ elif 'N-Me' in segment.get('bond_after', ''):
+ text += "peptide (N-methylated)"
+ else:
+ text += "peptide"
+ color = 'red'
+
+ # Add segment analysis
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
+
+ # If cyclic, add connection indicator
+ if sequence.startswith('cyclo('):
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
+ ha='center', color='red', fontsize=14)
+
+ # Add titles and adjust layout
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
+
+ # Remove axes
+ for ax in [ax_struct, ax_detail]:
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.axis('off')
+
+ plt.tight_layout()
+ return fig
+
+class PeptideStructureGenerator:
+ """A class to generate 3D structures of peptides using different embedding methods"""
+
+ @staticmethod
+ def prepare_molecule(smiles):
+ """Prepare molecule with proper hydrogen handling"""
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
+ if mol is None:
+ raise ValueError("Failed to create molecule from SMILES")
+
+ # Calculate valence for each atom
+ for atom in mol.GetAtoms():
+ atom.UpdatePropertyCache(strict=False)
+
+ # Sanitize with reduced requirements
+ Chem.SanitizeMol(mol,
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
+ Chem.SANITIZE_KEKULIZE|
+ Chem.SANITIZE_SETAROMATICITY|
+ Chem.SANITIZE_SETCONJUGATION|
+ Chem.SANITIZE_SETHYBRIDIZATION|
+ Chem.SANITIZE_CLEANUPCHIRALITY)
+
+ mol = Chem.AddHs(mol)
+ return mol
+
+ @staticmethod
+ def get_etkdg_params(attempt=0):
+ """Get ETKDG parameters with optional modifications based on attempt number"""
+ params = AllChem.ETKDGv3()
+ params.randomSeed = -1
+ params.maxIterations = 200
+ params.numThreads = 4 # Reduced for web interface
+ params.useBasicKnowledge = True
+ params.enforceChirality = True
+ params.useExpTorsionAnglePrefs = True
+ params.useSmallRingTorsions = True
+ params.useMacrocycleTorsions = True
+ params.ETversion = 2
+ params.pruneRmsThresh = -1
+ params.embedRmsThresh = 0.5
+
+ if attempt > 10:
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
+ params.useExpTorsionAnglePrefs = False
+
+ return params
+
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
+ """Generate 3D structure using ETKDG without UFF optimization"""
+ success = False
+ mol = None
+
+ for attempt in range(max_attempts):
+ try:
+ mol = self.prepare_molecule(smiles)
+ params = self.get_etkdg_params(attempt)
+
+ if AllChem.EmbedMolecule(mol, params) == 0:
+ success = True
+ break
+ except Exception as e:
+ continue
+
+ if not success:
+ raise ValueError("Failed to generate structure with ETKDG")
+
+ return mol
+
+ def generate_structure_uff(self, smiles, max_attempts=20):
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
+ best_mol = None
+ lowest_energy = float('inf')
+
+ for attempt in range(max_attempts):
+ try:
+ test_mol = self.prepare_molecule(smiles)
+ params = self.get_etkdg_params(attempt)
+
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
+ vdwThresh=10.0, confId=0,
+ ignoreInterfragInteractions=True)
+
+ if res == 0:
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
+ if ff:
+ current_energy = ff.CalcEnergy()
+ if current_energy < lowest_energy:
+ lowest_energy = current_energy
+ best_mol = Chem.Mol(test_mol)
+ except Exception:
+ continue
+
+ if best_mol is None:
+ raise ValueError("Failed to generate optimized structure")
+
+ return best_mol
+
+ @staticmethod
+ def mol_to_sdf_bytes(mol):
+ """Convert RDKit molecule to SDF file bytes"""
+ # First write to StringIO in text mode
+ sio = StringIO()
+ writer = Chem.SDWriter(sio)
+ writer.write(mol)
+ writer.close()
+
+ # Convert the string to bytes
+ return sio.getvalue().encode('utf-8')
+
+def process_input(smiles_input=None, file_obj=None, show_linear=False,
+ show_segment_details=False, generate_3d=False, use_uff=False):
+ """Process input and create visualizations using PeptideAnalyzer"""
+ analyzer = PeptideAnalyzer()
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
+ structure_files = []
+
+ # Handle direct SMILES input
+ if smiles_input:
+ smiles = smiles_input.strip()
+
+ # First check if it's a peptide using analyzer's method
+ if not analyzer.is_peptide(smiles):
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
+
+ try:
+ # Create molecule
+ mol = Chem.MolFromSmiles(smiles)
+ if mol is None:
+ return "Error: Invalid SMILES notation.", None, None
+
+ # Generate 3D structures if requested
+ if generate_3d:
+ generator = PeptideStructureGenerator()
+
+ try:
+ # Generate ETKDG structure
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
+ writer = Chem.SDWriter(etkdg_path)
+ writer.write(mol_etkdg)
+ writer.close()
+ structure_files.append(etkdg_path)
+
+ # Generate UFF structure if requested
+ if use_uff:
+ mol_uff = generator.generate_structure_uff(smiles)
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
+ writer = Chem.SDWriter(uff_path)
+ writer.write(mol_uff)
+ writer.close()
+ structure_files.append(uff_path)
+
+ except Exception as e:
+ return f"Error generating 3D structures: {str(e)}", None, None, None
+
+ # Use analyzer to get sequence
+ segments = analyzer.split_on_bonds(smiles)
+
+ # Process segments and build sequence
+ sequence_parts = []
+ output_text = ""
+
+ # Only include segment analysis in output if requested
+ if show_segment_details:
+ output_text += "Segment Analysis:\n"
+ for i, segment in enumerate(segments):
+ output_text += f"\nSegment {i}:\n"
+ output_text += f"Content: {segment['content']}\n"
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
+
+ residue, mods = analyzer.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence_parts.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence_parts.append(residue)
+ output_text += f"Identified as: {residue}\n"
+ output_text += f"Modifications: {mods}\n"
+ else:
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
+ output_text += "\n"
+ else:
+ # Just build sequence without detailed analysis in output
+ for segment in segments:
+ residue, mods = analyzer.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence_parts.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence_parts.append(residue)
+
+ # Check if cyclic using analyzer's method
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
+ three_letter = '-'.join(sequence_parts)
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
+
+ if is_cyclic:
+ three_letter = f"cyclo({three_letter})"
+ one_letter = f"cyclo({one_letter})"
+
+ # Create cyclic structure visualization
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
+
+ # Create linear representation if requested
+ img_linear = None
+ if show_linear:
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
+ buf = BytesIO()
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
+ buf.seek(0)
+ img_linear = Image.open(buf)
+ plt.close(fig_linear)
+
+ # Add summary to output
+ summary = "Summary:\n"
+ summary += f"Sequence: {three_letter}\n"
+ summary += f"One-letter code: {one_letter}\n"
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
+ #if is_cyclic:
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
+
+ if structure_files:
+ summary += "\n3D Structures Generated:\n"
+ for filepath in structure_files:
+ summary += f"- {os.path.basename(filepath)}\n"
+
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
+
+ except Exception as e:
+ return f"Error processing SMILES: {str(e)}", None, None, None
+
+ # Handle file input
+ if file_obj is not None:
+ try:
+ # Handle file content
+ if hasattr(file_obj, 'name'):
+ with open(file_obj.name, 'r') as f:
+ content = f.read()
+ else:
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
+
+ output_text = ""
+ for line in content.splitlines():
+ smiles = line.strip()
+ if smiles:
+ # Check if it's a peptide
+ if not analyzer.is_peptide(smiles):
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
+ continue
+
+ # Process this SMILES
+ segments = analyzer.split_on_bonds(smiles)
+ sequence_parts = []
+
+ # Add segment details if requested
+ if show_segment_details:
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
+ for i, segment in enumerate(segments):
+ output_text += f"\nSegment {i}:\n"
+ output_text += f"Content: {segment['content']}\n"
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
+ residue, mods = analyzer.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence_parts.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence_parts.append(residue)
+ output_text += f"Identified as: {residue}\n"
+ output_text += f"Modifications: {mods}\n"
+ else:
+ for segment in segments:
+ residue, mods = analyzer.identify_residue(segment)
+ if residue:
+ if mods:
+ sequence_parts.append(f"{residue}({','.join(mods)})")
+ else:
+ sequence_parts.append(residue)
+
+ # Get cyclicity and create sequence
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
+
+ output_text += f"\nSummary for SMILES: {smiles}\n"
+ output_text += f"Sequence: {sequence}\n"
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
+ if is_cyclic:
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
+ output_text += "-" * 50 + "\n"
+
+ return output_text, None, None
+
+ except Exception as e:
+ return f"Error processing file: {str(e)}", None, None
+
+ return "No input provided.", None, None
\ No newline at end of file
diff --git a/a2d2_pep/pep_utils/utils.py b/a2d2_pep/pep_utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..0795a73ff2ed657e5136ab99f83e109f5ea9174e
--- /dev/null
+++ b/a2d2_pep/pep_utils/utils.py
@@ -0,0 +1,135 @@
+"""Console logger utilities.
+
+Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
+Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
+"""
+
+import logging
+import fsspec
+import lightning
+import torch
+from timm.scheduler import CosineLRScheduler
+import argparse
+import numpy as np
+import random
+import os
+
+def sample_categorical_logits(logits, dtype=torch.float64):
+ # do not require logits to be log-softmaxed
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
+ return (logits + gumbel_noise).argmax(dim=-1)
+
+def fsspec_exists(filename):
+ """Check if a file exists using fsspec."""
+ fs, _ = fsspec.core.url_to_fs(filename)
+ return fs.exists(filename)
+
+
+def fsspec_listdir(dirname):
+ """Listdir in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ return fs.ls(dirname)
+
+
+def fsspec_mkdirs(dirname, exist_ok=True):
+ """Mkdirs in manner compatible with fsspec."""
+ fs, _ = fsspec.core.url_to_fs(dirname)
+ fs.makedirs(dirname, exist_ok=exist_ok)
+
+
+def print_nans(tensor, name):
+ if torch.isnan(tensor).any():
+ print(name, tensor)
+
+
+class CosineDecayWarmupLRScheduler(
+ CosineLRScheduler,
+ torch.optim.lr_scheduler._LRScheduler):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._last_epoch = -1
+ self.step(epoch=0)
+
+ def step(self, epoch=None):
+ if epoch is None:
+ self._last_epoch += 1
+ else:
+ self._last_epoch = epoch
+ # We call either step or step_update, depending on
+ # whether we're using the scheduler every epoch or every
+ # step.
+ # Otherwise, lightning will always call step (i.e.,
+ # meant for each epoch), and if we set scheduler
+ # interval to "step", then the learning rate update will
+ # be wrong.
+ if self.t_in_epochs:
+ super().step(epoch=self._last_epoch)
+ else:
+ super().step_update(num_updates=self._last_epoch)
+
+
+class LoggingContext:
+ """Context manager for selective logging."""
+ def __init__(self, logger, level=None, handler=None, close=True):
+ self.logger = logger
+ self.level = level
+ self.handler = handler
+ self.close = close
+
+ def __enter__(self):
+ if self.level is not None:
+ self.old_level = self.logger.level
+ self.logger.setLevel(self.level)
+ if self.handler:
+ self.logger.addHandler(self.handler)
+
+ def __exit__(self, et, ev, tb):
+ if self.level is not None:
+ self.logger.setLevel(self.old_level)
+ if self.handler:
+ self.logger.removeHandler(self.handler)
+ if self.handler and self.close:
+ self.handler.close()
+
+
+def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
+ """Initializes multi-GPU-friendly python logger."""
+
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ for level in ('debug', 'info', 'warning', 'error',
+ 'exception', 'fatal', 'critical'):
+ setattr(logger,
+ level,
+ lightning.pytorch.utilities.rank_zero_only(
+ getattr(logger, level)))
+
+ return logger
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def set_seed(seed, use_cuda):
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ # torch.backends.cudnn.deterministic = True
+ if use_cuda:
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ print(f'=> Seed of the run set to {seed}')
+
diff --git a/a2d2_pep/remasking_scheduleaware.py b/a2d2_pep/remasking_scheduleaware.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c5d1d35e76e1b52c531060b0c2a87f884a062a0
--- /dev/null
+++ b/a2d2_pep/remasking_scheduleaware.py
@@ -0,0 +1,181 @@
+"""
+Schedule-aware remasking and insertion logic that ensures the number of masked tokens
+follows the interpolant schedule.
+"""
+import torch
+import numpy as np
+
+def apply_schedule_aware_insertion(
+ model,
+ xt_tmp,
+ new_xt,
+ t,
+ dt,
+ ext,
+ mask,
+ pad,
+ max_length,
+ orig_mask,
+ new_pos_orig,
+ quality_threshold=1,
+):
+ """
+ Remove low-quality insertions based on insertion confidence while respecting
+ the interpolant schedule for expected sequence length.
+
+ Args:
+ model: Model with planner and interpolant
+ xt_tmp: Sequence after insertion [B, L]
+ new_xt: Sequence before insertion [B, L]
+ t: Current time [B]
+ dt: Time step size
+ ext: Number of insertions per gap [B, L+1]
+ mask: Mask token ID
+ pad: Pad token ID
+ max_length: Maximum sequence length
+ orig_mask: Mask of original token positions [B, L]
+ new_pos_orig: New positions of original tokens [B, L]
+ quality_threshold: If a float, drop insertions with confidence below it; if None, use schedule-driven deletion
+
+ Returns:
+ xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
+ """
+ device = xt_tmp.device
+ batch_size, L = xt_tmp.shape
+ total_ext = ext.sum(dim=1)
+
+ # Only proceed if there were insertions
+ if total_ext.sum() == 0:
+ return xt_tmp
+
+ # Get planner predictions on inserted state. The insertion head is trained
+ # with the pre-step time t (see loss_insert_planner_flexible), so condition
+ # on t here too; t_next is still used below for the length schedule.
+ t_next = t + dt
+ planner_out = model.planner(xt_tmp, t)
+ insertion_conf = planner_out.get("insertion_conf", None)
+
+ if insertion_conf is None:
+ return xt_tmp
+
+ insertion_conf = insertion_conf.squeeze(-1) # (B, L)
+
+ # Expected sequence length at next timestep according to schedule
+ current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
+ expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
+ estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
+ expected_length = estimated_final_length * expected_progress # [B]
+
+ # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
+ # Fancy-indexing scatter avoids the per-batch python loop.
+ valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
+ valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
+ is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
+ is_original[valid_b, valid_p] = True
+ inserted_positions = (xt_tmp == mask) & ~is_original
+
+ # Two deletion modes, selected by `quality_threshold`:
+ # * float: drop insertions whose confidence is below the threshold, capped
+ # so the length never falls below the scheduled minimum.
+ candidates = inserted_positions & (insertion_conf < quality_threshold)
+ num_bad = candidates.sum(dim=1) # [B], long
+ min_length = expected_length.long().clamp(min=1) # [B]
+ max_removable = (current_length_after.long() - min_length).clamp(min=0)
+ length_after_removal = current_length_after.long() - num_bad
+ schedule_violates = length_after_removal < min_length
+ k_per_row = torch.where(schedule_violates, max_removable, num_bad)
+ k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
+
+ if not candidates.any():
+ return xt_tmp
+
+ # Select the lowest-confidence candidates per row via a sort.
+ neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
+ scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
+ _, sorted_indices = scores.sort(dim=1, descending=True)
+ positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
+ final_bad = torch.zeros_like(candidates)
+ final_bad.scatter_(1, sorted_indices, keep_in_topk)
+
+ if not final_bad.any():
+ return xt_tmp
+
+ # Compact each row to the left (keep good, drop bad), then pad the tail.
+ # Stable sort by the bad flag pushes bad positions to the right.
+ sort_key = final_bad.long()
+ _, perm = torch.sort(sort_key, dim=1, stable=True)
+ xt_tmp = torch.gather(xt_tmp, 1, perm)
+ num_keep = (~final_bad).sum(dim=1) # [B]
+ tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
+ xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
+
+ return xt_tmp
+
+
+def apply_schedule_aware_remasking(
+ model,
+ new_xt,
+ t,
+ dt,
+ remasking_conf,
+ clean_index,
+ mask,
+ neg_inf,
+ batch_size,
+ unmask_quality_threshold=None,
+):
+ """
+ Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
+
+ Args:
+ model: Model with interpolant that has an unmask_schedule
+ new_xt: Current sequence [B, L]
+ t: Current time [B]
+ dt: Time step size
+ remasking_conf: Confidence scores for tokens [B, L]
+ clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
+ mask: Mask token ID
+ neg_inf: Negative infinity tensor
+ batch_size: Batch size
+ unmask_quality_threshold: If None (default), remask exactly the schedule
+ excess (count-based). If a float, ignore the schedule budget entirely
+ and remask EVERY clean token whose unmasking-quality confidence is
+ below the threshold. Higher threshold => more aggressive remasking.
+
+ Returns:
+ new_xt: Modified sequence with schedule-aware remasking applied
+ """
+ # Threshold gate (overrides the schedule-driven count when set): remask every
+ # clean token whose unmasking-quality confidence is below the threshold,
+ # regardless of the schedule budget. Higher threshold => more remasking.
+ if unmask_quality_threshold is not None:
+ to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
+ return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
+
+ t_next = t + dt
+ num_clean = clean_index.sum(dim=1) # [B], long
+ current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
+ expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
+ expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
+ masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
+
+ # Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
+ k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
+
+ if k_per_row.sum() == 0:
+ return new_xt
+
+ # Use confidence to decide which clean tokens to remask: lowest conf first.
+ remasking_score_temp = -1.0 * remasking_conf # low conf = high score
+ remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
+
+ _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
+ L = remasking_score_temp.shape[1]
+ positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
+ keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
+ to_mask = torch.zeros_like(clean_index)
+ to_mask.scatter_(1, sorted_indices, keep_in_topk)
+ new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
+
+ return new_xt
diff --git a/a2d2_pep/sampling.py b/a2d2_pep/sampling.py
new file mode 100755
index 0000000000000000000000000000000000000000..2bdd2eee730434849c946259f801ad527fd15723
--- /dev/null
+++ b/a2d2_pep/sampling.py
@@ -0,0 +1,1401 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path
+
+import torch
+from dataclasses import dataclass
+from typing import Any, Literal, Optional
+import numpy as np
+import pandas as pd
+
+from lightning_modules.mdm import MaskedDiffusionModule
+
+
+@dataclass
+class SamplingTraceDatapoint:
+ t: float
+ event_type: Literal["insertion", "change"]
+ position: int
+ token: Any
+
+
+@dataclass
+class SamplingResult:
+ samples: torch.Tensor
+ # Trace is supposed to be processed sequentially as updates are not commutative
+ trace: Optional[list[SamplingTraceDatapoint]]
+
+ def __iter__(self):
+ yield from [self.samples, self.trace]
+
+
+# Sample from categorical distribution for each position using the transition probabilities
+def _sample_tokens(probs: torch.Tensor) -> torch.Tensor:
+ """Sample one token per position from probability distribution.
+ Args:
+ probs: [batch_size, seq_len, vocab_size] transition probabilities
+ Returns:
+ [batch_size, seq_len] sampled token indices
+ """
+ batch_size, seq_len, vocab_size = probs.shape
+ flat_probs = probs.view(-1, vocab_size)
+ samples = torch.multinomial(flat_probs, num_samples=1)
+ return samples.view(batch_size, seq_len)
+
+
+def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor:
+
+ batch_size, seq_len, vocab_size = probs.shape
+
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device)
+ noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities
+
+ # select the highest score (most likely category after Gumbel noise)
+ samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long)
+
+ return samples.view(batch_size, seq_len)
+
+@torch.no_grad()
+def mdm_euler_sampling(
+ model: MaskedDiffusionModule,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+):
+ assert not return_trace, "Trace is not yet implemented in MDM Euler sampling"
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ for i in range(steps):
+ print("i-th sampling step")
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ _xt = xt.clone()
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0
+ print(trans_prob[mask_pos + (mask,)])
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt = torch.where(xt != mask, xt, new_xt)
+
+ xt = new_xt
+ t = t + dt
+
+ return xt, []
+
+
+@torch.no_grad()
+def any_order_mask_insertion_euler_sampling(
+ model: torch.nn.Module,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ # 1) Initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # Check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # Check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ return xt, sampling_trace
+
+@torch.no_grad()
+def batch_mcts_reverse_step(
+ xt: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ model: torch.nn.Module,
+ pretrained: torch.nn.Module,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ last_step: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ xt = xt.repeat(batch_size, 1)
+
+ # squeeze to remove extra dimensions, then expand to batch_size
+ t = t.squeeze().expand(batch_size)
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ if last_step:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if not last_step:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ # log P(ext; λ) = ext*log(λ) - λ
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+ log_pretrained_step += log_pretrained_insert
+ log_policy_step += log_policy_insert
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
+
+
+@torch.no_grad()
+def mcts_reverse_step(
+ xt: torch.Tensor,
+ t: torch.Tensor,
+ dt: float,
+ model: torch.nn.Module,
+ pretrained: torch.nn.Module,
+ mask: int,
+ pad: int,
+ max_length: int,
+ last_step: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ batch_size = xt.size(0)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ if last_step:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if not last_step:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ # log P(ext; λ) = ext*log(λ) - λ
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+ log_pretrained_step += log_pretrained_insert
+ log_policy_step += log_policy_insert
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch
+ gaps = torch.arange(seq_dim, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ return xt_tmp, log_rnd, log_policy_step, log_pretrained_step
+
+@torch.no_grad()
+def any_order_euler_sampling_with_schedule(
+ model: torch.nn.Module,
+ time_schedule: torch.Tensor,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+) -> SamplingResult:
+ device = next(model.parameters()).device
+
+ time_schedule = time_schedule.to(device)
+ if time_schedule[0] < time_schedule[-1]:
+ time_schedule = torch.flip(time_schedule, [0]) # descending order
+
+ steps = len(time_schedule) - 1
+
+ # initialize all-pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # use scheduled timesteps
+ t = time_schedule[i].repeat(batch_size)
+ t_next = time_schedule[i + 1]
+ dt = (t - t_next).abs() # timestep difference
+
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0)
+
+ # add "stay" probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+
+ if mask_has_zero_prob.any():
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ———
+ ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # Check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # Check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+
+ return xt, sampling_trace
+
+
+@torch.no_grad()
+def any_order_mask_insertion_euler_sampling_with_rnd(
+ model, pretrained, reward_model, analyzer,
+ tokenizer, steps,
+ mask,
+ pad,
+ batch_size,
+ max_length,
+ return_trace = False,
+ alpha = 0.1,
+ temperature: float = 1.0,
+):
+ device = next(model.parameters()).device
+
+ # initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ # initialize log_rnd to accumulate log probability ratios
+ log_rnd = torch.zeros(batch_size, device=device)
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— get pretrained model rates for log_rnd computation ———
+ pretrained_pred = pretrained(xt, t)
+ pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t)
+ pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V)
+ pretrained_len_rate = pretrained_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # Same for pretrained
+ pretrained_unmask_rate[xt != mask] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = 0
+ pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+ pretrained_trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # ——— compute log probabilities for RND ———
+ lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+ lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1)
+
+ changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad)
+
+ log_policy_step = (lp * changed_mask).sum(dim=1)
+ log_pretrained_step = (lp_pre * changed_mask).sum(dim=1)
+
+ log_rnd = log_pretrained_step - log_policy_step # (B,)
+
+ if i != steps - 1:
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+
+ insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1)
+ pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1)
+
+ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,)
+ log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,)
+
+ log_insert_diff = log_pretrained_insert - log_policy_insert # (B,)
+ log_rnd += log_insert_diff
+
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # check if the token was changed
+ for i in range(batch_size):
+ for j in range(max_length):
+ if xt[i, j] != pad and xt[i, j] != new_xt[i, j]:
+ sampling_trace[i].append(
+ SamplingTraceDatapoint(
+ t=t[i].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[i, j].item(),
+ )
+ )
+
+ # check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[i, id]:
+ sampling_trace[i].append(
+ SamplingTraceDatapoint(
+ t=t[i].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ # change rewards for peptides
+ samples = xt.to(device)
+
+ # store raw token IDs
+ # Decode and strip samples
+ decoded_samples = tokenizer.batch_decode(samples)
+
+ valid_x_final = []
+ validSequences = []
+ valid_log_rnd = []
+
+ for idx, seq in enumerate(decoded_samples):
+ # check if the peptide is valid
+ if analyzer.is_peptide(seq):
+ valid_x_final.append(xt[idx])
+ validSequences.append(seq)
+ valid_log_rnd.append(log_rnd[idx])
+
+ print("len valid sequences:", len(validSequences))
+ # compute multi-objective rewards
+ score_vectors = reward_model(input_seqs=validSequences)
+ scalar_rewards = np.sum(score_vectors, axis=-1)
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device)
+
+ print(f"scalar reward dim{len(scalar_rewards)}")
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
+
+ log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha
+ valid_x_final = torch.stack(valid_x_final, dim=0)
+
+ return valid_x_final, log_rnd, scalar_rewards, sampling_trace
+
+@torch.no_grad()
+def any_order_finetuned_euler_sampler(
+ model, reward_model, analyzer,
+ tokenizer, steps,
+ mask,
+ pad,
+ batch_size,
+ max_length,
+ return_trace = False,
+ dataframe = False,
+ temperature: float = 1.0,
+ ):
+ device = next(model.parameters()).device
+
+ # initialize all‑pad sequence and trace
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+ sampling_trace = [[] for _ in range(batch_size)] if return_trace else None
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred_rate = model(xt, t)
+ pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == mask).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add “stay” probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ # Apply temperature scaling
+ if temperature != 1.0:
+ logits = torch.log(trans_prob + 1e-10) / temperature
+ trans_prob = torch.softmax(logits, dim=-1)
+
+ if i == steps - 1:
+ print("Final step, removing mask token from sampling")
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # create uniform distribution over valid tokens (excluding mask and pad)
+ uniform_prob = torch.zeros_like(trans_prob[0])
+ uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ if i != steps - 1:
+ # gap-wise insertion refactored — compute new length, fill masks, scatter tokens
+ ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ mask_fill = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ if return_trace:
+ # check if the token was changed
+ for batch_idx in range(batch_size):
+ for j in range(max_length):
+ if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="change",
+ position=j,
+ token=new_xt[batch_idx, j].item(),
+ )
+ )
+
+ # check if a new token was inserted
+ for j in range(max_length):
+ id = max_length - j - 1
+ if ext[batch_idx, id]:
+ sampling_trace[batch_idx].append(
+ SamplingTraceDatapoint(
+ t=t[batch_idx].item(),
+ event_type="insertion",
+ position=id,
+ token=mask,
+ )
+ )
+
+ xt = xt_tmp
+ t = t + dt
+
+ # start eval
+ samples = xt.to(device)
+
+ decoded_samples = tokenizer.batch_decode(samples)
+
+ valid_x_final = []
+ validSequences = []
+
+ for idx, seq in enumerate(decoded_samples):
+ if analyzer.is_peptide(seq):
+ valid_x_final.append(samples[idx])
+ validSequences.append(seq)
+
+ print("len valid sequences:", len(validSequences))
+ valid_fraction = len(validSequences) / batch_size
+
+ if (len(validSequences) != 0):
+ # add scores to log
+ score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
+ average_scores = score_vectors.T
+
+ affinity = average_scores[0]
+ sol = average_scores[1]
+ hemo = average_scores[2]
+ nf = average_scores[3]
+ permeability = average_scores[4]
+
+ else:
+ zeros = [0.0]
+
+ affinity = zeros
+ sol = zeros
+ hemo = zeros
+ nf = zeros
+ permeability = zeros
+
+ if dataframe:
+ df = pd.DataFrame({
+ "Peptide Sequence": validSequences,
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
+ "Solubility": sol if len(validSequences) else [0.0],
+ "Hemolysis": hemo if len(validSequences) else [0.0],
+ "Nonfouling": nf if len(validSequences) else [0.0],
+ "Permeability": permeability if len(validSequences) else [0.0],
+ })
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df
+
+ return samples, affinity, sol, hemo, nf, permeability, valid_fraction
+
+@torch.no_grad()
+def mdm_tau_leaping_sampling(
+ model: MaskedDiffusionModule,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ temperature: float = 1.0,
+):
+ assert not return_trace, "Trace is not yet supported"
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device)
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ for i in range(steps):
+ # ——— predict and convert rates ———
+ pred = model(xt, t)
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
+ unmask_rate = pred.unmask_rate # (B, L, V)
+
+ if i == steps - 1:
+ # last step: deterministic unmask via argmax
+ mask_pos = xt == mask # (B, L)
+ new_token = unmask_rate.argmax(dim=2) # (B, L)
+ new_xt = xt.clone()
+ new_xt[mask_pos] = new_token[mask_pos]
+ new_xt = torch.where(xt != mask, xt, new_xt)
+ xt = new_xt
+ t = t + dt
+ continue
+ # tau-leaping via Poisson counts
+ counts = torch.poisson(unmask_rate * dt).long()
+ mask_pos = xt == mask # (B, L)
+ # zero out non-mask positions and mask→mask
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
+ counts[..., mask] = 0
+ # only accept exactly one event
+ sum_c = counts.sum(dim=2) # (B, L)
+ one_event = sum_c == 1
+ new_token = counts.argmax(dim=2) # (B, L)
+
+ # build new xt
+ new_xt = xt.clone()
+ new_xt[one_event] = new_token[one_event]
+ # keep pads and already-unmasked tokens
+ new_xt = torch.where(xt != mask, xt, new_xt)
+ xt = new_xt
+ t = t + dt
+
+ return xt, []
+
+# Not used in production, for debugging purposes
+lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}
+
+def binomial_mass(k, n, p):
+ """
+ Calculate the probability mass function (PMF) for a binomial distribution.
+
+ Args:
+ k (int): Number of successes
+ n (int): Number of trials
+ p (float): Probability of success in a single trial
+
+ Returns:
+ float: Probability mass P(X = k)
+ """
+ import math
+
+ # Calculate binomial coefficient (n choose k)
+ try:
+ binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k))
+ except ValueError:
+ # Handle cases where k > n or negative values
+ return 0.0
+
+ # Calculate probability mass
+ return binom_coef * (p ** k) * ((1 - p) ** (n - k))
+
+def calculate_rate_batch(alpha_t, len_t):
+ """
+ Calculate rate for a batch of alpha_t and len_t values.
+
+ Args:
+ alpha_t (torch.Tensor): Tensor of shape (batch_size,)
+ len_t (torch.Tensor): Tensor of shape (batch_size,)
+
+ Returns:
+ torch.Tensor: Tensor of shape (batch_size,) containing calculated rates
+ """
+ batch_size = alpha_t.shape[0]
+ device = alpha_t.device
+
+ # Initialize tensors for numerator and denominator
+ nom = torch.zeros(batch_size, device=device)
+ denom = torch.zeros(batch_size, device=device)
+
+ for length, probability in lengths.items():
+ # Create mask for valid entries where len_t <= length
+ valid_mask = (len_t <= length) & (len_t >= 0)
+
+ if not valid_mask.any():
+ continue
+
+ valid_indices = valid_mask.nonzero(as_tuple=True)[0]
+ valid_len_t = len_t[valid_indices]
+ valid_alpha_t = alpha_t[valid_indices]
+
+ # Calculate binomial probabilities efficiently using torch distribution
+ binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t)
+ binom_probs = binom_dist.log_prob(valid_len_t).exp()
+
+ # Update numerator and denominator for valid indices
+ nom[valid_indices] += (length - valid_len_t) * probability * binom_probs
+ denom[valid_indices] += probability * binom_probs
+
+ # Handle division by zero in a vectorized way
+ result = torch.zeros_like(nom)
+ div_mask = denom > 0
+ result[div_mask] = nom[div_mask] / (denom[div_mask])
+
+ return result
+
+# Keep the original function for backward compatibility
+def calculate_rate(alpha_t, len_t):
+ """Legacy scalar version of calculate_rate"""
+ if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0:
+ return calculate_rate_batch(alpha_t, len_t)
+
+ nom, denom = 0, 0
+ for length, probability in lengths.items():
+ if length >= len_t:
+ nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t)
+ denom += probability * binomial_mass(len_t, length, alpha_t)
+
+ if denom == 0:
+ return 0.0
+
+ return nom /denom
+
+
+@torch.no_grad()
+def any_order_mask_insertion_tau_leaping_sampling(
+ model: torch.nn.Module,
+ steps: int,
+ mask: int,
+ pad: int,
+ batch_size: int,
+ max_length: int,
+ return_trace: bool = False,
+ confidence_based_sampling: bool = True, # whether to use confidence-based decoding
+ alpha: float = 5.0, # hyperparameter for window size calculation
+ max_window: int = 32, # Maximum window size for sliding window
+ confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy"
+ use_sliding_window: bool = False, # whether to use sliding window for position selection
+ temperature: float = 1.0,
+) -> SamplingResult:
+
+ device = next(model.parameters()).device
+ xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device)
+ sampling_trace = []
+ dt = 1.0 / steps
+ t = torch.zeros(batch_size, device=device)
+
+ # Precompute row indices for scatter
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, max_length)
+ )
+ pos_idx_L = (
+ torch.arange(max_length, device=device)
+ .view(1, max_length)
+ .expand(batch_size, max_length)
+ )
+
+ for i in range(steps):
+ # --- predict rates ---
+ pred = model(xt, t)
+ xt_len = (xt != pad).sum(dim=1)
+ pred = model.interpolant.to_actual_rate(xt, pred, t)
+ unmask_rate = pred.unmask_rate # (B, L, V)
+ len_rate = pred.length_rate # (B, L+1)
+
+ if i == steps - 1:
+ # last step: deterministic unmask via argmax
+ mask_pos = xt == mask
+ new_token = unmask_rate.argmax(dim=2)
+ new_xt = xt.clone()
+ new_xt[mask_pos] = new_token[mask_pos]
+ new_xt = torch.where(xt == pad, pad, new_xt)
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+ xt = new_xt
+ t = t + dt
+ continue
+
+ # --- confidence-based decoding ---
+ if confidence_based_sampling > 0.0:
+ # Confidence-based unmasking (vectorized)
+ mask_positions = (xt == mask) # (B, L)
+ num_mask_positions = mask_positions.sum(dim=1) # (B,)
+
+ # 1. Determine number of tokens to unmask using Poisson
+ unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,)
+
+ # 2. Calculate confidence based on selected method
+ if confidence_method == "position":
+ # Position-based confidence: position i / len(xt)
+ xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths
+ position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L)
+ confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L)
+
+ elif confidence_method == "top_prob":
+ # Top probability confidence
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ confidence = unmask_probs.max(dim=-1)[0] # (B, L)
+
+ elif confidence_method == "prob_diff":
+ # Probability difference confidence (top - second top)
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V)
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2)
+ confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L)
+
+ elif confidence_method == "entropy":
+ # Entropy-based confidence (lower entropy = higher confidence)
+ import torch.nn.functional as F
+ token_logits = unmask_rate # (B, L, V)
+ unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V)
+ entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L)
+ confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence
+
+ else:
+ raise ValueError(f"Unknown confidence_method: {confidence_method}")
+
+ # 3. Apply window constraint if enabled
+ if use_sliding_window:
+ # Calculate dynamic k for each batch
+ k_values = torch.minimum(
+ torch.minimum(
+ (alpha * unmask_counts).long(),
+ torch.tensor(max_window, device=device)
+ ), num_mask_positions) # (B,)
+
+ # Get cumulative count of mask positions
+ mask_cumsum = mask_positions.cumsum(dim=1) # (B, L)
+
+ # Create window mask: position is eligible if it's a mask and within first k masks
+ is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L)
+ window_mask = mask_positions & is_within_window # (B, L)
+
+ # Set confidence to -inf for positions outside the window or non-mask positions
+ confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device))
+ else:
+ # No window constraint - only mask positions are eligible
+ confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device))
+
+ new_xt = xt.clone()
+
+ # vectorized unmasking
+ max_unmask = unmask_counts.max().item()
+ if max_unmask > 0:
+ _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask)
+
+ # create mask for valid unmask operations
+ unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask)
+
+ most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L)
+
+ selected_positions = all_top_indices[unmask_mask]
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask]
+
+ new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions]
+ else:
+ # --- tau-leaping unmask via Poisson ---
+ counts = torch.poisson(unmask_rate * dt).long()
+ mask_pos = xt == mask
+ counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0
+ counts[..., mask] = 0
+ sum_c = counts.sum(dim=2)
+ one_event = sum_c == 1
+ new_token = counts.argmax(dim=2)
+ new_xt = xt.clone()
+ new_xt[one_event] = new_token[one_event]
+ new_xt = torch.where(xt == pad, pad, new_xt)
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # insertion only on non-last
+ if i != steps - 1:
+ # --- Poisson insertion, compute new lengths and fill masks ---
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ gaps = torch.arange(max_length + 1, device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ # compute prefix sums of insertions
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ # initialize with pads, then fill mask up to new_len
+ xt_tmp = torch.full_like(xt, pad)
+ mask_pos = pos_idx_L < new_len.view(batch_size, 1)
+ xt_tmp[mask_pos] = mask
+
+ # shift and scatter original tokens
+ new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+ else:
+ xt_tmp = new_xt
+
+ xt = xt_tmp
+ t = t + dt
+ if return_trace:
+ sampling_trace.append(xt)
+
+ return xt, sampling_trace
diff --git a/a2d2_pep/scripts/run_peptide_finetune.slurm b/a2d2_pep/scripts/run_peptide_finetune.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..2a42d480dd9ec3a4dc7d3a7642af0996dfdd6088
--- /dev/null
+++ b/a2d2_pep/scripts/run_peptide_finetune.slurm
@@ -0,0 +1,210 @@
+#!/bin/bash
+# NOTE: --partition and --qos below are specific to our cluster. Change them
+# (or remove them and pass `--partition` on the `sbatch` command line) to match
+# the partitions/QOS available on yours.
+#SBATCH --job-name=peptide-finetune-len256
+#SBATCH --partition=b200-mig90
+#SBATCH --qos=mig
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=8
+#SBATCH --ntasks-per-node=1
+#SBATCH --mem=80GB
+#SBATCH --time=02-00:00:00
+#SBATCH --output=logs/peptide_finetune_%A.log
+
+# =====================================================================
+# run_peptide_finetune.slurm
+#
+# Single-mode job (1 MIG GPU) running ONE finetune_quality (peptide)
+# experiment. Select which mode to run via the MODE_ID variable below
+# (or override at submit time with `sbatch --export=ALL,MODE_ID=2 ...`):
+# 0) A2D2 (Ours) – with full planner (alternating)
+# 1) A2D2 w/o quality – --disable_planner
+# 2) A2D2 w/o insertion planner – --disable_insertion_planner
+# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner
+#
+# The job trains the selected mode then evaluates the resulting
+# checkpoint on the same GPU.
+# =====================================================================
+
+set -e
+
+# --- Mode selection ---------------------------------------------------
+# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`.
+MODE_ID="${MODE_ID:-0}"
+
+# Run prefix: YYYYMMDD + SLURM job ID
+DATE_STAMP=$(date +%Y%m%d)
+PREFIX="${DATE_STAMP}_job${SLURM_JOB_ID:-local$(date +%H%M%S)}"
+
+# Default protein target (must be defined before path definitions below)
+PROT_NAME=tfr
+
+# --- Paths ------------------------------------------------------------
+# Repo root is resolved at submit time so the script works from any clone:
+# - set A2D2_ROOT explicitly, OR
+# - run `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR), OR
+# - fall back to this script's location (a2d2_pep/scripts/ -> two levels up).
+if [ -n "${A2D2_ROOT:-}" ]; then
+ HOME_LOC="$A2D2_ROOT"
+elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
+ HOME_LOC="$SLURM_SUBMIT_DIR"
+else
+ HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
+fi
+SCRIPT_LOC="$HOME_LOC/a2d2_pep"
+LOG_LOC="$HOME_LOC/logs"
+SAVE_DIR="$HOME_LOC/checkpoints/finetune_test_peptides_${PROT_NAME}"
+RESULTS_DIR="$HOME_LOC/results/peptide_test_ablation_${PROT_NAME}"
+
+cd "$SCRIPT_LOC"
+
+# BASE_PATH is passed as --base_path to finetune_quality.py: it's used
+# to build the plot output path at $BASE_PATH/flexible/results/
+# (see finetune_quality.py:421). The pretrained checkpoint is now passed
+# explicitly via --checkpoint_path below, so base_path no longer needs
+# to follow the legacy /scratch layout.
+BASE_PATH="${A2D2_BASE_PATH:-$HOME_LOC}"
+
+mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR"
+
+# --- Environment setup ------------------------------------------------
+# Do NOT hardcode your W&B key. Either `wandb login` once on the cluster,
+# export WANDB_API_KEY in your shell/SLURM environment before submitting,
+# or set WANDB_MODE=offline to skip logging entirely.
+export WANDB_DIR=$HOME_LOC/.wandb
+export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb
+export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb
+# Stop wandb from hijacking stdout/stderr (its default fd-redirect mode sends
+# all output to wandb/run-*/files/output.log and freezes the RUN_LOG below).
+# With console off, everything flows to the `>> "$RUN_LOG" 2>&1` redirect.
+export WANDB_CONSOLE=off
+mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR"
+
+export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache
+mkdir -p "$TRITON_CACHE_DIR"
+
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda
+# install, or just have `conda` on PATH; override CONDA_ENV if your env name
+# differs from the one created by environment.yml.
+CONDA_ENV="${CONDA_ENV:-a2d2}"
+if [ -n "${CONDA_ROOT:-}" ]; then
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
+elif command -v conda >/dev/null 2>&1; then
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
+else
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
+ exit 1
+fi
+PYTHON_EXECUTABLE=$(which python)
+
+# Pretrained base checkpoint
+PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_pep.ckpt"
+
+# --- Shared training hyperparameters ----------------------------------
+COMMON_ARGS=(
+ --base_path "$BASE_PATH"
+ --checkpoint_path "$PRETRAINED_CKPT"
+ --prot_name "$PROT_NAME"
+ --noise_removal
+ --wdce_num_replicates 8
+ --pool_size 100
+ --pool_refresh_fraction 1.0
+ --buffer_size 50
+ --batch_size 200
+ --total_num_steps 256
+ --num_iter 20
+ --resample_every_n_step 10
+ --num_epochs 1000
+ --save_every_n_epochs 50
+ --reset_every_n_step 1
+ --alpha 0.1
+ --no_mcts
+ --schedule_warmup_epochs 20
+ --alternation_frequency 5
+ --num_remasking 3
+ --quality_threshold 0.2
+ --training_mini_batch_size 10
+ --max_length 256
+ --eval_every_n_epochs 50
+ --min_peptide_bonds 4
+ --grad_clip
+ --seed 42
+)
+
+# --- Shared evaluation hyperparameters --------------------------------
+EVAL_COMMON_ARGS=(
+ --pretrained_ckpt "$PRETRAINED_CKPT"
+ --num_samples 50
+ --batch_size 200
+ --max_length 256
+ --total_num_steps 256
+ --num_remasking 3
+ --quality_threshold 0.2
+ --prot_name "$PROT_NAME"
+ --seed 42
+)
+
+# =====================================================================
+# Pick experiment from $MODE_ID
+# =====================================================================
+case "$MODE_ID" in
+ 0) MODE="with_planner"; EXTRA_ARGS=() ;;
+ 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;;
+ 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;;
+ 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;;
+ *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;;
+esac
+
+RUN_NAME="${PREFIX}_peptide_${PROT_NAME}_${MODE}"
+RUN_LOG="$LOG_LOC/${RUN_NAME}.log"
+RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}"
+RESULTS_SUBDIR="$RESULTS_DIR/${MODE}"
+mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR"
+
+echo "=== Peptide finetune (MODE_ID=$MODE_ID) ==="
+echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST"
+echo "Mode: $MODE"
+echo "Save dir: $RUN_SAVE_DIR"
+echo "Results dir: $RESULTS_SUBDIR"
+echo "Python: $PYTHON_EXECUTABLE"
+echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}"
+
+# =====================================================================
+# Train
+# =====================================================================
+$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_quality.py \
+ "${COMMON_ARGS[@]}" \
+ --devices 1 \
+ "${EXTRA_ARGS[@]}" \
+ --save_path_dir "$RUN_SAVE_DIR" \
+ >> "$RUN_LOG" 2>&1
+
+echo "Training finished for $MODE. Log: $RUN_LOG"
+
+# =====================================================================
+# Evaluate
+# =====================================================================
+# finetune_quality.py saves to $RUN_SAVE_DIR//last.ckpt,
+# so glob the run_name subdir.
+RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt 2>/dev/null | head -1)
+if [ -z "$RUN_CKPT" ]; then
+ echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval."
+ exit 1
+fi
+
+echo "Evaluating checkpoint: $RUN_CKPT"
+$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_peptide_table.py \
+ --checkpoint_path "$RUN_CKPT" \
+ "${EVAL_COMMON_ARGS[@]}" \
+ "${EXTRA_ARGS[@]}" \
+ --output_dir "$RESULTS_SUBDIR" \
+ --device cuda:0 \
+ >> "$RESULTS_SUBDIR/${RUN_NAME}_eval.log" 2>&1
+
+echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}_${PROT_NAME}.csv"
+
+conda deactivate
diff --git a/a2d2_pep/scripts/train_pep.sh b/a2d2_pep/scripts/train_pep.sh
new file mode 100755
index 0000000000000000000000000000000000000000..5888c59889c44134c85585edc8ed124d6437053e
--- /dev/null
+++ b/a2d2_pep/scripts/train_pep.sh
@@ -0,0 +1,93 @@
+#!/bin/bash
+#SBATCH --job-name=a2d2-pep-pretrain
+#SBATCH --partition=dgx-b200
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=4
+#SBATCH --ntasks-per-node=4
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=512GB
+#SBATCH --time=7-00:00:00
+# SLURM's own catch-file (anything printed before the exec redirect below, plus
+# slurm-infra messages). Relative to the submit dir, so submit this script from
+# the a2d2_pep/ directory; the real run output is redirected via exec below.
+#SBATCH --output=logs/slurm/%x_%j.out
+#SBATCH --error=logs/slurm/%x_%j.err
+#
+# Pretrain the any-length insertion MDM on ~11M peptide SMILES on a dgx-b200 node.
+# Submit with: sbatch scripts/train_pep.sh (from the a2d2_pep/ directory).
+#
+# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and
+# --ntasks-per-node must match; change both together (and they override the
+# training.devices value baked into config_pep.yaml via the hydra override below).
+
+DATE=$(date +%Y%m%d)
+SPECIAL_PREFIX='a2d2-peptide'
+
+# Resolve a2d2_pep/ (which holds train.py + config_pep.yaml) so paths are
+# repo-relative. This script lives in a2d2_pep/scripts/, so the direct-run
+# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled
+# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_pep/ directory).
+if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then
+ SCRIPT_DIR="$SLURM_SUBMIT_DIR"
+else
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+fi
+cd "$SCRIPT_DIR"
+
+# Auto-detect GPUs from the SLURM allocation (falls back to 4 for `bash` runs).
+DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-4}}
+NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES}
+NODES=${SLURM_NNODES:-1}
+
+LOG_LOC="$SCRIPT_DIR/logs"
+mkdir -p "$LOG_LOC/slurm"
+exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1
+
+# ---------------------------------------------------------------------------
+# Weights & Biases: log in once on your machine before running this script with
+# `wandb login` (or `export WANDB_API_KEY=`).
+# Do NOT hardcode your API key here. To disable W&B entirely, uncomment:
+# export WANDB_MODE=disabled
+# ---------------------------------------------------------------------------
+
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+
+# Activate the conda env that has the deps (torch / pytorch_lightning / hydra).
+# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override
+# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on
+# PATH; override CONDA_ENV if your env name differs from the one created by
+# environment.yml.
+CONDA_ENV="${CONDA_ENV:-a2d2}"
+if [ -n "${CONDA_ROOT:-}" ]; then
+ source "$CONDA_ROOT/bin/activate" "$CONDA_ENV"
+elif command -v conda >/dev/null 2>&1; then
+ source "$(conda info --base)/bin/activate" "$CONDA_ENV"
+else
+ echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2
+ exit 1
+fi
+
+# --- Distributed / NCCL setup (single node, intra-node NVLink) --------------
+ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}')
+if [ -z "$ETH_IFACE" ]; then
+ ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}')
+fi
+export NCCL_IB_DISABLE=1
+export NCCL_SOCKET_FAMILY=AF_INET
+export NCCL_SOCKET_IFNAME=$ETH_IFACE
+export NCCL_P2P_LEVEL=NVL
+
+export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1)
+export MASTER_PORT=$(shuf -i 15000-59999 -n 1)
+export NODE_RANK=${SLURM_NODEID:-0}
+
+echo "=== a2d2 peptide pretraining (dgx-b200) ==="
+echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS"
+
+# --task pep makes train.py load config_pep.yaml; the hydra overrides pin
+# devices/nodes to the SLURM allocation so the two never drift apart.
+srun --ntasks-per-node=$NTASKS python train.py --task pep \
+ training.devices=$DEVICES \
+ training.nodes=$NODES
+
+conda deactivate
diff --git a/a2d2_pep/train.py b/a2d2_pep/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..9823bf0fd4ced62852356a7b6ce5f83d1156301d
--- /dev/null
+++ b/a2d2_pep/train.py
@@ -0,0 +1,216 @@
+import torch
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+import os
+import sys
+import argparse
+import hydra
+from omegaconf import OmegaConf
+from datetime import datetime
+# Directory containing this file and the config_*.yaml files (used by Hydra below).
+CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
+# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
+sys.path.insert(0, os.path.dirname(CONFIG_DIR))
+
+import wandb
+from lightning_modules import AnyOrderInsertionFlowModule
+
+
+torch.set_printoptions(threshold=10_000)
+torch.set_float32_matmul_precision("high")
+
+# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
+torch._dynamo.config.optimize_ddp = False
+
+def train(config):
+ wandb_logger = None
+
+ # set the random seed
+ pl.seed_everything(42)
+ torch.manual_seed(42)
+
+ # Only initialize wandb on rank 0 to avoid multiple runs
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+ wandb.init(
+ project=config.wandb.project,
+ name=config.wandb.name,
+ config=OmegaConf.to_container(config, resolve=True), # Convert to dict
+ dir=config.wandb.path
+ )
+ wandb_logger = WandbLogger(
+ project=wandb.run.project,
+ name=wandb.run.name,
+ log_model=False, # Disable checkpoint uploading to save disk space
+ )
+
+ # Modify config to add timestamp to checkpoint directory
+ OmegaConf.set_struct(config, False)
+ time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
+ config.training.checkpoint_dir = os.path.join(
+ config.training.checkpoint_dir, time_string
+ )
+ OmegaConf.set_struct(config, True)
+
+ # Create checkpoint directory
+ os.makedirs(config.training.checkpoint_dir, exist_ok=True)
+
+ # Setup data module - check if using HuggingFace dataset
+ if hasattr(config, 'hf_dataset'):
+ # Imported lazily: the HF/SAFE path is only used by the molecule configs,
+ # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
+ from mol_dataset import setup_hf_data_and_update_config
+ print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
+ data_module = setup_hf_data_and_update_config(
+ config,
+ dataset_name=config.hf_dataset.name,
+ smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
+ )
+ else:
+ # Imported lazily: the local (arrow) path is used by the peptide config,
+ # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
+ from data.dataloading_for_dynamic_batching import setup_data_and_update_config
+ print("Using local dataset")
+ data_module = setup_data_and_update_config(config)
+
+ module = AnyOrderInsertionFlowModule(config)
+
+ # Initialize trainer
+
+ # Configure trainer arguments
+ # Map torch_dtype to Lightning precision
+ dtype_str = config.model.get('torch_dtype', 'bfloat16')
+ precision_map = {
+ 'float32': '32-true',
+ 'float16': '16-mixed',
+ 'bfloat16': 'bf16-mixed'
+ }
+ precision = precision_map.get(dtype_str, 'bf16-mixed')
+
+ trainer_kwargs = dict(
+ num_nodes=config.training.nodes,
+ accelerator="gpu",
+ devices=config.training.devices,
+ strategy="ddp",
+ precision=precision,
+ accumulate_grad_batches=(
+ config.training.batch_size
+ // (
+ config.training.per_gpu_batch_size
+ * config.training.nodes
+ * config.training.devices
+ )
+ ),
+ log_every_n_steps=10,
+ enable_checkpointing=True,
+ default_root_dir=config.training.checkpoint_dir,
+ gradient_clip_val=1.0,
+ )
+ # Only one of max_steps or max_epochs will be used
+ if config.training.max_steps is not None:
+ trainer_kwargs["max_steps"] = config.training.max_steps
+ elif config.training.num_epochs is not None:
+ trainer_kwargs["max_epochs"] = config.training.num_epochs
+ config.training.max_steps = config.training.max_steps
+ else:
+ raise ValueError(
+ "Either max_steps or num_epochs must be specified in the config"
+ )
+
+ if config.training.warmup_steps is None:
+ config.training.warmup_steps = int(config.training.max_steps * 0.01)
+
+ # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
+ checkpoint_callback = ModelCheckpoint(
+ monitor="train/total_loss",
+ mode="min",
+ save_top_k=config.training.save_top_k,
+ save_last=True,
+ filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
+ dirpath=config.training.checkpoint_dir,
+ # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
+ auto_insert_metric_name=False
+ )
+
+ # Add separate callback for periodic saves (no val_loss dependency). Use
+ # step-based saves for streaming datasets (save_every_n_steps) and epoch-based
+ # saves otherwise (save_every_n_epochs); whichever the config provides.
+ save_every_n_steps = config.training.get('save_every_n_steps', None)
+ save_every_n_epochs = config.training.get('save_every_n_epochs', None)
+ if save_every_n_steps is not None:
+ periodic_checkpoint_callback = ModelCheckpoint(
+ save_top_k=-1, # Save all periodic checkpoints
+ filename="step-{step:08d}",
+ dirpath=config.training.checkpoint_dir,
+ every_n_train_steps=save_every_n_steps,
+ auto_insert_metric_name=False
+ )
+ elif save_every_n_epochs is not None:
+ periodic_checkpoint_callback = ModelCheckpoint(
+ save_top_k=-1, # Save all periodic checkpoints
+ filename="epoch-{epoch:02d}",
+ dirpath=config.training.checkpoint_dir,
+ every_n_epochs=save_every_n_epochs,
+ auto_insert_metric_name=False
+ )
+ else:
+ raise ValueError(
+ "Either save_every_n_steps or save_every_n_epochs must be specified in the config"
+ )
+
+ trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
+
+ if wandb_logger is not None:
+ trainer_kwargs["logger"] = wandb_logger
+
+ trainer = pl.Trainer(**trainer_kwargs)
+
+ # Train the model
+ ckpt_path = None
+ if "resume_path" in config.training:
+ ckpt_path = config.training.resume_path
+
+ trainer.fit(module,
+ datamodule=data_module,
+ ckpt_path=ckpt_path)
+
+ # Only finish wandb on rank 0
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+ wandb.finish()
+
+
+if __name__ == '__main__':
+ # Parse arguments to get config name
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config_name', type=str, default='config',
+ help='Name of the config file to use')
+ parser.add_argument('--task', type=str, default=None,
+ help='Task name (uses config_{task}.yaml)')
+
+ # Parse known args (hydra will handle the rest)
+ args, unknown = parser.parse_known_args()
+
+ # Determine config name from task or config_name
+ if args.task:
+ config_name = f'config_{args.task}'
+ else:
+ config_name = args.config_name
+
+ print(f"Using config: {config_name}.yaml")
+
+ # Add config name to Hydra overrides (this persists across DDP subprocesses)
+ if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
+ unknown.insert(0, f'--config-name={config_name}')
+
+ # Reconstruct sys.argv for hydra
+ sys.argv = [sys.argv[0]] + unknown
+
+ # Define main function with default config (will be overridden by command line)
+ @hydra.main(version_base=None,
+ config_path=CONFIG_DIR,
+ config_name='config')
+ def main(config):
+ """Main entry point for training"""
+ train(config)
+
+ main()
\ No newline at end of file
diff --git a/assets/a2d2.gif b/assets/a2d2.gif
new file mode 100644
index 0000000000000000000000000000000000000000..2d30aa8009f5e608aa559974ef07a7a7ed0a98fa
--- /dev/null
+++ b/assets/a2d2.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:178ca7850ca39365492fea70cfc5e4f2e8653ceeda9a13dcd0438af61e1a83bb
+size 7826144
diff --git a/demo/quality_inference_demo.ipynb b/demo/quality_inference_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d1b22bdb1fd856d37cd2ec573bdc2c35cee27870
--- /dev/null
+++ b/demo/quality_inference_demo.ipynb
@@ -0,0 +1,785 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# **A2D2 Unmasking + Insertion Quality Inference on a Toy Example**\n",
+ "\n",
+ "This notebook is a self-contained, runnable illustration of the two quantities at the heart of **A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**, and how they are used at *inference* time to decode adaptively:\n",
+ "\n",
+ "| Quantity | Definition | Inference use |\n",
+ "|---|---|---|\n",
+ "| **Unmasking quality** $\\mu_\\star^\\ell(\\boldsymbol{y})=p(\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}\\mid \\boldsymbol{y})=f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell,\\boldsymbol{x}_1^\\ell]$ | probability an unmasked token matches the true token given context | **re-mask** low-quality tokens so only mutually high-quality tokens stay unmasked in parallel |\n",
+ "| **Insertion quality** $\\nu_\\star^\\ell(\\boldsymbol{y})=\\sum_{\\boldsymbol{v}\\in\\mathcal{S}_\\ell}p(\\boldsymbol{y}^\\ell=\\boldsymbol{v}\\mid \\boldsymbol{y})$ | probability an inserted mask decodes to *some* true token belonging in its gap | **drop** low-quality insertions that would otherwise cause length/spacing errors |\n",
+ "\n",
+ "In this notebook, we build a tiny, fully-analytic **toy \"language\"** whose true denoising posterior $f_\\theta$ is known in closed form. That lets us (1) compute $\\mu_\\star,\\nu_\\star$ exactly, (2) train the light-weight predictors $\\mu_\\phi,\\nu_\\phi$ with the paper's **UQL/IQL** BCE losses and watch them recover the truth, and (3) drive the **actual** A2D2 inference routines `apply_schedule_aware_remasking` / `apply_schedule_aware_insertion` (from `remasking_scheduleaware.py`) on hand-crafted states."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 0. Setup\n",
+ "\n",
+ "The two schedule-aware routines below are from the repo's `remasking_scheduleaware.py`. They are the exact code the A2D2 sampler calls. We only feed them a tiny toy `model` exposing the same interface (`model.planner(seq, t)` → confidences, `model.interpolant.{unmask,insertion}_schedule.at(t)`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "torch.manual_seed(0)\n",
+ "np.random.seed(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "vendored A2D2 schedule-aware routines loaded\n"
+ ]
+ }
+ ],
+ "source": [
+ "def _debias_by_token(conf, tokens, clean_index):\n",
+ " \"\"\"Subtract the per-token-identity mean confidence so a token's score reflects\n",
+ " how confident it is RELATIVE to others of the same type (see repo docstring).\"\"\"\n",
+ " if clean_index.sum() == 0:\n",
+ " return conf\n",
+ " V = int(tokens.max().item()) + 1\n",
+ " flat_tok = tokens[clean_index]\n",
+ " flat_conf = conf[clean_index]\n",
+ " sums = torch.zeros(V, device=conf.device, dtype=conf.dtype).scatter_add_(0, flat_tok, flat_conf)\n",
+ " cnts = torch.zeros(V, device=conf.device, dtype=conf.dtype).scatter_add_(\n",
+ " 0, flat_tok, torch.ones_like(flat_conf))\n",
+ " mean_v = sums / cnts.clamp(min=1.0)\n",
+ " baseline = mean_v[tokens]\n",
+ " return torch.where(clean_index, conf - baseline, conf)\n",
+ "\n",
+ "\n",
+ "def apply_schedule_aware_insertion(model, xt_tmp, new_xt, t, dt, ext, mask, pad,\n",
+ " max_length, orig_mask, new_pos_orig, quality_threshold=1):\n",
+ " \"\"\"Remove low-quality insertions (insertion_conf < threshold) while keeping the\n",
+ " sequence length no shorter than the interpolant schedule allows.\"\"\"\n",
+ " device = xt_tmp.device\n",
+ " batch_size, L = xt_tmp.shape\n",
+ " total_ext = ext.sum(dim=1)\n",
+ " if total_ext.sum() == 0:\n",
+ " return xt_tmp\n",
+ "\n",
+ " t_next = t + dt\n",
+ " planner_out = model.planner(xt_tmp, t)\n",
+ " insertion_conf = planner_out.get(\"insertion_conf\", None)\n",
+ " if insertion_conf is None:\n",
+ " return xt_tmp\n",
+ " insertion_conf = insertion_conf.squeeze(-1) # (B, L)\n",
+ "\n",
+ " current_length_after = xt_tmp.ne(pad).sum(dim=1).float()\n",
+ " expected_progress = model.interpolant.insertion_schedule.at(t_next)\n",
+ " estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))\n",
+ " expected_length = estimated_final_length * expected_progress\n",
+ "\n",
+ " valid_b, valid_l = orig_mask.nonzero(as_tuple=True)\n",
+ " valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)\n",
+ " is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)\n",
+ " is_original[valid_b, valid_p] = True\n",
+ " inserted_positions = (xt_tmp == mask) & ~is_original\n",
+ "\n",
+ " candidates = inserted_positions & (insertion_conf < quality_threshold)\n",
+ " num_bad = candidates.sum(dim=1)\n",
+ " min_length = expected_length.long().clamp(min=1)\n",
+ " max_removable = (current_length_after.long() - min_length).clamp(min=0)\n",
+ " length_after_removal = current_length_after.long() - num_bad\n",
+ " schedule_violates = length_after_removal < min_length\n",
+ " k_per_row = torch.where(schedule_violates, max_removable, num_bad)\n",
+ " k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))\n",
+ "\n",
+ " if not candidates.any():\n",
+ " return xt_tmp\n",
+ "\n",
+ " neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)\n",
+ " scores = torch.where(candidates, -insertion_conf, neg_inf)\n",
+ " _, sorted_indices = scores.sort(dim=1, descending=True)\n",
+ " positions = torch.arange(L, device=device).unsqueeze(0)\n",
+ " keep_in_topk = positions < k_per_row.unsqueeze(1)\n",
+ " final_bad = torch.zeros_like(candidates)\n",
+ " final_bad.scatter_(1, sorted_indices, keep_in_topk)\n",
+ " if not final_bad.any():\n",
+ " return xt_tmp\n",
+ "\n",
+ " sort_key = final_bad.long()\n",
+ " _, perm = torch.sort(sort_key, dim=1, stable=True)\n",
+ " xt_tmp = torch.gather(xt_tmp, 1, perm)\n",
+ " num_keep = (~final_bad).sum(dim=1)\n",
+ " tail_mask = positions >= num_keep.unsqueeze(1)\n",
+ " xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)\n",
+ " return xt_tmp\n",
+ "\n",
+ "\n",
+ "def apply_schedule_aware_remasking(model, new_xt, t, dt, remasking_conf, clean_index,\n",
+ " mask, neg_inf, batch_size, debias=False):\n",
+ " \"\"\"Re-mask the lowest-confidence clean tokens so the unmasked count tracks the\n",
+ " unmask schedule (FlexMDM-style schedule-aware remasking).\"\"\"\n",
+ " t_next = t + dt\n",
+ " num_clean = clean_index.sum(dim=1)\n",
+ " current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float()\n",
+ " expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next)\n",
+ " expected_num_clean = expected_unmasked_frac * current_seq_len\n",
+ " masks_to_add = (num_clean.float() - expected_num_clean).round().long()\n",
+ "\n",
+ " k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean)\n",
+ " if k_per_row.sum() == 0:\n",
+ " return new_xt\n",
+ "\n",
+ " conf = _debias_by_token(remasking_conf, new_xt, clean_index) if debias else remasking_conf\n",
+ " remasking_score_temp = -1.0 * conf\n",
+ " remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)\n",
+ " _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)\n",
+ " L = remasking_score_temp.shape[1]\n",
+ " positions = torch.arange(L, device=new_xt.device).unsqueeze(0)\n",
+ " keep_in_topk = positions < k_per_row.unsqueeze(1)\n",
+ " to_mask = torch.zeros_like(clean_index)\n",
+ " to_mask.scatter_(1, sorted_indices, keep_in_topk)\n",
+ " new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)\n",
+ " return new_xt\n",
+ "\n",
+ "print(\"vendored A2D2 schedule-aware routines loaded\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. A toy \"language\" with a known denoising posterior $f_\\theta$\n",
+ "\n",
+ "A real any-length MDM parameterizes an **unmasking posterior** $f_\\theta(\\boldsymbol{x}_t,t)[\\ell]\\in\\Delta^V$. Here we replace it with an analytic stand-in: a 5-letter **bigram** model with a peaked transition matrix $T$. Given a partially-masked sequence, the posterior at a masked position is the (normalized) product of the constraints from its clean left/right neighbours, which is exactly the kind of context-dependent posterior $f_\\theta$ learns, but in closed form so we know $\\mu_\\star,\\nu_\\star$ exactly.\n",
+ "\n",
+ "Special tokens: `MASK` and `PAD` (matching the sampler's two reserved ids)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "a clean target x1 : BCDEABCDEA\n",
+ "posterior at a masked middle position (neighbours clean):\n",
+ " state: BCDEA_CDEA f_theta[5] = [0.009 0.963 0.009 0.009 0.009]\n"
+ ]
+ }
+ ],
+ "source": [
+ "LETTERS = [\"A\", \"B\", \"C\", \"D\", \"E\"]\n",
+ "V = len(LETTERS)\n",
+ "MASK, PAD = V, V + 1 # reserved ids, as in the sampler\n",
+ "NTOK = V + 2\n",
+ "\n",
+ "# Peaked bigram transition matrix -> context is informative.\n",
+ "T = np.array([\n",
+ " [0.05, 0.70, 0.10, 0.10, 0.05], # A -> mostly B\n",
+ " [0.05, 0.05, 0.75, 0.10, 0.05], # B -> mostly C\n",
+ " [0.10, 0.05, 0.05, 0.70, 0.10], # C -> mostly D\n",
+ " [0.10, 0.10, 0.05, 0.05, 0.70], # D -> mostly E\n",
+ " [0.70, 0.05, 0.10, 0.10, 0.05], # E -> mostly A\n",
+ "], dtype=np.float64)\n",
+ "T /= T.sum(1, keepdims=True)\n",
+ "pi0 = np.array([0.4, 0.2, 0.2, 0.1, 0.1]) # initial-token distribution\n",
+ "\n",
+ "def sample_seq(n):\n",
+ " \"\"\"Draw a clean sequence x_1 ~ p_target from the bigram chain.\"\"\"\n",
+ " s = [np.random.choice(V, p=pi0)]\n",
+ " for _ in range(n - 1):\n",
+ " s.append(np.random.choice(V, p=T[s[-1]]))\n",
+ " return s\n",
+ "\n",
+ "def posterior_at(seq, ell):\n",
+ " \"\"\"Toy f_theta(seq, t)[ell] in Delta^V: bigram constraints from clean neighbours.\"\"\"\n",
+ " factor = np.ones(V)\n",
+ " if ell - 1 >= 0 and seq[ell - 1] < V: # clean left neighbour a -> T[a, :]\n",
+ " factor *= T[seq[ell - 1], :]\n",
+ " if ell + 1 < len(seq) and seq[ell + 1] < V: # clean right neighbour c -> T[:, c]\n",
+ " factor *= T[:, seq[ell + 1]]\n",
+ " if factor.sum() == 0:\n",
+ " factor = np.ones(V)\n",
+ " return factor / factor.sum()\n",
+ "\n",
+ "def decode(seq):\n",
+ " return \"\".join(\"_\" if t == MASK else (\".\" if t == PAD else LETTERS[t]) for t in seq)\n",
+ "\n",
+ "x1 = sample_seq(10)\n",
+ "print(\"a clean target x1 :\", decode(x1))\n",
+ "print(\"posterior at a masked middle position (neighbours clean):\")\n",
+ "probe = x1.copy(); probe[5] = MASK\n",
+ "print(\" state:\", decode(probe), \" f_theta[5] =\", np.round(posterior_at(probe, 5), 3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Unmasking quality $\\mu_\\star$\n",
+ "\n",
+ "Given a clean target $\\boldsymbol{x}_1$, an interpolant sample $\\tilde{\\boldsymbol{x}}_t$ (partially masked), and a candidate unmasking $\\boldsymbol{y}$ with $\\boldsymbol{y}^\\ell \\sim f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell]$, the **unmasking quality** is the posterior probability of the *true* token:\n",
+ "\n",
+ "$$\\mu_\\star^\\ell(\\boldsymbol{y}) = p\\!\\left(\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}\\mid \\boldsymbol{y}\\right)=f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell,\\boldsymbol{x}_1^\\ell]$$\n",
+ "\n",
+ "High when the context pins down the token; low when it doesn't. Below, the middle masked position (both neighbours clean) is high-quality, while an isolated masked position (no clean neighbours) has a near-uniform posterior → low quality."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "clean x1 : CDEDABAC\n",
+ "masked x~t: C_ED_B_C\n",
+ "masked isolated: ___DABAC\n",
+ "\n",
+ " pos 1: true D mu* = 0.966 posterior=[0.01 0. 0.01 0.97 0.01]\n",
+ " pos 4: true A mu* = 0.596 posterior=[0.6 0.04 0.02 0.04 0.3 ]\n",
+ " pos 6: true A mu* = 0.056 posterior=[0.06 0.42 0.42 0.06 0.06]\n",
+ " isolated pos 1: true D mu* = 0.200 (near-uniform -> low quality)\n"
+ ]
+ }
+ ],
+ "source": [
+ "x1 = sample_seq(8)\n",
+ "xt = x1.copy()\n",
+ "for ell in [1, 4, 6]:\n",
+ " xt[ell] = MASK\n",
+ "print(\"clean x1 :\", decode(x1))\n",
+ "print(\"masked x~t:\", decode(xt))\n",
+ "print(\"masked isolated: \", end=\"\")\n",
+ "iso = [MASK if i in (0, 1, 2) else x1[i] for i in range(len(x1))] # left token has no clean nbrs\n",
+ "print(decode(iso))\n",
+ "print()\n",
+ "for ell in [1, 4, 6]:\n",
+ " p = posterior_at(xt, ell)\n",
+ " print(f\" pos {ell}: true {LETTERS[x1[ell]]} mu* = {p[x1[ell]]:.3f} posterior={np.round(p,2)}\")\n",
+ "p_iso = posterior_at(iso, 1)\n",
+ "print(f\" isolated pos 1: true {LETTERS[x1[1]]} mu* = {p_iso[x1[1]]:.3f} (near-uniform -> low quality)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 Training the unmasking-quality predictor $\\mu_\\phi$ (UQL)\n",
+ "\n",
+ "At inference we have no $\\boldsymbol{x}_1$, so A2D2 trains a light head $\\mu_\\phi$ to predict quality from the *post-unmasking* sequence $\\boldsymbol{y}$ alone, via the **Unmasking Quality Loss** (binary cross-entropy against the 0/1 hit indicator):\n",
+ "\n",
+ "$$\\mathcal{L}_{\\text{UQL}}(\\phi)=\\mathbb{E}\\Big[\\textstyle\\sum_{\\ell\\in\\mathcal{M}}\\text{BCE}\\big(\\mathbf{1}[\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}],\\,\\mu_\\phi^\\ell(\\boldsymbol{y})\\big)\\Big].$$\n",
+ "\n",
+ "The unique minimizer is the true quality: $\\mu_{\\phi^\\star}=\\mu_\\star$. Note the head sees the *still-masked* neighbours (the realistic decode state), so the BCE minimizer is $P(\\text{correct}\\mid \\boldsymbol{y}) = f_\\theta(\\cdot)[\\ell,\\boldsymbol{y}^\\ell]$. We verify $\\mu_\\phi$ recovers it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " step 500 UQL running avg = 0.5410 (irreducible floor = 0.4777)\n",
+ " step 1000 UQL running avg = 0.4752 (irreducible floor = 0.4777)\n",
+ " step 1500 UQL running avg = 0.4573 (irreducible floor = 0.4777)\n",
+ " step 2000 UQL running avg = 0.4625 (irreducible floor = 0.4777)\n",
+ "\n",
+ "running loss 0.4625 sits at the floor 0.4777 -> converged\n",
+ "corr(mu_phi, mu_star) on held-out states = 0.959\n"
+ ]
+ }
+ ],
+ "source": [
+ "WIN = 2 # context window (each side) fed to the head\n",
+ "\n",
+ "def window_feat(seq, ell):\n",
+ " feats = []\n",
+ " for d in range(-WIN, WIN + 1):\n",
+ " j = ell + d\n",
+ " tok = seq[j] if 0 <= j < len(seq) else PAD\n",
+ " feats.append(F.one_hot(torch.tensor(tok), NTOK).float())\n",
+ " return torch.cat(feats)\n",
+ "\n",
+ "def make_unmask_example():\n",
+ " n = np.random.randint(6, 11)\n",
+ " x1 = sample_seq(n)\n",
+ " ctx = x1.copy()\n",
+ " for i in range(n): # interpolant: mask ~half the tokens\n",
+ " if np.random.rand() < 0.5:\n",
+ " ctx[i] = MASK\n",
+ " ell = np.random.randint(n)\n",
+ " p = posterior_at(ctx, ell)\n",
+ " samp = int(np.random.choice(V, p=p)) # y^ell ~ f_theta(ctx)[ell]\n",
+ " y = ctx.copy(); y[ell] = samp # predictor sees still-masked neighbours\n",
+ " label = float(samp == x1[ell]) # 1[y^ell == x_1]\n",
+ " return window_feat(y, ell), torch.tensor([label]), p[samp] # p[samp] = true quality of y\n",
+ "\n",
+ "class QualityHead(nn.Module):\n",
+ " def __init__(self, din):\n",
+ " super().__init__()\n",
+ " self.net = nn.Sequential(nn.Linear(din, 64), nn.GELU(), nn.Linear(64, 1))\n",
+ " def forward(self, x):\n",
+ " return torch.sigmoid(self.net(x))\n",
+ " \n",
+ "din = (2 * WIN + 1) * NTOK\n",
+ "\n",
+ "\"\"\"\n",
+ "The UQL label 1[y^l == x_1] is a coin flip with probability mu_star, so even a PERFECT\n",
+ "mu_phi cannot beat the entropy floor E[H(mu_star)]. The per-batch loss is also noisy\n",
+ "(a fresh random minibatch each step), so we report an EMA-smoothed running loss + the floor. \n",
+ "\n",
+ "Convergence means the running loss sits AT the floor, not that it reaches 0\n",
+ "\"\"\" \n",
+ "_, _, q_pool = zip(*[make_unmask_example() for _ in range(4000)])\n",
+ "uql_floor = binary_entropy(q_pool)\n",
+ "\n",
+ "mu_phi = QualityHead(din)\n",
+ "opt = torch.optim.Adam(mu_phi.parameters(), lr=2e-3)\n",
+ "run = None\n",
+ "for step in range(2000):\n",
+ " feats, labels, _ = zip(*[make_unmask_example() for _ in range(64)])\n",
+ " loss = F.binary_cross_entropy(mu_phi(torch.stack(feats)), torch.stack(labels))\n",
+ " opt.zero_grad(); loss.backward(); opt.step()\n",
+ " run = loss.item() if run is None else 0.98 * run + 0.02 * loss.item()\n",
+ " if (step + 1) % 500 == 0:\n",
+ " print(f\" step {step+1:4d} UQL running avg = {run:.4f}\")\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " fs, _, qs = zip(*[make_unmask_example() for _ in range(800)])\n",
+ " pred = mu_phi(torch.stack(fs)).squeeze(-1).numpy()\n",
+ "qs = np.array(qs)\n",
+ "print(f\"corr(mu_phi, mu_star) on held-out states = {np.corrcoef(pred, qs)[0,1]:.3f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAGZCAYAAABbpUzOAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAlEJJREFUeJzs3Wd0FFUfgPFn+2bTewNC6J1AkKaIIkWKKChFRJSiYkNAVBAbyAsCgtgAGyhKFREsWFCkSJFi6B0CAdJ72WTrvB9CVlJJ2U02cH/ncHRmZ+7c2STzn9tlkiRJCIIgCIIDyGs6A4IgCMLNSwQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZO3rrrbeQyWQkJyeX+HmrVq246667iu2/fPkyzz33HA0bNkSr1eLt7U2PHj1Yu3ZtsWMvXryITCbj3XfftXf2hRKsWrWKRYsWFdsvfg72U79+fQYMGHDD42QyGW+99ZbjM+Rk7rrrrmLPjaLfxYkTJ3jrrbe4ePFiteatPJQ1nYFb3a5duxgwYABubm689NJLtGnThoyMDNatW8fw4cPZvHkzX375JTKZrKazektatWoVx44dY+LEiTWdlVvenj17qFOnTk1nwykU/S5OnDjBjBkzuOuuu6hfv37NZawEIsjUoPT0dAYPHoynpyf//PMPgYGBts/uv/9+2rRpw9SpU4mIiGDSpEkOyYNer0en0zkkbWdws9/fraRz5841nQWnUZu+C1FdVoM+//xzEhMTeeeddwoFmAIvv/wyzZo1Y86cOZjN5ipf76677qJVq1bs2LGDrl27otPpGDNmDACZmZlMmTKF8PBw1Go1oaGhTJw4kZycnEJpWK1WPvzwQyIiInBxccHLy4vOnTvzww8/FDpm3rx5NGvWDI1GQ0BAAKNGjeLKlSu2YyZOnIirqyuZmZnF8jls2DACAwMxmUy2fWvXrqVLly64urri5uZGnz59iIqKKnTe448/jpubG0ePHqV37964u7tzzz33ABAVFcWAAQMICAhAo9EQEhJC//79C+WppO/r559/5tKlS8hkMtu/ohYuXEh4eDhubm506dKFvXv3FjvmwIEDDBw4EB8fH7RaLe3atWPdunWlXrvAtm3bkMlkbNu2rdD+guq6L7/8stj9nzt3jn79+uHm5kbdunV58cUXMRgMxc6dP38+c+fOpX79+ri4uHDXXXdx5swZTCYTU6dOJSQkBE9PTwYNGkRiYmKh669du5bevXsTHByMi4sLzZs3Z+rUqcV+Xy5cuMDw4cMJCQlBo9EQGBjIPffcw6FDh8q878WLF6NUKnnzzTdt+4pWERWU8P/66y+efvpp/Pz88PX1ZfDgwcTGxhZKz2Aw8OKLLxIUFIROp+POO+/k4MGD1K9fn8cff7zMvADExsYydOhQ3N3d8fT0ZNiwYezdu7fYz6Ckqi3I/9kULWHMmDGDTp064ePjg4eHB+3bt+eLL76gPHMWX/9dfPnllwwZMgSAu+++2/Z7+uWXX/L222+jVCq5fPlysTTGjBmDr68veXl5N7xeVYggU4O2bNmCQqHgvvvuK/FzmUzGwIEDSUpKKvZAray4uDhGjhzJiBEj2Lx5M8888wx6vZ7u3bvz1VdfMWHCBH755RdeeeUVvvzySwYOHFjol/7xxx/nhRde4LbbbmPt2rWsWbOGgQMHFqoLfvrpp3nllVfo1asXP/zwA2+//Ta//vorXbt2tbVXjRkzBr1eX+xBm56ezqZNmxg5ciQqlQqA2bNn8/DDD9OiRQvWrVvH119/TVZWFt26dePEiROFzjcajQwcOJAePXqwadMmZsyYQU5ODr169SIhIYGPP/6YLVu2sGjRIurVq0dWVlap39XixYu5/fbbCQoKYs+ePbZ/17s+vZUrV5KTk0O/fv3IyMiwHfPXX39x++23k56eztKlS9m0aRMREREMGzas0APKHkwmEwMHDuSee+5h06ZNjBkzhvfee4+5c+cWO/bjjz9m165dfPzxx3z++eecOnWK++67j7Fjx5KUlMSyZcuYN28ef/zxB+PGjSt07tmzZ+nXrx9ffPEFv/76KxMnTmTdunXFfpf79evHwYMHmTdvHlu2bGHJkiW0a9eO9PT0EvMvSRJTpkxh4sSJfP7558yYMeOG9zxu3DhUKhWrVq1i3rx5bNu2jZEjRxY6ZvTo0SxatIjRo0ezadMmHnzwQQYNGlRqPq6Xm5tLz549+f3335kzZw7ffvstQUFBDBs27IbnluXixYs89dRTrFu3jg0bNjB48GCef/553n777Qql079/f2bPng3k/0wLfk/79+/PU089hVKp5JNPPil0TmpqKmvWrGHs2LFotdoq3ccNSYLdvPnmmxIgJSUllfh5y5Ytpe7du9u2mzVrJgUFBZWZ5pIlSyRA+vbbbyVJkqTo6GgJkObPn1/h/HXv3l0CpD///LPQ/jlz5khyuVzav39/of3r16+XAGnz5s2SJEnSjh07JECaPn16qdc4efKkBEjPPPNMof3//POPBEivvvqqbV/79u2lrl27Fjpu8eLFEiAdPXpUkiRJiomJkZRKpfT8888XOi4rK0sKCgqShg4datv32GOPSYC0bNmyQsceOHBAAqSNGzeWmu/S9O/fXwoLCyu2v+Dn0Lp1a8lsNtv279u3TwKk1atX2/Y1a9ZMateunWQymQqlMWDAACk4OFiyWCylXv+vv/6SAOmvv/4q8frLly+37Su4/3Xr1hU6tl+/flLTpk2Lndu2bdtC1160aJEESAMHDix0/sSJEyVAysjIKDGPVqtVMplM0vbt2yVAOnz4sCRJkpScnCwB0qJFi0q9P0mSpLCwMKl///6SXq+XHnzwQcnT01P6448/ih0HSG+++aZte/ny5SX+rs2bN08CpLi4OEmSJOn48eMSIL3yyiuFjlu9erUESI899liZ+Sv4G9y0aVOh/U888USxn0H37t0L/Y0XeOyxx0r8PSpgsVgkk8kkzZw5U/L19ZWsVmuZaRb9Lr799tsSf08Krh0QECAZDAbbvrlz50pyuVyKjo4uNU/2IkoyTk66VoqwV8N/Qc+16/3000+0atWKiIgIzGaz7V+fPn0KVdX88ssvADz77LOlpv/XX38BFKuC6NixI82bN+fPP/+07Rs9ejS7d+/m9OnTtn3Lly/ntttuo1WrVgD89ttvmM1mRo0aVShvWq2W7t27F6tGAnjwwQcLbTdq1Ahvb29eeeUVli5dWqz0UxX9+/dHoVDYttu0aQPApUuXADh37hynTp3ikUceASh0D/369SMuLq7Q/VeVTCYrVppo06aNLT/X69evH3L5f4+A5s2b2+7pegX7Y2JibPsuXLjAiBEjCAoKQqFQoFKp6N69OwAnT54EwMfHh4YNGzJ//nwWLlxIVFQUVqu1xHynpKTQo0cP9u3bx99//22r5iyPgQMHFrtf+O9nsH37dgCGDh1a6LiHHnoIpfLGzdJ//fUX7u7uxa4zYsSIcuexJFu3bqVnz554enravsM33niDlJSUYtWTVfHCCy+QmJjIt99+C+RXZy9ZsoT+/ftXSycBEWTsqOAX1mKxlPi52Wy2VQEB1KtXj6SkpGL12NcrqIaqW7euXfIYHBxcbF9CQgJHjhxBpVIV+ufu7o4kSbYqrqSkJBQKBUFBQaWmn5KSUup1QkJCbJ8DPPLII2g0GluV0YkTJ9i/fz+jR48ulDeA2267rVj+1q5dW6y7uE6nw8PDo9A+T09Ptm/fTkREBK+++iotW7YkJCSEN998s1C7T2X4+voW2tZoNEB+Fcv1+Z8yZUqx/D/zzDMApXZ5rwydTles+kOj0ZRY7+7j41NoW61Wl7m/II3s7Gy6devGP//8w6xZs9i2bRv79+9nw4YNwH/3LpPJ+PPPP+nTpw/z5s2jffv2+Pv7M2HChGLVlGfOnOGff/6hb9++theM8rrRz6Dgd65ou6dSqSx2bklSUlJKbDMt6+/gRvbt20fv3r0B+Oyzz9i1axf79+9n+vTphfJuD+3ataNbt258/PHHQP5L5cWLF3nuuefsdo2yiN5ldlTwi3j16tViv5SSJBEXF0eHDh1s+3r37s3vv//Ojz/+yPDhw4ulJ0kSP/zwA76+vrRt29YueSypROTn54eLiwvLli0r8Rw/Pz8A/P39sVgsxMfHlxhE4L8/+Li4uGLdTWNjY21pQX6p6v7772fFihXMmjWL5cuXo9Vqefjhh4tde/369YSFhVXq/gBat27NmjVrkCSJI0eO8OWXXzJz5kxcXFyYOnXqDdOtrIL8T5s2jcGDB5d4TNOmTUs9vyBgXN9wD/YNTBW1detWYmNj2bZtm630ApTYvhEWFsYXX3wB5AeSdevW8dZbb2E0Glm6dKntuC5dujBkyBDGjh0LwJIlSwqVsqqi4HcyISGB0NBQ236z2Vzopaes8/ft21dsf3x8fLF9Wq22UHtcgaI/rzVr1qBSqfjpp58KvRRs3LjxhvmpjAkTJjBkyBD+/fdfPvroI5o0aUKvXr0ccq2iREnGjnr06IFMJitxEOWvv/5KZmYmPXv2tO0bO3YsgYGBTJs2rcTi8bx58zh16hTjx4+3vZ05woABAzh//jy+vr506NCh2L+CInXfvn2B/AdAaQqq4r755ptC+/fv38/JkyeLVYOMHj2a2NhYNm/ezDfffMOgQYPw8vKyfd6nTx+USiXnz58vMW/XB+3ykMlktG3blvfeew8vLy/+/fffMo/XaDRVeqts2rQpjRs35vDhw6Xm393dvdTzC777I0eOFNp/fW++6lYQyIv+ThZtXC6qSZMmvPbaa7Ru3brE7/2xxx5jzZo1LF++nFGjRpVaI1BRd955J0Cxv8v169eXq9fm3XffTVZWVrHvfNWqVcWOrV+/PmfOnCn0UpCSksLu3bsLHSeTyVAqlYWqWnNzc/n6669vfEMlKFp6K2rQoEHUq1ePF198kT/++INnnnmm2sbeiZKMHTVs2JDnnnuO+fPnk56eTr9+/XBxcWH//v288847dOjQoVA9rpeXF9999x0DBgwgMjKSl156ibZt25KZmcnatWtZuXIlvXr1KnGU89GjR1m/fn2x/bfddlu53vivN3HiRL777jvuvPNOJk2aRJs2bbBarcTExPD777/z4osv0qlTJ7p168ajjz7KrFmzSEhIYMCAAWg0GqKiotDpdDz//PM0bdqUJ598kg8//BC5XE7fvn25ePEir7/+OnXr1i023qd3797UqVOHZ555hvj4+EJVZZD/Rztz5kymT5/OhQsXuPfee/H29iYhIYF9+/bh6up6wx5IP/30E4sXL+aBBx6gQYMGSJLEhg0bSE9Pv+HbXOvWrdmwYQNLliwhMjISuVxe4cD2ySef0LdvX/r06cPjjz9OaGgoqampnDx5kn///ddWV16SoKAgevbsyZw5c/D29iYsLIw///zTVjVVE7p27Yq3tzfjx4/nzTffRKVSsXLlSg4fPlzouCNHjvDcc88xZMgQGjdujFqtZuvWrRw5cqTU0uNDDz2ETqfjoYceIjc3l9WrV9uq6yqrZcuWPPzwwyxYsACFQkGPHj04fvw4CxYswNPT84YlplGjRvHee+8xatQo/ve//9G4cWM2b97Mb7/9VuzYRx99lE8++YSRI0fyxBNPkJKSwrx584pV4fbv35+FCxcyYsQInnzySVJSUnj33Xcr/TJZUMX46aef4u7ujlarJTw83FaKUygUPPvss7zyyiu4urqWq9u23Ti8a8Etxmq1SkuWLJE6dOgg6XQ6Sa1WS40bN5ZeeeUVKSsrq8RzLl26JD3zzDNSeHi4pFKpJEACpJkzZxbquSRJ//UMKu3f9T1diurevbvUsmXLEj/Lzs6WXnvtNalp06aSWq2WPD09pdatW0uTJk2S4uPjbcdZLBbpvffek1q1amU7rkuXLtKPP/5Y6Ji5c+dKTZo0kVQqleTn5yeNHDlSunz5conXfvXVVyVAqlu3bqk9rTZu3CjdfffdkoeHh6TRaKSwsDDpoYceKtQL6bHHHpNcXV2LnXvq1Cnp4Ycflho2bCi5uLhInp6eUseOHaUvv/yy1O+qQGpqqvTQQw9JXl5ekkwmkwr+ZMrq5UeRnj+SJEmHDx+Whg4dKgUEBEgqlUoKCgqSevToIS1duvSGeYiLi5MeeughycfHR/L09JRGjhxp6zFXtHdZSfdf0OuxQGl5L+jJVtCTsUBBL67rex/u3r1b6tKli6TT6SR/f39p3Lhx0r///lsoTwkJCdLjjz8uNWvWTHJ1dZXc3NykNm3aSO+9916h3+uC3mVF8+Lm5ibde++9kl6vL/F7LSlf19/H9T2t8vLypMmTJ0sBAQGSVquVOnfuLO3Zs0fy9PSUJk2aVMK3XtiVK1ekBx98UHJzc5Pc3d2lBx98UNq9e3eJf3NfffWV1Lx5c0mr1UotWrSQ1q5dW2LvsmXLlklNmzaVNBqN1KBBA2nOnDnSF198IQGFen2Vp3eZJOX3DgwPD5cUCkWJ+bp48aIESOPHj7/h/dqTTJLKMfJHqFZHjx6lW7duRERE8Msvv+Di4lLTWRKEm87u3bu5/fbbWblyZaV6il28eJHw8HCWL19evSWDSvrwww+ZMGECx44do2XLltV2XVFd5oRat27Npk2b6NOnD4MHD2bTpk1VrjIQhFvZli1b2LNnD5GRkbi4uHD48GHeeecdGjduXGqHjJtFVFQU0dHRzJw5k/vvv79aAwyIIOO0unfv7vDpHgThVuHh4cHvv//OokWLyMrKws/Pj759+zJnzhzHj3ivYYMGDSI+Pp5u3boV6tFXXUR1mSAIguAwoguzIAiC4DAiyAiCIAgOI9pkrrFarcTGxuLu7i4WCBMEQbgBSZLIysoiJCSkzLFGIshcExsba7f5wQRBEG4Vly9fLnPFUhFkrimY2uPy5cvFRucKgiAIhWVmZlK3bt0yp0UCEWRsCqrIPDw8RJARBEEopxs1L4iGf0EQBMFhRJARBEEQHEZUl5WDxWKp8uJWgvNTqVSFpl4XBKHqRJC5gezsbK5cuYKYGOHmJ5PJqFOnDm5ubjWdFUG4aYggUwaLxcKVK1fQ6XT4+/uL8TM3MUmSSEpK4sqVKzRu3FiUaATBTkSQKYPJZEKSJPz9/cV0+7cAf39/Ll68iMlkEkFGEOxENPyXgyjB3BrEz1kQ7E8EGUEQBMFhnC7I7Nixg/vuu4+QkBBkMhkbN2684Tnbt28nMjISrVZLgwYNamTNhOokk8nIzs6263n169fn2LFjVc1aIV9++SVnzpyp0Dnp6enMmzev3Me/9dZbGI3GimZNEIRq4nRBJicnh7Zt2/LRRx+V6/jo6Gj69etHt27diIqK4tVXX2XChAl89913Ds6pcCPVEWRmzJghgowgODGnCzJ9+/Zl1qxZ5V4SdenSpdSrV49FixbRvHlzxo0bx5gxY3j33XcdnNOymSxWTsVlsud8MqfiMjFZrHZN/+OPP6ZTp062NcYLnD17lv79+3PbbbfRtm1bFi9eXOL5O3fupHXr1nTs2JHnnnuuzC7aCQkJDBo0iNatW9OqVSs+/fRT22dFS0AdOnRg27ZtfP755xw4cIAJEyYQERHB5s2bC6VptVp57rnnaNasGW3btiUyMpK8vDzGjx9Peno6ERERdOjQAYCFCxdy22230a5dOzp27Mg///wDwPjx4wHo2rUrERERJCYmkpWVxRNPPEHHjh1p06YN48ePt41xmjVrFs2bNyciIoKIiAguXbpUka9cEITKkJwYIH3//fdlHtOtWzdpwoQJhfZt2LBBUiqVktFoLPW8vLw8KSMjw/bv8uXLEiBlZGTYjsnNzZVOnDgh5ebmVjjvJ2MzpDX7YqQNB69Iq/+5JJ2MzbjxSeUESIsWLZIkSZJOnDghubm5SSaTSTKbzVKHDh2kkydPSpIkSTk5OVLr1q2lgwcP2s7LysqS8vLypJCQEOmvv/6SJEmS1q5dKwHS0aNHS7ze0KFDpalTp0qSJEkJCQlSnTp1pH/++UeSJEkKCwsrdF5kZKQt3e7du0s//vhjiWn++++/UrNmzSSLxSJJkiSlp6dLFotFio6Olnx9fQsdm5iYaPv/PXv2SC1btiz0XWRlZdm2n3jiCWnFihWSJEmS1WqVxo4dKy1cuFBKTU2VPD09Jb1eb/tuiv5cq/LzFoTaxmQy2f5WKyMjI6PYM7MkTleSqaj4+HgCAwML7QsMDMRsNpOcnFzqeXPmzMHT09P2z97T/KfpjagVcoI8tWiUCtL09q3SeeSRRwBo3rw5SqWS+Ph4Tp8+zfHjxxk+fDgRERF07dqVrKwsTpw4Uejc06dPo9PpuOuuuwAYOnQonp6epV7rjz/+4NlnnwUgICCAwYMH8+eff1Yp/w0aNMBkMjFmzBi++uorTCZTqWtSREVF0b17d1q1asX48eM5ceJEqVVkGzduZP78+URERNCuXTt27tzJ2bNn8fDwoHHjxowcOZJPPvmE1NTUm35td0EojSRJPPLII/Tp04fLly879Fo3xTiZol1PpWtVP2V1SZ02bRqTJ0+2bRdMW20v3jo1l1L0xGfkYTBb8Nap7ZY2UOgBqVAoMJvNSJKEn58fhw4dKvNcqYyqsT/++IMpU6YAMGTIEKZPnw4U/y4LtpVKJRaLxbY/Ly+vXPn39PTk+PHjbN++nb/++otp06axY8cOlMrCv5JGo5EHH3yQbdu2ERkZSWZmJp6enhiNRtTq4t+pJEls3LiRBg0aFPts79697N69m23bttG5c2dWr15Nt27dypVfQbiZyGQyunXrxogRIxy+jlatL8kEBQURHx9faF9iYiJKpRJfX99Sz9NoNLZp/R0xvX/DADci6noR5Kkhoq4XDQMcP1VJ06ZN0el0rFixwrbv3LlzpKamFjquWbNm5ObmsmPHDgDWr19PRkYGAD179uTQoUMcOnTIFmB69uxpa4dJSkri+++/p0ePHvn32bChrY1k3759nD592nYdDw8PW7pFJSUlkZOTQ+/evZk9ezb169fnxIkTeHh4oNfrMZvNQH7QMplMtj+EDz/8sFA67u7uha4xcOBA3nnnHdv5aWlpnDt3jqysLBISEujWrRuvv/46d9xxB1FRUeX+bgXhZqDX6/n2228BeO6557j//vsdfs1aH2S6dOnCli1bCu37/fff6dChAyqVqoZyBSqFnGbBHnRp6EezYA9UCsd/1Uqlkh9//JF169bRpk0bWrZsybhx48jNzS10nEajYfXq1Tz77LN07NiRffv2Ua9evVLT/eCDDzhy5Aht2rTh7rvvZvr06XTs2BGA//3vf7z//vt06tSJ5cuX07JlS9t5Tz75JDNnziyx4f/y5cv06tWLNm3a2DoU9O3bFx8fHx555BFat25Nhw4d8PDwYObMmXTs2JE777wTjUZTKJ0XX3yRHj162Br+Fy1ahFKpJCIigjZt2tCzZ08uXrxIRkYGgwcPpnXr1rRp0waTycRjjz1W1a9cEGqN7OxsBgwYwOOPP87Vq1er7boyqay6kxqQnZ3NuXPnAGjXrh0LFy7k7rvvxsfHh3r16jFt2jSuXr1qe1uPjo6mVatWPPXUUzzxxBPs2bOH8ePHs3r1ah588MFyX7egGiYjI8NWqsnLyyM6Oprw8HBRf38LED9v4WaVmZlJv379OHz4MJs3b7ZLNXFJz8ySOF2bzIEDB7j77rtt2wXtJo899hhffvklcXFxxMTE2D4PDw9n8+bNTJo0iY8//piQkBA++OCDCgUYQRCEm1VaWhr33nsvp0+fZsuWLXTu3Llar+90Qeauu+4qs2H6yy+/LLave/fu/Pvvvw7MlSAIQu2kVCrx9/dnyZIltG/fvvqvX+1XFARBEBziYlImr6w/zOU0A37ybCbdHcbdndrx008/1VieRJARBEG4STy38iDH4vWYM5PZv3Y6O5d6kBp9vEZnGBdBRhAE4SZxPF6POSOBhNWvIklWXPtMrvElLESQcRDJYiFv7xHMCSkoA33Rdm6DTCyEJQiCg+iNZoypV0lY8xoyhYKgh+ei9Ayo6WyJIOMI2T9tJ3n6+1hik2z7FCH++P3vBdwGdK/BnAmCcLNat/8i1oxEXJVq+nR9HIM+lxhf+07MWxm1fjCms8n+aTsJY14rFGAALHFJJIx5jeyftlf5GmWtCxMREVFs8KUj3XXXXTXaqCgIQv54wYOf/cSGf3azX16fOf9sZ+Hm1Sxf84ldnjlVIYJMBUiShDUnt9R/lsxskl9dBCX1wL62L/nV97FkZpeaRlXHxh46dAgXF5cqpSEIQu1x8OBBIttGoPp2Ef45WSiua4Nxzciw28ttZYnqsgqQ9HlE1+9dhQTySzQXG/Yt9ZDwi78jc71xkHj33XfZsmULSUlJzJgxg4cffhjIL+VkZWXh5uZG/fr1GT16NL/99htxcXGMHTuW1157DcgvgXTq1Indu3cTGxtLr169bCuKZmVlMXnyZA4fPkxeXh5du3blww8/RKVSceLECUaPHo3JZKJ58+alToi5bds2Jk6cSOfOndm1axcqlYoVK1bw9ttvc/ToUUJDQ/n+++9xc3PDZDLx+uuvs3XrVoxGI82aNWPp0qV4eXmxatUq3n//fYxGI5IkMXv2bPr16wdQ5v0Jwq1gz5499O3bl3CLgpHaYIo28ReUIpJf+wDXvnfUSLuwKMnUUjKZjF27dvHrr7/y/PPPlzpdd3p6Ort372bfvn3Mnz+/0JxF58+fZ9u2bRw7dozffvuNPXv2APnzgd15553s27ePw4cPYzabbSuVPvroozzzzDP8+++/PP/88+zfv7/UPB4/fpzx48dz9OhRunTpwr333suCBQs4ceIEKpWKVatWATB//nzc3NzYt28fhw4domXLlrz55psA9OnTh7179xIVFcXGjRsZN26cbRGyG92fINzMduzYQe/evWkZFs5ybVM85aWUGSSwXE0kb++R6s3gNaIkUwEynZbwi7+X+nnunsPEP/zSDdMJWj0fly5tS71GeYwbNw7IX5fljjvuYOfOnYwYMaLYcQXrzvj7+9OgQQOio6MJDQ0FYPjw4SgUClxcXIiIiOD8+fN06dKFjRs3snfvXhYsWJB/X7m5qNVqMjMzOXbsGI8++igAnTt3pnXr1qXmsWnTpkRERADQvn17Ll26RJ06dQCIjIzkwoULQP4aMJmZmaxfvx7In96/YcOGQH5d8yOPPMKVK1dQKpUkJydz6dIlGjVqdMP7E4Sb2cqVK+nUqRPfjHmBnBduvGS5OSGlGnJVnAgyFSCTycqsytLdfRuKEH8scUklt8vIQBESgO7u2+xebC2tL3xJ687c6LPS1mTJzMysUJ/7oukX3S7ooCBJEosXL7YtH3C94cOH8+677/LAAw8A4OPjU6iKrqz7E4SbUXp6Ol5eXnz88cf560gdPElOOc5TBpa+9IkjieoyO5IpFPj974VrG0U/zP+P36wJdgkwy5YtA+DixYv8/fff3HHHHVVOs0Bpa7J4eHjQqlUrVq5cCeSvH3P06FG7XG/hwoXo9Xogf82L48eP265dv359AL755hvS0tKqfD1BqK2+//576tevz6FDh1AqlWi1Wswp6WWfJANFaADazm2qJY9FiSBjZ24DuhO4bBaKYP9C+xUhAQQum2W3cTIajYbbb7+d3r178+GHH9p1dbvS1mQBWLFiBR999BHt27fn008/pVOnTlW+3tSpU4mIiKBTp060adOGzp0721b3fP/99xk0aBB33HEHhw8fLnPdG0G4ma1Zs4YhQ4Zw77330rJlSyRJIn3xGhLH5bdfSkDRUTEFFSr2ermtDKdbT6am2Hs9GTHiv/YR68kIzuqrr75izJgxjBw5kmXLliGXJJJffZ/M5RsB8Bg9iBOhYegWLcMnO9N2nt7Hm/AFLzpkEHitXU/mZiFTKHC5vV1NZ0MQhFouOzubV199lbFjx+YPM9DnET/uTfR/7gWZDN8Zz+A5fhiu0SlM0bsRcPYCAXk5mH29qNOrE68NKL1zTnUQQUYQBMFJmUwm3Nzc2L9/P8HBwVjik4l7+GWMx88hc9EQsOQN3PrfCcD5lBx0LlrOhzXgrAw0CjmNNTW3BH0B0SYjCILghObOncs999yDwWAgJCQE4/HzXOnzFMbj51D4exOy8QNbgAFAkuGmUSBXyDGZJeRyGfV9dTV3A9eIICMIguBEJElixowZTJ06lbvvvhu1Wk3OH3u5OuAZLHFJqJqEEfrLUrTtWxQ6r2WIB+5aNW5qBSHeWur66jCYa36CTFFdJgiC4CQkSeLVV1/lnXfeYfbs2UybNo2MLzeSPHURWCxo72hP0PJZKLzci53bNNiD5qHuZOSZ8NAoUSjlSDW7lAwggowgCILT2L59O++88w7vvfceL0yYQPKbH5OxeA0A7sP74r/gJWTqkttZVAo5DXzc2HYykXS9Ca1ShosT9GgV1WW1UFnT+devX59jx45VOu1t27bRoUOHSp8vCELFFYwkueuuuzh48CATnhxPwpg3bAHGZ+o4/D+YVmqAKRCdkoMkk+HhosSKjOiU8swF4FgiyNRCYjp/Qbh5WCwWxo4dyyeffAJAmzr1iR38Ajk/bwe1ioAlr+P94mPlmtLJYLbgplZRz9sVd40Kg9ni6OzfkAgytdD1i5bt3LmT1q1b07FjR5577rlC69GcPXuW/v37c9ttt9G2bVsWL15s+2zkyJF06NCBNm3aMGDAABITE6v9PgThVmcymXj00UdZsWIF7u7uGM9c5Grf8RgOnkDu5U7Itwtxf6j8y4u0CfXCTasgy2DGTaugTaiX4zJfTqJNphLi4uKIi4srtM/b25vw8HDy8vI4ceJEsXPat28PwOnTp8nJKVyErV+/Pj4+PhXOh8FgYPjw4axcuZK77rqLdevW8fHHHwP5b0cjRozg66+/plmzZuj1ejp37kznzp1p3749ixYtws/PD4B33nmHmTNn2qbzF4TaKCPXyLr9MVxK1hPmp2PobfXwdFHXdLZKZTQaGT58OD/++CNr166lX2A4V/s9jTUjG2X9EIJXz0fdqGLTKPVoEYhCIeNqWi6h3i50bxrgoNyXnwgylfDJJ58wY8aMQvseeeQRvvnmG65cuUJkZGSxcwpKGI8//jh79+4t9NnXX3/NyJEjK5yP06dPo9PpuOuuuwAYOnQoTz75pO2z48ePM3z4cNvxWVlZnDhxgvbt27Ny5Uq+/vprDAYDubm5BAUFVfj6guBM1u2P4ZdjCWiVco7HZQHwxJ2NajhXpXvttdf4+eef2bBhA3fpVcQOfRFMZjS3tSJ4xWwUft4VTlOnVtK3dYgDclt5IshUwlNPPcXAgQML7fP2zv+FqFOnDgcPHiz13C+//LLEkkxllDXtnCRJ+Pn52SaavN7ff//NRx99xO7du/H39+eHH35g5syZlcqDIDiLS8l6tEo5DQPcOJ+YzaVkfU1nqUxTp06lX79+tNl3gcT5ywFwHXg3AR9NR+6iqXB6JouV84nZpOmNeOvUNAxwQ6Wo+RYREWQqITg4mODg4BI/02q1tqqxkjRt2tRu+WjWrBm5ubns2LGDO++8k/Xr15ORkWG7jk6nY8WKFYwaNQqAc+fO4ePjQ1paGh4eHvj4+GA0Gm0NjoJQmxStHgv01HI8LovzidnkmqyE+dX8aPeisrKyePrpp5k9ezZ1A4Novm4nad/mL4ToNeERfKY/iUxeucBwPjGbw1cyUCvkXErJD7DNgkufuLK61HyYEypNo9GwevVqnn32WTp27Mi+fftsU+ErlUp+/PFH1q1bR5s2bWjZsiXjxo0jNzeXvn370qhRI5o1a0afPn1sq1cKQm1SUD12ITmHzUcTAOjXOpAGfq70ax3I0Nuca1mI9PR0evfuzQ8//MDlU2eJHTKZ7G9/B4UC/4Uv4fv6+EoHGIA0vRG1Qk6QpxaNUkGa3mjH3FeemOr/GntP9S/UPuLnXbu8tuEIF5JzbNVjQZ5afF1VRCfpCffX8dRdjfBzc46fY0pKCn369OHChQv8vOxrQueuwnQuBpmbjqBlb6O7u2OVr3EqLpODMWmk642k6020r+dFzxZBDqsyE1P9C4JwU6vrq+PApTTSL6ZilcBsMXPgkgm1Qs7pxPwu/tMHtKrhXILVaqVv375cunSJXz/8DP+pSzClZKAMDSBo1Tw0LRra5ToNA9yITs5vi/JyUZOaY+J8YnaNV5mJICMIgtOIz9Az/9eTRCflEu7vwkv3NifIs+S2lch63hy5nE6a3oiXTk18Zi5qhZxwP1eik3OITnKOhn+5XM4bb7xB0MUkfKYuxppnRN26McGr5qEM8rPbdVQKOV46Nc2DPQjy1BKfkecUVWaiTUYQBKcx/9eT/HU6hSvpuWw9lcL8X0+WeqzJKtGzRTDP39OUXi2CCfLQYjBbiU7OwWC2Eu5fsw3/MTExzJw5E6vVyu3RGXjPWIaUZ0TX53ZCf/jIrgGmgLdOjcFsIT4jD4PZgreu5scJiZJMOYhmq1uD+DnXvOikXJRyGXV9dFxO1ROdVPIcfZD/QL2Uorc9UEd1qc+fJxMKtcnUlAsXLtCjRw9kMhlD4q2ov/0TAM9xD+I763mHLcVez1dHdHK2bTBmPSdYT0YEmTKoVCpkMhlJSUn4+/uXa+4goXaSJImkpCRkMhkqVc2vJnirCvd3ITpFz+VUPSaLRLh/6XP0eemU7I9OslWt3dW0uVO0wZw+fZp77rkHF42WVc175AcYmQzft5/H66khDr12TIqezDwLvm5aMnLNxKToRZuMM1MoFNSpU4crV65w8eLFms6O4GAymYw6deqgcILp0W9Vz/dozJVUPZfTDIT7anm+R+NSj/3oj7P8diwBixVOJ2Sikino3jyg0JQqOnX1PuKio6Pp3r07Ph4erPBpi/c/J5G5aAj85E1c+3Zz+PWv78bsLG0yIsjcgJubG40bN8ZkMtV0VgQHU6lUIsDUsJPx2dTxdadRoBe5Jgsn47Op71/ym/i2M0nkmiU0Shl5ZonfTsZhArRKBSeuTStT3VOs1KlTh5F972P43qt4n49H4e9D0Kq5aCOaVcv1i1YhijaZWkKhUIiHjyBUg6tpuWiVChr4u3IhKYeraaW3ycgKui1JMkDCaqXc59rb/v37MRqNROTIeGbrJSR9Lqpm4QSvmoeqbvXNC9gwwA2g0NQyNU0EGUEQnEaotwsn4rK4kJRDrslCqHfpbTLdG/ux8VAcFquEi0pOm1APck2Wcp1rT7t27aJfv350rNeQpYnuYLXi0r0DgV/MROFZfJnkW40IMoIgOI3WoR58dyCGIym51Pd1oXVo6Y3WE3s3w0WttPUme6xrfY5ezazWae63bdvGgAEDaOsbzLuxGpBbcR/RH/93pyBTVf/j1RnnLxNBRhAEp/HV7oucSdJfG7Wv56vdF0vtMebnpi32WR2f6qse2rJlCwMHDqSTTzAf5fjjIlfg8+oTeE18tMZ6ojpjw78YjCkIgtM4m5CDwWRBBhjNFs4m1Pwa9aXxU2np712HxYZgXDRaAj55E+9Jo2p0qIMYjCkIglAWyUpKtpGUbKNtuzRXUrN5a9MxLl6rWnvr/lbVUpLZunUrkb7B+Ez+gLmmIOQ+HgStmINL5zYOv/aNiIZ/QRCEMpiRkMtALiN/0ktKn4XhrU3H2B2djgK4mmHgrU3H+Hx0Z4fmb+XKlYwaNYqpXo0ZI/dHFV6HoNXzUDes69DrlpdKIa/xNpiiRHWZIAhOQyaBh05NXV9XPHRqZGXM9HMxJRcFEOylRSnL33akL774gkcffZTBmgAek/mh7dia0F+WOE2AcVYiyAiC4HAmi5VTcZnsOZ/MqbhMTJaSq8G6NPLHTa3AYLLiplbQpZF/qWnW93XBLEFceh5mKX/bUT766CPGjRvHw5pAZusa4jm4J8HfvYfC18th17xZiOoyQRAcrrxdax/pHIZKIbMtqVzW6pZTejchetW/JGabCHFXMaV3E4fkXTIYiV6+nse1IUx3Dcd70ih8po2r0iqWtxIRZARBcDhHdK3deT4VT1cXAj1dyTVZ2Xk+lWahPnbI7X+idu4iYN5qnr5kBs9GBCx4GY9H+tv1Gjc7EWQEQXC48s6ptXLvJb49cBmrVeLvczJMFoln7i55ksxLyXq0Srlt+eWziTn8cjTWLhNkSpLEaxMmMeej9/nZqx3NvAMIXD4LXfcOlUrvVua05b3Fixfb1lqPjIxk586dZR6/cuVK2rZti06nIzg4mNGjR5OSklJNuRUEoSwNA9yIqOtFkKeGiLpepXat3XU2iTS9CbMkkZ5rYtfZpFLTDPFyITYjl+1nkojNyMVgNPPHyUSik/VsOZHI9tOJlcqrJElMfnQ0sz96n5d09WlRvwEhPy8WAaaSnDLIrF27lokTJzJ9+nSioqLo1q0bffv2JSYmpsTj//77b0aNGsXYsWM5fvw43377Lfv372fcuHHVnHNBEEpS0LW2S0M/mgV7oFKU/OgxWSVMFisGkxWTxYrJWnr3skAPDe4aJWqFDDe1EpNkJddoQSaDPJOFmJSKL79stVp5uv8gFq38itddG/B8l3sI/WUpmuYNKpyWkM8pg8zChQsZO3Ys48aNo3nz5ixatIi6deuyZMmSEo/fu3cv9evXZ8KECYSHh3PHHXfw1FNPceDAgWrOuSAIVdG+njfeOjVqpRxvnZr29bxLPTZNb6JNHW8e7hhG27reGI1W4jLyuJCUTVxGHhVtl5ckicvzPuev37cwy60RTw8eSsimDx2yTPKtxOmCjNFo5ODBg/Tu3bvQ/t69e7N79+4Sz+natStXrlxh8+bNSJJEQkIC69evp3//0hvoDAYDmZmZhf4JglCzOjfwxlUtx2yx4qqW07lB6UEm1Nul0KzLTYI9iKjrQQN/VyLqetAooPwzIJty8zg5/g3MC75mo1cET73wPEHLZyF3rZ6ZnG9mThdkkpOTsVgsBAYGFtofGBhIfHx8ied07dqVlStXMmzYMNRqNUFBQXh5efHhhx+Wep05c+bg6elp+1e3rhhQJQg17Y8TSVxNzyMrz8yV9Dz+OFF6m0zXRn40D3ZDhkTzYDd6NAugaZAnkWG+NA3yJNizfAEiLyWNwU3a8MAXCzHLZITMmYTfrAnIxBpSduF0QaZA0UnmJEkqdeK5EydOMGHCBN544w0OHjzIr7/+SnR0NOPHjy81/WnTppGRkWH7d/nyZbvmXxCEitt2Ko5so5Vcs0SO0cq2U3GlHns2LpNDMemcT84mKiYds9lSrs4F18s+H8P9Tdvy25VzTPJuRJ2v5+D5xEP2vKVbntN1Yfbz80OhUBQrtSQmJhYr3RSYM2cOt99+Oy+99BIAbdq0wdXVlW7dujFr1iyCg4OLnaPRaNBoNPa/AUEQKi0201zm9vU2Hr7K8dhMdGolsel5/HQ0jlmD2pb7Wun/HOaBe3qzJyeJT+p2ZMSPX6Fp27TSeRdK5nQlGbVaTWRkJFu2bCm0f8uWLXTt2rXEc/R6PfIirXwFyyVLUhmTHwmC4FSK/rWW9deblmNGLpPh7apGIZORllN6QCoq57dd/DRwDAf0KXzZqicjd28SAcZBnC7IAEyePJnPP/+cZcuWcfLkSSZNmkRMTIyt+mvatGmMGjXKdvx9993Hhg0bWLJkCRcuXGDXrl1MmDCBjh07EhISUlO3IQhCBQXo5GVuX69ZkDt5JjNnE7LIM5lpFlS+hv74j1cS9+g0ulh17B3wBEN3bUBVp+RaEqHqnK66DGDYsGGkpKQwc+ZM4uLiaNWqFZs3byYsLAyAuLi4QmNmHn/8cbKysvjoo4948cUX8fLyokePHsydO7embkEQhEqYMqAFr60/hsEKGnn+dmlCvLX4uWvQG624qOWEeGvLTFuyWDj/0nwe+vB/dFf7MP2pZ2kwd3KNLJN8K5FJoj4JgMzMTDw9PcnIyMDDw7nWYxCEW8XirWf59uBlLFL+mjJDI+vyTI+Sp5X5fMd5opP1NPB35UJSDuF+Osbd2bDEY605uZx4fCrDvv+CWKuB7198k+7zptfoKpa1XXmfmU5ZXSYIwq3peGwGcpmcut4uKJBzPDaj1GOLjpMJ9S65y7I5PpnD947jwe8/J0Ey8ut7S7hr/msiwFQTUU4UBMFpaJVyEjL0xGfmIpMkIuqW/obcvWkAgG1CzK6N/DgVl1lo6WHrmYvEj3iZRad2k4GFP1evp+3QgdV1OwIiyAiCUA1MFivnE7MLBYCS5i8zmS2YpWvj4oBcg7nUmZX1RjP/XkolOklPuL8Ob52KS6l5tjVrZHuiUL44G7L1TG/djdcXTKJRN8cuzywUJ4KMIAgOV95FyzKNFjxdVHhoVWTmmbiYpuePk4lolQpOxGUB0Ld1fo/RxX+dZfPReOTIOBqbwemETJKzDKTnWrj39CH8f/qW59OPs6hrf+7Z+CkKH8/qu2HBRgQZQRAcrryLlrmqFKTpjaRmG5HJwctFiVapsDXuX03LtR17/GoWViv4eahJzDRwIDoNixUe27+dDgf+ZETmMdzd3Wnx5TsiwNQgEWQEQXC48i5almc2gwQyGSCBUiaRkJlLQmYeINGlwX8rX/q6KFBcuoBfbg4eLq4c9Qhg8s5fCDl7gBEZx1C5ebLz5JESZ/wQqo8IMoIgOFw9Xx3Rydm2tpV6vroSj8s2Snjp1AR4aEnMzEOSyQjy1JKZZ8Zdo6TOtR5k2T9t5/l3FqJOTrWda5IrkFnM9Mk6CR5+eI15VwQYJyCCjCAIDheToiczz4Kvm5aMXDMxKfoS22RCPdX8E51GYnZ+dVojfx2RYb62arZcs5Xsn7aTMOY1VEVG+KmsFiSZjF4dhvNr53tx8xLj3ZyBGCcjCILDXd8mo1EqSm2TOR5beF2nq6l6LqZks/NsEhdTsnFXykie/n5+ldp1x+03ZTAp6zQGycqY+ATUah0eLioH3pFQXiLICILgcN46NQaz5YZtMomZBuC/AJKSa0J+bVsGyKNOYIktvMbMbmM6YzKOk2w1YkUiICeLrulX6dlCzFvoDER1mSAIDlewtsv142RKkmOwAv/NvpxjlNh5NpH0XDNeLkraG+MIve74bcZUnsk8SWeVF4s9mqGV5c++3tVTSf16Xg66G6EiRJARBMHhVAp5iW0wNyIBJ+JzUABX0w3systm6LXPzphzeDrzJN3V3rzv3gyN7L+KmeZtwuh4bUYAoWaJICMI1aS8o95vRldSs3lr0zEupuRS39eFt+5vRR2f4qUZlRKMpsL7tEoZvm4aVFfiuHPzj7b9jRU6Zrs1YoDGH1VBgJGBIiSA7iN6iOWTncSt8RsuCE6gYNR7fIaBQ5fTOZ+YXdNZqjZvbTrG7uh0EjIN7LqQzlubjpV4XLOgwqUdTzVIkoy6p04ze90yglKS2SjPZLMhGZlcxiBtYKEAA+A3a4IIME5EBBlBqCbl7WF1M7qYkosCCPbSopTlb5fEU6dCKQMFoJRBy2APXog/xtSNq3A35LHWD15KOsrRbs1QBPkXOlcK8OOLwcO58988Biz6i3Px6Q6/L+HGRHWZIFST8o56vxnV93XhaoaBuPQ8zFL+9vUki4W8vUdouO8gFklNStPGpGTkMnDjBm779wAAa5p58trfP/P888/z/vvvc/pqOhd+24dLegZ6Tw/eS5ETnW5ClmsmI9fMxDVR/DTx7pq4XeE6IsgIQjUpbw+rm9HUvs2YuCaK2EwjYR5qpvZtZvss+6ftJE9/H0tsEsOv7UtzdUOv1hCaloJVJmNiiBeb//6RMU89x/vvv49MJiMp10R0eEM0SjkGs5X4y2eRAV46Fel6E7GZt05J0ZmJICMI1aSyPaxuBmeT9DQN8aZtPQW5Jgtnk/Q0CvKyjd6nyOh9r5xsvHOyMSiVvHXPYI54euDtG0h0g4G2xcay88xcSMpGq8pP00sjJ85oJV1vQgJCPG6dkqIzE0FGEASHu5qWW2w2ZclisY3eL0oGWCWJL/Li2efljbefH0r3B0jKMduOcdMqaRDgilapIM9soWN4c77ZfZ7YTCMhHmoWDW9XfTcolEoEGUEQHC7U24UTcVmFlkrO23uk2Oj9ApIkMScnmuV5sbQ7tIW0O/JHxwS4/TdVjL+bFl/XPDRKBQazhVYhHrw7LNJWHRnmf+tURzozEWQEQXC4oksld28agPmnEyUea5Uk3so5z6q8eN5ybUhu4wh+1sgJcFPx8Yj2tuOKtnGZLdZyLYwmVC8RZARBcDidWmlb0bJAbqBvice+ln2Obw0JzHFrxBBtECHj7uKd24tXfRVt49pzPrlcC6MJ1UsEGUGooFt55L49aTu3QRHijyUuqVC7THuVB51UntynDSDDy4sGnduUKz0XpZyDF1PIMuSvPfNQZB0H5VyoCPGXIQgVdCuP3LcnmUKB3/9eAAmMkpWfDfntMw9pA7lPG4AM2HRvv3KP3r+Slkt8loFsg4W4zDyupJU84FOoXqIkIwgVVN716oUbc+nWHqOLimfjD/G3KZ2WSjfqK1xIcfPgky49SW3SstxpXcnIRa2U4+uqJiXHyJUMEWScgQgyglBBt/LIfXuLW7SCJ+Oj2G/J4tv573P4iopPMq0cD66HVSbjDg9NudNSyCA+I4/0HCO5Jivt6no6MOdCeYkgIwgVdCuP3LenjIuXeWjO6xw1ZbLhfwvoN+V53vz+MMcOXsVslVDIZIR4aMudXqNAN9qk5WG2SCgVMhoFip+LMxBBRqgVnKmx/VYeuW9PmZ+ux0OS8U1kP/pOewGAs0lZmCUJuVyGRZI4k5jFqbjMcv3cgz10NAv2sI2bCfbQVeftCKUQQUaoFQoa28UYiNovNTWVK0dP4vb1TyzxaEHw3NdsU8Wk5piwWECmkLBY8qu/yvtzFyVM5ySCjFAriMb26uHoEmNSUhI9e/bEGJvID1JDdJ3a4NKjo+1zq8WKBbBY8rcNJku5f+6ihOmcRJARagXR2F49HFlijIuL45577iEtOYVlhCGXyfB59QlbKQZAUSSgqRRyDlw39mWIGPtS64ggI9QKoiqkejiqxHj58mV69OhBXl4eG/qMJODXfbh074BLkZH8BrMVhSx/gkwJ0BstJGbmISFDbzBzJS2XiDC7ZEmoJiLICLWCqAqpHo4qMV65cgWtVstPny5HOXw6AD6vPlHsuAB3F66m56GQy7FYrWg0SgI8XGyzN8dn5tklP0L1ESP+BUGwaRjgRkRdL4I8NUTU9apyiTEmJgaTyUSXLl04dOgQXqu2gNWK7t470LZvUez4RtfagCRJQqWQ08DbhVyTpdDszULtIkoygiDY2LPEePz4ce655x4ef/xx3nnnHcynosneuBUAn1fGlniOq1aJl4sKqwzkVmgS5ElkuE+h2ZuF2kUEGUEQ7O7w4cP07NmTkJAQJk+eDEDqO18A4PZADzStGpV4nmSVCPV2JchLQ3y6AblcVmz2ZqF2EdVlgiDY1f79+7n77rsJCwtj69atBAQEkPfvCfS//g1yOd4vjyn13DZ1vXDTKMjONeOmUdAy1JNTcZnsOZ/MqbhMTBZrNd6JYA+iJCMIgl2tW7eOZs2asXnzZry8vABInfM5AO5D+6BuXHr3sG5N/InLyOVSsp4wPx2hXi5iEG4tJ4KMIAh2kZ6ejpeXF3PnziUvLw+dLn9al9xdUeRu2w8qJd5THi8zjbj0PDxcNETWd8FgtnA+ORu1Qllil2q90cz204mF2mt0avFIczZ2qS778ccfadeuHY0aNWLw4MH8/vvv9khWEIRa4rfffqN+/frs3r0buVxuCzCSJNlKMR4jB6AKK7t95fpxOhqlAqPJysGLKazZd4mDF1NwUf73yNp+OpE/TiYSnaxny4lEtp5MEFVrTqhKYX/RokW0b9+eKVOmsH79eho1asThw4eZPXs2Fy5cYPz48fbKpyAITurHH3/koYceolevXrRv377QZ7lb95H3zxFkWjXekx+7YVpFx+nojRbiswxIEmQXGYx5NS0XrVJhG0Nz5HI6OUarqFpzMlUqySiVSr7++msuX77MsGHDGD58OD/++CNDhgzh/fffR5KkGyciCEKttX79egYPHsyAAQPYsGEDWu1/U/Pnl2I+A8Bj9CCUQX43TM/PXc35xEx+PnyV84mZZBpM5JnMyGRgMFuIScuxHRtaZAyNTqMoVAoS89s5hyqVZJ577jkAzp8/z6efforJZOLo0aMcOXKEuLg4WrRogZubG/v377dLZgXBkZxpOYHawGAw8MorrzBkyBBWrFiBUln4cZKzeSeGw6eR6VzwnvBIudJcvTeGn4/EYZLgeFwGwZ4aknPMqBVyDGYrzYP+CxwFY2YK2mSCPV04FZ8l5rdzMnZpJfvggw8YMmQIXbt2pXXr1uTk5NCqVSv+/vtvMjIy7HEJQXA4sZxA+ZnNZjQaDX///TcBAQEoFIpCn0sWC6nv5LfFeD41BIWfd7nS3Xoqnsw8C64aBVl5FiRrLo0CPdGq5OSZrHi5/rdSpk6tLDSGxmSxolUpxPx2TqZSr2mXL18utN2qVSv27NlDr169SEhIoF69emzatAkAT0+xBKpQOxRtdBbVLSVbsmQJXbt2JScnh+Dg4GIBBiB741ZMp6KRe7rh9ezwcqcts02NCSChVSkJ8NAS6OFCgIeWBn6upZ5bMFtBl4Z+NAv2EKVQJ1GpkkxYWBje3t60bduWtm3bEhERQUREBMHBwWzYsIEVK1bYO5+C4HBiOYEbe++995g8eTITJ0609SArSjKZSZu7DACvZx9G4elOcnYen2w7R3SSnnB/HU/d1Qg/t+JLK3dv6s/ltFwMZisuKgUD2obQMtRTTCtTi1UqyFy4cIFDhw5x6NAhoqKiWL9+PVevXkUmk+HhIaoXhNqpnq+O6ORs2wOtnq9Yvvd6s2fPZvr06UydOpXZs2cXWgfmellrfsEUfQW5nxeeTzwEwCfbzvHbiUTUCjmnE7MBmD6gVbFzb2/kz5mE/HYxL52aOxr74aXT4KVT461Ti9JJLVSpIFO/fn3q16/PAw88YNu3Z88eHnvsMebOnWuvvAlCtYpJ0ZOZZ8HXTUtGrpmYFL1ok7nmwIEDTJ8+nRkzZvD666+XGmAkg5G0BV8C4P3CSORu+YE6OkmPWiEn3M+V6OQcopP0JZ5vskr0bBFsG3x5KSWXS6l5op2sFrPba0GXLl14//33mTVrll3SW7x4MeHh4Wi1WiIjI9m5c2eZxxsMBqZPn05YWBgajYaGDRuybNkyu+RFuDWINpniCoYhdOjQgQMHDvDGG2+UGmAAMlf8gPlqIoogPzwee8C2P9xfh8FsJTo5B4PZSrh/yaVEF5WCA5dSWLPvIgcupWCwmMXPpJarVJAxmUwl7m/cuDHHjx+vUoYA1q5dy8SJE5k+fTpRUVF069aNvn37EhMTU+o5Q4cO5c8//+SLL77g9OnTrF69mmbNmlU5L8Ktw1unxmC2iDaZa6xWKy+88AILFiwAIDIysuzj9Xmkvfc1AN4vPobc5b+eYMM61MFTIyM1x4CnRsawDiUvo3wlTU9CRh7ZeRbi0/PQGyziZ1LLVaq6zNXVlRYtWtCuXTsiIiJo164dISEhfPjhh/Tu3bvKmVq4cCFjx45l3LhxQP7MAr/99htLlixhzpw5xY7/9ddf2b59OxcuXMDHxwfIr9Iri8FgwGAw2LYzMzOrnG+hdhNLPP/HarUyfvx4Pv/8c5YuXVquczK++A5LUirKsGA8RvQv9Nmq/ZdJ1pvRKhUk5ZhZtf8yb9znVSyNq2m5qJUKfFzVpOYYsUgQUddL/ExqsUqVZLZu3coTTzyBSqVi5cqV9O3blyZNmvDhhx9iNBqZPn06a9eu5eTJkxVO22g0cvDgwWLBqnfv3uzevbvEc3744Qc6dOjAvHnzCA0NpUmTJkyZMoXc3NxSrzNnzhw8PT1t/+rWrVvhvAo3F9EFNp/ZbGb06NF88cUXLF++nCeffPKG51gys0n/cBUA3lNGI1OrCn1+/GoWViv4uKmRpPztkuSZzOy/mMpPh6+y/2IqZrNF/ExquUqVZO644w7uuOMO27bVauX06dO2HmcHDx5k2bJlJCYmYrFYKpR2cnIyFouFwMDAQvsDAwOJj48v8ZwLFy7w999/o9Vq+f7770lOTuaZZ54hNTW11HaZadOm2RZTgvySjAg0N3Yzz3wrRvzne+edd1i5ciWrVq1i2LBh5Ton45NvsaZlomochvuQ4rUZvm4qziZKpOUYMVslfN1UJaQCx69mkpVnRCaDPJOF41dFDUNtZ5eng1wup3nz5jRv3pyHH37Ytj8hIaHSaRZtXJQkqdQGR6vVikwmY+XKlbbBnwsXLuShhx7i448/xsWl+LrgGo0GjUZTbL9QtoKZb7VKBSfi8t9Gb5aVCx094r+2BLEJEybQuXNnevbsWa7jLakZpC9eA4DPy2OQlTA4c1SX+mTmmUnNNhHur2JUl/olphWbnotSrsDTRUlGrpnY9NJrI4TawaG/4UVLI+Xh5+eHQqEoVmpJTEwsNb3g4GBCQ0MLzS7QvHlzJEniypUrFc6DULrrZ751USm4mnbzPAQc3busIIjFZxg4dDmd89fGiziD3NxcRo8ezdmzZ/Hw8Ch3gAFI/2g1UrYedctGuA68q8RjIuv78lr/lrx+Xwte69+SyPq+JR5Xx9cFuQxyTVbksvxtoXZzutcotVpNZGQkW7ZsKbR/y5YtdO3atcRzbr/9dmJjY8nO/u+P9syZM8jlcurUKbkXi1A5RWe+DfW+eR4Cju5d5qxdpHNycrjvvvtYu3ZthV/KzAkpZHy+HgCfV8chk5f8SClve9djt4fTtq4nQR4a2tb15LHbwyt2M4LTccrK9MmTJ/Poo4/SoUMHunTpwqeffkpMTIxtfZpp06Zx9epV2/Q1I0aM4O2332b06NHMmDGD5ORkXnrpJcaMGVNiVZlQeUVnvr2ZpvlwdO8yZ5y2JjMzkwEDBhAVFcWvv/7KnXfeWaHz0xd9jZRrQBPZAl2vkl8CK6JFsCd3NvG3Lb/cIljMfVjbOWWQGTZsGCkpKcycOZO4uDhatWrF5s2bCQvLX60oLi6u0JgZNzc3tmzZwvPPP0+HDh3w9fVl6NChdhsYKvyn6My3N5OCt21HcbYu0pIkMWjQII4cOcKWLVvo3Llzhc43XUkgY8UPAPi8+kSZgzTLa/e5ZE7FZ6NVKjgZl83O00k0CnR3+nYsoXQySawsBuS/0Xl6epKRkSHmXxNuGX/88Qfe3t43HGhZksRJc8n65ie0t7cj5Pv37RJkPt9xnuhkvW21S1eNnIYBHtfWk7EQUddLTCvjJMr7zBSvBIJwi0lISOCNN97AarXSs2fPSgUY04UrZK3+BQBfO5ViAPzc1JxLzOLno7GcS8xCpZA7ZTuWUH4iyAjCLeTq1at0796dzz//nNjY2Eqnkzp/GVgs6Hp2Rtuxtd3yJ0MGMpAkGRISrmqlmFamlnPKNhlBEKqmpDE5sVcu06NHD0wmEzt27Kh0z0vDyQtkf/cHAD5Tx9kz28Rl5eHrqvlvWhmseLooxfILtZgIMoJwEyo6sDQxIZ5H7++FUqlkx44dN5zbryxpc5eBJOE6oDuatk3tl2lAIYP4jDzS9UZyTVZCvTRi+YVartxB5vopWG5k4cKFlcqMIAj2cf2YnPiMPOQ6Tx555BGeeeaZKo0dMxw+Tc7P20Emw+eVsXbMcb5GgW60Sc/FbAGlArxcNYXuQ7TJ1D7lDjJRUVGFtg8ePIjFYqFp0/w3mTNnzqBQKCrViCgIgn0VjMn550AUycmJjBw8gNmzZ1c53ZTZnwHg9lAv1M3sP1Ay2ENHsyBPNEoFBrPFNr1MSW0ytWWanltduYPMX3/9Zfv/hQsX4u7uzldffYW3tzcAaWlpjB49mm7dutk/l4IgVEjDADeOHznEm+OH0qhxU954qnwTXZYld+8Rcrf+AwoFPi+NsUMuiys6lqier46YFH2JY4scPdecYB+VGicTGhrK77//TsuWLQvtP3bsGL17965Sr5WaIsbJCDeTf/75hz59+tC0aVN+/fVX28tgZUmSROwDE8jbfQj3R+8jYOHLdspp5e05n0x8hsFWlRbkqaFLQ7+aztYtw6HjZDIzM0ucYTkxMZGsrJLXiRCEW5XJYuVUXCZ7zidzKi4Tk8Xq0Ovt2rWLnj170rp1a7Zs2VLlAAOQu+MgebsPgVqFz4uPVT2TdiBWMq0dKtW7bNCgQYwePZoFCxbYpqLYu3cvL730EoMHD7ZrBgWhtqvuah1/f3/uu+8+PvvsM1xdXaucniRJpM7+FADPx+5HGVrx2dUdwdmm6RFKVqkgs3TpUqZMmcLIkSMxmUz5CSmVjB07lvnz59s1g4JQXcrbkFzRBueiPb0c1UNqx44dRERE0KRJE1atWmW3dPW/7cLw70lkOi1eEx+1W7pV5ei55gT7qFR1mU6nY/HixaSkpBAVFcW///5LamoqixcvtsubkyDUhPKu91LRdWGqo1pn48aN9OzZk3fffdeu6UpWK6nvfA6A57gHUQb42DV94eZXpcGYrq6utGnTxl55EYQaVd4SR0VLJo6u1lm7di2PPPIIgwcP5vXXX7dr2jmb/sJ4/Dxyd1e8nhth17SFW0OlO5Xv3LmTkSNH0qVLF65evQrA119/zd9//223zAlCdSpviaOiJZPyLthVGStWrGDEiBGMGDGCVatWoVKp7Ja2ZDaTOm8ZAJ7PDEPhLaqmhIqr1G/7d999R58+fXBxcSEqKgqDwQBAVlaWXQZ8CUJNaBjgRkRdL4I8NUTU9Sq1xFHe46pDYmIiY8aMYfny5SiV9p0lKuvb3zGdi0Hu44nXU0PtmnZp9EYzvxyN5fMd5/nlaCx6o7laris4TqXGybRr145JkyYxatQo3N3dOXz4MA0aNODQoUPce++9xMfHOyKvDiXGyQi1ydGjR2ndOn/2Y0mS7DbVfgHJaCKmyyOYY+LwefNpvKupquyXo7H8cTIRrVJBrslCrxYBN+0iebWdQ8fJnD59usRlWj08PEhPT69MkoIglNP8+fNp06YNe/bsAbBbgJEsFnJ3RZG14Q+SZy7BHBOHIsAHzzHVNyzhalouWqWCBv6uuKgUXE3LrbZrC45RqfJ1cHAw586dKzaT699//02DBg3skS9BEErw9ttv88Ybb/Daa69VeLnksmT/tJ3k6e9jiU0qtF/XqytyndZu17mRUG8XTsRlcSEph1yThVBvl2q7tuAYlQoyTz31FC+88ALLli1DJpMRGxvLnj17mDJlCm+88Ya98ygItzxJknjttdeYPXs2s2bNYvr06XZLO/un7SSMeQ1KqDjPWvUTup6dcRvQ3W7XK0v3pgEAtvVjCraF2qtSbTIA06dP57333iMvLw8AjUbDlClTePvtt+2aweoi2mQEZ6bX67nzzjt5+OGHefHFF+2WrmSxcKn9kGIlGBsZKEICCDu4DplCUenriBmTbz7lfWZWKsjExMRQp04d8vLyOHHiBFarlRYtWuDq6srly5epV69elTJfE0SQESqiuh6aVquV5ORkAgICMBqNqNX2HciZuyuK2Acm3PC4kI0f4HJ7u0pf51Rcpm1qHYPZQkRdr3KN1hfByXmV95lZqeqy8PBw4uLiCAgIoEOHDrb9KSkphIeHY7FYKpOsINQKJouVP07EExWTgaeLCi9d/tgUe09xYrFYePLJJ9m2bRvHjx9Hq7V/24g5IcWux5WmslPriOn8a79KBZnSCj/Z2dkO+UMQBGdyPjGbqJh0cgwWpGsNGfaej8xsNvPYY4+xZs0avvrqK4f8XVkys8nZsqdcxyoDfat0rYJF1Co6tU51zfsmOE6FgkzBEswymYw33ngDnU5n+8xisfDPP/8QERFh1wwKgrNJ0xvx1KmQAL3BApLRrvORGY1GRowYwaZNm1izZg1DhgyxW9oA1lwDmcs2kPb+N1jTMss++FqbjLZz1aaPquzUOpUNToLzqFCQKViCWZIkjh49Wqh+WK1W07ZtW6ZMmWLfHAqCk/HWqfG67mHXvp59R/1HRUXx22+/8d133zFw4EC7pSuZzWSt3kzq/C+xxOU39KuahKHrfTsZH1+btfm6SgpJBjLAb9aEKjX6Q+VnTK7nqyM6OdvW26yer+7GJwlOpUJBpmAJ5tGjR/P++++LBnLhllTSW3l5GqNv1Iidl5eHWq2mU6dOXLx4EV/fqlVRFZCsVnJ+3E7qnM8wnb8MgLJOIN4vj8F9aB9kCgXayBbFxskYfX1QvDq+2rovQ/HvyGyxkplnwddNS0aumZgUvWiTqWUq3YX5ZiN6lwkVUZleTwU9rBRyuJKqJ9hLS9s63jQMcMOQq+e+++4jMjLSbtP1S5JE7rb9pP7vUwyHTwMg9/XEe+IoPB6/H7lWU/h4i4WD3+0k63IiymAf1ls9yTFJdGrow9Db6uHp4viqqmNX0vnjVCJGsxWVQkaguxYXtVIsseyEHNq7bM6cOQQGBjJmzJhC+5ctW0ZSUhKvvPJKZZIVBKdUUkCpTK+ngkZso8VCXKYBo0VCkmRkZWYwccxwjh49yqxZs+yS57yDx0mZ9Sl5f/8LgMzVBa9nH8Zr/FDk7iWv+SRTKHDr1p5zl9PZfS6JQ1cz8HZVsflo/lLrT9zZyC55K8vxuAzi0vPwc1MTn2FAJpMI8tCJNplarFIdzj/55BOaNWtWbH/Lli1ZunRplTMlCM6kpEXKru/1pFEqytXrqWCJgJgUPUgQ5qvDmJPJqCEDOXHiBH/88Qe33357lfJqPB1N/GOvcvXe8fkBRq3C86khhB1Yi89Lo0sNMAUKZpjWGyx4u6poFeqJi0rOpWR9lfJVbpLM1mNPQiLATes0M14LlVOpkkx8fDzBwcHF9vv7+xMXF1flTAmCMympG21lej0VPCBlgFIhQyGT8/3XnxB/JYatW7fSrl3lBzuaLseTNm8ZWet+A6sV5HLch92L90ujUdUNKnc6BQ30nRr6sPloAtFJOeSarIT5VU+De8sQD+IycjFZJEI8tbSpU75Bm4LzqlSQqVu3Lrt27SI8PLzQ/l27dhESIqblFm4uJQWUynTJLXiANwxw42x8Jhl5Zma+9RbqaS/QtHHlqqLMSWmkL/qajC83gtEEgGv/O/GZNg510/CyTy7D0NvyZ+24lKwnzE9n23a0psEeKBVyh60iKlS/SgWZcePGMXHiREwmEz169ADgzz//5OWXX7brvEqC4AzTipTWm6yyb9jxsVd59IEH+Oijj2jdpQtQ8V5k1qwc0pesJX3xGqSc/OnwtXe0x/e1J9FGtqxUvq7n6aKuljaYoqryvQrOqVJB5uWXXyY1NZVnnnkGozG/Llqr1fLKK68wbdo0u2ZQuLU5w7QiVX3wXR8oc5JjeWL4/UiSRGBgYIXTsuYZyPxyE2mLVmBNyQBA07YpPq89hUv3DnZfvEwQqqpKXZizs7M5efIkLi4uNG7cGI1Gc+OTnJTowuyc9pxPJj7DUKu7sBZ0XU6+epE3nhqGm6sLO7f9VaGJZCWzmax1v5E2bxnmq4kAqBrWxefVJ3C97y67BxdnKEEKzs2hXZgLuLm5cdttt1UlCUEo080wrUhSdh6JGbnMe+kpVFoXFq34vtwBRpIkcn7ekT+Q8swlABTB/vi8NBr3h/siU1bpT7hUzlCCFG4O5f4NnTx5Mm+//Taurq62OcxKs3DhwipnTBCg8nNeOZPsPDMXU/Tc++wsFDovtJ7lK4npdx4kddYnGP49CYDc2wPviY/iMXoQchfH1hqIiSkFeyl3kImKisJkMtn+XxCqQ3U2BDuiiujAgQPMe30mA1/4H67BbYhLz+VSqp5TcZmlpp936BSpsz4hd/sBAGQ6F7zGD8Xz2eEoPKonyN4MJUjBOZQ7yBTMW1b0/wXhZmHvKqLdu3fTt29fwhs3xUMFeRaJzDwLfm5w6HJ6sfSN52JInf0ZOT9uy9+hUuL52P14TRqFMsCn0vmoDGcpQYq2odqvQtVl5SGTyViwYEGlMyQINcWeVUTbt2+nf//+REZG8v2mH0jMlfH32SQa+LnRpq4nyVlGW/rm2ERS5y8na/UvYLGATIbbkN74vDwGVVjNjDurqa7EJU2QeTwuS7QN1WIVqi673sGDB7FYLDRt2hSAM2fOoFAoiIyMtG8OBaGalFZFVNG36ejoaPr27cvtt9/Oxg0bkB85R2hCCt2UWqL865CcZcxP35hH8psfk/nFBiRDfsDR3XsHPtPGoWnRsFru2dkULU3KZBJqhVK0DdVilaouW7hwIe7u7nz11Vd4e3sDkJaWxujRo+nWrZv9cykI1aC0KqLTcZn8eTIBo0VCpZDRs3kgrep4lZpOeHg4n376KffqAki64zHb9PlKoEOgH5lPj8QtKRnliu/JyMoBQNulLb6vPYW2Y2uH3qOzK1qazM4zcuRKCtl5Jty0KoZE1q3pLAoVVKlxMqGhofz++++0bFl4ZPGxY8fo3bs3sbGxdstgdRHjZITSrN0Xw7+X0/FzVZOUbSCynjfDOhbvgvzdd9+RlZXF448/TvZP20kY81qhRcBKom7VGN/XnsKlR8ebeiBleUuDp+IyOXQ5HY1SgcFsISPXyP7oVAqeUoPahzKgbWiVriHYR3mfmZX6CWRmZpKQkFBsf2JiIllZWZVJUhCcl0yyBQsZsvztIlatWsWwYcP4448/sJrNJE9/v+wAo1Dgv/R16vz5Obp7Ot3UAQZKnsm6JAWzQBfMuowEgR4udG3kR5CnC/EZeVW+hlC9KhVkBg0axOjRo1m/fj1XrlzhypUrrF+/nrFjxzJ48GB751EQalTLYE+CPDVIQJCnhpbBnoU+X758OSNHjuTRRx/lq6++wvDP0UIrTJbIYkEV5I9Mfmu8aZd3aYSCDgddGvrRLNiDer46ck0WLiTlkGuyEOrtUuVrCNWrUsOFly5dypQpUxg5cqRt7IxSqWTs2LHMnz/frhkUhJpW1szAa9euZcyYMTz11FMsXrwYuVyOOe4GAeaamFOXSQ2qe0tU7VR23E33pgEAXE3LJdTbxbZtz2sIjlWluctycnI4f/48kiTRqFEjXF3LXhDJmYk2mfKJz9Az/9eTRCflEu7vwkv3NifIs3rWGnFGiYmJLF++nJdffhksFrK//5OUWZ9iiU284bmn33kNQ7tWGMwWIure3OumVEd7iWiTqV7lfWZWKcjcTESQKZ8X1x7kr9MpKOUyTBaJHs18WTDs5ui2XpGH1JIlSxg4cCChoaH5k1eu30L6eyswXbiSf4BMBqX9acnAEuDH8eUfEOSjc8qJP53lge0s+RCKc2jDP8DOnTsZOXIkXbp04erVqwB8/fXX/P3335VNspDFixcTHh6OVqslMjKSnTt3luu8Xbt2oVQqiYiIsEs+hMKik3JRymXU9dGhUsiITsqtluuaLFZOxWWy53wyp+IyMVmsdr9GeRqOJUnizTff5JlnnuH7774jc+XPxHR5hKTnZ2O6cAW5jyc+058k4OPp+UtgFmnPlwr+vfIkBkly2qodZ2lEd5Z8CJVXqSDz3Xff0adPH1xcXIiKisJgMACQlZXF7Nmzq5yptWvXMnHiRKZPn05UVBTdunWjb9++xMTElHleRkYGo0aN4p577qlyHoSShflpyTVaOJ+YRa7RQpiftlquWx0Pmxs1HEuSxNSpU5k5cyZvDnmU+776m6SJ72C+GIvczwufN8YTdnAd3hMfxX1IHwKXzUIR7F8ojTxfH45NeQ7T3V1oFeLhtGvXO0sjurPkQ6i8SlWXtWvXjkmTJjFq1Cjc3d05fPgwDRo04NChQ9x7773Ex8dXKVOdOnWiffv2LFmyxLavefPmPPDAA8yZM6fU84YPH07jxo1RKBRs3LiRQ4cOlfuaorqsfLafjuedzadIyTHh66piar9mdG9a/jXkK6s61pUpOkajaDvJ5IkTee/993k9pA2PGfP3K/x98HruYTweux+5a/GeT5LFwpmf93Dp1BVSdK4c9AslPMADXze1U7fD3Oi7uNXyIRTn0PVkTp8+zZ133llsv4eHB+np6ZVJ0sZoNHLw4EGmTp1aaH/v3r3ZvXt3qectX76c8+fP88033zBr1qwbXsdgMNhKYJD/hQk3Fp9hpE09H9vAxPiM6nmzrI6eQw18XZDtP0LO1SRcQ/0JjwgG8lejzPrmJxps2M1M14aMMHqgCPTF6/lH8Hj0PuS60ktzMoWCBv27It2WTczZJMLNUrG5y5yRs0yQ6Sz5ECqvUkEmODiYc+fOUb9+/UL7//77bxo0aFClDCUnJ2OxWIotTRsYGFhqCens2bNMnTqVnTt3oiznIk5z5sxhxowZVcrrLakcAxMdwdEPm+yftpM8/X2UsUkUjIKJDfZD3f02vtv4PX1zNfSVaVA0aIP3hEdwf2RAudd0uX6yyUOX0/+bu8zJ2mGuV1MTZDprPoTKq1SQeeqpp3jhhRdYtmwZMpmM2NhY9uzZw5QpU3jjjTfskrGiI6AlSSpxVLTFYmHEiBHMmDGDJk2alDv9adOmFZpZOjMzk7p1xbxIN9Iy2JPY9DxMFqnEgYmO4siHTWlTwOTFJvLcp+/xqzGZho17cPvU53Af0Q+5tnILhom3cuFWVKkg8/LLL5ORkcHdd99NXl4ed955JxqNhilTpvDcc89VKUN+fn4oFIpipZbExMRipRvI72xw4MABoqKibNe2Wq1IkoRSqeT333+nR48exc7TaDRoNI5dXfBmVNbAxNpIslhKnALGIFmZmHWKbcY0Pgpqx72HfkKuK320eVkyco2s2x/DpWQ9YX46ht5WD51aKbrmCreESi8Q/r///Y/p06dz4sQJrFYrLVq0wM2t6g8ctVpNZGQkW7ZsYdCgQbb9W7Zs4f777y92vIeHB0ePHi20b/HixWzdupX169cTHh5e5TwJ/7nZqi/y9h4pNgVMnmTh2cxT7DGls9ijOXebXTFEncLl9naVusa6/TH8ciwBrVLO8bj8uf26NQ6w6wJpguCsKhxkTCYTvXv35pNPPqFJkyZ06NDB7pmaPHkyjz76KB06dKBLly58+umnxMTEMH78eCC/quvq1ausWLECuVxOq1atCp0fEBCAVqsttl8QijInpBTbJ0eGTqbgE48WdFN7l3pceV1K1qNVymkY4Mb5xGwuJetpFWq/BdJuJmLw5c2nwkFGpVJx7Ngxh84aO2zYMFJSUpg5cyZxcXG0atWKzZs3ExYWBkBcXNwNx8wIwo2YLFZilVrbH0G21cwVq4FmSlc+9GhW6NhYpZbU88mVevCF+ek4HpfF+cRsck1Wwvx0Yp6tUth7CWyh5lVqnMyLL76ISqXinXfecUSeaoQYJ1M+N9Ob5qm4TOI+XE29z74h02pmdOZxUqxGtnhHopLl35MkA6OvD1sWvUOonxtmi1ThsRqiTab8qmM8lGAfDh0nYzQa+fzzz9myZQsdOnQoNjHmwoULK5OsUAv8eTKOtzYdIyvPgrtWwVv3t+LeViUvIuXs9Bv/pN5n35BmNfF4xjGuWA185dmqUIBBgn9HDONqlpEgH6lSo849XdQ8cWejYvvFG3pxooR386lUkDl27Bjt27cH4MyZM4U+u9kXX7rVzf7pOPFZZmRAjsnM7J+O18ogo/9rHx4zPyDZauQR6SLJLgq+CexG8wyT7RhrgB+XnxyFqkskskvpxKToqeujEw8+BxLdvG8+lQoyf/31l+3/C2rbRHC5NaTozQColTIMZsm2XZvk/XuC+MdfQ2Y2E9+5NZbT6axc9g29Okdg2X8Mc0IKykBfLoXVJyk2C6VMTpCnhhBPF9reYJ4xvdHM9tOJhdY/0akr3YnzlnOz9V4UqtCF+YsvvuC9997j7NmzADRu3JiJEycybtw4u2VOcD4+OjU5RgMGs2Tbrk2M52KIe/hl4rLSCeneiYFrFjBAqUChUOQfcF035YYWKyjyq8dahXqUq91k++lE/jiZiFap4MS17sp9W4c47H4EwdlVKsi8/vrrvPfeezz//PN06dIFgD179jBp0iQuXrxYrrnDhNrp3laBrNgTg9kCSkX+dm1hjksibshkopPiGZV7mgF+EUxMzSu18b0yb9VX03LRKhU08HflQlIOV9OqZykEQXBWlQoyS5Ys4bPPPuPhhx+27Rs4cCBt2rTh+eefF0HmJubn5sI9zYPwcVWTmmPEz61yo+CrmyU9i7hhUzhzMZpHc07iUTeE4eMn2b27bKi3Cyfissq1Jr0g3AoqFWQsFkuJgzAjIyMxm2tfHb1QfvV8dZxOyEaSQKtSUM/X+ZdeturziH/kFY4dPcqo7BP4h4fx5/ZtXNSrUF/XXdYeAyIrsib9reRm6vouVEylgszIkSNZsmRJsa7Kn376KY888ohdMiY4p9r2EJXMZhKeeJO8fUf5XZZFaKMGbNmxHX9/fzLiMu3eXVanVoo2mBKIQZa3rio1/P/+++907twZgL1793L58mVGjRpVaHZjMWbm5mKyWLmSpudSsh4JySHLINuLJEkkTZ5P/K878NTp+N/ab7G2aWSbY89Zu8vejG/9169wKabRubVUeZzM+fPnAfD398ff359jx47ZjhPdmm8+JU32WNJAQ2eQ+vYnbF2xmqcyT/LNi/No0DWi0OdFG/ZNFiun4jJr/OF+M771i0GWt64qj5MRbi0lTfZoL/Z8g09fsoZfFnzMk5kn6NC8FXdPeOKG5zjLw/1mfOt31lKj4HhilJhQISVN9mgv9nrIZ637lQ1TZ/FM5klub9aSnw7uxcXlxr28nOXhfjO+9YtBlrcuEWSEChnUvg6JmXlEJ+lpGeLOoPZ17Ja2PR7yOX/sJW7CHOblRHN34xZ8H7UPrVZbrnOd5eEu3vqFm4kIMkKFJGcZaRToSYsQbwxmC8lZRvzcyvcQv5GqPuTzDhzn6pjpKCxW1o94liafvY26AqufVtfD/UbVguKtX7iZiCAjVEhcRi6n4tIxWSRUChmBHhq7PRCr8pA3nrnI4gGP8FnyWTYMeJyWn/8PmVpVoetX18PdWdp+BKE6iCAjVMjJ2Ey2n07CYgW5HPzdtNzdzD5Ty1T2IW++msCCHoOZnnCYYaFNafTVOxUOMFVVkU4LztL2IwjVQQQZoUJOxWaSqjeikMmxWK2cis2s0fxYUjOY3W0Ab16N4tHAxnwRtRuVp3ul0qpK77aKlE6cpe1HEKqDCDJChSRk52GySCiUYLJKJGTn1VherDm57Lz/SWZEH2CcXyM+3r8Dlb9PpdOrSjVWRUonomFfuJWIICNUiL+rBjn5b/3ya9s1wWo0ET/mdeqeiuX7Ol3p8/vXqOsGVSnNqlRjVaR0Ihr2hVuJCDJChfh7aFCrFCABsvzt6ma1WJjcpSfWk9FM9GtC7/XvoWneoMrpVqUaS5ROBKFkIsgIFeLvoaVFsDsalQKDyYK/h326L5eXJEk8d0dvlvy7g6luDQj8fCbajq3tknZVAoUonQhCyUSQESqknreOf5UKkGRolArqeVffVP9Wq5Un7+rDF3u38oZrA1787CNce3e1W/qlBYqbccJKQaguIsgIFVLH24UgTy2ZeWbcNUrqVOOiXO+OeYZlO//gf26NeHrODDyG962W64pxLYJQeSLI1FI19Xada7YSGeZraxzPNVfPVP85v++i3y/H8fNowcCJz+D93IhquS6IcS2CUBUiyNRSNfV2Xd1jPEwmE08PHcGDey7TxKqi/8iH8X3zaYdesygxrkUQKk9ULNdS179da5SKanu7ruerw9NFSUp2Hp4uSocuv2wwGBh8bz++3vgdV3Kz0PXsTMCiqcjk1ftr2zDAjYi6XgR5aoio6yV6jglCBYiSTC1VU2/XMSl6MvMs+Lppycg1E5Oid0gJSq/XM6jfALbv2M4Sjxb07noHgZ/PRKZy3K9saVWQjug5JjoTCLcKEWRqqZoal1Fd7RMPPzSEnTt38Ll7C+5s1ZbglXORuzq2k0F1VkGKzgTCrUIEmVqqpsZlVEcJypqt57FkBQ+7t6BzWGNC1i1A4eNp9+sUVVoAdUSpQ3QmEG4VIsgIFeLIElRaWhrvzpvHU2fzaBudhjygLiHfLkAZap9Znm+ktADqiFKH6Ewg3CpEkBEqxFElqKSkJHr37k3MqTP0cGlOuLs3wavmoW5S3+7XKk1pAdQRpQ4xDY1wqxBBRqgQR1QdxcfHc88995B06TLfaJtSX+NG4Bdvo+3QssTj9UYz208ncjUtl1BvF7o3DUCnrvqvcmkB1BGlDjENjXCrEEFGqBB7Vx2lpqbSvXt3shKSWKlqTAOljoAPpuHas3Op52w/ncgfJxPRKhWciMsCoG/rkErn4UZEqUMQKk8EGaFC4jL1nIxLx2wBpQICPdVVCjLe3t4MahlJ76QzhClc8J35LO5D+pR5ztW0XLRKBQ38XbmQlMPVtNxKX788RKlDECpPBBmhQs4lZHPkSiZapZxck5VQLxfublrxdM6ePcvZs2fpLvfgiV1XQeGC1/Mj8Hp6+A3PDfV24URcFheScsg1WQitxvnTBEGoGBFkhArJM1uQJCt5JgmQyDNbKpzGiRMn6NmzJ35u7nynD0VuteI+vC8+r48v1/ndmwYAFGqTEQTBOYkgI1RIpt5ESrYJtVKOwWwlU2+q0PmHDx+mV69eBHr78IUhBLnBhK53V/zfexmZTFauNHRqpUPbYARBsB8RZIQKyTNYuJKWi5X8ie/yDOUvyfz777/07NmT+qF1+MJUB4+cLLQdWxP42QxkSvGrKAg3IzFZklAhPx25QsHk/tZr2+Xl6+vLPd3u5CtNEzxSslA3b0DQyrnIddW7uqYgCNVHBBmhQnKMZW+XZO/evaSmplLX15/39P7oLiWgrBtE8Np3UXi5OyajgiA4BVFHIVSIWSp7u6gtW7Zw//3388TYcbwSp8Bw6BRyX0+C1y1AGezvuIwKguAURJARKsRDA+mGwtul+fnnn3nwwQe5p0cPJuW4k7t9BzKdC8Gr56NuVM/xma2g0mYzENPyC0LliSAjVIhaIQesRbaL+/777xk2bBj9+/fno0a3k/flJlApCfpyFtp2zasptxVT2mwGYlp+Qag88TomVEigZ/7AR1mR7aISEhJ48MEH+aTrffkBBgj4aDq6uztWRzYrpbTVRmtqFVJBuBmIICNUiEqRH16kItsFjh49CsD48eNZ2nc4WfOWA+D7vxdwH9yz2vJZGd46NQazpdhEmKXtLy+TxcqpuEz2nE/mVFwmJov1xicJwk1CVJfZ0a1Qd3/5WnVRSdtLly7l6aef5vfff6eLUUPySwsB8Jr4KF5PPlSt+ayM0ibCrOoEmaK6TbiViSBjR9X1MKnJYKY35b+FK2Rgkf7bXrRoEZMmTWLChAnc7uJL/OiXwGrFfeQAfF59olryVlWlTYRZ1QkyxSqYwq1MBBk7qq6HSU2+GbtrleRkm7FI/22/8847TJs2jZdffpkZI8cRd//zSAYjrv264T//xXJPF3OzclHKOXgxhSyDGXeNkoci69R0lgSh2jhtXc7ixYsJDw9Hq9USGRnJzp07Sz12w4YN9OrVC39/fzw8POjSpQu//fZbNeY2X1Xr7surJhuiH+5UHw+NDK0CPDQyhrQP4aeffuLNN9/k7fETiB8+BWtWDtrObQlY+qaYLga4kpZLfJaBbIOFuMw8rjh4aQJBcCZOGWTWrl3LxIkTmT59OlFRUXTr1o2+ffsSExNT4vE7duygV69ebN68mYMHD3L33Xdz3333ERUVVa35bhjgRkRdL4I8NUTU9XLY4lbVFcxKElnfi4YB7gS4awnVmujUOICtW7fy2rMvED9sCpbEVNQtGxL0zRzkLmUMormFxGfmEeiupWtDX4I8XIjPzKvpLAlCtZFJknSDMdvVr1OnTrRv354lS5bY9jVv3pwHHniAOXPmlCuNli1bMmzYMN54440SPzcYDBgM/40qzMzMpG7dumRkZODh4dyNsjXZJjP6811sPZtG2p+foj+9i2GzV/P147dz9YEJGI+cQVkvmNCfF6MM8quW/NQGvxyNZcuJRFxUCnJNFnq1CBCzSAu1XmZmJp6enjd8ZjpdXYbRaOTgwYNMnTq10P7evXuze/fucqVhtVrJysrCx8en1GPmzJnDjBkzqpTXmlKTKzVuPZtK2m+LyTr8K53bDSLvn9PEbf0J45EzyP288qeLEQGmELH+jXArc7ogk5ycjMViITAwsND+wMBA4uPjy5XGggULyMnJYejQoaUeM23aNCZPnmzbLijJCKWzWCwo171N9sX9vOPWmIcuJ8HldeQBaNSErHkXdUPxHRYl1r8RbmVOF2QKFO2RJElSuXoprV69mrfeeotNmzYREFD6G6NGo0Gjqb1tBjVRZXZgyQoSLv3Lu25NuF9b5Ls1GDFdjkfTthJrMQuCcNNyuiDj5+eHQqEoVmpJTEwsVropau3atYwdO5Zvv/2Wnj2de3R5VRV0Y1bIZOy9kEKIpwttr3U2sHewMRqNyIHAz35kq3cHfOWq4gfJIPm1D3DtewcyhcKu1xcEofZyut5larWayMhItmzZUmj/li1b6Nq1a6nnrV69mscff5xVq1bRv39/R2ezxhV0Y7ZIVuIzDMSk6jl0OZ3zidl2vU5eXh6DBg3iyYeGY4lNKjnAAEhguZpI3t4jdr2+IAi1m9MFGYDJkyfz+eefs2zZMk6ePMmkSZOIiYlh/PjxQH57yqhRo2zHr169mlGjRrFgwQI6d+5MfHw88fHxZGRk1NQtOFxBN+ZLKXokJOr56uw+ZiYnJ4f77ruPv/76i4ER5ZvY0pyQYrfrC4JQ+zlddRnAsGHDSElJYebMmcTFxdGqVSs2b95MWFgYAHFxcYXGzHzyySeYzWaeffZZnn32Wdv+xx57jC+//LK6s18tCsbgyGQSaoUMhVxm1zEzWVlZDBgwgIMHD/LLL79wG67ELfn5hucpA33tcn1BEG4OTjlOpiaUt8+3s3FUB4C5c+cye/ZsfvnlFzpHtCd+7Ovk/rEXif+m+S9EBoqQAMIOrhNtMoJwCyjvM1MEmWtqW5BxVHAp6MVnsVi4cOECDQKCiH9kKnn/HAG1EsloRqJIPeu1qBO4bBZuA7pXOQ+CIDi/8j4znbJNRrixgt5l8RkGuzX4JyYm0rlzZ/78808UCgXhnr7EPvACef8cQe7hRsh377Ph0UdJdnUvdJ4iJEAEGEEQSuSUbTK3ksqWSOw943NsbCz33HMPGRkZBAcHY4qJI+6hyZiir6Dw9yF43QI0rRqxeONllg59mtYJl/HR55Cqc2XjsqdEFZkgCCUSQaaGVXbafneNkr0pKZxNzEKlkNE4wLXSeYiJiaFHjx4YDAa2b99OmFXJ1f7PYIlPRlkvmJBvF6Jq8N/09Fa5nMPBYbZtEWAEQSiNqC6rYVWZtl9OfnNIVVZrkSSJESNGYLVa2bFjB3WzTFy97zks8cmomtYn9KePCwUYQRCEihAlmRrmrVNzKUVf4Wn703KNKBUKXDVyDGYrabmVqy6TyWQsW7YMFxcXfC8kEDvqVSR9LprIFgSvmofCx7NS6QqCIIAoydQok8WK2WJFJpMwmi20CvEo9xo0qVlGdpxJYPPROHacSSA1q2JB5tixYwwaNIisrCyaNGmC9+ELxI14GUmfi0v3DoSsf6/EAOOjlZW5LQiCcD1RkqlB5xOzOR6XhVqhxGC2oFTIy9Xob7JY2RudTHymARe1ArkJziVllfu6UVFR9OrVizp16mAwGJA2bSfpxflgteI6oDuBS99Apim5ROXvqSU1L7fQtiAIQmlESaYGVbY95nxiNlfSclHK5GiUcowWiQtJOZyKy8RksZZ57r59++jRowcNGjRg69atKNb8TtKkuWC14v5IfwI/n1FqgAFILLJ0cNFtQRCE64mSjB1VtDtypdtj9EbCfV3JMVpIzzFhtkj4uKk5dDkdKL13Wnx8PL169aJVq1b8/PPPWD5cQ9oHKwHwen4EPq+Pv+FyCmnGsrcFQRCuJ4KMHZ2Oy+SPU4kYzVZUChm9mgfSqo5XiccWtMdYrVYupeoJcHPBbLFislhvWGXmrVPTKMgDuULGybgsAj003NM8kLQcc5mloaCgIJYuXcqAfv3Im/kpmSt+AMDnjfF4P/9Ipe9bEAShNCLI2NHxuAzi0vPwc1MTn2HgeFxGqUHmfGI2R65mcDYxi3OJOTQMMCHJQFmOpZULOgc0DnCjU7gvqTkm0nLMpZaGfvvtNy5cuMDTTz/N8AcfIuHpt8n54S+Qy/F/dwoej95X7nvUAIYi24IgCKURbTL2JMmQyJ8KTkICqfSqpzS9kTS9kaQsExarRGKWkfRr+25EpZDTMMANb50arVqO3mgiISsXTxcl9Xx1hY794YcfGDhwIL/++ivmzGziHnklP8ColAR+9laFAgxA0Rhmp0mfBUG4SYkgY0ctQzwIudbbKsRTS8uQ0ksk3jo16XoTVklCIZeBBOl6U7nbZQpmCjhxNZtTcdlYrBIZuWZirs0aAPDtt9/y4IMPct9997Hm0y+IHzaF3G37kelcCF41D7eBd1f4HkWbjCAIFSGqy+yoabAHSoW8UMN/aRoGuNE21JPkrDysagUapYy2oZ7lHidT0DNNo5SjVSnQKhWFeqj9/PPPDB8+nIcffpjP33mXpCEvYjx5AbmXO8Gr56Pt0NIu9ywIglAWEWTsSFWO9pTrj63v50qYn5uto0B9P9dyj5NJyMjlp8NxZOQakMnk+LupcdUobSWhrl27MnPmTF4c9igJD0zAfCkORaAvwd8uRNO8QZXuUxAEobxEkKlBSTl5pGQZSM0xYrJY8XVT0zTY44aB5nxiNnsvpHAxJQejxYpaLsNotRJR14u/N3+LpkcPwsPDmfLAMOLufx5LYirK+qGErF+IKiykmu5OEARBtMnYlcli5VRcJnvOJ5drYOS5hGz2XUzl8JV0Dl1J5/fj8ZyOy7zhddL0RuIy8rACWpUCqyTDYLby86rPeGLcONasWUPuP0eIvRZg1C0b5k90aYcAE6Are1sQBOF6oiRjRxWdtt9ksWKxWrBKEiqFjAy9ieOxmaV2ey7grVOTZ7KQmWtCrZAjIfHnqiVsX/0xr776KhNuu5u4IZORcg1oO7YmaNVcFJ7uZaZZXtnGsrcFQRCuJ4KMHVV4ITEJ0nJM5BgtyGVytAo5Zqns0g/kdxpoF+ZNUlb+iJXYv77h9C/Lefvtt3mhZWfiH50KZgu6ezoTuOxt5Dr7zS9msZS9LQiCcD0RZOygYDqZmBQ9sRm5WKwSZqv1ht2RkzIMZOWZMVoBLBhy8wg+dYasq+dQBvqi7dymxAXBVAo5zQLdOXwxldALF6gj09J96BM8F9KCxKdmgiThNrgnAR++ikytsuu9KhVgMBfeFgRBKI0IMnZQUE2mUMiQARbJQkRd7xt2R953KRmLBAoZ3B59mmf/+QO/7CwSr32uCPHH738v4Dage6HzTBYrOT9vp9VHCxhqcUV+bb6xlK0nAfAYPQi/dyYik9u/yS3ES83ZZGOhbUEQhNKIhn87KKgmC/XSEebrRj0fV5rdoJeYyWIlI9eCVcoPMG/8+T2+2YWn67fEJZEw5jWyf9peaP+pFT/zw/+m8HriIaLMxaf4d+nW3iEBBkBxregiK7ItCIJQElGSsYPKzKZ8PjEbtVKGUrLyzN4/gBKWUc6foYakSXOxJKchU8gxmcxMn/AcPxsSmOfWhEiVR7FTkl//ENd+3UqsaquqLL3h+qzZtgVBEEoigowdFFSLlWekf4E0vZGWdType/Yc/jllLzhmTc8i+aUFmCQrk7POsMWYwnvuTemv8S92rAywXE0kb+8RXG5vV6n7KUuq3lrmtiAIwvVEkLGDioz0L+CtU6OWyfHOzSnX8eq2TVEH+qL9M4EP9f700viWebw5IaVC+Skvo7nsbUEQhOuJNpka0jDAjWyjmViVyw2PzZMsXB7Rg7or57Lyu29vGGAAlIE3PqYyVLKytwVBEK4ngkwNUSnk6A0WToWEkezqTmmVTjlYeMJwjsGvTkav16Pt3AZrgG+pxyMDRWgA2s5tHJJvq1T2tiAIwvVEkHGQ8kwxE+qjA4WcT7r2QgbFAkeWZGZM+nGOWXL49ttv0el0yBQKDFOeKPH4gp4DfrMmOKTRnxKuKVpkBEEoiwgyDlIwdiY+w8Chy+mcT8wudsxDkXUI99exr1FT5t37IBnu/7XrZFjNPJ57mnMaC1v+2kq3bt1snx1u0ZLZPQeR7Fp4qhhLgB/mBa+i6dsNQRAEZyAa/h2kPFPMtAjxpHmQBxk5Zi5FtOXdLh0YrsjkHn812YYsDDNfZevKlbRv377QeVdSctka1pRtdRvTOuEyPvoczN4eDH6iHwZJQkrMrnBHhPJy08pIz5MKbQuCIJRGBBkHKc/YmZgUPYlZRoxWKxazDHOOicN1PejVpx1tXV05NnQgihKqvS6m5Hd5tsrlHA4OA8BDDQOQUMrl5VrCubK8XJSk55kKbQuCIJRGVJc5SMMANyLqehHkqSGirlexsTMmi5XDV9LI0BtAkpAhQ5+WyMcvPcr48eMBSgwwAMiKlx7yTLDp0FUuJGeVewnnylAp5WVuC4IgXE+8hjrIjcbOnE/MJj49DwkZIEPKTODoZ1PQKmW89dZbZabdOsSLbScTyb1uBmQPnZLkLCMZemO5l3CuDLM5f342pRzM1vxtQRCE0ojX0BqSpjcS5OVCsyB35FnxHFwyEa1KwarvN9OwYcMyz+3YwIcADy2Ka893OeB2bellV7WyXEs4V5aHTolcBpIEcln+tiAIQmlEkKkh3jo1sem55JmtZJz+B4Vay9g5y+nUttkNz801WfB2U+OmUaBV5M8jlpJjIMdooqmDGvwLNAv2wlUjR6uU46qR0yzYy6HXEwShdhOvoTWknq+OvJxMLqbk4dv5AQI79eNstoadp5MYEBFa5rnZeWYy9CYsVpDJZICESi5Hp1KglDu2+spNo0CnViInfyCmm0bMwiwIQulEkKkmBQubFUyieTjqX/73+GAaD34RRYOO+Op0GM1WjlzJKBZkip6rVcsJ99ORZ7SQkWvBRQkd6vvg7+5iWy3TUTx1aoI9tMhkMiRJwtOBnQwEQaj9RHVZNbl+cOa6zX8yZtj9hNStT/PIjlitVvRGExZJQq0uXhIpOrAzJ9eCVqlErVLgplWhVSnJMVjINVkI9b7xXGhV4aKSI5PJ0aoUIJPjohK/QoIglE6UZKpJweDMuNP/MmfCo9Rv0pKXFy7nfLqFuJwk1Eo57holjf3dSz23YGBnaq4Rg9mKi1qBl06Fh1ZJHW8ddzbxp3vTAIfeR6NAN9qk52K25C+93CjQcT3ZBEGo/USQqSbeOjUXk3N4f85bNGoZwRffrMXD3R3DmUS6N/MnxFOHyWrFx6149VPRgZ1JmXnkmixolQrMVit+bloe6xrusFH+1wv20NEsyBONUoHBbCHYQ+fwawqCUHuJIOMgRdtRQjzVUM+bRV+sItjXi5Zh/rauxocup9se2v5u2mJpFV0U7Up6DgazFY1Sjl5vwVUjd+jYmLLyUl3XFQShdhJBxkEK2lHUCjnrv9vAps8XsvvvHTS7rWmh4+r56ohOzuZqWi6h3i7U8y1eMig6sPPfS6n4uanRqpR46VS0CPZ06NiYsvIiCIJQFtFq6yAF7ShHd25m0avPUL9Jc7y9vYsdF5OiJzPPgq+bloxcMzEp+hum3aaOF82DPajj7ULzYA/a1PFywB0IgiBUnSjJOIi3Ts2KFSv4dNZL3H7vIN5f+hkqlarYceWZrbmopsEeKBVyUWUlCILTE0HGQVR5qXwxZyr3DR3J7Hffp3FQyVVM5ZmtuVjaospKEIRaQgQZB5AkiYbh9dmzZw/t27fHbJUKdQJoGOBma0MRDemCINzMRJCxs7lz55KUlMT8+fOJjIwE4Hxilq0TwKVrbS4FJRFRKhEE4WbmtA3/ixcvJjw8HK1WS2RkJDt37izz+O3btxMZGYlWq6VBgwYsXbq0mnKaT5IkZsyYwdSpU3F1dS302fXtLhqlwqGLigmCIDgTpwwya9euZeLEiUyfPp2oqCi6detG3759iYmJKfH46Oho+vXrR7du3YiKiuLVV19lwoQJfPfdd9WSX0mSePXVV3nrrbeYPXs2M2bMuDZxZT5vnRqD2VKhdhdBEISbgUySJOnGh1WvTp060b59e5YsWWLb17x5cx544AHmzJlT7PhXXnmFH374gZMnT9r2jR8/nsOHD7Nnz54Sr2EwGDAY/ptMMjMzk7p165KRkYGHR8Wqr7788ktGjx7Ne++9x8SJE4t9XnRg5vVtMoIgCLVRZmYmnp6eN3xmOl2bjNFo5ODBg0ydOrXQ/t69e7N79+4Sz9mzZw+9e/cutK9Pnz588cUXmEymErsOz5kzhxkzZtglzw8//DB+fn4MGDCgxM9Fu4sgCLcqp3udTk5OxmKxEBgYWGh/YGAg8fHxJZ4THx9f4vFms5nk5OQSz5k2bRoZGRm2f5cvX650njUaTakBRhAE4VbmdCWZAte3aUB+u0fRfTc6vqT9BTQaDRqNpoq5FARBEMridCUZPz8/FApFsVJLYmJisdJKgaCgoBKPVyqV+Pr6OiyvgiAIQtmcLsio1WoiIyPZsmVLof1btmyha9euJZ7TpUuXYsf//vvvdOjQocT2GEEQBKF6OF2QAZg8eTKff/45y5Yt4+TJk0yaNImYmBjGjx8P5LenjBo1ynb8+PHjuXTpEpMnT+bkyZMsW7aML774gilTptTULQiCIAg4aZvMsGHDSElJYebMmcTFxdGqVSs2b95MWFgYAHFxcYXGzISHh7N582YmTZrExx9/TEhICB988AEPPvhgTd2CIAiCgJOOk6kJ5e3zLQiCIJT/memU1WWCIAjCzUEEGUEQBMFhRJARBEEQHEYEGUEQBMFhRJARBEEQHEYEGUEQBMFhnHKcTE0o6MmdmZlZwzkRBEFwfgXPyhuNghFB5pqsrCwA6tatW8M5EQRBqD2ysrLw9PQs9XMxGPMaq9VKbGws7u7uZc72XJKCBc8uX758SwzkvNXuF8Q93wr3fKvdL1TtniVJIisri5CQEOTy0lteREnmGrlcTp06daqUhoeHxy3zywm33v2CuOdbwa12v1D5ey6rBFNANPwLgiAIDiOCjCAIguAwIsjYgUaj4c0337xlVtq81e4XxD3fCm61+4XquWfR8C8IgiA4jCjJCIIgCA4jgowgCILgMCLICIIgCA4jgowgCILgMCLIlNPixYsJDw9Hq9USGRnJzp07yzx++/btREZGotVqadCgAUuXLq2mnNpHRe53w4YN9OrVC39/fzw8POjSpQu//fZbNebWPir6My6wa9culEolERERjs2gnVX0fg0GA9OnTycsLAyNRkPDhg1ZtmxZNeXWPip6zytXrqRt27bodDqCg4MZPXo0KSkp1ZTbqtmxYwf33XcfISEhyGQyNm7ceMNzHPLckoQbWrNmjaRSqaTPPvtMOnHihPTCCy9Irq6u0qVLl0o8/sKFC5JOp5NeeOEF6cSJE9Jnn30mqVQqaf369dWc88qp6P2+8MIL0ty5c6V9+/ZJZ86ckaZNmyapVCrp33//reacV15F77lAenq61KBBA6l3795S27ZtqyezdlCZ+x04cKDUqVMnacuWLVJ0dLT0zz//SLt27arGXFdNRe95586dklwul95//33pwoUL0s6dO6WWLVtKDzzwQDXnvHI2b94sTZ8+Xfruu+8kQPr+++/LPN5Rzy0RZMqhY8eO0vjx4wvta9asmTR16tQSj3/55ZelZs2aFdr31FNPSZ07d3ZYHu2povdbkhYtWkgzZsywd9YcprL3PGzYMOm1116T3nzzzVoVZCp6v7/88ovk6ekppaSkVEf2HKKi9zx//nypQYMGhfZ98MEHUp06dRyWR0cpT5Bx1HNLVJfdgNFo5ODBg/Tu3bvQ/t69e7N79+4Sz9mzZ0+x4/v06cOBAwcwmUwOy6s9VOZ+i7JarWRlZeHj4+OILNpdZe95+fLlnD9/njfffNPRWbSrytzvDz/8QIcOHZg3bx6hoaE0adKEKVOmkJubWx1ZrrLK3HPXrl25cuUKmzdvRpIkEhISWL9+Pf3796+OLFc7Rz23xASZN5CcnIzFYiEwMLDQ/sDAQOLj40s8Jz4+vsTjzWYzycnJBAcHOyy/VVWZ+y1qwYIF5OTkMHToUEdk0e4qc89nz55l6tSp7Ny5E6Wydv0ZVeZ+L1y4wN9//41Wq+X777/n/+3de1BU5RsH8O+yy9rGTVgEVnAkrrMMd4jiopAsMRGal6gJUjBkxhphgKDxgnipGUyDUkRGGVphBtTJctSamGW4BVpYsGsOkBZCpFFmUiBgCJzfHw47Lgt7wwPi7/nM7B+e877vPs/L8Tx73j1wbt++jXfeeQd37tyZF9/LGJJzSEgIKioq8Prrr+PevXsYHR3FqlWrUFhYOBshzzq2zlt0JaOjyX/+n2EYjY8EmKr9VNsfV/rmO+HEiRPYvXs3Tp06BRsbG7bCY4WuOY+NjSE+Ph579uyBm5vbbIX3yOnzMx4fHweHw0FFRQWCgoIQExODgoICHD9+fN5czQD65dze3o60tDTk5uaipaUFVVVV6OrqwubNm2cj1DnBxnlrfn0EmwPW1tbgcrlqn3Zu3bqlVvUn2NnZTdmex+NBKBSyFuujYEi+E06dOoXk5GR89tlnkEgkbIb5SOmb88DAAH744QfI5XJs2bIFwIOTMMMw4PF4kMlkWLFixazEbghDfsYikQj29vYqf9pdLBaDYRjcuHEDrq6urMY8U4bknJeXh9DQUGRnZwMAvL29YWJigmXLluGDDz54rFckDMHWeYuuZLTg8/kICAhAdXW1yvbq6mqEhIRM2Sc4OFitvUwmQ2BgIIyNjVmL9VEwJF/gwRVMUlISKisr592atb45m5ub48qVK1AoFMrX5s2b4e7uDoVCgeeee262QjeIIT/j0NBQ/P7777h7965y27Vr1x7Jc5hmgyE5Dw0NqT2Mi8vlAtD+yOH5iLXz1oxuG/g/MXHrY2lpKdPe3s6kp6czJiYmTHd3N8MwDLN161Zm/fr1yvYTtwJmZGQw7e3tTGlp6by8hVnXfCsrKxkej8cUFRUxvb29ytc///wzVynoTd+cJ5tvd5fpm+/AwADj4ODAvPrqq0xbWxvT0NDAuLq6Mps2bZqrFPSmb85SqZTh8XjMkSNHmM7OTqapqYkJDAxkgoKC5ioFvQwMDDByuZyRy+UMAKagoICRy+XKW7Zn67xFRUZHRUVFzNKlSxk+n8/4+/szDQ0Nyn2JiYlMeHi4Svv6+nrGz8+P4fP5jKOjI1NcXDzLEc+MPvmGh4czANReiYmJsx/4DOj7M37YfCsyDKN/vh0dHYxEImEEAgHj4ODAZGZmMkNDQ7Mc9czom/OhQ4cYDw8PRiAQMCKRiElISGBu3Lgxy1Ebpq6uTuP/y9k6b9Gf+ieEEMIa+k6GEEIIa6jIEEIIYQ0VGUIIIayhIkMIIYQ1VGQIIYSwhooMIYQQ1lCRIYQQwhoqMoQQQlhDRYYQQghrqMgQQghhDRUZQmZRREQE0tPTDd4/30zO50nLj2hHz5MheouIiICvry8++eSTuQ7lifPFF1889o+DmInJ+dGx9OSjIkNYMTIyAj6fP9dhzDtWVlZzHQKrnvT8iDpaLiN6SUpKQkNDAw4ePAgOhwMOh4Pu7m5ERERgy5YtyMzMhLW1NaKiogAAjo6Oap9SfX19sXv3bgAPHv60f/9+ODk5QSAQwMfHB6dPn9YYg7YxgQefkNPS0vDee+/BysoKdnZ2Kvsn2qSmpiI9PR2WlpawtbXFsWPHMDg4iI0bN8LMzAzOzs74+uuvVfpVVVUhLCwMCxcuhFAoRGxsLDo7O5X7T58+DS8vLwgEAgiFQkgkEgwODk6ZS1VVFSwsLFBeXq6MafLykrY8BgYGkJCQABMTE4hEInz88cdal6UGBwexYcMGmJqaQiQSIT8/X62PLvOsbS4me/g9pjuWysvLIRQK8d9//6n0XbduHTZs2DDt2JM1NTXB2NhYZZyuri5wOBz8+uuvOo9DZoaKDNHLwYMHERwcjJSUFPT29qK3txdLliwBAJSVlYHH4+HChQs4evSoTuPl5ORAKpWiuLgYbW1tyMjIwJtvvomGhoYZx1pWVgYTExM0Nzdj//792Lt3r9qT/8rKymBtbY1Lly4hNTUVb7/9NuLi4hASEoLW1lZER0dj/fr1GBoaUvYZHBxEZmYmvv/+e9TU1MDIyAhr1qzB+Pg4ent78cYbb+Ctt95CR0cH6uvrsXbt2imfpHjy5Em89tprKC8v13jy1JZHZmYmLly4gHPnzqG6uhqNjY1obW3VODfZ2dmoq6vDmTNnIJPJUF9fj5aWFl2nVqe50Ga6YykuLg5jY2M4d+6csu3t27fx5ZdfYuPGjTrHplAoIBaLsWDBApVtCxcuxNKlS/VLlBiMlsuIXiwsLMDn8/H000/Dzs5OZZ+Liwv279+v81iDg4MoKChAbW0tgoODAQBOTk5oamrC0aNHER4ePqNYvb29sWvXLgCAq6srDh8+jJqaGuVVFgD4+PggJycHALBt2zbs27cP1tbWSElJAQDk5uaiuLgYP/74I55//nkADz5RP6y0tBQ2NjZob2/HyMgIRkdHsXbtWuWJzMvLSy22I0eOYPv27Th79ixeeOEFg/MYGBhAWVkZKisrERkZCQCQSqVYvHjxtOPdvXsXpaWlKC8vV85FWVmZQY9R1jQXnp6eGvtOdywJBALEx8dDKpUiLi4OAFBRUQEHBwdEREToHNvly5fh5+ensk2hUMDHx0fnMcjM0ZUMeWQCAwP1at/e3o579+4hKioKpqamyld5ebnGJRddeXt7q/xbJBLh1q1b07bhcrkQCoUqRcHW1hYAVPp1dnYiPj4eTk5OMDc3xzPPPAMA6OnpgY+PDyIjI+Hl5YW4uDiUlJSgr69P5T0///xzpKenQyaTaS0w2vK4fv067t+/j6CgIOV+CwsLuLu7TzteZ2cnRkZGlIUdePBdiaY+msaabi5mIiUlBTKZDDdv3gTwoHAmJSWBw+HoPIZCoYCvr6/KNrlcTkVmltGVDHlkTExM1LYZGRmpLRXdv38fAJRLKl999RXs7e1V2jy8xKHPmA+bfJcWh8NRW8aZqs3D2yZOag/3W7lyJZYsWYKSkhIsXrwY4+Pj8PT0xMjICLhcLqqrq3Hx4kXIZDIUFhZix44daG5uVp6AfX190draCqlUimeffVbriVNTHhPzMHkMTQ+81fVhuLrMs6a5mAk/Pz/4+PigvLwc0dHRuHLlCs6fP69z/7GxMbS1taldybS2tmLNmjUzio3oh65kiN74fD7GxsZ0arto0SL09vYq/93f34+uri4AgIeHBxYsWICenh64uLiovCa+59F3TLb9/fff6OjoQE5ODiIjIyEWi9WuVDgcDkJDQ7Fnzx7I5XLw+XycOXNGud/Z2Rl1dXU4e/YsUlNTZxSPs7MzjI2NcenSJeW2/v5+/Pzzz9P2cXFxgbGxMb777jvltr6+Ply7dk2lnbZ51mUutNF0LG3atAlSqRSffvopJBKJxmNisqtXr2J4eFhl2fDbb7/FzZs3p7yS+e2337R+j0UMQ1cyRG+Ojo5obm5Gd3c3TE1NNd6WumLFChw/fhwrV66EpaUldu7cCS6XCwAwMzNDVlYWMjIyMD4+jrCwMPT39+PixYswNTVFYmKi3mOyzdLSEkKhEMeOHYNIJEJPTw+2bt2q3N/c3Iyamhq8+OKLsLGxQXNzM/766y+IxWKVcdzc3FBXV4eIiAjweDyDf0/EzMwMiYmJyM7OhpWVFWxsbLBr1y4YGRlNe4VkamqK5ORkZGdnQygUwtbWFjt27ICRkepnTm3zrG0udDHVsTQRR0JCArKyslBSUqK8+05XCoUCAFBYWIi0tDT88ssvSEtLAwC1u9YaGxuRm5uL4eFh7Ny5Ey+//LJe70U0oysZoresrCxwuVx4eHhg0aJFGtfft23bhuXLlyM2NhYxMTFYvXo1nJ2dlfvff/995ObmIi8vD2KxGNHR0Th//rxyacmQMdlkZGSEkydPoqWlBZ6ensjIyMCBAweU+83NzfHNN98gJiYGbm5uyMnJQX5+Pl566SW1sdzd3VFbW4sTJ07g3XffNTimgoICBAcHIzY2FhKJBKGhoRCLxXjqqaem7XPgwAEsX74cq1atgkQiQVhYGAICAlTaaJtnbXOhC03Hkrm5OdatWwdTU1OsXr1ar3EVCgWioqLQ1dUFT09PbN++Hfv27YO5uTmKiopU2i5btgyjo6MYHh6mAsMCDqPrAi0hZF4YHByEvb098vPzkZycrHO/x/G376OioiAWi3Ho0CG9+kVHR8Pf3x95eXla246NjeH69esQCASwtLSc8rtFYjhaLiNknpPL5fjpp58QFBSEf//9F3v37gUAvPLKK3McmeHu3LkDmUyG2tpaHD58WO/+ly9fRlJSkk5tuVwuXF1d9X4PohsqMoQ8AT766CNcvXoVfD4fAQEBaGxshLW19VyHZTB/f3/09fXhww8/1PvW6j/++AN//vmn2q3fZG7QchkhhBDW0Bf/hBBCWENFhhBCCGuoyBBCCGENFRlCCCGsoSJDCCGENVRkCCGEsIaKDCGEENZQkSGEEMIaKjKEEEJYQ0WGEEIIa/4HzuqdTc/4qEkAAAAASUVORK5CYII=",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# predicted quality vs. true quality mu_star\n",
+ "plt.figure(figsize=(4.2, 4.2))\n",
+ "plt.scatter(qs, pred, s=6, alpha=0.25, label=\"held-out states\")\n",
+ "bins = np.linspace(0, 1, 11)\n",
+ "idx = np.digitize(qs, bins) - 1\n",
+ "bx = [qs[idx == b].mean() for b in range(10) if (idx == b).any()]\n",
+ "by = [pred[idx == b].mean() for b in range(10) if (idx == b).any()]\n",
+ "plt.plot(bx, by, \"o-\", color=\"crimson\", label=\"binned mean\")\n",
+ "plt.plot([0, 1], [0, 1], \"k--\", lw=1, label=\"ideal\")\n",
+ "plt.xlabel(r\"true unmasking quality $\\mu_\\star$\")\n",
+ "plt.ylabel(r\"predicted $\\mu_\\phi$\")\n",
+ "plt.title(\"UQL recovers the unmasking quality\")\n",
+ "plt.legend(fontsize=8); plt.tight_layout(); plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Insertion quality $\\nu_\\star$\n",
+ "\n",
+ "When the model inserts a mask into the gap between surviving tokens $\\ell-1$ and $\\ell$,\n",
+ "its **insertion quality** is the probability that the mask decodes to *some* token that\n",
+ "genuinely belongs in that gap, $\\mathcal{S}_\\ell=\\{\\boldsymbol{x}_1^i: s_t[\\ell-1] empty gap\n",
+ "print(f\" insert into an EMPTY gap (nothing belongs): nu* = {nu_empty:.3f} gap={gap_empty}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.1 Training the insertion-quality predictor $\\nu_\\phi$ (IQL)\n",
+ "\n",
+ "Similarly, $\\nu_\\phi$ is trained with the **Insertion Quality Loss**, which is a BCE against the (soft) target $\\nu_\\star$, taking only the post-insertion sequence $\\boldsymbol{y}$ as input:\n",
+ "\n",
+ "$$\\mathcal{L}_{\\text{IQL}}(\\phi)=\\mathbb{E}\\Big[\\textstyle\\sum_{i\\in\\mathcal{I}}\\text{BCE}\\big(\\nu_\\star^i(\\boldsymbol{y}),\\,\\nu_\\phi^i(\\boldsymbol{y})\\big)\\Big],$$\n",
+ "\n",
+ "whose unique minimizer is $\\nu_\\star$. The predictor learns to assign near-zero quality to inserts whose neighbours are already adjacent in the target."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " step 500 IQL running avg = 0.4728 (irreducible floor = 0.2009)\n",
+ " step 1000 IQL running avg = 0.4763 (irreducible floor = 0.2009)\n",
+ " step 1500 IQL running avg = 0.4646 (irreducible floor = 0.2009)\n",
+ " step 2000 IQL running avg = 0.4717 (irreducible floor = 0.2009)\n",
+ "\n",
+ "running loss 0.4717 sits at the floor 0.2009 -> converged\n",
+ "corr(nu_phi, nu_star) = 0.644\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAGZCAYAAABbpUzOAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAhuNJREFUeJzt3Xd4U2UbwOFfmo50UzqgFChlyJANsveUoTIEZMoUBGU7EGUIioAgyMaPqYAoKioiUFbZKFOgiAJltLR0QfdM3u+P0kjozmjS9r2vi0vPyRnPOUnz5LxTIYQQSJIkSZIJWJk7AEmSJKn4kklGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUnGTDZv3oxCoeDcuXNZXtu3bx89evTA09MTOzs7KlasyIgRI7hx40aWbefMmYNCoSAyMrIwwi7REhMTmTNnDkePHs3yWmG9D5mfmzt37pj0PIYIDAxkzpw52cY4fPhwKlWqVOgxmVu7du1o166dzjqFQsGcOXO0y7ndt6JMJhkL8+6779KtWzc0Gg2rV6/G39+fWbNmcfbsWRo0aMCePXvMHWKJlZiYyNy5c7NNMoWlR48enD59Gm9vb7PFkJfAwEDmzp2b7ZflRx99xE8//VT4QVmg06dPM3r0aO1ybvetKLM2dwDSf3bs2MHixYt58803Wb16tXZ9mzZtGDhwIG3btmXQoEFcvXqVihUrGv38arWa9PR07OzsjH5sSyCEIDk5GXt7e3OHojdPT088PT3NHUa20tLSUCgUuW5TpUqVQorG8jVr1szcIRQK+SRjQT755BPc3Nz4/PPPs7zm6OjIihUriIuLY9myZQaf686dOygUChYtWsT8+fPx8/PDzs6OI0eOAHDu3DlefvllSpcujUqlokGDBnz33XdZjhMSEsIbb7xBhQoVsLW1pVy5crz66qs8fPhQu829e/cYMmQIXl5e2NnZUbNmTZYsWYJGowEyvpy8vLwYOnRoluM/fvwYe3t7pk6dql0XGxvL9OnT8fPzw9bWFh8fHyZPnkxCQoLOvgqFgrfeeou1a9dSs2ZN7Ozs2LJlCwBr1qyhXr16ODk54ezsTI0aNfjggw9yvV+ZX+5z585FoVCgUCgYPny4znYPHz5k4MCBuLq6UqZMGUaOHElMTIzONkIIVq9eTf369bG3t8fNzY1XX32V27dv53j+TNkVl7Vr147atWvz559/0rp1axwcHKhcuTKfffaZ9h4DaDQa5s+fT/Xq1bG3t6dUqVLUrVuX5cuX65zj33//ZdCgQTrv16pVq3S2OXr0KAqFgq+//ppp06bh4+ODnZ0d//vf/+jXrx8A7du3196nzZs3A9kXlyUnJzNjxgyd93PChAk8fvxYZ7tKlSrRs2dP9u3bR8OGDbG3t6dGjRps3Lgxz/sG8ODBA/r374+zszOurq4MGDCAM2fO6MSXeT+fLdrKKfa5c+fStGlTSpcujYuLCw0bNmTDhg3kZ9zhp4vLNm/enON9mzdvHtbW1ty/fz/LMUaOHIm7uzvJycn5ugdmISSz2LRpkwDEn3/+KYQQ4sGDBwIQAwYMyHU/Ly8vUatWLe3y7NmzBSAiIiIKdP6goCABCB8fH9G+fXuxa9cuceDAAREUFCQOHz4sbG1tRevWrcXOnTvFvn37xPDhwwUgNm3apD1GcHCw8Pb2Fh4eHmLp0qXi4MGDYufOnWLkyJHi+vXrQgghwsPDhY+Pj/D09BRr164V+/btE2+99ZYAxJtvvqk91pQpU4S9vb2IiYnRiXP16tUCEH/99ZcQQoiEhARRv359nXMuX75cuLq6ig4dOgiNRqPdN/P66tatK7Zv3y4OHz4srl69Knbs2CEA8fbbb4sDBw6IgwcPirVr14qJEyfmeL+Sk5PFvn37BCBGjRolTp8+LU6fPi1u3ryp8z5Ur15dzJo1S/j7+4ulS5cKOzs7MWLECJ1jjRkzRtjY2Ihp06aJffv2ie3bt4saNWqIMmXKiLCwsFzft8zPTVBQkHZd27Zthbu7u6hWrZpYu3at8Pf3F+PHjxeA2LJli3a7BQsWCKVSKWbPni0OHTok9u3bJ5YtWybmzJmj3ebatWvC1dVV1KlTR2zdulUcOHBATJs2TVhZWelsd+TIEe39ffXVV8Uvv/wi9uzZI8LCwsSnn34qALFq1SrtfQoPDxdCCPH6668LX19f7XE0Go3o2rWrsLa2Fh999JE4cOCA+Pzzz4Wjo6No0KCBSE5O1m7r6+srypcvL2rVqiW2bt0q9u/fL/r16ycAERAQkOt9S0xMFDVr1hSurq5ixYoVYv/+/WLixImiYsWKWT7Xbdu2FW3bts1yjGdjF0KI4cOHiw0bNgh/f3/h7+8v5s2bJ+zt7cXcuXN1tsvumICYPXu2ECLj7ySn+/bw4UNhZ2cnZs6cqbN/VFSUsLe3F++8806u125uMsmYybNJ5syZMwIQ77//fq77NW3aVDg6OmqXDU0yVapUEampqTqv1ahRQzRo0ECkpaXprO/Zs6fw9vYWarVaCCHEyJEjhY2NjQgMDMzxPO+//74AxNmzZ3XWv/nmm0KhUIgbN24IIYT466+/BCDWr1+vs12TJk1Eo0aNtMsLFiwQVlZW2vuWadeuXQIQe/fu1a4DhKurq4iOjtbZ9q233hKlSpXKMeacRERE6HwxPC3zfVi0aJHO+vHjxwuVSqVNfqdPnxaAWLJkic529+/fF/b29uLdd9/NNYackkx297hWrVqia9eu2uWePXuK+vXr53r8rl27ivLly2dJ9m+99ZZQqVTae5mZZNq0aZPlGN9//70AxJEjR7K89uwXdWbifva+7dy5M8vnwdfXV6hUKnH37l3tuqSkJFG6dGkxduzYXK9rzZo1AhA///yzzvoxY8YYlGSeplarRVpamvj444+Fu7u7zg+evJKMEHnfNy8vL5GSkqJdt3DhQmFlZaXzWbBEsrisiBFC5FnuXRAvv/wyNjY22uWbN2/y999/M3jwYADS09O1/7p3705oaKi2ldvvv/9O+/btqVmzZo7HP3z4MLVq1aJJkyY664cPH44QgsOHDwNQp04dGjVqxKZNm7TbXL9+nT/++IORI0dq1+3Zs4fatWtTv359ndi6du2KQqHIUinfoUMH3NzcdNY1adKEx48fM3DgQH7++Wejtgh7+eWXdZbr1q1LcnIy4eHh2vgVCgVDhgzRib9s2bLUq1dP70YFZcuWzXKP69aty927d7XLTZo04fLly4wfP579+/cTGxurs31ycjKHDh2id+/eODg4ZHnvk5OTOXPmjM4+ffv21SveTJnv/7PFjv369cPR0ZFDhw7prK9fv75OfaRKpeK5557Tuc7sHDlyBGdn5yzvz6BBgwyIPiP+Tp064erqilKpxMbGhlmzZhEVFaV9z41h0qRJhIeH8/333wMZRZ9r1qyhR48eFt9aTyYZC5H5hxMUFJTrdnfv3qVChQpGO++zrZQy61KmT5+OjY2Nzr/x48cDaL+UIyIiKF++fK7Hj4qKyrYlVLly5bSvZxo5ciSnT5/m77//BmDTpk3Y2dkxcOBAnfj++uuvLLE5OzsjhMiSMLI799ChQ9m4cSN3796lb9++eHl50bRpU/z9/XO9lvxwd3fXWc5sRJGUlKSNXwhBmTJlslzDmTNn9E54z54389yZ5wWYMWMGn3/+OWfOnKFbt264u7vTsWNHbTP6qKgo0tPTWbFiRZbYunfvDpCv+1sQUVFRWFtbZ2nMoFAoKFu2rM7nI7/XmdN5ypQpk2V92bJl9Yg6wx9//EGXLl0A+Oqrrzh58iR//vknM2fOBMgzpoJo0KABrVu31taN7dmzhzt37vDWW28Z7RymIluXWQhvb29q167NgQMHSExMxMHBIcs2p0+f5uHDh7z66qtGO++zT0UeHh5AxhdSnz59st2nevXqQEZLp+Dg4FyP7+7uTmhoaJb1Dx480DkfwMCBA5k6dSqbN2/mk08+4euvv6ZXr146TyIeHh7Y29vnWNn79PGyu75MI0aMYMSIESQkJHDs2DFmz55Nz549+eeff/D19c31mgzh4eGBQqHg+PHj2bbiM2XLPmtra6ZOncrUqVN5/PgxBw8e5IMPPqBr167cv38fNzc3lEolQ4cOZcKECdkew8/PT2fZ0Kdqd3d30tPTiYiI0Ek0QgjCwsJ44YUXDDr+0+f5448/sqwPCwvLsk6lUmVprAFZE+y3336LjY0Ne/bsQaVSadfv3r3b8ICzMXHiRPr168eFCxdYuXIlzz33HJ07dzbJuYxJPslYkJkzZ/Lo0SOmT5+e5bWEhAQmTpyIra2t9onCFKpXr061atW4fPkyjRs3zvafs7MzAN26dePIkSPZdhLN1LFjRwIDA7lw4YLO+q1bt6JQKGjfvr12nZubG7169WLr1q3s2bOHsLAwnaIygJ49e3Lr1i3c3d2zja2gRQeOjo5069aNmTNnkpqayrVr13Lc9tmnEn307NkTIQQhISHZxl+nTh29j10QpUqV4tVXX2XChAlER0dz584dHBwcaN++PRcvXqRu3brZxpfdk8SzCnKfOnbsCMA333yjs/6HH34gISFB+7qh2rdvT1xcHL/88ovO+u3bt2fZtlKlSvzzzz+kpKRo10VFRXHq1Cmd7RQKBdbW1iiVSu26pKQkvv76a71izOu+9e7dm4oVKzJt2jQOHjzI+PHjjVp0biryScaCvPbaa5w/f57PP/+cO3fuMHLkSMqUKcONGzf44osv+Pvvv9mwYQO1atXKsu+vv/6q/fJ/mj5PPevWraNbt2507dqV4cOH4+PjQ3R0NNevX+fChQvacuGPP/6Y33//nTZt2vDBBx9Qp04dHj9+zL59+5g6dSo1atRgypQpbN26lR49evDxxx/j6+vLb7/9xurVq3nzzTd57rnndM49cuRIdu7cyVtvvUX58uXp1KmTzuuTJ0/mhx9+oE2bNkyZMoW6deui0Wi4d+8eBw4cYNq0aTRt2jTX6xszZgz29va0bNkSb29vwsLCWLBgAa6urrn+cnZ2dsbX15eff/6Zjh07Urp0aTw8PAqU2Fq2bMkbb7zBiBEjOHfuHG3atMHR0ZHQ0FBOnDhBnTp1ePPNN/N9vIJ46aWXqF27No0bN8bT05O7d++ybNkyfH19qVatGgDLly+nVatWtG7dmjfffJNKlSoRFxfHzZs3+fXXX7V1KLmpXbs2AOvXr8fZ2RmVSoWfn1+2Capz58507dqV9957j9jYWFq2bMlff/3F7NmzadCgQbbN2vUxbNgwvvjiC4YNG8Ynn3xCtWrV2Lt3L/v378+y7dChQ1m3bh1DhgxhzJgxREVFsWjRIlxcXHS269GjB0uXLmXQoEG88cYbREVF8fnnn+v9NJrXfVMqlUyYMIH33nsPR0fHLPVYFsucrQ5Ksmdblz3tt99+E926dROlS5cWCoVCAMLLy0ucOXMmy7aZrZpy+peTzNZlixcvzvb1y5cvi/79+wsvLy9hY2MjypYtKzp06CDWrl2rs939+/fFyJEjRdmyZYWNjY0oV66c6N+/v3j48KF2m7t374pBgwYJd3d3YWNjI6pXry4WL16sbaX2NLVaLSpUqCCALE02M8XHx4sPP/xQVK9eXdja2mqb3E6ZMkWnCTAgJkyYkGX/LVu2iPbt24syZcoIW1tbbcyZzaRzc/DgQdGgQQNhZ2cnAPH6668LIXJu5ZddazAhhNi4caO2paC9vb2oUqWKGDZsmDh37lyu58+pddnzzz+fZdtnW0MtWbJEtGjRQnh4eAhbW1tRsWJFMWrUKHHnzh2d/YKCgsTIkSOFj4+PsLGxEZ6enqJFixZi/vz52m0yW5d9//332ca5bNky4efnJ5RKpU7rrexaaCUlJYn33ntP+Pr6ChsbG+Ht7S3efPNN8ejRI53tfH19RY8ePbKcK6fWYM8KDg4Wffv2FU5OTsLZ2Vn07dtXnDp1KkvrMiEyPiM1a9YUKpVK1KpVS+zcuTPb2Ddu3CiqV68u7OzsROXKlcWCBQvEhg0bsn2P8mpdJkTO9y3TnTt3BCDGjRuX5/VaCoUQ+eg1JJnVxx9/zOzZs1m1apVJi8okqaS5c+cOfn5+bNq0qUg8GaxYsYKJEydy9epVnn/+eXOHky+yuKwImDVrFqGhobz11ls4Ojry+uuvmzskSZIK0cWLFwkKCuLjjz/mlVdeKTIJBmSSKTLWrFnDmjVrzB2GJElm0Lt3b8LCwmjdujVr1641dzgFIovLJEmSJJORTZglSZIkk5FJRpIkSTIZWSfzhEaj4cGDBzg7OxeJDk6SJEmFSQhBXFwc5cqVw8oq/88nMsk88eDBA6OOCSZJklQc3b9/P88xC58mk8wTmb3l79+/n6VnryRJUkkXGxtLhQoVsh1ZJDcyyTyRWUTm4uIik4wkSVIOClqdICv+JUmSJJOxuCRz7NgxXnrpJcqVK4dCocjXsNkBAQE0atQIlUpF5cqVi1xnJUmSpOLK4orLEhISqFevHiNGjMjXrHtBQUF0796dMWPG8M0333Dy5EnGjx+Pp6enwbP2ZVKr1aSlpRnlWFLhUiqVWFtbyxaDkmQmFpdkunXrRrdu3fK9/dq1a6lYsSLLli0DoGbNmpw7d47PP//cKEkmPj6e4OBg5MAIRZeDgwPe3t7Y2tqaOxRJKnEsLskU1OnTp7VToGbq2rUrGzZsIC0tTWf++qelpKToTEr07HznkPEEExwcjIODA56envLXcBEjhCA1NZWIiAiCgoKoVq1agdr3S5JkuCKfZMLCwrLM3V2mTBnS09OJjIzMcQ7yBQsWMHfu3FyPnZaWhhACT09P7O3tjRazVHjs7e2xsbHh7t27pKam6kyTK0mS6RWLn3XPPmFkFm3l9uQxY8YMYmJitP/u37+f7+Pnh1CrSTp5kbgfD5J08iJCrS7wMSTjkE8vkmQ+Rf5JpmzZsoSFhemsCw8Px9raOtf5yO3s7PSeJjUv8XsCiJy5HPWDCO06ZTlPPD6ZhFPPtiY5pyRJkiUq8j/xmjdvjr+/v866AwcO0Lhx4xzrY0wpfk8AD0d+qJNgANShETwc+SHxewIMPodCoSA+Pj7b1+rXr09SUpLB58ivdu3asWfPnkI7nyRJRYvFPcnEx8dz8+ZN7XJQUBCXLl2idOnSVKxYkRkzZhASEsLWrVsBGDduHCtXrmTq1KmMGTOG06dPs2HDBnbs2GH02IQQiMTknF9Xq4n8YBlk1xBNAAqI/GA59m0aoVAqsz2GwkFlUAODS5cu6b2vJEmSsVlckjl37hzt27fXLk+dOhWA119/nc2bNxMaGsq9e/e0r/v5+bF3716mTJnCqlWrKFeuHF9++aXR+sg8TSQmE1SpS94b5niAjCeaO1VybqLtd+cACse8Gxl8/vnn+Pv7ExERwdy5cxk4cCCQ8ZQTFxeHk5MTlSpVYsSIEezfv5/Q0FBGjRrFhx9+CGQ8gTRt2pRTp07x4MEDOnfurO3EGhcXx9SpU7l8+TLJycm0aNGCFStWYGNjQ2BgICNGjCAtLY2aNWuSnJx90j169CiTJ0+mWbNmnDx5EhsbG7Zu3cq8efO4cuUKPj4+/PTTTzg5OZGWlsZHH33E4cOHSU1NpUaNGqxdu5ZSpUqxfft2li9fTmpqKkIIPv30U7p37w6Q6/VJkmQZLK64rF27dhlPDM/827x5MwCbN2/m6NGjOvu0bduWCxcukJKSQlBQEOPGjSv8wAuZQqHg5MmT7Nu3j7fffjvHhguPHz/m1KlT/PHHHyxevJiQkBDta7du3eLo0aNcvXqV/fv3c/r0aQCmTZtGmzZt+OOPP7h8+TLp6emsXLkSgKFDhzJ+/HguXLjA22+/zZ9//pljjNeuXWPcuHFcuXKF5s2b8+KLL7JkyRICAwOxsbFh+/btACxevBgnJyf++OMPLl26xPPPP8/s2bOBjOboZ86c4eLFi+zevZvRo0frdIzN7fokSTI/i3uSsWQKBxV+dw7k+HrS6cuEDXwnz+OU3bEY++b1cjxHfowePRqAypUr06pVK44fP86gQYOybDd48GAAPD09qVy5MkFBQfj4+ADw2muvoVQqsbe3p379+ty6dYvmzZuze/duzpw5w5IlSzKuKykJW1tbYmNjuXr1KkOHDgWgWbNm1KlTJ8cYq1evTv369QFo2LAhd+/e1Q4R3qhRI27fvg3A7t27iY2NZdeuXQCkpqZSpUoVIKO4dPDgwQQHB2NtbU1kZCR3796latWqeV6fJEn/iYiIwNPTs9DPK5NMASgUilyLshzav4CynCfq0Ijs62UUoCznhUP7F3KskzEktuw83S9EqVSSnp6e52tCCHbv3k3lypV1jhUbG1ug+qJnj//scmYDBSEEq1evpkOHDlmO8dprr/H555/Tq1cvAEqXLq1TRJfb9UmSlGHPnj0MGDCAvXv30rZt4bZwtbjisqJMoVTi8cmkJwvPvpjxH4/5E42SYDZu3AjAnTt3OHHiBK1atTL4mJlefvllPvvsM+0X9qNHj7h58yYuLi7Url2bbdu2AfDHH39w5coVo5xv6dKlJCYmApCYmMi1a9e0565UqRIA33zzDY8ePTL4fJJUkuzevZs+ffrQtWtXmjdvXujnl0nGyJx6tqXMxvkovXUfS5XlvCizcb7R+snY2dnRsmVLunTpwooVK4w6q+eyZcuwtramfv361K1bl06dOnHnzh0Atm7dysqVK2nYsCHr16+nadOmBp/v/fffp379+jRt2pS6devSrFkzbSu55cuX07t3b1q1asXly5epWLGiweeTpJJi165d9OvXj169erFz506zjN+nEHLkRyCjKMjV1ZWYmBjtpGXJyckEBQXh5+dX4OFIhFpN8pm/SH8YhXUZd1TN6hq9iEzKH0PeR0kqqtLS0qhfvz7169dny5YtWFsbVjuS3Xdkfsg6GRNRKJXYt2xg7jAkSSqBMgcHPnr0KKVLl0Zpxh+4srhMkiSpGNm4cSONGjXi8ePHeHp6mjXBgEwykiRJxca6desYNWoULVu2LFCRlinJJCNJklQMrFy5knHjxjFx4kRWr15tMaOPW0YUkiRJkt7+/vtvJk2axLRp01i2bJlFTbAoK/4lSZKKuBo1anD27FkaNWpkUQkG5JOMJElSkTV//nzmzJkDQOPGjS0uwYBMMkVSbnPGVKpUiatXr+p97KNHj9K4cWO995ckyfSEEMyaNYuPPvrI7K3H8iKLy4ogOWeMJJVcQgg++OADPvvsMz777DPee+89c4eUK5lk9BAaGkpoaKjOOjc3N/z8/EhOTiYwMDDLPg0bNgTgxo0bJCQk6LxWqVIlSpcune/zPz1nzPHjxxk/fjz29vY0adKEpwdw+Pfff5k8eTLh4eGkpqYyduxYxo8fD8CQIUP4+++/SU1NpWLFimzcuBEvL698xyBJknn873//47PPPmPp0qVMmTLF3OHkTUhCCCFiYmIEIGJiYrTrkpKSRGBgoEhKStLZdvbs2YKMcZa1/wYPHiyEEOLff//N8trTt7lZs2ZZXvv6668LFCsg4uLiRHJysihXrpw4cuSIEEKInTt3CkBcuXJFpKeni8aNG4vr168LIYRISEgQderUEefPnxdCCBEREaE93oIFC8SECROEEEIcOXJENGrUqEDxWLqc3kdJKooSEhLEDz/8UOjnze47Mj/kk4wexo4dy8svv6yzzs3NDYDy5ctz/vz5HPfdvHlztk8y+rhx4wYODg60a9cOgP79+/PGG29oX7t27Rqvvfaadvu4uDgCAwNp2LAh27Zt4+uvvyYlJYWkpCTKli2rVwySJJmeRqNhxowZDB06lNq1a9OnTx9zh5RvMsnowdvbG29v72xfU6lU2qKx7FSvXt1ocYhcxjYVQuDh4ZFt/c2JEydYuXIlp06dwtPTk19++YWPP/7YaHFJkmQ8arWaN954g02bNlGnTh1q165t7pAKRLYuK8Jq1KhBUlISx44dAzKG9Y6JiQEykpmDgwNbt27Vbn/z5k2io6N59OgRLi4ulC5dmtTUVNatW2eW+CVJyp1arWbEiBFs3ryZrVu3MmTIEHOHVGAyyRRhdnZ27NixgwkTJtCkSRP++OMP7Xwr1tbW/Prrr3z33XfUrVuX559/ntGjR5OUlES3bt2oWrUqNWrUoGvXrtopkiVJsixvvvkm27dvZ9u2bUUywYCcT0bL2PPJSJZDvo9SUXXixAkePnxI3759zR2K3vPJyCcZSZIkC5KSksKSJUtIT0+nVatWFpFgDCGTjCRJkoVITk6mb9++zJw5s9h0upaty/JBligWbRqNxtwhSFKekpKS6N27NwEBAfzyyy/FZngnmWRyYWNjg0KhICIiAk9PT4scfE7KmRCC1NRUIiIisLKywtbW1twhSVK2kpOT6dmzJ2fOnOG3336jQ4cO5g7JaGSSyYVSqaR8+fIEBwdz584dc4cj6cnBwYGKFStazCROkvQsOzs76tevz+zZs2nTpo25wzEq2brsidxaTqjVatLS0swUmWQIpVKJtbW1fAqVLFJsbCwXLlzQjtphyfRtXSafZPJBqVRa/HDakiQVLY8fP+bFF1/kzp073Lp1C0dHR3OHZBIyyUiSJBWy6OhounTpwu3btzl48GCxTTAgk4wkSVKhioyMpHPnzgQHB3PkyBHq1atn7pBMSiYZSZKkQpSQkICNjQ1HjhwpcoNd6kMmGUmSpEIQFhaGra0tvr6+nD17tsQ0RpFtOiVJkkwsJCSEtm3bMnLkSIASk2BAJhlJkiSTunfvHm3btiU5OZklS5aYO5xCJ4vLJEmSTOTOnTu0b98egICAAL1nwS3K5JOMJEmSiQQEBGBtbc2xY8dKZIIB2eNfS9/erJIkSc96/PgxpUqVAiAxMREHBwfzBmQEcj4ZSZIkCxAYGEjNmjXZtGkTQLFIMIaQSUaSJMlIrly5Qrt27fD09KRHjx7mDsciyCQjSZJkBJcuXaJ9+/b4+Phw+PBhvLy8zB2SRZBJRpIkyQhmzZqFn58fhw4dwsPDw9zhWAzZhFmSJMkAarUapVLJN998g0aj0Vb4Sxnkk4wkSZKeTpw4wfPPP8/t27dxcXGRCSYbMslIkiTpISAggBdffJGyZcvK+pdcyCQjSZJUQIcOHaJbt240a9aMvXv34uTkZO6QLJZMMpIkSQUQGxtLv379aNOmDb/++muJ7weTF1nxL0mSVAAuLi7s37+fOnXqoFKpzB2OxZNPMpIkSfmwe/du3njjDTQaDS+88IJMMPkkk4wkSVIedu3aRb9+/Xj8+DFqtdrc4RQpFptkVq9ejZ+fHyqVikaNGnH8+PFct9+2bRv16tXDwcEBb29vRowYQVRUVCFFK0lScbVjxw5ee+01+vfvz/bt27GxsTF3SEWKRSaZnTt3MnnyZGbOnMnFixdp3bo13bp14969e9luf+LECYYNG8aoUaO4du0a33//PX/++SejR48u5MglSSpOTp48yZAhQxg8eDBbt27F2lpWYxeURQ7137RpUxo2bMiaNWu062rWrEmvXr1YsGBBlu0///xz1qxZw61bt7TrVqxYwaJFi7h//36+zimH+pck6VlqtZqtW7fy+uuvY2Vlkb/JC02xGeo/NTWV8+fP06VLF531Xbp04dSpU9nu06JFC4KDg9m7dy9CCB4+fMiuXbtyHQU1JSWF2NhYnX+SJEkA69ev59ChQyiVSkaMGFHiE4whLO7ORUZGolarKVOmjM76MmXKEBYWlu0+LVq0YNu2bQwYMABbW1vKli1LqVKlWLFiRY7nWbBgAa6urtp/FSpUMOp1SJJUNK1YsYKxY8eyf/9+c4dSLFhcksmkUCh0loUQWdZlCgwMZOLEicyaNYvz58+zb98+goKCGDduXI7HnzFjBjExMdp/+S1WkySp+Fq6dCkTJ05k2rRpLFy40NzhFAsWV4vl4eGBUqnM8tQSHh6e5ekm04IFC2jZsiXvvPMOAHXr1sXR0ZHWrVszf/58vL29s+xjZ2eHnZ2d8S9AkqQiadWqVUybNo0ZM2bwySef5PijVioYi3uSsbW1pVGjRvj7++us9/f3p0WLFtnuk5iYmKXMVKlUAhlPQJIkSXnp1KkTixcvlgnGyCzuSQZg6tSpDB06lMaNG9O8eXPWr1/PvXv3tMVfM2bMICQkhK1btwLw0ksvMWbMGNasWUPXrl0JDQ1l8uTJNGnShHLlypnzUiRJsmBCCDZs2MCAAQOoXr061atXN3dIxY5FJpkBAwYQFRXFxx9/TGhoKLVr12bv3r34+voCEBoaqtNnZvjw4cTFxbFy5UqmTZtGqVKl6NChgyxTlSQpR0IIPvjgAz777DMcHBwYNGiQuUMqliyyn4w5yH4yklRyCCGYPn06S5cuZenSpUyZMsXcIVk8fb8jLfJJRpIkyVSEEEyePJkvv/ySlStXMmHCBHOHVKzJJCNJUomiUCjw8fFh3bp1vPHGG+YOp9iTxWVP6PsoeOafUEZvuUC8GlRWMK5DZRr5uuPppKKKlxM2SotrwCdJJZJarSYgIIAOHTqYO5RCs/PMbd7bfV27PKldZZpV88TNwbbA30/FZliZombM1owEA5CsgS8P3uav+7Fcuv+YW+Hx5g1OkiQgI8GMGDGCLl26cPPmTXOHU2ieTjAAy4/eJiwmpVC/n2SSMVB8esZ/M1vVawA7ayvsrJU8Skw1V1iSJD2Rnp7OkCFD2L59O9988w1Vq1Y1d0hmVdZVVajfT7JOxkBO1hCXDplljlZASrqGlHQ1bg625gxNkkq8tLQ0Bg0axO7du/n222959dVXzR2S2YXFJBfq95N8kjHQV8Ma4pQxuAAqK5jYqTJ1K7hQv0Ipqng5mTc4SSrhEhMTCQ4OZteuXSUywSzsVVNneVK7ypR1tSvU7ydZ8f+E7CcjScVHcnIykZGRlC9fHo1GI4fqNwJZ8S9JkgQkJSXxyiuv0KVLF9LT02WCMTNZJ2OgxNR0Am6EE/IoCR83e9pW98LBVt5WSTKHhIQEXn75Zc6cOcOvv/5aoqdLPhoYwhtbL5FKxtPE3Fdq8VoT30LvViFTvIECboRz8Ho4QZGJ+AeGE3Aj3NwhSVKJFB8fT/fu3Tl79iy///57ieoPk53MBAMZrV4/+jnQLN0qZJIxUMijJFTWSip7OmJvoyTkUZK5Q5KkEunChQsEBgZy4MAB2rRpY+5wzC67Bsrm6FZRcp8ljcTHzZ7A0DhuRySQlKbGx83e3CFJUokSHx+Po6Mjbdq04fbt2zg7O5s7JItgS9ZEY45uFTLJGKhtdS8AnTqZoiBNreFWeDyPElMLPMSEIftKkjFFR0fTpUsXOnfuzIIFC7QJxpC60r9Dopmw/QLh8Wl4OdmwalBDaviUNuVlmMT6YfWz1MmYo1uFbML8RElrwvx3aCyXg2OwVVqRkq6mfoVS1PDO33Ubsq8kGUtkZCSdO3cmODiYgwcPUq9ePe1rv195wMHr4aislSSlqelcy4tudfI3gWHHxQe5HZWiXa7sbsehdzoZPf6iRjZhlgrkUWIqtkorvYaYMGRfSTKG8PBwOnTowIMHDzhy5IhOggHD6krD49MAcFYpdZYl/cgkU0K5OdiSkq7Wa4gJQ/aVJGNYsmQJERERHD16lNq1a2d53cfNnqQ0tV51pV5ONgDEJat1liX9yOKyJ0pacZmsk5GKosze+2lpaYSGhlKxYsVst5N1Msan73ekTDJPlLQkI0lFzb1793jllVdYs2YNzZo1M3c4JY6cflmSpGIrKChI27mybNmyZo5GKghZxiFJkkW7efMmbdu2RalUcuzYMSpVqmTukKQCkE8yBjrx9wPGbL1IkgZUCnijbSVUdrY8SkjBzcEOP09HOZ6ZJOlJCEG/fv2wt7fn8OHD+Pj4mDukImXp3mt8eeyOdnlim0pM7f58ocYg62Se0Le8seYHv5Gk+W9ZAbSqWpq70UlUdHPAy0VVoDb6kiTpunLlCp6enrKYTA+V3v8ty7o7n/XQ61iyTsZMMhOMFRmD0AnAzkaJrdIKlY2VycYzM7SFl2whJlmyK1eusGDBAjZs2ECdOnUKvL8hn++YpFS++/MedyMT8fVwoP8LFXG1z38zfUP3L25kkjGQvVVGosl8mFEAKWlqUtI1JKdpTDae2a3weG2v+7tRiQAF6nVv6P6SZCqXLl2iU6dOVKxYkaSkJOztC/73Y8jn+7s/7/H71YeorK24FhoHwJg2VfN9bkP3L27kT1cDfTWsAfZP7qJKAW+3q0TzKh68+LwXLaq407mWl0nGMzO0173stS9ZonPnztGhQwf8/Pw4dOgQpUvr1z/FkM/33chEVNZWVPFywt7GiruRiQU6t6H7G9PENpVyXS4M8knGQK1qlOP6p4Vf3+LmYMvdqES9e90bur8kGVtwcDCdOnWiZs2a7Nu3D1dXV72PZcjn29fDgWuhcdwKjycpTYOvh0OBzm3o/sY0tfvzhV7R/yyZZIqozNFUny5zLsz9JcnYfHx8+Pzzz+nfv7/BHaIN+Xz3fyFjFIGn61QKwtD9ixvZuuwJ2eNfkswjICCAe/fuMXToUHOHIuVCjsIsSVKRc+jQIbp168a2bduQv3eLJ5lkJEkyi3379tGzZ0/atm3L7t27USgU5g5JMgGZZCRJKnSHDh3ilVdeoXPnzuzevRuVSmXukCQTkUlGkqRCV6dOHd5++2127dqFnZ2ducORTEgmGUmSCs2ePXt48OABXl5efP7559jayqbzxZ1MMpIkFYodO3bQq1cvVqxYYe5QpEIkk4wkSSa3detWhgwZwpAhQ5g/f765w5EKkeyMaaDspmqtUraUHHxSkp7YuHEjo0ePZtSoUaxbtw4rK/m3UFh2/RHE9B8Dtcs9a5emRz3fQp1+RL7bBpqw/QK3o1KIT9FwOyqFCdsvaAfnC4tJ4dL9x9wKjzd3mJJkNo6OjkyYMEEmGDN4OsEA7LkajX9gOAE3wgstBvkkY6Dw+DQAnFVK4pLVhMen6QzOFxaTbJGDT0bGJ7Pu6E2CIhLx83RgbLuqeDjlrxlpYmo6ATfCCXmUhI+bvZyUTcrWiRMnaNmyJQMGDGDAgAEF3t+Q4foN+XwX92kwTDX9SE6Kz50zEy8nGwDiktXaZTcHW1LS1RY9+OS6ozfZHxjOnehE9l0LZ93Rm/neN+BGOAevhxMUmVjov4qkouGLL76gdevW7Nu3T+9jGFIiYMjnu7iXRJhq+pGcyCRjoFWDGlLZ3Q4nOysqu9tl1Ml4OVG/QinKutpRv0Ipixx8MigiEVulFX4ejthZWxEUkf/hyEMeJaGyVlLZ07HQfxVJlm/hwoVMnTqVGTNm8OKLL+p9HEOG6zfk812cpsH4vE8tneWetUubbPqRnMgyDgPV8CnNoXc6ZV1v4ROA+Xk6cCM8nqDIBFLSNfh55n84ch83ewJD47gdkVDov4okyzZv3jxmzZrF7NmzmT17tkFDxRgyXL8hn+/iNA3Gq038eLWJn1ljkEmmiDK03Hhsu4yZ+p4us86vzF9BT9fJSFJ6ejpnzpxh3rx5fPjhhwYfz5Dh+g35fHs423IrPFa7b6da8vNtCDnU/xNFbaj/v0NjtdPLpqSrqV+hlMU/PUnFkxCCkJAQypcvj1qtRqlUmjskg3x17KZ2+uSkNA3d65Qp0dMnZ5JD/ZcwxancWCq6hBC888471K1bl8jIyCKfYMCypk/OD6FWk3TyInE/HiTp5EWEWm3ukHQYJcn8+uuvNGjQgKpVq9KnTx8OHDhgjMNKuSgKLdik4k0IwaRJk1iyZAnz5s3Dw8PD3CEZha+HA0lpGouYPjkv8XsCuNuwHw96TSR87Fwe9JrI3Yb9iN8TYO7QtAyqk1m2bBkNGzZk+vTp7Nq1i6pVq3L58mU+/fRTbt++zbhx44wVp8U6808oY7ZeID4dHK1hVs+aeLs7EZ+cjpPKGk8nlUna2cvpk4uu7PoZ2SitstSxpak1FtsfSaPRMGHCBNauXcu6det44403zB2S0RSV6ZPj9wTwcOSH8EyFR9qDCMJGfMjHHXtzrFJ17frBjXzwK+Nc6J8lg+pkVq5cyeXLl9m2bRuVKlWiWrVq1K5dmxo1avDpp58SGBhYZCYi0re8sc6HvxGX/t+yrRVM6lSd2xHxVPZyxN3RrsjUlwi1muQzf5H+MArrMu6omtVFUQyKPyzN71cecPB6OCprJUlpajrX8sLPwylLHVtQZHyW7brVKWfu8AG4ffs2jRs3ZsmSJYwYMcLc4ZQ4Qq3mbsN+qB9EZPu6Boh0dGZw/zfRPDXKwuAmFfX+LOn7HWlQKnvrrbcAuHXrFuvXryctLY0rV67w119/ERoaSq1atXBycuLPP/805DQWLf5JglEqQC0gVQN21laobJSorJVFpr4kfk8AkTOX63xoleU88fhkEk4925oxsuLn6X5GtyMSCHmURCkH2yyjRGS3nbmlp6ejVqupXLkyt27dws3NzdwhlUjJZ/7KMcFARj2IV0IcdR7e57K3r3a9OT5LRinD+fLLL+nXrx8rV64kOjqahIQEateuzfXr1zl48KBex1y9ejV+fn6oVCoaNWrE8ePHc90+JSWFmTNn4uvri52dHVWqVGHjxo16nbsgnJ6kafWT50FbK0hJ15CUpiY5XV0k6ksyH7uf/dCqQyN4OPJDiyrfLQ583OxJSlPr9DPKro4tu+3MKT09naFDhzJo0CCEEDLBmIk6Ooa4HXvztW3pxASdZXN8lvR6ktFoNJw8eZKIiAhatGhB7dq1OX36NPv27ePy5ctUrFhR207e1dW1wMffuXMnkydPZvXq1bRs2ZJ169bRrVs3AgMDqVgx+/LR/v378/DhQzZs2EDVqlUJDw8nPT09222N6athDbOtk6ns6aBTJ2OphFpN5MzlWcp1M14EFBD54Zc4dmsli86MJLt+Rpl1dk/XyVR0d8iynbmkpaUxcOBAfv75Z7799tsiUwxeXAi1mqSAc8Ru30vC78chNS1f+0U7OGr/f3AjH/w8HAr9s6RXnYyHhwcpKSmoVCoeP37MkCFDWLlyJY6OjnnvnA9NmzalYcOGrFmzRruuZs2a9OrViwULFmTZft++fbz22mvcvn2b0qVL63XOotZPxliSTl7kQa+JeW5XbveX2LdsUAgRSZYmJSWFAQMGsHfvXr7//nteeeUVc4dUYqQFhRD37e/Efvs76gf/jRFoU7sq6vthaGLjs/+BqABlOS98z39ntB+HhdpP5vvvvycuLo6IiAj+/PNP7ty5wwsvvEBYWJg+h9ORmprK+fPn6dKli876Ll26cOrUqWz3+eWXX2jcuDGLFi3Cx8eH5557junTp5OUlHO5Y0pKCrGxsTr/SqL0h1FG3U4qfrZv387vv//OTz/9JBNMIdAkJhP33T5Cek3kXpPXeLR0C+oH4ViVcsZ1dF/KH9pAxSOb8Fz2fsYOzz5UPln2mD/RIkof9Coua9++vfb/69evz+HDh3n33Xdp1aoVx48fx9vbW++AIiMjUavVlClTRmd9mTJlckxit2/f5sSJE6hUKn766SciIyMZP3480dHROdbLLFiwgLlz5+odZ3FhXcbdqNtJxYcQAoVCwfDhw2natCm1atXKeydJL0IIUi4EErd9L3E/HkTEP+kAqlBg3+4FXAb1wOHFllip7LT7OPVsCxvnZ9NgxwuP+RMtpsGOURpKJyQkMGrUKK5du0aXLl24cuWKwcd8tsw38wOfHY1Gg0KhYNu2bdo6oKVLl/Lqq6+yatUq7O2zVnLNmDGDqVOnapdjY2OpUKGCwXEXNapmdVGW88y5pcqTx25Vs7qFG5hkVgkJCfTt25cxY8bQt29fmWBMJD08mvhdB4jd/htpN+5o11tXKofLa91xfu1FrH3K5Li/U8+2OHZrZdFdD/RKMiNGjCAkJITg4GBCQkKIj8+Yb0EIgUqVv4mBcuLh4YFSqczy1BIeHp7l6SaTt7c3Pj4+Oo0MatasiRCC4OBgqlWrlmUfOzs77OzssqwvaRRKJR6fTOLhiBwGNBTgMe9ti/rQSqYVFxdHz549OX/+PB988IG5wyl2RHo6iYfOErf9NxIOnIL0jGFgFPZ2OL7UDpdBPVA1r4cin7OIKpRKi64v1SvJPHz4EF9fX1q0aIGPj4/OP0OHlrC1taVRo0b4+/vTu3dv7Xp/f/8cy4NbtmzJ999/T3x8PE5OGS25/vnnH6ysrChfvrxB8ZQEds/nPvifSE4ppEgkc4uNjaVbt25cuXKFAwcO0KJFC3OHVGyk/nuXuB17idu5D3V4tHa9XaNauAzqgWOvDihdLLclqr70SjJ79+avjba+pk6dytChQ2ncuDHNmzdn/fr13Lt3TztMzYwZMwgJCWHr1q0ADBo0iHnz5jFixAjmzp1LZGQk77zzDiNHjsy2qEzSFbN5NwCq9k0oPWmI9rE76cxlHn22gcgPV+DQvglKD9kvorh78803uXbtGgcPHqRJkybmDqfI08QnEr/7MHHbfyP5z6va9VYepXDu/yIuA7tjW8O8872YmmUMhPSMAQMGEBUVxccff0xoaCi1a9dm7969+Ppm9FwNDQ3l3r172u2dnJzw9/fn7bffpnHjxri7u9O/f3/mz59vrksoMjSJycRt2wNAqdF9dB67VU3qkPDrUVKv3SJy5peUWTfbTFFKhWXhwoVMmzaNhg0bmjuUIksIQfKZv4jb/hvxvxxBJCZnvKBU4tC5OS6DuuPQqTkKG4v8+jU6OZ/MEyW1n0zsN3uImLIQa19vKp7dkaXuJfnS34R0HQsaDWW3LcSxiyw+KW4iIyOZNGkSy5cvLzYjKZtDemgEcTv3Ebd9L2lBwdr1NlUr4jy4B879uhbpVppmGbtMKtqEEMRs+BEAl+G9sq3cV9Wvgeu4/sSs/pbId5dg36I+Vk6WO/S5VDDh4eF07NiR8PBwwsPDZZIpIJGSSsL+k8Rt30vikT9AowFA4WiPU++OuAzqgV3j50v0CAkyyZRgKX9eJfXqvyhUtrgM6pHjdqXfG0XC3mOk33lA1Lx1eC6cUohRSqYSGhpKx44defToEUePHqVmzZrmDsns8jsSecq1mxl9WnYdQBMdo12val4P50E9cHqpHVaOsj4YZJIp0WI2/gSAU+9OKEvnPMaclYMKzyXvENp3CrGbfsKpT0fsm8p+M0VZSkoKHTp0IC4ujoCAAJ577jlzh2R2eY1Ern4cR/yPB4nb/hspl2/8t01ZD5xf64bza92wrVLy+trlRdbJPFHS6mTSw6O5W78vpKVT/uD/sKtXPc99wicuIG7HXmyq+VLhyEYUdpY9urSUu+3bt9O0aVOqVKli7lDMLqcJwFAAAlRN65Jy6W9EypNpO2yscXyxFc6DeuDQ/oUS0Y+sUMcuk4q+uK9/hbR07Bo/n68EA+D+8VsoPUuT9u9dHn2x1cQRSqYQFBTEl19+CWQ0/ZcJJh8jkQPJZ/9CpKRiW6sy7vMnUunKT5TdOA/HTs1KRIIxhEwyJZBITydmy88AuI7sncfW/1GWcsbjs8kAPFr+DSmBt0wRnmQiN2/epG3btnz55ZcldkDY7OQ1AVgmj8+nU/7oZkqN7YfSvZTpAysmZJIpgRJ+P4E6NAIrj1I4vdw+7x2e4vhSOxy7t4Z0NRFTFiLUahNFKRnTjRs3aNu2LQ4ODgQEBJSIIuH8yu8I41bOjiW6lZi+ZJIpgTIr/F2GvFTgehWFQoHHwqlYOTuScuE6MV/tMkWIkhHdvHmTdu3aUapUKY4ePYqPj4+5Q7IoSo9S+dquKPdxMad8ty57esTivCxdulSvYCTTS70RRPKJC2Blhcvr+s0NYl3WA/c544mYtpjoBf/DsVtrbHzLGTlSyVjKlStHr169mDt3Ll5e5ptd0xKl3X1A9IL/5b6RHIncIPlOMhcvXtRZPn/+PGq1murVMyqN//nnH5RKJY0aNTJuhJJRxWzIeIpxfLElNuVzHkI8L85DehL3gz/Jpy4RMf1zvL9bIosSLMylS5ewtramdu3aOrPMShkdkeO/20/E+18g4hNRqOwyBoJ90ppMy8ImACuK8l1cduTIEe2/l156iXbt2hEcHMyFCxe4cOEC9+/fp3379vTokXOnPsm8NHEJxH23DwCXUX0MOpbCygrPpe+iUNmSdPRP4nbuM0aIkpGcO3eODh068P7775s7FIujfhTLwzFzCH/rE0R8Iqqmdalw8mvKbJqP0ttTZ1tlOS/KbJxvMROAFUV69ZPx8fHhwIEDPP/88zrrr169SpcuXXjw4IHRAiwsJaGfTMyGH4l8/4uMfi4nvzbKk8ejL7cRPW8tVqWcqXDyG6y9ShshUskQZ86coWvXrtSqVYt9+/bpzLNU0iUeP0/4hE9Qh0aAtZLS746k1MTB2qeU/Pb4L4kKtZ9MbGwsDx8+zLI+PDycuLg4fQ4pmZgQgpiNGeOUuY7oZbSirVJvDsC2djU0j+OImrncKMeU9Hfy5Em6dOlC3bp12b9/v0wwT4iUVCLnrCK07xTUoRHYVKmAz941uE0ZppNEMicAc+7TCfuWDWSCMQK9kkzv3r0ZMWIEu3btIjg4mODgYHbt2sWoUaPo08ewYhjJNJJOXCDtn7soHO1xfq2b0Y6rsLHGa9l7oFQSv/swCftPGu3YUsFpNBratGnD77//XmyfyAsq9UYQwS+OI2bVtyAELsNepvyhDagayLHaCoNeSWbt2rX06NGDIUOG4Ovri6+vL4MHD6Zbt26sXr3a2DFKRhD7pMLfuX9XrJwdjXpsu3rVKTV+AAAR7yxBHRtv1ONLebt48SLp6em0bt2aPXv2aGeILcmEEMT87weCO40m9eq/WLm7Unbrp3gueUcOXlmI9EoyDg4OrF69mqioKC5evMiFCxeIjo5m9erVODoa9wtMMlx6yEMSfj8OgOtI0zxpur0zEutKPqhDI4iet9Yk55Cyt2/fPlq0aMGyZcvMHYrFSH8YRdjAd4mcsQyRnIp9h6ZUCNiCY7fW5g6txDGoM6ajoyN169alXr16MrlYsJgtv4BGg6plA5NN9Wplb4fXF+8CELv5Z5JOXzbJeSRde/bs4ZVXXqFz5868/fbb5g7HIiTsO8H9tq+TeOgMCpUtHgsm4/3tYtmZ0kz0TjLHjx9nyJAhNG/enJCQEAC+/vprTpw4YbTgJMOJlFTivvkVKNg4Zfqwb9UQ5yE9AYiYshBNcopJz1fS7d69mz59+tCjRw927dqFnZ2duUMyK01CEhHTFhM2dAaaqBhsn69Kef//4Tq6r+zDZUZ6JZkffviBrl27Ym9vz8WLF0lJyfgyiYuL49NPPzVqgJJh4n89ijriEUpvz0IpKnCfMx5lGXfSbt3n0ZItJj9fSXbw4EF69+7Nzp07sbUt2dMuJF/6m+COo4jd+gsoFLhOeI3y+9eZ7Mldyj+9+sk0aNCAKVOmMGzYMJydnbl8+TKVK1fm0qVLvPjii4SFhZkiVpPStw14cHQ8c36+yp2oJCq52zPnldqUL23cStc0tYZb4fE8SkzFzcGWKl5O2Cjz9/sguPubpPx5Fbf3R1F62nCjxpWT+N+O8XD4TLBWUt7/f9jVrmr0cySmphNwI5yQR0n4uNnTtroXDrYlYw6+sLAwypYti0ajQaPRYG2t33Xr87ky5LNoCkKt5vHybUQv3gjpapTennitmolDa8sdeeRSUDhvfH2ex8kaSqmsWD+0EfX9LH+4n0LtJ3Pjxg3atGmTZb2LiwuPHz/W55BF1pyfr3Iq6DEPY1M4efsxc36+avRz3AqP53JwDGExKVy6/5hb4flrvZVy+QYpf14FG2tchrxk9Lhy4tSjDY492v43UnN6utHPEXAjnIPXwwmKTMQ/MJyAG+FGP4cl2rp1q/YHnZWVld4JBvT7XOn7WTSFtHuhPHhlItELvoJ0NY6vdKBCwGaLTjAAb3x9nvBEDakaCE/U8MbX580dkknplWS8vb25efNmlvUnTpygcuXKBgdVlNyJSkIJeJdSYa3IWDa2R4mp2CqtKOuqws5ayaPE1Hztp51e+aV2hV7p6fHZZKxcnUi59Dcx640/UnPIoyRU1koqezpib6Mk5JHx77ul2bhxI8OHD2fw4MHUrWv4YI36fK70/SwakxCCuO/3E9xuBMln/0Lh5IDXypmU+WoOSjfL7xv0OFkDgJ1SobNcXOmVZMaOHcukSZM4e/YsCoWCBw8esG3bNqZPn8748eONHaNFq+RuT7qA0MfJpIuMZWNzc7AlJV1NWEwyKelq3BzyLn9XP4ol/kd/AFxMXOGfnYyRmicAEP3Z/0gLCjHq8X3c7ElKU3M7IoGkNDU+bsW738O6desYNWoU48aNY926dVhZGV5Epc/nSp99jEkdE0f42LmEj5+PJi4B1Qu1qXB0E84DXiwylfulVBnvXYpa6CwXV3o9a7/77rvExMTQvn17kpOTadOmDXZ2dkyfPp233nrL2DFatDmv1M5SJ2NsVbwy6nieLgfPS9yOvYjkVGyfr4qqSR2jx5QfzoN7EP+jP0nHLxAxbRHePywz2hdB2+oZZdhP18kUVwkJCSxYsIBJkybxxRdfGO0e6vO50mcfY0k6eZHwCfNJDwkHpZLS74yg1KTBKAwoMjSH9UMbZamTKc70qvi/d+8e5cuXJzk5mcDAQDQaDbVq1cLR0ZH79+9TsWJFU8RqUvpWat0Me8zkby/yIDaVci62LHutAVXLlsrXvqaqvBYaDfeaDiT9zgM8l76Ly9DCq495VtrtYO63fR2RnIrn8vdxGSRH6c5vY5E0tYbrwdEkpINIiqVxDV9srUveWFoiNY3ozzbweOV2EAIbv/J4rfkQVaPn8965mDgaGMIbWy+RCtgC64fVp12twp18rlAr/v38/IiMjMTBwYHGjRvTpEkTnJyciI6Oxs+vZDUZnPztRa4/TCQ2KZ3rDxOZ/O3FvHd6wlSV14mHzpJ+5wFWrk449e1slGPqy6ZyedzeGwVA1KyV+Z7qtjjLb2OR92fNo0/3Ttx9+Jh7idbcjkgo5EjNL/WfOwS/OJbHK7aBEDgP6Un5wxtKVIIBtAkGIPXJclGhV5LJ6eEnPj4elUplUEBFzYPYVBRAKQcbFE+W88tUldexGzJGW3Ye2B0rB/O/H6XG9ceuXnU0MfFEzlhm7nDMLj+NRebNm8fST+fQqGV7KnqVMlslu7lkjBr+U8a4Y1f+xaq0K2U2f4LXF+9h5eRg7vAKXeY7r3hmuSgoUNlM5hTMCoWCWbNm4eDw35utVqs5e/Ys9evXN2qAlq6ciy0xSek8TkxDPFnOLx83ewJD44xaeZ0WFELi4bMAuI4o/Ar/7CisrfH84j2CO48h4dejJOw9hmP3rE3gS4pK7vaExKRk21hECMHs2bOZN28eE9/9kOZ9x/AwNsUslezmkh4eTcTkz0j0Pw2AfbsX8FrxAdZlPcwcmfnYkpFYxFPLRUWBkkzmFMxCCK5cuaLTy9jW1pZ69eoxffp040Zo4Za91iBLnUx+maLyOmbzbhAC+w5Nsalc3uDjGYtdnWqUmvAaj7/cRsS7S1G1bIDS1dncYZlFbo1Fzp49y7x581i4cCFTpk3P0vGxuEs4cJLwSZ+hiXyMws6W0rPexHV0HxRGaE1XlK0fVj9LnUxRUaAkc+TIEQBGjBjB8uXL5XwVQLpaQ1KamrQn/01X57/Nu4OtNd3qlDNaLJrEZOK27QHAdZRlPMU8zW36CBL2BJB2O5joj9fiueQdc4dUaCLjk1l39CZBEYn4eTrwWb/6eDj9V5QphEChUFC3YWNW7NyPqmwVDgaG0ba6FzW8i//fmSYxmag5q4jdtBsA2+er4LVmFnY1S1a/O8h+RIB2tXz457PCreg3Fr1+HmzatEkmmCcmbL/A7agU4lM03I5KYcL2C2aLJf7Hg2hi4rH29cahYzOzxZETK3s7PL94D4DYrb+QdDL/jSSKunVHb7I/MJw70YnsuxbOuqP/dWYWQjBp0iSWLVtGwI1wgvAqUSMZpFy+kTHu2JME4/rmAHz2rSuRCQaK34gAeiWZBQsWsHHjxizrN27cyMKFCw0OqigJj08DwFml1FkubEIIYp5U+LsM72Wx08bat6iPy+uvABAxZRGapJIxUnNQRCK2Siv8PByxs7YiKCIRyJjJcvz48axYsQIHB4cSNZKBUKt5tPwbgl8cS9rNeyjLeuC96ws8Pn4LK1XJHVG6uI0IoFeSWbduHTVq1Miy/vnnn2ft2pI1YVVplTUCiE1WI54sQ0Yfh79DYzl9K5K/Q2NJK0Axmj5S/rxK6tV/UahsLb4vSulZ41CW9SAtKJhHizfle7/E1HR+v/KA/x27xe9XHpCYavwx0UzFz9OBlHQNQZEJpKRr8PN0QK1WM2bMGNatW8fGjRt54403LGokA2Pdb6FWk3TyInE/HiTp5EWEWk1a8EMe9JlM9Px1GeOO9WybMe5Y28ZGvoqi4eq9SFp9up+aH/5G6pOviuIyIoBePf/CwsLw9vbOst7T05PQ0FCDgypKqno5cjcmVWcZ/htI0FZpxd2ojF+tpixb145T1rsTytKuJjuPMShdnPBcNJWwYR/wePW3OPXqgF3d5/LcL7NfkcpaSWBoHIBR67RMaWy7jJGoM+tkxraryqJFi9i8eTNbt25lyJAhgGWNZGCM+x2/J4DImctRP4jQrrNyc0EkpyKSklE42uOxYDLOr3UrMsPCmMK4b84THKubxG2tKBYjAuiVZCpUqMDJkyezdLw8efIk5coVjT96Y3mUpMHLyZYKpR24H53Io6SMnyFPDyQYFpNs0j4O6eHRxP+S0SjDdZRpplc2NsdurXF8uT0JvxwhfMpCyu9fl+fwIE8XJd2OSChSRUkeTipm9tQdcmj8+PHUr1+fbt26adcZuzGIIQy93/F7Ang48sP/2t0+oXkUC4B1lQqU27EYG7+iWaFtTFGJGQnGwcaKxDQN9tZwfb5ll0jkl17PYaNHj2by5Mls2rSJu3fvcvfuXTZu3MiUKVMYM2aMsWO0aH6e9qSpBfejE0lTC/w8M4o3CnMgwbivf4W0dOwa1cKuXnWTncfYPBZMxqqUM6l//cPjNTvz3N6SipL0lZaWxoQJE7h58yaurq46CcbSGHK/hVpN5MzlWRKMzjZJyVhXLGuESIs+d4eMH1iJaRqd5eJA7wEyo6OjGT9+PKmpGb/QVSoV7733HjNmzDBqgJbunRdrAtcJikjCz9P+yXLhDSQo0tOJ2foLUHSeYjJZe5XG/eO3iJi4gEeLNuLYvQ22VSrkuL0lFSXpIyUlhQEDBrB371569OhB1arGn8zNmAy538ln/tIpIsuO+kEEyWf+wr5l/vuWFVdrhzRi3DfniUpMx93BmrVDinYR2dP0GiAzU3x8PNevX8fe3p5q1aoV6TnG9R38zdzi9wTwcMSHWHmUotKlH1DYFaW+wBmt4kL7TSUp4Byqlg0o99PyYlk2n5yczKuvvsrBgwf58ccf6d69u7lDMhlNcgqRM5YR982ePLf1Wjcb5z6dCiEqyVCFOkBmJicnJ1544QVq165dpBNMUaZttjzkpSKXYCBjiCLPz99B4aAi+eTFfH0xFUUDBw7k0KFD/PLLL8U2waSHRxO9cAN3G7ya7/exsCfTkwpfvp9kpk6dyrx583B0dNSOYZaTpUuXGiW4wlQUn2RSbwRxv9UwsLKi4vnvsClfxtwh6e3xmm+JmrUKKxcnKpz8utiNU7V//35sbGzo0KGDuUMxupTAW8Ss/Y64H/whNaOfmLJ8GURsPJq4hOzrZRSgLOeF7/nvLLZPl6RL3+/IfNfJXLx4kbS0NO3/S+YXs3E3AI4vtizSCQbA9Y1+xP90mJSL14l8/wvKbv7E3CEZLD4+nrVr1zJ16lS6du1q7nCMSmg0JB46S8y670gKOKddb/dCbUqN7Y9jj9Yk7DuZ0bpMgW6ieVIa6jF/okwwJYBBdTLFSVF7ktHEJXCnTm9EQhLeP3yBQ5ui34kt5dpNgjuNhnQ1ZTbNx6lnW3OHpLfY2Fi6devGlStXOHfuHM89l3c/oKJAk5hM3Hf7iFn3PWk372WsVCpx7NmWUuP6o2qsO89Ldv1klD5eeMyfWKTf35LI5E8yeRWRZVIoFCxZsiTfAUj6iftuPyIhCZtqvti3Lh4tUeyer0qptwfz+IutRL63FPtWDVGWKnojNT9+/JgXX3yRv//+m4MHDxaLBJMeFknMhh+J3fKztp+LlbMjzkNfwnV0X2wqZN8U2alnWxy7tSL5zF+kP4zCuow7qmZ15RNMCVKg4rKnnT9/HrVaTfXqGf0y/vnnH5RKJY0aFY8vPEuWMaFTRoW/64hexao1ltvUYST8epS0m/eImrMKr2XvmzukAomJiaFTp04EBQVx+PBhGjZsaO6QDJLy1z88Xvcd8T8dgrSMDoPWvt64vtEfl0Hd8zWBmEKplM2US7B8J5nMYf4ho2Lf2dmZLVu24ObmBsCjR48YMWIErVu3Nn6Uko6kExdI++cuCkd7nF+z3M58+rBSZYzU/OClCcRt+w2nvp1xKEJPak5OTjRv3pwNGzZQr149c4ejF6HRkHjgFI/X7CT51CXtelWzeriO64/jiy3lk4iUb3rVyfj4+HDgwAGef163/PXq1at06dKFBw8eGC3AwlKU6mTChn9Iwm8BuIzoheeiaeYOxyQi3l1C7KbdWFcqR4WALRYxjXRuwsPD+eeff2jVqpW5Q9GbJj6RuG9/J2b9LtKCgjNWWitxeqUDruP6o6qfdVBcqeQweZ3Msyd7+PBhliQTHh5OXFycPoeU8ik95CEJvx8HwHVk0erhXxDuH40jYf8p0u884NHijbjPHm/ukHIUGhpKx44dSUtLIzAwEBsbG3OHVCDpIQ8z6lu2/oImJh4AK1cnXIa9guvoPliXK1ojK0iWRa8k07t3b0aMGMGSJUto1ixjcqwzZ87wzjvv0KdP8f3iswQxW34BjQZVywbY1vDLe4ciysrZEc/FUwkb/D6PV+/EqVdHixyXLSQkhA4dOpCQkMCRI0csKsEItTrXCvfki9eJWfsd8T8fAbUaAJvK5XEd2x/nAS9i5Vj0xoaTLI9eSWbt2rVMnz6dIUOGaPvOWFtbM2rUKBYvXmzUAKX/iJRU4r75FQDXkZY3vbKxOXZpiVPvjsT/dIjwSZ9R3v8rFDaWM3DgvXv36NChA2lpaQQEBFClShVzh6SVbdPhcp54zHsbrKyIWbOT5D+uaF9TtWpIqXH9cejcHIVV0Z6/RLIsen2aHBwcWL16NVFRUVy8eJELFy4QHR3N6tWrcXR0NEpgq1evxs/PD5VKRaNGjTh+/Hi+9jt58iTW1tbUr1/fKHFYkvhfj6KOeISyrAeO3UpGAwv3TyZh5eZC6rWbPF61w9zh6EhKSsLLy8siE8zDkR9mGaBS/SCCh6Nm8XDEhxkJxsYap/4vUv7wRnx+Wo5j15YywUhGZ9AnytHRkbp161KvXj2jJReAnTt3MnnyZGbOnMnFixdp3bo13bp14969e7nuFxMTw7Bhw+jYsaPRYrEkmROTuQx/xaJ+0ZuStadbxq9v4NHnm0m9lftnoDAEBQURFxdH9erVOXnyJJUqVTJ3SFr5GWIfhQLXSUPwvfA9ZVbNxK5OtUKLTyp59E4yx48fZ8iQITRv3pyQkBAAvv76a06cOGFwUEuXLmXUqFGMHj2amjVrsmzZMipUqMCaNWty3W/s2LEMGjSI5s2bGxyDpUm5fIOUP6+CjTUuQ14ydziFyql/V+zbvYBISSViyiKExnxznt+4cYNWrVoxefJkAIvro5SfIfYRAsf2TYrd+HCSZdIryfzwww907doVe3t7Ll68SEpKCgBxcXF8+umnBgWUmprK+fPn6dKli876Ll26cOrUqRz327RpE7du3WL27Nn5Ok9KSgqxsbE6/yyZdnrlnm1L3Mi1CoUCzyVPRmo+fZnYr381SxyBgYG0a9eOUqVK8cknljm2WvrDKKNuJ0mG0ivJzJ8/n7Vr1/LVV1/ptKZp0aIFFy5cMCigyMhI1Go1ZcroDvhYpkwZwsLCst3n33//5f3332fbtm1Y5zGFb6YFCxbg6uqq/VehQs6TZZmb+lEs8T/6A+BSxCYmMxabit6UnpEx62r03DWkh+bxa93Irly5Qrt27fD09OTIkSOULWuZMzrm9wdISfuhIpmPXgX7N27coE2bNlnWu7i48PjxY0NjArIWQwghsi2aUKvVDBo0iLlz5xZojKgZM2bojMcWGxurV6IJjo5nzs9XuROVRCV3e+a8UpvypY07C2bcjr2I5FRsn6+Kqkkdox67KEhMTSfgRjgh1RvQrGZV7K/fJOLdpZTZNI+Us1eMNiaW9jxPzQTpYJvxJ3L48GHKly+Pv78/7u6W+wVtU7UCWCshXZ39Bk+G2Fc1q5vlpdyu31gK4xyWLDI+mXVHbxIUkYifpwNj21XFw8myOxobSq9319vbm5s3b2ap8Dxx4gSVK1c2KCAPDw+USmWWp5bw8PAsTzeQUUR37tw5Ll68yFtvvQWARqNBCIG1tTUHDhzIdg4POzs7o0y0NuGbc1x+kADAzchEIr45x88T2xl83ExCoyFmU0ZRmeuo3hZXB1AYAm6Ec/B6OCprJWG9+/Hav4tI3HeCu7Ve0Q7WCE+a6H4yqcCj+6apNdwKj2fftQdcC47Fw0VFYGhGp+Im5exwd3dn0qRJjB07FpUq6xdC5v5PT7Vto8y7kODvkGgmbL/Aw/g0HJTQs355XvBzz/LFGxaTyOJ9ulN8l3XNOmaYOuoxof2n5ZpgIOch9p++z5nX361OuTyvoyAK4xzZ0fc9MrZ1R2+yPzAcW6UVN8IzOr7O7Fm7wMe5ExHLe7suc/9RChXc7Fj4aj0qeVrmSCV63eWxY8cyadIkzp49i0Kh4MGDB2zbto3p06czfrxhPbNtbW1p1KgR/v7+Ouv9/f1p0aJFlu1dXFy4cuUKly5d0v4bN24c1atX59KlSzRt2tSgePKSmWByWjZU4qGzpN95gJWrE059Ohv12EVFyKMkVNZKKns6kuRbgfAXMgadfDrBAKhDI3g48kPi9wQU6Pi3wuO5HBzD7YhEYpLTcVZZY2+j5Njxk1StWpWffspI8tklmKf3D4tJ4dL9x9x68uWRlwnbL3A7KoX4FA3hiRq+/eMe/oHhBNwI19lu8b7rHLkRRfDjJA7/HcXifdezHEsd9ZgHfSeTGngbZRl33BdMRlnOU2cbZTkvymzMeQqFp++zvY2SkEdJ+bqOgiiMc2RH3/fI2IIiErFVWuHn4YidtRVBEYl6Hee9XZe5GBxHVEIqF4LjeG/XZSNHajx6Pcm8++67xMTE0L59e5KTk2nTpg12dnZMnz5d+zRhiKlTpzJ06FAaN25M8+bNWb9+Pffu3WPcuHFARlFXSEgIW7duxcrKitq1dX8JeHl5oVKpsqwvimKfjLbsPLB7ie2B7eNmT2BoHLcjEkhOSaP0v7ey31AACoj88Escu7XKd9HZo8RUbJVWVPZ05MHjZG6HJxB39yq/LnqbRg0b5NkkPnP/sq4qwmKSeZSYmq/zhsdndGS2UypIUQvSNWT7xRsUkYS1lYIKpR24H51IUITu6+roGB70nULqtVsovUpTbveX2FatiOuIXgUaYv/p+5yUpsbHzfift8I4R3b0fY+Mzc/TgRvh8QRFJpCSrsHPM+9RrLNz/1EKCsDT2Y6IuBTuP0oxbqBGpHdh6CeffMLMmTMJDAxEo9FQq1YtnJyMUxcxYMAAoqKi+PjjjwkNDaV27drs3bsXX19fIGOsqLz6zBQWGyDtmWVjSQsKIfHQWQBcRxT/Hv45aVs9Y+yskEdJVL17G5vI6Jw3FqAOCSf5zF/5Hl7ezcGWu1GJeLs4UL1MCg9vXOCXhW/RpMkL/LZnT56f68z9w2KSSUlX4+Zgm6/zejnZEJ+SQoo6o1OLtRXZfvH6edoTFJXI/ehE0tQCP8//Xlc/in2SYG6i9PwvwUDBh9h/+j5n1pcYW2GcIzv6vkfGNrZdVQCdOhl9VHCzIyIhlYi4FDRPli1VgUdhTktLo0uXLqxbt65YTMaUSd8RRnf9EcT0HwO1y5/3qcWrTYwzpljk7FXErP4W+w5NKbfzc6Mcs6iL+/Eg4WPn5rmd17rZOPfplK9jPl1eX8rehtH9euDgYM/PP/+Mg0PevzSzK++Pik/Osx7l75Boxn1zjnuP0lAANcraM7xFVXrUL5evOpmMBDOZ1Cv/ovR0w+uHZdwr5WH2egdLZCl1MvmVV7zmqJMptFGYbWxsuHr1aomsgM7O9+eDsVUqUACaJ8vGSDKaxGTitv8GZFT4SxlM0UTXRmlFDW8X0tPTsba25tdff8He3h57+/wV5WTu/7TMehRrKwVBUYnAdZYM0J0Xp4ZPaTo/X05bERybouGf8Fj62VbU2a6sq0OWfdWP4wjtN5XUK/9i5VGKcj8t57aLO5eDY7BVWnE3KqOs/9m4Sqrs3iNLllmHlNN7WcnThZ1vFo2hpfRK5cOGDWPDhg3GjqVIerps1OrJsjHE/3gQzeM4rH29cejYzCjHLA5UzepmVGjn8Rsn4fBZNPH5r1Tds2cPderUITQ0lNKlS+c7weTk6XoUG6UiSz3Kf9sVvCJYHZORYFIu38hIMD8ux7a6n069g5210mz1DpLhitN7qVedTGpqKv/73//w9/encePGWcYtW7p0qVGCKwpMUTb69PTKLsN7yVkIn6JQKvH4ZBIPR36YkWhyKOyN+XIb8d/+TukZY3Ae2C3Xe/jTTz8xYMAAevbsabQ+MLnVo+huV7CK4IwEM42US39j5e5KuR+XY1czo9uApdQ7SIYrTu+lXjNjtm/fPucDKhQcPnzYoKDMQd/yRlOUjSb/eZWQ7m+iUNnie/lHlKVdDTpecZTtUPY+XnjMexuF0oqoOWu0szvaPl8V93lvZTuN8/fff8+gQYPo3bs327ZtM9p8MPnt21KQznnq2PiMJ5gL17Eq7Uq5n5ZjV+u/0Z+LWr2DlDNLfC/1/Y7UK8k8LXP3ol5HY0nTLz8c9zHxP/jjPLA7Xl/OMGssliy3SblEahoxG3/k0eebtbM9OrzYCvc5b2JbJaPOIywsjMqVK9O7d2+2bNmS7yGJzEETl8CDflNJOR+YkWB+WIZdbf1aJkmSPgo9yWzYsIEvvviCf//9F4Bq1aoxefJkRo8erc/hzE7fGxiTlMp3f97jbmQivh4O9H+hIq72+j/apodHc7d+X0hLp/zB/+U4G6Ql/tIxJmPdV3XUY6IXbyJ2888Zsz9aK3Ed2QflW4PY9e9jTp/5g6YvNOK1Zn4GvW+mpIlL4MGA6aT8eRUrN5eMBFNIw/OXxGFQCkNR/PsttNZlAB999BFffPEFb7/9tnZY/dOnTzNlyhTu3LnD/Pnz9TlskfTdn/f4/epDVNZWXHsyTMaYNvr/woz7Zg+kpWPXqFau0w3n1fqkqDPWfVW6l8Lzsym4juxN1OxVJB48w1dfruD6iiXUfWU8qS1asi8wEiul0qD3zVQ08YmEZiaYUs6U2/VFoc7/YqxhUCRdxf3v92l6pc41a9bw1VdfsWDBAl5++WVefvllFixYwPr161m7dq2xY7RodyMTUVlbUcXLCXsbK+5G6jdMBIBITydmy88AuOYx2nJxan2SHWPeVwDb5yrhvWMxvw1pyQfxN1GnpfLK778yddUX1L/5N3cjjDsckDFkJpjkpxNM3cLtm2asYVAkXcX97/dpeiUZtVpN48aNs6xv1KgR6enpBgdVlPh6OJCUlvHom5SmwddDv2EiABL2nUT9IBwrj1I4vZxz4wrIaH2Skq4uFq1PsmPM+5ppxYoVTFq2kIlvv83AqXNJdHTC9WE4Q77ZzCurVpNy7aYRIjcOTXwioa+9Q/IfV7BydcpIMLk82ZqKn6cDKekag4dBkXQV97/fp+lVXDZkyBDWrFmTpany+vXrGTx4sFECKyr6v5BRifx03YG+YjY8abY8uCcKu9w/dFW8MoY6ebpMtzgx5n0F+PXXX5k4cSLTp09n0aJFxCan8UPr5rh98xO1Dx2h9JVAgjuMwnlQd0q/P9qs861o4hMJHfguyWf/wsrFCW8zJRgw3jAokq7i/vf7NL0q/t9++222bt1KhQoVaNYso6PgmTNnuH//PsOGDdNpBlpU+szoW6mV0/wYBa3YS70RxP1Ww8DKiornv8OmfJlcj19cmaqiOS0tje+++44evV/l+3P3dZKXQ3gkUfPWkfBzRtN7haM9bpOH4jq2P1b2hTsmlCYhidCB75B8+jJWzo54//AFqgY19T6eqSuYi2IFtqFK2t9kpkKt+L969SoNG2YMt37rVsaIuJ6ennh6enL16lXtdkW9WXN+5DQ/RkEr9mI27gbA8cWW2gST2/GLK2NXNC9cuJA2bdrQvHlzBg8ezFfHbmbboKDs/+aSNKYvUR+tJOXidaI/WU/s1l8o/dFYnHp1LJTPsiYhidDB7/2XYL5fYlCCAdNXMJekCuxMJe1v0lB6JZkjR44YO44i6+n5MW5HJGiHaS/I0OKauATidv4OgMtI3Qr/nI5fXD1d0RwUmaB3RbMQglmzZjF//nwWL16sbQX5dIOCW+HxOg0K7JvWxWffWuJ/PEjUvHWk3w8j/I25xKzfhce8t1E1ft4o15gdTWIyoUPeJ/nkRRRODnh/twRVI8PPZ+oh7i1lCP3CVNL+Jg1VvJ9rC4GPmz1Jaeos82MUpGIv7rv9iIQkbKpWxL6Nbq/0nI5fXBmjolkIwYwZM5g/fz4LFy5k+vTp2tfyalCgsLLC+dUuVDy9Dbf3R6FwUJFy7hoh3cbxcOxc0oIfGnyNz9IkJhM25D2ST1xA4eRAue+WGC2hmbqCuSRVYGcqaX+Thir+BYkmltP8GPmt2Ht6nDLXkVmnVzbX/BvmYoyK5rlz57Jw4UK++OILJk+erPNafhsUWDmoKD1tOC6DexL96VfEffs78T8eJGHvMVzHDcBt0hCsnAxvaaVJSiFs2AySjl9A4WhPuZ2fo3rBeP1QTF3BXJIqsDOVtL9JQxk8rExxYa5hZRKPnye0z2QUjvZUuvITVs6Oee8k5erSpUv88ccfvPHGG0Y7ZsrlG0TOWknyqUsAKL1K52vwzdxoE8zRP1E42OP93efYN61rtJglyZj0/Y6UxWVmFrshY/545/5dZYIxgEajYeXKlSQnJ1O/fn2jJhgAu3rVKbf7S8pu+QQbv/Kow6OJmLKQ4I6jSTx+vuDxJj+TYL5dLBOMVCzJJGNG6SEPSdh3AgDXkbn38JdyplarGTNmDBMnTuTo0aMmO49CocCxexsqnNiK+7y3sHJ1IvXaTUL7TCZ06AxSb+VvSnBNcgphr8/8L8HsWIR983omi1uSzEkmGTOK2fILqNWoWjbAtoZxpmwuadRqNSNGjGDz5s1s3bqVF1980eTnVNjaUGrcACqe3YHLqD6gVJK47wT3Ww0j8sMvUT+K1W4r1GqSTl4k7seDJJ28iCYxiYfDPyTp8FkUDiq8ty/EvkV9k8csSeYi62SeKOw6GZGSyt0Gr6KOeESZDR/nOYyMlJVGo2Hw4MF8//33fPPNN7z22mtmiSP1nzvawTcBrNxccJs+AusypYmctVJnzhuFnS0iJRWFvR3eOxZj37KBWWKWpIIq1M6YkuHi9wSgjniEsqwHjt2KxlzdlkahUPDcc8+xc+dO+vbta7Y4MgffTDzyB1GzV5F6/TZRM5dnu61IyehH4vr2IJlgpBJBFpeZiXacsuGvoLCRub4gUlJSOHToEAqFgrlz55o1wTzNoX0Tyh/egPuiqWCV+wgBcdt+Q6jVhRSZJJmP/HYzkD5jbaX89Q8pf14FG2tchrxUSJEWD8nJybz66qscOXKE27dvU6ZMmbx3MpK8xunKHNMqIcmGBprcS6HVIeEkn/nLZE8zxp5MT9JVUscv04e8KwbSZ6ytzM6XTj3bmnW036ImKSmJXr16cezYMX7++edCTTCQ9zhdmWNa1boXnq/jpT+MMkmcYPzJ9CRdcvyy/JPFZQYq6KRO6kexxP/gD5DRMknKl4SEBHr27MmJEyf47bff6NKlS6HHkNdEU5ljWpXyzV/yM+UPDGNP+ibpenr8MnsbpRy/LBcyyRiooGNtxe3Yi0hOxfb5qqia1CmkKIu+xMRE4uPj2bdvHx06dDBLDHmN05U5ptVFz/LElSpFjgVmClD6eKFqZrrOl6aY9E36jxy/LP9kcZmBCjLWltBoiNmU0cPfdVTWccqkrGJjY0lISMDb25szZ86Y9Z7lNU7X02NaxU0eifPcJ3MpPZ1tnoTvMX+i3sPR5IexJ32TdMnxy/JPJhkD2Sit8HJRkZSqwctFleuETYmHzpJ+5wFWrk449elciFEWTY8fP6Zr165YWVlx6tQpsydlG6VVrnOlONha/1cu36YK8b6liZy5XKefTHwpN4JGD8KjY/NsjxEcHc+cn69yJyqJSu72zHmlNuVLF3zQSVd7W1kHY0I677WUK5lkDFSQCtbYJxX+zgO7Y+UoH69zEx0dTZcuXQgKCsLf39/sCUYfTj3b4titFcln/sL/6DWOPdYQWqUKiWqI+PNetp+TOT9f5VTQY5RASEwKc36+yv9GNCv84CXJSGSdjIHyW8GaFhRC4qGzALgM71WIERY9kZGRdOjQgbt373L48GHtLKxFkUKpxL5lA/6oWZcH1apRuaxLrp+TO1FJKAHvUiqsFRnLklSUySRjoPxWsMZs3g1CYN+hKbZVKhRukEXMmTNniIiI4MiRI9SrVzwGjszv56SSuz3pAkIfJ5MuMpYlqSiTxWUGyk8FqyYxmbjtvwEZFf5S9jLHROrZsyf//PMPjo7FZ+qD/FbEz3mldpY6GUkqyuQAmU+YcoDM2G/2EDFlIdYVvan4xw6TtioqqkJCQujQoQOjRo3i3XffNXc4kiQ9Q05aZqGenl7ZZUQvmWCyce/ePdq2bUtycrLFjEMmSZJxyOIyE0s5d43UK/+iUNniMqiHucOxOHfu3KF9+/YoFAoCAgKoVKmSuUOSJMmI5JOMiWWOtuzUuxPK0q5mjsbyfPLJJ1hbW8sEI0nFlHySMaH08GjifzkCgKscp0yHRqPBysqKFStW8PjxY8qWLWvukCRJMgGZZAwUFpPI4n3XCYpIws/TnnderElZ14zmqXHf7IG0dOwa1cKuXnWTxVDUhh0PDAykX79+7Nixg7p161pUgrGEIfLzmlJAyj95L83Pcr+JiojF+65z5EYU1lYKgqISgessGdAIkZ5OzJafAdM/xRSlYcevXLlCx44dKVu2rEUll0yWMER+XlMKSPkn76X5yZRuoKCIJKytFFQo7YCNUkFQREYP7YR9J1E/CMfKoxROL7c3aQxFZdjxixcv0r59e8qXL8+RI0fw8rK8QQUtYYj8vKYUkPJP3kvzk0nGQH6e9qSpBfejE0lTC/w8M3poZ45T5jK4Jwo70xa3FIVhx9PS0ujTpw+VK1fm0KFDuLtb5mRtljBEfl5TCkj5J++l+cniMgO982JNQLdOJvVGEEnHL4CVVaGMU1YUhh23sbHhhx9+oEqVKri6Wm4rO0sYIj+vKQWk/JP30vxkj/8njNnjP+K9L4jd+COO3VtTdsunRoqwaDpx4gTr1q1j48aN2NjYmDscSZL0JHv8WwhNXAJxO38HwGVkyW62fPToUV588UWCg4NJTZVl4ZJUEskkY2Rx3+1HJCRhU7Ui9m0amTscszl48CDdu3enefPm/Pbbb8VqsEtJkvJPJhkjenqcMteRJXd65evXr9OzZ0/atm3LL7/8goODnF9ekkoqWfFvBEKtJvnMXyQeO0faP3fBXoXTgBfNHZbZ1KhRg5UrVzJ06FDs7OzMHY4kSWYkk4yB4vcEZJnHXaGApGPncerZ1oyRFb7du3ej0Wjo06cPo0ePNnc4kiRZAFlcZoD4PQE8HPmhToIBEInJPBz5IfF7AswUWeH7/vvv6devHz/99JO5Q5EkyYJYbJJZvXo1fn5+qFQqGjVqxPHjx3Pc9scff6Rz5854enri4uJC8+bN2b9/v0njE2o1kTOXQy4NwCM//BKhVps0Dkuwfft2XnvtNQYMGMCmTZvMHY4kSRbEIpPMzp07mTx5MjNnzuTixYu0bt2abt26ce/evWy3P3bsGJ07d2bv3r2cP3+e9u3b89JLL3Hx4kWTxZh85q8sTzA6BKhDwkk+85fJYrAEu3btYujQoQwbNowtW7ZgbS1LYCVJ+o9FdsZs2rQpDRs2ZM2aNdp1NWvWpFevXixYsCBfx3j++ecZMGAAs2bNytf2Be1oFPfjQcLHzs1zO691s3Hu0ylfMRRF9+/fZ/369cydOxcrK4v8zSJJkhEUm86YqampnD9/ni5duuis79KlC6dOncrXMTQaDXFxcZQuXTrHbVJSUoiNjdX5VxDWZfI39lZ+tytqvv32WyIjI6lQoQLz5s2TCUaSpGxZ3DdDZGQkarWaMmXK6KwvU6YMYWFh+TrGkiVLSEhIoH///jlus2DBAlxdXbX/KlSoUKA4Vc3qoiznCTl1hVGA0scLVbO6BTpuUbBixQoGDhzIli1bzB2KJEkWzuKSTKZnOzIKIfLVuXHHjh3MmTOHnTt35jqU/IwZM4iJidH+u3//fsHiUyrx+GTSk4VnX8z4j8f8iSiUygId19ItXbqUiRMnMn36dKZOnWrucCRJsnAWl2Q8PDxQKpVZnlrCw8OzPN08a+fOnYwaNYrvvvuOTp1yrwexs7PDxcVF519BOfVsS5mN81F6e+qsV5bzoszG+cWun8zChQuZNm0aH3zwAYsWLSqxIxpIkpR/FtcUyNbWlkaNGuHv70/v3r216/39/XnllVdy3G/Hjh2MHDmSHTt20KNHj8IIFchINI7dWpF85i/SH0ZhXcYdVbO6xe4JBsDd3Z3Zs2cze/ZsmWAkScoXi0syAFOnTmXo0KE0btyY5s2bs379eu7du8e4ceOAjKKukJAQtm7dCmQkmGHDhrF8+XKaNWumfQqyt7cvlLlLFEol9i0bmPw85iCE4NixY7Rt21b24pckqcAsrrgMYMCAASxbtoyPP/6Y+vXrc+zYMfbu3Yuvry8AoaGhOn1m1q1bR3p6OhMmTMDb21v7b9KkSea6hGJBCMEHH3xAu3btOHv2rLnDkSSpCLLIfjLmYMxJy4oDIQTTp09n6dKlfPHFF0yePNncIUmSZEb6fkdaZHGZZF5CCCZOnMjKlStZuXIlEyZMMHdIkiQVUTLJSFmkpKRw9epV1q1bxxtvvGHucCRJKsJkkpG01Go1ISEhVKxYkYMHD6Ishi3kJEkqXBZZ8S8VPrVazYgRI2jevDkJCQkywUiSZBTySUYiPT2doUOH8v333/PNN9/g6Oho7pAkSSomZJIp4dLS0hg4cCA///wzO3fupG/fvuYOSZKkYkQmmRIuMDCQQ4cOsWvXrlxHVJAkSdKHTDIlVHJyMtbW1tSrV487d+4UysgIkiSVPLLivwRKSkrilVde0Q7TIxOMJEmmIp9kSpiEhARefvllzpw5w549e8wdjiRJxZxMMiVIXFwcPXv25MKFC+zbt4/WrVubOyRJkoo5mWRKkPXr13Px4kX2799PixYtzB2OJEklgBwg84niPEBm5qyiGo2GW7duUa1aNXOHJElSEaPvd6Ss+C/moqOjadWqFQcOHMDKykomGEmSCpUsLivGIiMj6dSpEyEhIXlOXS1JkmQKMskUU+Hh4XTs2JHw8HCOHDlC7dq1zR2SJEklkEwyxdSwYcOIiooiICCAGjVqmDscSZJKKJlkiqlVq1ahVqt57rnnzB2KJEklmKz4L0bu3bvHwIEDiYmJoUqVKjLBSJJkdvJJppgICgqiffv2WFlZERMTI4eKkSTJIsgnmWLg5s2btG3bFhsbGwICAqhYsaK5Q5IkSQJkkinyYmNjadeuHQ4ODhw9epQKFSqYOyRJkiQtmWSKOBcXFz755BOOHj2Kj4+PucORJEnSIZNMEXXlyhXWrVsHwOuvv07ZsmXNHJEkSVJWMskUQZcuXaJ9+/asW7eO1NRUc4cjSZKUI5lkiphz587RoUMHKleuzKFDh7C1tTV3SJIkSTmSSaYIuXTpEp06daJGjRr4+/vj5uZm7pAkSZJyJZNMEeLn58eQIUPYv3+/7AcjSVKRIDtjFgHHjh2jXLlyVK1alZUrV5o7HEmSpHyTTzIW7uDBg7z44ovMnz/f3KFIkiQVmEwyFmzfvn289NJLtGvXjrVr15o7HEmSpAKTxWUG6rV8L5dC/5vBur63gt2Tuhf4OGlqDbfC43mUmIqbgy1//3GUAf370bVrV77//nvs7OyMGbZkoMTUdAJuhBPyKAkfN3vaVvfCwbbgf07GOo5UeJ79W63i5YSNUv5ez4n8NBvo6QST3XJ+3QqP53JwDLZKK+5GJRKfmE7fvn3ZvHmzbKZsgQJuhHPwejgqayWBoXEAdKtTzmzHkQrPs3+rADW88z/nfUkjk4yZZf4qOvFvBClqDTZRt3GvVIvnm7Rh9MA+ee4fk5TKd3/e425kIr4eDvR/oSKu9jIpFYQ+TxMhj5JQWSup7OnI7YgEQh4l6XVuYx1HKjyPElOxVVpR1lVFWEwyjxJlh+jcyGc8M8v8VZSSLvjtx12MG9Ad/5934uaQv0Tx3Z/3+P3qQ25HJrD3ykO++/OeiSMufjKfJoIiE/EPDCfgRnie+/i42ZOUpuZ2RAJJaWp83Oz1OrexjiMVHjcHW1LS1YTFJJOSrs7332pJJZ9kzCzzV9GDc/vYt/pDmnXpxdtjR1HFyylf+9+NTERlbUUVLyduhcdzNzLRxBEXP/o8TbSt7qXdN/PpRx/GOo5UeDL/Np+uk5FyJpOMmbk52LJ500a++vR92r08gJWr1lDLp1S+9/f1cOBaaBy3wuNJStPg6+FgumCLKR83ewJD4wr0NOFga22UuhNjHUcqPDZKK1kHUwAyyZhZZU9HQq+fp/eg4cxbuJRqZQv24e3/QkXS0gWnb0Xh7mhNaXsbElPTZQulAnj6acLDyZbUNA3/O3YrX/UzxqgTi4xPZt3RmwRFJOLn6cCo1pV5nJguWy8VAbKlWd7kN5EZPXjwgHLlyvHjzm0olUoUCkWBj+Fqb4ufpyO3IhNQWSs5efsRDiobvX8dl8Q/mqefJn6/8iDH1l7ZNRDIrBNTWVtx9UEs4bHJdKhZtkD3bt3Rm+wPDMdWacWN8HiiE1JoUtlLtl4yk4I0BJEtzfJWvL89LNjSpUt57rnnuHXrFtbW1nolmExP1ynY2ygNaqGU+UcTFpPCpfuPuRUer/exiqLc7mV2DQSerhMTQnAlJLbA9y4oIhFbpRV+Ho7YWVsRFJGkbb1kZ62UrZcKWUEagjzd0ky+V9mTScYMPvvsM6ZNm8bEiROpXLmywcczZgulkv5Hk9u9zC4B+Xo4kJSW8fQXn6qmjItdge+dn6cDKekagiITSEnX4OdpL1svmVFBfrTJlmZ5k8VlhWzevHnMmjWL2bNnM3v2bIOeYDIZs4WSm4Mtd6MSS+wfTW73MrsGAi2qegAZrfyeL+dMOTfHAt+7se2qAuRaJyMVnoI0BJEtzfKmEELo10W9mImNjcXV1ZWYmBhcXPJfpjp03WGOB/33S6e1nz1fj+2Q7baRkZHUrl2bt99+m5kzZxocsymUxDqZ/MqrrF7eu+JBDvWTPX2/I2WSeULfG5gfQghSUlJQqVRERUXh7u5u1ONLkiSZmr7fkfJnlokJIZg+fTrdunVDrVbLBCNJUokik4wJCSGYOHEiS5cu5dVXX0WpVJo7JEmSpEIlCxpNRKPRMH78eNatW8e6det44403zB2SJElSoZNJxkT27dvH+vXr2bhxIyNGjDB3OJIkSWYhk4yRCSFQKBR0796dy5cvU6dOHXOHJEmSZDayTsaI0tPTGTx4MBs2bACQCUaSpBLPYpPM6tWr8fPzQ6VS0ahRI44fP57r9gEBATRq1AiVSkXlypVZu3ZtIUWaIS0tjddee43vv/+eUqVKFeq5JUmSLJVFJpmdO3cyefJkZs6cycWLF2ndujXdunXj3r3sJ+QKCgqie/futG7dmosXL/LBBx8wceJEfvjhh0KJNyUlhX79+vHLL7/www8/0Ldv30I5ryRJkqWzyM6YTZs2pWHDhqxZs0a7rmbNmvTq1YsFCxZk2f69997jl19+4fr169p148aN4/Lly5w+fTrbc6SkpJCSkqJdjo2NpUKFCnp1xpw2bRqrVq3ixx9/pHv37gXaV5IkqSgoNp0xU1NTOX/+PF26dNFZ36VLF06dOpXtPqdPn86yfdeuXTl37hxpaWnZ7rNgwQJcXV21/ypUqKB3zDNmzGD//v0ywUiSJD3D4pJMZGQkarWaMmXK6KwvU6YMYWFh2e4TFhaW7fbp6elERkZmu8+MGTOIiYnR/rt//77eMXt4eNC2bVu995ckSSquLLYJ87OjE2c2DS7I9tmtz2RnZ4ednZ2BUUqSJEm5sbgnGQ8PD5RKZZanlvDw8CxPK5nKli2b7fbW1tZyrDBJkiQzsrgkY2trS6NGjfD399dZ7+/vT4sWLbLdp3nz5lm2P3DgAI0bN8bGxsZksUqSJEm5s7gkAzB16lT+97//sXHjRq5fv86UKVO4d+8e48aNAzLqU4YNG6bdfty4cdy9e5epU6dy/fp1Nm7cyIYNG5g+fbq5LkGSJEnCQutkBgwYQFRUFB9//DGhoaHUrl2bvXv34uvrC0BoaKhOnxk/Pz/27t3LlClTWLVqFeXKlePLL7+U/VUkSZLMzCL7yZiDKSctkyRJKuqKTT8ZSZIkqfiQSUaSJEkyGZlkJEmSJJORSUaSJEkyGZlkJEmSJJORSUaSJEkyGYvsJ2MOmS25Y2NjzRyJJEmS5cn8bixorxeZZJ6Ii4sDMGjIf0mSpOIuLi4OV1fXfG8vO2M+odFoePDgAc7OzrmO9pydzAnP7t+/X2I6cpbEa4aSed3ymkvGNUPu1y2EIC4ujnLlymFllf+aFvkk84SVlRXly5c36BguLi4l6gMJJfOaoWRet7zmkiOn6y7IE0wmWfEvSZIkmYxMMpIkSZLJyCRjBHZ2dsyePbtEzbRZEq8ZSuZ1y2suOUxx3bLiX5IkSTIZ+SQjSZIkmYxMMpIkSZLJyCQjSZIkmYxMMpIkSZLJyCSTD6tXr8bPzw+VSkWjRo04fvx4rtsHBATQqFEjVCoVlStXZu3atYUUqXEV5Lp//PFHOnfujKenJy4uLjRv3pz9+/cXYrTGUdD3OtPJkyextramfv36pg3QRAp63SkpKcycORNfX1/s7OyoUqUKGzduLKRojaOg17xt2zbq1auHg4MD3t7ejBgxgqioqEKK1nDHjh3jpZdeoly5cigUCnbv3p3nPkb5LhNSrr799lthY2MjvvrqKxEYGCgmTZokHB0dxd27d7Pd/vbt28LBwUFMmjRJBAYGiq+++krY2NiIXbt2FXLkhinodU+aNEksXLhQ/PHHH+Kff/4RM2bMEDY2NuLChQuFHLn+CnrNmR4/fiwqV64sunTpIurVq1c4wRqRPtf98ssvi6ZNmwp/f38RFBQkzp49K06ePFmIURumoNd8/PhxYWVlJZYvXy5u374tjh8/Lp5//nnRq1evQo5cf3v37hUzZ84UP/zwgwDETz/9lOv2xvouk0kmD02aNBHjxo3TWVejRg3x/vvvZ7v9u+++K2rUqKGzbuzYsaJZs2Ymi9EUCnrd2alVq5aYO3eusUMzGX2vecCAAeLDDz8Us2fPLpJJpqDX/fvvvwtXV1cRFRVVGOGZREGvefHixaJy5co667788ktRvnx5k8VoSvlJMsb6LpPFZblITU3l/PnzdOnSRWd9ly5dOHXqVLb7nD59Osv2Xbt25dy5c6SlpZksVmPS57qfpdFoiIuLo3Tp0qYI0ej0veZNmzZx69YtZs+ebeoQTUKf6/7ll19o3LgxixYtwsfHh+eee47p06eTlJRUGCEbTJ9rbtGiBcHBwezduxchBA8fPmTXrl306NGjMEI2C2N9l8kBMnMRGRmJWq2mTJkyOuvLlClDWFhYtvuEhYVlu316ejqRkZF4e3ubLF5j0ee6n7VkyRISEhLo37+/KUI0On2u+d9//+X999/n+PHjWFsXzT8lfa779u3bnDhxApVKxU8//URkZCTjx48nOjq6SNTL6HPNLVq0YNu2bQwYMIDk5GTS09N5+eWXWbFiRWGEbBbG+i6TTzL58OzQ/0KIXKcDyG777NZbuoJed6YdO3YwZ84cdu7ciZeXl6nCM4n8XrNarWbQoEHMnTuX5557rrDCM5mCvNcajQaFQsG2bdto0qQJ3bt3Z+nSpWzevLnIPM1Awa45MDCQiRMnMmvWLM6fP8++ffsICgpi3LhxhRGq2Rjju6xo/vwqJB4eHiiVyiy/bsLDw7Nk+Exly5bNdntra2vc3d1NFqsx6XPdmXbu3MmoUaP4/vvv6dSpkynDNKqCXnNcXBznzp3j4sWLvPXWW0DGl68QAmtraw4cOECHDh0KJXZD6PNee3t74+PjozPse82aNRFCEBwcTLVq1Uwas6H0ueYFCxbQsmVL3nnnHQDq1q2Lo6MjrVu3Zv78+UWihKKgjPVdJp9kcmFra0ujRo3w9/fXWe/v70+LFi2y3ad58+ZZtj9w4ACNGzfGxsbGZLEakz7XDRlPMMOHD2f79u1Frqy6oNfs4uLClStXuHTpkvbfuHHjqF69OpcuXaJp06aFFbpB9HmvW7ZsyYMHD4iPj9eu++eff4wyJ1Nh0OeaExMTs0zUpVQqgYJPR1xUGO27rEDNBEqgzKaOGzZsEIGBgWLy5MnC0dFR3LlzRwghxPvvvy+GDh2q3T6z2d+UKVNEYGCg2LBhQ5Fuwpzf696+fbuwtrYWq1atEqGhodp/jx8/NtclFFhBr/lZRbV1WUGvOy4uTpQvX168+uqr4tq1ayIgIEBUq1ZNjB492lyXUGAFveZNmzYJa2trsXr1anHr1i1x4sQJ0bhxY9GkSRNzXUKBxcXFiYsXL4qLFy8KQCxdulRcvHhR22zbVN9lMsnkw6pVq4Svr6+wtbUVDRs2FAEBAdrXXn/9ddG2bVud7Y8ePSoaNGggbG1tRaVKlcSaNWsKOWLjKMh1t23bVgBZ/r3++uuFH7gBCvpeP62oJhkhCn7d169fF506dRL29vaifPnyYurUqSIxMbGQozZMQa/5yy+/FLVq1RL29vbC29tbDB48WAQHBxdy1Po7cuRIrn+jpvouk0P9S5IkSSYj62QkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWSkYqVdu3ZMnjzZ3GHkyNLjK6hnr6e4XZ9kODmfjGRU7dq1o379+ixbtsws5//xxx8tZkqF7O6FJcVnCk9fn7k/C5JlkElGKnSpqanY2tqa5NilS5c2yXELIrfrs4T4TKm4X59UcLK4TDKa4cOHExAQwPLly1EoFCgUCu7cuUO7du146623mDp1Kh4eHnTu3BmASpUqZfmVW79+febMmQNkTAa1aNEiKleujL29PfXq1WPXrl25xpBd8c3EiRN59913KV26NGXLltUeP9OuXbuoU6cO9vb2uLu706lTJxISEvIdQ3bXl9u9eDq+lJQUJk6ciJeXFyqVilatWvHnn38WKP7sJCQkMGzYMJycnPD29mbJkiU6587r3gPs27ePVq1aUapUKdzd3enZsye3bt3K9byZ58jp+rdu3Yq7uzspKSk6+/Xt25dhw4bleV2ZunXrxuuvv65dPnz4MO7u7qSnp+f7GFLhkElGMprly5fTvHlzxowZQ2hoKKGhoVSoUAGALVu2YG1tzcmTJ1m3bl2+jvfhhx+yadMm1qxZw7Vr15gyZQpDhgwhICCgQHFt2bIFR0dHzp49y6JFi/j444+1M/6FhoYycOBARo4cyfXr1zl69Ch9+vTRznaY3xievb7c7sXT3n33XX744Qe2bNnChQsXqFq1Kl27diU6Ojpf8efknXfe4ciRI/z0008cOHCAo0ePcv78+QLdt4SEBKZOncqff/7JoUOHsLKyonfv3mg0mjz3zen6+/Xrh1qt5pdfftFuGxkZyZ49exgxYkS+Y/Px8SEkJES73L59e1JSUjh58mSBrlEyPVlcJhmNq6srtra2ODg4ULZsWZ3XqlatyqJFi/J9rISEBJYuXcrhw4dp3rw5AJUrV+bEiROsW7eOtm3b5vtYdevWZfbs2QBUq1aNlStXcujQITp37kxoaCjp6en06dMHX19fAOrUqVPgGLK7vpzuxdPXuGbNGjZv3ky3bt0A+Oqrr/D392fDhg0688nnFH924uPj2bBhA1u3btVus2XLlgJPjdy3b1+d5Q0bNuDl5UVgYCC1a9fOdd+cPgv29vYMGjSITZs20a9fPwC2bdtG+fLladeuXb5j8/Hx4cSJE9plhUKBSqUiIiIi38eQCodMMlKhaNy4cYG2DwwMJDk5OcsXaWpqKg0aNCjQserWrauz7O3tTXh4OAD16tWjY8eO1KlTh65du9KlSxdeffVV3NzcChRDQa8P4NatW6SlpdGyZUvtOhsbG5o0acL169fzFX9Ox01NTdUmRsioK6levXqB4/voo484c+YMkZGR2ieYe/fu5ZlkcjNmzBheeOEFQkJC8PHxYdOmTQwfPhyFQpHvYzz7JHPp0iUePXpEixYt9I5LMg2ZZKRC4ejomGWdlZUVz07MmpaWBqD9Qvvtt9/w8fHR2cbOzq5A5362NZdCodAeX6lU4u/vz6lTpzhw4AArVqxg5syZnD17tkAxZHd9ecm89me/XIUQOutyiz+34+Ymt3uf6aWXXqJChQp89dVXlCtXDo1GQ+3atUlNTc3z+Llp0KAB9erVY+vWrXTt2pUrV67w66+/FugYPj4+xMfHExsbi5OTE1OmTGHw4MGUK1fOoNgk45N1MpJR2draolar87Wtp6cnoaGh2uXY2FiCgoIAqFWrFnZ2dty7d4+qVavq/MuubsMQCoWCli1bMnfuXC5evIitrS0//fSTwTHkdS+qVq2Kra2tTrFPWloa586do2bNmnpfT9WqVbGxseHMmTPadY8ePeKff/7RLud27wGioqK4fv06H374IR07dqRmzZo8evSoQHHkdv2jR49m06ZNbNy4kU6dOhX4Pc1M+sHBwbz33nuEhoaycuXKbLe9f/8+Fy5cKNDxJeORTzKSUVWqVImzZ89y584dnJyccm3S2qFDBzZv3sxLL72Em5sbH330EUqlEgBnZ2emT5/OlClT0Gg0tGrVitjYWE6dOoWTk5NOyyJDnD17lkOHDtGlSxe8vLw4e/YsERER1KxZ0+AY8roXjo6OvPnmm7zzzjuULl2aihUrsmjRIhITExk1apTe1+Tk5MSoUaN45513cHd3p0yZMsycORMrq/9+U+Z27wHc3Nxwd3dn/fr1eHt7c+/ePd5///0CxZHd9WfGMHjwYKZPn85XX33F1q1bC3yNmUlm2rRp3Lhxg+PHj+Pi4pJlu+PHjzNr1iySkpL46KOP6NGjR4HPJRlGJhnJqKZPn87rr79OrVq1SEpK0vl1/KwZM2Zw+/ZtevbsiaurK/PmzdPZft68eXh5ebFgwQJu375NqVKlaNiwIR988IHR4nVxceHYsWMsW7aM2NhYfH19WbJkibYi3pAY8nMvPvvsMzQaDUOHDiUuLo7GjRuzf/9+3NzcDLquxYsXEx8fz8svv4yzszPTpk0jJiZG+3pe997Kyopvv/2WiRMnUrt2bapXr86XX35ZoMr57K6/UqVKQMZ979u3L7/99hu9evUq8PV5eHhgZ2fH3bt3OXbsWJbizEytW7cmPT2dpKQkmWDMRCHyU4ArSVKRZ2k98Dt37kzNmjX58ssvTXYOtVrN7du3sbe3x83NTa+6M8kw8klGkqRCFR0dzYEDBzh8+HCO9SjGolQqqVatmknPIeVOJhlJkgpVw4YNefToEQsXLixws2qp6JHFZZIkSZLJyCbMkiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZzP8Bin6Tz/o0ev4AAAAASUVORK5CYII=",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def make_insert_example():\n",
+ " n = np.random.randint(6, 11)\n",
+ " x1 = sample_seq(n)\n",
+ " i = np.random.randint(1, n - 1)\n",
+ " gapsize = int(np.random.choice([0, 0, 1, 2])) # often empty -> nu* = 0\n",
+ " left = x1[i - 1]\n",
+ " rj = i + gapsize\n",
+ " right = x1[rj] if rj < n else PAD\n",
+ " gap = [x1[j] for j in range(i, rj)]\n",
+ " y = [left, MASK, right]\n",
+ " p = posterior_at(y, 1)\n",
+ " nu = float(sum(p[v] for v in set(gap)))\n",
+ " return window_feat(y, 1), torch.tensor([nu]), nu\n",
+ "\n",
+ "# IQL's target nu_star is a *soft* value in [0,1], so BCE(nu_star, nu_phi) bottoms out at E[H(nu_star)] rather than 0\n",
+ "# same idea as UQL, again reported with the floor + EMA\n",
+ "\n",
+ "_, _, nu_pool = zip(*[make_insert_example() for _ in range(4000)])\n",
+ "iql_floor = binary_entropy(nu_pool)\n",
+ "\n",
+ "nu_phi = QualityHead(din)\n",
+ "opt = torch.optim.Adam(nu_phi.parameters(), lr=2e-3)\n",
+ "run = None\n",
+ "for step in range(2000):\n",
+ " feats, soft, _ = zip(*[make_insert_example() for _ in range(64)])\n",
+ " loss = F.binary_cross_entropy(nu_phi(torch.stack(feats)), torch.stack(soft))\n",
+ " opt.zero_grad(); loss.backward(); opt.step()\n",
+ " run = loss.item() if run is None else 0.98 * run + 0.02 * loss.item()\n",
+ " if (step + 1) % 500 == 0:\n",
+ " print(f\" step {step+1:4d} IQL running avg = {run:.4f}\")\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " fs, _, qs = zip(*[make_insert_example() for _ in range(800)])\n",
+ " pred = nu_phi(torch.stack(fs)).squeeze(-1).numpy()\n",
+ "qs = np.array(qs)\n",
+ "print(f\"corr(nu_phi, nu_star) = {np.corrcoef(pred, qs)[0,1]:.3f}\")\n",
+ "\n",
+ "plt.figure(figsize=(4.2, 4.2))\n",
+ "plt.scatter(qs, pred, s=6, alpha=0.25)\n",
+ "bins = np.linspace(0, qs.max() + 1e-6, 9)\n",
+ "idx = np.digitize(qs, bins) - 1\n",
+ "bx = [qs[idx == b].mean() for b in range(len(bins)) if (idx == b).any()]\n",
+ "by = [pred[idx == b].mean() for b in range(len(bins)) if (idx == b).any()]\n",
+ "plt.plot(bx, by, \"o-\", color=\"crimson\", label=\"binned mean\")\n",
+ "lim = max(qs.max(), pred.max())\n",
+ "plt.plot([0, lim], [0, lim], \"k--\", lw=1, label=\"ideal\")\n",
+ "plt.xlabel(r\"true insertion quality $\\nu_\\star$\")\n",
+ "plt.ylabel(r\"predicted $\\nu_\\phi$\")\n",
+ "plt.title(\"IQL recovers the insertion quality\")\n",
+ "plt.legend(fontsize=8); plt.tight_layout(); plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Quality-guided inference with A2D2\n",
+ "\n",
+ "The A2D2 sampler (`_diffusion_loop`) alternates *unmask* and *insert* Euler steps, and after each one calls the two schedule-aware routines we vendored in §0 to **re-mask low-quality tokens** and **drop low-quality insertions**, keeping the masked/clean counts on the interpolant schedule. We exercise the actual functions with a minimal toy model.\n",
+ "\n",
+ "The toy example only needs `model.planner(seq, t)` returning per-position confidences, and `model.interpolant.{unmask,insertion}_schedule.at(t)` (linear schedule here)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class LinearSchedule:\n",
+ " def at(self, t):\n",
+ " return t.clone() if torch.is_tensor(t) else torch.tensor(float(t))\n",
+ "\n",
+ "class ToyInterpolant:\n",
+ " unmask_schedule = LinearSchedule()\n",
+ " insertion_schedule = LinearSchedule()\n",
+ "\n",
+ "class ToyPlanner:\n",
+ " \"\"\"Stand-in for the trained planner. We pass confidences in directly so we can script\n",
+ " exactly which token / insertion is low-quality; in A2D2 these come from mu_phi / nu_phi.\"\"\"\n",
+ " def __init__(self, remask_conf=None, insert_conf=None):\n",
+ " self.remask_conf, self.insert_conf = remask_conf, insert_conf\n",
+ " def __call__(self, seq, t):\n",
+ " B, L = seq.shape\n",
+ " out = {}\n",
+ " if self.remask_conf is not None:\n",
+ " out[\"remasking_conf\"] = self.remask_conf.view(B, L, 1)\n",
+ " if self.insert_conf is not None:\n",
+ " out[\"insertion_conf\"] = self.insert_conf.view(B, L, 1)\n",
+ " return out\n",
+ "\n",
+ "class ToyModel:\n",
+ " def __init__(self, planner):\n",
+ " self.planner = planner\n",
+ " self.interpolant = ToyInterpolant()\n",
+ "\n",
+ "neg_inf = torch.tensor(-np.inf)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1 Schedule-aware **re-masking** (low unmasking quality → re-masked)\n",
+ "\n",
+ "Four clean tokens `ABCD` and two masks. Position 1 (`B`) is given a low unmasking-quality confidence (0.10) while the rest are confident. With the schedule asking for one more mask at this step, the routine re-masks exactly the lowest-quality token."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "before: ABCD__\n",
+ "after : A_CD__ <- the low-quality 'B' was re-masked\n"
+ ]
+ }
+ ],
+ "source": [
+ "new_xt = torch.tensor([[0, 1, 2, 3, MASK, MASK]]) # ABCD__\n",
+ "conf = torch.tensor([[0.9, 0.1, 0.95, 0.92, 0.0, 0.0]]) # pos 1 = low quality\n",
+ "clean_index = (new_xt != MASK) & (new_xt != PAD)\n",
+ "model = ToyModel(ToyPlanner()) # remasking_conf is passed in directly\n",
+ "t, dt = torch.tensor([0.4]), 0.1\n",
+ "\n",
+ "out = apply_schedule_aware_remasking(model, new_xt.clone(), t, dt, conf, clean_index,\n",
+ " MASK, neg_inf, batch_size=1)\n",
+ "print(\"before:\", decode(new_xt[0].tolist()))\n",
+ "print(\"after :\", decode(out[0].tolist()), \" <- the low-quality 'B' was re-masked\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2 Schedule-aware **insertion deletion** (low insertion quality → dropped)\n",
+ "\n",
+ "State `A_BC_` after inserting two masks (between `A` & `B`, and after `C`). The first inserted mask gets a low insertion-quality confidence (0.05); with `quality_threshold=0.5` the routine drops it and compacts the sequence, keeping the confident insertion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "before: A_BC_.\n",
+ "after : ABC_.. <- low-quality insertion dropped, kept the good one\n"
+ ]
+ }
+ ],
+ "source": [
+ "new_xt = torch.tensor([[0, 1, 2, PAD, PAD, PAD]]) # ABC (3 clean)\n",
+ "xt_tmp = torch.tensor([[0, MASK, 1, 2, MASK, PAD]]) # A_BC_ after inserts\n",
+ "ext = torch.tensor([[0, 1, 0, 1, 0, 0]]) # 1 insert in gaps 1 and 3\n",
+ "orig_mask = torch.tensor([[True, False, True, True, False, False]])\n",
+ "new_pos_orig = torch.tensor([[0, 2, 3, 3, 4, 5]]) # originals shifted right\n",
+ "insert_conf = torch.tensor([[0.9, 0.05, 0.9, 0.9, 0.95, 0.0]]) # inserted mask @pos1 = low\n",
+ "model = ToyModel(ToyPlanner(insert_conf=insert_conf))\n",
+ "t, dt = torch.tensor([0.0]), 0.05\n",
+ "\n",
+ "out = apply_schedule_aware_insertion(model, xt_tmp.clone(), new_xt, t, dt, ext,\n",
+ " MASK, PAD, max_length=6, orig_mask=orig_mask,\n",
+ " new_pos_orig=new_pos_orig, quality_threshold=0.5)\n",
+ "\n",
+ "print(\"before:\", decode(xt_tmp[0].tolist()))\n",
+ "print(\"after :\", decode(out[0].tolist()), \" <- low-quality insertion dropped, kept the good one\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Putting it together: quality inference catches decoding errors\n",
+ "\n",
+ "The point of both predictors is to flag mistakes *without* access to $\\boldsymbol{x}_1$. Here we corrupt one position of a clean target (decode it to the wrong letter) and let the trained $\\mu_\\phi$ score every clean token. The corrupted position gets the **lowest** predicted unmasking quality — so A2D2's schedule-aware re-masking would target it for a re-do, exactly as intended."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "target x1: BCDECDEAB\n",
+ "decoded y: BCDEEDEAB (position 4 corrupted: C -> E)\n",
+ "\n",
+ " predicted unmasking quality mu_phi per position:\n",
+ " pos 0 (B): 0.666\n",
+ " pos 1 (C): 0.948\n",
+ " pos 2 (D): 0.920\n",
+ " pos 3 (E): 0.339\n",
+ " pos 4 (E): 0.041 *corrupted* <-- lowest, would be re-masked\n",
+ " pos 5 (D): 0.595\n",
+ " pos 6 (E): 0.918\n",
+ " pos 7 (A): 0.957\n",
+ " pos 8 (B): 0.745\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAD5CAYAAADRCboRAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAMq9JREFUeJzt3XlcVGX/P/7XwMAMDIssAkqsaoohYOCG4pqoqHd2p9ltBSpW3OCK9jHjdslMUssPZoKVIi1kpLeWCy6oSZr0SU2tlBZXSEEUVJAUHbi+f/ib+TnMIEdZZoDX8/E4f8w11znnfZ1t3nOdTSaEECAiIiKiWpkZOwAiIiKipoKJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIlTM7FgwQLIZDJcvXrV4Pf+/v7o379/4wbVAGprp8b48ePh7e3dOEGZkLS0NMhkMpw/f15bZmhZLF68GF9//XW9z//YsWPo168f7O3tIZPJkJSUhP3790Mmk2H//v31Pr/6kJycjLS0NL1yTdwbN25s/KCagVOnTmHBggU622J90RwHavPFF18gKSmpTvOSyWSYPHlynaZRn+7evYtOnTrhnXfeaZT59e/fX+e3448//oClpSV++umnRpm/KWLiRM3S3LlzsXnzZmOHYRIMLYuGSpwmTpyIgoICfPnll8jJycHzzz9f7/OobzUlTlQ3p06dwptvvtkgiZNU9ZE4mZrk5GRcu3YNU6ZMMcr8H3/8cbzwwguYMWOGUeZvCuTGDoCoIbRr187YIZiMxlwWv/76K15++WUMGzZMW/bbb7812vxJnxACt2/fhpWVld53t27dglKplNR7Q8anVquxbNkyTJw4ESqV6oF1//77b1hbWzdIHJMnT0ZISAgOHTqE0NDQBpmHKWOPUwulOQ2xfv16JCQkoG3btrCzs8NTTz2F33//Xadu//794e/vj5ycHISGhsLKygre3t5Yt24dAGD79u148sknYW1tjS5dumDnzp06458+fRoTJkxAhw4dYG1tDXd3d4wcORK//PKLTr2qqiosWrQIHTt2hJWVFVq1aoWAgACsWLHigW357bff4Ovrix49eqCoqAiA4dNTmi73zz77DH5+frC2tkZgYCC2bdumN81vvvkGAQEBUCgU8PX1xYoVKySfHhBCYOnSpfDy8oJSqcSTTz6JHTt26HV5GzqtBsDgqa2srCw8/fTTeOyxx6BUKtG+fXu8+uqrtZ6yNLQsZDIZysvL8cknn0Amk0Emk6F///44f/485HI5EhMT9abx3XffQSaTYcOGDQbnoWmLWq1GSkqKdro1OXLkCJ5//nl4e3trt6d//etfuHDhgl7dgwcPolevXlAqlXB3d8fcuXOxZs0avWW3b98+9O/fH05OTrCysoKnpyeeffZZ/P333zXG4e3tjZMnTyI7O1sbc/Xt5u7du7XuIwCwZ88eDBo0CHZ2drC2tkbv3r2xd+/eGud9v+vXr2PmzJnw9fWFQqGAi4sLIiIidJLOkpISxMbGwt3dHZaWlvD19UVCQgIqKip0pqXZzlevXg0/Pz8oFAp88skn2nW0e/duTJw4Ea1bt4a1tTUqKipqPLVtaJvXTP/DDz/E448/DoVCgc6dO+PLL7/U1klLS8OYMWMAAAMGDNAu2/t79qQur+3btyMoKAgKhQI+Pj549913JS3T/v37Y/v27bhw4YJ2/ve3ReryrE4IgTfeeAMWFhb4+OOPteUZGRno1asXVCoVbGxsMGTIEBw7dkxn3PHjx8PGxganT59GREQEbGxs4OHhgZkzZ9Y6XwDYsmULLl68iJdeekmnXLOefvrpJ4wePRoODg7aP0xCCCQnJyMoKAhWVlZwcHDA6NGjcfbsWb12GTpuGRIcHAw/Pz+sXr261pibJUHNwvz58wUAceXKFYPfP/HEE6Jfv37az99++60AILy9vcULL7wgtm/fLtavXy88PT1Fhw4dhFqt1tbt16+fcHJyEh07dhRr164Vu3btEiNGjBAAxJtvvim6dOki1q9fLzIzM0XPnj2FQqEQFy9e1I6fnZ0tZs6cKTZu3Ciys7PF5s2bxahRo4SVlZX47bfftPUSExOFubm5mD9/vti7d6/YuXOnSEpKEgsWLKixnfv37xcODg7i6aefFuXl5dp6UVFRwsvLS2cZaNrbvXt38dVXX4nMzEzRv39/IZfLxZkzZ7T1duzYIczMzET//v3F5s2bxYYNG0SPHj2Et7e3kLLLaGKMjo4WO3bsEB999JFwd3cXbm5uOutg3bp1AoA4d+6czviadfPtt99qy1JSUkRiYqLYsmWLyM7OFp988okIDAwUHTt2FHfu3HngNKsvi5ycHGFlZSUiIiJETk6OyMnJESdPnhRCCPHMM88IT09PnfUvhBBjxowRbdu2FXfv3jXY5qKiIpGTkyMAiNGjR2unW1N7NmzYIObNmyc2b94ssrOzxZdffin69esnWrdurbMNnzhxQiiVShEQECC+/PJLsWXLFhEREaFdF5p2njt3TiiVSjF48GDx9ddfi/3794v09HTx0ksviWvXrtWwpoT46aefhK+vr+jatas25p9++kknbin7yGeffSZkMpkYNWqU2LRpk9i6dasYMWKEMDc3F3v27Klx/kIIUVpaKp544gmhUqnEwoULxa5du8R///tfMW3aNLFv3z4hhBC3bt0SAQEBQqVSiXfffVfs3r1bzJ07V8jlchEREaEzPQDC3d1dBAQEiC+++ELs27dP/Prrr9ptw93dXbzyyitix44dYuPGjUKtVhvcX4T4/7fl6tP38PAQnTt3FuvXrxdbtmwRQ4cOFQDEhg0btNvD4sWLBQCxatUq7bItKip6qOW1Z88eYW5uLvr06SM2bdokNmzYILp16yY8PT1r3RdPnjwpevfuLdzc3LTz12yTD7s84+LihBBC3L59Wzz//PPC1tZW7NixQ1vn7bffFjKZTEycOFFs27ZNbNq0SfTq1UuoVCrtviXEvX3R0tJS+Pn5iXfffVfs2bNHzJs3T8hkMvHmm28+sD1CCDFx4kTh4uJS43ry8vISs2fPFllZWeLrr78WQgjx8ssvCwsLCzFz5kyxc+dO8cUXX4hOnToJV1dXUVhYqDeN2o5bGv/+97+Fs7OzqKqqqjXu5oaJUzPxqIlT9YPEV199JQBoDzBC3EucAIgjR45oy4qLi4W5ubmwsrLSSZKOHz8uAIj333+/xljVarW4c+eO6NChg5gxY4a2fMSIESIoKEhyOz/77DNhaWkppk6dKiorK3Xq1ZQ4ubq6itLSUm1ZYWGhMDMzE4mJidqybt26CQ8PD1FRUaEtKysrE05OTrUerK9duyaUSqV45plndMq///57AeCRE6f7VVVVibt374oLFy4IAOKbb7554DQNLQuVSiWioqL0pq2Z9+bNm7VlFy9eFHK5XNKB/f4fGantEeLeNnHz5k2hUqnEihUrtOVjxowRKpVKZ7uurKwUnTt31mnnxo0bBQBx/PjxWmOsrvq+UT3u2vaR8vJy4ejoKEaOHKlTr7KyUgQGBoru3bs/cP4LFy4UAERWVlaNdVavXi0AiK+++kqnfMmSJQKA2L17t7YMgLC3txclJSU6dTXbRmRkpN70HzZxsrKy0vnRVavVolOnTqJ9+/basg0bNhhc7w+zvHr06CHatm0rbt26pS0rLS0Vjo6Okv7EDB8+3GC7HnZ5xsXFieLiYtGnTx/h7u6us53l5eUJuVwupkyZojOtsrIy4ebmJp577jltWVRUlMH5RkREiI4dO9baHj8/PzF06FC9cs16mjdvnk655s/Me++9p1Oen58vrKysxP/8z/8IIR7uuKXx8ccfCwAiNze31ribG56qa+H+8Y9/6HwOCAgAAL1TJm3atEFwcLD2s6OjI1xcXBAUFIS2bdtqy/38/PTGV6vVWLx4MTp37gxLS0vI5XJYWlrizz//RG5urrZe9+7dceLECcTGxmLXrl0oLS2tMe63334b48ePxzvvvIMVK1bAzEzapjxgwADY2tpqP7u6usLFxUUbb3l5OY4cOYJRo0bB0tJSW8/GxgYjR46sdfo5OTm4ffs2XnjhBZ3y0NBQeHl5SYrRkKKiIsTExMDDwwNyuRwWFhba6d2/DOuqf//+CAwMxKpVq7Rlq1evhkwmwyuvvFJv87l58yZmz56N9u3bQy6XQy6Xw8bGBuXl5Trtyc7OxsCBA+Hs7KwtMzMzw3PPPaczvaCgIFhaWuKVV17BJ598oncaoi5q20cOHTqEkpISREVFQa1Wa4eqqioMHToUhw8fRnl5eY3T37FjBx5//HE89dRTNdbZt28fVCoVRo8erVM+fvx4ANA7xTVw4EA4ODgYnNazzz5b43ykGjRoEFxdXbWfzc3NMXbsWJw+fRp//fXXA8eVurzKy8tx+PBh/POf/4RSqdSOb2trK2lffJCHXZ7nzp1Dr169UFpaih9++AGBgYHa73bt2gW1Wo3IyEid9iiVSvTr10/vblKZTKYXf0BAgMHT1NVdunQJLi4uNX5ffd1u27YNMpkML774ok5sbm5uCAwM1Mb2KMctTRwXL16sNe7mhheHNxNy+b1VWVlZafB7tVoNCwsLvXInJyedzwqFAsC9i0bv5+joqDeupaWlXrkm2bh9+7a2LD4+HqtWrcLs2bPRr18/ODg4wMzMDJMmTdKZz5w5c6BSqfD5559j9erVMDc3R9++fbFkyRKEhITozOfzzz+Hu7v7Q9+1Vb29mjZr4rh27RqEEDo/ChqGyqorLi4GALi5uel9Z6hMiqqqKoSHh+PSpUuYO3cuunTpApVKhaqqKvTs2VNvXdXV1KlTMWnSJPz+++/w9fXFxx9/jNGjRz9y/IaMGzcOe/fuxdy5c9GtWzfY2dlBJpMhIiJCpz3FxcWS1kW7du2wZ88eLF26FHFxcSgvL4evry+mTp2KadOm1SnW2vaRy5cvA4Dej/D9SkpKaryY98qVK/D09HxgDMXFxXBzc9O73sjFxQVyuVy73Wm0adOmxmk96DupHrR9FxcX47HHHqtxXKnLSyaToaqqql73JY2HXZ4//vgjrl69irfffluvbZr2dOvWzeC8qv+ps7a21kkEgXvb1P3HzJpoLuavSfV1e/ny5RqPZwDg6+sL4NGOW5o46vv40xQwcWomNDvGxYsX9XYSIQQKCgr0ko/G8vnnnyMyMhKLFy/WKb969SpatWql/SyXyxEfH4/4+Hhcv34de/bswRtvvIEhQ4YgPz9f5w6RnTt3YuzYsQgLC8PevXvr1JtzPwcHB8hkMu3B8H6FhYW1jq/5kTVUt7CwUOcCXM2Bp/pFodUv+P71119x4sQJpKWlISoqSlt++vTpWuN5FOPGjcPs2bOxatUq9OzZE4WFhYiLi6u36d+4cQPbtm3D/Pnz8frrr2vLKyoqUFJSolPXyclJ8roICwtDWFgYKisrceTIEaxcuRLTp0+Hq6trgz4WQdMbtnLlSvTs2dNgnQcl3a1bt661l8bJyQn/93//ByGEzo99UVER1Gq1To8cgAdemG/oO6VSafDi5JpuPqhp+9bE+iBSl9fdu3chk8keOK9H9bDLc+zYsXBzc0NCQgKqqqrwn//8R689GzdurLfjUE2cnZ319pH7VV+3zs7OkMlkOHDggDbhv5+m7GGOWxqaOKovq5aAp+qaiYEDB0ImkyEjI0Pvu507d6K0tPSBpwIakkwm09tpt2/f/sAu3latWmH06NGIi4tDSUmJ3p1nXl5e2oNBWFgY/vzzz3qJVaVSISQkBF9//TXu3LmjLb9586bBu++q69mzJ5RKJdLT03XKDx06pNcVrzkY/fzzzzrlW7Zs0fmsORhWX4YffvhhrfHU5P5etuqUSqX2lNfy5csRFBSE3r17P/K8qpPJZBBC6LVnzZo1ej2m/fr1w759+3R+wKuqqmq8uw+4d9qoR48e2tONtT2o70HLQorevXujVatWOHXqFEJCQgwO95/2rW7YsGH4448/sG/fvhrrDBo0CDdv3tR79tann36q/b4uvL29UVRUpJOk3rlzB7t27TJYf+/evTp1KysrkZGRgXbt2ml7ZGrqvZa6vFQqFbp3745Nmzbp9MaUlZVh69atktpV07p9lOX5n//8B0lJSZg3bx7mzJmjLR8yZAjkcjnOnDlTY3vqS6dOnXDmzBnJ9UeMGAEhBC5evGgwri5dugB4uOOWxtmzZ2FmZoaOHTs+eoOaKPY4NRPt2rXD5MmTsWzZMly/fh0RERGwsrLC4cOH8c477yAkJATjxo0zSmwjRoxAWloaOnXqhICAABw9ehTLli3T6/IeOXIk/P39ERISgtatW+PChQtISkqCl5cXOnTooDfdNm3aIDs7G0OGDEHfvn2RlZUFf3//Ose7cOFCDB8+HEOGDMG0adNQWVmJZcuWwcbG5oH/9oB7PVazZs3CokWLMGnSJIwZMwb5+flYsGCBXpd3t27d0LFjR8yaNQtqtRoODg7YvHkzDh48qFOvU6dOaNeuHV5//XUIIeDo6IitW7ciKyvrkdvYpUsX7N+/H1u3bkWbNm1ga2urcwCMjY3F0qVLcfToUaxZs+aR52OInZ0d+vbti2XLlsHZ2Rne3t7Izs7G2rVrdXogASAhIQFbt27FoEGDkJCQACsrK6xevVp7zZDmNMjq1auxb98+DB8+HJ6enrh9+zZSU1MBoNY/DF26dMGXX36JjIwM+Pr6QqlUan9QpLCxscHKlSsRFRWFkpISjB49Gi4uLrhy5QpOnDiBK1euICUlpcbxp0+fjoyMDDz99NN4/fXX0b17d9y6dQvZ2dkYMWIEBgwYgMjISKxatQpRUVE4f/48unTpgoMHD2Lx4sWIiIio85+isWPHYt68eXj++efx2muv4fbt23j//fdrPPXv7OyMgQMHYu7cuVCpVEhOTsZvv/2m80gCzb740UcfwdbWFkqlEj4+PnBycpK8vN566y0MHToUgwcPxsyZM1FZWYklS5ZApVLVui8C99btpk2bkJKSguDgYJiZmSEkJOSRl+e0adNgY2ODV155BTdv3sT7778Pb29vLFy4EAkJCTh79iyGDh0KBwcHXL58GT/++CNUKhXefPPNh10lBvXv3x8LFy6U/Iym3r1745VXXsGECRNw5MgR9O3bFyqVCgUFBTh48CC6dOmCf//73w913NL44YcfEBQUVOO1dM2a8a5Lp/pWVVUlUlJSREhIiLC2thaWlpaiQ4cOYvbs2aKsrEynruaOIc3twxrnzp0TAMS6deu0Zf369RNPPPGE3vy8vLzE8OHD9cpR7c6qa9euiejoaOHi4iKsra1Fnz59xIEDB0S/fv107tZ47733RGhoqHB2dhaWlpbC09NTREdHi/Pnz2vrGLp78Pr166J3797C0dFRHD58WAhR81111e/40rSj+h1mmzdvFl26dNHG8c4774ipU6cKBwcHvfGrq6qqEomJicLDw0NYWlqKgIAAsXXrVr32CiHEH3/8IcLDw4WdnZ1o3bq1mDJliti+fbve3UinTp0SgwcPFra2tsLBwUGMGTNG5OXlCQBi/vz52npS76o7fvy46N27t7C2tq7xrpn+/fsLR0dH8ffff9faZg1Dy9jQXXV//fWXePbZZ4WDg4OwtbUVQ4cOFb/++qvBdXHgwAHRo0cPoVAohJubm3jttde0dz9dv35dCHHv7qFnnnlGeHl5CYVCIZycnES/fv3Eli1bao35/PnzIjw8XNja2mpv6b4/bin7iBD3HrsxfPhw4ejoKCwsLIS7u7sYPny43viGXLt2TUybNk14enoKCwsL4eLiIoYPH67zuI7i4mIRExMj2rRpI+RyufDy8hJz5swRt2/f1plWTdu5ZtvQ7CPVZWZmiqCgIGFlZSV8fX3FBx98UONddXFxcSI5OVm0a9dOWFhYiE6dOon09HS9aSYlJQkfHx9hbm6ut8ykLq8tW7aIgIAAnX3RUFyGlJSUiNGjR4tWrVoJmUymM05dluf69euFXC4XEyZM0N7R+/XXX4sBAwYIOzs7oVAohJeXlxg9erTO4xWioqKESqXSi1Nqe06fPi1kMpneXXm13VWdmpoqevToIVQqlbCyshLt2rUTkZGROndKP8xxq6ysTFhbW+vdrddSyIQQotGyNKIm6u7duwgKCoK7uzt27979SNPQPPzSVN/Zdr+ioiJ4eXlhypQpWLp0qbHD0RMeHo7z58/jjz/+MHYoLY5MJkNcXBw++OADY4fSIo0cORJqtbrGh1M2hrVr12LatGnIz89vkT1OPFVHZEB0dDQGDx6MNm3aoLCwEKtXr0Zubm6tTzFv6v766y+cPXsWy5Ytg5mZWZ3vSKsP8fHx6Nq1Kzw8PFBSUoL09HRkZWVh7dq1xg6NqNElJiaia9euOHz4cI138jUktVqNJUuWYM6cOS0yaQKYOBEZVFZWhlmzZuHKlSuwsLDAk08+iczMTKNdYN9Y1qxZg4ULF8Lb2xvp6elwd3c3dkiorKzEvHnzUFhYCJlMhs6dO+Ozzz7Diy++aOzQiBqdv78/1q1bV+c7Cx9Vfn4+XnzxRcycOdMo8zcFPFVHREREJBEfR0BEREQkkUkkTt999x1GjhyJtm3bQiaT6T1bw5Ds7GwEBwdDqVTC19e35b6lmYiIiBqNSVzjVF5ejsDAQEyYMEHSe5TOnTuHiIgIvPzyy/j888/x/fffIzY2Fq1bt36o9zBVVVXh0qVLsLW1feCTdomIiKj5EkKgrKwMbdu2rfXdpyZ3jZNMJsPmzZsxatSoGuvMnj0bW7Zs0XkZaExMDE6cOIGcnJwax6uoqNB5rcDFixfRuXPneombiIiImrb8/PwHvmsRMJEep4eVk5OD8PBwnbIhQ4Zg7dq1uHv3rsGX2QL3buM09ATX/Px82NnZNUisREREZNpKS0vh4eEBW1vbWus2ycSpsLBQ76WZrq6uUKvVuHr1ao1v/54zZw7i4+O1nzULys7OjokTERFRCyflsp0mmTgB+o3TnHF8UKMVCoXBN0QTERERSWESd9U9LDc3N72HfxUVFUEul8PJyclIUREREVFz1yQTp169eum9GX737t0ICQmp8fomIiIioroyicTp5s2bOH78OI4fPw7g3uMGjh8/jry8PAD3rk2KjIzU1o+JicGFCxcQHx+P3NxcpKamYu3atZg1a5YxwiciIqIWwiSucTpy5AgGDBig/ay5gDsqKgppaWkoKCjQJlEA4OPjg8zMTMyYMQOrVq1C27Zt8f777z/UM5yIiIiIHpbJPcepMZWWlsLe3h43btzgXXVEREQt1MPkAybR40RERNQUDHlru7FDkGTX3OHGDqHZMolrnIiIiIiaAvY40UPjPy4iImqp2ONEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiXhxOhKZzwTvAi96paeA+Rc0Ve5yIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkkUklTsnJyfDx8YFSqURwcDAOHDjwwPrp6ekIDAyEtbU12rRpgwkTJqC4uLiRoiUiIqKWxmQSp4yMDEyfPh0JCQk4duwYwsLCMGzYMOTl5Rmsf/DgQURGRiI6OhonT57Ehg0bcPjwYUyaNKmRIyciIqKWwmQSp+XLlyM6OhqTJk2Cn58fkpKS4OHhgZSUFIP1f/jhB3h7e2Pq1Knw8fFBnz598Oqrr+LIkSM1zqOiogKlpaU6AxEREZFUJpE43blzB0ePHkV4eLhOeXh4OA4dOmRwnNDQUPz111/IzMyEEAKXL1/Gxo0bMXz48Brnk5iYCHt7e+3g4eFRr+0gIiKi5s0kEqerV6+isrISrq6uOuWurq4oLCw0OE5oaCjS09MxduxYWFpaws3NDa1atcLKlStrnM+cOXNw48YN7ZCfn1+v7SAiIqLmzSQSJw2ZTKbzWQihV6Zx6tQpTJ06FfPmzcPRo0exc+dOnDt3DjExMTVOX6FQwM7OTmcgIiIikkpu7AAAwNnZGebm5nq9S0VFRXq9UBqJiYno3bs3XnvtNQBAQEAAVCoVwsLCsGjRIrRp06bB4yYiIqKWxSR6nCwtLREcHIysrCyd8qysLISGhhoc5++//4aZmW745ubmAO71VBERERHVN5NInAAgPj4ea9asQWpqKnJzczFjxgzk5eVpT73NmTMHkZGR2vojR47Epk2bkJKSgrNnz+L777/H1KlT0b17d7Rt29ZYzSAiIqJmzCRO1QHA2LFjUVxcjIULF6KgoAD+/v7IzMyEl5cXAKCgoEDnmU7jx49HWVkZPvjgA8ycOROtWrXCwIEDsWTJEmM1gYiIiJo5k0mcACA2NhaxsbEGv0tLS9MrmzJlCqZMmdLAURERERHdYzKn6oiIiIhMHRMnIiIiIolM6lQdERERNa4hb203dgiS7Zpb89tBGgt7nIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkERMnIiIiIomYOBERERFJxMSJiIiISCImTkREREQSyY0dQHM25K3txg5Bsl1zhxs7BCIiIpPHHiciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJJFJJU7Jycnw8fGBUqlEcHAwDhw48MD6FRUVSEhIgJeXFxQKBdq1a4fU1NRGipaIiIhaGpN5cnhGRgamT5+O5ORk9O7dGx9++CGGDRuGU6dOwdPT0+A4zz33HC5fvoy1a9eiffv2KCoqglqtbuTIiYiIqKUwmcRp+fLliI6OxqRJkwAASUlJ2LVrF1JSUpCYmKhXf+fOncjOzsbZs2fh6OgIAPD29n7gPCoqKlBRUaH9XFpaWn8NICIiombPJE7V3blzB0ePHkV4eLhOeXh4OA4dOmRwnC1btiAkJARLly6Fu7s7Hn/8ccyaNQu3bt2qcT6JiYmwt7fXDh4eHvXaDiIiImreTKLH6erVq6isrISrq6tOuaurKwoLCw2Oc/bsWRw8eBBKpRKbN2/G1atXERsbi5KSkhqvc5ozZw7i4+O1n0tLS5k8EZHR8YXgRE2HSSROGjKZTOezEEKvTKOqqgoymQzp6emwt7cHcO903+jRo7Fq1SpYWVnpjaNQKKBQKOo/cCIiImoRTOJUnbOzM8zNzfV6l4qKivR6oTTatGkDd3d3bdIEAH5+fhBC4K+//mrQeImIiKhlMonEydLSEsHBwcjKytIpz8rKQmhoqMFxevfujUuXLuHmzZvasj/++ANmZmZ47LHHGjReIiIiaplMInECgPj4eKxZswapqanIzc3FjBkzkJeXh5iYGAD3rk+KjIzU1h83bhycnJwwYcIEnDp1Ct999x1ee+01TJw40eBpOiIiIqK6MplrnMaOHYvi4mIsXLgQBQUF8Pf3R2ZmJry8vAAABQUFyMvL09a3sbFBVlYWpkyZgpCQEDg5OeG5557DokWLjNUEIiIiauZMJnECgNjYWMTGxhr8Li0tTa+sU6dOeqf3iIiIiBqKyZyqIyIiIjJ1TJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUn0yInT1q1b0bVrV7Rv3x7//Oc/sXv37vqMi4iIiMjkPPTjCJKSkvDkk09i1qxZ2LhxI9q3b48TJ05g8eLFOHv2rPaBlURERETNzUP3OMnlcnz22WfIz8/H2LFj8fzzz2Pr1q0YM2YMVqxYASFEQ8RJREREZHQP3eM0efJkAMCZM2fw0Ucf4e7du/jll1/w888/o6CgAJ07d4aNjQ0OHz5c78ESERERGdMjPzl85cqVGDNmDEJDQ9GlSxeUl5fD398fBw8exI0bN+ozRiIiIiKT8MgXhz/xxBPIycnB4MGDcfnyZXh6euKbb74BANjb29dbgERERESmQnKPU35+Pjw8PHTKlEolRo0ahVGjRtV3XEREREQmR3Li5OXlBQcHBwQGBiIwMBBBQUEIDAxERUUFVq1ahU8//bQh4yQiIiIyOsmJ09mzZ3H8+HEcP34cx44dw8aNG3Hp0iUAgJ2dXYMFSERERGQqJCdO3t7e8Pb21jktl5OTg6ioKCxZsqQhYiMiIiIyKXV65UqvXr2wYsUKLFq0qL7iISIiIjJZkhOnu3fvGizv0KEDTp48WW8BEREREZkqyafqVCoVOnfujK5duyIoKAhdu3ZF27ZtsXLlSoSHhzdkjEREREQmQXLitG/fPpw4cQInTpxAeno63njjDdy6dQsAEB4ejoSEBAQEBCAgIAB+fn4NFjARERGRsUhOnPr06YM+ffpoP1dVVeH333/X3ml39OhRpKamoqioCJWVlQ0SLBEREZExPfIrV8zMzODn5wc/Pz/861//0pZfvny5XgIjIiIiMjV1uqvOEFdX1/qeJBEREZFJqPfEiYiIiKi5YuJEREREJBETJyIiIiKJmDgRERERSWRSiVNycjJ8fHygVCoRHByMAwcOSBrv+++/h1wuR1BQUMMGSERERC2aySROGRkZmD59OhISEnDs2DGEhYVh2LBhyMvLe+B4N27cQGRkJAYNGtRIkRIREVFLZTKJ0/LlyxEdHY1JkybBz88PSUlJ8PDwQEpKygPHe/XVVzFu3Dj06tWr1nlUVFSgtLRUZyAiIiKSyiQSpzt37uDo0aN677wLDw/HoUOHahxv3bp1OHPmDObPny9pPomJibC3t9cOHh4edYqbiIiIWhaTSJyuXr2KyspKvYdnurq6orCw0OA4f/75J15//XWkp6dDLpf2APQ5c+bgxo0b2iE/P7/OsRMREVHL8civXGkIMplM57MQQq8MACorKzFu3Di8+eabePzxxyVPX6FQQKFQ1DlOIiIiaplMInFydnaGubm5Xu9SUVGRwVe4lJWV4ciRIzh27BgmT54M4N5Lh4UQkMvl2L17NwYOHNgosRMREVHLYRKn6iwtLREcHIysrCyd8qysLISGhurVt7Ozwy+//ILjx49rh5iYGHTs2BHHjx9Hjx49Git0IiIiakFMoscJAOLj4/HSSy8hJCQEvXr1wkcffYS8vDzExMQAuHd90sWLF/Hpp5/CzMwM/v7+OuO7uLhAqVTqlRMRERHVF5NJnMaOHYvi4mIsXLgQBQUF8Pf3R2ZmJry8vAAABQUFtT7TiYiIiKghmUziBACxsbGIjY01+F1aWtoDx12wYAEWLFhQ/0ERERER/X9M4honIiIioqaAiRMRERGRREyciIiIiCRi4kREREQkkUldHE5E9WfIW9uNHYJku+YON3YIRESSsMeJiIiISCImTkREREQSMXEiIiIikoiJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCQyqcQpOTkZPj4+UCqVCA4OxoEDB2qsu2nTJgwePBitW7eGnZ0devXqhV27djVitERERNTSmEzilJGRgenTpyMhIQHHjh1DWFgYhg0bhry8PIP1v/vuOwwePBiZmZk4evQoBgwYgJEjR+LYsWONHDkRERG1FHJjB6CxfPlyREdHY9KkSQCApKQk7Nq1CykpKUhMTNSrn5SUpPN58eLF+Oabb7B161Z07drV4DwqKipQUVGh/VxaWlp/DSAiIqJmzyR6nO7cuYOjR48iPDxcpzw8PByHDh2SNI2qqiqUlZXB0dGxxjqJiYmwt7fXDh4eHnWKm4iIiFoWk0icrl69isrKSri6uuqUu7q6orCwUNI03nvvPZSXl+O5556rsc6cOXNw48YN7ZCfn1+nuImIiKhlMZlTdQAgk8l0Pgsh9MoMWb9+PRYsWIBvvvkGLi4uNdZTKBRQKBR1jpOIiIhaJpNInJydnWFubq7Xu1RUVKTXC1VdRkYGoqOjsWHDBjz11FMNGSYRERG1cCZxqs7S0hLBwcHIysrSKc/KykJoaGiN461fvx7jx4/HF198geHDhzd0mERERNTCmUSPEwDEx8fjpZdeQkhICHr16oWPPvoIeXl5iImJAXDv+qSLFy/i008/BXAvaYqMjMSKFSvQs2dPbW+VlZUV7O3tjdYOIiIiar5MJnEaO3YsiouLsXDhQhQUFMDf3x+ZmZnw8vICABQUFOg80+nDDz+EWq1GXFwc4uLitOVRUVFIS0tr7PCJiIioBTCZxAkAYmNjERsba/C76snQ/v37Gz4gIiIiovuYxDVORERERE0BEyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkERMnIiIiIomYOBERERFJxMSJiIiISCImTkREREQSMXEiIiIikkhu7ACIiB7GmdZhxg5BknZXDhg7BCJqAOxxIiIiIpKIiRMRERGRREyciIiIiCQyqcQpOTkZPj4+UCqVCA4OxoEDD75GIDs7G8HBwVAqlfD19cXq1asbKVIiIiJqiUwmccrIyMD06dORkJCAY8eOISwsDMOGDUNeXp7B+ufOnUNERATCwsJw7NgxvPHGG5g6dSr++9//NnLkRERE1FKYzF11y5cvR3R0NCZNmgQASEpKwq5du5CSkoLExES9+qtXr4anpyeSkpIAAH5+fjhy5AjeffddPPvsswbnUVFRgYqKCu3nGzduAABKS0vruTX3qG//3SDTbQgPswyaSruaY5sA6e1qjm0CgLIqdQNGUn+4/TW/NgFNp13NsU1Aw/1ea6YrhKi9sjABFRUVwtzcXGzatEmnfOrUqaJv374GxwkLCxNTp07VKdu0aZOQy+Xizp07BseZP3++AMCBAwcOHDhw4KA35Ofn15qzmESP09WrV1FZWQlXV1edcldXVxQWFhocp7Cw0GB9tVqNq1evok2bNnrjzJkzB/Hx8drPVVVVKCkpgZOTE2QyWT20pOGVlpbCw8MD+fn5sLOzM3Y49YJtajqaY7vYpqajObaLbTINQgiUlZWhbdu2tdY1icRJo3ryIoR4YEJjqL6hcg2FQgGFQqFT1qpVq0eI1Pjs7OyazAYpFdvUdDTHdrFNTUdzbBfbZHz29vaS6pnExeHOzs4wNzfX610qKirS61XScHNzM1hfLpfDycmpwWIlIiKilsskEidLS0sEBwcjKytLpzwrKwuhoaEGx+nVq5de/d27dyMkJAQWFhYNFisRERG1XCaROAFAfHw81qxZg9TUVOTm5mLGjBnIy8tDTEwMgHvXJ0VGRmrrx8TE4MKFC4iPj0dubi5SU1Oxdu1azJo1y1hNaBQKhQLz58/XO+XYlLFNTUdzbBfb1HQ0x3axTU2PTAgp9941juTkZCxduhQFBQXw9/fH//7v/6Jv374AgPHjx+P8+fPYv3+/tn52djZmzJiBkydPom3btpg9e7Y20SIiIiKqbyaVOBERERGZMpM5VUdERERk6pg4EREREUnExImIiIhIIiZORERERBIxcWoCxo8fD5lMph2cnJwwdOhQ/Pzzz8YOrc4KCwsxZcoU+Pr6QqFQwMPDAyNHjsTevXuNHdpDu389WVhYwNXVFYMHD0ZqaiqqqqqMHV6dVN8GNcPQoUONHdoja25t4vbX9Bw6dAjm5uZNvh0azfm36n5MnJqIoUOHoqCgAAUFBdi7dy/kcjlGjBhh7LDq5Pz58wgODsa+ffuwdOlS/PLLL9i5cycGDBiAuLg4Y4f3SDTr6fz589ixYwcGDBiAadOmYcSIEVCr1cYOr07u3wY1w/r1640dVp00tzZx+2taUlNTMWXKFBw8eBB5eXnGDqdeNMffqupM6l11VDOFQgE3NzcA9143M3v2bPTt2xdXrlxB69atjRzdo4mNjYVMJsOPP/4IlUqlLX/iiScwceJEI0b26O5fT+7u7njyySfRs2dPDBo0CGlpaZg0aZKRI3x097etuWhubeL213SUl5fjq6++wuHDh1FYWIi0tDTMmzfP2GHVWXP8raqOPU5N0M2bN5Geno727ds32ffylZSUYOfOnYiLi9NJmjSa6suXDRk4cCACAwOxadMmY4dCLRC3P9OUkZGBjh07omPHjnjxxRexbt06NLfHKjaH3ypDmDg1Edu2bYONjQ1sbGxga2uLLVu2ICMjA2ZmTXMVnj59GkIIdOrUydihNIpOnTrh/Pnzxg6jTu7fBjXDW2+9Zeyw6qQ5tskQbn+mZ+3atXjxxRcB3Du9dfPmzSZ5bWd1ze23yhCeqmsiBgwYgJSUFAD3emuSk5MxbNgw/Pjjj/Dy8jJydA9P889KJpMZOZLGIYRo8m29fxvUcHR0NFI09aM5tskQbn+m5ffff8ePP/6o7QWUy+UYO3YsUlNT8dRTTxk5urppbr9VhjBxaiJUKhXat2+v/RwcHAx7e3t8/PHHWLRokREjezQdOnSATCZDbm4uRo0aZexwGlxubi58fHyMHUadVN8Gm4Pm2CZDuP2ZlrVr10KtVsPd3V1bJoSAhYUFrl27BgcHByNGVzfN7bfKkObTd9bCyGQymJmZ4datW8YO5ZE4OjpiyJAhWLVqFcrLy/W+v379euMH1UD27duHX375Bc8++6yxQ6EWiNufaVGr1fj000/x3nvv4fjx49rhxIkT8PLyQnp6urFDrFdN/bfKEPY4NREVFRUoLCwEAFy7dg0ffPABbt68iZEjRxo5skeXnJyM0NBQdO/eHQsXLkRAQADUajWysrKQkpKC3NxcY4f40DTrqbKyEpcvX8bOnTuRmJiIESNGIDIy0tjh1cn926CGXC6Hs7OzkSKqu+bWJm5/pm/btm24du0aoqOjYW9vr/Pd6NGjsXbtWkyePNlI0dVdc/yt0iPI5EVFRQkA2sHW1lZ069ZNbNy40dih1dmlS5dEXFyc8PLyEpaWlsLd3V384x//EN9++62xQ3to968nuVwuWrduLZ566imRmpoqKisrjR1enVTfBjVDx44djR3aI2tubeL21zSMGDFCREREGPzu6NGjAoA4evRoI0dVP5rzb9X9ZEI0s/sfiYiIiBoIr3EiIiIikoiJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgk+n9hN13plaDdRwAAAABJRU5ErkJggg==",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "x1 = sample_seq(9)\n",
+ "y = x1.copy()\n",
+ "bad_pos = 4\n",
+ "y[bad_pos] = (x1[bad_pos] + 2) % V # decode this position to a WRONG letter\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " scores = []\n",
+ " for ell in range(len(y)):\n",
+ " scores.append(mu_phi(window_feat(y, ell).unsqueeze(0)).item())\n",
+ "\n",
+ "print(\"target x1:\", decode(x1))\n",
+ "print(\"decoded y:\", decode(y), f\" (position {bad_pos} corrupted: {LETTERS[x1[bad_pos]]} -> {LETTERS[y[bad_pos]]})\")\n",
+ "print(\"\\n predicted unmasking quality mu_phi per position:\")\n",
+ "for ell, s in enumerate(scores):\n",
+ " flag = \" <-- lowest, would be re-masked\" if ell == int(np.argmin(scores)) else \"\"\n",
+ " star = \" *corrupted*\" if ell == bad_pos else \"\"\n",
+ " print(f\" pos {ell} ({LETTERS[y[ell]]}): {s:.3f}{star}{flag}\")\n",
+ "\n",
+ "plt.figure(figsize=(6, 2.6))\n",
+ "colors = [\"crimson\" if ell == bad_pos else \"steelblue\" for ell in range(len(y))]\n",
+ "plt.bar(range(len(y)), scores, color=colors)\n",
+ "plt.xticks(range(len(y)), [LETTERS[t] for t in y])\n",
+ "plt.ylabel(r\"$\\mu_\\phi$\"); plt.title(\"Unmasking quality flags the corrupted token (red)\")\n",
+ "plt.tight_layout(); plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "- **Unmasking quality** $\\mu_\\star$ and **insertion quality** $\\nu_\\star$ are defined in closed form here from a known posterior; the predictors $\\mu_\\phi,\\nu_\\phi$ trained with the paper's **UQL / IQL** BCE losses recover them.\n",
+ "- At inference, A2D2 turns these scores that are used in `apply_schedule_aware_remasking` (re-mask low-quality tokens) and `apply_schedule_aware_insertion` (drop low-quality insertions), both shown here on toy states and tied back to a corrupted-decode example.\n",
+ "- Swap the toy `posterior_at` / `ToyPlanner` for a trained any-length MDM + its planner heads (`model/model_wrapper py::RemaskingAnyOrder`) and this becomes the quality-guided decode used in `a2d2_mol` / `a2d2_pep` / `a2d2_language`."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "peptune",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9a7021af20cd6805394dc00c9ca30b27a12b4f2a
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,57 @@
+# Conda environment shared across the molecule, peptide, and language experiments.
+# Create with:
+# conda env create -f environment.yml
+# conda activate a2d2
+#
+# NOTE: flash-attn is hardware-specific and must be built against your installed torch
+# and CUDA, so it is not listed below. It is imported by the shared transformer backbone
+# (model/casual_transformer.py, model/rotary.py) and is required for all experiments.
+# After creating the env, install it with:
+# pip install flash-attn==2.8.3 --no-build-isolation
+# Adjust pytorch-cuda below to match your CUDA toolkit / GPU.
+name: a2d2
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+dependencies:
+ - python=3.11
+ - pip
+ - pytorch
+ - pytorch-cuda=12.1
+ - rdkit=2023.9.6
+ - jupyterlab # for demo/quality_inference_demo.ipynb
+ - pip:
+ # --- core scientific / DL stack ---
+ - numpy==1.26.4
+ - scipy==1.17.1
+ - pandas==2.1.4
+ - scikit-learn==1.8.0
+ - pytorch-lightning==2.6.0
+ - lightning==2.6.1
+ - transformers==4.55.4
+ - tokenizers==0.21.4
+ - safetensors==0.7.0
+ - accelerate==0.33.0
+ - peft==0.15.1 # LoRA adapters (language experiment)
+ - datasets==2.19.2
+ - huggingface-hub==0.36.2
+ - einops==0.8.2
+ - timm==1.0.26
+ - omegaconf==2.3.0
+ - wandb==0.26.1
+ # --- molecule experiment ---
+ - safe-mol==0.1.14
+ - datamol==0.12.5
+ - PyTDC==1.1.15
+ # --- peptide experiment ---
+ - SmilesPE==0.0.3
+ - fair-esm==2.0.0
+ - xgboost==3.2.0
+ # --- plotting / utilities ---
+ - matplotlib==3.10.6
+ - seaborn==0.13.2
+ - tqdm==4.67.1
+ - joblib==1.5.3
+ - loguru==0.7.3
+ - fsspec==2024.3.1
\ No newline at end of file
diff --git a/lightning_modules/__init__.py b/lightning_modules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..67b96f756961ab511def199434620a0cec87082c
--- /dev/null
+++ b/lightning_modules/__init__.py
@@ -0,0 +1,16 @@
+from .mdm import MaskedDiffusionModule
+from .any_order import AnyOrderInsertionFlowModule
+
+
+__all__ = [
+ "MaskedDiffusionModule",
+ "AutoregressiveModule",
+ "AnyOrderInsertionFlowModule",
+]
+
+
+def __getattr__(name):
+ if name == "AutoregressiveModule":
+ from .autoregressive import AutoregressiveModule
+ return AutoregressiveModule
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/lightning_modules/any_length_remask.py b/lightning_modules/any_length_remask.py
new file mode 100644
index 0000000000000000000000000000000000000000..475818f01b3b3675cf800691d767f926f46d02b7
--- /dev/null
+++ b/lightning_modules/any_length_remask.py
@@ -0,0 +1,801 @@
+import os
+import torch
+import torch.nn as nn
+import pytorch_lightning as pl
+from omegaconf import DictConfig
+import torch.nn.functional as F
+from model.transformer import AnyOrderMaskInsertionFlow
+from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
+from .bregman import jump_kernel_elbo, mse
+from .schedule import get_schedule_from_config
+from lightning_modules.any_order import AnyOrderInsertionFlowModule
+from model.model_wrapper import RemaskingAnyOrder
+from sampling import _sample_tokens
+
+import re
+from typing import Dict, Any
+from dataclasses import dataclass
+
+def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Returns a new state_dict where any key containing '._orig_mod.' is replaced
+ by removing the '_orig_mod' segment, e.g.
+ 'model._orig_mod.vocab_embed.embedding'
+ becomes
+ 'model.vocab_embed.embedding'
+ """
+ new_state_dict: Dict[str, Any] = {}
+ for key, value in state_dict.items():
+ # remove all occurrences of '._orig_mod.'
+ clean_key = re.sub(r"\._orig_mod\.", ".", key)
+ new_state_dict[clean_key] = value
+ return new_state_dict
+
+
+@torch.no_grad()
+def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float:
+ """Rank-based AUROC (Mann-Whitney U statistic).
+
+ AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN
+ when only one class is present (AUC undefined). Ties are not averaged, which
+ is fine for continuous logits used here.
+ """
+ scores = scores.float().reshape(-1)
+ labels = labels.float().reshape(-1)
+ n_pos = labels.sum()
+ n_neg = labels.numel() - n_pos
+ if n_pos == 0 or n_neg == 0:
+ return float("nan")
+ order = torch.argsort(scores)
+ ranks = torch.empty_like(scores)
+ ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype)
+ auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
+ return auc.item()
+
+
+class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule):
+ """
+ Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model
+ for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint
+ and add the schedule model on top.
+ """
+ def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False):
+ # Initialize parent class first
+ super().__init__(config)
+
+ self.args = args
+ self.insertion_planner = insertion_planner
+
+ # Save hyperparameters for this class (overrides parent's save)
+ self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args'])
+
+ # Load pretrained model weights BEFORE initializing planner to avoid circular reference
+ if pretrained_checkpoint is not None:
+ self.load_pretrained_model(pretrained_checkpoint)
+
+ # Initialize adaptive schedule model AFTER loading pretrained weights
+ self.planner = RemaskingAnyOrder(
+ backbone=self,
+ d_model=self.config.model.hidden_size,
+ insertion_planner=insertion_planner)
+
+ def load_pretrained_model(self, checkpoint_path: str):
+ """
+ Load pretrained AnyOrderInsertionFlowModule weights.
+ Only loads the base model and interpolant, not the schedule model.
+ """
+ print(f"Loading pretrained model from {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
+
+ # Extract state dict - handle different checkpoint formats
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ # Strip _orig_mod keys if present
+ state_dict = strip_orig_mod_keys(state_dict)
+
+ # Filter out planner keys (if any exist from a previous FT checkpoint)
+ base_state_dict = {k: v for k, v in state_dict.items()
+ if not k.startswith('planner.')}
+
+ # Load the base model weights
+ # Use strict=False to ignore missing schedule_model keys
+ incompatible_keys = self.load_state_dict(base_state_dict, strict=False)
+
+ # Filter out expected missing planner keys for cleaner output
+ unexpected_missing = [k for k in incompatible_keys.missing_keys
+ if not k.startswith('planner.')]
+ planner_missing = [k for k in incompatible_keys.missing_keys
+ if k.startswith('planner.')]
+
+ if unexpected_missing:
+ print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}")
+ if planner_missing:
+ print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)")
+ if incompatible_keys.unexpected_keys:
+ print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}")
+
+ # Freeze base model if specified
+ if self.config.training.get('freeze_base_model', False):
+ print("Freezing base model parameters")
+ for name, param in self.named_parameters():
+ if not name.startswith('planner.'):
+ param.requires_grad = False
+
+ def forward(self, x, t, return_features=False):
+ # Use parent class forward method
+ return super().forward(x, t, return_features=return_features)
+
+ def training_loss(self, x1, t):
+ # Use parent class training_loss for base model loss
+ # Planner is trained separately via loss_planner_flexible with reward gradients
+ unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t)
+ return unmask_loss, insertion_loss, total_loss
+
+
+ def training_step(self, batch, batch_idx):
+ # Extract input data
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ t = self.sample_time(x1.shape[0], x1.device)
+
+ # Calculate the base model loss (planner trained separately, not here)
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
+
+ # Log component losses
+ self.log("train/unmask_loss", unmask_loss, prog_bar=True)
+ self.log("train/len_loss", len_loss, prog_bar=True)
+ self.log("train/total_loss", loss, prog_bar=True)
+
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ t = self.sample_time(x1.shape[0], x1.device)
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
+
+ self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
+ self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
+
+ return loss
+
+ @classmethod
+ def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs):
+ """
+ Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner.
+ Extracts config from original pretrained checkpoint and loads finetuned weights.
+ """
+ print(f"Loading finetuned checkpoint from {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False)
+
+ # Check if this is a wrapped checkpoint (from PeptideFinetuner)
+ hparams = checkpoint.get('hyper_parameters', {})
+ state_dict = checkpoint.get('state_dict', {})
+
+ # Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper)
+ has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys())
+
+ if has_policy_prefix:
+ # Detect model type (molecule vs peptide) based on vocab size in checkpoint
+ # Molecule models have vocab size ~1882, peptide models have ~587
+ vocab_size = None
+ for k, v in state_dict.items():
+ if 'vocab_embed.embedding' in k:
+ vocab_size = v.shape[0]
+ break
+
+ is_molecule_model = vocab_size is not None and vocab_size > 1000
+ model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner"
+ print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})")
+
+ # Extract args from hyperparameters
+ if 'args' not in hparams:
+ raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.")
+
+ args = hparams['args']
+ print(f"Found args in hyperparameters, type: {type(args)}")
+
+ # Get original checkpoint path from args
+ # Handle both Namespace (hasattr) and dict (get) access patterns
+ original_ckpt_path = None
+ if hasattr(args, 'checkpoint_path'):
+ original_ckpt_path = args.checkpoint_path
+ elif isinstance(args, dict) and 'checkpoint_path' in args:
+ original_ckpt_path = args['checkpoint_path']
+
+ # If checkpoint_path is not set or is None, use default pretrained checkpoint
+ # Select appropriate default based on detected model type
+ if original_ckpt_path is None:
+ _repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ if is_molecule_model:
+ original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt')
+ print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint")
+ else:
+ original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt')
+ print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint")
+
+ # Try to load config directly from checkpoint first (new checkpoints)
+ # Fall back to loading from original checkpoint (old checkpoints)
+ if 'config' in checkpoint:
+ print("Found config directly in checkpoint")
+ config = checkpoint['config']
+ else:
+ print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}")
+
+ # Load config from original pretrained checkpoint
+ orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False)
+ if 'config' not in orig_ckpt:
+ raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config")
+
+ config = orig_ckpt['config']
+
+ # Ensure adaptive schedule is enabled
+ # Need to disable struct mode to add new keys to OmegaConf config
+ from omegaconf import OmegaConf
+ if hasattr(config, 'training'):
+ OmegaConf.set_struct(config, False)
+ config.training.use_adaptive_schedule = True
+ OmegaConf.set_struct(config, True)
+
+ # Create args object if needed
+ if not hasattr(args, '__dict__'):
+ # Convert dict to object with attributes
+ class Args:
+ pass
+ args_obj = Args()
+ for k, v in args.items():
+ setattr(args_obj, k, v)
+ args = args_obj
+
+ # Initialize model with config and args
+ model = cls(
+ config=config,
+ args=args,
+ pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint
+ insertion_planner=getattr(args, 'insertion_planner', False)
+ )
+
+ # Extract policy_model weights from state_dict
+ policy_state = {}
+ for k, v in state_dict.items():
+ if k.startswith('policy_model.'):
+ # Strip 'policy_model.' prefix
+ new_key = k[len('policy_model.'):]
+ policy_state[new_key] = v
+
+ # Load the finetuned weights
+ incompatible = model.load_state_dict(policy_state, strict=False)
+ if incompatible.missing_keys or incompatible.unexpected_keys:
+ print(f"Warning: Incompatible keys when loading finetuned weights:")
+ if incompatible.missing_keys:
+ print(f" Missing: {incompatible.missing_keys[:5]}...")
+ if incompatible.unexpected_keys:
+ print(f" Unexpected: {incompatible.unexpected_keys[:5]}...")
+
+ # Initialize or load EMA params
+ if model.use_ema:
+ if "ema_params" in checkpoint:
+ # Load EMA params from checkpoint
+ model.ema_params = checkpoint["ema_params"]
+ print("Loaded EMA params from checkpoint")
+ else:
+ # Initialize empty EMA params (will be populated if needed)
+ model.ema_params = {
+ name: param.clone().detach()
+ for name, param in model.named_parameters()
+ }
+ print("Initialized EMA params from current model state")
+ else:
+ model.ema_params = {}
+
+ # Load planner state if it exists
+ if "planner_state" in checkpoint and hasattr(model, 'planner'):
+ model.planner.load_state_dict(checkpoint["planner_state"], strict=False)
+ print("Loaded planner state from checkpoint")
+
+ return model
+ else:
+ # Not a wrapped checkpoint, use default Lightning loading
+ # But we still need to provide required __init__ arguments
+ raise NotImplementedError(
+ "Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. "
+ "Please provide config and args as kwargs."
+ )
+
+ def on_save_checkpoint(self, checkpoint):
+ """Save config and EMA params, including planner state."""
+ # Call parent to save config and base model EMA
+ super().on_save_checkpoint(checkpoint)
+
+ # Explicitly save planner state
+ if hasattr(self, 'planner'):
+ checkpoint["planner_state"] = self.planner.state_dict()
+
+ def on_load_checkpoint(self, checkpoint):
+ """Load config and reinitialize interpolant, including planner."""
+ # For finetuned checkpoints loaded via custom load_from_checkpoint,
+ # config may not be in checkpoint (it's loaded from original checkpoint)
+ if "config" in checkpoint:
+ # Call parent to restore config and interpolant
+ super().on_load_checkpoint(checkpoint)
+ else:
+ # Config already set during __init__ via load_from_checkpoint
+ # Just restore EMA params if they exist
+ if self.use_ema and "ema_params" in checkpoint:
+ self.ema_params = checkpoint["ema_params"]
+
+ # Restore planner state if it exists in checkpoint
+ if hasattr(self, 'planner') and "planner_state" in checkpoint:
+ self.planner.load_state_dict(checkpoint["planner_state"])
+ print("Loaded planner from checkpoint")
+
+ def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
+ r"""
+ Weighted denoising cross entropy loss
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
+
+ log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
+ x: [B, L] (no mask)
+ num_replicates: R, number of replicates of each row in x
+ weight_func: w(lambda) for each sample, 1/lambda by default
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
+ """
+
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
+
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
+ if centering:
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
+
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
+
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
+
+ t = lamda
+
+ # compute unmasking and insertion loss
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ scale_factor = self.config.interpolant.max_length
+
+ match self.unmask_loss_fn:
+ case "elbo":
+ mask_indices = interpolant_sample.mask_indices
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
+ prediction.token_logits[mask_indices],
+ interpolant_sample.unmasked[mask_indices],
+ reduction="none",
+ )
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
+ case _:
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
+
+ match self.insert_loss_fn:
+ case "expectation":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
+
+ case "distribution":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
+
+ total_loss = unmask_loss + insertion_loss # [B*R]
+ # end compute unmasking and insertion loss
+
+ weighted_loss = total_loss * batch_weights # [B*R]
+ return weighted_loss.mean()
+
+ def one_step_sampler(self, xt, t, pred_rate=None):
+ """
+ Sample one step of unmasking using model predictions.
+
+ Args:
+ xt: Current state [B, L]
+ t: Time [B]
+ pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model.
+
+ Returns:
+ new_xt: Next state [B, L]
+ update_ids: Boolean mask of updated positions [B, L]
+ """
+ mask = self.interpolant.mask_token
+ pad = self.interpolant.pad_token
+ batch_size, L = xt.shape
+ device = xt.device
+ steps = self.args.total_num_steps
+ dt = 1.0 / steps
+ max_length = self.interpolant.max_length
+ # Use actual tensor dimension L instead of max_length to handle replicated batches
+ batch_idx_L = (
+ torch.arange(batch_size, device=device)
+ .view(batch_size, 1)
+ .expand(batch_size, L)
+ )
+ pos_idx_L = (
+ torch.arange(L, device=device)
+ .view(1, L)
+ .expand(batch_size, L)
+ )
+
+ # ——— predict and convert rates ———
+ if pred_rate is None:
+ pred_rate = self(xt, t)
+ pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t)
+ unmask_rate = pred_rate.unmask_rate # (B, L, V)
+ len_rate = pred_rate.length_rate # (B, L+1)
+
+ # ——— unmask step (Euler) ———
+ mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True)
+ unmask_rate[xt != mask] = 0
+ unmask_rate[mask_pos + (mask,)] = 0
+ unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
+ trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
+
+ # add "stay" probability
+ _xt = xt.clone()
+ _xt[xt == pad] = mask
+ trans_prob.scatter_add_(
+ 2,
+ _xt.unsqueeze(-1),
+ torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
+ )
+
+ trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
+
+ # Renormalize probabilities to ensure they sum to 1
+ prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
+ # Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
+ mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
+ if mask_has_zero_prob.any():
+ # Create uniform distribution over valid tokens (excluding mask and pad)
+ num_zero_prob = mask_has_zero_prob.sum().item()
+ uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
+ uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
+ trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
+ else:
+ # Normalize to sum to 1
+ trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
+
+ new_xt = _sample_tokens(trans_prob)
+ new_xt[xt == pad] = pad
+ new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
+
+ # update indices--boolean tensor of shape (B, max_length)
+ # A position is updated if:
+ # 1. The token changed (xt != new_xt)
+ # 2. It's not a pad position
+ # 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask)
+
+ # Debug before fix
+ old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask)
+
+ # Correct logic: updated positions are where mask tokens were changed
+ update_ids = (xt != new_xt) & (xt != pad)
+
+ if self.insertion_planner is False:
+ return new_xt, update_ids
+
+ # ——— Poisson insertion (tau-leaping) — can insert multiple masks per gap ———
+ ext = torch.poisson(len_rate * dt).long() # (B, L+1)
+ xt_len = xt.ne(pad).sum(dim=1) # (B,)
+ # Use ext.shape[1] to get the actual max_length dimension from the data
+ actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1
+ gaps = torch.arange(ext.shape[1], device=device).view(1, -1)
+ ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
+ total_ext = ext.sum(dim=1)
+ valid = xt_len + total_ext <= actual_max_length
+ ext = ext * valid.view(batch_size, 1).long()
+
+ ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
+ new_len = xt_len + total_ext # (B,)
+
+ xt_tmp = torch.full_like(xt, pad)
+ # Create position indices that match xt_tmp's shape
+ pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1)
+ mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1)
+ xt_tmp[mask_fill] = mask
+
+ new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L)
+ orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
+ flat_b = batch_idx_L[orig_mask]
+ flat_p = new_pos_orig[orig_mask]
+ xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
+
+ new_ins_xt = xt_tmp
+
+ # Newly inserted masks: positions that are mask now but weren't before.
+ newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad)
+
+ update_ins_ids = newly_inserted_masks
+
+ return new_xt, update_ids, new_ins_xt, update_ins_ids
+
+ def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
+ r"""
+ Weighted denoising cross entropy loss
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
+
+ log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer)
+ x: [B, L] (no mask)
+ num_replicates: R, number of replicates of each row in x
+ weight_func: w(lambda) for each sample, 1/lambda by default
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
+ """
+
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
+ batch_size = batch.shape[0]
+
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
+ if centering:
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
+
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
+
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
+
+ t = lamda
+ scale_factor = self.config.interpolant.max_length
+
+ # compute unmasking and insertion loss
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ with torch.no_grad(): # no need to compute gradient in this step
+ sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction)
+ # one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids)
+ xs, update_ids = sampler_out[0], sampler_out[1]
+
+ # The remasking head scores the freshly-decoded tokens to decide which to
+ # remask, so it reads the POST-unmask state xs (matching inference, which
+ # calls the planner on the decoded new_xt).
+ planner = self.planner(xs, t)
+ remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
+
+ # Compute per-sample loss
+ # IMPORTANT: interpolant_sample.xt has been reordered via st permutation
+ # We need to map back to the original positions to compare with batch
+ st = interpolant_sample.st # [B*R, L] permutation indices
+ batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
+
+ binary_label = (xs == batch_reordered).float()
+
+ # Only compute loss on positions that were updated
+ per_token_loss = F.binary_cross_entropy_with_logits(
+ remasking_conf.squeeze(-1), # [B*R, L]
+ binary_label, # [B*R, L]
+ reduction="none" # [B*R, L]
+ )
+
+ per_token_loss = per_token_loss * update_ids.float() # [B*R, L]
+
+ # Mask out non-updated positions and average per sample
+ per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R]
+
+ # Weight by importance sampling weights
+ weighted_loss = per_sample_loss * batch_weights # [B*R]
+
+ # ——— AUC / label-balance diagnostics (see loss_insert_planner_flexible) ———
+ with torch.no_grad():
+ metrics = {}
+ sel_u = update_ids.bool()
+ if sel_u.any():
+ u_scores = remasking_conf.squeeze(-1)[sel_u]
+ u_labels = binary_label[sel_u]
+ metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
+ metrics["unmask_label_mean"] = u_labels.mean().item()
+ metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
+ metrics["unmask_n"] = float(sel_u.sum().item())
+ self._last_planner_metrics = metrics
+
+ return weighted_loss.mean()
+
+ def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
+ r"""
+ Weighted denoising cross entropy loss
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
+
+ log_rnd: [B] — pre-computed importance weights
+ x: [B, L] (no mask)
+ num_replicates: R, number of replicates of each row in x
+ weight_func: w(lambda) for each sample, 1/lambda by default
+ centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
+ softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
+ """
+
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
+ batch_size = batch.shape[0]
+
+ batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
+ if centering:
+ batch_weights = batch_weights - centering_strength * batch_weights.mean()
+
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
+
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
+
+ t = lamda
+ scale_factor = self.config.interpolant.max_length
+
+ # compute unmasking and insertion loss
+ # deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted
+ # gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices
+ interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ with torch.no_grad(): # no need to compute gradient in this step
+ xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction)
+
+ # The remasking head scores the freshly-decoded tokens to decide which to
+ # remask, so it must see the POST-unmask state xs_unmask (matching
+ # inference in inference_quality.py, which calls the planner on the
+ # decoded new_xt). Grad stays on here since this head is what we train.
+ planner = self.planner(xs_unmask, t)
+ remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
+
+ # The insertion-quality head scores the freshly-inserted mask tokens, so
+ # it must see the POST-insertion state xs_insert (aligned with
+ # update_ins_ids / insertion_quality below, and matching inference in
+ # remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on
+ # here since this head is what we are training.
+ if self.planner.insertion_planner:
+ insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1]
+ else:
+ insertion_conf = None
+
+ # Compute per-sample loss
+ # IMPORTANT: interpolant_sample.xt has been reordered via st permutation
+ # We need to map back to the original positions to compare with batch
+ # Use the st (permutation) to get the ground truth in the reordered space
+ st = interpolant_sample.st # [B*R, L] permutation indices
+ batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
+
+ # Now compare in the reordered space
+ binary_label = (xs_unmask == batch_reordered).float()
+
+ # Only compute loss on positions that were updated
+ per_token_loss = F.binary_cross_entropy_with_logits(
+ remasking_conf.squeeze(-1), # [B*R, L]
+ binary_label, # [B*R, L]
+ reduction="none" # [B*R, L]
+ )
+
+ per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L]
+
+ # Mask out non-updated positions and average per sample
+ unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R]
+
+ # compute insertion planner loss
+ # For positions where masks were inserted, we evaluate the quality of insertion
+ # by computing the probability that the ground truth token would be predicted at that position
+
+ # IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted
+ # The original prediction was computed from xt (before insertion)
+ with torch.no_grad():
+ prediction_after_insert: ModelPrediction = self(xs_insert, t)
+
+ # Get the token prediction probabilities at inserted mask positions
+ # prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert
+ token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V]
+
+ # For each gap where masks were inserted, compute the sum of probabilities
+ # of the ground truth tokens that were deleted in that specific gap
+ # gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices
+ # batch: [B*R, L] - ground truth tokens in original space (before permutation)
+
+ vocab_size = token_probs.shape[-1]
+ L = token_probs.shape[1]
+ max_gaps = gap_assignment.shape[1]
+
+ # For each gap, create a vocabulary mask of tokens that belong to that gap
+ # gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx
+ gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float)
+
+ # Vectorized: gather tokens from batch for all gaps at once
+ # tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions
+ tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L]
+
+ # valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad
+ valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L]
+
+ # Scatter tokens into vocabulary dimension: mark which tokens appear in each gap
+ gap_vocab_mask.scatter_add_(
+ 2, # scatter along vocabulary dimension
+ tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L]
+ valid_mask.float() # values to add [B*R, max_gaps, L]
+ )
+
+ # Binarize: a token either appears in the gap or not
+ gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V]
+
+ # For each insertion position in xs_insert, determine which gap it corresponds to
+ # Position p in xs_insert corresponds to gap p (insertions occur between existing tokens)
+ # Vectorized: compute for all positions at once
+ # token_probs: [B*R, L, V]
+ # gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1
+ insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L]
+
+ # Only consider quality at positions where masks were actually inserted
+ insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L]
+
+ # Compute insertion planner loss only if insertion_planner is enabled
+ if insertion_conf is not None:
+ # The planner predicts insertion confidence with insertion_conf
+ # We want to train it to predict high confidence when insertion_quality is high
+ # Use Bernoulli cross-entropy: treat insertion_quality as the "success probability"
+
+ # Binary cross-entropy with insertion_quality as continuous labels in [0,1]
+ ins_per_token_loss = F.binary_cross_entropy_with_logits(
+ insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits
+ insertion_quality, # [B*R, L] - ground truth token probability as quality metric
+ reduction="none"
+ )
+
+ # Only compute loss where masks were actually inserted
+ ins_per_token_loss = ins_per_token_loss * update_ins_ids.float()
+
+ # Average per sample
+ ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8)
+ else:
+ # No insertion planner - set loss to zero
+ ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss)
+
+ # Add to total loss
+ per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss
+
+ # Weight by importance sampling weights
+ weighted_loss = per_sample_loss * batch_weights # [B*R]
+
+ # ——— AUC / label-balance diagnostics (the loss alone hides degenerate
+ # targets; near-0 BCE can mean "all labels one class", not "learned") ———
+ with torch.no_grad():
+ metrics = {}
+ sel_u = update_unmask_ids.bool()
+ if sel_u.any():
+ u_scores = remasking_conf.squeeze(-1)[sel_u]
+ u_labels = binary_label[sel_u]
+ metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
+ metrics["unmask_label_mean"] = u_labels.mean().item()
+ metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
+ metrics["unmask_n"] = float(sel_u.sum().item())
+ if insertion_conf is not None:
+ sel_i = update_ins_ids.bool()
+ if sel_i.any():
+ i_scores = insertion_conf.squeeze(-1)[sel_i]
+ i_targets = insertion_quality[sel_i]
+ i_labels = (i_targets > 0.5).float()
+ metrics["insert_auc"] = _binary_auc(i_scores, i_labels)
+ metrics["insert_target_mean"] = i_targets.mean().item()
+ metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item()
+ metrics["insert_n"] = float(sel_i.sum().item())
+ self._last_planner_metrics = metrics
+
+ return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean()
diff --git a/lightning_modules/any_order.py b/lightning_modules/any_order.py
new file mode 100755
index 0000000000000000000000000000000000000000..9c3c8df74f67033e161bc3afd3ca4ad960d6790a
--- /dev/null
+++ b/lightning_modules/any_order.py
@@ -0,0 +1,417 @@
+import torch
+import pytorch_lightning as pl
+from omegaconf import DictConfig
+import torch.nn.functional as F
+from model.transformer import AnyOrderMaskInsertionFlow
+from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
+from .bregman import jump_kernel_elbo, mse
+from .schedule import get_schedule_from_config
+
+
+import re
+from typing import Dict, Any
+
+
+def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Returns a new state_dict where any key containing '._orig_mod.' is replaced
+ by removing the '_orig_mod' segment, e.g.
+ 'model._orig_mod.vocab_embed.embedding'
+ becomes
+ 'model.vocab_embed.embedding'
+ """
+ new_state_dict: Dict[str, Any] = {}
+ for key, value in state_dict.items():
+ # remove all occurrences of '._orig_mod.'
+ clean_key = re.sub(r"\._orig_mod\.", ".", key)
+ new_state_dict[clean_key] = value
+ return new_state_dict
+
+
+class AnyOrderInsertionFlowModule(pl.LightningModule):
+ def __init__(self, config: DictConfig):
+ super().__init__()
+ self.config = config
+ self.model_type = config.interpolant.type
+ self.learning_rate = config.training.learning_rate
+ self.unmask_loss_fn = config.training.loss_fn.unmask
+ self.insert_loss_fn = config.training.loss_fn.insert
+
+ # Initialize model based on type
+ self.model = AnyOrderMaskInsertionFlow(config)
+ # self.model = torch.compile(self.model) # Disabled: incompatible with flex_attention nested functions
+
+ insert_schedule = get_schedule_from_config(config.interpolant.insert_schedule)
+ unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
+
+ # Initialize interpolant
+ self.interpolant = AnyOrderMaskInsertionInterpolant(
+ insertion_schedule=insert_schedule,
+ unmask_schedule=unmask_schedule,
+ vocab_size=config.interpolant.tokens,
+ mask_token=config.interpolant.mask_token,
+ pad_token=config.interpolant.pad_token,
+ max_length=config.interpolant.max_length,
+ )
+
+ # Save hyperparameters
+ self.save_hyperparameters()
+
+ self.ema_decay = config.training.ema_decay or 0.0
+ self.use_ema = self.ema_decay > 0
+ self._orig_params = {}
+
+ def forward(self, x, t, return_features: bool = False):
+ if self.config.training.only_embed_insert:
+ result = self.model(x, self.interpolant.insertion_schedule.at(t), return_features=return_features)
+ else:
+ result = self.model(x, t, return_features=return_features)
+ return result
+
+ def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
+ """Delegate to backbone transformer for RemaskingAnyOrder compatibility."""
+ return self.model.get_hidden_states(indices, t)
+
+ def training_loss(self, x1, t):
+ interpolant_sample = self.interpolant.sample_interpolant(t, x1)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x1)
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ scale_factor = x1.shape[0] * self.config.interpolant.max_length
+
+ match self.unmask_loss_fn:
+ case "elbo":
+ mask_indices = interpolant_sample.mask_indices
+ unmask_loss = unmask_weight[mask_indices] * F.cross_entropy(
+ prediction.token_logits[mask_indices],
+ interpolant_sample.unmasked[mask_indices],
+ reduction="none",
+ )
+ unmask_loss = unmask_loss.sum() / scale_factor
+ case _:
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
+
+ match self.insert_loss_fn:
+ case "expectation":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss = insert_weight[gaps_mask] * jump_kernel_elbo(
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss.sum() / scale_factor
+
+ case "distribution":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss = insert_weight[gaps_mask] * F.cross_entropy(
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss.sum() / scale_factor
+
+ total_loss = unmask_loss + insertion_loss
+ return unmask_loss, insertion_loss, total_loss
+
+ def prepare_noised_sample(self, x, num_samples=1, t=None):
+ """
+ Run the forward noising process on clean sequences x.
+ Replicates each sequence num_samples times with independent random times
+ so that both policy and pretrained can evaluate the same noised data.
+
+ Args:
+ x: [B, L] clean token sequences (no mask tokens)
+ num_samples: K, number of noisy time samples per sequence
+ t: [B*K] optional time values. If None, sampled uniformly.
+
+ Returns:
+ dict with all artifacts needed by compute_loss_from_noised.
+ """
+ B = x.shape[0]
+ x_rep = x.repeat_interleave(num_samples, dim=0) # [B*K, L]
+ if t is None:
+ t = torch.rand(B * num_samples, device=x.device)
+
+ interpolant_sample = self.interpolant.sample_interpolant(t, x_rep)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x_rep)
+ scale_factor = self.config.interpolant.max_length
+
+ return {
+ "interpolant_sample": interpolant_sample,
+ "unmask_weight": unmask_weight,
+ "insert_weight": insert_weight,
+ "t": t,
+ "scale_factor": scale_factor,
+ "num_samples": num_samples,
+ "batch_size": B,
+ }
+
+ def compute_loss_from_noised(self, noised):
+ """
+ Compute per-sample denoising loss given pre-noised data.
+ Each model runs its own forward pass on the shared noised xt.
+
+ Args:
+ noised: dict from prepare_noised_sample()
+
+ Returns:
+ total_loss: [B] per-sample loss averaged over K noisy samples
+ """
+ interpolant_sample = noised["interpolant_sample"]
+ unmask_weight = noised["unmask_weight"]
+ insert_weight = noised["insert_weight"]
+ t = noised["t"]
+ scale_factor = noised["scale_factor"]
+ num_samples = noised["num_samples"]
+ B = noised["batch_size"]
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ match self.unmask_loss_fn:
+ case "elbo":
+ mask_indices = interpolant_sample.mask_indices
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*K, L]
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
+ prediction.token_logits[mask_indices],
+ interpolant_sample.unmasked[mask_indices],
+ reduction="none",
+ )
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*K]
+ case _:
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
+
+ match self.insert_loss_fn:
+ case "expectation":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
+ case "distribution":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K]
+
+ per_replicate_loss = unmask_loss + insertion_loss # [B*K]
+ per_sample_loss = per_replicate_loss.view(B, num_samples).mean(dim=1) # [B]
+ return per_sample_loss
+
+ def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False):
+ r"""
+ Weighted denoising cross entropy loss
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
+
+ log_rnd: [B]; x: [B, L] (no mask)
+ num_replicates: R, number of replicates of each row in x
+ weight_func: w(lambda) for each sample, 1/lambda by default
+ """
+
+ print("logrnd shape:", log_rnd.shape)
+ print("x shape:", x.shape)
+
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
+
+ batch_weights = log_rnd.detach().softmax(dim=-1) # [B*R]
+ if centering:
+ batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
+
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
+
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
+
+ t = lamda
+
+ # compute unmasking and insertion loss
+ interpolant_sample = self.interpolant.sample_interpolant(t, batch)
+ unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
+
+ prediction: ModelPrediction = self(interpolant_sample.xt, t)
+
+ scale_factor = self.config.interpolant.max_length
+
+ match self.unmask_loss_fn:
+ case "elbo":
+ mask_indices = interpolant_sample.mask_indices
+ unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
+ unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
+ prediction.token_logits[mask_indices],
+ interpolant_sample.unmasked[mask_indices],
+ reduction="none",
+ )
+ unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
+ case _:
+ raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
+
+ match self.insert_loss_fn:
+ case "expectation":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
+ gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
+
+ case "distribution":
+ gaps, gaps_mask = interpolant_sample.gaps_and_mask
+ insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
+ insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
+ prediction.length_posterior[gaps_mask], gaps[gaps_mask]
+ )
+ insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
+
+ total_loss = unmask_loss + insertion_loss # [B*R]
+ # end compute unmasking and insertion loss
+
+ weighted_loss = total_loss * batch_weights # [B*R]
+ return weighted_loss.mean()
+
+ def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
+ eps = 1e-6
+ interval = 1.0 - eps
+ interval_size = interval / batch_size
+ u = torch.rand(batch_size, device=device)
+ return (torch.arange(batch_size, device=device, dtype=u.dtype) + u) * interval_size
+
+ def training_step(self, batch, batch_idx):
+ # Extract input data
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ t = self.sample_time(x1.shape[0], x1.device)
+
+ # Calculate the combined loss normally
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
+
+ # Log component losses
+ self.log("train/unmask_loss", unmask_loss, prog_bar=True)
+ self.log("train/len_loss", len_loss, prog_bar=True)
+ self.log("train/total_loss", loss, prog_bar=True)
+
+
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ t = self.sample_time(x1.shape[0], x1.device)
+ unmask_loss, len_loss, loss = self.training_loss(x1, t)
+
+ self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
+ self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.AdamW(
+ self.parameters(),
+ lr=self.learning_rate,
+ weight_decay=self.config.training.weight_decay,
+ )
+
+ warmup_steps = self.config.training.warmup_steps
+ max_steps = self.config.training.max_steps
+
+ # Always create a fresh schedule starting from step 0
+ # This allows extending training beyond original max_steps
+ linear_scheduler = torch.optim.lr_scheduler.LinearLR(
+ optimizer,
+ start_factor=1e-6,
+ end_factor=1.0,
+ total_iters=warmup_steps,
+ last_epoch=-1,
+ )
+ post_warmup = max_steps - warmup_steps
+ cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer,
+ T_max=post_warmup,
+ eta_min=0.0,
+ last_epoch=-1,
+ )
+
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
+ optimizer,
+ schedulers=[linear_scheduler, cosine_scheduler],
+ milestones=[warmup_steps],
+ last_epoch=-1,
+ )
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
+
+ def optimizer_step(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer,
+ optimizer_closure=None,
+ ):
+ super().optimizer_step(
+ epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
+ )
+ # log learning rate and gradient norm
+ lr = optimizer.param_groups[0]["lr"]
+ self.log("train/lr", lr, on_step=True, prog_bar=True)
+ grad_norm = torch.sqrt(
+ sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
+ )
+ self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
+
+ # update EMA
+ if self.use_ema:
+ for n, p in self.named_parameters():
+ self.ema_params[n].mul_(self.ema_decay).add_(
+ p.data.clone().detach(), alpha=1 - self.ema_decay
+ )
+
+ def on_save_checkpoint(self, checkpoint):
+ checkpoint["config"] = self.config
+ # save EMA state
+ if self.use_ema:
+ checkpoint["ema_params"] = {
+ n: v.clone() for n, v in self.ema_params.items()
+ }
+
+ def on_load_checkpoint(self, checkpoint):
+ self.config = checkpoint["config"]
+
+ insert_schedule = get_schedule_from_config(
+ self.config.interpolant.insert_schedule
+ )
+ unmask_schedule = get_schedule_from_config(
+ self.config.interpolant.unmask_schedule
+ )
+
+ self.interpolant = AnyOrderMaskInsertionInterpolant(
+ insertion_schedule=insert_schedule,
+ unmask_schedule=unmask_schedule,
+ vocab_size=self.config.interpolant.tokens,
+ mask_token=self.config.interpolant.mask_token,
+ pad_token=self.config.interpolant.pad_token,
+ max_length=self.config.interpolant.max_length,
+ )
+
+ self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
+
+ def swap_to_ema(self):
+ for name, p in self.named_parameters():
+ self._orig_params[name] = p.data.clone()
+ p.data.copy_(self.ema_params[name].to(p.device))
+
+ def restore_original(self):
+ for name, p in self.named_parameters():
+ p.data.copy_(self._orig_params[name])
+ self._orig_params.clear()
+
+ def on_train_start(self):
+ # initialize and move EMA buffers once model is on correct device
+ if self.use_ema:
+ self.ema_params = {
+ name: param.clone().detach().to(self.device)
+ for name, param in self.named_parameters()
+ }
+ for buf in self.ema_params.values():
+ buf.requires_grad = False
\ No newline at end of file
diff --git a/lightning_modules/bregman.py b/lightning_modules/bregman.py
new file mode 100755
index 0000000000000000000000000000000000000000..825f26eacb1b1eaab7b52f3e0406923138681ab7
--- /dev/null
+++ b/lightning_modules/bregman.py
@@ -0,0 +1,19 @@
+# A file of bregman divergences
+import torch
+
+
+def mse(x, y):
+ sq_diff = (x - y) ** 2
+ if x.shape != y.shape:
+ assert False, "x and y must have the same shape"
+ return sq_diff.reshape(sq_diff.size(0), -1).sum(dim=-1)
+
+
+# TODO: check if this formulation is correct
+def jump_kernel_elbo(x, y, eps=1e-6):
+ # x_safe: true length
+ # y_safe: predicted length
+ x_safe = torch.clamp(x, min=eps)
+ y_safe = torch.clamp(y, min=eps)
+
+ return y_safe - x_safe + x_safe * (torch.log(x_safe) - torch.log(y_safe))
diff --git a/lightning_modules/mdm.py b/lightning_modules/mdm.py
new file mode 100755
index 0000000000000000000000000000000000000000..3360925d99a3452898d45b3855f7f25d807bc6d0
--- /dev/null
+++ b/lightning_modules/mdm.py
@@ -0,0 +1,184 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from model.MDM_transformer import DDiTNoLengthModel
+from model.interpolant import MDMInterpolant # replaced relative import
+from .schedule import get_schedule_from_config
+
+
+class MaskedDiffusionModule(pl.LightningModule):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.learning_rate = config.training.learning_rate
+
+ # Initialize model (no length head)
+ self.model = DDiTNoLengthModel(config)
+ self.model = torch.compile(self.model)
+
+ unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
+
+ # Initialize interpolant
+ self.interpolant = MDMInterpolant(
+ unmask_schedule=unmask_schedule,
+ vocab_size=config.interpolant.tokens,
+ mask_token=config.interpolant.mask_token,
+ pad_token=config.interpolant.pad_token,
+ max_length=config.interpolant.max_length,
+ )
+
+ # Save hyperparameters
+ self.save_hyperparameters()
+
+ self.ema_decay = config.training.ema_decay or 0.0
+ self.use_ema = self.ema_decay > 0
+ self._orig_params = {}
+
+ def forward(self, x, t) -> torch.Tensor:
+ return self.model(x, t)
+
+ def training_loss(self, x1, t):
+ # sample interpolant and elbo weight
+
+ interpolant_result = self.interpolant.sample_interpolant(t, x1)
+ unmask_weight = self.interpolant.elbo_weight(t, x1)
+
+ # model prediction
+ predicted_logits = self(interpolant_result.xt, t)
+ mask_indices = interpolant_result.mask_indices
+
+ # compute unmask loss
+ loss = unmask_weight[mask_indices] * F.cross_entropy(
+ predicted_logits[mask_indices],
+ interpolant_result.unmasked[mask_indices],
+ reduction="none",
+ )
+
+ loss = loss.sum() / (x1.shape[0] * self.config.interpolant.max_length)
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ # Extract input data
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ batch_size = x1.shape[0]
+ t = torch.rand(batch_size, device=x1.device)
+ loss = self.training_loss(x1, t)
+
+ self.log("train/total_loss", loss, prog_bar=True)
+
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ if isinstance(batch, dict):
+ batch = batch["input_ids"]
+
+ x1 = batch
+ batch_size = x1.shape[0]
+
+ t = torch.rand(batch_size, device=x1.device)
+ loss = self.training_loss(x1, t)
+
+ self.log("val_loss", loss, prog_bar=True)
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.AdamW(
+ self.parameters(),
+ lr=self.learning_rate,
+ weight_decay=self.config.training.weight_decay,
+ )
+ warmup_steps = self.config.training.warmup_steps
+ max_steps = self.config.training.max_steps
+
+ linear_scheduler = torch.optim.lr_scheduler.LinearLR(
+ optimizer,
+ start_factor=1e-6,
+ end_factor=1.0,
+ total_iters=warmup_steps,
+ )
+ post_warmup = max_steps - warmup_steps
+ cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ optimizer,
+ T_0=post_warmup // 10,
+ T_mult=1,
+ eta_min=0.0,
+ )
+
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
+ optimizer,
+ schedulers=[linear_scheduler, cosine_scheduler],
+ milestones=[warmup_steps],
+ )
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
+
+ def optimizer_step(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer,
+ optimizer_closure=None,
+ ):
+ super().optimizer_step(
+ epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
+ )
+ # log learning rate and gradient norm
+ lr = optimizer.param_groups[0]["lr"]
+ self.log("train/lr", lr, on_step=True, prog_bar=True)
+ grad_norm = torch.sqrt(
+ sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
+ )
+ self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
+
+ # update EMA
+ if self.use_ema:
+ for n, p in self.named_parameters():
+ self.ema_params[n].mul_(self.ema_decay).add_(
+ p.data.clone().detach(), alpha=1 - self.ema_decay
+ )
+
+ def on_save_checkpoint(self, checkpoint):
+ checkpoint["config"] = self.config
+ # save EMA state
+ if self.use_ema:
+ checkpoint["ema_params"] = {n: v.cpu() for n, v in self.ema_params.items()}
+
+ def on_load_checkpoint(self, checkpoint):
+ self.config = checkpoint["config"]
+
+ unmask_schedule = get_schedule_from_config(
+ self.config.interpolant.unmask_schedule
+ )
+
+ self.interpolant = MDMInterpolant(
+ unmask_schedule=unmask_schedule,
+ vocab_size=self.config.interpolant.tokens,
+ mask_token=self.config.interpolant.mask_token,
+ pad_token=self.config.interpolant.pad_token,
+ max_length=self.config.interpolant.max_length,
+ )
+
+ self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
+
+ def swap_to_ema(self):
+ for name, p in self.named_parameters():
+ self._orig_params[name] = p.data.clone()
+ p.data.copy_(self.ema_params[name].to(p.device))
+
+ def restore_original(self):
+ for name, p in self.named_parameters():
+ p.data.copy_(self._orig_params[name])
+ self._orig_params.clear()
+
+ def on_train_start(self):
+ # initialize and move EMA buffers once model is on correct device
+ if self.use_ema:
+ self.ema_params = {
+ name: param.clone().detach().to(self.device)
+ for name, param in self.named_parameters()
+ }
+ for buf in self.ema_params.values():
+ buf.requires_grad = False
diff --git a/lightning_modules/schedule.py b/lightning_modules/schedule.py
new file mode 100755
index 0000000000000000000000000000000000000000..38e27fd72d24cbba5b95c1d1f63fa0a2b7e2c48b
--- /dev/null
+++ b/lightning_modules/schedule.py
@@ -0,0 +1,156 @@
+import abc
+from omegaconf import DictConfig
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+def get_schedule_from_config(config: DictConfig):
+ match config.type:
+ case "geometric":
+ return GeometricSchedule(min_val=config.min, max_val=config.max)
+ case "linear":
+ return LinearSchedule()
+ case "sin":
+ return SinSchedule()
+ case "cosine":
+ return CosineSchedule()
+ case "polynomial":
+ return PolynomialSchedule(exp=config.exp)
+ case _:
+ raise ValueError(f"Invalid schedule type: {config.type}")
+
+
+class Schedule(abc.ABC):
+ """
+ Generic schedule class for masking or noising
+ This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately
+ """
+
+ @abc.abstractmethod
+ def at(self, t: Tensor):
+ """
+ Return value a(t)
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def derivative_at(self, t: Tensor):
+ """
+ Return d/dt a(t)
+ """
+ raise NotImplementedError
+
+ def rate_scale_factor(self, t: Tensor) -> Tensor:
+ """
+ Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation
+ """
+ return self.derivative_at(t) / (1 - self.at(t))
+
+ def sample(self, shape, device) -> Tensor:
+ """
+ Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1]
+ """
+ uniform = torch.rand(shape, device=device)
+ return self.inv(uniform)
+
+ def sample_truncated(self, threshold, shape, device) -> Tensor:
+ """
+ Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1]
+ """
+ uniform = torch.rand(shape, device=device)
+ threshold = self.at(threshold)
+ return self.inv(uniform * (1 - threshold) + threshold)
+
+ @abc.abstractmethod
+ def inv(self, alpha: Tensor):
+ """
+ Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t.
+ """
+ raise NotImplementedError
+
+
+class LinearSchedule(Schedule):
+ def __init__(self):
+ pass
+
+ def at(self, t: Tensor):
+ return t
+
+ def derivative_at(self, t: Tensor):
+ return torch.ones_like(t, device=t.device)
+
+ def inv(self, alpha: Tensor):
+ return alpha
+
+
+class GeometricSchedule(Schedule, nn.Module):
+ def __init__(self, min_val: float, max_val: float):
+ super().__init__()
+ self.register_buffer("min", Tensor([min_val]))
+ self.register_buffer("max", Tensor([max_val]))
+
+ def at(self, t: Tensor):
+ min_val = self.min.to(t.device)
+ max_val = self.max.to(t.device)
+ return torch.exp(-(min_val ** (1 - t)) * max_val**t)
+
+ def derivative_at(self, t):
+ min_val = self.min.to(t.device)
+ max_val = self.max.to(t.device)
+ return (
+ self.at(t)
+ * min_val ** (1 - t)
+ * max_val**t
+ * (min_val.log() - max_val.log())
+ )
+
+ def inv(self, alpha: Tensor):
+ log_min = self.min.to(alpha.device).log()
+ log_max = self.max.to(alpha.device).log()
+ return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min)
+
+
+class SinSchedule(Schedule, nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def at(self, t: Tensor):
+ return torch.sin(torch.pi / 2 * t)
+
+ def derivative_at(self, t: Tensor):
+ return (torch.pi / 2) * torch.cos(torch.pi / 2 * t)
+
+ def inv(self, alpha: Tensor):
+ return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.))
+
+
+class CosineSchedule(Schedule, nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def at(self, t: Tensor):
+ return 1 - torch.cos(torch.pi / 2 * t)
+
+ def derivative_at(self, t: Tensor):
+ return (torch.pi / 2) * torch.sin(torch.pi / 2 * t)
+
+ def rate_scale_factor(self, t):
+ return (torch.pi/2) * torch.tan(torch.pi / 2 * t)
+
+ def inv(self, alpha):
+ return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.))
+
+class PolynomialSchedule(Schedule, nn.Module):
+ def __init__(self, exp):
+ super().__init__()
+ self.exp = exp
+
+ def at(self, t: Tensor):
+ return t ** self.exp
+
+ def derivative_at(self, t: Tensor):
+ return self.exp * t ** (self.exp - 1)
+
+ def inv(self, alpha: Tensor):
+ return alpha ** (1 / self.exp)
\ No newline at end of file
diff --git a/model/MDM_transformer.py b/model/MDM_transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..1728fb66bea910df6c3e00b85329249271e3d4ba
--- /dev/null
+++ b/model/MDM_transformer.py
@@ -0,0 +1,75 @@
+import torch
+import torch.nn.functional as F
+from . import rotary
+from .transformer import EmbeddingLayer, TimestepEmbedder, DDiTBlock, DDitFinalLayer
+from omegaconf import OmegaConf
+from torch.nn.attention.flex_attention import create_block_mask
+
+
+def _dense_mask(b, h, q_idx, kv_idx):
+ return torch.full_like(q_idx, True, dtype=torch.bool)
+
+
+class DDiTNoLengthModel(torch.nn.Module):
+ """
+ A DDiT‐style model that predicts only per‐token posteriors,
+ without any sequence‐length head, opt for the vanilla MDM
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # allowing dict configs too
+ if isinstance(config, dict):
+ config = OmegaConf.create(config)
+
+ self.config = config
+ self.vocab_size = config.interpolant.tokens
+ self.pad_token = config.interpolant.pad_token
+ self.mask_token = config.interpolant.mask_token
+
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size)
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
+ self.rotary_emb = rotary.Rotary(
+ config.model.hidden_size // config.model.n_heads
+ )
+
+ self.blocks = torch.nn.ModuleList(
+ [
+ DDiTBlock(
+ config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout,
+ )
+ for _ in range(config.model.n_blocks)
+ ]
+ )
+ # final per‐token head only / no length head
+ self.output_layer = DDitFinalLayer(
+ config.model.hidden_size, self.vocab_size, config.model.cond_dim
+ )
+
+ def forward(self, indices: torch.Tensor, t: torch.Tensor):
+ """
+ indices: (B, L) token indices
+ t: (B,) timestep scalars
+ returns: ReparametrizedRate with only per_token_posterior set
+ """
+ B, L = indices.shape
+
+ block_mask = create_block_mask(
+ _dense_mask, B=B, H=None, Q_LEN=indices.shape[1], KV_LEN=indices.shape[1]
+ )
+ print(block_mask)
+
+ x = self.vocab_embed(indices) # (B, L, hidden)
+ c = F.silu(self.sigma_map(t)) # (B, cond_dim)
+ rotary_cos_sin = self.rotary_emb(x) # precompute rotary embeddings
+
+ # run the stack
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](x, rotary_cos_sin, c, block_mask)
+
+ token_logits = self.output_layer(x, c)
+ return token_logits
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/casual_transformer.py b/model/casual_transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd6aaba43ad4eb11ee0bd3a3d7c525d99de9bd60
--- /dev/null
+++ b/model/casual_transformer.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+from .fused_add_dropout_scale import modulate_fused, bias_dropout_add_scale_fused_train, bias_dropout_add_scale_fused_inference
+from .transformer import LayerNorm, EmbeddingLayer
+from . import rotary
+
+
+class CausalDiTBlock(nn.Module):
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
+ super().__init__()
+ self.n_heads = n_heads
+ self.norm1 = LayerNorm(dim)
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
+ self.attn_out = nn.Linear(dim, dim, bias=False)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm2 = LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_ratio * dim, dim, bias=True)
+ )
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout = dropout
+ # No time or label conditioning, so no adaLN_modulation
+
+ def _get_bias_dropout_scale(self):
+ return (
+ bias_dropout_add_scale_fused_train
+ if self.training
+ else bias_dropout_add_scale_fused_inference
+ )
+
+ def forward(self, x, rotary_cos_sin, seqlens=None):
+ batch_size, seq_len = x.shape[0], x.shape[1]
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
+
+ # attention operation
+ x_skip = x
+ x = self.norm1(x)
+ # dtype0 = x.dtype
+
+ qkv = self.attn_qkv(x)
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads)
+ with torch.cuda.amp.autocast(enabled=False):
+ cos, sin = rotary_cos_sin
+ qkv = rotary.apply_rotary_pos_emb(
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
+ )
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ if seqlens is None:
+ cu_seqlens = torch.arange(
+ 0, (batch_size + 1) * seq_len, step=seq_len,
+ dtype=torch.int32, device=qkv.device
+ )
+ else:
+ cu_seqlens = seqlens.cumsum(-1)
+ x = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, seq_len, 0., causal=True)
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
+
+ scale = torch.ones(1, device=x.device, dtype=x.dtype)
+ x = bias_dropout_scale_fn(self.attn_out(x), None, scale, x_skip, self.dropout)
+
+ # mlp operation
+ x = bias_dropout_scale_fn(
+ self.mlp(self.norm2(x)), None, scale, x, self.dropout
+ )
+ return x
+
+class CausalDiT(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if isinstance(config, dict):
+ config = OmegaConf.create(config)
+
+ self.config = config
+ self.vocab_size = config.interpolant.tokens
+ self.pad_token = config.interpolant.pad_token
+
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size)
+ self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads)
+ self.blocks = nn.ModuleList([
+ CausalDiTBlock(config.model.hidden_size, config.model.n_heads, config.model.cond_dim, dropout=config.model.dropout)
+ for _ in range(config.model.n_blocks)
+ ])
+ self.output_layer = nn.Linear(config.model.hidden_size, self.vocab_size)
+
+ def forward(self, indices):
+ x = self.vocab_embed(indices)
+ rotary_cos_sin = self.rotary_emb(x)
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
+ for block in self.blocks:
+ x = block(x, rotary_cos_sin, seqlens=None)
+ logits = self.output_layer(x)
+ return logits
\ No newline at end of file
diff --git a/model/fused_add_dropout_scale.py b/model/fused_add_dropout_scale.py
new file mode 100755
index 0000000000000000000000000000000000000000..56266ada1b469d84225e8db3eb617c05efc06cc8
--- /dev/null
+++ b/model/fused_add_dropout_scale.py
@@ -0,0 +1,66 @@
+import torch
+import torch.nn.functional as F
+from typing import Optional
+from torch import Tensor
+
+# flags required to enable jit fusion kernels
+torch._C._jit_set_profiling_mode(False)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_override_can_fuse_on_cpu(True)
+torch._C._jit_override_can_fuse_on_gpu(True)
+
+
+def bias_dropout_add_scale(
+ x: Tensor,
+ bias: Optional[Tensor],
+ scale: Tensor,
+ residual: Optional[Tensor],
+ prob: float,
+ training: bool,
+) -> Tensor:
+ if bias is not None:
+ out = scale * F.dropout(x + bias, p=prob, training=training)
+ else:
+ out = scale * F.dropout(x, p=prob, training=training)
+
+ if residual is not None:
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add_scale(training):
+ def _bias_dropout_add(x, bias, scale, residual, prob):
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, training)
+
+ return _bias_dropout_add
+
+
+def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
+ return x * (1 + scale) + shift
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_train(
+ x: Tensor,
+ bias: Optional[Tensor],
+ scale: Tensor,
+ residual: Optional[Tensor],
+ prob: float,
+) -> Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, True)
+
+
+@torch.jit.script
+def bias_dropout_add_scale_fused_inference(
+ x: Tensor,
+ bias: Optional[Tensor],
+ scale: Tensor,
+ residual: Optional[Tensor],
+ prob: float,
+) -> Tensor:
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, False)
+
+
+@torch.jit.script
+def modulate_fused(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
+ return modulate(x, shift, scale)
diff --git a/model/interpolant.py b/model/interpolant.py
new file mode 100755
index 0000000000000000000000000000000000000000..a461bfdecd4eba425f17368e0893a092b8075826
--- /dev/null
+++ b/model/interpolant.py
@@ -0,0 +1,411 @@
+import abc
+from typing import Optional
+import torch
+from torch import Tensor
+from dataclasses import dataclass
+from .schedule import Schedule
+import torch.nn.functional as F
+
+
+@dataclass
+class ModelPrediction:
+ token_logits: Tensor
+ length_posterior: Optional[Tensor]
+ expected_gaps: Tensor
+
+ def __init__(
+ self,
+ token_logits: Tensor,
+ length_posterior: Optional[Tensor] = None,
+ expected_gaps: Optional[Tensor] = None,
+ ):
+ assert length_posterior is not None or expected_gaps is not None
+ self.token_logits = token_logits
+ self.length_posterior = length_posterior
+ self.expected_gaps = expected_gaps
+ if self.expected_gaps is None:
+ _, _, L = self.length_posterior.shape
+ index = torch.arange(0, L, device=token_logits.device).view(1, 1, -1)
+ self.expected_gaps = (F.softmax(self.length_posterior, dim=-1) * index).sum(dim=-1)
+
+
+@dataclass
+class Rate:
+ unmask_rate: Tensor # Shape [Batch, Length, Vocab]
+ length_rate: Tensor # Shape [Batch]
+
+
+@dataclass
+class HittingTime:
+ insertion_time: Tensor # Shape [Batch, Length]
+ unmasking_time: Tensor # Shape [Batch, Length]
+
+ def __iter__(self):
+ yield from [self.insertion_time, self.unmasking_time]
+
+
+@dataclass
+class JointInterpolantResult:
+ # Joint Interpolant
+ xt: Tensor # Shape [Batch, Length]
+ st: Tensor # Shape [Batch, Length]
+ _x1: Tensor
+ _pad_token: int
+ _mask_token: int
+
+ @property
+ def mask_indices(self) -> Tensor:
+ return self.xt == self._mask_token
+
+ @property
+ def unmasked(self) -> Tensor:
+ return torch.gather(self._x1, 1, self.st)
+
+ @property
+ def xt_length(self) -> Tensor:
+ # Calculate length of xt
+ return (self.xt != self._pad_token).sum(dim=1)
+
+ @property
+ def x1_length(self) -> Tensor:
+ # Calculate length of x1
+ return (self._x1 != self._pad_token).sum(dim=1)
+
+ @property
+ def gaps_and_mask(self) -> tuple[Tensor, Tensor]:
+ x1_len = self.x1_length
+ gaps = self.st.clone()
+
+ pad_front = gaps.new_zeros((gaps.shape[0], 1)) - 1 # -1 for the front padding
+ pad_back = gaps.new_zeros((gaps.shape[0], 1))
+ gaps = torch.cat([pad_front, gaps, pad_back], dim=1) # Add a leading zero
+
+ gaps.scatter_(
+ 1, self.xt_length.unsqueeze(1) + 1, x1_len.unsqueeze(1)
+ ) # Fill the last position with x1_len
+
+ gaps = gaps[:, 1:] - gaps[:, :-1] - 1
+ gaps = torch.clamp(gaps, min=0)
+
+ idx = torch.arange(gaps.size(1), device=self.xt.device).unsqueeze(
+ 0
+ ) # shape [1, max_gap]
+ mask = idx <= self.xt_length.unsqueeze(1)
+ gaps[~mask] = 0
+
+ return gaps, mask
+
+
+class JointInterpolant(abc.ABC):
+ def __init__(
+ self,
+ vocab_size: int,
+ mask_token: int,
+ pad_token: int,
+ max_length: int,
+ ):
+ """
+ TODO: Add knobs
+ """
+ self.mask_token = mask_token
+ self.pad_token = pad_token
+ self.max_length = max_length
+ self.vocab_size = vocab_size
+
+ @abc.abstractmethod
+ def elbo_weight(self, t: Tensor, x1: Tensor):
+ """
+ Return the ELBO weight for the training, can be changed depends on the empirical results
+ Shape:
+ t: [B]
+ Returns:
+ weight_unmask: [B, L]
+ weight_delete: [B, L+1]
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def to_actual_rate(self, prediction: ModelPrediction, t: Tensor) -> Rate:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
+ """
+ Sample the interpolant xt from x1 at time t
+ Shapes:
+ x1: [B, L]
+ t: [B]
+ Returns:
+ xt: [B, L]
+ st: [B, L] boolean mask of positions that corresponds to xt
+ xt_mask_indices: [B, L] boolean mask of positions that are masked at xt
+ x1_remained: [B, L] tokens that are not deleted, used for the training target
+ gap_counts: [B, L+1] the number of deleted tokens between xt slots
+ """
+ raise NotImplementedError
+
+
+class AnyOrderMaskInsertionInterpolant(JointInterpolant):
+ def __init__(
+ self,
+ insertion_schedule: Schedule,
+ unmask_schedule: Schedule,
+ vocab_size: int,
+ mask_token: int,
+ pad_token: int,
+ max_length: int,
+ ):
+ super().__init__(vocab_size, mask_token, pad_token, max_length)
+ self.insertion_schedule = insertion_schedule
+ self.unmask_schedule = unmask_schedule
+ #self.max_length = 500
+
+ def expected_mask_fraction(self, t: Tensor, xt: Tensor) -> Tensor:
+ """
+ Compute the expected fraction of tokens that should be masked at time t.
+ For AnyOrderMaskInsertionInterpolant, tokens are:
+ - Deleted (pad) if t < insertion_time
+ - Masked if insertion_time <= t < unmasking_time
+ - Unmasked if t >= unmasking_time
+
+ We approximate: E[fraction masked] ≈ max(0, insertion_schedule.at(t) - unmask_schedule.at(t))
+
+ Args:
+ t: [B] current time
+ xt: [B, L] current sequence (to get current length)
+ Returns:
+ [B] expected number of masked tokens per sequence
+ """
+ # Get schedule values at time t
+ insertion_progress = self.insertion_schedule.at(t) # [B]
+ unmask_progress = self.unmask_schedule.at(t) # [B]
+
+ # Expected fraction of tokens that are inserted but not yet unmasked
+ # Clamp to ensure non-negative
+ expected_mask_frac = torch.clamp(insertion_progress - unmask_progress, min=0.0, max=1.0)
+
+ # Get current sequence length (non-pad tokens)
+ current_length = (xt != self.pad_token).sum(dim=1).float() # [B]
+
+ # Expected number of masked tokens
+ expected_num_masked = expected_mask_frac * current_length # [B]
+
+ return expected_num_masked
+
+ def hitting_time(self, t: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]:
+ """
+ t1 is sampled from a uniform distribution over [0, 1]. when t1 < self.mask_schedule.at(t)
+ t2 is sampled from a uniform distribution over [t1, 1]
+ """
+ B, L = x1.shape
+ eps = 1e-6
+
+ insert_time = self.insertion_schedule.sample((B, L), device=x1.device)
+ insert_time = eps + (1 - eps) * insert_time # ensure t1 is not 0
+ unmask_time = self.unmask_schedule.sample_truncated(
+ insert_time, (B, L), device=x1.device
+ )
+
+ return insert_time, unmask_time
+
+ def elbo_weight(self, t: Tensor, x1: Tensor):
+ """
+ Return the ELBO weight for the training, can be changed depends on the empirical results
+ """
+ insert_weight = self.insertion_schedule.rate_scale_factor(t)
+ insert_weight = insert_weight[:, None].expand(-1, x1.shape[1] + 1)
+
+ unmask_weight = self.unmask_schedule.rate_scale_factor(t)
+ unmask_weight = unmask_weight.unsqueeze(1).expand(-1, x1.shape[1])
+
+ return unmask_weight, insert_weight
+
+ def to_actual_rate(
+ self, xt: Tensor, prediction: ModelPrediction, t: Tensor
+ ) -> Rate:
+ """
+ Return the actual rate for the sampling
+ Args:
+ xt: [B, L] the sampled tokens
+ prediction: ModelPrediction object containing token_posterior and expected_gaps
+ t: [B] the time parameter
+ """
+ token_posterior = F.softmax(prediction.token_logits, dim=-1) # (B, L, V)
+ unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view(
+ -1, 1, 1
+ )
+
+ length_rate = (
+ prediction.expected_gaps
+ * self.insertion_schedule.rate_scale_factor(t).view(-1, 1)
+ )
+ #print("expected_gaps:", prediction.expected_gaps, "length_rate:", length_rate)
+
+ return Rate(
+ unmask_rate=unmask_rate, # (B, L, V)
+ length_rate=length_rate, # (B, L+1)
+ )
+
+ def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
+ """
+ Shapes:
+ x1: [B, L]
+ t: [B]
+ Returns:
+ xt: [B, L]
+ st: [B, L] boolean mask of positions that corresponds to xt
+ xt_mask_indices: [B, L] boolean mask of positions that are masked at xt
+ x1_remained: [B, L] tokens that are not deleted, used for the training target
+ gap_counts: [B, L+1] the number of deleted tokens between xt slots
+ """
+ # sample the stopping time (B, L, 2)
+ insertion_time, unmasking_time = self.hitting_time(t, x1)
+
+ clean_tokens = x1.ne(self.pad_token)
+ deleted_tokens = clean_tokens & (t[:, None] < insertion_time)
+ masked_tokens = (
+ clean_tokens
+ & (t[:, None] >= insertion_time)
+ & (t[:, None] < unmasking_time)
+ )
+
+ xt = torch.where(
+ deleted_tokens,
+ self.pad_token, # for deletion, change to pad token
+ torch.where(
+ masked_tokens,
+ self.mask_token, # for masking, change to mask token
+ x1,
+ ),
+ )
+
+ st = xt.ne(self.pad_token).to(torch.int32).argsort(dim=1, descending=True, stable=True) # edited to sort integers
+ xt = torch.gather(xt, 1, st)
+ st[xt == self.pad_token] = 0
+
+ return JointInterpolantResult(
+ xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token
+ )
+
+ def sample_interpolant_plan(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
+ """
+ Shapes:
+ x1: [B, L]
+ t: [B]
+ Returns:
+ xt: [B, L]
+ st: [B, L] boolean mask of positions that corresponds to xt
+ xt_mask_indices: [B, L] boolean mask of positions that are masked at xt
+ x1_remained: [B, L] tokens that are not deleted, used for the training target
+ gap_counts: [B, L+1] the number of deleted tokens between xt slots
+ """
+ # sample the stopping time (B, L, 2)
+ insertion_time, unmasking_time = self.hitting_time(t, x1)
+
+ clean_tokens = x1.ne(self.pad_token)
+ deleted_tokens = clean_tokens & (t[:, None] < insertion_time)
+ masked_tokens = (
+ clean_tokens
+ & (t[:, None] >= insertion_time)
+ & (t[:, None] < unmasking_time)
+ )
+
+ xt = torch.where(
+ deleted_tokens,
+ self.pad_token, # for deletion, change to pad token
+ torch.where(
+ masked_tokens,
+ self.mask_token, # for masking, change to mask token
+ x1,
+ ),
+ )
+ st = xt.ne(self.pad_token).to(torch.int32).argsort(dim=1, descending=True, stable=True) # edited to sort integers
+ xt = torch.gather(xt, 1, st)
+ st[xt == self.pad_token] = 0
+ num_gaps = (st != 0).sum(dim=1) + 1 # [B]
+
+ deleted_mask = deleted_tokens # [B, L]
+
+ # Create gap assignment tensor: gap_assignment[b, gap_idx, x1_pos] = 1 if x1_pos is in gap gap_idx
+ B, L = x1.shape
+ max_gaps = L + 1
+ gap_assignment = torch.zeros(B, max_gaps, L, device=x1.device, dtype=torch.float)
+
+ # For each deleted position in x1, determine which gap it belongs to
+ # Gap index = number of non-deleted positions (st values) that come before it
+ pos_indices = torch.arange(L, device=x1.device).view(1, L, 1) # [1, L, 1]
+ st_expanded = st.unsqueeze(1) # [B, 1, L]
+ st_valid_mask = (st != 0).unsqueeze(1) # [B, 1, L]
+
+ # Count how many valid st entries are less than each position
+ # gap_indices[b, pos] = number of st values < pos for deleted positions
+ gap_indices = ((st_expanded < pos_indices) & st_valid_mask).sum(dim=2) # [B, L]
+
+ # Set gap_assignment[b, gap_idx, pos] = 1 where pos is deleted and belongs to gap_idx
+ batch_idx = torch.arange(B, device=x1.device).view(B, 1).expand(B, L)
+ pos_idx = torch.arange(L, device=x1.device).view(1, L).expand(B, L)
+
+ gap_assignment[batch_idx[deleted_mask], gap_indices[deleted_mask], pos_idx[deleted_mask]] = 1.0
+
+ return JointInterpolantResult(
+ xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token
+ ), deleted_mask, gap_assignment
+
+
+class MDMInterpolant(JointInterpolant):
+ def __init__(
+ self,
+ unmask_schedule: Schedule,
+ vocab_size: int,
+ mask_token: int,
+ pad_token: int,
+ max_length: int,
+ ):
+ super().__init__(vocab_size, mask_token, pad_token, max_length)
+ self.unmask_schedule = unmask_schedule
+
+ def elbo_weight(self, t: Tensor, x1: Tensor):
+ """
+ Return the ELBO weight for the training, can be changed depends on the empirical results
+ there's no weight_delete for the vanilla MDM
+ """
+ weight_unmask = self.unmask_schedule.rate_scale_factor(t)
+ weight_unmask_expanded = weight_unmask.unsqueeze(1).expand(
+ -1, x1.shape[1]
+ ) # (B,L)
+ return weight_unmask_expanded
+
+ def to_actual_rate(self, xt: Tensor, prediction: Tensor, t: Tensor) -> Rate:
+ """
+ Return the actual rate for the sampling
+ """
+ token_posterior = F.softmax(prediction, dim=-1) # (B, L, V)
+ unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view(
+ -1, 1, 1
+ )
+
+ return Rate(
+ unmask_rate=unmask_rate, # (B, L, V)
+ length_rate=None, # (B, L+1)
+ )
+
+ def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
+ # sample the stopping time (B, L, 2)
+ eps = 1e-6
+ unmask_time = self.unmask_schedule.sample(
+ (x1.shape[0], x1.shape[1]), device=x1.device
+ )
+ unmask_time = unmask_time * (1 - eps) + eps
+
+ xt = torch.where(
+ t[:, None] < unmask_time,
+ self.mask_token, # for masking, change to mask token
+ x1,
+ )
+ st = torch.arange(xt.shape[1], device=xt.device, dtype=torch.long).repeat(
+ xt.shape[0], 1
+ )
+
+ return JointInterpolantResult(
+ xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token
+ )
diff --git a/model/model_wrapper.py b/model/model_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0993746aa08fede9e06596c763e5bb1e5b426c46
--- /dev/null
+++ b/model/model_wrapper.py
@@ -0,0 +1,100 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# ------------------------------------------------------------
+# additional sigmoid head
+# ------------------------------------------------------------
+class RemaskingHead(nn.Module):
+ def __init__(self, hidden_size: int):
+ super().__init__()
+ self.norm = nn.LayerNorm(hidden_size)
+ self.proj1 = nn.Linear(hidden_size, hidden_size)
+ self.act = nn.GELU()
+ self.proj2 = nn.Linear(hidden_size, 1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm(x)
+ h = self.proj1(h)
+ h = self.act(h)
+ h = self.proj2(h)
+ return h
+
+
+class InsertionQualityHead(nn.Module):
+ def __init__(self, hidden_size: int):
+ super().__init__()
+ self.norm = nn.LayerNorm(hidden_size)
+ self.proj1 = nn.Linear(hidden_size, hidden_size)
+ self.act = nn.GELU()
+ self.proj2 = nn.Linear(hidden_size, 1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm(x)
+ h = self.proj1(h)
+ h = self.act(h)
+ h = self.proj2(h)
+ return h
+
+
+class RemaskingAnyOrder(nn.Module):
+ """Remasking adapter for AnyOrderMaskInsertionFlow models."""
+ def __init__(self, backbone: nn.Module, d_model: int, insertion_planner: bool = False):
+ super().__init__()
+
+ # Store backbone as non-module attribute to avoid circular reference in module tree
+ # Use object.__setattr__ to bypass nn.Module's __setattr__ which registers modules
+ object.__setattr__(self, 'backbone', backbone)
+ self.d_model = d_model
+ self.insertion_planner = insertion_planner
+ self.remasking_head = RemaskingHead(d_model)
+
+ if insertion_planner:
+ self.insertion_head = InsertionQualityHead(d_model)
+
+ def forward(self, indices: torch.Tensor, t: torch.Tensor, **kwargs):
+ """
+ Forward pass for remasking training.
+
+ Args:
+ indices: Token indices [batch_size, seq_len]
+ t: Timesteps [batch_size]
+ **kwargs: Additional arguments (ignored for compatibility)
+
+ Returns:
+ Dict with 'logits', 'remasking_conf', and optionally 'insertion_conf' keys
+ """
+ # Single backbone pass returning both prediction and post-block features.
+ # features has shape [B, L+1, hidden]; the remasking/insertion heads use
+ # the same [B, L, hidden] slice that get_hidden_states would have returned.
+ prediction, features = self.backbone(indices, t, return_features=True)
+ hidden_states = features[:, :-1]
+
+ remasking_conf = self.remasking_head(hidden_states)
+ token_logits = prediction.token_logits
+
+ result = {"logits": token_logits, "remasking_conf": remasking_conf}
+
+ if self.insertion_planner:
+ insertion_conf = self.insertion_head(hidden_states)
+ result["insertion_conf"] = insertion_conf
+
+ return result
+
+ def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
+ """
+ Get hidden states and logits for adapter training.
+
+ Args:
+ indices: Token indices [batch_size, seq_len]
+ t: Timesteps [batch_size]
+
+ Returns:
+ Tuple of (token_logits, hidden_states, conditioning)
+ """
+ return self.backbone.get_hidden_states(indices, t)
+
+ @property
+ def device(self):
+ return next(self.backbone.parameters()).device
diff --git a/model/rotary.py b/model/rotary.py
new file mode 100755
index 0000000000000000000000000000000000000000..6bda73203ff2aaaf75e8db376fc2d340b601035c
--- /dev/null
+++ b/model/rotary.py
@@ -0,0 +1,54 @@
+import torch
+
+
+class Rotary(torch.nn.Module):
+ def __init__(self, dim, base=10_000):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ # dims are: batch, seq_len, qkv, head, dim
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
+ # This makes the transformation on v an identity.
+ self.cos_cached[:, :, 2, :, :].fill_(1.0)
+ self.sin_cached[:, :, 2, :, :].fill_(0.0)
+
+ return self.cos_cached, self.sin_cached
+
+
+def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def _apply_rotary_pos_emb_native(qkv, cos, sin):
+ """Native PyTorch implementation without JIT compilation"""
+ return (qkv * cos) + (rotate_half(qkv) * sin)
+
+
+@torch.jit.script
+def _apply_rotary_pos_emb_torchscript(qkv, cos, sin):
+ return (qkv * cos) + (rotate_half(qkv) * sin)
+
+
+def apply_rotary_pos_emb(qkv, cos, sin):
+ try:
+ import flash_attn.layers.rotary
+
+ cos_flash = cos[0, :, 0, 0, : cos.shape[-1] // 2]
+ sin_flash = sin[0, :, 0, 0, : sin.shape[-1] // 2]
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos_flash, sin_flash)
+ except (ImportError, AttributeError, RuntimeError):
+ # Use native implementation without TorchScript due to compatibility issues
+ return _apply_rotary_pos_emb_native(qkv, cos, sin)
diff --git a/model/schedule.py b/model/schedule.py
new file mode 100755
index 0000000000000000000000000000000000000000..38e27fd72d24cbba5b95c1d1f63fa0a2b7e2c48b
--- /dev/null
+++ b/model/schedule.py
@@ -0,0 +1,156 @@
+import abc
+from omegaconf import DictConfig
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+def get_schedule_from_config(config: DictConfig):
+ match config.type:
+ case "geometric":
+ return GeometricSchedule(min_val=config.min, max_val=config.max)
+ case "linear":
+ return LinearSchedule()
+ case "sin":
+ return SinSchedule()
+ case "cosine":
+ return CosineSchedule()
+ case "polynomial":
+ return PolynomialSchedule(exp=config.exp)
+ case _:
+ raise ValueError(f"Invalid schedule type: {config.type}")
+
+
+class Schedule(abc.ABC):
+ """
+ Generic schedule class for masking or noising
+ This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately
+ """
+
+ @abc.abstractmethod
+ def at(self, t: Tensor):
+ """
+ Return value a(t)
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def derivative_at(self, t: Tensor):
+ """
+ Return d/dt a(t)
+ """
+ raise NotImplementedError
+
+ def rate_scale_factor(self, t: Tensor) -> Tensor:
+ """
+ Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation
+ """
+ return self.derivative_at(t) / (1 - self.at(t))
+
+ def sample(self, shape, device) -> Tensor:
+ """
+ Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1]
+ """
+ uniform = torch.rand(shape, device=device)
+ return self.inv(uniform)
+
+ def sample_truncated(self, threshold, shape, device) -> Tensor:
+ """
+ Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1]
+ """
+ uniform = torch.rand(shape, device=device)
+ threshold = self.at(threshold)
+ return self.inv(uniform * (1 - threshold) + threshold)
+
+ @abc.abstractmethod
+ def inv(self, alpha: Tensor):
+ """
+ Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t.
+ """
+ raise NotImplementedError
+
+
+class LinearSchedule(Schedule):
+ def __init__(self):
+ pass
+
+ def at(self, t: Tensor):
+ return t
+
+ def derivative_at(self, t: Tensor):
+ return torch.ones_like(t, device=t.device)
+
+ def inv(self, alpha: Tensor):
+ return alpha
+
+
+class GeometricSchedule(Schedule, nn.Module):
+ def __init__(self, min_val: float, max_val: float):
+ super().__init__()
+ self.register_buffer("min", Tensor([min_val]))
+ self.register_buffer("max", Tensor([max_val]))
+
+ def at(self, t: Tensor):
+ min_val = self.min.to(t.device)
+ max_val = self.max.to(t.device)
+ return torch.exp(-(min_val ** (1 - t)) * max_val**t)
+
+ def derivative_at(self, t):
+ min_val = self.min.to(t.device)
+ max_val = self.max.to(t.device)
+ return (
+ self.at(t)
+ * min_val ** (1 - t)
+ * max_val**t
+ * (min_val.log() - max_val.log())
+ )
+
+ def inv(self, alpha: Tensor):
+ log_min = self.min.to(alpha.device).log()
+ log_max = self.max.to(alpha.device).log()
+ return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min)
+
+
+class SinSchedule(Schedule, nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def at(self, t: Tensor):
+ return torch.sin(torch.pi / 2 * t)
+
+ def derivative_at(self, t: Tensor):
+ return (torch.pi / 2) * torch.cos(torch.pi / 2 * t)
+
+ def inv(self, alpha: Tensor):
+ return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.))
+
+
+class CosineSchedule(Schedule, nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def at(self, t: Tensor):
+ return 1 - torch.cos(torch.pi / 2 * t)
+
+ def derivative_at(self, t: Tensor):
+ return (torch.pi / 2) * torch.sin(torch.pi / 2 * t)
+
+ def rate_scale_factor(self, t):
+ return (torch.pi/2) * torch.tan(torch.pi / 2 * t)
+
+ def inv(self, alpha):
+ return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.))
+
+class PolynomialSchedule(Schedule, nn.Module):
+ def __init__(self, exp):
+ super().__init__()
+ self.exp = exp
+
+ def at(self, t: Tensor):
+ return t ** self.exp
+
+ def derivative_at(self, t: Tensor):
+ return self.exp * t ** (self.exp - 1)
+
+ def inv(self, alpha: Tensor):
+ return alpha ** (1 / self.exp)
\ No newline at end of file
diff --git a/model/transformer.py b/model/transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..f488aa4420c9d0aa0815d87d430f4986dea0bbd3
--- /dev/null
+++ b/model/transformer.py
@@ -0,0 +1,404 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from einops import rearrange
+from omegaconf import OmegaConf
+from .interpolant import ModelPrediction
+from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from . import rotary
+from .fused_add_dropout_scale import (
+ bias_dropout_add_scale_fused_train,
+ bias_dropout_add_scale_fused_inference,
+ modulate_fused,
+)
+
+
+# Disable torch.compile for flex_attention due to batch dimension mismatch issues
+# flex_attention = torch.compile(flex_attention, mode="max-autotune")
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Layers #
+#################################################################################
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones([dim]))
+ self.dim = dim
+
+ def forward(self, x):
+ with torch.amp.autocast("cuda", enabled=False):
+ x = F.layer_norm(x.float(), [self.dim])
+ return x * self.weight[None, None, :]
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256, silu=True):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, cond_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
+ self.num_classes = num_classes
+
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
+
+ def forward(self, labels):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+# length scalar head
+class ScalarLengthHead(nn.Module):
+ def __init__(self, d_model: int, normalized_len: int, cond_dim: int | None = None):
+ super().__init__()
+ self.has_cond = cond_dim is not None
+ if self.has_cond:
+ self.adaLN = nn.Linear(cond_dim, 2 * d_model, bias=True)
+ self.adaLN.weight.data.zero_()
+ self.adaLN.bias.data.zero_()
+
+ self.norm = LayerNorm(d_model)
+ self.proj1 = nn.Linear(d_model, d_model)
+ self.act = nn.GELU()
+ self.proj2 = nn.Linear(d_model, 1)
+ self.softplus = nn.Softplus()
+ self.normalized_len = normalized_len
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor | None = None):
+ x_fp32 = x.float()
+ c_fp32 = c.float() if (self.has_cond and c is not None) else None
+ if self.has_cond and c_fp32 is not None:
+ shift, scale = self.adaLN(c_fp32)[:, None].chunk(2, dim=2)
+ x_fp32 = modulate_fused(self.norm(x_fp32), shift, scale)
+ else:
+ x_fp32 = self.norm(x_fp32)
+ s = self.proj2(self.act(self.proj1(x_fp32)))
+ out = self.softplus(s).squeeze(-1) * self.normalized_len
+ return out.to(x.dtype)
+
+
+#################################################################################
+# Core Model #
+#################################################################################
+
+
+def get_mask_mod(seq_len: torch.Tensor):
+ def mask_mod(b, h, q_idx, kv_idx):
+ return (q_idx <= seq_len[b]) & (kv_idx <= seq_len[b])
+
+ return mask_mod
+
+
+class DDiTBlock(nn.Module):
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
+ super().__init__()
+ self.n_heads = n_heads
+
+ self.norm1 = LayerNorm(dim)
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
+ self.attn_out = nn.Linear(dim, dim, bias=False)
+ self.dropout1 = nn.Dropout(dropout)
+
+ self.norm2 = LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_ratio * dim, dim, bias=True),
+ )
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.dropout = dropout
+
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+ def _get_bias_dropout_scale(self):
+ return (
+ bias_dropout_add_scale_fused_train
+ if self.training
+ else bias_dropout_add_scale_fused_inference
+ )
+
+ def forward(self, x, rotary_cos_sin, c, block_mask):
+ batch_size = x.shape[0]
+
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
+ )
+
+ # attention operation
+ x_skip = x
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
+ # dtype0 = x.dtype
+
+ qkv = self.attn_qkv(x)
+ qkv = rearrange(
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads
+ )
+ with torch.amp.autocast("cuda", enabled=False):
+ cos, sin = rotary_cos_sin
+ qkv = rotary.apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
+
+ q, k, v = rearrange(qkv, "b s three h d -> three b h s d", three=3)
+
+ x = flex_attention(q, k, v, block_mask=block_mask)
+
+ x = rearrange(x, "b h s d -> b s (h d)", b=batch_size)
+
+ x = bias_dropout_scale_fn(
+ self.attn_out(x), None, gate_msa, x_skip, self.dropout
+ )
+
+ # mlp operation
+ x = bias_dropout_scale_fn(
+ self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)),
+ None,
+ gate_mlp,
+ x,
+ self.dropout,
+ )
+
+ return x
+
+
+class EmbeddingLayer(nn.Module):
+ def __init__(self, dim, vocab_dim):
+ super().__init__()
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
+
+ def forward(self, x):
+ return self.embedding[x]
+
+
+class DDitFinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, cond_dim):
+ super().__init__()
+ self.norm_final = LayerNorm(hidden_size)
+ self.linear = nn.Linear(hidden_size, out_channels)
+ self.linear.weight.data.zero_()
+ self.linear.bias.data.zero_()
+
+ self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
+ self.adaLN_modulation.weight.data.zero_()
+ self.adaLN_modulation.bias.data.zero_()
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
+ x = modulate_fused(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class AnyOrderMaskInsertionFlow(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # hack to make loading in configs easier
+ if isinstance(config, dict):
+ config = OmegaConf.create(config)
+
+ self.config = config
+ self.vocab_size = config.interpolant.tokens
+ self.pad_token = config.interpolant.pad_token
+ self.mask_token = config.interpolant.mask_token
+
+ # Get dtype from config, default to bfloat16
+ dtype_str = config.model.get('torch_dtype', 'bfloat16')
+ self.dtype = getattr(torch, dtype_str)
+
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size)
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
+ self.rotary_emb = rotary.Rotary(
+ config.model.hidden_size // config.model.n_heads
+ )
+
+ self.blocks = nn.ModuleList(
+ [
+ DDiTBlock(
+ config.model.hidden_size,
+ config.model.n_heads,
+ config.model.cond_dim,
+ dropout=config.model.dropout,
+ )
+ for _ in range(config.model.n_blocks)
+ ]
+ )
+
+ self.output_layer = DDitFinalLayer(
+ config.model.hidden_size, self.vocab_size, config.model.cond_dim
+ )
+
+ self.len_predict_type = config.training.loss_fn.insert
+ if self.len_predict_type == "distribution":
+ self.len_pred = DDitFinalLayer(
+ config.model.hidden_size,
+ config.interpolant.max_length + 1,
+ config.model.cond_dim,
+ )
+ elif self.len_predict_type == "expectation":
+ normalized_len = config.interpolant.max_length
+ self.len_pred = ScalarLengthHead(
+ config.model.hidden_size, normalized_len, config.model.cond_dim
+ )
+ else:
+ raise ValueError(f"Invalid length prediction type: {self.len_predict_type}")
+
+ def _get_bias_dropout_scale(self):
+ return (
+ bias_dropout_add_scale_fused_train
+ if self.training
+ else bias_dropout_add_scale_fused_inference
+ )
+
+ def forward(self, indices: torch.Tensor, t: torch.Tensor, return_features: bool = False):
+ B, L = indices.shape
+ indices = torch.cat(
+ [
+ indices,
+ self.pad_token
+ * torch.ones((B, 1), device=indices.device, dtype=torch.int64),
+ ],
+ dim=-1,
+ )
+ seq_lens = (indices != self.pad_token).sum(dim=-1).to(indices.device)
+ block_mask = create_block_mask(
+ get_mask_mod(seq_lens),
+ B=B,
+ H=None,
+ Q_LEN=indices.shape[1],
+ KV_LEN=indices.shape[1],
+ device=indices.device,
+ )
+
+ x = self.vocab_embed(indices)
+ c = F.silu(self.sigma_map(t))
+
+ rotary_cos_sin = self.rotary_emb(x)
+
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](x, rotary_cos_sin, c, block_mask)
+
+ # Store features after transformer blocks for optional sharing
+ features = x.clone() if return_features else None
+
+ # --- unmasking ---
+ token_logits = self.output_layer(x[:, :-1], c)
+
+ # --- length prediction ---
+ match self.len_predict_type:
+ case "distribution":
+ length_posterior = self.len_pred(x, c)
+ prediction = ModelPrediction(
+ token_logits=token_logits,
+ length_posterior=length_posterior,
+ )
+ case "expectation":
+ prediction = ModelPrediction(
+ token_logits=token_logits,
+ expected_gaps=self.len_pred(x, c),
+ )
+
+ if return_features:
+ return prediction, features # [B, L+1, hidden_dim]
+ else:
+ return prediction
+
+ def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
+ """Returns token logits, hidden states, and conditioning for adapter training."""
+ B, L = indices.shape
+ indices = torch.cat(
+ [
+ indices,
+ self.pad_token
+ * torch.ones((B, 1), device=indices.device, dtype=torch.int64),
+ ],
+ dim=-1,
+ )
+ seq_lens = (indices != self.pad_token).sum(dim=-1).to(indices.device)
+ block_mask = create_block_mask(
+ get_mask_mod(seq_lens),
+ B=B,
+ H=None,
+ Q_LEN=indices.shape[1],
+ KV_LEN=indices.shape[1],
+ device=indices.device,
+ )
+
+ x = self.vocab_embed(indices)
+ c = F.silu(self.sigma_map(t))
+
+ rotary_cos_sin = self.rotary_emb(x)
+
+ with torch.amp.autocast("cuda", dtype=self.dtype):
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](x, rotary_cos_sin, c, block_mask)
+
+ # Hidden states after transformer blocks
+ hidden_states = x[:, :-1] # [B, L, hidden_dim] - exclude padding position
+
+ # Token logits
+ token_logits = self.output_layer(hidden_states, c)
+
+ return token_logits, hidden_states, c