diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..6e8075b1ea694b70e967356b554bddda7b77f13d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.gif filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..01080488b905578eb4f1196ac2213c50a5682c4b --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +checkpoints/ +pretrained/ +__pycache__/ +results/ +a2d2_language/ +a2d2_language/wandb/ +a2d2_pep/wandb/ +a2d2_mol/wandb/ +logs/ +*.pt +*.pyc +*.out +*.json +*.log +*.txt +*.wandb \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..839c375c02da75f7ae46b3c50f8689a788fcdcee --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Sophia Tang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fef6ef0d7a5f8213300b5372dfe9d0a198de0db8 --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# [A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding](https://arxiv.org/abs/2606.13565) 🃏🔮 + +[**Sophia Tang**](https://sophtang.github.io/), [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/), [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/) + +

+ arXiv + Project Page +

+ +![A2D2](assets/a2d2.gif) + +This is the repository for the paper [**A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**](https://arxiv.org/abs/2606.13565). + +Masked discrete diffusion models (MDMs) offer a simple, stable likelihood-based framework for sequence generation, recently extended to **any-length** settings via token insertion. **A2D2** is a unified framework for reward-guided fine-tuning of any-length MDMs that **jointly optimizes the insertion and unmasking policies together with a quality-based inference schedule**, converging to the intractable reward-tilted distribution without requiring target samples. + +🃏 We derive the **Radon–Nikodym derivative** for the joint insertion–unmasking path measures, enabling theoretically guaranteed convergence to the reward-tilted sequence distribution. + +🃏 We establish **unmasking and insertion quality** as tractable approaches for minimizing decoding error (compounding parallelization error), and train lightweight quality predictors alongside the policy. + +🃏 We introduce the **Adaptive Joint Decoding (AJD)** loss, which provably yields the optimal path measure that generates the reward-tilted distribution while remasking low-quality tokens and dropping low-quality insertions at inference. + +🃏 Empirically, A2D2 improves reward optimization while enhancing generation **flexibility** and **accuracy** over prior fixed-length fine-tuning and inference-time guidance methods. + +## Drug-Like Small Molecule Design 🧪 + +We pre-train an any-length MDM on the **SAFE** dataset ([Noutahi et al. 2024](https://arxiv.org/abs/2310.10773), ~950M molecules from ZINC and Unichem in SAFE notation) and fine-tune it with **A2D2** to optimize **QED** (drug-likeness) and **synthetic accessibility (SA)**. A2D2 jointly raises QED and lowers SA over the pre-trained baseline while increasing the fraction of valid, unique, drug-like, and synthesizable molecules. Code and instructions are in [`/a2d2_mol`](a2d2_mol). + +## Multi-Objective Therapeutic Peptide Generation 💉 + +We pre-train an any-length **peptide SMILES** MDM on ~11M peptides (CycPeptMPDB, SmProt, CycloPs) and fine-tune with **A2D2** on five therapeutic properties simultaneously: **target-protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability**. A2D2 outperforms inference-time multi-objective guidance and fixed-length off-policy RL fine-tuning on almost all objectives, while improving the fraction of valid peptides. Code and instructions are in [`/a2d2_pep`](a2d2_pep). + +## Language Model Reasoning 🧠 + +We additionally apply **A2D2** to reward fine-tuning of any-length language MDMs (LLaDA / FlexMDM), optimizing math-reasoning correctness and format rewards (GSM8K / MATH), including infilling variants. Code is in [`/a2d2_language`](a2d2_language). + +## Repository Structure + +| Directory | Experiment | +|-----------|------------| +| [`a2d2_mol`](a2d2_mol) | Drug-like small molecule design (QED, SA) | +| [`a2d2_pep`](a2d2_pep) | Multi-objective therapeutic peptide generation | +| [`a2d2_language`](a2d2_language) | Language model reasoning reward fine-tuning (code soon) | +| [`lightning_modules`](lightning_modules) | Any-length insertion MDM Lightning modules (policy + quality predictors) | +| [`model`](model) | Shared model architecture | +| [`demo`](demo) | Quality-guided inference demo notebook | + +Each experiment directory contains its own `README.md` with environment setup, pretrained weight placement, fine-tuning commands, and evaluation instructions. + +## Citation + +If you find this repository helpful for your publications, please consider citing our paper: + +```python +@article{tang2026a2d2, + title={A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding}, + author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee}, + journal={arXiv preprint arXiv:2606.13565}, + year={2026} +} +``` + +To use this repository, you agree to abide by the MIT License. \ No newline at end of file diff --git a/a2d2_mol/README.md b/a2d2_mol/README.md new file mode 100644 index 0000000000000000000000000000000000000000..61db0ba8e8ffbe7b929b1dd17480e9248fbe81ca --- /dev/null +++ b/a2d2_mol/README.md @@ -0,0 +1,132 @@ +# A2D2 for Molecule Generation 🧪 + +This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA). + +A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality. + +Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer. + +The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main). + +## Environment Installation +``` +# from the repository root +conda env create -f environment.yml + +conda activate a2d2 +``` +The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step. + +## Model Pretrained Weights + +A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at: +``` +A2D2/pretrained/anylength_mol.ckpt +``` +```bash +# from the repository root +pip install gdown +mkdir -p pretrained +gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt +``` +(Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.) +This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout. + +## Pretraining the Any-Length Model + +If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch. + +### 1. The pretraining dataset + +The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run. + +The dataset is configured in [`config_mol.yaml`](config_mol.yaml): + +```yaml +hf_dataset: + name: "datamol-io/safe-gpt" + smiles_column: "smiles" +``` + +To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column). + +### 2. Configure + +Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields: + +| Field | Default | Notes | +|-------|---------|-------| +| `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). | +| `training.devices` | `2` | GPUs per node (DDP). | +| `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. | +| `training.max_steps` | `500000` | Total optimizer steps. | +| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. | +| `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). | +| `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. | +| `interpolant.max_length` | `256` | Max token length. | + +### 3. Pre-training Any-Length Molecule Model + +Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job: + +```bash +# from a2d2_mol/ +sbatch train_mol.sh +``` + +`train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of: + +```bash +python train.py --task mol +``` + +It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`. + +Checkpoints are written to `checkpoints/pretrain_mol//` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/_a2d2-mol_.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config. + +## Fine-Tune with A2D2 + +The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running: + +1. Set `--base_path` to the location of `a2d2`. Results plots are written to `/flexible/results//`. +2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`. + +### Single run + +[`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand. + +The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`): +```bash +export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone +export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH +sbatch scripts/run_mol_finetune.slurm +``` + +Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time: +```bash +sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm +``` +The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/_mol_/` and `results/mol_ablation//`. + +### Ablation flags +| Flag | Variant | +|------|---------| +| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) | +| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) | +| `--disable_insertion_planner` | A2D2 w/o insertion quality | +| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality | +| `--joint_training` | train policy + quality heads jointly (no alternation) | + +## Evaluation + +Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually: +``` +python evaluate_mol_table.py \ + --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \ + --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \ + --output_dir /path/to/results \ + --num_samples 1000 --batch_size 50 \ + --max_length 256 --total_num_steps 256 \ + --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0 +``` +This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_.csv`. diff --git a/a2d2_mol/config_mol.yaml b/a2d2_mol/config_mol.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a2a659d8417aba012326c69218188b80aa9e4a4 --- /dev/null +++ b/a2d2_mol/config_mol.yaml @@ -0,0 +1,54 @@ +trainer: "any-order-flow" +dataset: "safe-drugs" + +# HuggingFace dataset configuration +hf_dataset: + name: "datamol-io/safe-gpt" + smiles_column: "smiles" # Adjust based on actual column name in the dataset + +model: + hidden_size: 768 + n_heads: 12 + cond_dim: 128 + dropout: 0.05 + n_blocks: 12 + torch_dtype: 'float32' # Options: 'float32', 'float16', 'bfloat16' + +interpolant: + type: "any-order" + tokens: null # filled in automatically + pad_token: null # filled in automatically + mask_token: null # filled in automatically + max_length: 256 + insert_schedule: + type: "linear" + unmask_schedule: + type: "linear" + +training: + only_embed_insert: true + batch_size: 2048 + per_gpu_batch_size: 64 # Gradient accumulation happens automatically + cpus: 4 + learning_rate: 3e-4 + nodes: 1 + devices: 2 + max_steps: 500000 + weight_decay: 0.03 + checkpoint_dir: "checkpoints/pretrain_mol" + save_top_k: 3 + save_every_n_steps: 1000 # Save checkpoint every 1k steps (for streaming datasets) + # save_every_n_epochs: 1 # Not used with streaming datasets + loss_fn: + unmask: "elbo" + insert: "expectation" + reset_lr: false + warmup_steps: 2000 + ema_decay: 0.9999 + filter_max_length: false + +wandb: + entity: null # set to your W&B entity, or leave null to use the default + project: "a2d2-mol" + name: "a2d2-mol" + path: "./wandb" diff --git a/a2d2_mol/evaluate_mol_table.py b/a2d2_mol/evaluate_mol_table.py new file mode 100644 index 0000000000000000000000000000000000000000..103b277a1d0fa634be4e4b8e465bfe524e242327 --- /dev/null +++ b/a2d2_mol/evaluate_mol_table.py @@ -0,0 +1,308 @@ +""" +Evaluate a finetuned molecule model checkpoint by sampling sequences +and computing metrics for the De Novo Small Molecule Generation table: + Validity (%), Uniqueness (%), QED (↑), SA (↓), Quality (%), Diversity (↑), Sampling Time (↓) +""" + +import os +import sys +import argparse +import time +import torch +import numpy as np +import pandas as pd +from tdc import Oracle, Evaluator + +# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) + +from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT +from lightning_modules import AnyOrderInsertionFlowModule +from inference_quality_mol import sample_mol_eval +from mol_scoring.scoring_functions import MolScoringFunctions +from finetune_mol import MolFinetuner, get_tokenizer +from mol_utils.utils import str2bool, set_seed + + +def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'): + """Load a finetuned MolFinetuner from a Lightning checkpoint.""" + # We need to reconstruct the model the same way main() does, then load state + # Load from Lightning checkpoint directly + ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + hparams = ckpt.get('hyper_parameters', {}) + args = hparams.get('args', None) + + # Load pretrained base checkpoint to get config + base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False) + if 'hyper_parameters' in base_ckpt: + config = base_ckpt['hyper_parameters']['config'] + elif 'config' in base_ckpt: + config = base_ckpt['config'] + else: + raise ValueError("Cannot find config in base checkpoint") + + from omegaconf import OmegaConf, DictConfig + if not OmegaConf.is_config(config): + config = DictConfig(config) + OmegaConf.set_struct(config, False) + + # Set adaptive schedule config from args or defaults + config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True) + config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256) + config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2) + config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1) + config.training.freeze_base_model = getattr(args, 'freeze_base_model', False) + config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0) + config.training.use_bracket_safe = True + OmegaConf.set_struct(config, True) + + # Determine if planner should be loaded based on disable_planner flag + disable_planner = getattr(args, 'disable_planner', False) + + # Initialize policy model + policy_model = AnyOrderInsertionFlowModuleFT( + config=config, + args=args, + pretrained_checkpoint=pretrained_ckpt_path, + insertion_planner=not disable_planner, + ) + + # Load policy model weights from the finetuned checkpoint + state_dict = ckpt['state_dict'] + # Lightning wraps the model: 'policy_model.xxx' -> remove prefix for the sub-module + policy_state = {} + for k, v in state_dict.items(): + if k.startswith('policy_model.'): + policy_state[k[len('policy_model.'):]] = v + policy_model.load_state_dict(policy_state, strict=False) + policy_model = policy_model.to(device) + policy_model.eval() + + return policy_model, args, config + + +@torch.no_grad() +def evaluate_checkpoint(policy_model, tokenizer, reward_model, evaluator, + num_samples=1000, batch_size=50, max_length=256, + total_num_steps=256, quality_mode="both", num_remasking=2, + quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'): + """ + Sample `num_samples` molecules and compute all table metrics. + Returns a dict with: validity, uniqueness, qed, sa, quality, diversity, sampling_time + """ + all_valid_seqs = [] + all_smiles_generated = 0 + total_time = 0.0 + + num_batches = (num_samples + batch_size - 1) // batch_size + remaining = num_samples + + for b in range(num_batches): + bs = min(batch_size, remaining) + remaining -= bs + + t_start = time.time() + result = sample_mol_eval( + model=policy_model, + reward_model=reward_model, + tokenizer=tokenizer, + steps=total_num_steps, + mask=policy_model.interpolant.mask_token, + pad=policy_model.interpolant.pad_token, + batch_size=bs, + max_length=max_length, + quality_mode=quality_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + unmask_quality_threshold=unmask_quality_threshold, + evaluator=evaluator, + dataframe=True, + ) + t_end = time.time() + + # Unpack: uniqueSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df + unique_seqs, qed_scores, sa_scores, valid_frac, uniq, div, qual, df = result + + all_valid_seqs.extend(list(unique_seqs) if not isinstance(unique_seqs, list) else unique_seqs) + all_smiles_generated += bs + total_time += (t_end - t_start) + + print(f" Batch {b+1}/{num_batches}: {len(unique_seqs)} valid unique, " + f"time={t_end - t_start:.1f}s") + + # --- Aggregate metrics over all samples --- + total_generated = num_samples + + # Valid sequences (keeping duplicates for validity count) + # Re-evaluate from scratch on all collected valid sequences + all_unique = list(set(all_valid_seqs)) + num_valid = len(all_valid_seqs) # total valid across batches (before dedup) + num_unique = len(all_unique) + + validity = num_valid / total_generated * 100.0 + uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0 + + # Diversity on unique SMILES + diversity = evaluator(all_unique) if num_unique > 1 else 0.0 + + # QED and SA on unique sequences + if num_unique > 0: + oracle_qed = Oracle('qed') + oracle_sa = Oracle('sa') + qed_vals = oracle_qed(all_unique) + sa_vals = oracle_sa(all_unique) + mean_qed = np.mean(qed_vals) + mean_sa = np.mean(sa_vals) + + # Quality: unique sequences with QED >= 0.6 AND SA <= 4 + quality_mask = [(q >= 0.6 and s <= 4) for q, s in zip(qed_vals, sa_vals)] + quality = sum(quality_mask) / total_generated * 100.0 + else: + mean_qed = 0.0 + mean_sa = 0.0 + quality = 0.0 + + sampling_time = total_time + + metrics = { + 'Validity (%)': validity, + 'Uniqueness (%)': uniqueness, + 'QED': mean_qed, + 'Synthetic Accessibility': mean_sa, + 'Quality (%)': quality, + 'Diversity': diversity, + 'Sampling Time (s)': sampling_time, + 'Num Generated': total_generated, + 'Num Valid': num_valid, + 'Num Unique': num_unique, + } + + return metrics, all_unique, qed_vals if num_unique > 0 else [], sa_vals if num_unique > 0 else [] + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate a finetuned mol checkpoint") + parser.add_argument('--checkpoint_path', type=str, required=True, + help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)') + parser.add_argument('--pretrained_ckpt', type=str, + default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt'), + help='Path to the pretrained base model checkpoint ' + '(defaults to /pretrained/anylength_mol.ckpt)') + parser.add_argument('--num_samples', type=int, default=1000, + help='Number of molecules to sample') + parser.add_argument('--batch_size', type=int, default=50, + help='Batch size for sampling') + parser.add_argument('--max_length', type=int, default=256) + parser.add_argument('--total_num_steps', type=int, default=256) + parser.add_argument('--num_remasking', type=int, default=2) + parser.add_argument('--disable_planner', action='store_true', + help='If set, disable remasking during evaluation (matches training mode)') + parser.add_argument('--disable_insertion_planner', action='store_true', + help='If set, disable insertion quality filtering during evaluation') + parser.add_argument('--disable_unmasking_planner', action='store_true', + help='If set, disable unmasking confidence planner during evaluation') + parser.add_argument('--quality_threshold', type=float, default=0.5, + help='Threshold for insertion quality filtering during sampling') + parser.add_argument('--unmask_quality_threshold', type=float, default=None, + help='If set, gate unmasking remasking on confidence: remask clean ' + 'tokens whose remasking_conf < threshold (overrides the ' + 'schedule-driven count). Default None = schedule-driven behavior.') + parser.add_argument('--output_dir', type=str, default=None, + help='Directory to save results CSV. Defaults to checkpoint directory.') + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--seed', type=int, default=42) + args = parser.parse_args() + + set_seed(args.seed, use_cuda=True) + device = torch.device(args.device if torch.cuda.is_available() else 'cpu') + + print(f"Loading checkpoint: {args.checkpoint_path}") + print(f"Pretrained base: {args.pretrained_ckpt}") + print(f"Disable planner (no remasking): {args.disable_planner}") + print(f"Disable insertion planner: {args.disable_insertion_planner}") + print(f"Disable unmasking planner: {args.disable_unmasking_planner}") + + policy_model, train_args, config = load_finetuned_model( + args.checkpoint_path, args.pretrained_ckpt, device=device + ) + + tokenizer = get_tokenizer() + score_func_names = ['qed', 'sa'] + reward_model = MolScoringFunctions(score_func_names, device=device) + evaluator = Evaluator('diversity') + + use_remasking = not args.disable_planner + disable_insertion_planner = args.disable_insertion_planner + disable_unmasking_planner = args.disable_unmasking_planner + + # Map flags to quality_mode + if args.disable_planner: + quality_mode = "none" + elif args.disable_insertion_planner and args.disable_unmasking_planner: + quality_mode = "none" + elif args.disable_insertion_planner: + quality_mode = "unmasking_only" + elif args.disable_unmasking_planner: + quality_mode = "insertion_only" + else: + quality_mode = "both" + + print(f"\nSampling {args.num_samples} molecules (quality_mode={quality_mode})...") + + metrics, unique_smiles, qed_vals, sa_vals = evaluate_checkpoint( + policy_model=policy_model, + tokenizer=tokenizer, + reward_model=reward_model, + evaluator=evaluator, + num_samples=args.num_samples, + batch_size=args.batch_size, + max_length=args.max_length, + total_num_steps=args.total_num_steps, + quality_mode=quality_mode, + num_remasking=args.num_remasking, + quality_threshold=getattr(args, 'quality_threshold', 0.5), + unmask_quality_threshold=args.unmask_quality_threshold, + device=device, + ) + + # Print summary table + print("\n" + "=" * 60) + print(" De Novo Small Molecule Generation Results") + print("=" * 60) + for k, v in metrics.items(): + if isinstance(v, float): + print(f" {k:<30s}: {v:.4f}") + else: + print(f" {k:<30s}: {v}") + print("=" * 60) + + # Save results + output_dir = args.output_dir or os.path.dirname(args.checkpoint_path) + os.makedirs(output_dir, exist_ok=True) + + if args.disable_planner: + tag = "no_planner" + elif args.disable_insertion_planner: + tag = "no_insertion_planner" + elif args.disable_unmasking_planner: + tag = "no_unmasking_planner" + else: + tag = "with_planner" + metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}.csv') + pd.DataFrame([metrics]).to_csv(metrics_path, index=False) + print(f"Metrics saved to: {metrics_path}") + + if unique_smiles: + smiles_path = os.path.join(output_dir, f'eval_smiles_{tag}.csv') + df = pd.DataFrame({ + 'SMILES': unique_smiles, + 'QED': qed_vals, + 'SA': sa_vals, + }) + df.to_csv(smiles_path, index=False) + print(f"SMILES saved to: {smiles_path}") + + +if __name__ == '__main__': + main() diff --git a/a2d2_mol/finetune_mol.py b/a2d2_mol/finetune_mol.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d84ebba28cd76e46ef5545662f59a47206d089 --- /dev/null +++ b/a2d2_mol/finetune_mol.py @@ -0,0 +1,747 @@ +import argparse +from datetime import datetime +import numpy as np +import torch +import pytorch_lightning as pl +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +import wandb +import os +import sys +from tqdm import tqdm +import pandas as pd + +# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# imports +from inference_quality_mol import sample_mol_buffer, sample_mol_eval +from mol_utils.utils import str2bool, set_seed +from mol_scoring.scoring_functions import MolScoringFunctions +from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT +from lightning_modules import AnyOrderInsertionFlowModule +from safe.tokenizer import SAFETokenizer +from tdc import Evaluator + +# Repository root (two levels up from this file: A2D2/a2d2_mol/finetune_mol.py) +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def get_tokenizer(): + """Get SAFE tokenizer with added special tokens.""" + tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained() + tk.add_tokens(['<', '>']) # for bracket_safe + return tk + +class MolFinetuner(pl.LightningModule): + """Lightning module for distributed molecule finetuning.""" + + def __init__( + self, + args, + policy_model, + reward_model, + tokenizer, + pretrained=None, + mcts=None, + filename=None, + eps=1e-5 + ): + super().__init__() + self.args = args + self.policy_model = policy_model + self.reward_model = reward_model + self.tokenizer = tokenizer + self.pretrained = pretrained + self.mcts = mcts + self.filename = filename + self.eps = eps + + self.evaluator = Evaluator("diversity") + + # Save hyperparameters + self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts']) + + # Buffer for sequences + self.x_saved = None + self.log_rnd_saved = None + self.final_rewards_saved = None + + # initialize logs + self.valid_fraction_log = [] + self.diversity_log = [] + self.qed_log = [] + self.sa_log = [] + self.quality_log = [] + self.uniqueness_log = [] + + # Alternating training between policy and planner + self.train_policy = True # Start by training policy + self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs + + def freeze_policy_model(self): + """Freeze policy model parameters (but not planner).""" + for name, param in self.policy_model.named_parameters(): + if not name.startswith('planner.'): + param.requires_grad = False + + def unfreeze_policy_model(self): + """Unfreeze policy model parameters (but not planner).""" + for name, param in self.policy_model.named_parameters(): + if not name.startswith('planner.'): + param.requires_grad = True + + def freeze_planner_model(self): + """Freeze planner parameters.""" + if hasattr(self.policy_model, 'planner'): + for param in self.policy_model.planner.parameters(): + param.requires_grad = False + + def unfreeze_planner_model(self): + """Unfreeze planner parameters.""" + if hasattr(self.policy_model, 'planner'): + for param in self.policy_model.planner.parameters(): + param.requires_grad = True + + def configure_optimizers(self): + # Separate parameter groups for policy backbone vs planner heads + planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate) + planner_params = [] + policy_params = [] + for name, param in self.policy_model.named_parameters(): + if name.startswith('planner.'): + planner_params.append(param) + else: + policy_params.append(param) + + param_groups = [ + {'params': policy_params, 'lr': self.args.learning_rate}, + {'params': planner_params, 'lr': planner_lr}, + ] + optimizer = torch.optim.AdamW(param_groups) + return optimizer + + def _get_quality_mode(self): + """Map ablation flags + warmup state to quality_mode string.""" + if self.args.disable_planner: + return "none" + if self.current_epoch < self.args.schedule_warmup_epochs: + return "none" + di = getattr(self.args, 'disable_insertion_planner', False) + du = getattr(self.args, 'disable_unmasking_planner', False) + if di and du: + return "none" + if di: + return "unmasking_only" + if du: + return "insertion_only" + return "both" + + def on_train_epoch_start(self): + """Called at the start of each training epoch.""" + + # If disable_planner mode, only train policy (no alternation) + if self.args.disable_planner: + self.train_policy = True + self.unfreeze_policy_model() + self.freeze_planner_model() + if self.global_rank == 0 and self.current_epoch == 0: + print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)") + + elif getattr(self.args, 'joint_training', False): + # Joint mode: train policy + planner together every step (no alternation) + self.train_policy = True # marker; training_step adds planner loss when joint_training is set + self.unfreeze_policy_model() + self.unfreeze_planner_model() + if self.global_rank == 0 and self.current_epoch == 0: + print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)") + + else: + # Alternate between training policy and planner from epoch 0 + # Determine which model to train this epoch + cycle_position = (self.current_epoch // self.alternation_frequency) % 2 + self.train_policy = (cycle_position == 0) + + if self.train_policy: + # Train policy, freeze planner + self.unfreeze_policy_model() + self.freeze_planner_model() + if self.global_rank == 0: + print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)") + else: + # Train planner, freeze policy + self.freeze_policy_model() + self.unfreeze_planner_model() + if self.global_rank == 0: + print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)") + + # Resample buffer if needed + if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0: + if self.global_rank == 0: + print(f"[BUFFER] Starting buffer generation for epoch {self.current_epoch}") + self._generate_buffer() + # Synchronize all ranks after buffer generation + if self.trainer and self.trainer.world_size > 1: + if self.global_rank == 0: + print(f"[BUFFER] All ranks completed buffer generation, synchronizing...") + torch.distributed.barrier() + if self.global_rank == 0: + print(f"[BUFFER] Synchronization complete!") + + def _generate_buffer(self): + """Generate buffer of sequences for training. + + When pool_size > 0, maintains a persistent pool and refreshes a fraction + each time instead of regenerating the entire buffer from scratch. + """ + rank = self.global_rank if self.trainer else 0 + world_size = self.trainer.world_size if self.trainer else 1 + + pool_size = getattr(self.args, 'pool_size', 0) + is_pool = pool_size > 0 + is_init = self.x_saved is None + + # Determine how many molecules to sample this call + if is_pool: + refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2) + if is_init: + samples_per_gpu = pool_size + else: + samples_per_gpu = max(1, int(pool_size * refresh_frac)) + if rank == 0: + if is_init: + print(f"\n[POOL] Initializing pool with {pool_size} molecules at epoch {self.current_epoch}") + else: + print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} molecules ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}") + else: + samples_per_gpu = self.args.buffer_size // world_size + if rank == 0: + samples_per_gpu += self.args.buffer_size % world_size + + if rank == 0: + print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}") + + accumulated_x = [] + accumulated_log_rnd = [] + accumulated_rewards = [] + total_accumulated = 0 + + max_attempts = 100 # Prevent infinite loop + attempts = 0 + + import time + while total_accumulated < samples_per_gpu and attempts < max_attempts: + attempts += 1 + if rank == 0: + print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}") + + start_time = time.time() + + x_final, log_rnd, final_rewards, trace = \ + sample_mol_buffer( + self.policy_model, + self.pretrained, + self.reward_model, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=self.args.batch_size, + max_length=self.args.max_length, + quality_mode=self._get_quality_mode(), + alpha=self.args.alpha, + num_remasking=self.args.num_remasking, + quality_threshold=self.args.quality_threshold, + use_quality_filter=self.args.use_quality_filter, + ) + if self.args.elbo_rnd: + # Override trajectory log_rnd with forward ELBO estimate + if x_final.shape[0] > 0: + with torch.no_grad(): + noised = self.policy_model.prepare_noised_sample( + x_final, num_samples=self.args.elbo_rnd_num_samples) + policy_loss = self.policy_model.compute_loss_from_noised(noised) + pretrained_loss = self.pretrained.compute_loss_from_noised(noised) + log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha) + + elapsed = time.time() - start_time + if rank == 0: + print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s") + + n_valid = x_final.shape[0] + if n_valid > 0: + accumulated_x.append(x_final) + accumulated_log_rnd.append(log_rnd) + accumulated_rewards.append(final_rewards) + total_accumulated += n_valid + + if rank == 0: + qm = self._get_quality_mode() + print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode={qm} accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}") + + if total_accumulated == 0: + raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.") + + if total_accumulated < samples_per_gpu: + print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts") + + new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu] + new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu] + new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu] + + del accumulated_x, accumulated_log_rnd, accumulated_rewards + torch.cuda.empty_cache() + + # add to buffer: pool mode replaces a random subset, classic mode overwrites + if is_pool and not is_init: + actual_new = min(new_x.shape[0], self.x_saved.shape[0]) + indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new] + self.x_saved[indices] = new_x[:actual_new] + self.log_rnd_saved[indices] = new_log_rnd[:actual_new] + self.final_rewards_saved[indices] = new_rewards[:actual_new] + if rank == 0: + print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} molecules, reward mean={self.final_rewards_saved.mean():.4f}") + else: + self.x_saved = new_x + self.log_rnd_saved = new_log_rnd + self.final_rewards_saved = new_rewards + + if rank == 0: + print(f"[BUFFER] After cleanup - GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + + def training_step(self, batch, batch_idx): + """Training step - batch is ignored, we use saved buffer.""" + # Process buffer in mini-batches to avoid OOM + mini_batch_size = getattr(self.args, 'training_mini_batch_size', 8) + buffer_size = self.x_saved.shape[0] + + # Randomly sample a mini-batch from buffer + indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size] + x_final = self.x_saved[indices] + + # get log_rnd values + log_rnd = self.log_rnd_saved[indices] + + sm_temp = getattr(self.args, 'softmax_temperature', 1.0) + + joint = getattr(self.args, 'joint_training', False) + policy_loss = None + planner_loss = None + + if self.train_policy: + # Train policy with WDCE loss + policy_loss = self.policy_model.loss_wdce_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength, + softmax_temperature=sm_temp, + ) + self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + if (not self.train_policy) or joint: + # Train planner with appropriate loss based on ablation flags + if self.args.disable_insertion_planner: + # Ablation: only train unmasking planner (no insertion head) + planner_loss = self.policy_model.loss_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength, + softmax_temperature=sm_temp, + ) + self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + elif self.args.disable_unmasking_planner: + # only train insertion planner (no remasking head) + unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength, + softmax_temperature=sm_temp, + ) + # Zero out the unmasking component - only backprop insertion loss + planner_loss = insert_loss + self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + else: + # Full planner: train both remasking + insertion + unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength, + softmax_temperature=sm_temp, + ) + self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # Combine losses depending on mode + if joint: + loss = policy_loss + planner_loss + mode_value = 0.5 + elif self.train_policy: + loss = policy_loss + mode_value = 0.0 + else: + loss = planner_loss + mode_value = 1.0 + + # Log overall loss and mode + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/mode', mode_value, prog_bar=True, sync_dist=True) + + return loss + + def on_train_epoch_end(self): + """Called at the end of each training epoch - only rank 0 evaluates.""" + # Only evaluate every N epochs to save time + eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5) + is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1) + if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch): + # Sample eval batch with updated policy + x_eval, qed, sa, uniqueness, diversity, quality, valid_fraction = \ + sample_mol_eval( + self.policy_model, self.reward_model, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=50, + max_length=self.args.max_length, + quality_mode=self._get_quality_mode(), + num_remasking=self.args.num_remasking, + evaluator=self.evaluator, + ) + + # Append to logs + self.valid_fraction_log.append(valid_fraction) + self.uniqueness_log.append(uniqueness) + self.diversity_log.append(diversity) + self.qed_log.append(qed) + self.sa_log.append(sa) + self.quality_log.append(quality) + + # Compute reward stats + mean_reward = self.final_rewards_saved.mean().item() + min_reward = self.final_rewards_saved.min().item() + max_reward = self.final_rewards_saved.max().item() + median_reward = self.final_rewards_saved.median().item() + + # Log metrics + self.log_dict({ + "eval/valid_fraction": valid_fraction, + "eval/uniqueness": np.mean(uniqueness), + "eval/diversity": np.mean(diversity), + "eval/qed": np.mean(qed), + "eval/sa": np.mean(sa), + "eval/quality": np.mean(quality), + "eval/mean_reward_search": mean_reward, + "eval/min_reward_search": min_reward, + "eval/max_reward_search": max_reward, + "eval/median_reward_search": median_reward + }) + + print(f"epoch {self.current_epoch} | validity {valid_fraction:.4f} | uniqueness {np.mean(uniqueness):.4f} | diversity {np.mean(diversity):.4f} | " + f"QED {np.mean(qed):.4f} | SA {np.mean(sa):.4f} | quality {np.mean(quality):.4f} | ") + + def on_fit_end(self): + """Called at the end of training - save results.""" + if self.global_rank == 0: + # Save logs and plot + base_path = self.args.base_path + plot_path = f'{base_path}/results/{self.args.run_name}' + os.makedirs(plot_path, exist_ok=True) + + output_log_path = f'{plot_path}/log_{self.filename}.csv' + save_logs_to_file(self.valid_fraction_log, self.uniqueness_log, + self.diversity_log, self.qed_log, self.sa_log, + self.quality_log, output_log_path) + + # Final generation + x_eval, qed, sa, valid_fraction, uniqueness, diversity, quality, df = \ + sample_mol_eval( + self.policy_model, self.reward_model, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=50, + max_length=self.args.max_length, + quality_mode=self._get_quality_mode(), + num_remasking=self.args.num_remasking, + evaluator=self.evaluator, + dataframe=True, + ) + df.to_csv(f'{plot_path}/mol_generation_results.csv', index=False) + + +def save_logs_to_file(valid_fraction_log, uniqueness_log, + diversity_log, qed_log, sa_log, + quality_log, output_path): + """ + Saves the logs to a CSV file. + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + log_data = { + "Iteration": list(range(1, len(valid_fraction_log) + 1)), + "Valid Fraction": valid_fraction_log, + "Uniqueness": uniqueness_log, + "Diversity": diversity_log, + "QED": qed_log, + "Synthetic Accessibility": sa_log, + "Quality": quality_log + } + + df = pd.DataFrame(log_data) + df.to_csv(output_path, index=False) + + +class DummyDataset(torch.utils.data.Dataset): + """Dummy dataset for Lightning trainer (we use buffer instead).""" + def __init__(self, size=100): + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return torch.zeros(1) # Dummy data + + +def main(): + """Main entry point for distributed training.""" + # Disable DDP optimizer for higher-order ops like flex_attention + import torch._dynamo + torch._dynamo.config.optimize_ddp = False + + argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + argparser.add_argument('--base_path', type=str, default=REPO_ROOT) + argparser.add_argument('--learning_rate', type=float, default=1e-4) + argparser.add_argument('--num_epochs', type=int, default=100) + argparser.add_argument('--num_accum_steps', type=int, default=4) + argparser.add_argument('--truncate_steps', type=int, default=50) + argparser.add_argument("--truncate_kl", type=str2bool, default=False) + argparser.add_argument('--gumbel_temp', type=float, default=1.0) + argparser.add_argument('--gradnorm_clip', type=float, default=1.0) + argparser.add_argument('--batch_size', type=int, default=50) + argparser.add_argument('--name', type=str, default='debug') + argparser.add_argument('--total_num_steps', type=int, default=128) + argparser.add_argument('--copy_flag_temp', type=float, default=None) + argparser.add_argument('--save_every_n_epochs', type=int, default=10) + argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time') + argparser.add_argument('--alpha_schedule_warmup', type=int, default=0) + argparser.add_argument("--seed", type=int, default=0) + # new + argparser.add_argument('--run_name', type=str, default='mol') + argparser.add_argument("--save_path_dir", default="", type=str) + # mcts + argparser.add_argument('--num_sequences', type=int, default=10) + argparser.add_argument('--max_length', type=int, default=1024) + argparser.add_argument('--num_children', type=int, default=50) + argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts + argparser.add_argument('--seq_length', type=int, default=1024) + argparser.add_argument('--time_conditioning', action='store_true', default=False) + argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise + argparser.add_argument('--buffer_size', type=int, default=100) + argparser.add_argument('--wdce_num_replicates', type=int, default=16) + argparser.add_argument('--noise_removal', action='store_true', default=False) + argparser.add_argument('--grad_clip', action='store_true', default=False) + argparser.add_argument('--resample_every_n_step', type=int, default=3) + argparser.add_argument('--exploration', type=float, default=0.1) + argparser.add_argument('--reset_every_n_step', type=int, default=100) + argparser.add_argument('--alpha', type=float, default=0.01) + argparser.add_argument('--scalarization', type=str, default='sum') + argparser.add_argument('--no_mcts', action='store_true', default=False) + argparser.add_argument("--centering", action='store_true', default=False) + argparser.add_argument("--centering_strength", type=float, default=1.0) + + # adaptive schedule parameters + argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True) + argparser.add_argument('--schedule_hidden_dim', type=int, default=256) + argparser.add_argument('--schedule_num_layers', type=int, default=2) + argparser.add_argument('--schedule_loss_weight', type=float, default=0.1) + argparser.add_argument('--adaptive_threshold', type=float, default=0.5) + argparser.add_argument('--freeze_base_model', action='store_true', default=False) + argparser.add_argument('--schedule_warmup_epochs', type=int, default=20, help='Number of initial epochs to train WITHOUT remasking in buffer generation') + argparser.add_argument('--alternation_frequency', type=int, default=5, help='Number of epochs to train each model before alternating (1=alternate every epoch)') + argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)') + + # objectives + argparser.add_argument('--num_obj', type=int, default=2) + argparser.add_argument('--devices', type=int, default=-1) + argparser.add_argument('--checkpoint_path', type=str, default=None) + + # ELBO-based log_rnd estimation + argparser.add_argument('--elbo_rnd', action='store_true', default=False, + help='If set, compute log_rnd via forward ELBO instead of trajectory rollout') + argparser.add_argument('--elbo_rnd_num_samples', type=int, default=4, + help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation') + + # remasking + argparser.add_argument('--num_remasking', type=int, default=5) + argparser.add_argument('--quality_threshold', type=float, default=1) + argparser.add_argument('--use_quality_filter', action='store_true', help='If set, filter buffer to only include molecules with QED>=0.6 and SA<=4') + argparser.add_argument('--training_mini_batch_size', type=int, default=8, help='Mini-batch size for training step to avoid OOM') + argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization') + argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner') + argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering') + argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.') + argparser.add_argument('--qed_only', action='store_true', help='If set, optimize only for QED score (no SA)') + argparser.add_argument('--softmax_temperature', type=float, default=1.0, + help='Temperature for softmax on importance weights (>1 smooths, prevents concentration)') + argparser.add_argument('--pool_size', type=int, default=0, + help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer)') + argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2, + help='Fraction of pool to replace each resample step (only used when pool_size>0)') + argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10, + help='Number of gradient updates per epoch (1=original, 10=recommended)') + + args = argparser.parse_args() + + # Default planner LR to policy LR if not specified + if args.planner_learning_rate is None: + args.planner_learning_rate = args.learning_rate + + # Set seed + pl.seed_everything(args.seed) + + # Load models + checkpoint_path = args.checkpoint_path if args.checkpoint_path else \ + os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt') + + curr_time = datetime.now().strftime("%Y%m%d_%H%M%S") + + if args.no_mcts: + args.run_name = f'mol_al_resample{args.resample_every_n_step}_no-mcts_{curr_time}' + else: + args.run_name = f'mol_al_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}' + + # append ablation tags to run name for easy identification + if args.disable_planner: + args.run_name += '_no_planner' + if args.disable_insertion_planner: + args.run_name += '_no_insertion_planner' + if args.disable_unmasking_planner: + args.run_name += '_no_unmasking_planner' + if args.joint_training: + if args.disable_planner: + raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)") + args.run_name += '_joint_training' + + args.save_path = os.path.join(args.save_path_dir, args.run_name) + os.makedirs(args.save_path, exist_ok=True) + set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers + + # Initialize the model + print("Loading models..") + + # Load pretrained model for reference (frozen) + pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path, + map_location='cpu', + weights_only=False) + pretrained.eval() + for param in pretrained.parameters(): + param.requires_grad = False + + # Load checkpoint to extract config + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + if 'hyper_parameters' in checkpoint: + config = checkpoint['hyper_parameters']['config'] + elif 'config' in checkpoint: + config = checkpoint['config'] + else: + raise ValueError("Cannot find config in checkpoint") + + # Update config for adaptive schedule + from omegaconf import OmegaConf + if not OmegaConf.is_config(config): + from omegaconf import DictConfig + config = DictConfig(config) + + OmegaConf.set_struct(config, False) + + config.training.use_adaptive_schedule = args.use_adaptive_schedule + config.training.schedule_hidden_dim = args.schedule_hidden_dim + config.training.schedule_num_layers = args.schedule_num_layers + config.training.schedule_loss_weight = args.schedule_loss_weight + config.training.freeze_base_model = args.freeze_base_model + config.training.schedule_warmup_epochs = args.schedule_warmup_epochs + config.training.use_bracket_safe = True + + OmegaConf.set_struct(config, True) + + # initialize policy model with adaptive schedule + policy_model = AnyOrderInsertionFlowModuleFT( + config=config, + args=args, + pretrained_checkpoint=checkpoint_path, + insertion_planner=True, + ) + + # define mcts + if args.qed_only: + score_func_names = ['qed'] + else: + score_func_names = ['qed', 'sa'] + + tokenizer = get_tokenizer() + + filename = args.run_name + + # Device will be set by Lightning automatically in DDP + reward_model = MolScoringFunctions(score_func_names, device='cpu') + model = MolFinetuner( + args=args, + policy_model=policy_model, + reward_model=reward_model, + tokenizer=tokenizer, + pretrained=pretrained, + mcts=None, + filename=filename, + ) + + checkpoint_callback = ModelCheckpoint( + dirpath=args.save_path, + filename='model-{epoch:02d}-{train_loss:.4f}', + every_n_epochs=args.save_every_n_epochs, + save_top_k=-1, # Save all checkpoints + save_last=True, # Also save last.ckpt + auto_insert_metric_name=False + ) + + # Defaults to your default wandb entity; override with the WANDB_ENTITY env var. + wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-mol', name=args.run_name) + + # create dummy dataloader + dataset = DummyDataset(size=args.num_training_steps_per_epoch) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) + + # setup trainer with DDP + trainer = pl.Trainer( + max_epochs=args.num_epochs, + accelerator='gpu', + devices=args.devices, + strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto', + gradient_clip_val=args.gradnorm_clip if args.grad_clip else None, + logger=wandb_logger, + callbacks=[checkpoint_callback], + enable_progress_bar=True, + log_every_n_steps=1 + ) + + # Train + trainer.fit(model, dataloader) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/a2d2_mol/inference_quality_mol.py b/a2d2_mol/inference_quality_mol.py new file mode 100644 index 0000000000000000000000000000000000000000..57e564f1a8ed4aa3570375097bd37c458abb8123 --- /dev/null +++ b/a2d2_mol/inference_quality_mol.py @@ -0,0 +1,554 @@ +"""Unified molecule sampling with quality-guided planning. + +Supports 4 quality modes and optional RND (importance weight) computation. + +Quality modes: + "none" - No planner, no remasking (policy-only) + "both" - Both unmasking + insertion planners active + "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled) + "insertion_only" - Only insertion planner (unmasking planner disabled) + +RND toggle: + compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights + compute_rnd=False - Run policy model only (use with ELBO-based RND or eval) +""" + +import torch +import numpy as np +import pandas as pd +import torch.nn.functional as F +from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens +from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion +from mol_utils.utils_chem import batch_safe_to_smiles, batch_validate_and_extract +from tdc import Evaluator, Oracle + +QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"} + + +@torch.no_grad() +def _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode="both", + compute_rnd=False, + pretrained=None, + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + temperature=1.0, + return_trace=False, + unmask_quality_threshold=None, +): + """Core discrete diffusion sampling loop for molecule generation. + + Args: + model: Finetuned policy model. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: One of "none", "both", "unmasking_only", "insertion_only". + compute_rnd: Whether to compute step-wise log importance weights. + pretrained: Frozen pretrained model (required if compute_rnd=True). + remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf"). + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. + temperature: Sampling temperature (1.0 = no scaling). + return_trace: Whether to record sampling trace. + + Returns: + (xt, log_rnd, sampling_trace) + log_rnd is None when compute_rnd=False. + """ + assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}" + if compute_rnd: + assert pretrained is not None, "pretrained model required when compute_rnd=True" + + # Derive flags from quality_mode + use_remasking = quality_mode != "none" + disable_unmasking_planner = quality_mode in ("none", "insertion_only") + disable_insertion_planner = quality_mode in ("none", "unmasking_only") + + device = next(model.parameters()).device + + # Initialize all-pad sequence + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute index tensors + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + neg_inf = torch.tensor(-np.inf, device=device) + + if use_remasking and remasking_mode == "remdm_conf": + remasking_score = torch.zeros((batch_size, max_length), device=device) + + log_rnd = None + + for i in range(steps): + # --- Policy model forward --- + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # --- Pretrained model forward (for RND) --- + if compute_rnd: + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # --- Unmask step (Euler) --- + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + if compute_rnd: + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # Add "stay" probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + if compute_rnd: + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + # Temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + # Final step: remove mask token from sampling + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 + + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + num_zero_prob = mask_has_zero_prob.sum().item() + uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) + uniform_prob[:, :mask] = 1.0 / mask + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # Update remasking_score buffer for remdm_conf mode + if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1: + token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V) + chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L) + changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad) + remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score) + + # --- Remasking step --- + if use_remasking and i < steps - 1: + if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None): + remasking_conf = torch.zeros((batch_size, max_length), device=device) + else: + planner_out = model.planner(new_xt, t) + remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L) + + clean_index = (new_xt != mask) & (new_xt != pad) # (B, L) + + if remasking_mode == "schedule_aware": + new_xt = apply_schedule_aware_remasking( + model, new_xt, t, dt, remasking_conf, clean_index, + mask, neg_inf, batch_size, + unmask_quality_threshold=unmask_quality_threshold, + ) + remasking_score_temp = None + else: + raise ValueError(f"Unknown remasking_mode: {remasking_mode}") + + if remasking_score_temp is not None: + remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) + for j in range(batch_size): + k = min(num_remasking, int(clean_index[j].sum().item())) + if k > 0: + _, select_indices = torch.topk(remasking_score_temp[j], k=k) + new_xt[j, select_indices] = mask + + if return_trace: + for batch_idx in range(batch_size): + for pos in range(max_length): + if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=pos, + token=mask, + ) + ) + + # --- Compute log probabilities for RND --- + if compute_rnd: + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + # --- Insertion step --- + if i != steps - 1: + ext = torch.poisson(len_rate * dt).long() # (B, L+1) + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + + # Schedule-aware insertion quality filtering + if use_remasking and not disable_insertion_planner: + if compute_rnd: + xt_tmp_before = xt_tmp.clone() + + xt_tmp = apply_schedule_aware_insertion( + model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length, + orig_mask, new_pos_orig, quality_threshold + ) + + if compute_rnd: + # Compute corrected ext based on what actually stayed + ext_corrected = torch.zeros_like(ext) + for b in range(batch_size): + after_len = xt_tmp[b].ne(pad).sum().item() + orig_len = xt_len[b].item() + surviving_insertions = after_len - orig_len + if total_ext[b] > 0: + ratio = surviving_insertions / total_ext[b].item() + ext_corrected[b] = (ext[b].float() * ratio).long() + else: + ext_corrected = ext + else: + ext_corrected = ext + + # Compute insertion log_rnd + if compute_rnd: + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1) + log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) + + log_insert_diff = log_pretrained_insert - log_policy_insert + log_rnd += log_insert_diff + else: + xt_tmp = new_xt + + if return_trace: + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + if i != steps - 1: + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + return xt, log_rnd, sampling_trace + + +def _decode_and_validate(model, tokenizer, samples): + """Decode token IDs to SMILES and validate. + + Returns: + (validSequences, valid_indices): list of valid SMILES, list of batch indices. + """ + decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True) + + use_bracket_safe = model.config.training.get('use_bracket_safe', False) + smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True) + + # Extract valid sequences (take largest fragment) + validSequences = [] + valid_indices = [] + for idx, s in enumerate(smiles_samples): + if s: + largest_frag = sorted(s.split('.'), key=len)[-1] + validSequences.append(largest_frag) + valid_indices.append(idx) + + return validSequences, valid_indices + + +@torch.no_grad() +def sample_mol_buffer( + model, pretrained, reward_model, tokenizer, + steps, mask, pad, batch_size, max_length, + quality_mode="both", + alpha=0.1, + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + temperature=1.0, + use_quality_filter=True, +): + """Generate molecules for training buffer. Always computes step-wise RND. + + Args: + model: Finetuned policy model. + pretrained: Frozen pretrained model. + reward_model: Molecule scoring function. + tokenizer: SAFE tokenizer for decoding. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: "none", "both", "unmasking_only", or "insertion_only". + alpha: RND scaling factor. + remasking_mode: Remasking strategy. + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. + temperature: Sampling temperature. + use_quality_filter: If True, filter to QED>=0.6 and SA<=4. + + Returns: + (valid_x, log_rnd, scalar_rewards, sampling_trace) + """ + xt, log_rnd, trace = _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode=quality_mode, + compute_rnd=True, + pretrained=pretrained, + remasking_mode=remasking_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + temperature=temperature, + ) + + device = xt.device + samples = xt.to(device) + + validSequences, valid_indices = _decode_and_validate(model, tokenizer, samples) + + valid_x_final = [samples[idx] for idx in valid_indices] + valid_log_rnd = [log_rnd[idx] for idx in valid_indices] + + print("len valid sequences:", len(validSequences)) + + if len(validSequences) == 0: + print("[WARNING] No valid molecules generated in this batch") + empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) + empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) + empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) + return empty_x, empty_log_rnd, empty_rewards, trace + + # Compute multi-objective rewards + score_vectors = reward_model(input_seqs=validSequences) + scalar_rewards = np.sum(score_vectors, axis=-1) + scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) + + print(f"scalar reward dim{len(scalar_rewards)}") + valid_log_rnd = torch.stack(valid_log_rnd, dim=0) + + log_rnd = valid_log_rnd + (scalar_rewards / alpha) + valid_x_final = torch.stack(valid_x_final, dim=0) + + # Optionally filter to only keep quality sequences (QED >= 0.6 and SA <= 4) + if use_quality_filter: + qed_scores = score_vectors[:, 0] + if score_vectors.shape[1] > 1: + sa_scores = score_vectors[:, 1] + else: + _oracle_sa = Oracle('sa') + raw_sa = np.array(_oracle_sa(validSequences)) + sa_scores = raw_sa + quality_mask = (qed_scores >= 0.6) & (sa_scores <= 4) + + n_quality = quality_mask.sum() + print(f"Quality filtering: {n_quality}/{len(validSequences)} sequences pass (QED>=0.6, SA<=4)") + + if n_quality == 0: + print("[WARNING] No quality molecules in this batch") + empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) + empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) + empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) + return empty_x, empty_log_rnd, empty_rewards, trace + + quality_mask_torch = torch.as_tensor(quality_mask, dtype=torch.bool, device=device) + + quality_x_final = valid_x_final[quality_mask_torch] + quality_log_rnd = log_rnd[quality_mask_torch] + quality_rewards = scalar_rewards[quality_mask_torch] + else: + print(f"No quality filtering applied - using all {len(validSequences)} valid molecules") + quality_x_final = valid_x_final + quality_log_rnd = log_rnd + quality_rewards = scalar_rewards + + return quality_x_final, quality_log_rnd, quality_rewards, trace + + +@torch.no_grad() +def sample_mol_eval( + model, reward_model, tokenizer, + steps, mask, pad, batch_size, max_length, + quality_mode="both", + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + temperature=1.0, + evaluator=None, + dataframe=False, + unmask_quality_threshold=None, +): + """Generate molecules for evaluation. + + Args: + model: Finetuned policy model. + reward_model: Molecule scoring function. + tokenizer: SAFE tokenizer for decoding. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: "none", "both", "unmasking_only", or "insertion_only". + remasking_mode: Remasking strategy. + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. Pass None + to use schedule-driven deletion with no threshold gate + temperature: Sampling temperature. + evaluator: TDC Evaluator for diversity (created if None). + dataframe: If True, include a pandas DataFrame in the return. + + Returns: + Without dataframe: + (validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction) + With dataframe: + (validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df) + validSequences is the raw list including duplicates; qed/sa are scored + on the unique set. Caller can dedup with set(validSequences). The + dataframe (when requested) has one row per unique molecule. + """ + if evaluator is None: + evaluator = Evaluator('diversity') + + xt, _, trace = _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode=quality_mode, + compute_rnd=False, + remasking_mode=remasking_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + temperature=temperature, + unmask_quality_threshold=unmask_quality_threshold, + ) + + device = xt.device + samples = xt.to(device) + + decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True) + + use_bracket_safe = model.config.training.get('use_bracket_safe', False) + smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True) + + # Extract valid sequences (take largest fragment) + validSequences = [sorted(s.split('.'), key=len)[-1] for s in smiles_samples if s] + + print("len valid sequences:", len(validSequences)) + valid_fraction = len(validSequences) / batch_size + uniqueSequences = list(set(validSequences)) + uniqueness = len(uniqueSequences) / len(validSequences) if len(validSequences) > 0 else 0 + diversity = evaluator(uniqueSequences) if len(uniqueSequences) > 0 else 0 + + # Calculate quality (unique sequences with QED >= 0.6 and SA <= 4) + if len(uniqueSequences) > 0: + score_vectors_temp = reward_model(input_seqs=list(uniqueSequences)) + qed_scores = score_vectors_temp[:, 0] # Raw QED (0-1) + + # Always use raw SA (1-10 scale) for quality filtering + _oracle_sa = Oracle('sa') + raw_sa_scores = np.array(_oracle_sa(list(uniqueSequences))) + + quality_count = sum((qed_scores >= 0.6) & (raw_sa_scores <= 4)) + quality = quality_count / batch_size + print(f'Quality:\t{quality}') + + qed = qed_scores + sa = raw_sa_scores + else: + zeros = [0.0] + qed = zeros + sa = zeros + quality = 0.0 + + if dataframe: + df = pd.DataFrame({ + "Mol Sequence": uniqueSequences, + "QED": qed if len(uniqueSequences) else [0.0], + "SA": sa if len(uniqueSequences) else [0.0], + }) + return validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df + + return validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction diff --git a/a2d2_mol/mol_dataset.py b/a2d2_mol/mol_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb4da02094d9865a89bf8f98e7c50dcd8a122c3 --- /dev/null +++ b/a2d2_mol/mol_dataset.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python +""" +Adapter to use HuggingFace datasets with the any-length discrete diffusion model. +This module converts HuggingFace datasets (like datamol-io/safe-drugs) into the format +expected by the training pipeline. +""" + +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +import pytorch_lightning as pl +from safe.tokenizer import SAFETokenizer +from mol_utils.bracket_safe_converter import safe2bracketsafe +from typing import Optional, List +import re + + +def get_tokenizer(): + """Get SAFE tokenizer with added special tokens.""" + tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained() + tk.add_tokens(['<', '>']) # for bracket_safe + return tk + + +class Collator: + """Data collator for SAFE/bracket-SAFE format.""" + + def __init__(self, config, tokenizer=None): + self.tokenizer = tokenizer if tokenizer is not None else get_tokenizer() + self.max_length = config.interpolant.max_length + self.use_bracket_safe = config.training.get('use_bracket_safe', False) + + def __call__(self, examples): + # Handle both dict with 'labels' and direct string format + inputs = [] + for example in examples: + if isinstance(example, dict): + # Try different key names: 'input', 'labels', 'smiles' + input_text = example.get('input', example.get('labels', example.get('smiles', ''))) + else: + input_text = example + + if self.use_bracket_safe: + input_text = safe2bracketsafe(input_text) + + inputs.append(input_text) + + batch = self.tokenizer( + inputs, + return_tensors='pt', + padding=True, + truncation=True, + max_length=self.max_length + ) + + # Convert BatchEncoding to plain dict with tensors + # Remove token_type_ids if present (not needed for diffusion models) + result = { + 'input_ids': batch['input_ids'], + 'attention_mask': batch['attention_mask'] + } + + return result + + +class HFDatasetAdapter(Dataset): + """Adapts HuggingFace datasets to the format expected by the diffusion model.""" + + def __init__(self, hf_dataset, tokenizer, smiles_column='smiles', max_length=1024, convert_to_safe=False, is_streaming=False): + """ + Args: + hf_dataset: HuggingFace dataset object (streaming or regular) + tokenizer: SMILES tokenizer instance + smiles_column: Name of the column containing SMILES strings + max_length: Maximum sequence length + convert_to_safe: Whether to convert SMILES to SAFE format + is_streaming: Whether dataset is in streaming mode + """ + self.tokenizer = tokenizer + self.smiles_column = smiles_column + self.max_length = max_length + self.convert_to_safe = convert_to_safe + self.is_streaming = is_streaming + + if is_streaming: + # For streaming datasets, we don't pre-load the data + self.data = hf_dataset + self._length = None # Unknown length for streaming + print(f'Initialized streaming dataset adapter') + else: + # Store raw data without pre-tokenization (tokenization will happen in collator) + print(f'Initializing HF dataset adapter with {len(hf_dataset)} samples...') + self.data = [] + for item in hf_dataset: + smiles = item[smiles_column] + if smiles: # Skip empty SMILES + self.data.append({'input': smiles, 'labels': smiles}) + print(f'Processed {len(self.data)} valid samples') + + def __len__(self): + if self.is_streaming: + # Streaming datasets don't have a length + # Return a large number to prevent issues with samplers + return 10_000_000 if self._length is None else self._length + return len(self.data) + + def __getitem__(self, idx): + if self.is_streaming: + # For streaming, iteration happens differently + raise NotImplementedError("Streaming datasets should be iterated, not indexed") + return self.data[idx] + + def __iter__(self): + """Support iteration for streaming datasets.""" + if self.is_streaming: + for item in self.data: + smiles = item[self.smiles_column] + if smiles: # Skip empty SMILES + yield {'input': smiles, 'labels': smiles} + else: + for item in self.data: + yield item + + +class HFDataModule(pl.LightningDataModule): + """PyTorch Lightning DataModule for HuggingFace datasets.""" + + def __init__( + self, + config, + dataset_name: str, + tokenizer: SAFETokenizer, + smiles_column: str = 'smiles', + val_split: float = 0.1, + test_split: Optional[float] = None, + streaming: bool = True, + max_train_samples: Optional[int] = None, + max_val_samples: Optional[int] = None, + ): + """ + Args: + config: Configuration object containing training parameters + dataset_name: HuggingFace dataset identifier (e.g., "datamol-io/safe-gpt") + tokenizer: SMILES tokenizer instance + smiles_column: Name of column containing SMILES strings + val_split: Fraction of data to use for validation + test_split: Optional fraction of data to use for testing + streaming: Whether to use streaming mode (recommended for large datasets) + max_train_samples: Maximum number of training samples to use (for non-streaming) + max_val_samples: Maximum number of validation samples to use (for non-streaming) + """ + super().__init__() + self.config = config + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.smiles_column = smiles_column + self.max_length = config.interpolant.max_length + self.batch_size = config.training.per_gpu_batch_size + self.num_workers = config.training.get('cpus', 4) + self.val_split = val_split + self.test_split = test_split + self.streaming = streaming + self.max_train_samples = max_train_samples + self.max_val_samples = max_val_samples + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + # Initialize collator + self.collator = Collator(config, tokenizer) + + def setup(self, stage: Optional[str] = None): + """Load and split the dataset.""" + print(f'Loading dataset: {self.dataset_name} (streaming={self.streaming})') + + if self.streaming: + # Load dataset in streaming mode + raw_dataset = load_dataset(self.dataset_name, streaming=True) + + # Handle different dataset structures + if 'train' in raw_dataset: + train_stream = raw_dataset['train'] + else: + # If no splits exist, use the entire dataset + train_stream = raw_dataset[list(raw_dataset.keys())[0]] + + # For streaming, we need to manually split train/val + # Skip validation samples, then take training samples + val_size = int(100000 * self.val_split) # Assume ~100k samples for val split calculation + train_size = 100000 - val_size + + # Create validation stream (take first val_size samples) + val_stream = train_stream.take(val_size) + + # Create training stream (skip val_size samples, then iterate) + train_stream_shifted = train_stream.skip(val_size) + + # Create adapted datasets + self.train_dataset = HFDatasetAdapter( + train_stream_shifted, + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=True + ) + + self.val_dataset = HFDatasetAdapter( + val_stream, + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=True + ) + + print(f'Streaming dataset initialized - samples will be loaded on-the-fly') + + else: + # Traditional non-streaming mode with full dataset loading + raw_dataset = load_dataset(self.dataset_name) + + # Handle different dataset structures + if 'train' in raw_dataset: + train_data = raw_dataset['train'] + else: + # If no splits exist, use the entire dataset and split it + train_data = raw_dataset[list(raw_dataset.keys())[0]] + + # Limit samples if specified + if self.max_train_samples: + train_data = train_data.select(range(min(self.max_train_samples, len(train_data)))) + + # Check if dataset already has validation split + if 'validation' in raw_dataset or 'val' in raw_dataset: + val_key = 'validation' if 'validation' in raw_dataset else 'val' + val_data = raw_dataset[val_key] + else: + # Create train/val split + split_dataset = train_data.train_test_split(test_size=self.val_split, seed=42) + train_data = split_dataset['train'] + val_data = split_dataset['test'] + + # Limit validation samples if specified + if self.max_val_samples: + val_data = val_data.select(range(min(self.max_val_samples, len(val_data)))) + + # Create test split if requested + if self.test_split and 'test' not in raw_dataset: + split_dataset = train_data.train_test_split(test_size=self.test_split, seed=42) + train_data = split_dataset['train'] + self.test_dataset = HFDatasetAdapter( + split_dataset['test'], + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=False + ) + elif 'test' in raw_dataset: + self.test_dataset = HFDatasetAdapter( + raw_dataset['test'], + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=False + ) + + # Create adapted datasets + self.train_dataset = HFDatasetAdapter( + train_data, + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=False + ) + + self.val_dataset = HFDatasetAdapter( + val_data, + self.tokenizer, + self.smiles_column, + self.max_length, + is_streaming=False + ) + + print(f'Dataset splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}') + if self.test_dataset: + print(f'Test: {len(self.test_dataset)}') + + def train_dataloader(self): + if self.streaming: + # Pass streaming dataset directly to DataLoader (HF IterableDataset) + # Must use num_workers=0 when using .skip() or .take() operations + return DataLoader( + self.train_dataset.data, # Use the raw HF streaming dataset + batch_size=self.batch_size, + collate_fn=self.collator, + num_workers=0, # Required for streaming with skip/take operations + pin_memory=True, + shuffle=False, # Cannot shuffle streaming datasets + ) + else: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=self.collator, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True if self.num_workers > 0 else False + ) + + def val_dataloader(self): + if self.streaming: + # Pass streaming dataset directly to DataLoader (HF IterableDataset) + # Must use num_workers=0 when using .skip() or .take() operations + return DataLoader( + self.val_dataset.data, # Use the raw HF streaming dataset + batch_size=self.batch_size, + collate_fn=self.collator, + num_workers=0, # Required for streaming with skip/take operations + pin_memory=True, + shuffle=False, # Cannot shuffle streaming datasets + ) + else: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=self.collator, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True if self.num_workers > 0 else False + ) + + def test_dataloader(self): + if self.test_dataset: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + collate_fn=self.collator, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True if self.num_workers > 0 else False + ) + return None + + +def setup_hf_data_and_update_config(config, dataset_name="datamol-io/safe-gpt", smiles_column="smiles", streaming=True): + """ + Setup HuggingFace dataset and update config with token information. + + Args: + config: Hydra config object + dataset_name: HuggingFace dataset identifier + smiles_column: Name of column containing SMILES strings + streaming: Whether to use streaming mode (recommended for large datasets like safe-gpt) + + Returns: + HFDataModule instance + """ + # Initialize tokenizer + tokenizer = get_tokenizer() + + # Update config with tokenizer info + config.interpolant.tokens = len(tokenizer) + config.interpolant.pad_token = tokenizer.pad_token_id + config.interpolant.mask_token = tokenizer.mask_token_id + + # Create data module + data_module = HFDataModule( + config=config, + dataset_name=dataset_name, + tokenizer=tokenizer, + smiles_column=smiles_column, + val_split=0.1, + streaming=streaming, + ) + + return data_module diff --git a/a2d2_mol/mol_scoring/oracle/fpscores.pkl b/a2d2_mol/mol_scoring/oracle/fpscores.pkl new file mode 100644 index 0000000000000000000000000000000000000000..24e7f60a5b12606184f8642bb5ac9a84c291484b --- /dev/null +++ b/a2d2_mol/mol_scoring/oracle/fpscores.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf +size 9048931 diff --git a/a2d2_mol/mol_scoring/scoring_functions.py b/a2d2_mol/mol_scoring/scoring_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..8f18620debb54206868ad9a07e2769274d08bb40 --- /dev/null +++ b/a2d2_mol/mol_scoring/scoring_functions.py @@ -0,0 +1,68 @@ +from transformers import AutoModelForMaskedLM +import numpy as np +from tdc import Oracle, Evaluator + +class MolScoringFunctions: + def __init__(self, score_func_names=None, device=None, sa_transform='inverse'): + """ + Class for generating score vectors given generated sequence + + Args: + score_func_names: list of scoring function names to be evaluated + score_weights: weights to scale scores (default: 1) + sa_transform: how to transform SA scores to higher-is-better ~[0,1]: + 'inverse' (default): 1/(1+SA) — range ~0.09-0.5, weak gradient + 'linear': (10-SA)/9 — range ~0-1, stronger gradient + """ + if score_func_names is None: + # just do unmasking based on validity of peptide bonds + self.score_func_names = [] + else: + self.score_func_names = score_func_names + + self.sa_transform = sa_transform + + oracle_qed = Oracle('qed') + oracle_sa = Oracle('sa') + + self.all_funcs = {'qed': oracle_qed, + 'sa': oracle_sa, + } + + def forward(self, input_seqs): + scores = [] + + for i, score_func in enumerate(self.score_func_names): + score = self.all_funcs[score_func](input_seqs) + + # Transform SA to be maximized and normalized (original SA: 1-10, lower is better) + # Convert to: higher is better, normalized to ~0-1 range like QED + if score_func == 'sa': + if self.sa_transform == 'linear': + score = (10.0 - np.array(score)) / 9.0 # range ~0-1, clipped at 0 + score = np.maximum(score, 0.0) + else: + score = 1.0 / (1.0 + np.array(score)) # range ~0.09-0.5 + + scores.append(score) + + # convert to numpy arrays with shape (num_sequences, num_functions) + scores = np.float32(scores).T + + return scores + + def __call__(self, input_seqs: list): + return self.forward(input_seqs) + + +def unittest(): + scoring = MolScoringFunctions(score_func_names=['qed', 'sa']) + + smiles = ['CCOc1cc(ccc1NC(=O)N[C@@H]2CCCC[C@@H]2O)F'] + + scores = scoring(input_seqs=smiles) + print(scores) + print(len(scores)) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/a2d2_mol/mol_utils/bracket_safe_converter.py b/a2d2_mol/mol_utils/bracket_safe_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc9c5d0fdf8ba3ddd55dbb60838e74acacb452 --- /dev/null +++ b/a2d2_mol/mol_utils/bracket_safe_converter.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Patch: stub out `auto_docstring` if missing from transformers.utils +# (needed by safe.trainer.model in newer safe versions) +import transformers.utils as _tu +if not hasattr(_tu, 'auto_docstring'): + _tu.auto_docstring = lambda *a, **kw: (lambda fn: fn) + +from safe.converter import * + + +class BracketSAFEConverter(SAFEConverter): + def encoder( + self, + inp: Union[str, dm.Mol], + canonical: bool = True, + randomize: Optional[bool] = False, + seed: Optional[int] = None, + constraints: Optional[List[dm.Mol]] = None, + allow_empty: bool = False, + rdkit_safe: bool = True, + ): + rng = None + if randomize: + rng = np.random.default_rng(seed) + if not canonical: + inp = dm.to_mol(inp, remove_hs=False) + inp = self.randomize(inp, rng) + + if isinstance(inp, dm.Mol): + inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False) + + branch_numbers = self._find_branch_number(inp) + + mol = dm.to_mol(inp, remove_hs=False) + if self.ignore_stereo: + mol = dm.remove_stereochemistry(mol) + + bond_map_id = 1 + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() == 0: + atom.SetAtomMapNum(0) + atom.SetIsotope(bond_map_id) + bond_map_id += 1 + + if self.require_hs: + mol = dm.add_hs(mol) + matching_bonds = self._fragment(mol, allow_empty=allow_empty) + substructed_ignored = [] + if constraints is not None: + substructed_ignored = list( + itertools.chain( + *[ + mol.GetSubstructMatches(constraint, uniquify=True) + for constraint in constraints + ] + ) + ) + + bonds = [] + for i_a, i_b in matching_bonds: + # if both atoms of the bond are found in a disallowed substructure, we cannot consider them + # on the other end, a bond between two substructure to preserved independently is perfectly fine + if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored): + continue + obond = mol.GetBondBetweenAtoms(i_a, i_b) + bonds.append(obond.GetIdx()) + + if len(bonds) > 0: + mol = Chem.FragmentOnBonds( + mol, + bonds, + dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], + ) + + frags = list(Chem.GetMolFrags(mol, asMols=True)) + if randomize: + frags = rng.permutation(frags).tolist() + elif canonical: + frags = sorted( + frags, + key=lambda x: x.GetNumAtoms(), + reverse=True, + ) + + frags_str = [] + for frag in frags: + non_map_atom_idxs = [ + atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 + ] + frags_str.append( + Chem.MolToSmiles( + frag, + isomericSmiles=True, + canonical=True, # needs to always be true + rootedAtAtom=non_map_atom_idxs[0], + ) + ) + + scaffold_str = ".".join(frags_str) + + # don't capture atom mapping in the scaffold + attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) + if canonical: + attach_pos = sorted(attach_pos) + starting_num = 1 + for attach in attach_pos: + val = str(starting_num) if starting_num < 10 else f"%{starting_num}" + val = '<' + val + '>' # bracket added + # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" + attach_regexp = re.compile(r"(" + re.escape(attach) + r")") + scaffold_str = attach_regexp.sub(val, scaffold_str) + starting_num += 1 + + # now we need to remove all the parenthesis around digit only number + wrong_attach = re.compile(r"\((<[\%\d+]*>)\)") # bracket added + scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) + # furthermore, we autoapply rdkit-compatible digit standardization. + if rdkit_safe: + pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" + replacement = r"\g<1>\g<2>" + scaffold_str = re.sub(pattern, replacement, scaffold_str) + return scaffold_str + + +def safe2bracketsafe(safe_str): + try: + return BracketSAFEConverter().encoder(Chem.MolFromSmiles(safe_str), allow_empty=True, canonical=False, randomize=True) + except: + return safe_str + + +def bracketsafe2safe(safe_str): + intrafrag_points = [m.group(0) for m in re.finditer(r'(?)', safe_str)] + \ + [m.group(0).lstrip('%') for m in re.finditer(r'%\d+', safe_str)] + starting_num = max([int(i) for i in intrafrag_points]) + 1 if intrafrag_points else 0 + interfrag_points = [(m.start(0), m.end(0)) for m in re.finditer(r'<\d+>', safe_str)] + + safe_str = list(safe_str) + for start, end in interfrag_points: + safe_str[start] = safe_str[end-1] = ' ' # '<', '>' -> '' + num_to_replace = int(''.join(safe_str[start+1 : end-1])) + starting_num + num_to_replace = '%' + str(num_to_replace) if num_to_replace >= 10 else str(num_to_replace) + safe_str[start+1 : end-1] = [num_to_replace] + [' '] * (end - start - 3) + safe_str = re.sub(' ', '', ''.join(safe_str)) + return safe_str \ No newline at end of file diff --git a/a2d2_mol/mol_utils/utils.py b/a2d2_mol/mol_utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..0795a73ff2ed657e5136ab99f83e109f5ea9174e --- /dev/null +++ b/a2d2_mol/mol_utils/utils.py @@ -0,0 +1,135 @@ +"""Console logger utilities. + +Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py +Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging +""" + +import logging +import fsspec +import lightning +import torch +from timm.scheduler import CosineLRScheduler +import argparse +import numpy as np +import random +import os + +def sample_categorical_logits(logits, dtype=torch.float64): + # do not require logits to be log-softmaxed + gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log() + return (logits + gumbel_noise).argmax(dim=-1) + +def fsspec_exists(filename): + """Check if a file exists using fsspec.""" + fs, _ = fsspec.core.url_to_fs(filename) + return fs.exists(filename) + + +def fsspec_listdir(dirname): + """Listdir in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + return fs.ls(dirname) + + +def fsspec_mkdirs(dirname, exist_ok=True): + """Mkdirs in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + fs.makedirs(dirname, exist_ok=exist_ok) + + +def print_nans(tensor, name): + if torch.isnan(tensor).any(): + print(name, tensor) + + +class CosineDecayWarmupLRScheduler( + CosineLRScheduler, + torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_epoch = -1 + self.step(epoch=0) + + def step(self, epoch=None): + if epoch is None: + self._last_epoch += 1 + else: + self._last_epoch = epoch + # We call either step or step_update, depending on + # whether we're using the scheduler every epoch or every + # step. + # Otherwise, lightning will always call step (i.e., + # meant for each epoch), and if we set scheduler + # interval to "step", then the learning rate update will + # be wrong. + if self.t_in_epochs: + super().step(epoch=self._last_epoch) + else: + super().step_update(num_updates=self._last_epoch) + + +class LoggingContext: + """Context manager for selective logging.""" + def __init__(self, logger, level=None, handler=None, close=True): + self.logger = logger + self.level = level + self.handler = handler + self.close = close + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + if self.handler: + self.logger.addHandler(self.handler) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) + if self.handler: + self.logger.removeHandler(self.handler) + if self.handler and self.close: + self.handler.close() + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ('debug', 'info', 'warning', 'error', + 'exception', 'fatal', 'critical'): + setattr(logger, + level, + lightning.pytorch.utilities.rank_zero_only( + getattr(logger, level))) + + return logger + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def set_seed(seed, use_cuda): + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + if use_cuda: + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + print(f'=> Seed of the run set to {seed}') + diff --git a/a2d2_mol/mol_utils/utils_chem.py b/a2d2_mol/mol_utils/utils_chem.py new file mode 100644 index 0000000000000000000000000000000000000000..8682e03d49ca11ee0b28950f3b8cab471d86cc61 --- /dev/null +++ b/a2d2_mol/mol_utils/utils_chem.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import safe as sf +import datamol as dm +from contextlib import suppress +from rdkit import Chem, RDLogger +RDLogger.DisableLog('rdApp.*') + +# https://github.com/datamol-io/safe/blob/main/safe/sample.py +# https://github.com/jensengroup/GB_GA/blob/master/crossover.py +def safe_to_smiles(safe_str, fix=True): + if fix: + safe_str = '.'.join([frag for frag in safe_str.split('.') + if sf.decode(frag, ignore_errors=True) is not None]) + return sf.decode(safe_str, canonical=True, ignore_errors=True) + + +def _safe_to_smiles_worker(args): + """Worker function for parallel SAFE to SMILES conversion.""" + safe_str, use_bracket_safe, fix = args + try: + from mol_utils.bracket_safe_converter import bracketsafe2safe + if use_bracket_safe: + safe_str = bracketsafe2safe(safe_str) + return safe_to_smiles(safe_str, fix=fix) + except Exception: + return None + + +def batch_safe_to_smiles(safe_strings, use_bracket_safe=False, fix=True, num_workers=None): + """ + Convert a batch of SAFE strings to SMILES in parallel using multiprocessing. + + Args: + safe_strings: List of SAFE format strings + use_bracket_safe: Whether to convert from bracket SAFE format first + fix: Whether to fix invalid fragments + num_workers: Number of parallel workers (default: min(cpu_count, len(safe_strings), 8)) + + Returns: + List of SMILES strings (None for invalid molecules) + """ + from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + import os + + n = len(safe_strings) + if n == 0: + return [] + + # For small batches, use sequential processing (overhead not worth it) + if n <= 4: + if use_bracket_safe: + from mol_utils.bracket_safe_converter import bracketsafe2safe + return [safe_to_smiles(bracketsafe2safe(s), fix=fix) for s in safe_strings] + else: + return [safe_to_smiles(s, fix=fix) for s in safe_strings] + + # Use ThreadPoolExecutor for I/O bound tasks (RDKit releases GIL) + # ProcessPoolExecutor has too much overhead for this use case + if num_workers is None: + num_workers = min(os.cpu_count() or 4, n, 8) + + args_list = [(s, use_bracket_safe, fix) for s in safe_strings] + + # ThreadPoolExecutor is faster here because: + # 1. No pickle serialization overhead + # 2. RDKit releases the GIL during computation + # 3. Lower startup cost + with ThreadPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(_safe_to_smiles_worker, args_list)) + + return results + + +def batch_validate_and_extract(smiles_list, samples_tensor, log_rnd_tensor): + """ + Batch validate SMILES and extract valid samples efficiently. + + Args: + smiles_list: List of SMILES strings (may contain None for invalid) + samples_tensor: Tensor of token IDs (B, L) + log_rnd_tensor: Tensor of log random values (B,) + + Returns: + valid_sequences: List of valid SMILES (largest fragment) + valid_indices: List of indices of valid samples + """ + valid_sequences = [] + valid_indices = [] + + for idx, smiles in enumerate(smiles_list): + if smiles: # Valid SMILES + # Take largest fragment if multiple + largest_fragment = sorted(smiles.split('.'), key=len)[-1] + valid_sequences.append(largest_fragment) + valid_indices.append(idx) + + return valid_sequences, valid_indices + + +def filter_by_substructure(sequences, substruct): + substruct = sf.utils.standardize_attach(substruct) + substruct = Chem.DeleteSubstructs(Chem.MolFromSmarts(substruct), Chem.MolFromSmiles('*')) + substruct = Chem.MolFromSmarts(Chem.MolToSmiles(substruct)) + return sf.utils.filter_by_substructure_constraints(sequences, substruct) + + +def mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, num_samples=1): + mol_linker_slicer = sf.utils.MolSlicer(require_ring_system=False) + + prefix_linkers = [] + suffix_linkers = [] + prefix_query = dm.from_smarts(prefix) + suffix_query = dm.from_smarts(suffix) + + for x in prefix_sequences: + with suppress(Exception): + x = dm.to_mol(x) + out = mol_linker_slicer(x, prefix_query) + prefix_linkers.append(out[1]) + + for x in suffix_sequences: + with suppress(Exception): + x = dm.to_mol(x) + out = mol_linker_slicer(x, suffix_query) + suffix_linkers.append(out[1]) + + n_linked = 0 + linked = [] + linkers = prefix_linkers + suffix_linkers + linkers = [x for x in linkers if x is not None] + for n_linked, linker in enumerate(linkers): + linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix)) + if n_linked > num_samples: + break + linked = [x for x in linked if x] + return linked[:num_samples] + + +def cut(smiles): + def cut_nonring(mol): + if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')): + return None + + bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) # single bond not in ring + bs = [mol.GetBondBetweenAtoms(bis[0], bis[1]).GetIdx()] + fragments_mol = Chem.FragmentOnBonds(mol, bs, addDummies=True, dummyLabels=[(1, 1)]) + + try: + return Chem.GetMolFrags(fragments_mol, asMols=True, sanitizeFrags=True) + except ValueError: + return None + + mol = Chem.MolFromSmiles(smiles) + frags = set() + # non-ring cut + for _ in range(3): + frags_nonring = cut_nonring(mol) + if frags_nonring is not None: + frags |= set([Chem.MolToSmiles(f) for f in frags_nonring]) + return frags + + +class Slicer: + def __call__(self, mol): + if isinstance(mol, str): + mol = Chem.MolFromSmiles(mol) + + # non-ring single bonds + bonds = mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]')) + for bond in bonds: + yield bond \ No newline at end of file diff --git a/a2d2_mol/oracle/fpscores.pkl b/a2d2_mol/oracle/fpscores.pkl new file mode 100644 index 0000000000000000000000000000000000000000..24e7f60a5b12606184f8642bb5ac9a84c291484b --- /dev/null +++ b/a2d2_mol/oracle/fpscores.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24a4392f5c673e79c0446af3c4d8e458293b5fecaa244328e76741ead9d21dbf +size 9048931 diff --git a/a2d2_mol/remasking_scheduleaware.py b/a2d2_mol/remasking_scheduleaware.py new file mode 100644 index 0000000000000000000000000000000000000000..e8637a514828dd6336caa996c133b2250748064b --- /dev/null +++ b/a2d2_mol/remasking_scheduleaware.py @@ -0,0 +1,177 @@ +""" +Schedule-aware remasking and insertion logic that ensures the number of masked tokens +follows the interpolant schedule. +""" +import torch +import numpy as np + +def apply_schedule_aware_insertion( + model, + xt_tmp, + new_xt, + t, + dt, + ext, + mask, + pad, + max_length, + orig_mask, + new_pos_orig, + quality_threshold=1, +): + """ + Remove low-quality insertions based on insertion confidence while respecting + the interpolant schedule for expected sequence length. + + Args: + model: Model with planner and interpolant + xt_tmp: Sequence after insertion [B, L] + new_xt: Sequence before insertion [B, L] + t: Current time [B] + dt: Time step size + ext: Number of insertions per gap [B, L+1] + mask: Mask token ID + pad: Pad token ID + max_length: Maximum sequence length + orig_mask: Mask of original token positions [B, L] + new_pos_orig: New positions of original tokens [B, L] + quality_threshold: If a float, drop insertions with confidence below it + + Returns: + xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule) + """ + device = xt_tmp.device + batch_size, L = xt_tmp.shape + total_ext = ext.sum(dim=1) + + # Only proceed if there were insertions + if total_ext.sum() == 0: + return xt_tmp + + # Get planner predictions on inserted state. The insertion head is trained + # with the pre-step time t (see loss_insert_planner_flexible), so condition + # on t here too; t_next is still used below for the length schedule. + t_next = t + dt + planner_out = model.planner(xt_tmp, t) + insertion_conf = planner_out.get("insertion_conf", None) + + if insertion_conf is None: + return xt_tmp + + insertion_conf = insertion_conf.squeeze(-1) # (B, L) + + # Expected sequence length at next timestep according to schedule + current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B] + expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B] + estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1)) + expected_length = estimated_final_length * expected_progress # [B] + + # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions. + # Fancy-indexing scatter avoids the per-batch python loop. + valid_b, valid_l = orig_mask.nonzero(as_tuple=True) + valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1) + is_original = torch.zeros_like(xt_tmp, dtype=torch.bool) + is_original[valid_b, valid_p] = True + inserted_positions = (xt_tmp == mask) & ~is_original + + # Two deletion modes, selected by `quality_threshold`: + # * float: drop insertions whose confidence is below the threshold, capped + # so the length never falls below the scheduled minimum. + candidates = inserted_positions & (insertion_conf < quality_threshold) + num_bad = candidates.sum(dim=1) # [B], long + min_length = expected_length.long().clamp(min=1) # [B] + max_removable = (current_length_after.long() - min_length).clamp(min=0) + length_after_removal = current_length_after.long() - num_bad + schedule_violates = length_after_removal < min_length + k_per_row = torch.where(schedule_violates, max_removable, num_bad) + k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row)) + + if not candidates.any(): + return xt_tmp + + # Select the lowest-confidence candidates per row via a sort. + neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype) + scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse + _, sorted_indices = scores.sort(dim=1, descending=True) + positions = torch.arange(L, device=device).unsqueeze(0) # [1, L] + keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L] + final_bad = torch.zeros_like(candidates) + final_bad.scatter_(1, sorted_indices, keep_in_topk) + + if not final_bad.any(): + return xt_tmp + + # Compact each row to the left (keep good, drop bad), then pad the tail. + # Stable sort by the bad flag pushes bad positions to the right. + sort_key = final_bad.long() + _, perm = torch.sort(sort_key, dim=1, stable=True) + xt_tmp = torch.gather(xt_tmp, 1, perm) + num_keep = (~final_bad).sum(dim=1) # [B] + tail_mask = positions >= num_keep.unsqueeze(1) # [B, L] + xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp) + + return xt_tmp + + +def apply_schedule_aware_remasking( + model, + new_xt, + t, + dt, + remasking_conf, + clean_index, + mask, + neg_inf, + batch_size, + unmask_quality_threshold=None, +): + """ + Apply schedule-aware remasking: adjust number of masks to match expected count from schedule. + + Args: + model: Model with interpolant that has an unmask_schedule + new_xt: Current sequence [B, L] + t: Current time [B] + dt: Time step size + remasking_conf: Confidence scores for tokens [B, L] + clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L] + mask: Mask token ID + neg_inf: Negative infinity tensor + batch_size: Batch size + + Returns: + new_xt: Modified sequence with schedule-aware remasking applied + """ + # Optional AJD threshold gate (overrides the schedule-driven count when set): + # remask every clean token whose unmasking-quality confidence is below the + # threshold. Higher threshold => more aggressive remasking. + if unmask_quality_threshold is not None: + to_mask = clean_index & (remasking_conf < unmask_quality_threshold) + return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt) + + t_next = t + dt + num_clean = clean_index.sum(dim=1) # [B], long + current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B] + expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B] + expected_num_clean = expected_unmasked_frac * current_seq_len # [B] + masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B] + + # Per-row k = min(masks_to_add, num_clean), clamped to >= 0. + k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B] + + if k_per_row.sum() == 0: + return new_xt + + # Use confidence to decide which clean tokens to remask: lowest conf first. + remasking_score_temp = -1.0 * remasking_conf # low conf = high score + remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) + + _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True) + L = remasking_score_temp.shape[1] + positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L] + keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L] + to_mask = torch.zeros_like(clean_index) + to_mask.scatter_(1, sorted_indices, keep_in_topk) + new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt) + + return new_xt diff --git a/a2d2_mol/sampling.py b/a2d2_mol/sampling.py new file mode 100755 index 0000000000000000000000000000000000000000..2bdd2eee730434849c946259f801ad527fd15723 --- /dev/null +++ b/a2d2_mol/sampling.py @@ -0,0 +1,1401 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path + +import torch +from dataclasses import dataclass +from typing import Any, Literal, Optional +import numpy as np +import pandas as pd + +from lightning_modules.mdm import MaskedDiffusionModule + + +@dataclass +class SamplingTraceDatapoint: + t: float + event_type: Literal["insertion", "change"] + position: int + token: Any + + +@dataclass +class SamplingResult: + samples: torch.Tensor + # Trace is supposed to be processed sequentially as updates are not commutative + trace: Optional[list[SamplingTraceDatapoint]] + + def __iter__(self): + yield from [self.samples, self.trace] + + +# Sample from categorical distribution for each position using the transition probabilities +def _sample_tokens(probs: torch.Tensor) -> torch.Tensor: + """Sample one token per position from probability distribution. + Args: + probs: [batch_size, seq_len, vocab_size] transition probabilities + Returns: + [batch_size, seq_len] sampled token indices + """ + batch_size, seq_len, vocab_size = probs.shape + flat_probs = probs.view(-1, vocab_size) + samples = torch.multinomial(flat_probs, num_samples=1) + return samples.view(batch_size, seq_len) + + +def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor: + + batch_size, seq_len, vocab_size = probs.shape + + gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device) + noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities + + # select the highest score (most likely category after Gumbel noise) + samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long) + + return samples.view(batch_size, seq_len) + +@torch.no_grad() +def mdm_euler_sampling( + model: MaskedDiffusionModule, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +): + assert not return_trace, "Trace is not yet implemented in MDM Euler sampling" + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + for i in range(steps): + print("i-th sampling step") + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + _xt = xt.clone() + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 + print(trans_prob[mask_pos + (mask,)]) + + new_xt = _sample_tokens(trans_prob) + new_xt = torch.where(xt != mask, xt, new_xt) + + xt = new_xt + t = t + dt + + return xt, [] + + +@torch.no_grad() +def any_order_mask_insertion_euler_sampling( + model: torch.nn.Module, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + # 1) Initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # Check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # Check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + return xt, sampling_trace + +@torch.no_grad() +def batch_mcts_reverse_step( + xt: torch.Tensor, + t: torch.Tensor, + dt: float, + model: torch.nn.Module, + pretrained: torch.nn.Module, + mask: int, + pad: int, + batch_size: int, + max_length: int, + last_step: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + xt = xt.repeat(batch_size, 1) + + # squeeze to remove extra dimensions, then expand to batch_size + t = t.squeeze().expand(batch_size) + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # Same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + if last_step: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if not last_step: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + # log P(ext; λ) = ext*log(λ) - λ + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + log_pretrained_step += log_pretrained_insert + log_policy_step += log_policy_insert + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch + gaps = torch.arange(seq_dim, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + return xt_tmp, log_rnd, log_policy_step, log_pretrained_step + + +@torch.no_grad() +def mcts_reverse_step( + xt: torch.Tensor, + t: torch.Tensor, + dt: float, + model: torch.nn.Module, + pretrained: torch.nn.Module, + mask: int, + pad: int, + max_length: int, + last_step: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + batch_size = xt.size(0) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + if last_step: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if not last_step: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + # log P(ext; λ) = ext*log(λ) - λ + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + log_pretrained_step += log_pretrained_insert + log_policy_step += log_policy_insert + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch + gaps = torch.arange(seq_dim, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + return xt_tmp, log_rnd, log_policy_step, log_pretrained_step + +@torch.no_grad() +def any_order_euler_sampling_with_schedule( + model: torch.nn.Module, + time_schedule: torch.Tensor, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + time_schedule = time_schedule.to(device) + if time_schedule[0] < time_schedule[-1]: + time_schedule = torch.flip(time_schedule, [0]) # descending order + + steps = len(time_schedule) - 1 + + # initialize all-pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # use scheduled timesteps + t = time_schedule[i].repeat(batch_size) + t_next = time_schedule[i + 1] + dt = (t - t_next).abs() # timestep difference + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0) + + # add "stay" probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + + if mask_has_zero_prob.any(): + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # Check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # Check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + + return xt, sampling_trace + + +@torch.no_grad() +def any_order_mask_insertion_euler_sampling_with_rnd( + model, pretrained, reward_model, analyzer, + tokenizer, steps, + mask, + pad, + batch_size, + max_length, + return_trace = False, + alpha = 0.1, + temperature: float = 1.0, +): + device = next(model.parameters()).device + + # initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + # initialize log_rnd to accumulate log probability ratios + log_rnd = torch.zeros(batch_size, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # Same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if i != steps - 1: + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # check if the token was changed + for i in range(batch_size): + for j in range(max_length): + if xt[i, j] != pad and xt[i, j] != new_xt[i, j]: + sampling_trace[i].append( + SamplingTraceDatapoint( + t=t[i].item(), + event_type="change", + position=j, + token=new_xt[i, j].item(), + ) + ) + + # check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[i, id]: + sampling_trace[i].append( + SamplingTraceDatapoint( + t=t[i].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + # change rewards for peptides + samples = xt.to(device) + + # store raw token IDs + # Decode and strip samples + decoded_samples = tokenizer.batch_decode(samples) + + valid_x_final = [] + validSequences = [] + valid_log_rnd = [] + + for idx, seq in enumerate(decoded_samples): + # check if the peptide is valid + if analyzer.is_peptide(seq): + valid_x_final.append(xt[idx]) + validSequences.append(seq) + valid_log_rnd.append(log_rnd[idx]) + + print("len valid sequences:", len(validSequences)) + # compute multi-objective rewards + score_vectors = reward_model(input_seqs=validSequences) + scalar_rewards = np.sum(score_vectors, axis=-1) + scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) + + print(f"scalar reward dim{len(scalar_rewards)}") + valid_log_rnd = torch.stack(valid_log_rnd, dim=0) + + log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha + valid_x_final = torch.stack(valid_x_final, dim=0) + + return valid_x_final, log_rnd, scalar_rewards, sampling_trace + +@torch.no_grad() +def any_order_finetuned_euler_sampler( + model, reward_model, analyzer, + tokenizer, steps, + mask, + pad, + batch_size, + max_length, + return_trace = False, + dataframe = False, + temperature: float = 1.0, + ): + device = next(model.parameters()).device + + # initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # gap-wise insertion refactored — compute new length, fill masks, scatter tokens + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + # start eval + samples = xt.to(device) + + decoded_samples = tokenizer.batch_decode(samples) + + valid_x_final = [] + validSequences = [] + + for idx, seq in enumerate(decoded_samples): + if analyzer.is_peptide(seq): + valid_x_final.append(samples[idx]) + validSequences.append(seq) + + print("len valid sequences:", len(validSequences)) + valid_fraction = len(validSequences) / batch_size + + if (len(validSequences) != 0): + # add scores to log + score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives) + average_scores = score_vectors.T + + affinity = average_scores[0] + sol = average_scores[1] + hemo = average_scores[2] + nf = average_scores[3] + permeability = average_scores[4] + + else: + zeros = [0.0] + + affinity = zeros + sol = zeros + hemo = zeros + nf = zeros + permeability = zeros + + if dataframe: + df = pd.DataFrame({ + "Peptide Sequence": validSequences, + "Binding Affinity": affinity if len(validSequences) else [0.0], + "Solubility": sol if len(validSequences) else [0.0], + "Hemolysis": hemo if len(validSequences) else [0.0], + "Nonfouling": nf if len(validSequences) else [0.0], + "Permeability": permeability if len(validSequences) else [0.0], + }) + return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df + + return samples, affinity, sol, hemo, nf, permeability, valid_fraction + +@torch.no_grad() +def mdm_tau_leaping_sampling( + model: MaskedDiffusionModule, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +): + assert not return_trace, "Trace is not yet supported" + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + for i in range(steps): + # ——— predict and convert rates ——— + pred = model(xt, t) + pred = model.interpolant.to_actual_rate(xt, pred, t) + unmask_rate = pred.unmask_rate # (B, L, V) + + if i == steps - 1: + # last step: deterministic unmask via argmax + mask_pos = xt == mask # (B, L) + new_token = unmask_rate.argmax(dim=2) # (B, L) + new_xt = xt.clone() + new_xt[mask_pos] = new_token[mask_pos] + new_xt = torch.where(xt != mask, xt, new_xt) + xt = new_xt + t = t + dt + continue + # tau-leaping via Poisson counts + counts = torch.poisson(unmask_rate * dt).long() + mask_pos = xt == mask # (B, L) + # zero out non-mask positions and mask→mask + counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 + counts[..., mask] = 0 + # only accept exactly one event + sum_c = counts.sum(dim=2) # (B, L) + one_event = sum_c == 1 + new_token = counts.argmax(dim=2) # (B, L) + + # build new xt + new_xt = xt.clone() + new_xt[one_event] = new_token[one_event] + # keep pads and already-unmasked tokens + new_xt = torch.where(xt != mask, xt, new_xt) + xt = new_xt + t = t + dt + + return xt, [] + +# Not used in production, for debugging purposes +lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1} + +def binomial_mass(k, n, p): + """ + Calculate the probability mass function (PMF) for a binomial distribution. + + Args: + k (int): Number of successes + n (int): Number of trials + p (float): Probability of success in a single trial + + Returns: + float: Probability mass P(X = k) + """ + import math + + # Calculate binomial coefficient (n choose k) + try: + binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k)) + except ValueError: + # Handle cases where k > n or negative values + return 0.0 + + # Calculate probability mass + return binom_coef * (p ** k) * ((1 - p) ** (n - k)) + +def calculate_rate_batch(alpha_t, len_t): + """ + Calculate rate for a batch of alpha_t and len_t values. + + Args: + alpha_t (torch.Tensor): Tensor of shape (batch_size,) + len_t (torch.Tensor): Tensor of shape (batch_size,) + + Returns: + torch.Tensor: Tensor of shape (batch_size,) containing calculated rates + """ + batch_size = alpha_t.shape[0] + device = alpha_t.device + + # Initialize tensors for numerator and denominator + nom = torch.zeros(batch_size, device=device) + denom = torch.zeros(batch_size, device=device) + + for length, probability in lengths.items(): + # Create mask for valid entries where len_t <= length + valid_mask = (len_t <= length) & (len_t >= 0) + + if not valid_mask.any(): + continue + + valid_indices = valid_mask.nonzero(as_tuple=True)[0] + valid_len_t = len_t[valid_indices] + valid_alpha_t = alpha_t[valid_indices] + + # Calculate binomial probabilities efficiently using torch distribution + binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t) + binom_probs = binom_dist.log_prob(valid_len_t).exp() + + # Update numerator and denominator for valid indices + nom[valid_indices] += (length - valid_len_t) * probability * binom_probs + denom[valid_indices] += probability * binom_probs + + # Handle division by zero in a vectorized way + result = torch.zeros_like(nom) + div_mask = denom > 0 + result[div_mask] = nom[div_mask] / (denom[div_mask]) + + return result + +# Keep the original function for backward compatibility +def calculate_rate(alpha_t, len_t): + """Legacy scalar version of calculate_rate""" + if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0: + return calculate_rate_batch(alpha_t, len_t) + + nom, denom = 0, 0 + for length, probability in lengths.items(): + if length >= len_t: + nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t) + denom += probability * binomial_mass(len_t, length, alpha_t) + + if denom == 0: + return 0.0 + + return nom /denom + + +@torch.no_grad() +def any_order_mask_insertion_tau_leaping_sampling( + model: torch.nn.Module, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + confidence_based_sampling: bool = True, # whether to use confidence-based decoding + alpha: float = 5.0, # hyperparameter for window size calculation + max_window: int = 32, # Maximum window size for sliding window + confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy" + use_sliding_window: bool = False, # whether to use sliding window for position selection + temperature: float = 1.0, +) -> SamplingResult: + + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + for i in range(steps): + # --- predict rates --- + pred = model(xt, t) + xt_len = (xt != pad).sum(dim=1) + pred = model.interpolant.to_actual_rate(xt, pred, t) + unmask_rate = pred.unmask_rate # (B, L, V) + len_rate = pred.length_rate # (B, L+1) + + if i == steps - 1: + # last step: deterministic unmask via argmax + mask_pos = xt == mask + new_token = unmask_rate.argmax(dim=2) + new_xt = xt.clone() + new_xt[mask_pos] = new_token[mask_pos] + new_xt = torch.where(xt == pad, pad, new_xt) + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + xt = new_xt + t = t + dt + continue + + # --- confidence-based decoding --- + if confidence_based_sampling > 0.0: + # Confidence-based unmasking (vectorized) + mask_positions = (xt == mask) # (B, L) + num_mask_positions = mask_positions.sum(dim=1) # (B,) + + # 1. Determine number of tokens to unmask using Poisson + unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,) + + # 2. Calculate confidence based on selected method + if confidence_method == "position": + # Position-based confidence: position i / len(xt) + xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths + position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L) + confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L) + + elif confidence_method == "top_prob": + # Top probability confidence + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + confidence = unmask_probs.max(dim=-1)[0] # (B, L) + + elif confidence_method == "prob_diff": + # Probability difference confidence (top - second top) + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2) + confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L) + + elif confidence_method == "entropy": + # Entropy-based confidence (lower entropy = higher confidence) + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L) + confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence + + else: + raise ValueError(f"Unknown confidence_method: {confidence_method}") + + # 3. Apply window constraint if enabled + if use_sliding_window: + # Calculate dynamic k for each batch + k_values = torch.minimum( + torch.minimum( + (alpha * unmask_counts).long(), + torch.tensor(max_window, device=device) + ), num_mask_positions) # (B,) + + # Get cumulative count of mask positions + mask_cumsum = mask_positions.cumsum(dim=1) # (B, L) + + # Create window mask: position is eligible if it's a mask and within first k masks + is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L) + window_mask = mask_positions & is_within_window # (B, L) + + # Set confidence to -inf for positions outside the window or non-mask positions + confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device)) + else: + # No window constraint - only mask positions are eligible + confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device)) + + new_xt = xt.clone() + + # vectorized unmasking + max_unmask = unmask_counts.max().item() + if max_unmask > 0: + _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask) + + # create mask for valid unmask operations + unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask) + + most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L) + + selected_positions = all_top_indices[unmask_mask] + batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask] + + new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions] + else: + # --- tau-leaping unmask via Poisson --- + counts = torch.poisson(unmask_rate * dt).long() + mask_pos = xt == mask + counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 + counts[..., mask] = 0 + sum_c = counts.sum(dim=2) + one_event = sum_c == 1 + new_token = counts.argmax(dim=2) + new_xt = xt.clone() + new_xt[one_event] = new_token[one_event] + new_xt = torch.where(xt == pad, pad, new_xt) + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # insertion only on non-last + if i != steps - 1: + # --- Poisson insertion, compute new lengths and fill masks --- + ext = torch.poisson(len_rate * dt).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + # compute prefix sums of insertions + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + # initialize with pads, then fill mask up to new_len + xt_tmp = torch.full_like(xt, pad) + mask_pos = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_pos] = mask + + # shift and scatter original tokens + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + xt = xt_tmp + t = t + dt + if return_trace: + sampling_trace.append(xt) + + return xt, sampling_trace diff --git a/a2d2_mol/scripts/run_mol_finetune.slurm b/a2d2_mol/scripts/run_mol_finetune.slurm new file mode 100644 index 0000000000000000000000000000000000000000..f629a011f69d7ee13fe38d0b50fd054d06ac69c6 --- /dev/null +++ b/a2d2_mol/scripts/run_mol_finetune.slurm @@ -0,0 +1,200 @@ +#!/bin/bash +# NOTE: --partition and --qos below are specific to our cluster. Change them +# (or remove them and pass `--partition` on the `sbatch` command line) to match +# the partitions/QOS available on yours. +#SBATCH --job-name=mol-finetune +#SBATCH --partition=dgx-b200 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=80GB +#SBATCH --time=02-00:00:00 +#SBATCH --output=logs/slurm-%A.%x.log + +# ===================================================================== +# run_mol_finetune.slurm +# +# Single-mode job (1 MIG GPU) running ONE finetune_mol experiment. +# Select which mode to run via the MODE_ID variable below (or override +# at submit time with `sbatch --export=ALL,MODE_ID=2 ...`): +# 0) A2D2 (Ours) – with full planner (alternating) +# 1) A2D2 w/o quality – --disable_planner +# 2) A2D2 w/o insertion planner – --disable_insertion_planner +# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner +# +# The job trains the selected mode then evaluates the resulting +# checkpoint on the same GPU. +# ===================================================================== + +set -e + +# --- Mode selection --------------------------------------------------- +# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`. +MODE_ID="${MODE_ID:-0}" + +# Run prefix +PREFIX=${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)} + +# --- Paths ------------------------------------------------------------ +# Repo root is resolved at submit time so the job runs from any clone: +# - set A2D2_ROOT explicitly, OR +# - submit with `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR; +# note sbatch copies the script to a spool dir, so we can't rely on the +# script's own path here), OR +# - run the script directly, falling back to its location on disk. +if [ -n "${A2D2_ROOT:-}" ]; then + HOME_LOC="$A2D2_ROOT" +elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then + HOME_LOC="$SLURM_SUBMIT_DIR" +else + # This script lives in a2d2_mol/scripts/, so the repo root is two levels up. + HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +fi +SCRIPT_LOC="$HOME_LOC/a2d2_mol" +LOG_LOC=$HOME_LOC/logs +SAVE_DIR=$HOME_LOC/checkpoints/finetune_mol +RESULTS_DIR=$HOME_LOC/results/mol_ablation + +mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR" + +# --- Environment setup ------------------------------------------------ +# Set WANDB_API_KEY in your shell/secret store before submitting (do NOT commit it): +# export WANDB_API_KEY=... or `wandb login` +export WANDB_DIR=$HOME_LOC/.wandb +export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb +export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb +mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR" + +export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache +mkdir -p "$TRITON_CACHE_DIR" + +export TORCHINDUCTOR_CACHE_DIR=$HOME_LOC/.torchinductor/cache +mkdir -p "$TORCHINDUCTOR_CACHE_DIR" + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Force unbuffered stdout/stderr so live training output is flushed to the +# redirected RUN_LOG (Python block-buffers stdout when it's a file, not a TTY). +export PYTHONUNBUFFERED=1 + +# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda +# install, or just have `conda` on PATH; override CONDA_ENV if your env name +# differs from the one created by environment.yml. +CONDA_ENV="${CONDA_ENV:-a2d2}" +if [ -n "${CONDA_ROOT:-}" ]; then + source "$CONDA_ROOT/bin/activate" "$CONDA_ENV" +elif command -v conda >/dev/null 2>&1; then + source "$(conda info --base)/bin/activate" "$CONDA_ENV" +else + echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2 + exit 1 +fi +PYTHON_EXECUTABLE=$(which python) + +cd "$SCRIPT_LOC" + +# Pretrained base checkpoint +PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_mol.ckpt" + +# --- Shared training hyperparameters ---------------------------------- +COMMON_ARGS=( + --base_path "$HOME_LOC" + --use_quality_filter + --noise_removal + --wdce_num_replicates 16 + --pool_size 1000 + --pool_refresh_fraction 0.3 + --buffer_size 100 + --batch_size 200 + --training_mini_batch_size 20 + --max_length 256 + --total_num_steps 256 + --num_iter 20 + --resample_every_n_step 10 + --num_epochs 1000 + --save_every_n_epochs 100 + --reset_every_n_step 1 + --alpha 0.01 + --no_mcts + --schedule_warmup_epochs 20 + --alternation_frequency 5 + --num_remasking 3 + --quality_threshold 0.3 + --checkpoint_path "$PRETRAINED_CKPT" + --grad_clip + --qed_only + --seed 42 + --num_training_steps_per_epoch 25 +) + +# --- Shared evaluation hyperparameters -------------------------------- +EVAL_COMMON_ARGS=( + --pretrained_ckpt "$PRETRAINED_CKPT" + --num_samples 1000 + --batch_size 50 + --max_length 256 + --total_num_steps 256 + --num_remasking 2 + --quality_threshold 0.3 + --seed 42 +) + +# ===================================================================== +# Pick experiment from $MODE_ID +# ===================================================================== +case "$MODE_ID" in + 0) MODE="with_planner"; EXTRA_ARGS=() ;; + 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;; + 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;; + 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;; + *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;; +esac + +RUN_NAME="${PREFIX}_mol_${MODE}" +RUN_LOG="$LOG_LOC/${RUN_NAME}.log" +RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}" +RESULTS_SUBDIR="$RESULTS_DIR/${MODE}" +mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR" + +echo "=== Mol finetune (MODE_ID=$MODE_ID) ===" +echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST" +echo "Mode: $MODE" +echo "Save dir: $RUN_SAVE_DIR" +echo "Results dir: $RESULTS_SUBDIR" +echo "Python: $PYTHON_EXECUTABLE" +echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}" + +# ===================================================================== +# Train +# ===================================================================== +$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_mol.py \ + "${COMMON_ARGS[@]}" \ + --devices 1 \ + "${EXTRA_ARGS[@]}" \ + --save_path_dir "$RUN_SAVE_DIR" \ + >> "$RUN_LOG" 2>&1 + +echo "Training finished for $MODE. Log: $RUN_LOG" + +# ===================================================================== +# Evaluate +# ===================================================================== +RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt "$RUN_SAVE_DIR"/last.ckpt 2>/dev/null | head -1) +if [ -z "$RUN_CKPT" ]; then + echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval." + exit 1 +fi + +echo "Evaluating checkpoint: $RUN_CKPT" +$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_mol_table.py \ + --checkpoint_path "$RUN_CKPT" \ + "${EVAL_COMMON_ARGS[@]}" \ + "${EXTRA_ARGS[@]}" \ + --output_dir "$RESULTS_SUBDIR" \ + --device cuda:0 \ + >> "$RESULTS_SUBDIR/eval.log" 2>&1 + +echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}.csv" + +conda deactivate diff --git a/a2d2_mol/scripts/train_mol.sh b/a2d2_mol/scripts/train_mol.sh new file mode 100755 index 0000000000000000000000000000000000000000..33bc79330da861c95e033e8b615a1f667a2b1725 --- /dev/null +++ b/a2d2_mol/scripts/train_mol.sh @@ -0,0 +1,93 @@ +#!/bin/bash +#SBATCH --job-name=a2d2-mol-pretrain +#SBATCH --partition=dgx-b200 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=512GB +#SBATCH --time=7-00:00:00 +# SLURM's own catch-file (anything printed before the exec redirect below, plus +# slurm-infra messages). Relative to the submit dir, so submit this script from +# the a2d2_mol/ directory; the real run output is redirected via exec below. +#SBATCH --output=logs/slurm/%x_%j.out +#SBATCH --error=logs/slurm/%x_%j.err +# +# Pretrain the any-length insertion MDM on drug-like SAFE molecules on a dgx-b200 node. +# Submit with: sbatch scripts/train_mol.sh (from the a2d2_mol/ directory). +# +# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and +# --ntasks-per-node must match; change both together (and they override the +# training.devices value baked into config_mol.yaml via the hydra override below). + +DATE=$(date +%Y%m%d) +SPECIAL_PREFIX='a2d2-mol' + +# Resolve a2d2_mol/ (which holds train.py + config_mol.yaml) so paths are +# repo-relative. This script lives in a2d2_mol/scripts/, so the direct-run +# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled +# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_mol/ directory). +if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then + SCRIPT_DIR="$SLURM_SUBMIT_DIR" +else + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +fi +cd "$SCRIPT_DIR" + +# Auto-detect GPUs from the SLURM allocation (falls back to 2 for `bash` runs). +DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-2}} +NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES} +NODES=${SLURM_NNODES:-1} + +LOG_LOC="$SCRIPT_DIR/logs" +mkdir -p "$LOG_LOC/slurm" +exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1 + +# --------------------------------------------------------------------------- +# Weights & Biases: log in once on your machine before running this script with +# `wandb login` (or `export WANDB_API_KEY=`). +# Do NOT hardcode your API key here. To disable W&B entirely, uncomment: +# export WANDB_MODE=disabled +# --------------------------------------------------------------------------- + +export PYTORCH_ALLOC_CONF=expandable_segments:True + +# Activate the conda env that has the deps (torch / pytorch_lightning / hydra). +# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override +# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on +# PATH; override CONDA_ENV if your env name differs from the one created by +# environment.yml. +CONDA_ENV="${CONDA_ENV:-a2d2}" +if [ -n "${CONDA_ROOT:-}" ]; then + source "$CONDA_ROOT/bin/activate" "$CONDA_ENV" +elif command -v conda >/dev/null 2>&1; then + source "$(conda info --base)/bin/activate" "$CONDA_ENV" +else + echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2 + exit 1 +fi + +# --- Distributed / NCCL setup (single node, intra-node NVLink) -------------- +ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}') +if [ -z "$ETH_IFACE" ]; then + ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}') +fi +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_FAMILY=AF_INET +export NCCL_SOCKET_IFNAME=$ETH_IFACE +export NCCL_P2P_LEVEL=NVL + +export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1) +export MASTER_PORT=$(shuf -i 15000-59999 -n 1) +export NODE_RANK=${SLURM_NODEID:-0} + +echo "=== a2d2 molecule pretraining (dgx-b200) ===" +echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS" + +# --task mol makes train.py load config_mol.yaml; the hydra overrides pin +# devices/nodes to the SLURM allocation so the two never drift apart. +srun --ntasks-per-node=$NTASKS python train.py --task mol \ + training.devices=$DEVICES \ + training.nodes=$NODES + +conda deactivate diff --git a/a2d2_mol/train.py b/a2d2_mol/train.py new file mode 100755 index 0000000000000000000000000000000000000000..84e8e0196fc50b1fa496a1827394a7518bd25fdf --- /dev/null +++ b/a2d2_mol/train.py @@ -0,0 +1,216 @@ +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +import os +import sys +import argparse +import hydra +from omegaconf import OmegaConf +from datetime import datetime +# Directory containing this file and the config_*.yaml files (used by Hydra below). +CONFIG_DIR = os.path.dirname(os.path.abspath(__file__)) +# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve. +sys.path.insert(0, os.path.dirname(CONFIG_DIR)) + +import wandb +from lightning_modules import AnyOrderInsertionFlowModule + + +torch.set_printoptions(threshold=10_000) +torch.set_float32_matmul_precision("high") + +# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops +torch._dynamo.config.optimize_ddp = False + +def train(config): + wandb_logger = None + + # set the random seed + pl.seed_everything(42) + torch.manual_seed(42) + + # Only initialize wandb on rank 0 to avoid multiple runs + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + wandb.init( + project=config.wandb.project, + name=config.wandb.name, + config=OmegaConf.to_container(config, resolve=True), # Convert to dict + dir=config.wandb.path + ) + wandb_logger = WandbLogger( + project=wandb.run.project, + name=wandb.run.name, + log_model=False, # Disable checkpoint uploading to save disk space + ) + + # Modify config to add timestamp to checkpoint directory + OmegaConf.set_struct(config, False) + time_string = datetime.now().strftime("%Y%m%d-%H%M%S") + config.training.checkpoint_dir = os.path.join( + config.training.checkpoint_dir, time_string + ) + OmegaConf.set_struct(config, True) + + # Create checkpoint directory + os.makedirs(config.training.checkpoint_dir, exist_ok=True) + + # Setup data module - check if using HuggingFace dataset + if hasattr(config, 'hf_dataset'): + # Imported lazily: the HF/SAFE path is only used by the molecule configs, + # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/. + from mol_dataset import setup_hf_data_and_update_config + print(f"Using HuggingFace dataset: {config.hf_dataset.name}") + data_module = setup_hf_data_and_update_config( + config, + dataset_name=config.hf_dataset.name, + smiles_column=config.hf_dataset.get('smiles_column', 'smiles') + ) + else: + # Imported lazily: the local (arrow) path is used by the peptide config, + # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/. + from dataloading_for_dynamic_batching import setup_data_and_update_config + print("Using local dataset") + data_module = setup_data_and_update_config(config) + + module = AnyOrderInsertionFlowModule(config) + + # Initialize trainer + + # Configure trainer arguments + # Map torch_dtype to Lightning precision + dtype_str = config.model.get('torch_dtype', 'bfloat16') + precision_map = { + 'float32': '32-true', + 'float16': '16-mixed', + 'bfloat16': 'bf16-mixed' + } + precision = precision_map.get(dtype_str, 'bf16-mixed') + + trainer_kwargs = dict( + num_nodes=config.training.nodes, + accelerator="gpu", + devices=config.training.devices, + strategy="ddp", + precision=precision, + accumulate_grad_batches=( + config.training.batch_size + // ( + config.training.per_gpu_batch_size + * config.training.nodes + * config.training.devices + ) + ), + log_every_n_steps=10, + enable_checkpointing=True, + default_root_dir=config.training.checkpoint_dir, + gradient_clip_val=1.0, + ) + # Only one of max_steps or max_epochs will be used + if config.training.max_steps is not None: + trainer_kwargs["max_steps"] = config.training.max_steps + elif config.training.num_epochs is not None: + trainer_kwargs["max_epochs"] = config.training.num_epochs + config.training.max_steps = config.training.max_steps + else: + raise ValueError( + "Either max_steps or num_epochs must be specified in the config" + ) + + if config.training.warmup_steps is None: + config.training.warmup_steps = int(config.training.max_steps * 0.01) + + # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low + checkpoint_callback = ModelCheckpoint( + monitor="train/total_loss", + mode="min", + save_top_k=config.training.save_top_k, + save_last=True, + filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}", + dirpath=config.training.checkpoint_dir, + # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run + auto_insert_metric_name=False + ) + + # Add separate callback for periodic saves (no val_loss dependency). Use + # step-based saves for streaming datasets (save_every_n_steps) and epoch-based + # saves otherwise (save_every_n_epochs); whichever the config provides. + save_every_n_steps = config.training.get('save_every_n_steps', None) + save_every_n_epochs = config.training.get('save_every_n_epochs', None) + if save_every_n_steps is not None: + periodic_checkpoint_callback = ModelCheckpoint( + save_top_k=-1, # Save all periodic checkpoints + filename="step-{step:08d}", + dirpath=config.training.checkpoint_dir, + every_n_train_steps=save_every_n_steps, + auto_insert_metric_name=False + ) + elif save_every_n_epochs is not None: + periodic_checkpoint_callback = ModelCheckpoint( + save_top_k=-1, # Save all periodic checkpoints + filename="epoch-{epoch:02d}", + dirpath=config.training.checkpoint_dir, + every_n_epochs=save_every_n_epochs, + auto_insert_metric_name=False + ) + else: + raise ValueError( + "Either save_every_n_steps or save_every_n_epochs must be specified in the config" + ) + + trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback] + + if wandb_logger is not None: + trainer_kwargs["logger"] = wandb_logger + + trainer = pl.Trainer(**trainer_kwargs) + + # Train the model + ckpt_path = None + if "resume_path" in config.training: + ckpt_path = config.training.resume_path + + trainer.fit(module, + datamodule=data_module, + ckpt_path=ckpt_path) + + # Only finish wandb on rank 0 + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + wandb.finish() + + +if __name__ == '__main__': + # Parse arguments to get config name + parser = argparse.ArgumentParser() + parser.add_argument('--config_name', type=str, default='config', + help='Name of the config file to use') + parser.add_argument('--task', type=str, default=None, + help='Task name (uses config_{task}.yaml)') + + # Parse known args (hydra will handle the rest) + args, unknown = parser.parse_known_args() + + # Determine config name from task or config_name + if args.task: + config_name = f'config_{args.task}' + else: + config_name = args.config_name + + print(f"Using config: {config_name}.yaml") + + # Add config name to Hydra overrides (this persists across DDP subprocesses) + if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown: + unknown.insert(0, f'--config-name={config_name}') + + # Reconstruct sys.argv for hydra + sys.argv = [sys.argv[0]] + unknown + + # Define main function with default config (will be overridden by command line) + @hydra.main(version_base=None, + config_path=CONFIG_DIR, + config_name='config') + def main(config): + """Main entry point for training""" + train(config) + + main() \ No newline at end of file diff --git a/a2d2_pep/README.md b/a2d2_pep/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28f6d169706537f88ec3836ed9d406bc1d087854 --- /dev/null +++ b/a2d2_pep/README.md @@ -0,0 +1,145 @@ +# A2D2 for Multi-Objective Therapeutic Peptide Generation 🧫 + +This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over peptide SMILES with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize **five therapeutic properties simultaneously**: binding affinity to a target protein, solubility, non-hemolysis, non-fouling, and cell-membrane permeability. + +A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating peptides via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality. + +Peptides are represented as **SMILES** strings and tokenized with the SMILES Pair Encoding tokenizer (vocabulary size `V = 587`) from [PeptideCLM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01443). Generated SMILES are decoded and validity-checked with the `SMILES2PEPTIDE` filter from [PepTune](https://arxiv.org/abs/2412.17780). + +The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main). + +## Environment Installation +``` +# from the repository root +conda env create -f environment.yml + +conda activate a2d2 +``` +The peptide scripts share the `a2d2` environment with the molecule and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step. + +## Model Pretrained Weights + +A2D2 fine-tunes a pretrained any-length insertion MDM trained on ~11M peptide SMILES (7,451 sequences from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs). Download the base checkpoint and place it at: +``` +A2D2/pretrained/anylength_pep.ckpt +``` +```bash +# from the repository root +pip install gdown +mkdir -p pretrained +gdown 1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc -O pretrained/anylength_pep.ckpt +``` +(Or download manually from https://drive.google.com/file/d/1K8yxM-omh-MuPo0EG6UyxHZLk3HehoJc/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.) +This is the default `--checkpoint_path`; pass `--checkpoint_path` to override it. + +The reward classifiers (binding-affinity Transformer, plus XGBoost predictors for solubility, hemolysis, non-fouling, and permeability) and the SMILES PE tokenizer ship with the repo under [`pep_scoring/`](pep_scoring); no separate download is required. The PeptideCLM embedding model is fetched automatically from the Hugging Face Hub (`aaronfeller/PeptideCLM-23M-all`) on first run. + +## Pretraining the Any-Length Model + +If you only want to fine-tune with A2D2, download the released `anylength_pep.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch. + +### 1. Download the pretraining dataset + +The pretraining corpus is ~11M peptide SMILES (7,451 from CycPeptMPDB, 825,632 from SmProt, and ~10M modified peptides from CycloPs), already tokenized with the in-repo SMILES PE tokenizer and saved as a Hugging Face `arrow` dataset (with `train`/`val` splits) via `save_to_disk`. + +Download the archive and unpack it into [`data/`](data): + +```bash +# from a2d2_pep/ +pip install gdown +gdown https://drive.google.com/uc?id=1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp -O 11M_peptide_smiles.zip +mkdir -p data +unzip 11M_peptide_smiles.zip -d data/ +# result: a2d2_pep/data/11M_peptide_smiles/{train,val}/... +``` + +This is the default `training.data_path` in [`config_pep.yaml`](config_pep.yaml). To store the dataset elsewhere, set `training.data_path` (absolute, or relative to `a2d2_pep/`). + +### 2. Configure + +Pretraining is driven by [`config_pep.yaml`](config_pep.yaml). Key fields: + +| Field | Default | Notes | +|-------|---------|-------| +| `training.data_path` | `data/11M_peptide_smiles` | Preprocessed arrow dataset from step 1. | +| `training.devices` | `4` | GPUs per node (DDP). | +| `training.batch_size` | `1024` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. | +| `training.max_steps` | `1000000` | Total optimizer steps. | +| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. | +| `training.checkpoint_dir` | `checkpoints/peptides` | A timestamped subdirectory is created per run. | +| `interpolant.max_length` | `1024` | Max token length. | + +### 3. Pre-training Any-Length Peptide Model + +Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job: + +```bash +# from a2d2_pep/ +sbatch train_pep.sh +``` + +`train_pep.sh` is a SLURM batch script that requests one `dgx-b200` node with 4 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of: + +```bash +python train.py --task pep +``` + +It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task pep` makes `train.py` load `config_pep.yaml`. + +Checkpoints are written to `checkpoints/peptides//` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` for fine-tuning); the run log goes to `logs/_a2d2-peptide_.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config. + +## Fine-Tune with A2D2 + +All paths resolve relative to the repository, so the scripts run from any checkout. Before running, create the output directories `A2D2/checkpoints`, `A2D2/results`, and `A2D2/logs` (the script also creates them on demand). Fine-tuning curves and a `_generation_results.csv` are written to `/results//`, and checkpoints to `--save_path_dir`. + +Choose a target protein with `--prot_name` (looked up in the built-in `PROTEINS` table — e.g. `glp1` for GLP-1R or `glast` for GLAST), or supply an arbitrary target with `--prot_seq `. + +#### Available `--prot_name` targets + +The named targets and their amino-acid sequences are defined in the `PROTEINS` dict in [`finetune_quality.py`](finetune_quality.py) (search for `PROTEINS = {`). The default is `glast`; passing a name not in the table raises an error listing the valid keys. To add a new target, add a `'': ''` entry there, or skip the table entirely with `--prot_seq`. + +| `--prot_name` | Target | +|---------------|--------| +| `tfr` | Transferrin receptor (TfR) | +| `glp1` | GLP-1 receptor (GLP-1R) | + +### Single run + +[`scripts/run_peptide_finetune.slurm`](scripts/run_peptide_finetune.slurm) runs a single `finetune_quality.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the hyperparameter set from the peptide column of the fine-tuning table in the paper — replicates `R = 8`, buffer size `B = 50`, resample interval `N_resample = 10`, gradient steps per iteration `N_update = 10`, alternation frequency `N_alt = 5`, warmup `N_warmup = 20`, sampling steps `N_steps = 256`, training mini-batch `10`, reward scaling `α = 0.1`, quality threshold `μ_min = 0.5`, and `--num_obj 5` — so you don't have to pass them by hand. + +The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`) and `WANDB_ENTITY`: +```bash +export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone +export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH +export WANDB_ENTITY=your_wandb_entity # optional +sbatch scripts/run_peptide_finetune.slurm +``` + +Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time: +```bash +sbatch --export=ALL,MODE_ID=2 scripts/run_peptide_finetune.slurm +``` +The target protein is set by the `PROT_NAME` variable near the top of the script (default `tfr`); edit it to one of the named targets above (or any key in the `PROTEINS` table). The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_pep.ckpt`. Outputs land in `checkpoints/finetune_test_peptides_/_peptide__/` and `results/peptide_test_ablation_//`. + +### Key arguments +- `--prot_name` / `--prot_seq` — target protein (named lookup, or a raw amino-acid sequence). +- `--alternation_frequency` — epochs to train each of {policy, planner} before alternating. +- `--alpha` — reward-tilting temperature (smaller = stronger reward optimization). +- `--buffer_size`, `--resample_every_n_step` — replay-buffer size and how often it is regenerated. + +### Ablation flags +| Flag | Variant | +|------|---------| +| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) | +| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) | +| `--disable_insertion_planner` | A2D2 w/o insertion quality | +| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality | +| `--joint_training` | train policy + quality heads jointly (no alternation) | + +During buffer generation only sequences passing the `SMILES2PEPTIDE` validity filter are retained; the scalarized multi-objective reward is added to the log Radon–Nikodym derivative of each sequence. Fine-tuning runs on a single GPU (`--devices 1`). + +## Evaluation + +Evaluation runs automatically every `--eval_every_n_epochs` epochs and at the end of training. It samples from the current model and reports the fraction of valid peptides along with the five therapeutic rewards (binding affinity, solubility, non-hemolysis, non-fouling, permeability), saving per-objective curves and `_generation_results.csv` under `/results//`. + +To resume a run, pass `--resume_ckpt /path/to/last.ckpt` (restores epoch, optimizer, and planner state; new checkpoints continue in the same directory). diff --git a/a2d2_pep/config_pep.yaml b/a2d2_pep/config_pep.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b7877e1a8fd72fd8bde7d391e5f6dda073df060f --- /dev/null +++ b/a2d2_pep/config_pep.yaml @@ -0,0 +1,50 @@ +trainer: "any-order-flow" +dataset: "peptides" + +model: + hidden_size: 768 + n_heads: 12 + cond_dim: 128 + dropout: 0.05 + n_blocks: 12 + +interpolant: + type: "any-order" + tokens: null # filled in automatically + pad_token: null # filled in automatically + mask_token: null # filled in automatically + max_length: 1024 + insert_schedule: + type: "linear" + unmask_schedule: + type: "linear" + +training: + only_embed_insert: true + batch_size: 1024 + per_gpu_batch_size: 64 # Gradient accumulation happens automatically + cpus: 4 + learning_rate: 3e-4 + nodes: 1 + devices: 4 + max_steps: 1000000 + weight_decay: 0.03 + # Path to the preprocessed (arrow) pretraining dataset; see README for the download link. + # Relative paths resolve against a2d2_pep/. Defaults to a2d2_pep/data/11M_peptide_smiles. + data_path: "data/11M_peptide_smiles" + checkpoint_dir: "checkpoints/peptides" + save_top_k: 1 + save_every_n_epochs: 1 + loss_fn: + unmask: "elbo" + insert: "expectation" + reset_lr: false + warmup_steps: 2000 + ema_decay: 0.9999 + filter_max_length: false + +wandb: + entity: null # set to your W&B entity, or leave null to use the default + project: "a2d2-pep" + name: "a2d2-pep" + path: "./wandb" diff --git a/a2d2_pep/data/dataloading_for_dynamic_batching.py b/a2d2_pep/data/dataloading_for_dynamic_batching.py new file mode 100755 index 0000000000000000000000000000000000000000..ceffa4370e08296941116dc57932dec358eaec66 --- /dev/null +++ b/a2d2_pep/data/dataloading_for_dynamic_batching.py @@ -0,0 +1,189 @@ +#!/usr/bin/env +import os +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import Dataset,load_from_disk +import sys +import pytorch_lightning as pl +from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from functools import partial +import re + +# Directory containing this file; used to resolve the in-repo tokenizer files. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class DynamicBatchingDataset(Dataset): + def __init__(self, dataset_dict, tokenizer): + print('Initializing dataset...') + self.dataset_dict = { + 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']], + 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']], + 'labels': dataset_dict['labels'] + } + self.tokenizer = tokenizer + + def __len__(self): + return len(self.dataset_dict['attention_mask']) + + def __getitem__(self, idx): + if isinstance(idx, int): + return { + 'input_ids': self.dataset_dict['input_ids'][idx], + 'attention_mask': self.dataset_dict['attention_mask'][idx], + 'labels': self.dataset_dict['labels'][idx] + } + elif isinstance(idx, list): + return { + 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx], + 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx], + 'labels': [self.dataset_dict['labels'][i] for i in idx] + } + else: + raise ValueError(f"Expected idx to be int or list, but got {type(idx)}") + +class CustomDataModule(pl.LightningDataModule): + def __init__(self, dataset_path, tokenizer): + super().__init__() + self.dataset = load_from_disk(dataset_path) + self.tokenizer = tokenizer + + def peptide_bond_mask(self, smiles_list): + """ + Returns a mask with shape (batch_size, seq_length) that has 1 at the locations + of recognized bonds in the positions dictionary and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + + Returns: + np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES + mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros + + bond_patterns = [ + (r'OC\(=O\)', 'ester'), + (r'N\(C\)C\(=O\)', 'n_methyl'), + (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds + (r'NC\(=O\)', 'peptide'), # Regular peptide bonds + (r'C\(=O\)N\(C\)', 'n_methyl'), + (r'C\(=O\)N[12]?', 'peptide') + ] + + for batch_idx, smiles in enumerate(smiles_list): + positions = [] + used = set() + + # Identify bonds + for pattern, bond_type in bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Update the mask for the current SMILES + for pos in positions: + mask[batch_idx, pos['start']:pos['end']] = 1 + + return mask + + def peptide_token_mask(self, smiles_list, token_lists): + """ + Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens + where any part of the token overlaps with a peptide bond, and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + token_lists: List of tokenized SMILES strings (split into tokens). + + Returns: + np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence + tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros + atomwise_masks = self.peptide_bond_mask(smiles_list) + + + for batch_idx, atomwise_mask in enumerate(atomwise_masks): + token_seq = token_lists[batch_idx] + atom_idx = 0 + + for token_idx, token in enumerate(token_seq): + if token_idx != 0 and token_idx != len(token_seq) - 1: + if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: + tokenized_masks[batch_idx][token_idx] = 1 + atom_idx += len(token) + + return tokenized_masks + + def collate_fn(self, batch): + item = batch[0] + + token_array = self.tokenizer.get_token_split(item['input_ids']) + bond_mask = self.peptide_token_mask(item['labels'], token_array) + + return { + 'input_ids': item['input_ids'], + 'attention_mask': item['attention_mask'], + 'bond_mask': bond_mask + } + + def train_dataloader(self): + train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer) + return DataLoader( + train_dataset, + batch_size=1, + collate_fn=self.collate_fn, # Use the instance method + shuffle=True, + num_workers=12, + pin_memory=True + ) + + def val_dataloader(self): + val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer) + return DataLoader( + val_dataset, + batch_size=1, + collate_fn=self.collate_fn, # Use the instance method + num_workers=8, + pin_memory=True + ) + + +def setup_data_and_update_config(config): + """ + Get the dataset and update the config with token information for text datasets. + """ + # SMILES Pair Encoding tokenizer ships with the repo under pep_scoring/tokenizer/. + tokenizer = SMILES_SPE_Tokenizer( + os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_vocab.txt'), + os.path.join(_THIS_DIR, 'pep_scoring', 'tokenizer', 'new_splits.txt'), + ) + + config.interpolant.tokens = len(tokenizer) + config.interpolant.pad_token = tokenizer.pad_token_id + config.interpolant.mask_token = tokenizer.mask_token_id + + # Path to the preprocessed (arrow) pretraining dataset saved via `save_to_disk`. + # Download instructions are in the README; override with `training.data_path` in the config. + data_path = config.training.get('data_path', os.path.join('data', '11M_peptide_smiles')) + if not os.path.isabs(data_path): + data_path = os.path.join(_THIS_DIR, data_path) + if not os.path.exists(data_path): + raise FileNotFoundError( + f"Pretraining dataset not found at '{data_path}'. Download it (see a2d2_pep/README.md, " + "'Pretraining the Any-Length Model') and set `training.data_path` in config_pep.yaml." + ) + data_module = CustomDataModule(data_path, tokenizer) + + return data_module diff --git a/a2d2_pep/data/dataset.py b/a2d2_pep/data/dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..a6cdd558951a4faaaaf53e109fca5a2bd71b4011 --- /dev/null +++ b/a2d2_pep/data/dataset.py @@ -0,0 +1,207 @@ + +import re +import torch + +import utils + +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl +from functools import partial +import sys + +class CustomDataset(Dataset): + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + actual_idx = int(self.indices[idx]) + item = self.dataset[actual_idx] + return item + + +# for weighting losses of peptide bonds +def peptide_bond_mask(smiles_list): + """ + Returns a mask with shape (batch_size, seq_length) that has 1 at the locations + of recognized bonds in the positions dictionary and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + + Returns: + np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES + mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros + + bond_patterns = [ + (r'OC\(=O\)', 'ester'), + (r'N\(C\)C\(=O\)', 'n_methyl'), + (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds + (r'NC\(=O\)', 'peptide'), # Regular peptide bonds + (r'C\(=O\)N\(C\)', 'n_methyl'), + (r'C\(=O\)N[12]?', 'peptide') + ] + + for batch_idx, smiles in enumerate(smiles_list): + positions = [] + used = set() + + # Identify bonds + for pattern, bond_type in bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Update the mask for the current SMILES + for pos in positions: + mask[batch_idx, pos['start']:pos['end']] = 1 + + return mask + +def peptide_token_mask(smiles_list, token_lists): + """ + Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens + where any part of the token overlaps with a peptide bond, and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + token_lists: List of tokenized SMILES strings (split into tokens). + + Returns: + np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence + tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros + atomwise_masks = peptide_bond_mask(smiles_list) + + + for batch_idx, atomwise_mask in enumerate(atomwise_masks): + token_seq = token_lists[batch_idx] + atom_idx = 0 + + for token_idx, token in enumerate(token_seq): + if token_idx != 0 and token_idx != len(token_seq) - 1: + if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: + tokenized_masks[batch_idx][token_idx] = 1 + atom_idx += len(token) + + return tokenized_masks + +def extract_amino_acid_sequence(helm_string): + """ + Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array, + removing any brackets around each amino acid. + + Args: + helm_string (str): The HELM notation string for a peptide. + + Returns: + list: A list containing each amino acid in sequence without brackets. + """ + # Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number + matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string) + + if matches: + # Join all matched sequences and split by dots to get individual amino acids + amino_acid_sequence = [] + for match in matches: + sequence = match.replace('[', '').replace(']', '').split('.') + amino_acid_sequence.extend(sequence) + return amino_acid_sequence + else: + return "Invalid HELM notation or no peptide sequence found." + +def helm_collate_fn(batch, tokenizer): + sequences = [item['HELM'] for item in batch] + + max_len = 0 + for sequence in sequences: + seq_len = len(extract_amino_acid_sequence(sequence)) + if seq_len > max_len: + max_len = seq_len + + tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024) + + return { + 'input_ids': tokens['input_ids'], + 'attention_mask': tokens['attention_mask'] + } + + +def collate_fn(batch, tokenizer): + """Standard data collator that truncates/pad sequences based on max_length""" + valid_sequences = [] + valid_items = [] + + for item in batch: + try: + test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035) + valid_sequences.append(item['SMILES']) + valid_items.append(item) + except Exception as e: + print(f"Skipping sequence due to: {str(e)}") + continue + + #sequences = [item['SMILES'] for item in batch] + #max_len = max([len(seq) for seq in sequences]) + #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32) + + tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035) + + token_array = tokenizer.get_token_split(tokens['input_ids']) + bond_mask = peptide_token_mask(valid_sequences, token_array) + #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool) + + return { + 'input_ids': tokens['input_ids'], + 'attention_mask': tokens['attention_mask'], + 'bond_mask': bond_mask + } + + +class CustomDataModule(pl.LightningDataModule): + def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn): + super().__init__() + self.train_dataset = train_dataset + self.val_dataset = val_dataset + #self.test_dataset = test_dataset + self.batch_size = batch_size + self.tokenizer = tokenizer + self.collate_fn = collate_fn + + def train_dataloader(self): + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), + num_workers=8, + pin_memory=True + ) + + + def val_dataloader(self): + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), + num_workers=8, + pin_memory=True + ) + + """def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.batch_size, + collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer), + num_workers=8, pin_memory=True)""" \ No newline at end of file diff --git a/a2d2_pep/evaluate_peptide_table.py b/a2d2_pep/evaluate_peptide_table.py new file mode 100644 index 0000000000000000000000000000000000000000..ce15303670f8a38e59e0f3055bd589f26c4ae075 --- /dev/null +++ b/a2d2_pep/evaluate_peptide_table.py @@ -0,0 +1,326 @@ +""" +Evaluate a finetuned peptide model checkpoint by sampling sequences +and computing metrics for the De Novo Peptide Generation table: + Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑), + Nonfouling (↑), Permeability (↑), Sampling Time (↓) +""" + +import os +import sys +import argparse +import time +import torch +import numpy as np +import pandas as pd + +# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) + +from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT +from lightning_modules import AnyOrderInsertionFlowModule +from inference_quality import sample_peptides_eval +from pep_scoring.scoring_functions import ScoringFunctions +from pep_utils.analyzer import PeptideAnalyzer +from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from finetune_quality import PeptideFinetuner +from pep_utils.utils import str2bool, set_seed +from tdc import Evaluator + + +# Protein sequences +PROTEINS = { + 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV', + 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF', + 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM', + 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS', + 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM', + 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF', + 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL', + 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS', + 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL', +} + + +def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'): + """Load a finetuned PeptideFinetuner from a Lightning checkpoint.""" + ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + hparams = ckpt.get('hyper_parameters', {}) + args = hparams.get('args', None) + + # Load pretrained base checkpoint to get config + base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False) + if 'hyper_parameters' in base_ckpt: + config = base_ckpt['hyper_parameters']['config'] + elif 'config' in base_ckpt: + config = base_ckpt['config'] + else: + raise ValueError("Cannot find config in base checkpoint") + + from omegaconf import OmegaConf, DictConfig + if not OmegaConf.is_config(config): + config = DictConfig(config) + OmegaConf.set_struct(config, False) + + config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True) + config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256) + config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2) + config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1) + config.training.freeze_base_model = getattr(args, 'freeze_base_model', False) + config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0) + OmegaConf.set_struct(config, True) + + disable_planner = getattr(args, 'disable_planner', False) + + policy_model = AnyOrderInsertionFlowModuleFT( + config=config, + args=args, + pretrained_checkpoint=pretrained_ckpt_path, + insertion_planner=not disable_planner, + ) + + # Load finetuned weights + state_dict = ckpt['state_dict'] + policy_state = {} + for k, v in state_dict.items(): + if k.startswith('policy_model.'): + policy_state[k[len('policy_model.'):]] = v + policy_model.load_state_dict(policy_state, strict=False) + policy_model = policy_model.to(device) + policy_model.eval() + + return policy_model, args, config + + +@torch.no_grad() +def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer, + num_samples=1000, batch_size=50, max_length=512, + total_num_steps=256, quality_mode="both", num_remasking=3, + quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'): + """ + Sample `num_samples` peptides and compute all table metrics. + Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time + """ + all_affinity = [] + all_sol = [] + all_hemo = [] + all_nf = [] + all_permeability = [] + all_valid_seqs = [] + total_valid = 0 + total_generated = 0 + total_time = 0.0 + + num_batches = (num_samples + batch_size - 1) // batch_size + remaining = num_samples + + for b in range(num_batches): + bs = min(batch_size, remaining) + remaining -= bs + + t_start = time.time() + result = sample_peptides_eval( + model=policy_model, + reward_model=reward_model, + analyzer=analyzer, + tokenizer=tokenizer, + steps=total_num_steps, + mask=policy_model.interpolant.mask_token, + pad=policy_model.interpolant.pad_token, + batch_size=bs, + max_length=max_length, + quality_mode=quality_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + unmask_quality_threshold=unmask_quality_threshold, + return_valid=True, + ) + t_end = time.time() + + # Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction + valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result + + batch_valid = len(valid_seqs) + total_valid += batch_valid + total_generated += bs + total_time += (t_end - t_start) + all_valid_seqs.extend(valid_seqs) + + if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0: + all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist()) + all_sol.extend(sol if isinstance(sol, list) else sol.tolist()) + all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist()) + all_nf.extend(nf if isinstance(nf, list) else nf.tolist()) + all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist()) + + print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, " + f"time={t_end - t_start:.1f}s") + + validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0 + + # Uniqueness (% of valid sequences that are unique) and + # Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences). + # Matches the convention used in evaluate_mol_table.py. + all_unique = list(set(all_valid_seqs)) + num_unique = len(all_unique) + uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0 + if num_unique > 1: + diversity = Evaluator('diversity')(all_unique) + else: + diversity = 0.0 + + metrics = { + 'Validity (%)': validity, + 'Uniqueness (%)': uniqueness, + 'Diversity': diversity, + 'Affinity': np.mean(all_affinity) if all_affinity else 0.0, + 'Affinity Std': np.std(all_affinity) if all_affinity else 0.0, + 'Solubility': np.mean(all_sol) if all_sol else 0.0, + 'Solubility Std': np.std(all_sol) if all_sol else 0.0, + 'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0, + 'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0, + 'Nonfouling': np.mean(all_nf) if all_nf else 0.0, + 'Nonfouling Std': np.std(all_nf) if all_nf else 0.0, + 'Permeability': np.mean(all_permeability) if all_permeability else 0.0, + 'Permeability Std': np.std(all_permeability) if all_permeability else 0.0, + 'Sampling Time (s)': total_time, + 'Num Generated': total_generated, + 'Num Valid': total_valid, + 'Num Unique': num_unique, + } + + return metrics + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint") + parser.add_argument('--checkpoint_path', type=str, required=True, + help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)') + parser.add_argument('--pretrained_ckpt', type=str, + default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'), + help='Path to the pretrained base model checkpoint') + parser.add_argument('--num_samples', type=int, default=500, + help='Number of peptides to sample') + parser.add_argument('--batch_size', type=int, default=50, + help='Batch size for sampling') + parser.add_argument('--max_length', type=int, default=512) + parser.add_argument('--total_num_steps', type=int, default=256) + parser.add_argument('--num_remasking', type=int, default=3) + parser.add_argument('--quality_threshold', type=float, default=0.5, + help='Threshold for insertion quality filtering during sampling') + parser.add_argument('--unmask_quality_threshold', type=float, default=None, + help='If set, gate unmasking/remasking by confidence: remask ' + 'ALL clean tokens whose unmasking confidence is below this ' + 'threshold, regardless of the schedule budget. If unset ' + '(default), remasking is purely schedule-driven (count-based).') + parser.add_argument('--prot_name', type=str, default='glast', + help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')') + parser.add_argument('--prot_seq', type=str, default=None, + help='Custom protein sequence (overrides --prot_name)') + parser.add_argument('--disable_planner', action='store_true', + help='If set, disable remasking during evaluation') + parser.add_argument('--disable_insertion_planner', action='store_true', + help='If set, disable insertion quality filtering during evaluation') + parser.add_argument('--disable_unmasking_planner', action='store_true', + help='If set, disable unmasking confidence planner during evaluation') + parser.add_argument('--output_dir', type=str, default=None, + help='Directory to save results CSV. Defaults to checkpoint directory.') + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--seed', type=int, default=42) + args = parser.parse_args() + + set_seed(args.seed, use_cuda=True) + device = torch.device(args.device if torch.cuda.is_available() else 'cpu') + + # Map flags to quality_mode + if args.disable_planner: + quality_mode = "none" + elif args.disable_insertion_planner and args.disable_unmasking_planner: + quality_mode = "none" + elif args.disable_insertion_planner: + quality_mode = "unmasking_only" + elif args.disable_unmasking_planner: + quality_mode = "insertion_only" + else: + quality_mode = "both" + + print(f"Loading checkpoint: {args.checkpoint_path}") + print(f"Pretrained base: {args.pretrained_ckpt}") + print(f"Quality mode: {quality_mode}") + + policy_model, train_args, config = load_finetuned_model( + args.checkpoint_path, args.pretrained_ckpt, device=device + ) + + # Setup tokenizer, reward model, analyzer + tokenizer = SMILES_SPE_Tokenizer( + os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'), + os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt') + ) + + if args.prot_seq is not None: + prot = args.prot_seq + prot_name = args.prot_name + else: + prot_name = args.prot_name + if prot_name not in PROTEINS: + raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}") + prot = PROTEINS[prot_name] + + score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'] + reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device) + analyzer = PeptideAnalyzer() + + print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...") + + metrics = evaluate_checkpoint( + policy_model=policy_model, + tokenizer=tokenizer, + reward_model=reward_model, + analyzer=analyzer, + num_samples=args.num_samples, + batch_size=args.batch_size, + max_length=args.max_length, + total_num_steps=args.total_num_steps, + quality_mode=quality_mode, + num_remasking=args.num_remasking, + quality_threshold=args.quality_threshold, + unmask_quality_threshold=args.unmask_quality_threshold, + device=device, + ) + + # Print summary table + print("\n" + "=" * 60) + print(" De Novo Peptide Generation Results") + print("=" * 60) + for k, v in metrics.items(): + if isinstance(v, float): + print(f" {k:<30s}: {v:.4f}") + else: + print(f" {k:<30s}: {v}") + print("=" * 60) + + # Save results + output_dir = args.output_dir or os.path.dirname(args.checkpoint_path) + os.makedirs(output_dir, exist_ok=True) + + if args.disable_planner: + tag = "no_planner" + elif args.disable_insertion_planner: + tag = "no_insertion_planner" + elif args.disable_unmasking_planner: + tag = "no_unmasking_planner" + else: + tag = "with_planner" + if args.unmask_quality_threshold is not None: + tag += f"_ut{args.unmask_quality_threshold:g}" + # Record the sweep parameter in the saved row for traceability. + metrics['unmask_quality_threshold'] = args.unmask_quality_threshold + metrics['quality_threshold'] = args.quality_threshold + metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv') + pd.DataFrame([metrics]).to_csv(metrics_path, index=False) + print(f"Metrics saved to: {metrics_path}") + + +if __name__ == '__main__': + main() diff --git a/a2d2_pep/finetune_quality.py b/a2d2_pep/finetune_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..89673c782a1e0db1c07653d56e1096ab73cbced5 --- /dev/null +++ b/a2d2_pep/finetune_quality.py @@ -0,0 +1,892 @@ +# Distributed Data Parallel (DDP) finetuning for peptide generation using PyTorch Lightning +import argparse +import math +from datetime import datetime +import numpy as np +import torch +import pytorch_lightning as pl +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +import wandb +import os +import sys +from tqdm import tqdm +import pandas as pd + +# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from inference_quality import sample_peptides_buffer, sample_peptides_eval +from pep_utils.analyzer import PeptideAnalyzer +from pep_utils.utils import str2bool, set_seed +from pep_scoring.scoring_functions import ScoringFunctions +from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT +from lightning_modules import AnyOrderInsertionFlowModule +from tdc import Evaluator + +# Repository root (two levels up from this file: A2D2/a2d2_pep/finetune_quality.py) +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class PeptideFinetuner(pl.LightningModule): + """Lightning module for distributed peptide finetuning.""" + + def __init__( + self, + args, + policy_model, + reward_model, + tokenizer, + pretrained=None, + mcts=None, + filename=None, + prot_name=None, + eps=1e-5 + ): + super().__init__() + self.args = args + self.policy_model = policy_model + self.reward_model = reward_model + self.tokenizer = tokenizer + self.pretrained = pretrained + self.mcts = mcts + self.filename = filename + self.prot_name = prot_name + self.eps = eps + + # Length cutoff is tunable from the CLI: --min_peptide_bonds N enforces + # >=N peptide bonds (filters degenerate short reward-hacked molecules); + # --min_peptide_bonds 0 disables the cutoff. + min_bonds = getattr(args, 'min_peptide_bonds', 4) + self.analyzer = PeptideAnalyzer( + min_peptide_bonds=max(0, min_bonds), + enforce_min_peptide_bonds=min_bonds > 0, + ) + + # Save hyperparameters + self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts']) + + # Buffer for sequences + self.x_saved = None + self.log_rnd_saved = None + self.final_rewards_saved = None + + # Logs + self.valid_fraction_log = [] + self.uniqueness_log = [] + self.diversity_log = [] + self.affinity_log = [] + self.sol_log = [] + self.hemo_log = [] + self.nf_log = [] + self.permeability_log = [] + self._diversity_evaluator = Evaluator('diversity') + + # Alternating training between policy and planner + self.train_policy = True # Start by training policy + self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs + + def freeze_policy_model(self): + """Freeze policy model parameters (but not planner).""" + for name, param in self.policy_model.named_parameters(): + if not name.startswith('planner.'): + param.requires_grad = False + + def unfreeze_policy_model(self): + """Unfreeze policy model parameters (but not planner).""" + for name, param in self.policy_model.named_parameters(): + if not name.startswith('planner.'): + param.requires_grad = True + + def freeze_planner_model(self): + """Freeze planner parameters.""" + if hasattr(self.policy_model, 'planner'): + for param in self.policy_model.planner.parameters(): + param.requires_grad = False + + def unfreeze_planner_model(self): + """Unfreeze planner parameters.""" + if hasattr(self.policy_model, 'planner'): + for param in self.policy_model.planner.parameters(): + param.requires_grad = True + + def configure_optimizers(self): + # Separate parameter groups for policy backbone vs planner heads + planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate) + planner_params = [] + policy_params = [] + for name, param in self.policy_model.named_parameters(): + if name.startswith('planner.'): + planner_params.append(param) + else: + policy_params.append(param) + + param_groups = [ + {'params': policy_params, 'lr': self.args.learning_rate}, + {'params': planner_params, 'lr': planner_lr}, + ] + optimizer = torch.optim.AdamW(param_groups) + return optimizer + + def _get_quality_mode(self): + """Map ablation flags + warmup state to quality_mode string.""" + if self.args.disable_planner: + return "none" + if self.current_epoch < self.args.schedule_warmup_epochs: + return "none" + di = getattr(self.args, 'disable_insertion_planner', False) + du = getattr(self.args, 'disable_unmasking_planner', False) + if di and du: + return "none" + if di: + return "unmasking_only" + if du: + return "insertion_only" + return "both" + + def on_save_checkpoint(self, checkpoint): + """ + Save additional metadata to make loading easier. + Saves the config directly in the checkpoint so loading doesn't need to follow references. + """ + # Save the config from the policy model directly in the checkpoint + if hasattr(self.policy_model, 'config'): + checkpoint['config'] = self.policy_model.config + print(f"Saved config to checkpoint for easier loading") + + # Save EMA params if they exist in the policy model + if hasattr(self.policy_model, 'ema_params') and self.policy_model.ema_params: + checkpoint['ema_params'] = self.policy_model.ema_params + print(f"Saved EMA params to checkpoint") + + # Save planner state if it exists + if hasattr(self.policy_model, 'planner'): + checkpoint['planner_state'] = self.policy_model.planner.state_dict() + print(f"Saved planner state to checkpoint") + + def on_train_epoch_start(self): + """Called at the start of each training epoch.""" + # If disable_planner mode, only train policy (no alternation) + if self.args.disable_planner: + self.train_policy = True + self.unfreeze_policy_model() + self.freeze_planner_model() + if self.global_rank == 0 and self.current_epoch == 0: + print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)") + elif getattr(self.args, 'joint_training', False): + # Joint mode: train policy + planner together every step (no alternation) + self.train_policy = True # marker; training_step adds planner loss when joint_training is set + self.unfreeze_policy_model() + self.unfreeze_planner_model() + if self.global_rank == 0 and self.current_epoch == 0: + print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)") + else: + # Alternate between training policy and planner from epoch 0 + # Determine which model to train this epoch + cycle_position = (self.current_epoch // self.alternation_frequency) % 2 + self.train_policy = (cycle_position == 0) + + if self.train_policy: + # Train policy, freeze planner + self.unfreeze_policy_model() + self.freeze_planner_model() + if self.global_rank == 0: + print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)") + else: + # Train planner, freeze policy + self.freeze_policy_model() + self.unfreeze_planner_model() + if self.global_rank == 0: + print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)") + + # Resample buffer if needed + if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0: + self._generate_buffer() + # Synchronize all ranks after buffer generation to prevent NCCL timeout + if self.trainer and self.trainer.world_size > 1: + torch.distributed.barrier() + + def _generate_buffer(self): + """Generate buffer of sequences for training - all ranks generate in parallel. + + When pool_size > 0, maintains a persistent pool and refreshes a fraction + each time instead of regenerating the entire buffer from scratch. This + preserves diversity/uniqueness across training by avoiding wholesale + replacement with samples from an increasingly mode-collapsed policy. + """ + world_size = self.trainer.world_size if self.trainer else 1 + rank = self.global_rank if self.trainer else 0 + + pool_size = getattr(self.args, 'pool_size', 0) + is_pool = pool_size > 0 + is_init = self.x_saved is None + + # Determine how many sequences to sample this call + if is_pool: + refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2) + if is_init: + samples_per_gpu = pool_size + else: + samples_per_gpu = max(1, int(pool_size * refresh_frac)) + if rank == 0: + if is_init: + print(f"\n[POOL] Initializing pool with {pool_size} sequences at epoch {self.current_epoch}") + else: + print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} sequences ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}") + else: + samples_per_gpu = self.args.buffer_size // world_size + if rank == 0: + samples_per_gpu += self.args.buffer_size % world_size + + accumulated_x = [] + accumulated_log_rnd = [] + accumulated_rewards = [] + total_accumulated = 0 + + if rank == 0: + print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}") + print(f"[BUFFER] GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + print(f"[BUFFER] GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB") + if not is_pool: + print(f"[BUFFER] Each of {world_size} ranks will generate {samples_per_gpu} samples") + + max_attempts = getattr(self.args, 'max_buffer_attempts', 100) # cap wasted GPU / infinite loop + starvation_patience = getattr(self.args, 'buffer_starvation_patience', 10) + attempts = 0 + + import time + while total_accumulated < samples_per_gpu and attempts < max_attempts: + attempts += 1 + if rank == 0: + print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}") + + start_time = time.time() + + # new elbo loss + if self.args.elbo_rnd: + x_final, _, final_rewards, trace = \ + sample_peptides_buffer( + self.policy_model, + self.reward_model, self.analyzer, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=self.args.batch_size, + max_length=self.args.max_length, + # Buffer generation never uses the quality heads (planner): + # the backbone must train on raw policy samples so that a + # poorly-trained planner can't corrupt the backbone's data. + quality_mode="none", + compute_rnd=False, + alpha=self.args.alpha, + num_remasking=self.args.num_remasking, + min_length=self.args.min_length, + ) + if x_final.shape[0] > 0: + with torch.no_grad(): + noised = self.policy_model.prepare_noised_sample( + x_final, num_samples=self.args.elbo_rnd_num_samples) + policy_loss = self.policy_model.compute_loss_from_noised(noised) + pretrained_loss = self.pretrained.compute_loss_from_noised(noised) + log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha) + else: + log_rnd = torch.empty((0,), dtype=torch.float32, device=x_final.device) + else: + x_final, log_rnd, final_rewards, trace = \ + sample_peptides_buffer( + self.policy_model, + self.reward_model, self.analyzer, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=self.args.batch_size, + max_length=self.args.max_length, + # Buffer generation never uses the quality heads (planner): + # the backbone must train on raw policy samples so that a + # poorly-trained planner can't corrupt the backbone's data. + quality_mode="none", + compute_rnd=True, + pretrained=self.pretrained, + alpha=self.args.alpha, + num_remasking=self.args.num_remasking, + min_length=self.args.min_length, + ) + elapsed = time.time() - start_time + if rank == 0: + print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s") + + n_valid = x_final.shape[0] + if n_valid > 0: + accumulated_x.append(x_final) + accumulated_log_rnd.append(log_rnd) + accumulated_rewards.append(final_rewards) + total_accumulated += n_valid + + if rank == 0: + print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode=none (heads disabled for buffer gen) accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}") + + # Starvation guard: if nothing valid comes through (e.g. the length + # cutoff is too aggressive for a collapsed policy), stop grinding GPU + # hours and fail fast with an actionable message. + if attempts >= starvation_patience and total_accumulated == 0: + if rank == 0: + print(f"[BUFFER STARVATION] 0 valid samples after {attempts} attempts " + f"(min_peptide_bonds={getattr(self.args, 'min_peptide_bonds', 4)}). " + f"Aborting refill early — lower --min_peptide_bonds or check the policy.") + break + + if total_accumulated == 0: + raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.") + + if total_accumulated < samples_per_gpu: + print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts") + + new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu] + new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu] + new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu] + + del accumulated_x, accumulated_log_rnd, accumulated_rewards + torch.cuda.empty_cache() + + # Pool mode (after init): replace a random subset of the existing pool. + # Classic mode / pool init: overwrite the buffer. + if is_pool and not is_init: + actual_new = min(new_x.shape[0], self.x_saved.shape[0]) + indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new] + self.x_saved[indices] = new_x[:actual_new] + self.log_rnd_saved[indices] = new_log_rnd[:actual_new] + self.final_rewards_saved[indices] = new_rewards[:actual_new] + if rank == 0: + print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} sequences, reward mean={self.final_rewards_saved.mean():.4f}") + else: + self.x_saved = new_x + self.log_rnd_saved = new_log_rnd + self.final_rewards_saved = new_rewards + + # Sanity check: median length (non-pad tokens) of buffered peptides. + if rank == 0: + pad = self.policy_model.interpolant.pad_token + token_lens = (self.x_saved != pad).sum(dim=1) + print(f"[BUFFER] peptide token length: median={token_lens.median().item()} " + f"min={token_lens.min().item()} max={token_lens.max().item()} " + f"(n={token_lens.shape[0]})") + + def training_step(self, batch, batch_idx): + """Training step - batch is ignored, we use saved buffer.""" + # Use mini-batch sampling from buffer to avoid OOM + buffer_size = self.x_saved.shape[0] + mini_batch_size = getattr(self.args, 'training_mini_batch_size', 6) + + # Randomly sample mini_batch_size sequences from buffer + if buffer_size > mini_batch_size: + indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size] + x_final = self.x_saved[indices] + log_rnd = self.log_rnd_saved[indices] + else: + # If buffer is smaller than mini_batch_size, use all + x_final = self.x_saved + log_rnd = self.log_rnd_saved + + joint = getattr(self.args, 'joint_training', False) + policy_loss = None + planner_loss = None + + if self.train_policy: + # Train policy with WDCE loss + policy_loss = self.policy_model.loss_wdce_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength + ) + self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + if (not self.train_policy) or joint: + # Train planner with appropriate loss based on ablation flags + if self.args.disable_insertion_planner: + # Ablation: only train unmasking/remasking planner (no insertion head) + planner_loss = self.policy_model.loss_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength + ) + self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + elif self.args.disable_unmasking_planner: + # Ablation: only train insertion planner (no remasking head) + unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength + ) + # Zero out the unmasking component - only backprop insertion loss + planner_loss = insert_loss + self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + else: + # Full planner: train both remasking + insertion + unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible( + log_rnd, + x_final, + num_replicates=self.args.wdce_num_replicates, + centering=self.args.centering, + centering_strength=self.args.centering_strength + ) + self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + + # Combine losses depending on mode + if joint: + loss = policy_loss + planner_loss + mode_value = 0.5 + elif self.train_policy: + loss = policy_loss + mode_value = 0.0 + else: + loss = planner_loss + mode_value = 1.0 + + # Log overall loss and mode + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train/mode', mode_value, prog_bar=True, sync_dist=True) + + return loss + + def on_train_epoch_end(self): + """Called at the end of each training epoch - only rank 0 evaluates.""" + # Only evaluate every N epochs to save time + eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5) + is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1) + if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch): + # Sample eval batch with updated policy + valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = \ + sample_peptides_eval( + self.policy_model, self.reward_model, self.analyzer, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=50, + max_length=self.args.max_length, + quality_mode=self._get_quality_mode(), + num_remasking=self.args.num_remasking, + return_valid=True, + ) + + # Uniqueness (% of valid that are unique) and Diversity + # (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences), + # matching evaluate_peptide_table.py / evaluate_mol_table.py. + num_valid = len(valid_seqs) + unique_seqs = list(set(valid_seqs)) + num_unique = len(unique_seqs) + uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0 + diversity = self._diversity_evaluator(unique_seqs) if num_unique > 1 else 0.0 + + # Append to logs + self.affinity_log.append(affinity) + self.sol_log.append(sol) + self.hemo_log.append(hemo) + self.nf_log.append(nf) + self.permeability_log.append(permeability) + self.valid_fraction_log.append(valid_fraction) + self.uniqueness_log.append(uniqueness) + self.diversity_log.append(diversity) + + # Compute reward stats + mean_reward = self.final_rewards_saved.mean().item() + min_reward = self.final_rewards_saved.min().item() + max_reward = self.final_rewards_saved.max().item() + median_reward = self.final_rewards_saved.median().item() + + # Log metrics + self.log_dict({ + "eval/affinity": np.mean(affinity), + "eval/sol": np.mean(sol), + "eval/hemo": np.mean(hemo), + "eval/nf": np.mean(nf), + "eval/permeability": np.mean(permeability), + "eval/valid_fraction": valid_fraction, + "eval/uniqueness": uniqueness, + "eval/diversity": diversity, + "eval/mean_reward_search": mean_reward, + "eval/min_reward_search": min_reward, + "eval/max_reward_search": max_reward, + "eval/median_reward_search": median_reward + }) + + print(f"epoch {self.current_epoch} | affinity {np.mean(affinity):.4f} | " + f"sol {np.mean(sol):.4f} | hemo {np.mean(hemo):.4f} | " + f"nf {np.mean(nf):.4f} | permeability {np.mean(permeability):.4f} | " + f"valid {valid_fraction:.4f} | uniq {uniqueness:.2f}% | div {diversity:.4f}") + + def on_fit_end(self): + """Called at the end of training - save results.""" + if self.global_rank == 0: + # Save logs and plot + base_path = self.args.base_path + plot_path = f'{base_path}/results/{self.args.run_name}' + os.makedirs(plot_path, exist_ok=True) + + output_log_path = f'{plot_path}/log_{self.filename}.csv' + save_logs_to_file(self.valid_fraction_log, self.affinity_log, + self.sol_log, self.hemo_log, self.nf_log, + self.permeability_log, output_log_path, + uniqueness_log=self.uniqueness_log, + diversity_log=self.diversity_log) + + # Final generation + x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = \ + sample_peptides_eval( + self.policy_model, self.reward_model, self.analyzer, + self.tokenizer, + steps=self.args.total_num_steps, + mask=self.policy_model.interpolant.mask_token, + pad=self.policy_model.interpolant.pad_token, + batch_size=50, + max_length=self.args.max_length, + quality_mode=self._get_quality_mode(), + num_remasking=self.args.num_remasking, + dataframe=True, + ) + df.to_csv(f'{plot_path}/{self.prot_name}_generation_results.csv', index=False) + + +def save_logs_to_file(valid_fraction_log, affinity_log, + sol_log, hemo_log, nf_log, + permeability_log, output_path, + uniqueness_log=None, diversity_log=None): + """ + Saves the logs to a CSV file. + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + log_data = { + "Iteration": list(range(1, len(valid_fraction_log) + 1)), + "Valid Fraction": valid_fraction_log, + "Binding Affinity": affinity_log, + "Solubility": sol_log, + "Hemolysis": hemo_log, + "Nonfouling": nf_log, + "Permeability": permeability_log, + } + if uniqueness_log is not None: + log_data["Uniqueness (%)"] = uniqueness_log + if diversity_log is not None: + log_data["Diversity"] = diversity_log + + df = pd.DataFrame(log_data) + df.to_csv(output_path, index=False) + + +class DummyDataset(torch.utils.data.Dataset): + """Dummy dataset for Lightning trainer (we use buffer instead).""" + def __init__(self, size=10): + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return torch.zeros(1) # Dummy data + + +def main(): + """Main entry point for distributed training.""" + # Disable DDP optimizer for higher-order ops like flex_attention + import torch._dynamo + torch._dynamo.config.optimize_ddp = False + + argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + argparser.add_argument('--base_path', type=str, default=REPO_ROOT) + argparser.add_argument('--learning_rate', type=float, default=1e-4) + argparser.add_argument('--num_epochs', type=int, default=100) + argparser.add_argument('--num_accum_steps', type=int, default=4) + argparser.add_argument('--truncate_steps', type=int, default=50) + argparser.add_argument("--truncate_kl", type=str2bool, default=False) + argparser.add_argument('--gumbel_temp', type=float, default=1.0) + argparser.add_argument('--gradnorm_clip', type=float, default=1.0) + argparser.add_argument('--batch_size', type=int, default=50) + argparser.add_argument('--name', type=str, default='debug') + argparser.add_argument('--total_num_steps', type=int, default=128) + argparser.add_argument('--copy_flag_temp', type=float, default=None) + argparser.add_argument('--save_every_n_epochs', type=int, default=10) + argparser.add_argument('--alpha_schedule_warmup', type=int, default=0) + argparser.add_argument("--seed", type=int, default=0) + # new + argparser.add_argument('--run_name', type=str, default='peptides') + argparser.add_argument("--save_path_dir", default=os.path.join(REPO_ROOT, "checkpoints", "finetune_peptides"), type=str) + # mcts + argparser.add_argument('--num_sequences', type=int, default=10) + argparser.add_argument('--max_length', type=int, default=1024) + argparser.add_argument('--min_length', type=int, default=0, + help='Minimum sequence length (in SMILES SPE tokens). ' + 'Samples shorter than this are dropped from the buffer. 0 disables the filter.') + argparser.add_argument('--num_children', type=int, default=50) + argparser.add_argument('--num_iter', type=int, default=30) + argparser.add_argument('--seq_length', type=int, default=1024) + argparser.add_argument('--time_conditioning', action='store_true', default=False) + argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise + argparser.add_argument('--buffer_size', type=int, default=100) + argparser.add_argument('--wdce_num_replicates', type=int, default=16) + argparser.add_argument('--noise_removal', action='store_true', default=False) + argparser.add_argument('--grad_clip', action='store_true', default=False) + argparser.add_argument('--resample_every_n_step', type=int, default=10) + argparser.add_argument('--exploration', type=float, default=0.1) + argparser.add_argument('--reset_every_n_step', type=int, default=100) + argparser.add_argument('--alpha', type=float, default=0.01) + argparser.add_argument('--scalarization', type=str, default='sum') + argparser.add_argument('--no_mcts', action='store_true', default=False) + argparser.add_argument("--centering", action='store_true', default=False) + argparser.add_argument("--centering_strength", type=float, default=1.0) + + # ELBO-based log_rnd estimation + argparser.add_argument('--elbo_rnd', action='store_true', default=False, + help='If set, compute log_rnd via ELBO instead of trajectory rollout') + argparser.add_argument('--elbo_rnd_num_samples', type=int, default=16, + help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation') + + # adaptive schedule parameters + argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True) + argparser.add_argument('--schedule_hidden_dim', type=int, default=256) + argparser.add_argument('--schedule_num_layers', type=int, default=2) + argparser.add_argument('--schedule_loss_weight', type=float, default=0.1) + argparser.add_argument('--adaptive_threshold', type=float, default=0.5) + argparser.add_argument('--freeze_base_model', action='store_true', default=False) + argparser.add_argument('--schedule_warmup_epochs', type=int, default=0, help='Number of initial epochs to train WITHOUT remasking in buffer generation') + argparser.add_argument('--alternation_frequency', type=int, default=20, help='Number of epochs to train each model before alternating (1=alternate every epoch)') + argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)') + + # objectives + argparser.add_argument('--num_obj', type=int, default=5) + argparser.add_argument('--prot_seq', type=str, default=None) + argparser.add_argument('--prot_name', type=str, default='glast', + help='Protein target name. Looked up in PROTEINS dict unless --prot_seq is given.') + argparser.add_argument('--devices', type=int, default=-1) + argparser.add_argument('--checkpoint_path', type=str, default=None) + argparser.add_argument('--resume_ckpt', type=str, default=None, + help='Path to a Lightning last.ckpt to resume training from (restores epoch/optimizer/planner state). ' + 'New checkpoints continue in the same directory as this checkpoint.') + + # remasking + argparser.add_argument('--num_remasking', type=int, default=5) + argparser.add_argument('--quality_threshold', type=float, default=1) + + # length cutoff (peptide-bond filter) + buffer starvation guard + argparser.add_argument('--min_peptide_bonds', type=int, default=4, + help='Minimum backbone peptide bonds for a sample to count as valid. ' + '0 disables the cutoff. Filters degenerate short reward-hacked molecules.') + argparser.add_argument('--max_buffer_attempts', type=int, default=100, + help='Max sampling rounds per buffer refill before giving up (caps wasted GPU when validity is low).') + argparser.add_argument('--buffer_starvation_patience', type=int, default=10, + help='If 0 valid samples after this many rounds, abort the refill early (starvation guard).') + + # planner ablation flags + argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization') + argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner') + argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering') + argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.') + + # performance optimization + argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time') + argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10, help='Number of gradient updates per epoch') + argparser.add_argument('--training_mini_batch_size', type=int, default=6, help='Mini-batch size for training from buffer to avoid OOM') + argparser.add_argument('--pool_size', type=int, default=0, + help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer). Helps preserve uniqueness/diversity over training.') + argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2, + help='Fraction of pool to replace each resample step (only used when pool_size>0)') + + args = argparser.parse_args() + + # Default planner LR to policy LR if not specified + if args.planner_learning_rate is None: + args.planner_learning_rate = args.learning_rate + + # Set seed + pl.seed_everything(args.seed) + + # Load models + checkpoint_path = args.checkpoint_path if args.checkpoint_path else \ + os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt') + + # Update args.checkpoint_path to ensure it's saved in hyperparameters for later inference + args.checkpoint_path = checkpoint_path + + PROTEINS = { + 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV', + 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF', + 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM', + 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS', + 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM', + 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF', + 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL', + 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS', + 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL', + } + + if args.prot_seq is not None: + prot = args.prot_seq + prot_name = args.prot_name + else: + prot_name = args.prot_name + if prot_name not in PROTEINS: + raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}") + prot = PROTEINS[prot_name] + filename = prot_name + + curr_time = datetime.now().strftime("%Y%m%d_%H%M%S") + + if args.no_mcts: + args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_no-mcts' + else: + args.run_name = f'{curr_time}_adaptive_{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}' + + # Append ablation tags to run name for easy identification + if args.disable_planner: + args.run_name += '_no_planner' + if args.disable_insertion_planner: + args.run_name += '_no_insertion_planner' + if args.disable_unmasking_planner: + args.run_name += '_no_unmasking_planner' + if args.joint_training: + if args.disable_planner: + raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)") + args.run_name += '_joint_training' + + # When resuming, continue writing checkpoints into the SAME directory as the + # checkpoint we resume from (keeps model-{epoch}.ckpt contiguous) instead of + # spawning a fresh timestamped run directory. + if args.resume_ckpt: + args.save_path = os.path.dirname(os.path.abspath(args.resume_ckpt)) + args.run_name = os.path.basename(args.save_path) + else: + args.save_path = os.path.join(args.save_path_dir, args.run_name) + os.makedirs(args.save_path, exist_ok=True) + set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers + + # Initialize the model + print("Loading models..") + + # Load pretrained model for reference (frozen) + pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path, + map_location='cpu', + weights_only=False) + pretrained.eval() + for param in pretrained.parameters(): + param.requires_grad = False + + # Load checkpoint to extract config + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + if 'hyper_parameters' in checkpoint: + config = checkpoint['hyper_parameters']['config'] + elif 'config' in checkpoint: + config = checkpoint['config'] + else: + raise ValueError("Cannot find config in checkpoint") + + # Update config for adaptive schedule + from omegaconf import OmegaConf + if not OmegaConf.is_config(config): + from omegaconf import DictConfig + config = DictConfig(config) + + # Disable struct mode to allow adding new keys + OmegaConf.set_struct(config, False) + + config.training.use_adaptive_schedule = args.use_adaptive_schedule + config.training.schedule_hidden_dim = args.schedule_hidden_dim + config.training.schedule_num_layers = args.schedule_num_layers + config.training.schedule_loss_weight = args.schedule_loss_weight + config.training.freeze_base_model = args.freeze_base_model + config.training.schedule_warmup_epochs = args.schedule_warmup_epochs + + # Re-enable struct mode + OmegaConf.set_struct(config, True) + + # Initialize policy model with adaptive schedule + policy_model = AnyOrderInsertionFlowModuleFT( + config=config, + args=args, + pretrained_checkpoint=checkpoint_path, + insertion_planner=True, + ) + + # define mcts + score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'] + + tokenizer = SMILES_SPE_Tokenizer( + os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'), + os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt') + ) + + # Device will be set by Lightning automatically in DDP + reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device='cpu') + model = PeptideFinetuner( + args=args, + policy_model=policy_model, + reward_model=reward_model, + tokenizer=tokenizer, + pretrained=pretrained, + mcts=None, + filename=filename, + prot_name=prot_name + ) + + # Setup checkpoint callback + checkpoint_callback = ModelCheckpoint( + dirpath=args.save_path, + filename='model-{epoch:02d}', + every_n_epochs=args.save_every_n_epochs, + save_top_k=-1, + save_last=True, # Also save last.ckpt + auto_insert_metric_name=False + ) + + # Setup wandb logger - only on rank 0 to avoid multiple runs + # Check if we're in a spawned DDP process + rank = int(os.environ.get('LOCAL_RANK', 0)) + if rank == 0: + # Defaults to your default wandb entity; override with the WANDB_ENTITY env var. + wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-pep', name=args.run_name) + else: + wandb_logger = None + + # Create dummy dataloader + dataset = DummyDataset(size=args.num_training_steps_per_epoch) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) + + # Setup trainer with DDP + trainer = pl.Trainer( + max_epochs=args.num_epochs, + accelerator='gpu', + devices=args.devices, + strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto', + gradient_clip_val=args.gradnorm_clip if args.grad_clip else None, + logger=wandb_logger, + callbacks=[checkpoint_callback], + enable_progress_bar=True, + log_every_n_steps=1 + ) + + # Train (resume full training state from --resume_ckpt if provided). + # weights_only=False is required when resuming because these checkpoints + # store argparse.Namespace / OmegaConf objects in hyper_parameters, which + # PyTorch 2.6's default weights_only=True unpickler rejects. + if args.resume_ckpt: + trainer.fit(model, dataloader, ckpt_path=args.resume_ckpt, weights_only=False) + else: + trainer.fit(model, dataloader) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/a2d2_pep/inference_quality.py b/a2d2_pep/inference_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9ddc5cabc3bedc557e2c4a045362c8af2e4efa --- /dev/null +++ b/a2d2_pep/inference_quality.py @@ -0,0 +1,605 @@ +"""Unified peptide sampling with quality-guided planning. + +Supports 4 quality modes and optional RND (importance weight) computation. + +Quality modes: + "none" - No planner, no remasking (policy-only) + "both" - Both unmasking + insertion planners active + "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled) + "insertion_only" - Only insertion planner (unmasking planner disabled) + +RND toggle: + compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights + compute_rnd=False - Run policy model only (use with ELBO-based RND or eval) +""" + +import os +import torch +import numpy as np +import pandas as pd +import torch.nn.functional as F +from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens +from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion + +QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"} + +# When set (e.g. A2D2_QUALITY_DEBUG=1), the diffusion loop prints, per step, how +# many already-unmasked tokens get remasked and how many proposed insertions get +# filtered by the quality planner, plus a per-batch total. Off by default so it +# never spams training/eval runs. +_QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False") + + +@torch.no_grad() +def _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode="both", + compute_rnd=False, + pretrained=None, + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + unmask_quality_threshold=None, + unmask_all=False, + freq_penalty=0.0, + return_trace=False, +): + """Core discrete diffusion sampling loop for peptide generation. + + Args: + model: Finetuned policy model. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: One of "none", "both", "unmasking_only", "insertion_only". + compute_rnd: Whether to compute step-wise log importance weights. + pretrained: Frozen pretrained model (required if compute_rnd=True). + remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf"). + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. + return_trace: Whether to record sampling trace. + + Returns: + (xt, log_rnd, sampling_trace) + log_rnd is None when compute_rnd=False. + """ + assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}" + if compute_rnd: + assert pretrained is not None, "pretrained model required when compute_rnd=True" + + # Derive flags from quality_mode + use_remasking = quality_mode != "none" + disable_unmasking_planner = quality_mode in ("none", "insertion_only") + disable_insertion_planner = quality_mode in ("none", "unmasking_only") + + device = next(model.parameters()).device + + # Initialize all-pad sequence + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute index tensors + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + neg_inf = torch.tensor(-np.inf, device=device) + + if use_remasking and remasking_mode == "remdm_conf": + remasking_score = torch.zeros((batch_size, max_length), device=device) + + log_rnd = None + + dbg_total_remasked = 0 + dbg_total_proposed_ins = 0 + dbg_total_filtered = 0 + + for i in range(steps): + step_remasked = 0 + step_proposed_ins = 0 + step_filtered = 0 + # --- Policy model forward --- + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # --- Pretrained model forward (for RND) --- + if compute_rnd: + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # --- Unmask step (Euler) --- + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + if compute_rnd: + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # Add "stay" probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + if compute_rnd: + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + # Remove mask token from sampling so every masked position is decoded. + # The final step always does this; unmask_all does it every step, so the + # schedule-aware remasking below re-masks the lowest-quality tokens back + # down to the schedule's expected mask count. + if i == steps - 1 or unmask_all: + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 + + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + num_zero_prob = mask_has_zero_prob.sum().item() + uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) + uniform_prob[:, :mask] = 1.0 / mask + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + # --- Frequency penalty: down-weight residues already abundant in the + # sequence so (re)decoded masked positions don't collapse onto the modal + # token (glycine). Only masked positions are sampled; clean positions are + # overwritten below, so penalizing the whole tensor is harmless. mask/pad + # never accumulate counts, so their entries stay untouched. Applied to a + # copy so trans_prob (used for RND log-probs) is unchanged. + sample_prob = trans_prob + if freq_penalty > 0.0: + V = trans_prob.shape[-1] + clean_tok = (xt != mask) & (xt != pad) # (B, L) + counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype) + counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)), + clean_tok.to(trans_prob.dtype)) + sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1) + + new_xt = _sample_tokens(sample_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # Update remasking_score buffer for remdm_conf mode + if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1: + token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V) + chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L) + changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad) + remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score) + + # --- Remasking step --- + if use_remasking and i < steps - 1: + if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None): + remasking_conf = torch.zeros((batch_size, max_length), device=device) + else: + planner_out = model.planner(new_xt, t) + remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L) + + clean_index = (new_xt != mask) & (new_xt != pad) # (B, L) + + if remasking_mode == "remdm": + remasking_score_temp = torch.rand(remasking_conf.shape, device=device) + elif remasking_mode == "remdm_conf": + remasking_score_temp = -1.0 * remasking_conf + elif remasking_mode == "schedule_aware": + # Only remask when the unmasking planner is active. Otherwise + # (e.g. insertion_only / no_unmasking_planner) remasking_conf is + # all zeros, so this would remask schedule-excess tokens by + # position rather than by quality. + if not disable_unmasking_planner: + new_xt = apply_schedule_aware_remasking( + model, new_xt, t, dt, remasking_conf, clean_index, + mask, neg_inf, batch_size, + unmask_quality_threshold=unmask_quality_threshold, + ) + remasking_score_temp = None + else: + raise ValueError(f"Unknown remasking_mode: {remasking_mode}") + + if remasking_score_temp is not None: + remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) + for j in range(batch_size): + k = min(num_remasking, int(clean_index[j].sum().item())) + if k > 0: + _, select_indices = torch.topk(remasking_score_temp[j], k=k) + new_xt[j, select_indices] = mask + + if _QUALITY_DEBUG: + # Positions that were clean before this remasking block and are + # now mask are exactly the unmasked tokens that got remasked. + step_remasked = int((clean_index & (new_xt == mask)).sum().item()) + + if return_trace: + for batch_idx in range(batch_size): + for pos in range(max_length): + if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=pos, + token=mask, + ) + ) + + # --- Compute log probabilities for RND --- + if compute_rnd: + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + # --- Insertion step --- + if i != steps - 1: + ext = torch.poisson(len_rate * dt).long() # (B, L+1) + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + + if _QUALITY_DEBUG: + # ext has been masked by the max-length validity check above, so + # this is the number of fresh mask tokens actually inserted. + step_proposed_ins = int(ext.sum().item()) + + # Schedule-aware insertion quality filtering + if use_remasking and not disable_insertion_planner: + if compute_rnd: + xt_tmp_before = xt_tmp.clone() + + dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0 + + xt_tmp = apply_schedule_aware_insertion( + model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length, + orig_mask, new_pos_orig, quality_threshold + ) + + if _QUALITY_DEBUG: + # Filtering only drops/compacts tokens, so the drop in + # non-pad count is the number of insertions filtered out. + step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item()) + + if compute_rnd: + # Compute corrected ext based on what actually stayed + ext_corrected = torch.zeros_like(ext) + for b in range(batch_size): + after_len = xt_tmp[b].ne(pad).sum().item() + orig_len = xt_len[b].item() + surviving_insertions = after_len - orig_len + if total_ext[b] > 0: + ratio = surviving_insertions / total_ext[b].item() + ext_corrected[b] = (ext[b].float() * ratio).long() + else: + ext_corrected = ext + else: + ext_corrected = ext + + # Compute insertion log_rnd + if compute_rnd: + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1) + log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) + + log_insert_diff = log_pretrained_insert - log_policy_insert + log_rnd += log_insert_diff + else: + xt_tmp = new_xt + + if return_trace: + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + if i != steps - 1: + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + if _QUALITY_DEBUG: + dbg_total_remasked += step_remasked + dbg_total_proposed_ins += step_proposed_ins + dbg_total_filtered += step_filtered + print( + f"[QUALITY {quality_mode}] step {i+1}/{steps}: " + f"remasked {step_remasked} unmasked tokens -> mask | " + f"insertions proposed {step_proposed_ins}, " + f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}" + ) + + xt = xt_tmp + t = t + dt + + if _QUALITY_DEBUG: + print( + f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): " + f"remasked {dbg_total_remasked} unmasked tokens | " + f"insertions proposed {dbg_total_proposed_ins}, " + f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}" + ) + + return xt, log_rnd, sampling_trace + + +@torch.no_grad() +def sample_peptides_buffer( + model, reward_model, analyzer, tokenizer, + steps, mask, pad, batch_size, max_length, + quality_mode="both", + compute_rnd=False, + pretrained=None, + alpha=0.1, + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + min_length=0, +): + """Generate peptides for training buffer. + + Args: + model: Finetuned policy model. + reward_model: Multi-objective scoring function. + analyzer: PeptideAnalyzer for validation. + tokenizer: Tokenizer for decoding. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: "none", "both", "unmasking_only", or "insertion_only". + compute_rnd: If True, compute step-wise log importance weights (requires pretrained). + If False, returns placeholder zero log_rnd (for ELBO-based RND). + pretrained: Frozen pretrained model (required when compute_rnd=True). + alpha: RND scaling factor. + remasking_mode: Remasking strategy. + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. + + Returns: + (valid_x, log_rnd, scalar_rewards, sampling_trace) + """ + xt, log_rnd, trace = _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode=quality_mode, + compute_rnd=compute_rnd, + pretrained=pretrained, + remasking_mode=remasking_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + ) + + device = xt.device + decoded_samples = tokenizer.batch_decode(xt) + + valid_x_final = [] + validSequences = [] + valid_log_rnd = [] + + for idx, seq in enumerate(decoded_samples): + if not analyzer.is_peptide(seq): + continue + token_len = int((xt[idx] != pad).sum().item()) + if min_length > 0 and token_len < min_length: + continue + valid_x_final.append(xt[idx]) + validSequences.append(seq) + if compute_rnd: + valid_log_rnd.append(log_rnd[idx]) + + print("len valid sequences:", len(validSequences)) + + if len(validSequences) == 0: + print("[WARNING] No valid peptides generated in this batch") + empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) + empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) + empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) + return empty_x, empty_log_rnd, empty_rewards, trace + + score_vectors = reward_model(input_seqs=validSequences) + scalar_rewards = np.sum(score_vectors, axis=-1) + scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) + + print(f"scalar reward dim{len(scalar_rewards)}") + valid_x_final = torch.stack(valid_x_final, dim=0) + + if compute_rnd: + valid_log_rnd = torch.stack(valid_log_rnd, dim=0) + log_rnd_out = valid_log_rnd + (scalar_rewards / alpha) + else: + log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device) + + return valid_x_final, log_rnd_out, scalar_rewards, trace + + +@torch.no_grad() +def sample_peptides_eval( + model, reward_model, analyzer, tokenizer, + steps, mask, pad, batch_size, max_length, + quality_mode="both", + remasking_mode="schedule_aware", + num_remasking=1, + quality_threshold=1, + unmask_quality_threshold=None, + unmask_all=False, + freq_penalty=0.0, + dataframe=False, + return_valid=False, +): + """Generate peptides for evaluation. + + Args: + model: Finetuned policy model. + reward_model: Multi-objective scoring function. + analyzer: PeptideAnalyzer for validation. + tokenizer: Tokenizer for decoding. + steps: Number of diffusion steps. + mask: Mask token ID. + pad: Pad token ID. + batch_size: Number of sequences to generate. + max_length: Maximum sequence length. + quality_mode: "none", "both", "unmasking_only", or "insertion_only". + remasking_mode: Remasking strategy. + num_remasking: Number of tokens to remask per step. + quality_threshold: Threshold for insertion quality filtering. + dataframe: If True, include a pandas DataFrame in the return. + return_valid: If True, return decoded valid sequences instead of raw token tensors. + + Returns: + For multi-objective (5 objectives): + (samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df]) + For single objective: + (samples, sol, valid_fraction[, df]) + When return_valid=True, samples is replaced with validSequences list. + """ + xt, _, trace = _diffusion_loop( + model, steps, mask, pad, batch_size, max_length, + quality_mode=quality_mode, + compute_rnd=False, + remasking_mode=remasking_mode, + num_remasking=num_remasking, + quality_threshold=quality_threshold, + unmask_quality_threshold=unmask_quality_threshold, + unmask_all=unmask_all, + freq_penalty=freq_penalty, + ) + + device = xt.device + samples = xt.to(device) + decoded_samples = tokenizer.batch_decode(samples) + + valid_x_final = [] + validSequences = [] + + for idx, seq in enumerate(decoded_samples): + if analyzer.is_peptide(seq): + valid_x_final.append(samples[idx]) + validSequences.append(seq) + + print("len valid sequences:", len(validSequences)) + + valid_fraction = len(validSequences) / batch_size + + # Determine number of objectives from reward model + num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5 + + if len(validSequences) != 0: + score_vectors = reward_model(input_seqs=validSequences) # (N, num_objectives) + average_scores = score_vectors.T + + if num_objectives == 1: + sol = average_scores[0] + else: + affinity = average_scores[0] + sol = average_scores[1] + hemo = average_scores[2] + nf = average_scores[3] + permeability = average_scores[4] + else: + zeros = [0.0] + + if num_objectives == 1: + sol = zeros + else: + affinity = zeros + sol = zeros + hemo = zeros + nf = zeros + permeability = zeros + + if num_objectives == 1: + if dataframe: + df = pd.DataFrame({ + "Peptide Sequence": validSequences, + "Solubility": sol if len(validSequences) else [0.0], + }) + if return_valid: + return validSequences, sol, valid_fraction, df + return samples, sol, valid_fraction, df + + if return_valid: + return validSequences, sol, valid_fraction + return samples, sol, valid_fraction + + if dataframe: + df = pd.DataFrame({ + "Peptide Sequence": validSequences, + "Binding Affinity": affinity if len(validSequences) else [0.0], + "Solubility": sol if len(validSequences) else [0.0], + "Hemolysis": hemo if len(validSequences) else [0.0], + "Nonfouling": nf if len(validSequences) else [0.0], + "Permeability": permeability if len(validSequences) else [0.0], + }) + if return_valid: + return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df + return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df + + if return_valid: + return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction + return samples, affinity, sol, hemo, nf, permeability, valid_fraction diff --git a/a2d2_pep/pep_scoring/functions/binding.py b/a2d2_pep/pep_scoring/functions/binding.py new file mode 100755 index 0000000000000000000000000000000000000000..c192b79f6387c8847341362f56a253264b61fb6e --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/binding.py @@ -0,0 +1,178 @@ +import sys +import os, torch +import numpy as np +import torch +import pandas as pd +import torch.nn as nn +import esm +from transformers import AutoModelForMaskedLM + +class ImprovedBindingPredictor(nn.Module): + def __init__(self, + esm_dim=1280, + smiles_dim=768, + hidden_dim=512, + n_heads=8, + n_layers=3, + dropout=0.1): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Project to same dimension + self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) + self.protein_projection = nn.Linear(esm_dim, hidden_dim) + self.protein_norm = nn.LayerNorm(hidden_dim) + self.smiles_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def forward(self, protein_emb, smiles_emb): + protein = self.protein_norm(self.protein_projection(protein_emb)) + smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) + + #protein = protein.transpose(0, 1) + #smiles = smiles.transpose(0, 1) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to SMILES + attended_protein = layer['attention']( + protein, smiles, smiles + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # SMILES attending to protein + attended_smiles = layer['attention']( + smiles, protein, protein + )[0] + smiles = layer['norm1'](smiles + attended_smiles) + smiles = layer['norm2'](smiles + layer['ffn'](smiles)) + + # Get sequence-level representations + protein_pool = torch.mean(protein, dim=0) + smiles_pool = torch.mean(smiles, dim=0) + + # Concatenate both representations + combined = torch.cat([protein_pool, smiles_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + classification_logits = self.classification_head(shared_features) + + return regression_output, classification_logits + +class BindingAffinity: + def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None): + super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + + # peptide embeddings + if emb_model is not None: + self.pep_model = emb_model.to(self.device).eval() + else: + self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval() + + self.pep_tokenizer = tokenizer + + self.model = ImprovedBindingPredictor().to(self.device) + checkpoint = torch.load(f'{base_path}/functions/classifiers/binding-affinity.pt', + map_location=self.device, + weights_only=False) + self.model.load_state_dict(checkpoint['model_state_dict']) + + self.model.eval() + + self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model + self.esm_model = self.esm_model.to(self.device).eval() + self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer + + data = [("target", prot_seq)] + # get tokenized protein + _, _, prot_tokens = self.prot_tokenizer(data) + prot_tokens = prot_tokens.to(self.device) + with torch.no_grad(): + results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2 + prot_emb = results["representations"][33] + + self.prot_emb = prot_emb[0].to(self.device) + self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True) + + + def forward(self, input_seqs): + with torch.no_grad(): + scores = [] + for seq in input_seqs: + pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True) + + pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()} + + with torch.no_grad(): + emb = self.pep_model(input_ids=pep_tokens['input_ids'], + attention_mask=pep_tokens['attention_mask'], + output_hidden_states=True) + + #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask']) + pep_emb = emb.last_hidden_state.squeeze(0) + pep_emb = torch.mean(pep_emb, dim=0, keepdim=True) + + score, logits = self.model.forward(self.prot_emb, pep_emb) + scores.append(score.item()) + return scores + + def __call__(self, input_seqs: list): + return self.forward(input_seqs) \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/binding_utils.py b/a2d2_pep/pep_scoring/functions/binding_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5f5b7fe1d0ab2c2be70a11a27d3f1f4bddcd8ff7 --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/binding_utils.py @@ -0,0 +1,290 @@ +from torch import nn +import torch +import numpy as np + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + +class MultiHeadAttentionSequence(nn.Module): + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + + attention = torch.matmul(Q, K) + + attention = attention / np.sqrt(self.d_k) + + attention = F.softmax(attention, dim=-1) + + output = torch.matmul(attention, V) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output = self.W_O(output) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + return output, attention + +class MultiHeadAttentionReciprocal(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + self.W_V_2 = nn.Linear(d_model, n_head*d_v) + self.W_O_2 = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, q, k, v, v_2): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + batch, len_v_2, _ = v_2.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + V_2 = V_2.transpose(1,2) + + attention = torch.matmul(Q, K) + + + attention = attention /np.sqrt(self.d_k) + + attention_2 = attention.transpose(-2, -1) + + + + attention = F.softmax(attention, dim=-1) + + attention_2 = F.softmax(attention_2, dim=-1) + + + output = torch.matmul(attention, V) + + output_2 = torch.matmul(attention_2, V_2) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head]) + + output = self.W_O(output) + + output_2 = self.W_O_2(output_2) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + output_2 = self.dropout(output_2) + + output_2 = self.layer_norm(output_2 + k) + + + return output, output_2, attention, attention_2 + + +class FFN(nn.Module): + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + + self.layer_1 = nn.Conv1d(d_in, d_hid,1) + self.layer_2 = nn.Conv1d(d_hid, d_in,1) + self.relu = nn.ReLU() + self.layer_norm = nn.LayerNorm(d_in) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + output = self.layer_1(x.transpose(1, 2)) + + output = self.relu(output) + + output = self.layer_2(output) + + output = self.dropout(output) + + output = self.layer_norm(output.transpose(1, 2)+residual) + + return output + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): + super(ConvLayer, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +class DilatedCNN(nn.Module): + def __init__(self, d_model, d_hidden): + super(DilatedCNN, self).__init__() + self.first_ = nn.ModuleList() + self.second_ = nn.ModuleList() + self.third_ = nn.ModuleList() + + dilation_tuple = (1, 2, 3) + dim_in_tuple = (d_model, d_hidden, d_hidden) + dim_out_tuple = (d_hidden, d_hidden, d_hidden) + + for i, dilation_rate in enumerate(dilation_tuple): + self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate, + dilation=dilation_rate)) + + def forward(self, protein_seq_enc): + # pdb.set_trace() + protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L + + first_embedding = protein_seq_enc + second_embedding = protein_seq_enc + third_embedding = protein_seq_enc + + for i in range(len(self.first_)): + first_embedding = self.first_[i](first_embedding) + + for i in range(len(self.second_)): + second_embedding = self.second_[i](second_embedding) + + for i in range(len(self.third_)): + third_embedding = self.third_[i](third_embedding) + + # pdb.set_trace() + + protein_seq_enc = first_embedding + second_embedding + third_embedding + + return protein_seq_enc.transpose(1, 2) + + +class ReciprocalLayerwithCNN(nn.Module): + + def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v): + super().__init__() + + self.cnn = DilatedCNN(d_model, d_hidden) + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v) + + self.ffn_seq = FFN(d_hidden, d_inner) + + self.ffn_protein = FFN(d_hidden, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model + protein_seq_enc = self.cnn(protein_seq_enc) + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc) + + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + +class ReciprocalLayer(nn.Module): + + def __init__(self, d_model, d_inner, n_head, d_k, d_v): + + super().__init__() + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v) + + self.ffn_seq = FFN(d_model, d_inner) + + self.ffn_protein = FFN(d_model, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/hemolysis.py b/a2d2_pep/pep_scoring/functions/hemolysis.py new file mode 100755 index 0000000000000000000000000000000000000000..04615bc0339f0ac675a135a51bfe88ff475e4e62 --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/hemolysis.py @@ -0,0 +1,63 @@ +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +import warnings +import numpy as np +from rdkit import rdBase + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Hemolysis: + + def __init__(self, tokenizer, base_path, device=None, emb_model=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/hemolysis-xgboost.json') + self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval() + self.tokenizer = tokenizer + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt') + tokenized = {k: v.to(self.device) for k, v in tokenized.items()} + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.ones(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + probs = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores - probs + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + hemo = Hemolysis() + seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"] + print(hemo.tokenizer.vocab_size) + scores = hemo(input_seqs=seq) + print(scores) + + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/nonfouling.py b/a2d2_pep/pep_scoring/functions/nonfouling.py new file mode 100755 index 0000000000000000000000000000000000000000..4e76d8e99ca2e02d0bcac9f3eb00393e9f74e11c --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/nonfouling.py @@ -0,0 +1,66 @@ +import sys +import os +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs + + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Nonfouling: + + def __init__(self, tokenizer, base_path, device=None, emb_model=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/nonfouling-xgboost.json') + self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval() + self.tokenizer = tokenizer + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt') + tokenized = {k: v.to(self.device) for k, v in tokenized.items()} + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + nf = Nonfouling() + seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"] + + scores = nf(input_seqs=seq) + print(scores) + + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/permeability.py b/a2d2_pep/pep_scoring/functions/permeability.py new file mode 100755 index 0000000000000000000000000000000000000000..1a9909a42706078679c9cb1c3bc9db26eea60525 --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/permeability.py @@ -0,0 +1,170 @@ +import sys +import os +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +import warnings +import numpy as np +from rdkit.Chem import Descriptors, rdMolDescriptors +from rdkit import Chem, rdBase, DataStructs +from rdkit.Chem import AllChem +from typing import List + + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +def fingerprints_from_smiles(smiles: List, size=2048): + """ Create ECFP fingerprints of smiles, with validity check """ + fps = [] + valid_mask = [] + for i, smile in enumerate(smiles): + mol = Chem.MolFromSmiles(smile) + valid_mask.append(int(mol is not None)) + fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size)) + fps.append(fp) + + fps = np.concatenate(fps, axis=0) + return fps, valid_mask + + +def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): + """ Create ECFP fingerprint of a molecule """ + if hashed: + fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) + else: + fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) + fp_np = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(fp_bits, fp_np) + return fp_np.reshape(1, -1) + +def getMolDescriptors(mol, missingVal=0): + """ calculate the full list of descriptors for a molecule """ + + values, names = [], [] + for nm, fn in Descriptors._descList: + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + + custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD, + 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA, + 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,} + + for nm, fn in custom_descriptors.items(): + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + return values, names + +def get_pep_dps_from_smi(smi): + try: + mol = Chem.MolFromSmiles(smi) + except: + print(f"convert smi {smi} to molecule failed!") + mol = None + + dps, _ = getMolDescriptors(mol) + return np.array(dps) + + +def get_pep_dps(smi_list): + if len(smi_list) == 0: + return np.zeros((0, 213)) + return np.array([get_pep_dps_from_smi(smi) for smi in smi_list]) + +def check_smi_validity(smiles: list): + valid_smi, valid_idx = [], [] + for idx, smi in enumerate(smiles): + try: + mol = Chem.MolFromSmiles(smi) if smi else None + if mol: + valid_smi.append(smi) + valid_idx.append(idx) + except Exception as e: + # logger.debug(f'Error: {e} in smiles {smi}') + pass + return valid_smi, valid_idx + +class Permeability: + + def __init__(self, tokenizer, base_path, device=None, emb_model=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/permeability-xgboost.json') + if emb_model is not None: + self.emb_model = emb_model.to(self.device).eval() + else: + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval() + + self.tokenizer = tokenizer + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt') + tokenized = {k: v.to(self.device) for k, v in tokenized.items()} + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_features(self, input_seqs: list, dps=False, fps=False): + #valid_smiles, valid_idxes = check_smi_validity(input_seqs) + + + if fps: + fingerprints = fingerprints_from_smiles(input_seqs)[0] + else: + fingerprints = torch.empty((len(input_seqs), 0)) + + if dps: + descriptors = get_pep_dps(input_seqs) + else: + descriptors = torch.empty((len(input_seqs), 0)) + + embeddings = self.generate_embeddings(input_seqs) + # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}') + + features = np.concatenate([fingerprints, descriptors, embeddings], axis=1) + + return features + + def get_scores(self, input_seqs: list): + scores = -10 * np.ones(len(input_seqs)) + features = self.get_features(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + permeability = Permeability() + seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O'] + scores = permeability(input_seqs=seq) + print(scores) + + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/scoring_utils.py b/a2d2_pep/pep_scoring/functions/scoring_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..911c29708d61288723865f71f2915990ca2e165f --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/scoring_utils.py @@ -0,0 +1,94 @@ +import warnings +import numpy as np +from loguru import logger +from sklearn.ensemble import RandomForestRegressor +from rdkit.Chem import Descriptors, rdMolDescriptors +import joblib +from rdkit import Chem, rdBase, DataStructs +from rdkit.Chem import AllChem +from typing import List + + +def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): + """ + Create ECFP fingerprint of a molecule + """ + if hashed: + fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) + else: + fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) + fp_np = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(fp_bits, fp_np) + return fp_np.reshape(1, -1) + + +def fingerprints_from_smiles(smiles: List, size=2048): + """ Create ECFP fingerprints of smiles, with validity check """ + fps = [] + valid_mask = [] + for i, smile in enumerate(smiles): + mol = Chem.MolFromSmiles(smile) + valid_mask.append(int(mol is not None)) + fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size)) + fps.append(fp) + + fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size)) + return fps, valid_mask + + +def getMolDescriptors(mol, missingVal=0): + """ calculate the full list of descriptors for a molecule """ + + values, names = [], [] + for nm, fn in Descriptors._descList: + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + + custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD, + 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA, + 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,} + + for nm, fn in custom_descriptors.items(): + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + return values, names + + +def get_pep_dps_from_smi(smi): + try: + mol = Chem.MolFromSmiles(smi) + except: + print(f"convert smi {smi} to molecule failed!") + mol = None + + dps, _ = getMolDescriptors(mol) + return np.array(dps) + + +def get_pep_dps(smi_list): + if len(smi_list) == 0: + return np.zeros((0, 211)) + return np.array([get_pep_dps_from_smi(smi) for smi in smi_list]) + + + +def check_smi_validity(smiles: list): + valid_smi, valid_idx = [], [] + for idx, smi in enumerate(smiles): + try: + mol = Chem.MolFromSmiles(smi) if smi else None + if mol: + valid_smi.append(smi) + valid_idx.append(idx) + except Exception as e: + # logger.debug(f'Error: {e} in smiles {smi}') + pass + return valid_smi, valid_idx \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/functions/solubility.py b/a2d2_pep/pep_scoring/functions/solubility.py new file mode 100755 index 0000000000000000000000000000000000000000..80a71d0cb7ee6fc5a3ad4a2f04dc631932bb8ff0 --- /dev/null +++ b/a2d2_pep/pep_scoring/functions/solubility.py @@ -0,0 +1,63 @@ +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +import warnings +import numpy as np +from rdkit import rdBase + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Solubility: + def __init__(self, tokenizer, base_path, device=None, emb_model=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.predictor = xgb.Booster(model_file=f'{base_path}/functions/classifiers/solubility-xgboost.json') + if emb_model is not None: + self.emb_model = emb_model.to(self.device).eval() + else: + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval() + + self.tokenizer = tokenizer + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt') + tokenized = {k: v.to(self.device) for k, v in tokenized.items()} + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return scores + +def unittest(): + solubility = Solubility() + seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"] + scores = solubility(input_seqs=seq) + print(scores) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/scoring_functions.py b/a2d2_pep/pep_scoring/scoring_functions.py new file mode 100755 index 0000000000000000000000000000000000000000..45f3cdc14245a05d29bcd0177fe39a3e6c819903 --- /dev/null +++ b/a2d2_pep/pep_scoring/scoring_functions.py @@ -0,0 +1,79 @@ +import os +from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from transformers import AutoModelForMaskedLM +import numpy as np +from .functions.binding import BindingAffinity +from .functions.permeability import Permeability +from .functions.solubility import Solubility +from .functions.hemolysis import Hemolysis +from .functions.nonfouling import Nonfouling + +# base path: this package directory (holds tokenizer/ and functions/classifiers/) +base_path = os.path.dirname(os.path.abspath(__file__)) + +class ScoringFunctions: + def __init__(self, score_func_names=None, prot_seqs=None, device=None): + """ + Class for generating score vectors given generated sequence + + Args: + score_func_names: list of scoring function names to be evaluated + score_weights: weights to scale scores (default: 1) + target_protein: sequence of target protein binder + """ + emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval() + tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tokenizer/new_vocab.txt', + f'{base_path}/tokenizer/new_splits.txt') + prot_seqs = prot_seqs if prot_seqs is not None else [] + + if score_func_names is None: + # just do unmasking based on validity of peptide bonds + self.score_func_names = [] + else: + self.score_func_names = score_func_names + + # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights) + + # binding affinities + self.target_protein = prot_seqs + print(len(prot_seqs)) + + if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1): + binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device) + binding_affinity2 = None + elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2): + binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device) + binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device) + else: + print("here") + binding_affinity1 = None + binding_affinity2 = None + + permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model) + sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model) + nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model) + hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model) + + self.all_funcs = {'binding_affinity1': binding_affinity1, + 'binding_affinity2': binding_affinity2, + 'permeability': permeability, + 'nonfouling': nonfouling, + 'solubility': sol, + 'hemolysis': hemo + } + + def forward(self, input_seqs): + scores = [] + + for i, score_func in enumerate(self.score_func_names): + score = self.all_funcs[score_func](input_seqs = input_seqs) + + scores.append(score) + + # convert to numpy arrays with shape (num_sequences, num_functions) + scores = np.float32(scores).T + + return scores + + def __call__(self, input_seqs: list): + return self.forward(input_seqs) \ No newline at end of file diff --git a/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py b/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py new file mode 100755 index 0000000000000000000000000000000000000000..0ada3e41d16041e8e22d37b4e5fd8303b9f00491 --- /dev/null +++ b/a2d2_pep/pep_scoring/tokenizer/my_tokenizers.py @@ -0,0 +1,424 @@ +import collections +import os +import re +from typing import List, Optional +from transformers import PreTrainedTokenizer +from SmilesPE.tokenizer import SPE_Tokenizer +import torch + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + +class Atomwise_Tokenizer(object): + """Run atom-level SMILES tokenization""" + + def __init__(self): + """ Constructs a atom-level Tokenizer. + """ + # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" + self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" + + self.regex = re.compile(self.regex_pattern) + + def tokenize(self, text): + """ Basic Tokenization of a SMILES. + """ + tokens = [token for token in self.regex.findall(text)] + return tokens + +class SMILES_SPE_Tokenizer(PreTrainedTokenizer): + r""" + Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE). + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + spe_file (:obj:`string`): + File containing the trained SMILES Pair Encoding vocabulary. + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + def __init__(self, vocab_file, spe_file, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs): + if not os.path.isfile(vocab_file): + raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file)) + if not os.path.isfile(spe_file): + raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file)) + + self.vocab = load_vocab(vocab_file) + self.spe_vocab = open(spe_file, 'r', encoding='utf-8') + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab) + + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + return self.spe_tokenizer.tokenize(text).split(' ') + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + # changed encode and decode functions + def encode(self, token_array): + token_ids = [] + token_ids.append(2) + for token in token_array: + id = self._convert_token_to_id(token) + token_ids.append(id) + token_ids.append(3) + token_ids = torch.tensor([token_ids]) + attn_mask = torch.ones_like(token_ids) + return {'input_ids': token_ids, 'attention_mask': attn_mask} + + def decode(self, token_ids, skip_special_tokens=True): + token_ids = token_ids.squeeze(0).cpu().tolist() + token_array = [] + for idx in token_ids: + if idx == 3: # Stop decoding when token ID 3 is encountered + break + if skip_special_tokens and idx in self.all_special_ids: + continue + token = self._convert_id_to_token(idx) + token_array.append(token) + sequence = "".join(token_array) + return sequence + + def batch_decode(self, batch_token_ids, skip_special_tokens=True): + sequences = [] + for token_ids in batch_token_ids: + sequences.append(self.decode(token_ids)) + return sequences + + def get_token_split(self, token_ids): + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.cpu().tolist() + + token_array = [] + for seq_ids in token_ids: + seq_array = [] + for id in seq_ids: + token = self._convert_id_to_token(id) + seq_array.append(token) + token_array.append(seq_array) + + return token_array + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + if token_ids_1 is None, only returns the first portion of the mask (0's). + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, vocab_path): + """ + Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + index = 0 + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + +class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer): + r""" + Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE). + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + def __init__( + self, + vocab_file, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs + ): + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'.".format(vocab_file) + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.tokenizer = Atomwise_Tokenizer() + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + + def _tokenize(self, text): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + if token_ids_1 is None, only returns the first portion of the mask (0's). + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, vocab_path): + """ + Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + index = 0 + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) diff --git a/a2d2_pep/pep_utils/analyzer.py b/a2d2_pep/pep_utils/analyzer.py new file mode 100755 index 0000000000000000000000000000000000000000..1ba3d3c6fe69ea73b933225148cf9f2c8844a129 --- /dev/null +++ b/a2d2_pep/pep_utils/analyzer.py @@ -0,0 +1,1274 @@ +import os +import re +import pandas as pd +from io import StringIO +import rdkit +from rdkit import Chem +from rdkit.Chem import AllChem, Draw +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import matplotlib.pyplot as plt +import matplotlib.patches as patches +from io import BytesIO +import tempfile +from rdkit import Chem + +class PeptideAnalyzer: + def __init__(self, min_peptide_bonds=2, enforce_min_peptide_bonds=True): + # length cutoff: minimum number of backbone residues (N-Cα-C(=O) units) + + self.min_peptide_bonds = min_peptide_bonds + self.enforce_min_peptide_bonds = enforce_min_peptide_bonds + self.bond_patterns = [ + (r'OC\(=O\)', 'ester'), # Ester bond + (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond + (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond + (r'NC\(=O\)', 'peptide'), # Standard peptide bond + (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated + (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond + ] + # Three to one letter code mapping + self.three_to_one = { + 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E', + 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I', + 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N', + 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S', + 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y' + } + + def count_peptide_bonds(self, smiles): + """Count backbone peptide residues via N-Cα-C(=O) units. + + Matches the backbone pattern [NX3][CX4][CX3](=O): an amide nitrogen + bonded to an sp3 alpha-carbon bonded to a carbonyl. Requiring the sp3 + Cα excludes non-backbone amides — ureas/biurets (N-C(=O)-N, no Cα), + sulfonamides, and side-chain amides (Asn/Gln) — and uniquify=True + avoids the multiple-mapping over-count of symmetric N-methyl groups. + Each match corresponds to one backbone residue. + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return 0 + backbone_pattern = Chem.MolFromSmarts('[NX3][CX4][CX3](=O)') + return len(mol.GetSubstructMatches(backbone_pattern, uniquify=True)) + + def is_peptide(self, smiles): + """Check if the SMILES represents a peptide structure""" + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return False + + # Count backbone residues (N-Cα-C(=O) units). Requiring a real backbone + # unit rejects ureas/biurets and side-chain-only amides outright. + n_residues = self.count_peptide_bonds(smiles) + if n_residues == 0: + return False + + # length cutoff: reject molecules with too few backbone residues + if self.enforce_min_peptide_bonds and n_residues < self.min_peptide_bonds: + return False + + return True + + def is_cyclic(self, smiles): + """Improved cyclic peptide detection""" + # Check for C-terminal carboxyl + if smiles.endswith('C(=O)O'): + return False, [], [] + + # Find all numbers used in ring closures + ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles) + + # Find aromatic ring numbers + aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles) + aromatic_cycles = [] + for match in aromatic_matches: + numbers = re.findall(r'[0-9]', match) + aromatic_cycles.extend(numbers) + + # Numbers that aren't part of aromatic rings are peptide cycles + peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles] + + is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O') + return is_cyclic, peptide_cycles, aromatic_cycles + + def split_on_bonds(self, smiles): + """Split SMILES into segments with simplified Pro handling""" + positions = [] + used = set() + + # Find Gly pattern first + gly_pattern = r'NCC\(=O\)' + for match in re.finditer(gly_pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': 'gly', + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + for pattern, bond_type in self.bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Sort by position + positions.sort(key=lambda x: x['start']) + + # Create segments + segments = [] + + if positions: + # First segment + if positions[0]['start'] > 0: + segments.append({ + 'content': smiles[0:positions[0]['start']], + 'bond_after': positions[0]['pattern'] + }) + + # Process segments + for i in range(len(positions)-1): + current = positions[i] + next_pos = positions[i+1] + + if current['type'] == 'gly': + segments.append({ + 'content': 'NCC(=O)', + 'bond_before': positions[i-1]['pattern'] if i > 0 else None, + 'bond_after': next_pos['pattern'] + }) + else: + content = smiles[current['end']:next_pos['start']] + if content: + segments.append({ + 'content': content, + 'bond_before': current['pattern'], + 'bond_after': next_pos['pattern'] + }) + + # Last segment + if positions[-1]['end'] < len(smiles): + segments.append({ + 'content': smiles[positions[-1]['end']:], + 'bond_before': positions[-1]['pattern'] + }) + + return segments + + def clean_terminal_carboxyl(self, segment): + """Remove C-terminal carboxyl only if it's the true terminus""" + content = segment['content'] + + # Only clean if: + # 1. Contains C(=O)O + # 2. No bond_after exists (meaning it's the last segment) + # 3. C(=O)O is at the end of the content + if 'C(=O)O' in content and not segment.get('bond_after'): + print('recognized?') + # Remove C(=O)O pattern regardless of position + cleaned = re.sub(r'\(C\(=O\)O\)', '', content) + # Remove any leftover empty parentheses + cleaned = re.sub(r'\(\)', '', cleaned) + print(cleaned) + return cleaned + return content + + def identify_residue(self, segment): + """Identify residue with Pro reconstruction""" + # Only clean terminal carboxyl if this is the last segment + content = self.clean_terminal_carboxyl(segment) + mods = self.get_modifications(segment) + + # UAA pattern matching section - before regular residues + # Phenylglycine and derivatives + if 'c1ccccc1' in content: + if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content: + return '4', mods # Base phenylglycine + + # 4-substituted phenylalanines + if 'Cc1ccc' in content: + if 'OMe' in content or 'OCc1ccc' in content: + return '0A1', mods # 4-methoxy-Phenylalanine + elif 'Clc1ccc' in content: + return '200', mods # 4-chloro-Phenylalanine + elif 'Brc1ccc' in content: + return '4BF', mods # 4-Bromo-phenylalanine + elif 'C#Nc1ccc' in content: + return '4CF', mods # 4-cyano-phenylalanine + elif 'Ic1ccc' in content: + return 'PHI', mods # 4-Iodo-phenylalanine + elif 'Fc1ccc' in content: + return 'PFF', mods # 4-Fluoro-phenylalanine + + # Modified tryptophans + if 'c[nH]c2' in content: + if 'Oc2cccc2' in content: + return '0AF', mods # 7-hydroxy-tryptophan + elif 'Fc2cccc2' in content: + return '4FW', mods # 4-fluoro-tryptophan + elif 'Clc2cccc2' in content: + return '6CW', mods # 6-chloro-tryptophan + elif 'Brc2cccc2' in content: + return 'BTR', mods # 6-bromo-tryptophan + elif 'COc2cccc2' in content: + return 'MOT5', mods # 5-Methoxy-tryptophan + elif 'Cc2cccc2' in content: + return 'MTR5', mods # 5-Methyl-tryptophan + + # Special amino acids + if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content: + return 'BUG', mods # Tertleucine + + if 'CCCNC(=N)N' in content: + return 'CIR', mods # Citrulline + + if '[SeH]' in content: + return 'CSE', mods # Selenocysteine + + if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content: + return 'DAB', mods # Diaminobutyric acid + + if 'C1CCCCC1' in content: + if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content: + return 'CHG', mods # Cyclohexylglycine + elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content: + return 'ALC', mods # 3-cyclohexyl-alanine + + # Naphthalene derivatives + if 'c1cccc2c1cccc2' in content: + if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content: + return 'NAL', mods # 2-Naphthyl-alanine + + # Heteroaromatic derivatives + if 'c1cncc' in content: + return 'PYR4', mods # 3-(4-Pyridyl)-alanine + if 'c1cscc' in content: + return 'THA3', mods # 3-(3-thienyl)-alanine + if 'c1nnc' in content: + return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine + + # Modified serines and threonines + if 'OP(O)(O)O' in content: + if '[C@@H](COP' in content or '[C@H](COP' in content: + return 'SEP', mods # phosphoserine + elif '[C@@H](OP' in content or '[C@H](OP' in content: + return 'TPO', mods # phosphothreonine + + # Specialized ring systems + if 'c1c2ccccc2cc2c1cccc2' in content: + return 'ANTH', mods # 3-(9-anthryl)-alanine + if 'c1csc2c1cccc2' in content: + return 'BTH3', mods # 3-(3-benzothienyl)-alanine + if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content: + return 'ADAM', mods # Adamanthane + + # Fluorinated derivatives + if 'FC(F)(F)' in content: + if 'CC(F)(F)F' in content: + return 'FLA', mods # Trifluoro-alanine + if 'C(F)(F)F)c1' in content: + if 'c1ccccc1C(F)(F)F' in content: + return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine + if 'c1cccc(c1)C(F)(F)F' in content: + return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine + if 'c1ccc(cc1)C(F)(F)F' in content: + return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine + + # Multiple halogen patterns + if 'F' in content and 'c1' in content: + if 'c1ccc(c(c1)F)F' in content: + return 'F2F', mods # 3,4-Difluoro-phenylalanine + if 'cc(F)cc(c1)F' in content: + return 'WFP', mods # 3,5-Difluoro-phenylalanine + if 'Cl' in content and 'c1' in content: + if 'c1ccc(cc1Cl)Cl' in content: + return 'CP24', mods # 2,4-dichloro-phenylalanine + if 'c1ccc(c(c1)Cl)Cl' in content: + return 'CP34', mods # 3,4-dichloro-phenylalanine + + # Hydroxy and amino derivatives + if 'O' in content and 'c1' in content: + if 'c1cc(O)cc(c1)O' in content: + return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid + if 'c1ccc(c(c1)O)O' in content: + return 'DAH', mods # 3,4-Dihydroxy-phenylalanine + + # Cyclic amino acids + if 'C1CCCC1' in content: + return 'CPA3', mods # 3-Cyclopentyl-alanine + if 'C1CCCCC1' in content: + if 'CC1CCCCC1' in content: + return 'ALC', mods # 3-cyclohexyl-alanine + else: + return 'CHG', mods # Cyclohexylglycine + + # Chain-length variants + if 'CCC[C@@H]' in content or 'CCC[C@H]' in content: + return 'NLE', mods # Norleucine + if 'CC[C@@H]' in content or 'CC[C@H]' in content: + if not any(x in content for x in ['CC(C)', 'COC', 'CN(']): + return 'ABA', mods # 2-Aminobutyric acid + + # Modified histidines + if 'c1cnc' in content: + if '[C@@H]1CN[C@@H](N1)F' in content: + return '2HF', mods # 2-fluoro-l-histidine + if 'c1cnc([nH]1)F' in content: + return '2HF1', mods # 2-fluoro-l-histidine variant + if 'c1c[nH]c(n1)F' in content: + return '2HF2', mods # 2-fluoro-l-histidine variant + + # Sulfur and selenium containing + if '[SeH]' in content: + return 'CSE', mods # Selenocysteine + if 'S' in content: + if 'CSCc1ccccc1' in content: + return 'BCS', mods # benzylcysteine + if 'CCSC' in content: + return 'ESC', mods # Ethionine + if 'CCS' in content: + return 'HCS', mods # homocysteine + + # Additional modifications + if 'CN=[N]=N' in content: + return 'AZDA', mods # azido-alanine + if '[NH]=[C](=[NH2])=[NH2]' in content: + if 'CCC[NH]=' in content: + return 'AGM', mods # 5-methyl-arginine + if 'CC[NH]=' in content: + return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid + + if 'CCON' in content: + return 'CAN', mods # canaline + if '[C@@H]1C=C[C@@H](C=C1)' in content: + return 'ACZ', mods # cis-amiclenomycin + if 'CCC(=O)[NH3]' in content: + return 'ONL', mods # 5-oxo-l-norleucine + if 'c1ccncc1' in content: + return 'PYR4', mods # 3-(4-Pyridyl)-alanine + if 'c1ccco1' in content: + return 'FUA2', mods # (2-furyl)-alanine + + if 'c1ccc' in content: + if 'c1ccc(cc1)c1ccccc1' in content: + return 'BIF', mods # 4,4-biphenylalanine + if 'c1ccc(cc1)C(=O)c1ccccc1' in content: + return 'PBF', mods # 4-benzoyl-phenylalanine + if 'c1ccc(cc1)C(C)(C)C' in content: + return 'TBP4', mods # 4-tert-butyl-phenylalanine + if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content: + return '0BN', mods # 4-carbamimidoyl-l-phenylalanine + if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content: + return 'APM', mods # m-amidinophenyl-3-alanine + + # Multiple hydroxy patterns + if 'O' in content: + if '[C@H]([C@H](C)O)O' in content: + return 'ILX', mods # 4,5-dihydroxy-isoleucine + if '[C@H]([C@@H](C)O)O' in content: + return 'ALO', mods # Allo-threonine + if '[C@H](COP(O)(O)O)' in content: + return 'SEP', mods # phosphoserine + if '[C@H]([C@@H](C)OP(O)(O)O)' in content: + return 'TPO', mods # phosphothreonine + if '[C@H](c1ccc(O)cc1)O' in content: + return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine + if '[C@H](c1ccc(c(Cl)c1)O)O' in content: + return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine + + # Heterocyclic patterns + if 'n1' in content: + if 'n1cccn1' in content: + return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine + if 'n1nncn1' in content: + return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine + if 'c2c(n1)cccc2' in content: + return 'QU32', mods # 3-(2-Quinolyl)-alanine + if 'c1cnc2c(c1)cccc2' in content: + return 'QU33', mods # 3-(3-quinolyl)-alanine + if 'c1ccnc2c1cccc2' in content: + return 'QU34', mods # 3-(4-quinolyl)-alanine + if 'c1ccc2c(c1)nccc2' in content: + return 'QU35', mods # 3-(5-Quinolyl)-alanine + if 'c1ccc2c(c1)cncc2' in content: + return 'QU36', mods # 3-(6-Quinolyl)-alanine + if 'c1cnc2c(n1)cccc2' in content: + return 'QX32', mods # 3-(2-quinoxalyl)-alanine + + # Multiple nitrogen patterns + if 'N' in content: + if '[NH3]CC[C@@H]' in content: + return 'DAB', mods # Diaminobutyric acid + if '[NH3]C[C@@H]' in content: + return 'DPP', mods # 2,3-Diaminopropanoic acid + if '[NH3]CCCCCC[C@@H]' in content: + return 'HHK', mods # (2s)-2,8-diaminooctanoic acid + if 'CCC[NH]=[C](=[NH2])=[NH2]' in content: + return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid + if '[NH]=[C](=S)=[NH2]' in content: + return 'THIC', mods # Thio-citrulline + + # Chain modified amino acids + if 'CC' in content: + if 'CCCC[C@@H]' in content: + return 'AHP', mods # 2-Aminoheptanoic acid + if 'CCC([C@@H])(C)C' in content: + return 'I2M', mods # 3-methyl-l-alloisoleucine + if 'CC[C@H]([C@@H])C' in content: + return 'IIL', mods # Allo-Isoleucine + if '[C@H](CCC(C)C)' in content: + return 'HLEU', mods # Homoleucine + if '[C@@H]([C@@H](C)O)C' in content: + return 'HLU', mods # beta-hydroxyleucine + + # Modified glutamate/aspartate patterns + if '[C@@H]' in content: + if '[C@@H](C[C@@H](F))' in content: + return 'FGA4', mods # 4-Fluoro-glutamic acid + if '[C@@H](C[C@@H](O))' in content: + return '3GL', mods # 4-hydroxy-glutamic-acid + if '[C@@H](C[C@H](C))' in content: + return 'LME', mods # (3r)-3-methyl-l-glutamic acid + if '[C@@H](CC[C@H](C))' in content: + return 'MEG', mods # (3s)-3-methyl-l-glutamic acid + + # Sulfur and selenium modifications + if 'S' in content: + if 'SCC[C@@H]' in content: + return 'HSER', mods # homoserine + if 'SCCN' in content: + return 'SLZ', mods # thialysine + if 'SC(=O)' in content: + return 'CSA', mods # s-acetonylcysteine + if '[S@@](=O)' in content: + return 'SME', mods # Methionine sulfoxide + if 'S(=O)(=O)' in content: + return 'OMT', mods # Methionine sulfone + + # Double bond containing + if 'C=' in content: + if 'C=C[C@@H]' in content: + return '2AG', mods # 2-Allyl-glycine + if 'C=C[C@@H]' in content: + return 'LVG', mods # vinylglycine + if 'C=Cc1ccccc1' in content: + return 'STYA', mods # Styrylalanine + + # Special cases + if '[C@@H]1Cc2c(C1)cccc2' in content: + return 'IGL', mods # alpha-amino-2-indanacetic acid + if '[C](=[C](=O)=O)=O' in content: + return '26P', mods # 2-amino-6-oxopimelic acid + if '[C](=[C](=O)=O)=C' in content: + return '2NP', mods # l-2-amino-6-methylene-pimelic acid + if 'c2cnc[nH]2' in content: + return 'HIS', mods # histidine core + if 'c1cccc2c1cc(O)cc2' in content: + return 'NAO1', mods # 5-hydroxy-1-naphthalene + if 'c1ccc2c(c1)cc(O)cc2' in content: + return 'NAO2', mods # 6-hydroxy-2-naphthalene + + # Proline (P) - flexible ring numbers + if any([ + # Check for any ring number in bond patterns + (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and + any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789')) + for n in '123456789' + ]) or any([ + # Check ending patterns with any ring number + (f'CCCN{n}' in content and content.endswith('=O') and + any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789')) + for n in '123456789' + ]) or any([ + # Handle CCC[C@H]n patterns + (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or + (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or + # N-terminal Pro with any ring number + (f'N{n}CCC[C@H]{n}' in content) or + (f'N{n}CCC[C@@H]{n}' in content) + for n in '123456789' + ]): + return 'Pro', mods + + # Tryptophan (W) - more specific indole pattern + if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \ + 'c[nH]c' in content.replace(' ', ''): + return 'Trp', mods + + # Lysine (K) - both patterns + if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content: + return 'Lys', mods + + # Arginine (R) - both patterns + if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content: + return 'Arg', mods + + if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content: + return 'Nle', mods + + # Ornithine (Orn) - 3-carbon chain with NH2 + if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content: + return 'Orn', mods + + # 2-Naphthylalanine (2Nal) - distinct from Phe pattern + if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return '2Nal', mods + + # Cyclohexylalanine (Cha) - already in your code but moved here for clarity + if 'N2CCCCC2' in content or 'CCCCC2' in content: + return 'Cha', mods + + # Aminobutyric acid (Abu) - 2-carbon chain + if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']): + return 'Abu', mods + + # Pipecolic acid (Pip) - 6-membered ring like Pro + if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Pip', mods + + # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2 + if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content): + return 'Chg', mods + + # 4-Fluorophenylalanine (4F-Phe) + if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return '4F-Phe', mods + + # Regular residue identification + if ('NCC(=O)' in content) or (content == 'C'): + # Middle case - between bonds + if segment.get('bond_before') and segment.get('bond_after'): + if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']): + return 'Gly', mods + # Terminal case - at the end + elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'): + return 'Gly', mods + + if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content: + return 'Leu', mods + if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content: + return 'Leu', mods + + if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content: + return 'Thr', mods + + if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content: + return 'Phe', mods + + if ('[C@H](C(C)C)' in content or # With outer parentheses + '[C@@H](C(C)C)' in content or # With outer parentheses + '[C@H]C(C)C' in content or # Without outer parentheses + '[C@@H]C(C)C' in content): # Without outer parentheses + if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu + return 'Val', mods + + if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content: + return 'O-tBu', mods + + if any([ + 'CC[C@H](C)' in content, + 'CC[C@@H](C)' in content, + 'C(C)C[C@H]' in content and 'CC(C)C' not in content, + 'C(C)C[C@@H]' in content and 'CC(C)C' not in content + ]): + return 'Ile', mods + + if ('[C@H](C)' in content or '[C@@H](C)' in content): + if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']): + return 'Ala', mods + + # Tyrosine (Tyr) - 4-hydroxybenzyl side chain + if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content): + return 'Tyr', mods + + + # Serine (Ser) - Hydroxymethyl side chain + if '[C@H](CO)' in content or '[C@@H](CO)' in content: + if not ('C(C)O' in content or 'COC' in content): + return 'Ser', mods + + # Threonine (Thr) - 1-hydroxyethyl side chain + if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content: + return 'Thr', mods + + # Cysteine (Cys) - Thiol side chain + if '[C@H](CS)' in content or '[C@@H](CS)' in content: + return 'Cys', mods + + # Methionine (Met) - Methylthioethyl side chain + if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content): + return 'Met', mods + + # Asparagine (Asn) - Carbamoylmethyl side chain + if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Asn', mods + + # Glutamine (Gln) - Carbamoylethyl side chain + if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Gln', mods + + # Aspartic acid (Asp) - Carboxymethyl side chain + if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Asp', mods + + # Glutamic acid (Glu) - Carboxyethyl side chain + if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Glu', mods + + # Arginine (Arg) - 3-guanidinopropyl side chain + if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Arg', mods + + # Histidine (His) - Imidazole side chain + if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'His', mods + + return None, mods + + def get_modifications(self, segment): + """Get modifications based on bond types""" + mods = [] + if segment.get('bond_after'): + if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'): + mods.append('N-Me') + if 'OC(=O)' in segment['bond_after']: + mods.append('O-linked') + return mods + + def analyze_structure(self, smiles): + """Main analysis function with debug output""" + print("\nAnalyzing structure:", smiles) + + # Split into segments + segments = self.split_on_bonds(smiles) + + print("\nSegment Analysis:") + sequence = [] + for i, segment in enumerate(segments): + print(f"\nSegment {i}:") + print(f"Content: {segment['content']}") + print(f"Bond before: {segment.get('bond_before', 'None')}") + print(f"Bond after: {segment.get('bond_after', 'None')}") + + residue, mods = self.identify_residue(segment) + if residue: + if mods: + sequence.append(f"{residue}({','.join(mods)})") + else: + sequence.append(residue) + print(f"Identified as: {residue}") + print(f"Modifications: {mods}") + else: + print(f"Warning: Could not identify residue in segment: {segment['content']}") + + # Check if cyclic + is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles) + three_letter = '-'.join(sequence) + one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence) + + if is_cyclic: + three_letter = f"cyclo({three_letter})" + one_letter = f"cyclo({one_letter})" + + print(f"\nFinal sequence: {three_letter}") + print(f"One-letter code: {one_letter}") + print(f"Is cyclic: {is_cyclic}") + #print(f"Peptide cycles: {peptide_cycles}") + #print(f"Aromatic cycles: {aromatic_cycles}") + + return three_letter, len(segments) + """return { + 'three_letter': three_letter, + #'one_letter': one_letter, + 'is_cyclic': is_cyclic + }""" + + def return_sequence(self, smiles): + """Main analysis function with debug output""" + print("\nAnalyzing structure:", smiles) + + # Split into segments + segments = self.split_on_bonds(smiles) + + print("\nSegment Analysis:") + sequence = [] + for i, segment in enumerate(segments): + print(f"\nSegment {i}:") + print(f"Content: {segment['content']}") + print(f"Bond before: {segment.get('bond_before', 'None')}") + print(f"Bond after: {segment.get('bond_after', 'None')}") + + residue, mods = self.identify_residue(segment) + if residue: + if mods: + sequence.append(f"{residue}({','.join(mods)})") + else: + sequence.append(residue) + print(f"Identified as: {residue}") + print(f"Modifications: {mods}") + else: + print(f"Warning: Could not identify residue in segment: {segment['content']}") + + return sequence + +""" +def annotate_cyclic_structure(mol, sequence): + '''Create annotated 2D structure with clear, non-overlapping residue labels''' + # Generate 2D coordinates + # Generate 2D coordinates + AllChem.Compute2DCoords(mol) + + # Create drawer with larger size for annotations + drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size + + # Get residue list and reverse it to match structural representation + if sequence.startswith('cyclo('): + residues = sequence[6:-1].split('-') + else: + residues = sequence.split('-') + residues = list(reversed(residues)) # Reverse the sequence + + # Draw molecule first to get its bounds + drawer.drawOptions().addAtomIndices = False + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to PIL Image + img = Image.open(BytesIO(drawer.GetDrawingText())) + draw = ImageDraw.Draw(img) + + try: + # Try to use DejaVuSans as it's commonly available on Linux systems + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + except OSError: + try: + # Fallback to Arial if available (common on Windows) + font = ImageFont.truetype("arial.ttf", 60) + small_font = ImageFont.truetype("arial.ttf", 60) + except OSError: + # If no TrueType fonts are available, fall back to default + print("Warning: TrueType fonts not available, using default font") + font = ImageFont.load_default() + small_font = ImageFont.load_default() + # Get molecule bounds + conf = mol.GetConformer() + positions = [] + for i in range(mol.GetNumAtoms()): + pos = conf.GetAtomPosition(i) + positions.append((pos.x, pos.y)) + + x_coords = [p[0] for p in positions] + y_coords = [p[1] for p in positions] + min_x, max_x = min(x_coords), max(x_coords) + min_y, max_y = min(y_coords), max(y_coords) + + # Calculate scaling factors + scale = 150 # Increased scale factor + center_x = 1000 # Image center + center_y = 1000 + + # Add residue labels in a circular arrangement around the structure + n_residues = len(residues) + radius = 700 # Distance of labels from center + + # Start from the rightmost point (3 o'clock position) and go counterclockwise + # Offset by -3 positions to align with structure + offset = 0 # Adjust this value to match the structure alignment + for i, residue in enumerate(residues): + # Calculate position in a circle around the structure + # Start from 0 (3 o'clock) and go counterclockwise + angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues) + + # Calculate label position + label_x = center_x + radius * np.cos(angle) + label_y = center_y + radius * np.sin(angle) + + # Draw residue label + text = f"{i+1}. {residue}" + bbox = draw.textbbox((label_x, label_y), text, font=font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((label_x, label_y), text, + font=font, fill='black', anchor="mm") + + # Add sequence at the top with white background + seq_text = f"Sequence: {sequence}" + bbox = draw.textbbox((center_x, 100), seq_text, font=small_font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((center_x, 100), seq_text, + font=small_font, fill='black', anchor="mm") + + return img + +""" +def annotate_cyclic_structure(mol, sequence): + """Create structure visualization with just the sequence header""" + # Generate 2D coordinates + AllChem.Compute2DCoords(mol) + + # Create drawer with larger size for annotations + drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) + + # Draw molecule first + drawer.drawOptions().addAtomIndices = False + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to PIL Image + img = Image.open(BytesIO(drawer.GetDrawingText())) + draw = ImageDraw.Draw(img) + try: + small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + except OSError: + try: + small_font = ImageFont.truetype("arial.ttf", 60) + except OSError: + print("Warning: TrueType fonts not available, using default font") + small_font = ImageFont.load_default() + + # Add just the sequence header at the top + seq_text = f"Sequence: {sequence}" + bbox = draw.textbbox((1000, 100), seq_text, font=small_font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((1000, 100), seq_text, + font=small_font, fill='black', anchor="mm") + + return img + +def create_enhanced_linear_viz(sequence, smiles): + """Create an enhanced linear representation using PeptideAnalyzer""" + analyzer = PeptideAnalyzer() # Create analyzer instance + + # Create figure with two subplots + fig = plt.figure(figsize=(15, 10)) + gs = fig.add_gridspec(2, 1, height_ratios=[1, 2]) + ax_struct = fig.add_subplot(gs[0]) + ax_detail = fig.add_subplot(gs[1]) + + # Parse sequence and get residues + if sequence.startswith('cyclo('): + residues = sequence[6:-1].split('-') + else: + residues = sequence.split('-') + + # Get segments using analyzer + segments = analyzer.split_on_bonds(smiles) + + # Debug print + print(f"Number of residues: {len(residues)}") + print(f"Number of segments: {len(segments)}") + + # Top subplot - Basic structure + ax_struct.set_xlim(0, 10) + ax_struct.set_ylim(0, 2) + + num_residues = len(residues) + spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0 + + # Draw basic structure + y_pos = 1.5 + for i in range(num_residues): + x_pos = 0.5 + i * spacing + + # Draw amino acid box + rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4, + facecolor='lightblue', edgecolor='black') + ax_struct.add_patch(rect) + + # Draw connecting bonds if not the last residue + if i < num_residues - 1: + segment = segments[i] if i < len(segments) else None + if segment: + # Determine bond type from segment info + bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide' + is_n_methylated = 'N-Me' in segment.get('bond_after', '') + + bond_color = 'red' if bond_type == 'ester' else 'black' + linestyle = '--' if bond_type == 'ester' else '-' + + # Draw bond line + ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos], + color=bond_color, linestyle=linestyle, linewidth=2) + + # Add bond type label + mid_x = x_pos + spacing/2 + bond_label = f"{bond_type}" + if is_n_methylated: + bond_label += "\n(N-Me)" + ax_struct.text(mid_x, y_pos+0.1, bond_label, + ha='center', va='bottom', fontsize=10, + color=bond_color) + + # Add residue label + ax_struct.text(x_pos, y_pos-0.5, residues[i], + ha='center', va='top', fontsize=14) + + # Bottom subplot - Detailed breakdown + ax_detail.set_ylim(0, len(segments)+1) + ax_detail.set_xlim(0, 1) + + # Create detailed breakdown + segment_y = len(segments) # Start from top + for i, segment in enumerate(segments): + y = segment_y - i + + # Check if this is a bond or residue + residue, mods = analyzer.identify_residue(segment) + if residue: + text = f"Residue {i+1}: {residue}" + if mods: + text += f" ({', '.join(mods)})" + color = 'blue' + else: + # Must be a bond + text = f"Bond {i}: " + if 'O-linked' in segment.get('bond_after', ''): + text += "ester" + elif 'N-Me' in segment.get('bond_after', ''): + text += "peptide (N-methylated)" + else: + text += "peptide" + color = 'red' + + # Add segment analysis + ax_detail.text(0.05, y, text, fontsize=12, color=color) + ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray') + + # If cyclic, add connection indicator + if sequence.startswith('cyclo('): + ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos), + arrowprops=dict(arrowstyle='<->', color='red', lw=2)) + ax_struct.text(5, y_pos+0.3, 'Cyclic Connection', + ha='center', color='red', fontsize=14) + + # Add titles and adjust layout + ax_struct.set_title("Peptide Structure Overview", pad=20) + ax_detail.set_title("Segment Analysis Breakdown", pad=20) + + # Remove axes + for ax in [ax_struct, ax_detail]: + ax.set_xticks([]) + ax.set_yticks([]) + ax.axis('off') + + plt.tight_layout() + return fig + +class PeptideStructureGenerator: + """A class to generate 3D structures of peptides using different embedding methods""" + + @staticmethod + def prepare_molecule(smiles): + """Prepare molecule with proper hydrogen handling""" + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None: + raise ValueError("Failed to create molecule from SMILES") + + # Calculate valence for each atom + for atom in mol.GetAtoms(): + atom.UpdatePropertyCache(strict=False) + + # Sanitize with reduced requirements + Chem.SanitizeMol(mol, + sanitizeOps=Chem.SANITIZE_FINDRADICALS| + Chem.SANITIZE_KEKULIZE| + Chem.SANITIZE_SETAROMATICITY| + Chem.SANITIZE_SETCONJUGATION| + Chem.SANITIZE_SETHYBRIDIZATION| + Chem.SANITIZE_CLEANUPCHIRALITY) + + mol = Chem.AddHs(mol) + return mol + + @staticmethod + def get_etkdg_params(attempt=0): + """Get ETKDG parameters with optional modifications based on attempt number""" + params = AllChem.ETKDGv3() + params.randomSeed = -1 + params.maxIterations = 200 + params.numThreads = 4 # Reduced for web interface + params.useBasicKnowledge = True + params.enforceChirality = True + params.useExpTorsionAnglePrefs = True + params.useSmallRingTorsions = True + params.useMacrocycleTorsions = True + params.ETversion = 2 + params.pruneRmsThresh = -1 + params.embedRmsThresh = 0.5 + + if attempt > 10: + params.bondLength = 1.5 + (attempt - 10) * 0.02 + params.useExpTorsionAnglePrefs = False + + return params + + def generate_structure_etkdg(self, smiles, max_attempts=20): + """Generate 3D structure using ETKDG without UFF optimization""" + success = False + mol = None + + for attempt in range(max_attempts): + try: + mol = self.prepare_molecule(smiles) + params = self.get_etkdg_params(attempt) + + if AllChem.EmbedMolecule(mol, params) == 0: + success = True + break + except Exception as e: + continue + + if not success: + raise ValueError("Failed to generate structure with ETKDG") + + return mol + + def generate_structure_uff(self, smiles, max_attempts=20): + """Generate 3D structure using ETKDG followed by UFF optimization""" + best_mol = None + lowest_energy = float('inf') + + for attempt in range(max_attempts): + try: + test_mol = self.prepare_molecule(smiles) + params = self.get_etkdg_params(attempt) + + if AllChem.EmbedMolecule(test_mol, params) == 0: + res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000, + vdwThresh=10.0, confId=0, + ignoreInterfragInteractions=True) + + if res == 0: + ff = AllChem.UFFGetMoleculeForceField(test_mol) + if ff: + current_energy = ff.CalcEnergy() + if current_energy < lowest_energy: + lowest_energy = current_energy + best_mol = Chem.Mol(test_mol) + except Exception: + continue + + if best_mol is None: + raise ValueError("Failed to generate optimized structure") + + return best_mol + + @staticmethod + def mol_to_sdf_bytes(mol): + """Convert RDKit molecule to SDF file bytes""" + # First write to StringIO in text mode + sio = StringIO() + writer = Chem.SDWriter(sio) + writer.write(mol) + writer.close() + + # Convert the string to bytes + return sio.getvalue().encode('utf-8') + +def process_input(smiles_input=None, file_obj=None, show_linear=False, + show_segment_details=False, generate_3d=False, use_uff=False): + """Process input and create visualizations using PeptideAnalyzer""" + analyzer = PeptideAnalyzer() + temp_dir = tempfile.mkdtemp() if generate_3d else None + structure_files = [] + + # Handle direct SMILES input + if smiles_input: + smiles = smiles_input.strip() + + # First check if it's a peptide using analyzer's method + if not analyzer.is_peptide(smiles): + return "Error: Input SMILES does not appear to be a peptide structure.", None, None + + try: + # Create molecule + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return "Error: Invalid SMILES notation.", None, None + + # Generate 3D structures if requested + if generate_3d: + generator = PeptideStructureGenerator() + + try: + # Generate ETKDG structure + mol_etkdg = generator.generate_structure_etkdg(smiles) + etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf") + writer = Chem.SDWriter(etkdg_path) + writer.write(mol_etkdg) + writer.close() + structure_files.append(etkdg_path) + + # Generate UFF structure if requested + if use_uff: + mol_uff = generator.generate_structure_uff(smiles) + uff_path = os.path.join(temp_dir, "structure_uff.sdf") + writer = Chem.SDWriter(uff_path) + writer.write(mol_uff) + writer.close() + structure_files.append(uff_path) + + except Exception as e: + return f"Error generating 3D structures: {str(e)}", None, None, None + + # Use analyzer to get sequence + segments = analyzer.split_on_bonds(smiles) + + # Process segments and build sequence + sequence_parts = [] + output_text = "" + + # Only include segment analysis in output if requested + if show_segment_details: + output_text += "Segment Analysis:\n" + for i, segment in enumerate(segments): + output_text += f"\nSegment {i}:\n" + output_text += f"Content: {segment['content']}\n" + output_text += f"Bond before: {segment.get('bond_before', 'None')}\n" + output_text += f"Bond after: {segment.get('bond_after', 'None')}\n" + + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + output_text += f"Identified as: {residue}\n" + output_text += f"Modifications: {mods}\n" + else: + output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n" + output_text += "\n" + else: + # Just build sequence without detailed analysis in output + for segment in segments: + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + + # Check if cyclic using analyzer's method + is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles) + three_letter = '-'.join(sequence_parts) + one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts) + + if is_cyclic: + three_letter = f"cyclo({three_letter})" + one_letter = f"cyclo({one_letter})" + + # Create cyclic structure visualization + img_cyclic = annotate_cyclic_structure(mol, three_letter) + + # Create linear representation if requested + img_linear = None + if show_linear: + fig_linear = create_enhanced_linear_viz(three_letter, smiles) + buf = BytesIO() + fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300) + buf.seek(0) + img_linear = Image.open(buf) + plt.close(fig_linear) + + # Add summary to output + summary = "Summary:\n" + summary += f"Sequence: {three_letter}\n" + summary += f"One-letter code: {one_letter}\n" + summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n" + #if is_cyclic: + #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n" + #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n" + + if structure_files: + summary += "\n3D Structures Generated:\n" + for filepath in structure_files: + summary += f"- {os.path.basename(filepath)}\n" + + return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None + + except Exception as e: + return f"Error processing SMILES: {str(e)}", None, None, None + + # Handle file input + if file_obj is not None: + try: + # Handle file content + if hasattr(file_obj, 'name'): + with open(file_obj.name, 'r') as f: + content = f.read() + else: + content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj) + + output_text = "" + for line in content.splitlines(): + smiles = line.strip() + if smiles: + # Check if it's a peptide + if not analyzer.is_peptide(smiles): + output_text += f"Skipping non-peptide SMILES: {smiles}\n" + continue + + # Process this SMILES + segments = analyzer.split_on_bonds(smiles) + sequence_parts = [] + + # Add segment details if requested + if show_segment_details: + output_text += f"\nSegment Analysis for SMILES: {smiles}\n" + for i, segment in enumerate(segments): + output_text += f"\nSegment {i}:\n" + output_text += f"Content: {segment['content']}\n" + output_text += f"Bond before: {segment.get('bond_before', 'None')}\n" + output_text += f"Bond after: {segment.get('bond_after', 'None')}\n" + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + output_text += f"Identified as: {residue}\n" + output_text += f"Modifications: {mods}\n" + else: + for segment in segments: + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + + # Get cyclicity and create sequence + is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles) + sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts) + + output_text += f"\nSummary for SMILES: {smiles}\n" + output_text += f"Sequence: {sequence}\n" + output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n" + if is_cyclic: + output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n" + #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n" + output_text += "-" * 50 + "\n" + + return output_text, None, None + + except Exception as e: + return f"Error processing file: {str(e)}", None, None + + return "No input provided.", None, None \ No newline at end of file diff --git a/a2d2_pep/pep_utils/utils.py b/a2d2_pep/pep_utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..0795a73ff2ed657e5136ab99f83e109f5ea9174e --- /dev/null +++ b/a2d2_pep/pep_utils/utils.py @@ -0,0 +1,135 @@ +"""Console logger utilities. + +Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py +Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging +""" + +import logging +import fsspec +import lightning +import torch +from timm.scheduler import CosineLRScheduler +import argparse +import numpy as np +import random +import os + +def sample_categorical_logits(logits, dtype=torch.float64): + # do not require logits to be log-softmaxed + gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log() + return (logits + gumbel_noise).argmax(dim=-1) + +def fsspec_exists(filename): + """Check if a file exists using fsspec.""" + fs, _ = fsspec.core.url_to_fs(filename) + return fs.exists(filename) + + +def fsspec_listdir(dirname): + """Listdir in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + return fs.ls(dirname) + + +def fsspec_mkdirs(dirname, exist_ok=True): + """Mkdirs in manner compatible with fsspec.""" + fs, _ = fsspec.core.url_to_fs(dirname) + fs.makedirs(dirname, exist_ok=exist_ok) + + +def print_nans(tensor, name): + if torch.isnan(tensor).any(): + print(name, tensor) + + +class CosineDecayWarmupLRScheduler( + CosineLRScheduler, + torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_epoch = -1 + self.step(epoch=0) + + def step(self, epoch=None): + if epoch is None: + self._last_epoch += 1 + else: + self._last_epoch = epoch + # We call either step or step_update, depending on + # whether we're using the scheduler every epoch or every + # step. + # Otherwise, lightning will always call step (i.e., + # meant for each epoch), and if we set scheduler + # interval to "step", then the learning rate update will + # be wrong. + if self.t_in_epochs: + super().step(epoch=self._last_epoch) + else: + super().step_update(num_updates=self._last_epoch) + + +class LoggingContext: + """Context manager for selective logging.""" + def __init__(self, logger, level=None, handler=None, close=True): + self.logger = logger + self.level = level + self.handler = handler + self.close = close + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + if self.handler: + self.logger.addHandler(self.handler) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) + if self.handler: + self.logger.removeHandler(self.handler) + if self.handler and self.close: + self.handler.close() + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ('debug', 'info', 'warning', 'error', + 'exception', 'fatal', 'critical'): + setattr(logger, + level, + lightning.pytorch.utilities.rank_zero_only( + getattr(logger, level))) + + return logger + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def set_seed(seed, use_cuda): + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + if use_cuda: + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + print(f'=> Seed of the run set to {seed}') + diff --git a/a2d2_pep/remasking_scheduleaware.py b/a2d2_pep/remasking_scheduleaware.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5d1d35e76e1b52c531060b0c2a87f884a062a0 --- /dev/null +++ b/a2d2_pep/remasking_scheduleaware.py @@ -0,0 +1,181 @@ +""" +Schedule-aware remasking and insertion logic that ensures the number of masked tokens +follows the interpolant schedule. +""" +import torch +import numpy as np + +def apply_schedule_aware_insertion( + model, + xt_tmp, + new_xt, + t, + dt, + ext, + mask, + pad, + max_length, + orig_mask, + new_pos_orig, + quality_threshold=1, +): + """ + Remove low-quality insertions based on insertion confidence while respecting + the interpolant schedule for expected sequence length. + + Args: + model: Model with planner and interpolant + xt_tmp: Sequence after insertion [B, L] + new_xt: Sequence before insertion [B, L] + t: Current time [B] + dt: Time step size + ext: Number of insertions per gap [B, L+1] + mask: Mask token ID + pad: Pad token ID + max_length: Maximum sequence length + orig_mask: Mask of original token positions [B, L] + new_pos_orig: New positions of original tokens [B, L] + quality_threshold: If a float, drop insertions with confidence below it; if None, use schedule-driven deletion + + Returns: + xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule) + """ + device = xt_tmp.device + batch_size, L = xt_tmp.shape + total_ext = ext.sum(dim=1) + + # Only proceed if there were insertions + if total_ext.sum() == 0: + return xt_tmp + + # Get planner predictions on inserted state. The insertion head is trained + # with the pre-step time t (see loss_insert_planner_flexible), so condition + # on t here too; t_next is still used below for the length schedule. + t_next = t + dt + planner_out = model.planner(xt_tmp, t) + insertion_conf = planner_out.get("insertion_conf", None) + + if insertion_conf is None: + return xt_tmp + + insertion_conf = insertion_conf.squeeze(-1) # (B, L) + + # Expected sequence length at next timestep according to schedule + current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B] + expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B] + estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1)) + expected_length = estimated_final_length * expected_progress # [B] + + # Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions. + # Fancy-indexing scatter avoids the per-batch python loop. + valid_b, valid_l = orig_mask.nonzero(as_tuple=True) + valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1) + is_original = torch.zeros_like(xt_tmp, dtype=torch.bool) + is_original[valid_b, valid_p] = True + inserted_positions = (xt_tmp == mask) & ~is_original + + # Two deletion modes, selected by `quality_threshold`: + # * float: drop insertions whose confidence is below the threshold, capped + # so the length never falls below the scheduled minimum. + candidates = inserted_positions & (insertion_conf < quality_threshold) + num_bad = candidates.sum(dim=1) # [B], long + min_length = expected_length.long().clamp(min=1) # [B] + max_removable = (current_length_after.long() - min_length).clamp(min=0) + length_after_removal = current_length_after.long() - num_bad + schedule_violates = length_after_removal < min_length + k_per_row = torch.where(schedule_violates, max_removable, num_bad) + k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row)) + + if not candidates.any(): + return xt_tmp + + # Select the lowest-confidence candidates per row via a sort. + neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype) + scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse + _, sorted_indices = scores.sort(dim=1, descending=True) + positions = torch.arange(L, device=device).unsqueeze(0) # [1, L] + keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L] + final_bad = torch.zeros_like(candidates) + final_bad.scatter_(1, sorted_indices, keep_in_topk) + + if not final_bad.any(): + return xt_tmp + + # Compact each row to the left (keep good, drop bad), then pad the tail. + # Stable sort by the bad flag pushes bad positions to the right. + sort_key = final_bad.long() + _, perm = torch.sort(sort_key, dim=1, stable=True) + xt_tmp = torch.gather(xt_tmp, 1, perm) + num_keep = (~final_bad).sum(dim=1) # [B] + tail_mask = positions >= num_keep.unsqueeze(1) # [B, L] + xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp) + + return xt_tmp + + +def apply_schedule_aware_remasking( + model, + new_xt, + t, + dt, + remasking_conf, + clean_index, + mask, + neg_inf, + batch_size, + unmask_quality_threshold=None, +): + """ + Apply schedule-aware remasking: adjust number of masks to match expected count from schedule. + + Args: + model: Model with interpolant that has an unmask_schedule + new_xt: Current sequence [B, L] + t: Current time [B] + dt: Time step size + remasking_conf: Confidence scores for tokens [B, L] + clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L] + mask: Mask token ID + neg_inf: Negative infinity tensor + batch_size: Batch size + unmask_quality_threshold: If None (default), remask exactly the schedule + excess (count-based). If a float, ignore the schedule budget entirely + and remask EVERY clean token whose unmasking-quality confidence is + below the threshold. Higher threshold => more aggressive remasking. + + Returns: + new_xt: Modified sequence with schedule-aware remasking applied + """ + # Threshold gate (overrides the schedule-driven count when set): remask every + # clean token whose unmasking-quality confidence is below the threshold, + # regardless of the schedule budget. Higher threshold => more remasking. + if unmask_quality_threshold is not None: + to_mask = clean_index & (remasking_conf < unmask_quality_threshold) + return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt) + + t_next = t + dt + num_clean = clean_index.sum(dim=1) # [B], long + current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B] + expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B] + expected_num_clean = expected_unmasked_frac * current_seq_len # [B] + masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B] + + # Per-row k = min(masks_to_add, num_clean), clamped to >= 0. + k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B] + + if k_per_row.sum() == 0: + return new_xt + + # Use confidence to decide which clean tokens to remask: lowest conf first. + remasking_score_temp = -1.0 * remasking_conf # low conf = high score + remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) + + _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True) + L = remasking_score_temp.shape[1] + positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L] + keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L] + to_mask = torch.zeros_like(clean_index) + to_mask.scatter_(1, sorted_indices, keep_in_topk) + new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt) + + return new_xt diff --git a/a2d2_pep/sampling.py b/a2d2_pep/sampling.py new file mode 100755 index 0000000000000000000000000000000000000000..2bdd2eee730434849c946259f801ad527fd15723 --- /dev/null +++ b/a2d2_pep/sampling.py @@ -0,0 +1,1401 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path + +import torch +from dataclasses import dataclass +from typing import Any, Literal, Optional +import numpy as np +import pandas as pd + +from lightning_modules.mdm import MaskedDiffusionModule + + +@dataclass +class SamplingTraceDatapoint: + t: float + event_type: Literal["insertion", "change"] + position: int + token: Any + + +@dataclass +class SamplingResult: + samples: torch.Tensor + # Trace is supposed to be processed sequentially as updates are not commutative + trace: Optional[list[SamplingTraceDatapoint]] + + def __iter__(self): + yield from [self.samples, self.trace] + + +# Sample from categorical distribution for each position using the transition probabilities +def _sample_tokens(probs: torch.Tensor) -> torch.Tensor: + """Sample one token per position from probability distribution. + Args: + probs: [batch_size, seq_len, vocab_size] transition probabilities + Returns: + [batch_size, seq_len] sampled token indices + """ + batch_size, seq_len, vocab_size = probs.shape + flat_probs = probs.view(-1, vocab_size) + samples = torch.multinomial(flat_probs, num_samples=1) + return samples.view(batch_size, seq_len) + + +def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor: + + batch_size, seq_len, vocab_size = probs.shape + + gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device) + noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities + + # select the highest score (most likely category after Gumbel noise) + samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long) + + return samples.view(batch_size, seq_len) + +@torch.no_grad() +def mdm_euler_sampling( + model: MaskedDiffusionModule, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +): + assert not return_trace, "Trace is not yet implemented in MDM Euler sampling" + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + for i in range(steps): + print("i-th sampling step") + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + _xt = xt.clone() + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 + print(trans_prob[mask_pos + (mask,)]) + + new_xt = _sample_tokens(trans_prob) + new_xt = torch.where(xt != mask, xt, new_xt) + + xt = new_xt + t = t + dt + + return xt, [] + + +@torch.no_grad() +def any_order_mask_insertion_euler_sampling( + model: torch.nn.Module, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + # 1) Initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # Check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # Check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + return xt, sampling_trace + +@torch.no_grad() +def batch_mcts_reverse_step( + xt: torch.Tensor, + t: torch.Tensor, + dt: float, + model: torch.nn.Module, + pretrained: torch.nn.Module, + mask: int, + pad: int, + batch_size: int, + max_length: int, + last_step: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + xt = xt.repeat(batch_size, 1) + + # squeeze to remove extra dimensions, then expand to batch_size + t = t.squeeze().expand(batch_size) + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # Same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + if last_step: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if not last_step: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + # log P(ext; λ) = ext*log(λ) - λ + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + log_pretrained_step += log_pretrained_insert + log_policy_step += log_policy_insert + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch + gaps = torch.arange(seq_dim, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + return xt_tmp, log_rnd, log_policy_step, log_pretrained_step + + +@torch.no_grad() +def mcts_reverse_step( + xt: torch.Tensor, + t: torch.Tensor, + dt: float, + model: torch.nn.Module, + pretrained: torch.nn.Module, + mask: int, + pad: int, + max_length: int, + last_step: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + batch_size = xt.size(0) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + if last_step: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if not last_step: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + # log P(ext; λ) = ext*log(λ) - λ + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + log_pretrained_step += log_pretrained_insert + log_policy_step += log_policy_insert + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch + gaps = torch.arange(seq_dim, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + return xt_tmp, log_rnd, log_policy_step, log_pretrained_step + +@torch.no_grad() +def any_order_euler_sampling_with_schedule( + model: torch.nn.Module, + time_schedule: torch.Tensor, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +) -> SamplingResult: + device = next(model.parameters()).device + + time_schedule = time_schedule.to(device) + if time_schedule[0] < time_schedule[-1]: + time_schedule = torch.flip(time_schedule, [0]) # descending order + + steps = len(time_schedule) - 1 + + # initialize all-pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # use scheduled timesteps + t = time_schedule[i].repeat(batch_size) + t_next = time_schedule[i + 1] + dt = (t - t_next).abs() # timestep difference + + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0) + + # add "stay" probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + + if mask_has_zero_prob.any(): + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— + ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # Check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # Check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + + return xt, sampling_trace + + +@torch.no_grad() +def any_order_mask_insertion_euler_sampling_with_rnd( + model, pretrained, reward_model, analyzer, + tokenizer, steps, + mask, + pad, + batch_size, + max_length, + return_trace = False, + alpha = 0.1, + temperature: float = 1.0, +): + device = next(model.parameters()).device + + # initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + # initialize log_rnd to accumulate log probability ratios + log_rnd = torch.zeros(batch_size, device=device) + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— get pretrained model rates for log_rnd computation ——— + pretrained_pred = pretrained(xt, t) + pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) + pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) + pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # Same for pretrained + pretrained_unmask_rate[xt != mask] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = 0 + pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + pretrained_trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # ——— compute log probabilities for RND ——— + lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) + + changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) + + log_policy_step = (lp * changed_mask).sum(dim=1) + log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) + + log_rnd = log_pretrained_step - log_policy_step # (B,) + + if i != steps - 1: + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + + insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) + pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) + + log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) + log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) + + log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) + log_rnd += log_insert_diff + + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # check if the token was changed + for i in range(batch_size): + for j in range(max_length): + if xt[i, j] != pad and xt[i, j] != new_xt[i, j]: + sampling_trace[i].append( + SamplingTraceDatapoint( + t=t[i].item(), + event_type="change", + position=j, + token=new_xt[i, j].item(), + ) + ) + + # check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[i, id]: + sampling_trace[i].append( + SamplingTraceDatapoint( + t=t[i].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + # change rewards for peptides + samples = xt.to(device) + + # store raw token IDs + # Decode and strip samples + decoded_samples = tokenizer.batch_decode(samples) + + valid_x_final = [] + validSequences = [] + valid_log_rnd = [] + + for idx, seq in enumerate(decoded_samples): + # check if the peptide is valid + if analyzer.is_peptide(seq): + valid_x_final.append(xt[idx]) + validSequences.append(seq) + valid_log_rnd.append(log_rnd[idx]) + + print("len valid sequences:", len(validSequences)) + # compute multi-objective rewards + score_vectors = reward_model(input_seqs=validSequences) + scalar_rewards = np.sum(score_vectors, axis=-1) + scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) + + print(f"scalar reward dim{len(scalar_rewards)}") + valid_log_rnd = torch.stack(valid_log_rnd, dim=0) + + log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha + valid_x_final = torch.stack(valid_x_final, dim=0) + + return valid_x_final, log_rnd, scalar_rewards, sampling_trace + +@torch.no_grad() +def any_order_finetuned_euler_sampler( + model, reward_model, analyzer, + tokenizer, steps, + mask, + pad, + batch_size, + max_length, + return_trace = False, + dataframe = False, + temperature: float = 1.0, + ): + device = next(model.parameters()).device + + # initialize all‑pad sequence and trace + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + sampling_trace = [[] for _ in range(batch_size)] if return_trace else None + + for i in range(steps): + # ——— predict and convert rates ——— + pred_rate = model(xt, t) + pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == mask).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # add “stay” probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + # Apply temperature scaling + if temperature != 1.0: + logits = torch.log(trans_prob + 1e-10) / temperature + trans_prob = torch.softmax(logits, dim=-1) + + if i == steps - 1: + print("Final step, removing mask token from sampling") + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # create uniform distribution over valid tokens (excluding mask and pad) + uniform_prob = torch.zeros_like(trans_prob[0]) + uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + if i != steps - 1: + # gap-wise insertion refactored — compute new length, fill masks, scatter tokens + ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + mask_fill = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + if return_trace: + # check if the token was changed + for batch_idx in range(batch_size): + for j in range(max_length): + if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="change", + position=j, + token=new_xt[batch_idx, j].item(), + ) + ) + + # check if a new token was inserted + for j in range(max_length): + id = max_length - j - 1 + if ext[batch_idx, id]: + sampling_trace[batch_idx].append( + SamplingTraceDatapoint( + t=t[batch_idx].item(), + event_type="insertion", + position=id, + token=mask, + ) + ) + + xt = xt_tmp + t = t + dt + + # start eval + samples = xt.to(device) + + decoded_samples = tokenizer.batch_decode(samples) + + valid_x_final = [] + validSequences = [] + + for idx, seq in enumerate(decoded_samples): + if analyzer.is_peptide(seq): + valid_x_final.append(samples[idx]) + validSequences.append(seq) + + print("len valid sequences:", len(validSequences)) + valid_fraction = len(validSequences) / batch_size + + if (len(validSequences) != 0): + # add scores to log + score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives) + average_scores = score_vectors.T + + affinity = average_scores[0] + sol = average_scores[1] + hemo = average_scores[2] + nf = average_scores[3] + permeability = average_scores[4] + + else: + zeros = [0.0] + + affinity = zeros + sol = zeros + hemo = zeros + nf = zeros + permeability = zeros + + if dataframe: + df = pd.DataFrame({ + "Peptide Sequence": validSequences, + "Binding Affinity": affinity if len(validSequences) else [0.0], + "Solubility": sol if len(validSequences) else [0.0], + "Hemolysis": hemo if len(validSequences) else [0.0], + "Nonfouling": nf if len(validSequences) else [0.0], + "Permeability": permeability if len(validSequences) else [0.0], + }) + return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df + + return samples, affinity, sol, hemo, nf, permeability, valid_fraction + +@torch.no_grad() +def mdm_tau_leaping_sampling( + model: MaskedDiffusionModule, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + temperature: float = 1.0, +): + assert not return_trace, "Trace is not yet supported" + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + for i in range(steps): + # ——— predict and convert rates ——— + pred = model(xt, t) + pred = model.interpolant.to_actual_rate(xt, pred, t) + unmask_rate = pred.unmask_rate # (B, L, V) + + if i == steps - 1: + # last step: deterministic unmask via argmax + mask_pos = xt == mask # (B, L) + new_token = unmask_rate.argmax(dim=2) # (B, L) + new_xt = xt.clone() + new_xt[mask_pos] = new_token[mask_pos] + new_xt = torch.where(xt != mask, xt, new_xt) + xt = new_xt + t = t + dt + continue + # tau-leaping via Poisson counts + counts = torch.poisson(unmask_rate * dt).long() + mask_pos = xt == mask # (B, L) + # zero out non-mask positions and mask→mask + counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 + counts[..., mask] = 0 + # only accept exactly one event + sum_c = counts.sum(dim=2) # (B, L) + one_event = sum_c == 1 + new_token = counts.argmax(dim=2) # (B, L) + + # build new xt + new_xt = xt.clone() + new_xt[one_event] = new_token[one_event] + # keep pads and already-unmasked tokens + new_xt = torch.where(xt != mask, xt, new_xt) + xt = new_xt + t = t + dt + + return xt, [] + +# Not used in production, for debugging purposes +lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1} + +def binomial_mass(k, n, p): + """ + Calculate the probability mass function (PMF) for a binomial distribution. + + Args: + k (int): Number of successes + n (int): Number of trials + p (float): Probability of success in a single trial + + Returns: + float: Probability mass P(X = k) + """ + import math + + # Calculate binomial coefficient (n choose k) + try: + binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k)) + except ValueError: + # Handle cases where k > n or negative values + return 0.0 + + # Calculate probability mass + return binom_coef * (p ** k) * ((1 - p) ** (n - k)) + +def calculate_rate_batch(alpha_t, len_t): + """ + Calculate rate for a batch of alpha_t and len_t values. + + Args: + alpha_t (torch.Tensor): Tensor of shape (batch_size,) + len_t (torch.Tensor): Tensor of shape (batch_size,) + + Returns: + torch.Tensor: Tensor of shape (batch_size,) containing calculated rates + """ + batch_size = alpha_t.shape[0] + device = alpha_t.device + + # Initialize tensors for numerator and denominator + nom = torch.zeros(batch_size, device=device) + denom = torch.zeros(batch_size, device=device) + + for length, probability in lengths.items(): + # Create mask for valid entries where len_t <= length + valid_mask = (len_t <= length) & (len_t >= 0) + + if not valid_mask.any(): + continue + + valid_indices = valid_mask.nonzero(as_tuple=True)[0] + valid_len_t = len_t[valid_indices] + valid_alpha_t = alpha_t[valid_indices] + + # Calculate binomial probabilities efficiently using torch distribution + binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t) + binom_probs = binom_dist.log_prob(valid_len_t).exp() + + # Update numerator and denominator for valid indices + nom[valid_indices] += (length - valid_len_t) * probability * binom_probs + denom[valid_indices] += probability * binom_probs + + # Handle division by zero in a vectorized way + result = torch.zeros_like(nom) + div_mask = denom > 0 + result[div_mask] = nom[div_mask] / (denom[div_mask]) + + return result + +# Keep the original function for backward compatibility +def calculate_rate(alpha_t, len_t): + """Legacy scalar version of calculate_rate""" + if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0: + return calculate_rate_batch(alpha_t, len_t) + + nom, denom = 0, 0 + for length, probability in lengths.items(): + if length >= len_t: + nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t) + denom += probability * binomial_mass(len_t, length, alpha_t) + + if denom == 0: + return 0.0 + + return nom /denom + + +@torch.no_grad() +def any_order_mask_insertion_tau_leaping_sampling( + model: torch.nn.Module, + steps: int, + mask: int, + pad: int, + batch_size: int, + max_length: int, + return_trace: bool = False, + confidence_based_sampling: bool = True, # whether to use confidence-based decoding + alpha: float = 5.0, # hyperparameter for window size calculation + max_window: int = 32, # Maximum window size for sliding window + confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy" + use_sliding_window: bool = False, # whether to use sliding window for position selection + temperature: float = 1.0, +) -> SamplingResult: + + device = next(model.parameters()).device + xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) + sampling_trace = [] + dt = 1.0 / steps + t = torch.zeros(batch_size, device=device) + + # Precompute row indices for scatter + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, max_length) + ) + pos_idx_L = ( + torch.arange(max_length, device=device) + .view(1, max_length) + .expand(batch_size, max_length) + ) + + for i in range(steps): + # --- predict rates --- + pred = model(xt, t) + xt_len = (xt != pad).sum(dim=1) + pred = model.interpolant.to_actual_rate(xt, pred, t) + unmask_rate = pred.unmask_rate # (B, L, V) + len_rate = pred.length_rate # (B, L+1) + + if i == steps - 1: + # last step: deterministic unmask via argmax + mask_pos = xt == mask + new_token = unmask_rate.argmax(dim=2) + new_xt = xt.clone() + new_xt[mask_pos] = new_token[mask_pos] + new_xt = torch.where(xt == pad, pad, new_xt) + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + xt = new_xt + t = t + dt + continue + + # --- confidence-based decoding --- + if confidence_based_sampling > 0.0: + # Confidence-based unmasking (vectorized) + mask_positions = (xt == mask) # (B, L) + num_mask_positions = mask_positions.sum(dim=1) # (B,) + + # 1. Determine number of tokens to unmask using Poisson + unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,) + + # 2. Calculate confidence based on selected method + if confidence_method == "position": + # Position-based confidence: position i / len(xt) + xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths + position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L) + confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L) + + elif confidence_method == "top_prob": + # Top probability confidence + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + confidence = unmask_probs.max(dim=-1)[0] # (B, L) + + elif confidence_method == "prob_diff": + # Probability difference confidence (top - second top) + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2) + confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L) + + elif confidence_method == "entropy": + # Entropy-based confidence (lower entropy = higher confidence) + import torch.nn.functional as F + token_logits = unmask_rate # (B, L, V) + unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) + entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L) + confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence + + else: + raise ValueError(f"Unknown confidence_method: {confidence_method}") + + # 3. Apply window constraint if enabled + if use_sliding_window: + # Calculate dynamic k for each batch + k_values = torch.minimum( + torch.minimum( + (alpha * unmask_counts).long(), + torch.tensor(max_window, device=device) + ), num_mask_positions) # (B,) + + # Get cumulative count of mask positions + mask_cumsum = mask_positions.cumsum(dim=1) # (B, L) + + # Create window mask: position is eligible if it's a mask and within first k masks + is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L) + window_mask = mask_positions & is_within_window # (B, L) + + # Set confidence to -inf for positions outside the window or non-mask positions + confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device)) + else: + # No window constraint - only mask positions are eligible + confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device)) + + new_xt = xt.clone() + + # vectorized unmasking + max_unmask = unmask_counts.max().item() + if max_unmask > 0: + _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask) + + # create mask for valid unmask operations + unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask) + + most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L) + + selected_positions = all_top_indices[unmask_mask] + batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask] + + new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions] + else: + # --- tau-leaping unmask via Poisson --- + counts = torch.poisson(unmask_rate * dt).long() + mask_pos = xt == mask + counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 + counts[..., mask] = 0 + sum_c = counts.sum(dim=2) + one_event = sum_c == 1 + new_token = counts.argmax(dim=2) + new_xt = xt.clone() + new_xt[one_event] = new_token[one_event] + new_xt = torch.where(xt == pad, pad, new_xt) + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # insertion only on non-last + if i != steps - 1: + # --- Poisson insertion, compute new lengths and fill masks --- + ext = torch.poisson(len_rate * dt).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + gaps = torch.arange(max_length + 1, device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= max_length + ext = ext * valid.view(batch_size, 1).long() + + # compute prefix sums of insertions + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + # initialize with pads, then fill mask up to new_len + xt_tmp = torch.full_like(xt, pad) + mask_pos = pos_idx_L < new_len.view(batch_size, 1) + xt_tmp[mask_pos] = mask + + # shift and scatter original tokens + new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + else: + xt_tmp = new_xt + + xt = xt_tmp + t = t + dt + if return_trace: + sampling_trace.append(xt) + + return xt, sampling_trace diff --git a/a2d2_pep/scripts/run_peptide_finetune.slurm b/a2d2_pep/scripts/run_peptide_finetune.slurm new file mode 100644 index 0000000000000000000000000000000000000000..2a42d480dd9ec3a4dc7d3a7642af0996dfdd6088 --- /dev/null +++ b/a2d2_pep/scripts/run_peptide_finetune.slurm @@ -0,0 +1,210 @@ +#!/bin/bash +# NOTE: --partition and --qos below are specific to our cluster. Change them +# (or remove them and pass `--partition` on the `sbatch` command line) to match +# the partitions/QOS available on yours. +#SBATCH --job-name=peptide-finetune-len256 +#SBATCH --partition=b200-mig90 +#SBATCH --qos=mig +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=80GB +#SBATCH --time=02-00:00:00 +#SBATCH --output=logs/peptide_finetune_%A.log + +# ===================================================================== +# run_peptide_finetune.slurm +# +# Single-mode job (1 MIG GPU) running ONE finetune_quality (peptide) +# experiment. Select which mode to run via the MODE_ID variable below +# (or override at submit time with `sbatch --export=ALL,MODE_ID=2 ...`): +# 0) A2D2 (Ours) – with full planner (alternating) +# 1) A2D2 w/o quality – --disable_planner +# 2) A2D2 w/o insertion planner – --disable_insertion_planner +# 3) A2D2 w/o unmasking planner – --disable_unmasking_planner +# +# The job trains the selected mode then evaluates the resulting +# checkpoint on the same GPU. +# ===================================================================== + +set -e + +# --- Mode selection --------------------------------------------------- +# Which experiment to run (0-3). Override with `--export=ALL,MODE_ID=N`. +MODE_ID="${MODE_ID:-0}" + +# Run prefix: YYYYMMDD + SLURM job ID +DATE_STAMP=$(date +%Y%m%d) +PREFIX="${DATE_STAMP}_job${SLURM_JOB_ID:-local$(date +%H%M%S)}" + +# Default protein target (must be defined before path definitions below) +PROT_NAME=tfr + +# --- Paths ------------------------------------------------------------ +# Repo root is resolved at submit time so the script works from any clone: +# - set A2D2_ROOT explicitly, OR +# - run `sbatch` from the repo root (SLURM sets SLURM_SUBMIT_DIR), OR +# - fall back to this script's location (a2d2_pep/scripts/ -> two levels up). +if [ -n "${A2D2_ROOT:-}" ]; then + HOME_LOC="$A2D2_ROOT" +elif [ -n "${SLURM_SUBMIT_DIR:-}" ]; then + HOME_LOC="$SLURM_SUBMIT_DIR" +else + HOME_LOC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +fi +SCRIPT_LOC="$HOME_LOC/a2d2_pep" +LOG_LOC="$HOME_LOC/logs" +SAVE_DIR="$HOME_LOC/checkpoints/finetune_test_peptides_${PROT_NAME}" +RESULTS_DIR="$HOME_LOC/results/peptide_test_ablation_${PROT_NAME}" + +cd "$SCRIPT_LOC" + +# BASE_PATH is passed as --base_path to finetune_quality.py: it's used +# to build the plot output path at $BASE_PATH/flexible/results/ +# (see finetune_quality.py:421). The pretrained checkpoint is now passed +# explicitly via --checkpoint_path below, so base_path no longer needs +# to follow the legacy /scratch layout. +BASE_PATH="${A2D2_BASE_PATH:-$HOME_LOC}" + +mkdir -p "$LOG_LOC" "$SAVE_DIR" "$RESULTS_DIR" + +# --- Environment setup ------------------------------------------------ +# Do NOT hardcode your W&B key. Either `wandb login` once on the cluster, +# export WANDB_API_KEY in your shell/SLURM environment before submitting, +# or set WANDB_MODE=offline to skip logging entirely. +export WANDB_DIR=$HOME_LOC/.wandb +export WANDB_CONFIG_DIR=$HOME_LOC/.config/wandb +export WANDB_CACHE_DIR=$HOME_LOC/.cache/wandb +# Stop wandb from hijacking stdout/stderr (its default fd-redirect mode sends +# all output to wandb/run-*/files/output.log and freezes the RUN_LOG below). +# With console off, everything flows to the `>> "$RUN_LOG" 2>&1` redirect. +export WANDB_CONSOLE=off +mkdir -p "$WANDB_DIR" "$WANDB_CONFIG_DIR" "$WANDB_CACHE_DIR" + +export TRITON_CACHE_DIR=$HOME_LOC/.triton/cache +mkdir -p "$TRITON_CACHE_DIR" + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Activate conda env. Override CONDA_ROOT to point at your conda/miniconda +# install, or just have `conda` on PATH; override CONDA_ENV if your env name +# differs from the one created by environment.yml. +CONDA_ENV="${CONDA_ENV:-a2d2}" +if [ -n "${CONDA_ROOT:-}" ]; then + source "$CONDA_ROOT/bin/activate" "$CONDA_ENV" +elif command -v conda >/dev/null 2>&1; then + source "$(conda info --base)/bin/activate" "$CONDA_ENV" +else + echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2 + exit 1 +fi +PYTHON_EXECUTABLE=$(which python) + +# Pretrained base checkpoint +PRETRAINED_CKPT="$HOME_LOC/pretrained/anylength_pep.ckpt" + +# --- Shared training hyperparameters ---------------------------------- +COMMON_ARGS=( + --base_path "$BASE_PATH" + --checkpoint_path "$PRETRAINED_CKPT" + --prot_name "$PROT_NAME" + --noise_removal + --wdce_num_replicates 8 + --pool_size 100 + --pool_refresh_fraction 1.0 + --buffer_size 50 + --batch_size 200 + --total_num_steps 256 + --num_iter 20 + --resample_every_n_step 10 + --num_epochs 1000 + --save_every_n_epochs 50 + --reset_every_n_step 1 + --alpha 0.1 + --no_mcts + --schedule_warmup_epochs 20 + --alternation_frequency 5 + --num_remasking 3 + --quality_threshold 0.2 + --training_mini_batch_size 10 + --max_length 256 + --eval_every_n_epochs 50 + --min_peptide_bonds 4 + --grad_clip + --seed 42 +) + +# --- Shared evaluation hyperparameters -------------------------------- +EVAL_COMMON_ARGS=( + --pretrained_ckpt "$PRETRAINED_CKPT" + --num_samples 50 + --batch_size 200 + --max_length 256 + --total_num_steps 256 + --num_remasking 3 + --quality_threshold 0.2 + --prot_name "$PROT_NAME" + --seed 42 +) + +# ===================================================================== +# Pick experiment from $MODE_ID +# ===================================================================== +case "$MODE_ID" in + 0) MODE="with_planner"; EXTRA_ARGS=() ;; + 1) MODE="no_planner"; EXTRA_ARGS=(--disable_planner) ;; + 2) MODE="no_insertion_planner"; EXTRA_ARGS=(--disable_insertion_planner) ;; + 3) MODE="no_unmasking_planner"; EXTRA_ARGS=(--disable_unmasking_planner) ;; + *) echo "Unknown MODE_ID=$MODE_ID (expected 0-3)"; exit 1 ;; +esac + +RUN_NAME="${PREFIX}_peptide_${PROT_NAME}_${MODE}" +RUN_LOG="$LOG_LOC/${RUN_NAME}.log" +RUN_SAVE_DIR="$SAVE_DIR/${RUN_NAME}" +RESULTS_SUBDIR="$RESULTS_DIR/${MODE}" +mkdir -p "$RUN_SAVE_DIR" "$RESULTS_SUBDIR" + +echo "=== Peptide finetune (MODE_ID=$MODE_ID) ===" +echo "Job: ${SLURM_JOB_ID} Node: $SLURM_NODELIST" +echo "Mode: $MODE" +echo "Save dir: $RUN_SAVE_DIR" +echo "Results dir: $RESULTS_SUBDIR" +echo "Python: $PYTHON_EXECUTABLE" +echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-(unset)}" + +# ===================================================================== +# Train +# ===================================================================== +$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune_quality.py \ + "${COMMON_ARGS[@]}" \ + --devices 1 \ + "${EXTRA_ARGS[@]}" \ + --save_path_dir "$RUN_SAVE_DIR" \ + >> "$RUN_LOG" 2>&1 + +echo "Training finished for $MODE. Log: $RUN_LOG" + +# ===================================================================== +# Evaluate +# ===================================================================== +# finetune_quality.py saves to $RUN_SAVE_DIR//last.ckpt, +# so glob the run_name subdir. +RUN_CKPT=$(ls -t "$RUN_SAVE_DIR"/*/last.ckpt 2>/dev/null | head -1) +if [ -z "$RUN_CKPT" ]; then + echo "No checkpoint found in $RUN_SAVE_DIR — skipping eval." + exit 1 +fi + +echo "Evaluating checkpoint: $RUN_CKPT" +$PYTHON_EXECUTABLE $SCRIPT_LOC/evaluate_peptide_table.py \ + --checkpoint_path "$RUN_CKPT" \ + "${EVAL_COMMON_ARGS[@]}" \ + "${EXTRA_ARGS[@]}" \ + --output_dir "$RESULTS_SUBDIR" \ + --device cuda:0 \ + >> "$RESULTS_SUBDIR/${RUN_NAME}_eval.log" 2>&1 + +echo "Eval finished for $MODE. CSV: $RESULTS_SUBDIR/eval_metrics_${MODE}_${PROT_NAME}.csv" + +conda deactivate diff --git a/a2d2_pep/scripts/train_pep.sh b/a2d2_pep/scripts/train_pep.sh new file mode 100755 index 0000000000000000000000000000000000000000..5888c59889c44134c85585edc8ed124d6437053e --- /dev/null +++ b/a2d2_pep/scripts/train_pep.sh @@ -0,0 +1,93 @@ +#!/bin/bash +#SBATCH --job-name=a2d2-pep-pretrain +#SBATCH --partition=dgx-b200 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=512GB +#SBATCH --time=7-00:00:00 +# SLURM's own catch-file (anything printed before the exec redirect below, plus +# slurm-infra messages). Relative to the submit dir, so submit this script from +# the a2d2_pep/ directory; the real run output is redirected via exec below. +#SBATCH --output=logs/slurm/%x_%j.out +#SBATCH --error=logs/slurm/%x_%j.err +# +# Pretrain the any-length insertion MDM on ~11M peptide SMILES on a dgx-b200 node. +# Submit with: sbatch scripts/train_pep.sh (from the a2d2_pep/ directory). +# +# DDP is launched by SLURM: one srun task per GPU. --gpus-per-node and +# --ntasks-per-node must match; change both together (and they override the +# training.devices value baked into config_pep.yaml via the hydra override below). + +DATE=$(date +%Y%m%d) +SPECIAL_PREFIX='a2d2-peptide' + +# Resolve a2d2_pep/ (which holds train.py + config_pep.yaml) so paths are +# repo-relative. This script lives in a2d2_pep/scripts/, so the direct-run +# fallback goes one level up. Under sbatch, BASH_SOURCE points at the spooled +# copy, so we rely on SLURM_SUBMIT_DIR (submit from the a2d2_pep/ directory). +if [ -n "${SLURM_SUBMIT_DIR:-}" ]; then + SCRIPT_DIR="$SLURM_SUBMIT_DIR" +else + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +fi +cd "$SCRIPT_DIR" + +# Auto-detect GPUs from the SLURM allocation (falls back to 4 for `bash` runs). +DEVICES=${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-4}} +NTASKS=${SLURM_NTASKS_PER_NODE:-$DEVICES} +NODES=${SLURM_NNODES:-1} + +LOG_LOC="$SCRIPT_DIR/logs" +mkdir -p "$LOG_LOC/slurm" +exec > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_${SLURM_JOB_ID:-local}.log" 2>&1 + +# --------------------------------------------------------------------------- +# Weights & Biases: log in once on your machine before running this script with +# `wandb login` (or `export WANDB_API_KEY=`). +# Do NOT hardcode your API key here. To disable W&B entirely, uncomment: +# export WANDB_MODE=disabled +# --------------------------------------------------------------------------- + +export PYTORCH_ALLOC_CONF=expandable_segments:True + +# Activate the conda env that has the deps (torch / pytorch_lightning / hydra). +# The batch shell does NOT source ~/.bashrc, so conda is not on PATH. Override +# CONDA_ROOT to point at your conda/miniconda install, or just have `conda` on +# PATH; override CONDA_ENV if your env name differs from the one created by +# environment.yml. +CONDA_ENV="${CONDA_ENV:-a2d2}" +if [ -n "${CONDA_ROOT:-}" ]; then + source "$CONDA_ROOT/bin/activate" "$CONDA_ENV" +elif command -v conda >/dev/null 2>&1; then + source "$(conda info --base)/bin/activate" "$CONDA_ENV" +else + echo "ERROR: conda not found; set CONDA_ROOT to your miniconda install." >&2 + exit 1 +fi + +# --- Distributed / NCCL setup (single node, intra-node NVLink) -------------- +ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -E "ens|eth|enp|bond" | head -1 | awk '{print $2}') +if [ -z "$ETH_IFACE" ]; then + ETH_IFACE=$(ip -o -4 addr list | grep -v "127.0.0.1" | grep -v "ibp" | head -1 | awk '{print $2}') +fi +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_FAMILY=AF_INET +export NCCL_SOCKET_IFNAME=$ETH_IFACE +export NCCL_P2P_LEVEL=NVL + +export MASTER_ADDR=$(scontrol show hostnames "${SLURM_NODELIST:-$(hostname)}" | head -n 1) +export MASTER_PORT=$(shuf -i 15000-59999 -n 1) +export NODE_RANK=${SLURM_NODEID:-0} + +echo "=== a2d2 peptide pretraining (dgx-b200) ===" +echo "Job ID: ${SLURM_JOB_ID:-local} Node: ${SLURM_NODELIST:-$(hostname)} GPUs: $DEVICES Tasks: $NTASKS" + +# --task pep makes train.py load config_pep.yaml; the hydra overrides pin +# devices/nodes to the SLURM allocation so the two never drift apart. +srun --ntasks-per-node=$NTASKS python train.py --task pep \ + training.devices=$DEVICES \ + training.nodes=$NODES + +conda deactivate diff --git a/a2d2_pep/train.py b/a2d2_pep/train.py new file mode 100755 index 0000000000000000000000000000000000000000..9823bf0fd4ced62852356a7b6ce5f83d1156301d --- /dev/null +++ b/a2d2_pep/train.py @@ -0,0 +1,216 @@ +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +import os +import sys +import argparse +import hydra +from omegaconf import OmegaConf +from datetime import datetime +# Directory containing this file and the config_*.yaml files (used by Hydra below). +CONFIG_DIR = os.path.dirname(os.path.abspath(__file__)) +# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve. +sys.path.insert(0, os.path.dirname(CONFIG_DIR)) + +import wandb +from lightning_modules import AnyOrderInsertionFlowModule + + +torch.set_printoptions(threshold=10_000) +torch.set_float32_matmul_precision("high") + +# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops +torch._dynamo.config.optimize_ddp = False + +def train(config): + wandb_logger = None + + # set the random seed + pl.seed_everything(42) + torch.manual_seed(42) + + # Only initialize wandb on rank 0 to avoid multiple runs + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + wandb.init( + project=config.wandb.project, + name=config.wandb.name, + config=OmegaConf.to_container(config, resolve=True), # Convert to dict + dir=config.wandb.path + ) + wandb_logger = WandbLogger( + project=wandb.run.project, + name=wandb.run.name, + log_model=False, # Disable checkpoint uploading to save disk space + ) + + # Modify config to add timestamp to checkpoint directory + OmegaConf.set_struct(config, False) + time_string = datetime.now().strftime("%Y%m%d-%H%M%S") + config.training.checkpoint_dir = os.path.join( + config.training.checkpoint_dir, time_string + ) + OmegaConf.set_struct(config, True) + + # Create checkpoint directory + os.makedirs(config.training.checkpoint_dir, exist_ok=True) + + # Setup data module - check if using HuggingFace dataset + if hasattr(config, 'hf_dataset'): + # Imported lazily: the HF/SAFE path is only used by the molecule configs, + # which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/. + from mol_dataset import setup_hf_data_and_update_config + print(f"Using HuggingFace dataset: {config.hf_dataset.name}") + data_module = setup_hf_data_and_update_config( + config, + dataset_name=config.hf_dataset.name, + smiles_column=config.hf_dataset.get('smiles_column', 'smiles') + ) + else: + # Imported lazily: the local (arrow) path is used by the peptide config, + # which keeps dataloading_for_dynamic_batching.py in a2d2_pep/. + from data.dataloading_for_dynamic_batching import setup_data_and_update_config + print("Using local dataset") + data_module = setup_data_and_update_config(config) + + module = AnyOrderInsertionFlowModule(config) + + # Initialize trainer + + # Configure trainer arguments + # Map torch_dtype to Lightning precision + dtype_str = config.model.get('torch_dtype', 'bfloat16') + precision_map = { + 'float32': '32-true', + 'float16': '16-mixed', + 'bfloat16': 'bf16-mixed' + } + precision = precision_map.get(dtype_str, 'bf16-mixed') + + trainer_kwargs = dict( + num_nodes=config.training.nodes, + accelerator="gpu", + devices=config.training.devices, + strategy="ddp", + precision=precision, + accumulate_grad_batches=( + config.training.batch_size + // ( + config.training.per_gpu_batch_size + * config.training.nodes + * config.training.devices + ) + ), + log_every_n_steps=10, + enable_checkpointing=True, + default_root_dir=config.training.checkpoint_dir, + gradient_clip_val=1.0, + ) + # Only one of max_steps or max_epochs will be used + if config.training.max_steps is not None: + trainer_kwargs["max_steps"] = config.training.max_steps + elif config.training.num_epochs is not None: + trainer_kwargs["max_epochs"] = config.training.num_epochs + config.training.max_steps = config.training.max_steps + else: + raise ValueError( + "Either max_steps or num_epochs must be specified in the config" + ) + + if config.training.warmup_steps is None: + config.training.warmup_steps = int(config.training.max_steps * 0.01) + + # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low + checkpoint_callback = ModelCheckpoint( + monitor="train/total_loss", + mode="min", + save_top_k=config.training.save_top_k, + save_last=True, + filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}", + dirpath=config.training.checkpoint_dir, + # Don't use val_loss in filename for periodic saves - causes failures when val doesn't run + auto_insert_metric_name=False + ) + + # Add separate callback for periodic saves (no val_loss dependency). Use + # step-based saves for streaming datasets (save_every_n_steps) and epoch-based + # saves otherwise (save_every_n_epochs); whichever the config provides. + save_every_n_steps = config.training.get('save_every_n_steps', None) + save_every_n_epochs = config.training.get('save_every_n_epochs', None) + if save_every_n_steps is not None: + periodic_checkpoint_callback = ModelCheckpoint( + save_top_k=-1, # Save all periodic checkpoints + filename="step-{step:08d}", + dirpath=config.training.checkpoint_dir, + every_n_train_steps=save_every_n_steps, + auto_insert_metric_name=False + ) + elif save_every_n_epochs is not None: + periodic_checkpoint_callback = ModelCheckpoint( + save_top_k=-1, # Save all periodic checkpoints + filename="epoch-{epoch:02d}", + dirpath=config.training.checkpoint_dir, + every_n_epochs=save_every_n_epochs, + auto_insert_metric_name=False + ) + else: + raise ValueError( + "Either save_every_n_steps or save_every_n_epochs must be specified in the config" + ) + + trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback] + + if wandb_logger is not None: + trainer_kwargs["logger"] = wandb_logger + + trainer = pl.Trainer(**trainer_kwargs) + + # Train the model + ckpt_path = None + if "resume_path" in config.training: + ckpt_path = config.training.resume_path + + trainer.fit(module, + datamodule=data_module, + ckpt_path=ckpt_path) + + # Only finish wandb on rank 0 + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + wandb.finish() + + +if __name__ == '__main__': + # Parse arguments to get config name + parser = argparse.ArgumentParser() + parser.add_argument('--config_name', type=str, default='config', + help='Name of the config file to use') + parser.add_argument('--task', type=str, default=None, + help='Task name (uses config_{task}.yaml)') + + # Parse known args (hydra will handle the rest) + args, unknown = parser.parse_known_args() + + # Determine config name from task or config_name + if args.task: + config_name = f'config_{args.task}' + else: + config_name = args.config_name + + print(f"Using config: {config_name}.yaml") + + # Add config name to Hydra overrides (this persists across DDP subprocesses) + if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown: + unknown.insert(0, f'--config-name={config_name}') + + # Reconstruct sys.argv for hydra + sys.argv = [sys.argv[0]] + unknown + + # Define main function with default config (will be overridden by command line) + @hydra.main(version_base=None, + config_path=CONFIG_DIR, + config_name='config') + def main(config): + """Main entry point for training""" + train(config) + + main() \ No newline at end of file diff --git a/assets/a2d2.gif b/assets/a2d2.gif new file mode 100644 index 0000000000000000000000000000000000000000..2d30aa8009f5e608aa559974ef07a7a7ed0a98fa --- /dev/null +++ b/assets/a2d2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:178ca7850ca39365492fea70cfc5e4f2e8653ceeda9a13dcd0438af61e1a83bb +size 7826144 diff --git a/demo/quality_inference_demo.ipynb b/demo/quality_inference_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d1b22bdb1fd856d37cd2ec573bdc2c35cee27870 --- /dev/null +++ b/demo/quality_inference_demo.ipynb @@ -0,0 +1,785 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **A2D2 Unmasking + Insertion Quality Inference on a Toy Example**\n", + "\n", + "This notebook is a self-contained, runnable illustration of the two quantities at the heart of **A2D2: Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding**, and how they are used at *inference* time to decode adaptively:\n", + "\n", + "| Quantity | Definition | Inference use |\n", + "|---|---|---|\n", + "| **Unmasking quality** $\\mu_\\star^\\ell(\\boldsymbol{y})=p(\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}\\mid \\boldsymbol{y})=f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell,\\boldsymbol{x}_1^\\ell]$ | probability an unmasked token matches the true token given context | **re-mask** low-quality tokens so only mutually high-quality tokens stay unmasked in parallel |\n", + "| **Insertion quality** $\\nu_\\star^\\ell(\\boldsymbol{y})=\\sum_{\\boldsymbol{v}\\in\\mathcal{S}_\\ell}p(\\boldsymbol{y}^\\ell=\\boldsymbol{v}\\mid \\boldsymbol{y})$ | probability an inserted mask decodes to *some* true token belonging in its gap | **drop** low-quality insertions that would otherwise cause length/spacing errors |\n", + "\n", + "In this notebook, we build a tiny, fully-analytic **toy \"language\"** whose true denoising posterior $f_\\theta$ is known in closed form. That lets us (1) compute $\\mu_\\star,\\nu_\\star$ exactly, (2) train the light-weight predictors $\\mu_\\phi,\\nu_\\phi$ with the paper's **UQL/IQL** BCE losses and watch them recover the truth, and (3) drive the **actual** A2D2 inference routines `apply_schedule_aware_remasking` / `apply_schedule_aware_insertion` (from `remasking_scheduleaware.py`) on hand-crafted states." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0. Setup\n", + "\n", + "The two schedule-aware routines below are from the repo's `remasking_scheduleaware.py`. They are the exact code the A2D2 sampler calls. We only feed them a tiny toy `model` exposing the same interface (`model.planner(seq, t)` → confidences, `model.interpolant.{unmask,insertion}_schedule.at(t)`)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import matplotlib.pyplot as plt\n", + "\n", + "torch.manual_seed(0)\n", + "np.random.seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vendored A2D2 schedule-aware routines loaded\n" + ] + } + ], + "source": [ + "def _debias_by_token(conf, tokens, clean_index):\n", + " \"\"\"Subtract the per-token-identity mean confidence so a token's score reflects\n", + " how confident it is RELATIVE to others of the same type (see repo docstring).\"\"\"\n", + " if clean_index.sum() == 0:\n", + " return conf\n", + " V = int(tokens.max().item()) + 1\n", + " flat_tok = tokens[clean_index]\n", + " flat_conf = conf[clean_index]\n", + " sums = torch.zeros(V, device=conf.device, dtype=conf.dtype).scatter_add_(0, flat_tok, flat_conf)\n", + " cnts = torch.zeros(V, device=conf.device, dtype=conf.dtype).scatter_add_(\n", + " 0, flat_tok, torch.ones_like(flat_conf))\n", + " mean_v = sums / cnts.clamp(min=1.0)\n", + " baseline = mean_v[tokens]\n", + " return torch.where(clean_index, conf - baseline, conf)\n", + "\n", + "\n", + "def apply_schedule_aware_insertion(model, xt_tmp, new_xt, t, dt, ext, mask, pad,\n", + " max_length, orig_mask, new_pos_orig, quality_threshold=1):\n", + " \"\"\"Remove low-quality insertions (insertion_conf < threshold) while keeping the\n", + " sequence length no shorter than the interpolant schedule allows.\"\"\"\n", + " device = xt_tmp.device\n", + " batch_size, L = xt_tmp.shape\n", + " total_ext = ext.sum(dim=1)\n", + " if total_ext.sum() == 0:\n", + " return xt_tmp\n", + "\n", + " t_next = t + dt\n", + " planner_out = model.planner(xt_tmp, t)\n", + " insertion_conf = planner_out.get(\"insertion_conf\", None)\n", + " if insertion_conf is None:\n", + " return xt_tmp\n", + " insertion_conf = insertion_conf.squeeze(-1) # (B, L)\n", + "\n", + " current_length_after = xt_tmp.ne(pad).sum(dim=1).float()\n", + " expected_progress = model.interpolant.insertion_schedule.at(t_next)\n", + " estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))\n", + " expected_length = estimated_final_length * expected_progress\n", + "\n", + " valid_b, valid_l = orig_mask.nonzero(as_tuple=True)\n", + " valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)\n", + " is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)\n", + " is_original[valid_b, valid_p] = True\n", + " inserted_positions = (xt_tmp == mask) & ~is_original\n", + "\n", + " candidates = inserted_positions & (insertion_conf < quality_threshold)\n", + " num_bad = candidates.sum(dim=1)\n", + " min_length = expected_length.long().clamp(min=1)\n", + " max_removable = (current_length_after.long() - min_length).clamp(min=0)\n", + " length_after_removal = current_length_after.long() - num_bad\n", + " schedule_violates = length_after_removal < min_length\n", + " k_per_row = torch.where(schedule_violates, max_removable, num_bad)\n", + " k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))\n", + "\n", + " if not candidates.any():\n", + " return xt_tmp\n", + "\n", + " neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)\n", + " scores = torch.where(candidates, -insertion_conf, neg_inf)\n", + " _, sorted_indices = scores.sort(dim=1, descending=True)\n", + " positions = torch.arange(L, device=device).unsqueeze(0)\n", + " keep_in_topk = positions < k_per_row.unsqueeze(1)\n", + " final_bad = torch.zeros_like(candidates)\n", + " final_bad.scatter_(1, sorted_indices, keep_in_topk)\n", + " if not final_bad.any():\n", + " return xt_tmp\n", + "\n", + " sort_key = final_bad.long()\n", + " _, perm = torch.sort(sort_key, dim=1, stable=True)\n", + " xt_tmp = torch.gather(xt_tmp, 1, perm)\n", + " num_keep = (~final_bad).sum(dim=1)\n", + " tail_mask = positions >= num_keep.unsqueeze(1)\n", + " xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)\n", + " return xt_tmp\n", + "\n", + "\n", + "def apply_schedule_aware_remasking(model, new_xt, t, dt, remasking_conf, clean_index,\n", + " mask, neg_inf, batch_size, debias=False):\n", + " \"\"\"Re-mask the lowest-confidence clean tokens so the unmasked count tracks the\n", + " unmask schedule (FlexMDM-style schedule-aware remasking).\"\"\"\n", + " t_next = t + dt\n", + " num_clean = clean_index.sum(dim=1)\n", + " current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float()\n", + " expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next)\n", + " expected_num_clean = expected_unmasked_frac * current_seq_len\n", + " masks_to_add = (num_clean.float() - expected_num_clean).round().long()\n", + "\n", + " k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean)\n", + " if k_per_row.sum() == 0:\n", + " return new_xt\n", + "\n", + " conf = _debias_by_token(remasking_conf, new_xt, clean_index) if debias else remasking_conf\n", + " remasking_score_temp = -1.0 * conf\n", + " remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)\n", + " _, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)\n", + " L = remasking_score_temp.shape[1]\n", + " positions = torch.arange(L, device=new_xt.device).unsqueeze(0)\n", + " keep_in_topk = positions < k_per_row.unsqueeze(1)\n", + " to_mask = torch.zeros_like(clean_index)\n", + " to_mask.scatter_(1, sorted_indices, keep_in_topk)\n", + " new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)\n", + " return new_xt\n", + "\n", + "print(\"vendored A2D2 schedule-aware routines loaded\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. A toy \"language\" with a known denoising posterior $f_\\theta$\n", + "\n", + "A real any-length MDM parameterizes an **unmasking posterior** $f_\\theta(\\boldsymbol{x}_t,t)[\\ell]\\in\\Delta^V$. Here we replace it with an analytic stand-in: a 5-letter **bigram** model with a peaked transition matrix $T$. Given a partially-masked sequence, the posterior at a masked position is the (normalized) product of the constraints from its clean left/right neighbours, which is exactly the kind of context-dependent posterior $f_\\theta$ learns, but in closed form so we know $\\mu_\\star,\\nu_\\star$ exactly.\n", + "\n", + "Special tokens: `MASK` and `PAD` (matching the sampler's two reserved ids)." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a clean target x1 : BCDEABCDEA\n", + "posterior at a masked middle position (neighbours clean):\n", + " state: BCDEA_CDEA f_theta[5] = [0.009 0.963 0.009 0.009 0.009]\n" + ] + } + ], + "source": [ + "LETTERS = [\"A\", \"B\", \"C\", \"D\", \"E\"]\n", + "V = len(LETTERS)\n", + "MASK, PAD = V, V + 1 # reserved ids, as in the sampler\n", + "NTOK = V + 2\n", + "\n", + "# Peaked bigram transition matrix -> context is informative.\n", + "T = np.array([\n", + " [0.05, 0.70, 0.10, 0.10, 0.05], # A -> mostly B\n", + " [0.05, 0.05, 0.75, 0.10, 0.05], # B -> mostly C\n", + " [0.10, 0.05, 0.05, 0.70, 0.10], # C -> mostly D\n", + " [0.10, 0.10, 0.05, 0.05, 0.70], # D -> mostly E\n", + " [0.70, 0.05, 0.10, 0.10, 0.05], # E -> mostly A\n", + "], dtype=np.float64)\n", + "T /= T.sum(1, keepdims=True)\n", + "pi0 = np.array([0.4, 0.2, 0.2, 0.1, 0.1]) # initial-token distribution\n", + "\n", + "def sample_seq(n):\n", + " \"\"\"Draw a clean sequence x_1 ~ p_target from the bigram chain.\"\"\"\n", + " s = [np.random.choice(V, p=pi0)]\n", + " for _ in range(n - 1):\n", + " s.append(np.random.choice(V, p=T[s[-1]]))\n", + " return s\n", + "\n", + "def posterior_at(seq, ell):\n", + " \"\"\"Toy f_theta(seq, t)[ell] in Delta^V: bigram constraints from clean neighbours.\"\"\"\n", + " factor = np.ones(V)\n", + " if ell - 1 >= 0 and seq[ell - 1] < V: # clean left neighbour a -> T[a, :]\n", + " factor *= T[seq[ell - 1], :]\n", + " if ell + 1 < len(seq) and seq[ell + 1] < V: # clean right neighbour c -> T[:, c]\n", + " factor *= T[:, seq[ell + 1]]\n", + " if factor.sum() == 0:\n", + " factor = np.ones(V)\n", + " return factor / factor.sum()\n", + "\n", + "def decode(seq):\n", + " return \"\".join(\"_\" if t == MASK else (\".\" if t == PAD else LETTERS[t]) for t in seq)\n", + "\n", + "x1 = sample_seq(10)\n", + "print(\"a clean target x1 :\", decode(x1))\n", + "print(\"posterior at a masked middle position (neighbours clean):\")\n", + "probe = x1.copy(); probe[5] = MASK\n", + "print(\" state:\", decode(probe), \" f_theta[5] =\", np.round(posterior_at(probe, 5), 3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Unmasking quality $\\mu_\\star$\n", + "\n", + "Given a clean target $\\boldsymbol{x}_1$, an interpolant sample $\\tilde{\\boldsymbol{x}}_t$ (partially masked), and a candidate unmasking $\\boldsymbol{y}$ with $\\boldsymbol{y}^\\ell \\sim f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell]$, the **unmasking quality** is the posterior probability of the *true* token:\n", + "\n", + "$$\\mu_\\star^\\ell(\\boldsymbol{y}) = p\\!\\left(\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}\\mid \\boldsymbol{y}\\right)=f_\\theta(\\tilde{\\boldsymbol{x}}_t,t)[\\ell,\\boldsymbol{x}_1^\\ell]$$\n", + "\n", + "High when the context pins down the token; low when it doesn't. Below, the middle masked position (both neighbours clean) is high-quality, while an isolated masked position (no clean neighbours) has a near-uniform posterior → low quality." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "clean x1 : CDEDABAC\n", + "masked x~t: C_ED_B_C\n", + "masked isolated: ___DABAC\n", + "\n", + " pos 1: true D mu* = 0.966 posterior=[0.01 0. 0.01 0.97 0.01]\n", + " pos 4: true A mu* = 0.596 posterior=[0.6 0.04 0.02 0.04 0.3 ]\n", + " pos 6: true A mu* = 0.056 posterior=[0.06 0.42 0.42 0.06 0.06]\n", + " isolated pos 1: true D mu* = 0.200 (near-uniform -> low quality)\n" + ] + } + ], + "source": [ + "x1 = sample_seq(8)\n", + "xt = x1.copy()\n", + "for ell in [1, 4, 6]:\n", + " xt[ell] = MASK\n", + "print(\"clean x1 :\", decode(x1))\n", + "print(\"masked x~t:\", decode(xt))\n", + "print(\"masked isolated: \", end=\"\")\n", + "iso = [MASK if i in (0, 1, 2) else x1[i] for i in range(len(x1))] # left token has no clean nbrs\n", + "print(decode(iso))\n", + "print()\n", + "for ell in [1, 4, 6]:\n", + " p = posterior_at(xt, ell)\n", + " print(f\" pos {ell}: true {LETTERS[x1[ell]]} mu* = {p[x1[ell]]:.3f} posterior={np.round(p,2)}\")\n", + "p_iso = posterior_at(iso, 1)\n", + "print(f\" isolated pos 1: true {LETTERS[x1[1]]} mu* = {p_iso[x1[1]]:.3f} (near-uniform -> low quality)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Training the unmasking-quality predictor $\\mu_\\phi$ (UQL)\n", + "\n", + "At inference we have no $\\boldsymbol{x}_1$, so A2D2 trains a light head $\\mu_\\phi$ to predict quality from the *post-unmasking* sequence $\\boldsymbol{y}$ alone, via the **Unmasking Quality Loss** (binary cross-entropy against the 0/1 hit indicator):\n", + "\n", + "$$\\mathcal{L}_{\\text{UQL}}(\\phi)=\\mathbb{E}\\Big[\\textstyle\\sum_{\\ell\\in\\mathcal{M}}\\text{BCE}\\big(\\mathbf{1}[\\boldsymbol{y}^\\ell=\\boldsymbol{x}_1^{s_t[\\ell]}],\\,\\mu_\\phi^\\ell(\\boldsymbol{y})\\big)\\Big].$$\n", + "\n", + "The unique minimizer is the true quality: $\\mu_{\\phi^\\star}=\\mu_\\star$. Note the head sees the *still-masked* neighbours (the realistic decode state), so the BCE minimizer is $P(\\text{correct}\\mid \\boldsymbol{y}) = f_\\theta(\\cdot)[\\ell,\\boldsymbol{y}^\\ell]$. We verify $\\mu_\\phi$ recovers it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " step 500 UQL running avg = 0.5410 (irreducible floor = 0.4777)\n", + " step 1000 UQL running avg = 0.4752 (irreducible floor = 0.4777)\n", + " step 1500 UQL running avg = 0.4573 (irreducible floor = 0.4777)\n", + " step 2000 UQL running avg = 0.4625 (irreducible floor = 0.4777)\n", + "\n", + "running loss 0.4625 sits at the floor 0.4777 -> converged\n", + "corr(mu_phi, mu_star) on held-out states = 0.959\n" + ] + } + ], + "source": [ + "WIN = 2 # context window (each side) fed to the head\n", + "\n", + "def window_feat(seq, ell):\n", + " feats = []\n", + " for d in range(-WIN, WIN + 1):\n", + " j = ell + d\n", + " tok = seq[j] if 0 <= j < len(seq) else PAD\n", + " feats.append(F.one_hot(torch.tensor(tok), NTOK).float())\n", + " return torch.cat(feats)\n", + "\n", + "def make_unmask_example():\n", + " n = np.random.randint(6, 11)\n", + " x1 = sample_seq(n)\n", + " ctx = x1.copy()\n", + " for i in range(n): # interpolant: mask ~half the tokens\n", + " if np.random.rand() < 0.5:\n", + " ctx[i] = MASK\n", + " ell = np.random.randint(n)\n", + " p = posterior_at(ctx, ell)\n", + " samp = int(np.random.choice(V, p=p)) # y^ell ~ f_theta(ctx)[ell]\n", + " y = ctx.copy(); y[ell] = samp # predictor sees still-masked neighbours\n", + " label = float(samp == x1[ell]) # 1[y^ell == x_1]\n", + " return window_feat(y, ell), torch.tensor([label]), p[samp] # p[samp] = true quality of y\n", + "\n", + "class QualityHead(nn.Module):\n", + " def __init__(self, din):\n", + " super().__init__()\n", + " self.net = nn.Sequential(nn.Linear(din, 64), nn.GELU(), nn.Linear(64, 1))\n", + " def forward(self, x):\n", + " return torch.sigmoid(self.net(x))\n", + " \n", + "din = (2 * WIN + 1) * NTOK\n", + "\n", + "\"\"\"\n", + "The UQL label 1[y^l == x_1] is a coin flip with probability mu_star, so even a PERFECT\n", + "mu_phi cannot beat the entropy floor E[H(mu_star)]. The per-batch loss is also noisy\n", + "(a fresh random minibatch each step), so we report an EMA-smoothed running loss + the floor. \n", + "\n", + "Convergence means the running loss sits AT the floor, not that it reaches 0\n", + "\"\"\" \n", + "_, _, q_pool = zip(*[make_unmask_example() for _ in range(4000)])\n", + "uql_floor = binary_entropy(q_pool)\n", + "\n", + "mu_phi = QualityHead(din)\n", + "opt = torch.optim.Adam(mu_phi.parameters(), lr=2e-3)\n", + "run = None\n", + "for step in range(2000):\n", + " feats, labels, _ = zip(*[make_unmask_example() for _ in range(64)])\n", + " loss = F.binary_cross_entropy(mu_phi(torch.stack(feats)), torch.stack(labels))\n", + " opt.zero_grad(); loss.backward(); opt.step()\n", + " run = loss.item() if run is None else 0.98 * run + 0.02 * loss.item()\n", + " if (step + 1) % 500 == 0:\n", + " print(f\" step {step+1:4d} UQL running avg = {run:.4f}\")\n", + "\n", + "with torch.no_grad():\n", + " fs, _, qs = zip(*[make_unmask_example() for _ in range(800)])\n", + " pred = mu_phi(torch.stack(fs)).squeeze(-1).numpy()\n", + "qs = np.array(qs)\n", + "print(f\"corr(mu_phi, mu_star) on held-out states = {np.corrcoef(pred, qs)[0,1]:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAGZCAYAAABbpUzOAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAlEJJREFUeJzs3Wd0FFUfgPFn+2bTewNC6J1AkKaIIkWKKChFRJSiYkNAVBAbyAsCgtgAGyhKFREsWFCkSJFi6B0CAdJ72WTrvB9CVlJJ2U02cH/ncHRmZ+7c2STzn9tlkiRJCIIgCIIDyGs6A4IgCMLNSwQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZQRAEwWFEkBEEQRAcRgQZO3rrrbeQyWQkJyeX+HmrVq246667iu2/fPkyzz33HA0bNkSr1eLt7U2PHj1Yu3ZtsWMvXryITCbj3XfftXf2hRKsWrWKRYsWFdsvfg72U79+fQYMGHDD42QyGW+99ZbjM+Rk7rrrrmLPjaLfxYkTJ3jrrbe4ePFiteatPJQ1nYFb3a5duxgwYABubm689NJLtGnThoyMDNatW8fw4cPZvHkzX375JTKZrKazektatWoVx44dY+LEiTWdlVvenj17qFOnTk1nwykU/S5OnDjBjBkzuOuuu6hfv37NZawEIsjUoPT0dAYPHoynpyf//PMPgYGBts/uv/9+2rRpw9SpU4mIiGDSpEkOyYNer0en0zkkbWdws9/fraRz5841nQWnUZu+C1FdVoM+//xzEhMTeeeddwoFmAIvv/wyzZo1Y86cOZjN5ipf76677qJVq1bs2LGDrl27otPpGDNmDACZmZlMmTKF8PBw1Go1oaGhTJw4kZycnEJpWK1WPvzwQyIiInBxccHLy4vOnTvzww8/FDpm3rx5NGvWDI1GQ0BAAKNGjeLKlSu2YyZOnIirqyuZmZnF8jls2DACAwMxmUy2fWvXrqVLly64urri5uZGnz59iIqKKnTe448/jpubG0ePHqV37964u7tzzz33ABAVFcWAAQMICAhAo9EQEhJC//79C+WppO/r559/5tKlS8hkMtu/ohYuXEh4eDhubm506dKFvXv3FjvmwIEDDBw4EB8fH7RaLe3atWPdunWlXrvAtm3bkMlkbNu2rdD+guq6L7/8stj9nzt3jn79+uHm5kbdunV58cUXMRgMxc6dP38+c+fOpX79+ri4uHDXXXdx5swZTCYTU6dOJSQkBE9PTwYNGkRiYmKh669du5bevXsTHByMi4sLzZs3Z+rUqcV+Xy5cuMDw4cMJCQlBo9EQGBjIPffcw6FDh8q878WLF6NUKnnzzTdt+4pWERWU8P/66y+efvpp/Pz88PX1ZfDgwcTGxhZKz2Aw8OKLLxIUFIROp+POO+/k4MGD1K9fn8cff7zMvADExsYydOhQ3N3d8fT0ZNiwYezdu7fYz6Ckqi3I/9kULWHMmDGDTp064ePjg4eHB+3bt+eLL76gPHMWX/9dfPnllwwZMgSAu+++2/Z7+uWXX/L222+jVCq5fPlysTTGjBmDr68veXl5N7xeVYggU4O2bNmCQqHgvvvuK/FzmUzGwIEDSUpKKvZAray4uDhGjhzJiBEj2Lx5M8888wx6vZ7u3bvz1VdfMWHCBH755RdeeeUVvvzySwYOHFjol/7xxx/nhRde4LbbbmPt2rWsWbOGgQMHFqoLfvrpp3nllVfo1asXP/zwA2+//Ta//vorXbt2tbVXjRkzBr1eX+xBm56ezqZNmxg5ciQqlQqA2bNn8/DDD9OiRQvWrVvH119/TVZWFt26dePEiROFzjcajQwcOJAePXqwadMmZsyYQU5ODr169SIhIYGPP/6YLVu2sGjRIurVq0dWVlap39XixYu5/fbbCQoKYs+ePbZ/17s+vZUrV5KTk0O/fv3IyMiwHfPXX39x++23k56eztKlS9m0aRMREREMGzas0APKHkwmEwMHDuSee+5h06ZNjBkzhvfee4+5c+cWO/bjjz9m165dfPzxx3z++eecOnWK++67j7Fjx5KUlMSyZcuYN28ef/zxB+PGjSt07tmzZ+nXrx9ffPEFv/76KxMnTmTdunXFfpf79evHwYMHmTdvHlu2bGHJkiW0a9eO9PT0EvMvSRJTpkxh4sSJfP7558yYMeOG9zxu3DhUKhWrVq1i3rx5bNu2jZEjRxY6ZvTo0SxatIjRo0ezadMmHnzwQQYNGlRqPq6Xm5tLz549+f3335kzZw7ffvstQUFBDBs27IbnluXixYs89dRTrFu3jg0bNjB48GCef/553n777Qql079/f2bPng3k/0wLfk/79+/PU089hVKp5JNPPil0TmpqKmvWrGHs2LFotdoq3ccNSYLdvPnmmxIgJSUllfh5y5Ytpe7du9u2mzVrJgUFBZWZ5pIlSyRA+vbbbyVJkqTo6GgJkObPn1/h/HXv3l0CpD///LPQ/jlz5khyuVzav39/of3r16+XAGnz5s2SJEnSjh07JECaPn16qdc4efKkBEjPPPNMof3//POPBEivvvqqbV/79u2lrl27Fjpu8eLFEiAdPXpUkiRJiomJkZRKpfT8888XOi4rK0sKCgqShg4datv32GOPSYC0bNmyQsceOHBAAqSNGzeWmu/S9O/fXwoLCyu2v+Dn0Lp1a8lsNtv279u3TwKk1atX2/Y1a9ZMateunWQymQqlMWDAACk4OFiyWCylXv+vv/6SAOmvv/4q8frLly+37Su4/3Xr1hU6tl+/flLTpk2Lndu2bdtC1160aJEESAMHDix0/sSJEyVAysjIKDGPVqtVMplM0vbt2yVAOnz4sCRJkpScnCwB0qJFi0q9P0mSpLCwMKl///6SXq+XHnzwQcnT01P6448/ih0HSG+++aZte/ny5SX+rs2bN08CpLi4OEmSJOn48eMSIL3yyiuFjlu9erUESI899liZ+Sv4G9y0aVOh/U888USxn0H37t0L/Y0XeOyxx0r8PSpgsVgkk8kkzZw5U/L19ZWsVmuZaRb9Lr799tsSf08Krh0QECAZDAbbvrlz50pyuVyKjo4uNU/2IkoyTk66VoqwV8N/Qc+16/3000+0atWKiIgIzGaz7V+fPn0KVdX88ssvADz77LOlpv/XX38BFKuC6NixI82bN+fPP/+07Rs9ejS7d+/m9OnTtn3Lly/ntttuo1WrVgD89ttvmM1mRo0aVShvWq2W7t27F6tGAnjwwQcLbTdq1Ahvb29eeeUVli5dWqz0UxX9+/dHoVDYttu0aQPApUuXADh37hynTp3ikUceASh0D/369SMuLq7Q/VeVTCYrVppo06aNLT/X69evH3L5f4+A5s2b2+7pegX7Y2JibPsuXLjAiBEjCAoKQqFQoFKp6N69OwAnT54EwMfHh4YNGzJ//nwWLlxIVFQUVqu1xHynpKTQo0cP9u3bx99//22r5iyPgQMHFrtf+O9nsH37dgCGDh1a6LiHHnoIpfLGzdJ//fUX7u7uxa4zYsSIcuexJFu3bqVnz554enravsM33niDlJSUYtWTVfHCCy+QmJjIt99+C+RXZy9ZsoT+/ftXSycBEWTsqOAX1mKxlPi52Wy2VQEB1KtXj6SkpGL12NcrqIaqW7euXfIYHBxcbF9CQgJHjhxBpVIV+ufu7o4kSbYqrqSkJBQKBUFBQaWmn5KSUup1QkJCbJ8DPPLII2g0GluV0YkTJ9i/fz+jR48ulDeA2267rVj+1q5dW6y7uE6nw8PDo9A+T09Ptm/fTkREBK+++iotW7YkJCSEN998s1C7T2X4+voW2tZoNEB+Fcv1+Z8yZUqx/D/zzDMApXZ5rwydTles+kOj0ZRY7+7j41NoW61Wl7m/II3s7Gy6devGP//8w6xZs9i2bRv79+9nw4YNwH/3LpPJ+PPPP+nTpw/z5s2jffv2+Pv7M2HChGLVlGfOnOGff/6hb9++theM8rrRz6Dgd65ou6dSqSx2bklSUlJKbDMt6+/gRvbt20fv3r0B+Oyzz9i1axf79+9n+vTphfJuD+3ataNbt258/PHHQP5L5cWLF3nuuefsdo2yiN5ldlTwi3j16tViv5SSJBEXF0eHDh1s+3r37s3vv//Ojz/+yPDhw4ulJ0kSP/zwA76+vrRt29YueSypROTn54eLiwvLli0r8Rw/Pz8A/P39sVgsxMfHlxhE4L8/+Li4uGLdTWNjY21pQX6p6v7772fFihXMmjWL5cuXo9Vqefjhh4tde/369YSFhVXq/gBat27NmjVrkCSJI0eO8OWXXzJz5kxcXFyYOnXqDdOtrIL8T5s2jcGDB5d4TNOmTUs9vyBgXN9wD/YNTBW1detWYmNj2bZtm630ApTYvhEWFsYXX3wB5AeSdevW8dZbb2E0Glm6dKntuC5dujBkyBDGjh0LwJIlSwqVsqqi4HcyISGB0NBQ236z2Vzopaes8/ft21dsf3x8fLF9Wq22UHtcgaI/rzVr1qBSqfjpp58KvRRs3LjxhvmpjAkTJjBkyBD+/fdfPvroI5o0aUKvXr0ccq2iREnGjnr06IFMJitxEOWvv/5KZmYmPXv2tO0bO3YsgYGBTJs2rcTi8bx58zh16hTjx4+3vZ05woABAzh//jy+vr506NCh2L+CInXfvn2B/AdAaQqq4r755ptC+/fv38/JkyeLVYOMHj2a2NhYNm/ezDfffMOgQYPw8vKyfd6nTx+USiXnz58vMW/XB+3ykMlktG3blvfeew8vLy/+/fffMo/XaDRVeqts2rQpjRs35vDhw6Xm393dvdTzC777I0eOFNp/fW++6lYQyIv+ThZtXC6qSZMmvPbaa7Ru3brE7/2xxx5jzZo1LF++nFGjRpVaI1BRd955J0Cxv8v169eXq9fm3XffTVZWVrHvfNWqVcWOrV+/PmfOnCn0UpCSksLu3bsLHSeTyVAqlYWqWnNzc/n6669vfEMlKFp6K2rQoEHUq1ePF198kT/++INnnnmm2sbeiZKMHTVs2JDnnnuO+fPnk56eTr9+/XBxcWH//v288847dOjQoVA9rpeXF9999x0DBgwgMjKSl156ibZt25KZmcnatWtZuXIlvXr1KnGU89GjR1m/fn2x/bfddlu53vivN3HiRL777jvuvPNOJk2aRJs2bbBarcTExPD777/z4osv0qlTJ7p168ajjz7KrFmzSEhIYMCAAWg0GqKiotDpdDz//PM0bdqUJ598kg8//BC5XE7fvn25ePEir7/+OnXr1i023qd3797UqVOHZ555hvj4+EJVZZD/Rztz5kymT5/OhQsXuPfee/H29iYhIYF9+/bh6up6wx5IP/30E4sXL+aBBx6gQYMGSJLEhg0bSE9Pv+HbXOvWrdmwYQNLliwhMjISuVxe4cD2ySef0LdvX/r06cPjjz9OaGgoqampnDx5kn///ddWV16SoKAgevbsyZw5c/D29iYsLIw///zTVjVVE7p27Yq3tzfjx4/nzTffRKVSsXLlSg4fPlzouCNHjvDcc88xZMgQGjdujFqtZuvWrRw5cqTU0uNDDz2ETqfjoYceIjc3l9WrV9uq6yqrZcuWPPzwwyxYsACFQkGPHj04fvw4CxYswNPT84YlplGjRvHee+8xatQo/ve//9G4cWM2b97Mb7/9VuzYRx99lE8++YSRI0fyxBNPkJKSwrx584pV4fbv35+FCxcyYsQInnzySVJSUnj33Xcr/TJZUMX46aef4u7ujlarJTw83FaKUygUPPvss7zyyiu4urqWq9u23Ti8a8Etxmq1SkuWLJE6dOgg6XQ6Sa1WS40bN5ZeeeUVKSsrq8RzLl26JD3zzDNSeHi4pFKpJEACpJkzZxbquSRJ//UMKu3f9T1diurevbvUsmXLEj/Lzs6WXnvtNalp06aSWq2WPD09pdatW0uTJk2S4uPjbcdZLBbpvffek1q1amU7rkuXLtKPP/5Y6Ji5c+dKTZo0kVQqleTn5yeNHDlSunz5conXfvXVVyVAqlu3bqk9rTZu3CjdfffdkoeHh6TRaKSwsDDpoYceKtQL6bHHHpNcXV2LnXvq1Cnp4Ycflho2bCi5uLhInp6eUseOHaUvv/yy1O+qQGpqqvTQQw9JXl5ekkwmkwr+ZMrq5UeRnj+SJEmHDx+Whg4dKgUEBEgqlUoKCgqSevToIS1duvSGeYiLi5MeeughycfHR/L09JRGjhxp6zFXtHdZSfdf0OuxQGl5L+jJVtCTsUBBL67rex/u3r1b6tKli6TT6SR/f39p3Lhx0r///lsoTwkJCdLjjz8uNWvWTHJ1dZXc3NykNm3aSO+9916h3+uC3mVF8+Lm5ibde++9kl6vL/F7LSlf19/H9T2t8vLypMmTJ0sBAQGSVquVOnfuLO3Zs0fy9PSUJk2aVMK3XtiVK1ekBx98UHJzc5Pc3d2lBx98UNq9e3eJf3NfffWV1Lx5c0mr1UotWrSQ1q5dW2LvsmXLlklNmzaVNBqN1KBBA2nOnDnSF198IQGFen2Vp3eZJOX3DgwPD5cUCkWJ+bp48aIESOPHj7/h/dqTTJLKMfJHqFZHjx6lW7duRERE8Msvv+Di4lLTWRKEm87u3bu5/fbbWblyZaV6il28eJHw8HCWL19evSWDSvrwww+ZMGECx44do2XLltV2XVFd5oRat27Npk2b6NOnD4MHD2bTpk1VrjIQhFvZli1b2LNnD5GRkbi4uHD48GHeeecdGjduXGqHjJtFVFQU0dHRzJw5k/vvv79aAwyIIOO0unfv7vDpHgThVuHh4cHvv//OokWLyMrKws/Pj759+zJnzhzHj3ivYYMGDSI+Pp5u3boV6tFXXUR1mSAIguAwoguzIAiC4DAiyAiCIAgOI9pkrrFarcTGxuLu7i4WCBMEQbgBSZLIysoiJCSkzLFGIshcExsba7f5wQRBEG4Vly9fLnPFUhFkrimY2uPy5cvFRucKgiAIhWVmZlK3bt0yp0UCEWRsCqrIPDw8RJARBEEopxs1L4iGf0EQBMFhRJARBEEQHEZUl5WDxWKp8uJWgvNTqVSFpl4XBKHqRJC5gezsbK5cuYKYGOHmJ5PJqFOnDm5ubjWdFUG4aYggUwaLxcKVK1fQ6XT4+/uL8TM3MUmSSEpK4sqVKzRu3FiUaATBTkSQKYPJZEKSJPz9/cV0+7cAf39/Ll68iMlkEkFGEOxENPyXgyjB3BrEz1kQ7E8EGUEQBMFhnC7I7Nixg/vuu4+QkBBkMhkbN2684Tnbt28nMjISrVZLgwYNamTNhOokk8nIzs6263n169fn2LFjVc1aIV9++SVnzpyp0Dnp6enMmzev3Me/9dZbGI3GimZNEIRq4nRBJicnh7Zt2/LRRx+V6/jo6Gj69etHt27diIqK4tVXX2XChAl89913Ds6pcCPVEWRmzJghgowgODGnCzJ9+/Zl1qxZ5V4SdenSpdSrV49FixbRvHlzxo0bx5gxY3j33XcdnNOymSxWTsVlsud8MqfiMjFZrHZN/+OPP6ZTp062NcYLnD17lv79+3PbbbfRtm1bFi9eXOL5O3fupHXr1nTs2JHnnnuuzC7aCQkJDBo0iNatW9OqVSs+/fRT22dFS0AdOnRg27ZtfP755xw4cIAJEyYQERHB5s2bC6VptVp57rnnaNasGW3btiUyMpK8vDzGjx9Peno6ERERdOjQAYCFCxdy22230a5dOzp27Mg///wDwPjx4wHo2rUrERERJCYmkpWVxRNPPEHHjh1p06YN48ePt41xmjVrFs2bNyciIoKIiAguXbpUka9cEITKkJwYIH3//fdlHtOtWzdpwoQJhfZt2LBBUiqVktFoLPW8vLw8KSMjw/bv8uXLEiBlZGTYjsnNzZVOnDgh5ebmVjjvJ2MzpDX7YqQNB69Iq/+5JJ2MzbjxSeUESIsWLZIkSZJOnDghubm5SSaTSTKbzVKHDh2kkydPSpIkSTk5OVLr1q2lgwcP2s7LysqS8vLypJCQEOmvv/6SJEmS1q5dKwHS0aNHS7ze0KFDpalTp0qSJEkJCQlSnTp1pH/++UeSJEkKCwsrdF5kZKQt3e7du0s//vhjiWn++++/UrNmzSSLxSJJkiSlp6dLFotFio6Olnx9fQsdm5iYaPv/PXv2SC1btiz0XWRlZdm2n3jiCWnFihWSJEmS1WqVxo4dKy1cuFBKTU2VPD09Jb1eb/tuiv5cq/LzFoTaxmQy2f5WKyMjI6PYM7MkTleSqaj4+HgCAwML7QsMDMRsNpOcnFzqeXPmzMHT09P2z97T/KfpjagVcoI8tWiUCtL09q3SeeSRRwBo3rw5SqWS+Ph4Tp8+zfHjxxk+fDgRERF07dqVrKwsTpw4Uejc06dPo9PpuOuuuwAYOnQonp6epV7rjz/+4NlnnwUgICCAwYMH8+eff1Yp/w0aNMBkMjFmzBi++uorTCZTqWtSREVF0b17d1q1asX48eM5ceJEqVVkGzduZP78+URERNCuXTt27tzJ2bNn8fDwoHHjxowcOZJPPvmE1NTUm35td0EojSRJPPLII/Tp04fLly879Fo3xTiZol1PpWtVP2V1SZ02bRqTJ0+2bRdMW20v3jo1l1L0xGfkYTBb8Nap7ZY2UOgBqVAoMJvNSJKEn58fhw4dKvNcqYyqsT/++IMpU6YAMGTIEKZPnw4U/y4LtpVKJRaLxbY/Ly+vXPn39PTk+PHjbN++nb/++otp06axY8cOlMrCv5JGo5EHH3yQbdu2ERkZSWZmJp6enhiNRtTq4t+pJEls3LiRBg0aFPts79697N69m23bttG5c2dWr15Nt27dypVfQbiZyGQyunXrxogRIxy+jlatL8kEBQURHx9faF9iYiJKpRJfX99Sz9NoNLZp/R0xvX/DADci6noR5Kkhoq4XDQMcP1VJ06ZN0el0rFixwrbv3LlzpKamFjquWbNm5ObmsmPHDgDWr19PRkYGAD179uTQoUMcOnTIFmB69uxpa4dJSkri+++/p0ePHvn32bChrY1k3759nD592nYdDw8PW7pFJSUlkZOTQ+/evZk9ezb169fnxIkTeHh4oNfrMZvNQH7QMplMtj+EDz/8sFA67u7uha4xcOBA3nnnHdv5aWlpnDt3jqysLBISEujWrRuvv/46d9xxB1FRUeX+bgXhZqDX6/n2228BeO6557j//vsdfs1aH2S6dOnCli1bCu37/fff6dChAyqVqoZyBSqFnGbBHnRp6EezYA9UCsd/1Uqlkh9//JF169bRpk0bWrZsybhx48jNzS10nEajYfXq1Tz77LN07NiRffv2Ua9evVLT/eCDDzhy5Aht2rTh7rvvZvr06XTs2BGA//3vf7z//vt06tSJ5cuX07JlS9t5Tz75JDNnziyx4f/y5cv06tWLNm3a2DoU9O3bFx8fHx555BFat25Nhw4d8PDwYObMmXTs2JE777wTjUZTKJ0XX3yRHj162Br+Fy1ahFKpJCIigjZt2tCzZ08uXrxIRkYGgwcPpnXr1rRp0waTycRjjz1W1a9cEGqN7OxsBgwYwOOPP87Vq1er7boyqay6kxqQnZ3NuXPnAGjXrh0LFy7k7rvvxsfHh3r16jFt2jSuXr1qe1uPjo6mVatWPPXUUzzxxBPs2bOH8ePHs3r1ah588MFyX7egGiYjI8NWqsnLyyM6Oprw8HBRf38LED9v4WaVmZlJv379OHz4MJs3b7ZLNXFJz8ySOF2bzIEDB7j77rtt2wXtJo899hhffvklcXFxxMTE2D4PDw9n8+bNTJo0iY8//piQkBA++OCDCgUYQRCEm1VaWhr33nsvp0+fZsuWLXTu3Llar+90Qeauu+4qs2H6yy+/LLave/fu/Pvvvw7MlSAIQu2kVCrx9/dnyZIltG/fvvqvX+1XFARBEBziYlImr6w/zOU0A37ybCbdHcbdndrx008/1VieRJARBEG4STy38iDH4vWYM5PZv3Y6O5d6kBp9vEZnGBdBRhAE4SZxPF6POSOBhNWvIklWXPtMrvElLESQcRDJYiFv7xHMCSkoA33Rdm6DTCyEJQiCg+iNZoypV0lY8xoyhYKgh+ei9Ayo6WyJIOMI2T9tJ3n6+1hik2z7FCH++P3vBdwGdK/BnAmCcLNat/8i1oxEXJVq+nR9HIM+lxhf+07MWxm1fjCms8n+aTsJY14rFGAALHFJJIx5jeyftlf5GmWtCxMREVFs8KUj3XXXXTXaqCgIQv54wYOf/cSGf3azX16fOf9sZ+Hm1Sxf84ldnjlVIYJMBUiShDUnt9R/lsxskl9dBCX1wL62L/nV97FkZpeaRlXHxh46dAgXF5cqpSEIQu1x8OBBIttGoPp2Ef45WSiua4Nxzciw28ttZYnqsgqQ9HlE1+9dhQTySzQXG/Yt9ZDwi78jc71xkHj33XfZsmULSUlJzJgxg4cffhjIL+VkZWXh5uZG/fr1GT16NL/99htxcXGMHTuW1157DcgvgXTq1Indu3cTGxtLr169bCuKZmVlMXnyZA4fPkxeXh5du3blww8/RKVSceLECUaPHo3JZKJ58+alToi5bds2Jk6cSOfOndm1axcqlYoVK1bw9ttvc/ToUUJDQ/n+++9xc3PDZDLx+uuvs3XrVoxGI82aNWPp0qV4eXmxatUq3n//fYxGI5IkMXv2bPr16wdQ5v0Jwq1gz5499O3bl3CLgpHaYIo28ReUIpJf+wDXvnfUSLuwKMnUUjKZjF27dvHrr7/y/PPPlzpdd3p6Ort372bfvn3Mnz+/0JxF58+fZ9u2bRw7dozffvuNPXv2APnzgd15553s27ePw4cPYzabbSuVPvroozzzzDP8+++/PP/88+zfv7/UPB4/fpzx48dz9OhRunTpwr333suCBQs4ceIEKpWKVatWATB//nzc3NzYt28fhw4domXLlrz55psA9OnTh7179xIVFcXGjRsZN26cbRGyG92fINzMduzYQe/evWkZFs5ybVM85aWUGSSwXE0kb++R6s3gNaIkUwEynZbwi7+X+nnunsPEP/zSDdMJWj0fly5tS71GeYwbNw7IX5fljjvuYOfOnYwYMaLYcQXrzvj7+9OgQQOio6MJDQ0FYPjw4SgUClxcXIiIiOD8+fN06dKFjRs3snfvXhYsWJB/X7m5qNVqMjMzOXbsGI8++igAnTt3pnXr1qXmsWnTpkRERADQvn17Ll26RJ06dQCIjIzkwoULQP4aMJmZmaxfvx7In96/YcOGQH5d8yOPPMKVK1dQKpUkJydz6dIlGjVqdMP7E4Sb2cqVK+nUqRPfjHmBnBduvGS5OSGlGnJVnAgyFSCTycqsytLdfRuKEH8scUklt8vIQBESgO7u2+xebC2tL3xJ687c6LPS1mTJzMysUJ/7oukX3S7ooCBJEosXL7YtH3C94cOH8+677/LAAw8A4OPjU6iKrqz7E4SbUXp6Ol5eXnz88cf560gdPElOOc5TBpa+9IkjieoyO5IpFPj974VrG0U/zP+P36wJdgkwy5YtA+DixYv8/fff3HHHHVVOs0Bpa7J4eHjQqlUrVq5cCeSvH3P06FG7XG/hwoXo9Xogf82L48eP265dv359AL755hvS0tKqfD1BqK2+//576tevz6FDh1AqlWi1Wswp6WWfJANFaADazm2qJY9FiSBjZ24DuhO4bBaKYP9C+xUhAQQum2W3cTIajYbbb7+d3r178+GHH9p1dbvS1mQBWLFiBR999BHt27fn008/pVOnTlW+3tSpU4mIiKBTp060adOGzp0721b3fP/99xk0aBB33HEHhw8fLnPdG0G4ma1Zs4YhQ4Zw77330rJlSyRJIn3xGhLH5bdfSkDRUTEFFSr2ermtDKdbT6am2Hs9GTHiv/YR68kIzuqrr75izJgxjBw5kmXLliGXJJJffZ/M5RsB8Bg9iBOhYegWLcMnO9N2nt7Hm/AFLzpkEHitXU/mZiFTKHC5vV1NZ0MQhFouOzubV199lbFjx+YPM9DnET/uTfR/7gWZDN8Zz+A5fhiu0SlM0bsRcPYCAXk5mH29qNOrE68NKL1zTnUQQUYQBMFJmUwm3Nzc2L9/P8HBwVjik4l7+GWMx88hc9EQsOQN3PrfCcD5lBx0LlrOhzXgrAw0CjmNNTW3BH0B0SYjCILghObOncs999yDwWAgJCQE4/HzXOnzFMbj51D4exOy8QNbgAFAkuGmUSBXyDGZJeRyGfV9dTV3A9eIICMIguBEJElixowZTJ06lbvvvhu1Wk3OH3u5OuAZLHFJqJqEEfrLUrTtWxQ6r2WIB+5aNW5qBSHeWur66jCYa36CTFFdJgiC4CQkSeLVV1/lnXfeYfbs2UybNo2MLzeSPHURWCxo72hP0PJZKLzci53bNNiD5qHuZOSZ8NAoUSjlSDW7lAwggowgCILT2L59O++88w7vvfceL0yYQPKbH5OxeA0A7sP74r/gJWTqkttZVAo5DXzc2HYykXS9Ca1ShosT9GgV1WW1UFnT+devX59jx45VOu1t27bRoUOHSp8vCELFFYwkueuuuzh48CATnhxPwpg3bAHGZ+o4/D+YVmqAKRCdkoMkk+HhosSKjOiU8swF4FgiyNRCYjp/Qbh5WCwWxo4dyyeffAJAmzr1iR38Ajk/bwe1ioAlr+P94mPlmtLJYLbgplZRz9sVd40Kg9ni6OzfkAgytdD1i5bt3LmT1q1b07FjR5577rlC69GcPXuW/v37c9ttt9G2bVsWL15s+2zkyJF06NCBNm3aMGDAABITE6v9PgThVmcymXj00UdZsWIF7u7uGM9c5Grf8RgOnkDu5U7Itwtxf6j8y4u0CfXCTasgy2DGTaugTaiX4zJfTqJNphLi4uKIi4srtM/b25vw8HDy8vI4ceJEsXPat28PwOnTp8nJKVyErV+/Pj4+PhXOh8FgYPjw4axcuZK77rqLdevW8fHHHwP5b0cjRozg66+/plmzZuj1ejp37kznzp1p3749ixYtws/PD4B33nmHmTNn2qbzF4TaKCPXyLr9MVxK1hPmp2PobfXwdFHXdLZKZTQaGT58OD/++CNr166lX2A4V/s9jTUjG2X9EIJXz0fdqGLTKPVoEYhCIeNqWi6h3i50bxrgoNyXnwgylfDJJ58wY8aMQvseeeQRvvnmG65cuUJkZGSxcwpKGI8//jh79+4t9NnXX3/NyJEjK5yP06dPo9PpuOuuuwAYOnQoTz75pO2z48ePM3z4cNvxWVlZnDhxgvbt27Ny5Uq+/vprDAYDubm5BAUFVfj6guBM1u2P4ZdjCWiVco7HZQHwxJ2NajhXpXvttdf4+eef2bBhA3fpVcQOfRFMZjS3tSJ4xWwUft4VTlOnVtK3dYgDclt5IshUwlNPPcXAgQML7fP2zv+FqFOnDgcPHiz13C+//LLEkkxllDXtnCRJ+Pn52SaavN7ff//NRx99xO7du/H39+eHH35g5syZlcqDIDiLS8l6tEo5DQPcOJ+YzaVkfU1nqUxTp06lX79+tNl3gcT5ywFwHXg3AR9NR+6iqXB6JouV84nZpOmNeOvUNAxwQ6Wo+RYREWQqITg4mODg4BI/02q1tqqxkjRt2tRu+WjWrBm5ubns2LGDO++8k/Xr15ORkWG7jk6nY8WKFYwaNQqAc+fO4ePjQ1paGh4eHvj4+GA0Gm0NjoJQmxStHgv01HI8LovzidnkmqyE+dX8aPeisrKyePrpp5k9ezZ1A4Novm4nad/mL4ToNeERfKY/iUxeucBwPjGbw1cyUCvkXErJD7DNgkufuLK61HyYEypNo9GwevVqnn32WTp27Mi+fftsU+ErlUp+/PFH1q1bR5s2bWjZsiXjxo0jNzeXvn370qhRI5o1a0afPn1sq1cKQm1SUD12ITmHzUcTAOjXOpAGfq70ax3I0Nuca1mI9PR0evfuzQ8//MDlU2eJHTKZ7G9/B4UC/4Uv4fv6+EoHGIA0vRG1Qk6QpxaNUkGa3mjH3FeemOr/GntP9S/UPuLnXbu8tuEIF5JzbNVjQZ5afF1VRCfpCffX8dRdjfBzc46fY0pKCn369OHChQv8vOxrQueuwnQuBpmbjqBlb6O7u2OVr3EqLpODMWmk642k6020r+dFzxZBDqsyE1P9C4JwU6vrq+PApTTSL6ZilcBsMXPgkgm1Qs7pxPwu/tMHtKrhXILVaqVv375cunSJXz/8DP+pSzClZKAMDSBo1Tw0LRra5ToNA9yITs5vi/JyUZOaY+J8YnaNV5mJICMIgtOIz9Az/9eTRCflEu7vwkv3NifIs+S2lch63hy5nE6a3oiXTk18Zi5qhZxwP1eik3OITnKOhn+5XM4bb7xB0MUkfKYuxppnRN26McGr5qEM8rPbdVQKOV46Nc2DPQjy1BKfkecUVWaiTUYQBKcx/9eT/HU6hSvpuWw9lcL8X0+WeqzJKtGzRTDP39OUXi2CCfLQYjBbiU7OwWC2Eu5fsw3/MTExzJw5E6vVyu3RGXjPWIaUZ0TX53ZCf/jIrgGmgLdOjcFsIT4jD4PZgreu5scJiZJMOYhmq1uD+DnXvOikXJRyGXV9dFxO1ROdVPIcfZD/QL2Uorc9UEd1qc+fJxMKtcnUlAsXLtCjRw9kMhlD4q2ov/0TAM9xD+I763mHLcVez1dHdHK2bTBmPSdYT0YEmTKoVCpkMhlJSUn4+/uXa+4goXaSJImkpCRkMhkqVc2vJnirCvd3ITpFz+VUPSaLRLh/6XP0eemU7I9OslWt3dW0uVO0wZw+fZp77rkHF42WVc175AcYmQzft5/H66khDr12TIqezDwLvm5aMnLNxKToRZuMM1MoFNSpU4crV65w8eLFms6O4GAymYw6deqgcILp0W9Vz/dozJVUPZfTDIT7anm+R+NSj/3oj7P8diwBixVOJ2Sikino3jyg0JQqOnX1PuKio6Pp3r07Ph4erPBpi/c/J5G5aAj85E1c+3Zz+PWv78bsLG0yIsjcgJubG40bN8ZkMtV0VgQHU6lUIsDUsJPx2dTxdadRoBe5Jgsn47Op71/ym/i2M0nkmiU0Shl5ZonfTsZhArRKBSeuTStT3VOs1KlTh5F972P43qt4n49H4e9D0Kq5aCOaVcv1i1YhijaZWkKhUIiHjyBUg6tpuWiVChr4u3IhKYeraaW3ycgKui1JMkDCaqXc59rb/v37MRqNROTIeGbrJSR9Lqpm4QSvmoeqbvXNC9gwwA2g0NQyNU0EGUEQnEaotwsn4rK4kJRDrslCqHfpbTLdG/ux8VAcFquEi0pOm1APck2Wcp1rT7t27aJfv350rNeQpYnuYLXi0r0DgV/MROFZfJnkW40IMoIgOI3WoR58dyCGIym51Pd1oXVo6Y3WE3s3w0WttPUme6xrfY5ezazWae63bdvGgAEDaOsbzLuxGpBbcR/RH/93pyBTVf/j1RnnLxNBRhAEp/HV7oucSdJfG7Wv56vdF0vtMebnpi32WR2f6qse2rJlCwMHDqSTTzAf5fjjIlfg8+oTeE18tMZ6ojpjw78YjCkIgtM4m5CDwWRBBhjNFs4m1Pwa9aXxU2np712HxYZgXDRaAj55E+9Jo2p0qIMYjCkIglAWyUpKtpGUbKNtuzRXUrN5a9MxLl6rWnvr/lbVUpLZunUrkb7B+Ez+gLmmIOQ+HgStmINL5zYOv/aNiIZ/QRCEMpiRkMtALiN/0ktKn4XhrU3H2B2djgK4mmHgrU3H+Hx0Z4fmb+XKlYwaNYqpXo0ZI/dHFV6HoNXzUDes69DrlpdKIa/xNpiiRHWZIAhOQyaBh05NXV9XPHRqZGXM9HMxJRcFEOylRSnL33akL774gkcffZTBmgAek/mh7dia0F+WOE2AcVYiyAiC4HAmi5VTcZnsOZ/MqbhMTJaSq8G6NPLHTa3AYLLiplbQpZF/qWnW93XBLEFceh5mKX/bUT766CPGjRvHw5pAZusa4jm4J8HfvYfC18th17xZiOoyQRAcrrxdax/pHIZKIbMtqVzW6pZTejchetW/JGabCHFXMaV3E4fkXTIYiV6+nse1IUx3Dcd70ih8po2r0iqWtxIRZARBcDhHdK3deT4VT1cXAj1dyTVZ2Xk+lWahPnbI7X+idu4iYN5qnr5kBs9GBCx4GY9H+tv1Gjc7EWQEQXC48s6ptXLvJb49cBmrVeLvczJMFoln7i55ksxLyXq0Srlt+eWziTn8cjTWLhNkSpLEaxMmMeej9/nZqx3NvAMIXD4LXfcOlUrvVua05b3Fixfb1lqPjIxk586dZR6/cuVK2rZti06nIzg4mNGjR5OSklJNuRUEoSwNA9yIqOtFkKeGiLpepXat3XU2iTS9CbMkkZ5rYtfZpFLTDPFyITYjl+1nkojNyMVgNPPHyUSik/VsOZHI9tOJlcqrJElMfnQ0sz96n5d09WlRvwEhPy8WAaaSnDLIrF27lokTJzJ9+nSioqLo1q0bffv2JSYmpsTj//77b0aNGsXYsWM5fvw43377Lfv372fcuHHVnHNBEEpS0LW2S0M/mgV7oFKU/OgxWSVMFisGkxWTxYrJWnr3skAPDe4aJWqFDDe1EpNkJddoQSaDPJOFmJSKL79stVp5uv8gFq38itddG/B8l3sI/WUpmuYNKpyWkM8pg8zChQsZO3Ys48aNo3nz5ixatIi6deuyZMmSEo/fu3cv9evXZ8KECYSHh3PHHXfw1FNPceDAgWrOuSAIVdG+njfeOjVqpRxvnZr29bxLPTZNb6JNHW8e7hhG27reGI1W4jLyuJCUTVxGHhVtl5ckicvzPuev37cwy60RTw8eSsimDx2yTPKtxOmCjNFo5ODBg/Tu3bvQ/t69e7N79+4Sz+natStXrlxh8+bNSJJEQkIC69evp3//0hvoDAYDmZmZhf4JglCzOjfwxlUtx2yx4qqW07lB6UEm1Nul0KzLTYI9iKjrQQN/VyLqetAooPwzIJty8zg5/g3MC75mo1cET73wPEHLZyF3rZ6ZnG9mThdkkpOTsVgsBAYGFtofGBhIfHx8ied07dqVlStXMmzYMNRqNUFBQXh5efHhhx+Wep05c+bg6elp+1e3rhhQJQg17Y8TSVxNzyMrz8yV9Dz+OFF6m0zXRn40D3ZDhkTzYDd6NAugaZAnkWG+NA3yJNizfAEiLyWNwU3a8MAXCzHLZITMmYTfrAnIxBpSduF0QaZA0UnmJEkqdeK5EydOMGHCBN544w0OHjzIr7/+SnR0NOPHjy81/WnTppGRkWH7d/nyZbvmXxCEitt2Ko5so5Vcs0SO0cq2U3GlHns2LpNDMemcT84mKiYds9lSrs4F18s+H8P9Tdvy25VzTPJuRJ2v5+D5xEP2vKVbntN1Yfbz80OhUBQrtSQmJhYr3RSYM2cOt99+Oy+99BIAbdq0wdXVlW7dujFr1iyCg4OLnaPRaNBoNPa/AUEQKi0201zm9vU2Hr7K8dhMdGolsel5/HQ0jlmD2pb7Wun/HOaBe3qzJyeJT+p2ZMSPX6Fp27TSeRdK5nQlGbVaTWRkJFu2bCm0f8uWLXTt2rXEc/R6PfIirXwFyyVLUhmTHwmC4FSK/rWW9deblmNGLpPh7apGIZORllN6QCoq57dd/DRwDAf0KXzZqicjd28SAcZBnC7IAEyePJnPP/+cZcuWcfLkSSZNmkRMTIyt+mvatGmMGjXKdvx9993Hhg0bWLJkCRcuXGDXrl1MmDCBjh07EhISUlO3IQhCBQXo5GVuX69ZkDt5JjNnE7LIM5lpFlS+hv74j1cS9+g0ulh17B3wBEN3bUBVp+RaEqHqnK66DGDYsGGkpKQwc+ZM4uLiaNWqFZs3byYsLAyAuLi4QmNmHn/8cbKysvjoo4948cUX8fLyokePHsydO7embkEQhEqYMqAFr60/hsEKGnn+dmlCvLX4uWvQG624qOWEeGvLTFuyWDj/0nwe+vB/dFf7MP2pZ2kwd3KNLJN8K5FJoj4JgMzMTDw9PcnIyMDDw7nWYxCEW8XirWf59uBlLFL+mjJDI+vyTI+Sp5X5fMd5opP1NPB35UJSDuF+Osbd2bDEY605uZx4fCrDvv+CWKuB7198k+7zptfoKpa1XXmfmU5ZXSYIwq3peGwGcpmcut4uKJBzPDaj1GOLjpMJ9S65y7I5PpnD947jwe8/J0Ey8ut7S7hr/msiwFQTUU4UBMFpaJVyEjL0xGfmIpMkIuqW/obcvWkAgG1CzK6N/DgVl1lo6WHrmYvEj3iZRad2k4GFP1evp+3QgdV1OwIiyAiCUA1MFivnE7MLBYCS5i8zmS2YpWvj4oBcg7nUmZX1RjP/XkolOklPuL8Ob52KS6l5tjVrZHuiUL44G7L1TG/djdcXTKJRN8cuzywUJ4KMIAgOV95FyzKNFjxdVHhoVWTmmbiYpuePk4lolQpOxGUB0Ld1fo/RxX+dZfPReOTIOBqbwemETJKzDKTnWrj39CH8f/qW59OPs6hrf+7Z+CkKH8/qu2HBRgQZQRAcrryLlrmqFKTpjaRmG5HJwctFiVapsDXuX03LtR17/GoWViv4eahJzDRwIDoNixUe27+dDgf+ZETmMdzd3Wnx5TsiwNQgEWQEQXC48i5almc2gwQyGSCBUiaRkJlLQmYeINGlwX8rX/q6KFBcuoBfbg4eLq4c9Qhg8s5fCDl7gBEZx1C5ebLz5JESZ/wQqo8IMoIgOFw9Xx3Rydm2tpV6vroSj8s2Snjp1AR4aEnMzEOSyQjy1JKZZ8Zdo6TOtR5k2T9t5/l3FqJOTrWda5IrkFnM9Mk6CR5+eI15VwQYJyCCjCAIDheToiczz4Kvm5aMXDMxKfoS22RCPdX8E51GYnZ+dVojfx2RYb62arZcs5Xsn7aTMOY1VEVG+KmsFiSZjF4dhvNr53tx8xLj3ZyBGCcjCILDXd8mo1EqSm2TOR5beF2nq6l6LqZks/NsEhdTsnFXykie/n5+ldp1x+03ZTAp6zQGycqY+ATUah0eLioH3pFQXiLICILgcN46NQaz5YZtMomZBuC/AJKSa0J+bVsGyKNOYIktvMbMbmM6YzKOk2w1YkUiICeLrulX6dlCzFvoDER1mSAIDlewtsv142RKkmOwAv/NvpxjlNh5NpH0XDNeLkraG+MIve74bcZUnsk8SWeVF4s9mqGV5c++3tVTSf16Xg66G6EiRJARBMHhVAp5iW0wNyIBJ+JzUABX0w3systm6LXPzphzeDrzJN3V3rzv3gyN7L+KmeZtwuh4bUYAoWaJICMI1aS8o95vRldSs3lr0zEupuRS39eFt+5vRR2f4qUZlRKMpsL7tEoZvm4aVFfiuHPzj7b9jRU6Zrs1YoDGH1VBgJGBIiSA7iN6iOWTncSt8RsuCE6gYNR7fIaBQ5fTOZ+YXdNZqjZvbTrG7uh0EjIN7LqQzlubjpV4XLOgwqUdTzVIkoy6p04ze90yglKS2SjPZLMhGZlcxiBtYKEAA+A3a4IIME5EBBlBqCbl7WF1M7qYkosCCPbSopTlb5fEU6dCKQMFoJRBy2APXog/xtSNq3A35LHWD15KOsrRbs1QBPkXOlcK8OOLwcO58988Biz6i3Px6Q6/L+HGRHWZIFST8o56vxnV93XhaoaBuPQ8zFL+9vUki4W8vUdouO8gFklNStPGpGTkMnDjBm779wAAa5p58trfP/P888/z/vvvc/pqOhd+24dLegZ6Tw/eS5ETnW5ClmsmI9fMxDVR/DTx7pq4XeE6IsgIQjUpbw+rm9HUvs2YuCaK2EwjYR5qpvZtZvss+6ftJE9/H0tsEsOv7UtzdUOv1hCaloJVJmNiiBeb//6RMU89x/vvv49MJiMp10R0eEM0SjkGs5X4y2eRAV46Fel6E7GZt05J0ZmJICMI1aSyPaxuBmeT9DQN8aZtPQW5Jgtnk/Q0CvKyjd6nyOh9r5xsvHOyMSiVvHXPYI54euDtG0h0g4G2xcay88xcSMpGq8pP00sjJ85oJV1vQgJCPG6dkqIzE0FGEASHu5qWW2w2ZclisY3eL0oGWCWJL/Li2efljbefH0r3B0jKMduOcdMqaRDgilapIM9soWN4c77ZfZ7YTCMhHmoWDW9XfTcolEoEGUEQHC7U24UTcVmFlkrO23uk2Oj9ApIkMScnmuV5sbQ7tIW0O/JHxwS4/TdVjL+bFl/XPDRKBQazhVYhHrw7LNJWHRnmf+tURzozEWQEQXC4oksld28agPmnEyUea5Uk3so5z6q8eN5ybUhu4wh+1sgJcFPx8Yj2tuOKtnGZLdZyLYwmVC8RZARBcDidWmlb0bJAbqBvice+ln2Obw0JzHFrxBBtECHj7uKd24tXfRVt49pzPrlcC6MJ1UsEGUGooFt55L49aTu3QRHijyUuqVC7THuVB51UntynDSDDy4sGnduUKz0XpZyDF1PIMuSvPfNQZB0H5VyoCPGXIQgVdCuP3LcnmUKB3/9eAAmMkpWfDfntMw9pA7lPG4AM2HRvv3KP3r+Slkt8loFsg4W4zDyupJU84FOoXqIkIwgVVN716oUbc+nWHqOLimfjD/G3KZ2WSjfqK1xIcfPgky49SW3SstxpXcnIRa2U4+uqJiXHyJUMEWScgQgyglBBt/LIfXuLW7SCJ+Oj2G/J4tv573P4iopPMq0cD66HVSbjDg9NudNSyCA+I4/0HCO5Jivt6no6MOdCeYkgIwgVdCuP3LenjIuXeWjO6xw1ZbLhfwvoN+V53vz+MMcOXsVslVDIZIR4aMudXqNAN9qk5WG2SCgVMhoFip+LMxBBRqgVnKmx/VYeuW9PmZ+ux0OS8U1kP/pOewGAs0lZmCUJuVyGRZI4k5jFqbjMcv3cgz10NAv2sI2bCfbQVeftCKUQQUaoFQoa28UYiNovNTWVK0dP4vb1TyzxaEHw3NdsU8Wk5piwWECmkLBY8qu/yvtzFyVM5ySCjFAriMb26uHoEmNSUhI9e/bEGJvID1JDdJ3a4NKjo+1zq8WKBbBY8rcNJku5f+6ihOmcRJARagXR2F49HFlijIuL45577iEtOYVlhCGXyfB59QlbKQZAUSSgqRRyDlw39mWIGPtS64ggI9QKoiqkejiqxHj58mV69OhBXl4eG/qMJODXfbh074BLkZH8BrMVhSx/gkwJ0BstJGbmISFDbzBzJS2XiDC7ZEmoJiLICLWCqAqpHo4qMV65cgWtVstPny5HOXw6AD6vPlHsuAB3F66m56GQy7FYrWg0SgI8XGyzN8dn5tklP0L1ESP+BUGwaRjgRkRdL4I8NUTU9apyiTEmJgaTyUSXLl04dOgQXqu2gNWK7t470LZvUez4RtfagCRJQqWQ08DbhVyTpdDszULtIkoygiDY2LPEePz4ce655x4ef/xx3nnnHcynosneuBUAn1fGlniOq1aJl4sKqwzkVmgS5ElkuE+h2ZuF2kUEGUEQ7O7w4cP07NmTkJAQJk+eDEDqO18A4PZADzStGpV4nmSVCPV2JchLQ3y6AblcVmz2ZqF2EdVlgiDY1f79+7n77rsJCwtj69atBAQEkPfvCfS//g1yOd4vjyn13DZ1vXDTKMjONeOmUdAy1JNTcZnsOZ/MqbhMTBZrNd6JYA+iJCMIgl2tW7eOZs2asXnzZry8vABInfM5AO5D+6BuXHr3sG5N/InLyOVSsp4wPx2hXi5iEG4tJ4KMIAh2kZ6ejpeXF3PnziUvLw+dLn9al9xdUeRu2w8qJd5THi8zjbj0PDxcNETWd8FgtnA+ORu1Qllil2q90cz204mF2mt0avFIczZ2qS778ccfadeuHY0aNWLw4MH8/vvv9khWEIRa4rfffqN+/frs3r0buVxuCzCSJNlKMR4jB6AKK7t95fpxOhqlAqPJysGLKazZd4mDF1NwUf73yNp+OpE/TiYSnaxny4lEtp5MEFVrTqhKYX/RokW0b9+eKVOmsH79eho1asThw4eZPXs2Fy5cYPz48fbKpyAITurHH3/koYceolevXrRv377QZ7lb95H3zxFkWjXekx+7YVpFx+nojRbiswxIEmQXGYx5NS0XrVJhG0Nz5HI6OUarqFpzMlUqySiVSr7++msuX77MsGHDGD58OD/++CNDhgzh/fffR5KkGyciCEKttX79egYPHsyAAQPYsGEDWu1/U/Pnl2I+A8Bj9CCUQX43TM/PXc35xEx+PnyV84mZZBpM5JnMyGRgMFuIScuxHRtaZAyNTqMoVAoS89s5hyqVZJ577jkAzp8/z6efforJZOLo0aMcOXKEuLg4WrRogZubG/v377dLZgXBkZxpOYHawGAw8MorrzBkyBBWrFiBUln4cZKzeSeGw6eR6VzwnvBIudJcvTeGn4/EYZLgeFwGwZ4aknPMqBVyDGYrzYP+CxwFY2YK2mSCPV04FZ8l5rdzMnZpJfvggw8YMmQIXbt2pXXr1uTk5NCqVSv+/vtvMjIy7HEJQXA4sZxA+ZnNZjQaDX///TcBAQEoFIpCn0sWC6nv5LfFeD41BIWfd7nS3Xoqnsw8C64aBVl5FiRrLo0CPdGq5OSZrHi5/rdSpk6tLDSGxmSxolUpxPx2TqZSr2mXL18utN2qVSv27NlDr169SEhIoF69emzatAkAT0+xBKpQOxRtdBbVLSVbsmQJXbt2JScnh+Dg4GIBBiB741ZMp6KRe7rh9ezwcqcts02NCSChVSkJ8NAS6OFCgIeWBn6upZ5bMFtBl4Z+NAv2EKVQJ1GpkkxYWBje3t60bduWtm3bEhERQUREBMHBwWzYsIEVK1bYO5+C4HBiOYEbe++995g8eTITJ0609SArSjKZSZu7DACvZx9G4elOcnYen2w7R3SSnnB/HU/d1Qg/t+JLK3dv6s/ltFwMZisuKgUD2obQMtRTTCtTi1UqyFy4cIFDhw5x6NAhoqKiWL9+PVevXkUmk+HhIaoXhNqpnq+O6ORs2wOtnq9Yvvd6s2fPZvr06UydOpXZs2cXWgfmellrfsEUfQW5nxeeTzwEwCfbzvHbiUTUCjmnE7MBmD6gVbFzb2/kz5mE/HYxL52aOxr74aXT4KVT461Ti9JJLVSpIFO/fn3q16/PAw88YNu3Z88eHnvsMebOnWuvvAlCtYpJ0ZOZZ8HXTUtGrpmYFL1ok7nmwIEDTJ8+nRkzZvD666+XGmAkg5G0BV8C4P3CSORu+YE6OkmPWiEn3M+V6OQcopP0JZ5vskr0bBFsG3x5KSWXS6l5op2sFrPba0GXLl14//33mTVrll3SW7x4MeHh4Wi1WiIjI9m5c2eZxxsMBqZPn05YWBgajYaGDRuybNkyu+RFuDWINpniCoYhdOjQgQMHDvDGG2+UGmAAMlf8gPlqIoogPzwee8C2P9xfh8FsJTo5B4PZSrh/yaVEF5WCA5dSWLPvIgcupWCwmMXPpJarVJAxmUwl7m/cuDHHjx+vUoYA1q5dy8SJE5k+fTpRUVF069aNvn37EhMTU+o5Q4cO5c8//+SLL77g9OnTrF69mmbNmlU5L8Ktw1unxmC2iDaZa6xWKy+88AILFiwAIDIysuzj9Xmkvfc1AN4vPobc5b+eYMM61MFTIyM1x4CnRsawDiUvo3wlTU9CRh7ZeRbi0/PQGyziZ1LLVaq6zNXVlRYtWtCuXTsiIiJo164dISEhfPjhh/Tu3bvKmVq4cCFjx45l3LhxQP7MAr/99htLlixhzpw5xY7/9ddf2b59OxcuXMDHxwfIr9Iri8FgwGAw2LYzMzOrnG+hdhNLPP/HarUyfvx4Pv/8c5YuXVquczK++A5LUirKsGA8RvQv9Nmq/ZdJ1pvRKhUk5ZhZtf8yb9znVSyNq2m5qJUKfFzVpOYYsUgQUddL/ExqsUqVZLZu3coTTzyBSqVi5cqV9O3blyZNmvDhhx9iNBqZPn06a9eu5eTJkxVO22g0cvDgwWLBqnfv3uzevbvEc3744Qc6dOjAvHnzCA0NpUmTJkyZMoXc3NxSrzNnzhw8PT1t/+rWrVvhvAo3F9EFNp/ZbGb06NF88cUXLF++nCeffPKG51gys0n/cBUA3lNGI1OrCn1+/GoWViv4uKmRpPztkuSZzOy/mMpPh6+y/2IqZrNF/ExquUqVZO644w7uuOMO27bVauX06dO2HmcHDx5k2bJlJCYmYrFYKpR2cnIyFouFwMDAQvsDAwOJj48v8ZwLFy7w999/o9Vq+f7770lOTuaZZ54hNTW11HaZadOm2RZTgvySjAg0N3Yzz3wrRvzne+edd1i5ciWrVq1i2LBh5Ton45NvsaZlomochvuQ4rUZvm4qziZKpOUYMVslfN1UJaQCx69mkpVnRCaDPJOF41dFDUNtZ5eng1wup3nz5jRv3pyHH37Ytj8hIaHSaRZtXJQkqdQGR6vVikwmY+XKlbbBnwsXLuShhx7i448/xsWl+LrgGo0GjUZTbL9QtoKZb7VKBSfi8t9Gb5aVCx094r+2BLEJEybQuXNnevbsWa7jLakZpC9eA4DPy2OQlTA4c1SX+mTmmUnNNhHur2JUl/olphWbnotSrsDTRUlGrpnY9NJrI4TawaG/4UVLI+Xh5+eHQqEoVmpJTEwsNb3g4GBCQ0MLzS7QvHlzJEniypUrFc6DULrrZ751USm4mnbzPAQc3busIIjFZxg4dDmd89fGiziD3NxcRo8ezdmzZ/Hw8Ch3gAFI/2g1UrYedctGuA68q8RjIuv78lr/lrx+Xwte69+SyPq+JR5Xx9cFuQxyTVbksvxtoXZzutcotVpNZGQkW7ZsKbR/y5YtdO3atcRzbr/9dmJjY8nO/u+P9syZM8jlcurUKbkXi1A5RWe+DfW+eR4Cju5d5qxdpHNycrjvvvtYu3ZthV/KzAkpZHy+HgCfV8chk5f8SClve9djt4fTtq4nQR4a2tb15LHbwyt2M4LTccrK9MmTJ/Poo4/SoUMHunTpwqeffkpMTIxtfZpp06Zx9epV2/Q1I0aM4O2332b06NHMmDGD5ORkXnrpJcaMGVNiVZlQeUVnvr2ZpvlwdO8yZ5y2JjMzkwEDBhAVFcWvv/7KnXfeWaHz0xd9jZRrQBPZAl2vkl8CK6JFsCd3NvG3Lb/cIljMfVjbOWWQGTZsGCkpKcycOZO4uDhatWrF5s2bCQvLX60oLi6u0JgZNzc3tmzZwvPPP0+HDh3w9fVl6NChdhsYKvyn6My3N5OCt21HcbYu0pIkMWjQII4cOcKWLVvo3Llzhc43XUkgY8UPAPi8+kSZgzTLa/e5ZE7FZ6NVKjgZl83O00k0CnR3+nYsoXQySawsBuS/0Xl6epKRkSHmXxNuGX/88Qfe3t43HGhZksRJc8n65ie0t7cj5Pv37RJkPt9xnuhkvW21S1eNnIYBHtfWk7EQUddLTCvjJMr7zBSvBIJwi0lISOCNN97AarXSs2fPSgUY04UrZK3+BQBfO5ViAPzc1JxLzOLno7GcS8xCpZA7ZTuWUH4iyAjCLeTq1at0796dzz//nNjY2Eqnkzp/GVgs6Hp2Rtuxtd3yJ0MGMpAkGRISrmqlmFamlnPKNhlBEKqmpDE5sVcu06NHD0wmEzt27Kh0z0vDyQtkf/cHAD5Tx9kz28Rl5eHrqvlvWhmseLooxfILtZgIMoJwEyo6sDQxIZ5H7++FUqlkx44dN5zbryxpc5eBJOE6oDuatk3tl2lAIYP4jDzS9UZyTVZCvTRi+YVartxB5vopWG5k4cKFlcqMIAj2cf2YnPiMPOQ6Tx555BGeeeaZKo0dMxw+Tc7P20Emw+eVsXbMcb5GgW60Sc/FbAGlArxcNYXuQ7TJ1D7lDjJRUVGFtg8ePIjFYqFp0/w3mTNnzqBQKCrViCgIgn0VjMn550AUycmJjBw8gNmzZ1c53ZTZnwHg9lAv1M3sP1Ay2ENHsyBPNEoFBrPFNr1MSW0ytWWanltduYPMX3/9Zfv/hQsX4u7uzldffYW3tzcAaWlpjB49mm7dutk/l4IgVEjDADeOHznEm+OH0qhxU954qnwTXZYld+8Rcrf+AwoFPi+NsUMuiys6lqier46YFH2JY4scPdecYB+VGicTGhrK77//TsuWLQvtP3bsGL17965Sr5WaIsbJCDeTf/75hz59+tC0aVN+/fVX28tgZUmSROwDE8jbfQj3R+8jYOHLdspp5e05n0x8hsFWlRbkqaFLQ7+aztYtw6HjZDIzM0ucYTkxMZGsrJLXiRCEW5XJYuVUXCZ7zidzKi4Tk8Xq0Ovt2rWLnj170rp1a7Zs2VLlAAOQu+MgebsPgVqFz4uPVT2TdiBWMq0dKtW7bNCgQYwePZoFCxbYpqLYu3cvL730EoMHD7ZrBgWhtqvuah1/f3/uu+8+PvvsM1xdXaucniRJpM7+FADPx+5HGVrx2dUdwdmm6RFKVqkgs3TpUqZMmcLIkSMxmUz5CSmVjB07lvnz59s1g4JQXcrbkFzRBueiPb0c1UNqx44dRERE0KRJE1atWmW3dPW/7cLw70lkOi1eEx+1W7pV5ei55gT7qFR1mU6nY/HixaSkpBAVFcW///5LamoqixcvtsubkyDUhPKu91LRdWGqo1pn48aN9OzZk3fffdeu6UpWK6nvfA6A57gHUQb42DV94eZXpcGYrq6utGnTxl55EYQaVd4SR0VLJo6u1lm7di2PPPIIgwcP5vXXX7dr2jmb/sJ4/Dxyd1e8nhth17SFW0OlO5Xv3LmTkSNH0qVLF65evQrA119/zd9//223zAlCdSpviaOiJZPyLthVGStWrGDEiBGMGDGCVatWoVKp7Ja2ZDaTOm8ZAJ7PDEPhLaqmhIqr1G/7d999R58+fXBxcSEqKgqDwQBAVlaWXQZ8CUJNaBjgRkRdL4I8NUTU9Sq1xFHe46pDYmIiY8aMYfny5SiV9p0lKuvb3zGdi0Hu44nXU0PtmnZp9EYzvxyN5fMd5/nlaCx6o7laris4TqXGybRr145JkyYxatQo3N3dOXz4MA0aNODQoUPce++9xMfHOyKvDiXGyQi1ydGjR2ndOn/2Y0mS7DbVfgHJaCKmyyOYY+LwefNpvKupquyXo7H8cTIRrVJBrslCrxYBN+0iebWdQ8fJnD59usRlWj08PEhPT69MkoIglNP8+fNp06YNe/bsAbBbgJEsFnJ3RZG14Q+SZy7BHBOHIsAHzzHVNyzhalouWqWCBv6uuKgUXE3LrbZrC45RqfJ1cHAw586dKzaT699//02DBg3skS9BEErw9ttv88Ybb/Daa69VeLnksmT/tJ3k6e9jiU0qtF/XqytyndZu17mRUG8XTsRlcSEph1yThVBvl2q7tuAYlQoyTz31FC+88ALLli1DJpMRGxvLnj17mDJlCm+88Ya98ygItzxJknjttdeYPXs2s2bNYvr06XZLO/un7SSMeQ1KqDjPWvUTup6dcRvQ3W7XK0v3pgEAtvVjCraF2qtSbTIA06dP57333iMvLw8AjUbDlClTePvtt+2aweoi2mQEZ6bX67nzzjt5+OGHefHFF+2WrmSxcKn9kGIlGBsZKEICCDu4DplCUenriBmTbz7lfWZWKsjExMRQp04d8vLyOHHiBFarlRYtWuDq6srly5epV69elTJfE0SQESqiuh6aVquV5ORkAgICMBqNqNX2HciZuyuK2Acm3PC4kI0f4HJ7u0pf51Rcpm1qHYPZQkRdr3KN1hfByXmV95lZqeqy8PBw4uLiCAgIoEOHDrb9KSkphIeHY7FYKpOsINQKJouVP07EExWTgaeLCi9d/tgUe09xYrFYePLJJ9m2bRvHjx9Hq7V/24g5IcWux5WmslPriOn8a79KBZnSCj/Z2dkO+UMQBGdyPjGbqJh0cgwWpGsNGfaej8xsNvPYY4+xZs0avvrqK4f8XVkys8nZsqdcxyoDfat0rYJF1Co6tU51zfsmOE6FgkzBEswymYw33ngDnU5n+8xisfDPP/8QERFh1wwKgrNJ0xvx1KmQAL3BApLRrvORGY1GRowYwaZNm1izZg1DhgyxW9oA1lwDmcs2kPb+N1jTMss++FqbjLZz1aaPquzUOpUNToLzqFCQKViCWZIkjh49Wqh+WK1W07ZtW6ZMmWLfHAqCk/HWqfG67mHXvp59R/1HRUXx22+/8d133zFw4EC7pSuZzWSt3kzq/C+xxOU39KuahKHrfTsZH1+btfm6SgpJBjLAb9aEKjX6Q+VnTK7nqyM6OdvW26yer+7GJwlOpUJBpmAJ5tGjR/P++++LBnLhllTSW3l5GqNv1Iidl5eHWq2mU6dOXLx4EV/fqlVRFZCsVnJ+3E7qnM8wnb8MgLJOIN4vj8F9aB9kCgXayBbFxskYfX1QvDq+2rovQ/HvyGyxkplnwddNS0aumZgUvWiTqWUq3YX5ZiN6lwkVUZleTwU9rBRyuJKqJ9hLS9s63jQMcMOQq+e+++4jMjLSbtP1S5JE7rb9pP7vUwyHTwMg9/XEe+IoPB6/H7lWU/h4i4WD3+0k63IiymAf1ls9yTFJdGrow9Db6uHp4viqqmNX0vnjVCJGsxWVQkaguxYXtVIsseyEHNq7bM6cOQQGBjJmzJhC+5ctW0ZSUhKvvPJKZZIVBKdUUkCpTK+ngkZso8VCXKYBo0VCkmRkZWYwccxwjh49yqxZs+yS57yDx0mZ9Sl5f/8LgMzVBa9nH8Zr/FDk7iWv+SRTKHDr1p5zl9PZfS6JQ1cz8HZVsflo/lLrT9zZyC55K8vxuAzi0vPwc1MTn2FAJpMI8tCJNplarFIdzj/55BOaNWtWbH/Lli1ZunRplTMlCM6kpEXKru/1pFEqytXrqWCJgJgUPUgQ5qvDmJPJqCEDOXHiBH/88Qe33357lfJqPB1N/GOvcvXe8fkBRq3C86khhB1Yi89Lo0sNMAUKZpjWGyx4u6poFeqJi0rOpWR9lfJVbpLM1mNPQiLATes0M14LlVOpkkx8fDzBwcHF9vv7+xMXF1flTAmCMympG21lej0VPCBlgFIhQyGT8/3XnxB/JYatW7fSrl3lBzuaLseTNm8ZWet+A6sV5HLch92L90ujUdUNKnc6BQ30nRr6sPloAtFJOeSarIT5VU+De8sQD+IycjFZJEI8tbSpU75Bm4LzqlSQqVu3Lrt27SI8PLzQ/l27dhESIqblFm4uJQWUynTJLXiANwxw42x8Jhl5Zma+9RbqaS/QtHHlqqLMSWmkL/qajC83gtEEgGv/O/GZNg510/CyTy7D0NvyZ+24lKwnzE9n23a0psEeKBVyh60iKlS/SgWZcePGMXHiREwmEz169ADgzz//5OWXX7brvEqC4AzTipTWm6yyb9jxsVd59IEH+Oijj2jdpQtQ8V5k1qwc0pesJX3xGqSc/OnwtXe0x/e1J9FGtqxUvq7n6aKuljaYoqryvQrOqVJB5uWXXyY1NZVnnnkGozG/Llqr1fLKK68wbdo0u2ZQuLU5w7QiVX3wXR8oc5JjeWL4/UiSRGBgYIXTsuYZyPxyE2mLVmBNyQBA07YpPq89hUv3DnZfvEwQqqpKXZizs7M5efIkLi4uNG7cGI1Gc+OTnJTowuyc9pxPJj7DUKu7sBZ0XU6+epE3nhqGm6sLO7f9VaGJZCWzmax1v5E2bxnmq4kAqBrWxefVJ3C97y67BxdnKEEKzs2hXZgLuLm5cdttt1UlCUEo080wrUhSdh6JGbnMe+kpVFoXFq34vtwBRpIkcn7ekT+Q8swlABTB/vi8NBr3h/siU1bpT7hUzlCCFG4O5f4NnTx5Mm+//Taurq62OcxKs3DhwipnTBCg8nNeOZPsPDMXU/Tc++wsFDovtJ7lK4npdx4kddYnGP49CYDc2wPviY/iMXoQchfH1hqIiSkFeyl3kImKisJkMtn+XxCqQ3U2BDuiiujAgQPMe30mA1/4H67BbYhLz+VSqp5TcZmlpp936BSpsz4hd/sBAGQ6F7zGD8Xz2eEoPKonyN4MJUjBOZQ7yBTMW1b0/wXhZmHvKqLdu3fTt29fwhs3xUMFeRaJzDwLfm5w6HJ6sfSN52JInf0ZOT9uy9+hUuL52P14TRqFMsCn0vmoDGcpQYq2odqvQtVl5SGTyViwYEGlMyQINcWeVUTbt2+nf//+REZG8v2mH0jMlfH32SQa+LnRpq4nyVlGW/rm2ERS5y8na/UvYLGATIbbkN74vDwGVVjNjDurqa7EJU2QeTwuS7QN1WIVqi673sGDB7FYLDRt2hSAM2fOoFAoiIyMtG8OBaGalFZFVNG36ejoaPr27cvtt9/Oxg0bkB85R2hCCt2UWqL865CcZcxP35hH8psfk/nFBiRDfsDR3XsHPtPGoWnRsFru2dkULU3KZBJqhVK0DdVilaouW7hwIe7u7nz11Vd4e3sDkJaWxujRo+nWrZv9cykI1aC0KqLTcZn8eTIBo0VCpZDRs3kgrep4lZpOeHg4n376KffqAki64zHb9PlKoEOgH5lPj8QtKRnliu/JyMoBQNulLb6vPYW2Y2uH3qOzK1qazM4zcuRKCtl5Jty0KoZE1q3pLAoVVKlxMqGhofz++++0bFl4ZPGxY8fo3bs3sbGxdstgdRHjZITSrN0Xw7+X0/FzVZOUbSCynjfDOhbvgvzdd9+RlZXF448/TvZP20kY81qhRcBKom7VGN/XnsKlR8ebeiBleUuDp+IyOXQ5HY1SgcFsISPXyP7oVAqeUoPahzKgbWiVriHYR3mfmZX6CWRmZpKQkFBsf2JiIllZWZVJUhCcl0yyBQsZsvztIlatWsWwYcP4448/sJrNJE9/v+wAo1Dgv/R16vz5Obp7Ot3UAQZKnsm6JAWzQBfMuowEgR4udG3kR5CnC/EZeVW+hlC9KhVkBg0axOjRo1m/fj1XrlzhypUrrF+/nrFjxzJ48GB751EQalTLYE+CPDVIQJCnhpbBnoU+X758OSNHjuTRRx/lq6++wvDP0UIrTJbIYkEV5I9Mfmu8aZd3aYSCDgddGvrRLNiDer46ck0WLiTlkGuyEOrtUuVrCNWrUsOFly5dypQpUxg5cqRt7IxSqWTs2LHMnz/frhkUhJpW1szAa9euZcyYMTz11FMsXrwYuVyOOe4GAeaamFOXSQ2qe0tU7VR23E33pgEAXE3LJdTbxbZtz2sIjlWluctycnI4f/48kiTRqFEjXF3LXhDJmYk2mfKJz9Az/9eTRCflEu7vwkv3NifIs3rWGnFGiYmJLF++nJdffhksFrK//5OUWZ9iiU284bmn33kNQ7tWGMwWIure3OumVEd7iWiTqV7lfWZWKcjcTESQKZ8X1x7kr9MpKOUyTBaJHs18WTDs5ui2XpGH1JIlSxg4cCChoaH5k1eu30L6eyswXbiSf4BMBqX9acnAEuDH8eUfEOSjc8qJP53lge0s+RCKc2jDP8DOnTsZOXIkXbp04erVqwB8/fXX/P3335VNspDFixcTHh6OVqslMjKSnTt3luu8Xbt2oVQqiYiIsEs+hMKik3JRymXU9dGhUsiITsqtluuaLFZOxWWy53wyp+IyMVmsdr9GeRqOJUnizTff5JlnnuH7774jc+XPxHR5hKTnZ2O6cAW5jyc+058k4OPp+UtgFmnPlwr+vfIkBkly2qodZ2lEd5Z8CJVXqSDz3Xff0adPH1xcXIiKisJgMACQlZXF7Nmzq5yptWvXMnHiRKZPn05UVBTdunWjb9++xMTElHleRkYGo0aN4p577qlyHoSShflpyTVaOJ+YRa7RQpiftlquWx0Pmxs1HEuSxNSpU5k5cyZvDnmU+776m6SJ72C+GIvczwufN8YTdnAd3hMfxX1IHwKXzUIR7F8ojTxfH45NeQ7T3V1oFeLhtGvXO0sjurPkQ6i8SlWXtWvXjkmTJjFq1Cjc3d05fPgwDRo04NChQ9x7773Ex8dXKVOdOnWiffv2LFmyxLavefPmPPDAA8yZM6fU84YPH07jxo1RKBRs3LiRQ4cOlfuaorqsfLafjuedzadIyTHh66piar9mdG9a/jXkK6s61pUpOkajaDvJ5IkTee/993k9pA2PGfP3K/x98HruYTweux+5a/GeT5LFwpmf93Dp1BVSdK4c9AslPMADXze1U7fD3Oi7uNXyIRTn0PVkTp8+zZ133llsv4eHB+np6ZVJ0sZoNHLw4EGmTp1aaH/v3r3ZvXt3qectX76c8+fP88033zBr1qwbXsdgMNhKYJD/hQk3Fp9hpE09H9vAxPiM6nmzrI6eQw18XZDtP0LO1SRcQ/0JjwgG8lejzPrmJxps2M1M14aMMHqgCPTF6/lH8Hj0PuS60ktzMoWCBv27It2WTczZJMLNUrG5y5yRs0yQ6Sz5ECqvUkEmODiYc+fOUb9+/UL7//77bxo0aFClDCUnJ2OxWIotTRsYGFhqCens2bNMnTqVnTt3oiznIk5z5sxhxowZVcrrLakcAxMdwdEPm+yftpM8/X2UsUkUjIKJDfZD3f02vtv4PX1zNfSVaVA0aIP3hEdwf2RAudd0uX6yyUOX0/+bu8zJ2mGuV1MTZDprPoTKq1SQeeqpp3jhhRdYtmwZMpmM2NhY9uzZw5QpU3jjjTfskrGiI6AlSSpxVLTFYmHEiBHMmDGDJk2alDv9adOmFZpZOjMzk7p1xbxIN9Iy2JPY9DxMFqnEgYmO4siHTWlTwOTFJvLcp+/xqzGZho17cPvU53Af0Q+5tnILhom3cuFWVKkg8/LLL5ORkcHdd99NXl4ed955JxqNhilTpvDcc89VKUN+fn4oFIpipZbExMRipRvI72xw4MABoqKibNe2Wq1IkoRSqeT333+nR48exc7TaDRoNI5dXfBmVNbAxNpIslhKnALGIFmZmHWKbcY0Pgpqx72HfkKuK320eVkyco2s2x/DpWQ9YX46ht5WD51aKbrmCreESi8Q/r///Y/p06dz4sQJrFYrLVq0wM2t6g8ctVpNZGQkW7ZsYdCgQbb9W7Zs4f777y92vIeHB0ePHi20b/HixWzdupX169cTHh5e5TwJ/7nZqi/y9h4pNgVMnmTh2cxT7DGls9ijOXebXTFEncLl9naVusa6/TH8ciwBrVLO8bj8uf26NQ6w6wJpguCsKhxkTCYTvXv35pNPPqFJkyZ06NDB7pmaPHkyjz76KB06dKBLly58+umnxMTEMH78eCC/quvq1ausWLECuVxOq1atCp0fEBCAVqsttl8QijInpBTbJ0eGTqbgE48WdFN7l3pceV1K1qNVymkY4Mb5xGwuJetpFWq/BdJuJmLw5c2nwkFGpVJx7Ngxh84aO2zYMFJSUpg5cyZxcXG0atWKzZs3ExYWBkBcXNwNx8wIwo2YLFZilVrbH0G21cwVq4FmSlc+9GhW6NhYpZbU88mVevCF+ek4HpfF+cRsck1Wwvx0Yp6tUth7CWyh5lVqnMyLL76ISqXinXfecUSeaoQYJ1M+N9Ob5qm4TOI+XE29z74h02pmdOZxUqxGtnhHopLl35MkA6OvD1sWvUOonxtmi1ThsRqiTab8qmM8lGAfDh0nYzQa+fzzz9myZQsdOnQoNjHmwoULK5OsUAv8eTKOtzYdIyvPgrtWwVv3t+LeViUvIuXs9Bv/pN5n35BmNfF4xjGuWA185dmqUIBBgn9HDONqlpEgH6lSo849XdQ8cWejYvvFG3pxooR386lUkDl27Bjt27cH4MyZM4U+u9kXX7rVzf7pOPFZZmRAjsnM7J+O18ogo/9rHx4zPyDZauQR6SLJLgq+CexG8wyT7RhrgB+XnxyFqkskskvpxKToqeujEw8+BxLdvG8+lQoyf/31l+3/C2rbRHC5NaTozQColTIMZsm2XZvk/XuC+MdfQ2Y2E9+5NZbT6axc9g29Okdg2X8Mc0IKykBfLoXVJyk2C6VMTpCnhhBPF9reYJ4xvdHM9tOJhdY/0akr3YnzlnOz9V4UqtCF+YsvvuC9997j7NmzADRu3JiJEycybtw4u2VOcD4+OjU5RgMGs2Tbrk2M52KIe/hl4rLSCeneiYFrFjBAqUChUOQfcF035YYWKyjyq8dahXqUq91k++lE/jiZiFap4MS17sp9W4c47H4EwdlVKsi8/vrrvPfeezz//PN06dIFgD179jBp0iQuXrxYrrnDhNrp3laBrNgTg9kCSkX+dm1hjksibshkopPiGZV7mgF+EUxMzSu18b0yb9VX03LRKhU08HflQlIOV9OqZykEQXBWlQoyS5Ys4bPPPuPhhx+27Rs4cCBt2rTh+eefF0HmJubn5sI9zYPwcVWTmmPEz61yo+CrmyU9i7hhUzhzMZpHc07iUTeE4eMn2b27bKi3Cyfissq1Jr0g3AoqFWQsFkuJgzAjIyMxm2tfHb1QfvV8dZxOyEaSQKtSUM/X+ZdeturziH/kFY4dPcqo7BP4h4fx5/ZtXNSrUF/XXdYeAyIrsib9reRm6vouVEylgszIkSNZsmRJsa7Kn376KY888ohdMiY4p9r2EJXMZhKeeJO8fUf5XZZFaKMGbNmxHX9/fzLiMu3eXVanVoo2mBKIQZa3rio1/P/+++907twZgL1793L58mVGjRpVaHZjMWbm5mKyWLmSpudSsh4JySHLINuLJEkkTZ5P/K878NTp+N/ab7G2aWSbY89Zu8vejG/9169wKabRubVUeZzM+fPnAfD398ff359jx47ZjhPdmm8+JU32WNJAQ2eQ+vYnbF2xmqcyT/LNi/No0DWi0OdFG/ZNFiun4jJr/OF+M771i0GWt64qj5MRbi0lTfZoL/Z8g09fsoZfFnzMk5kn6NC8FXdPeOKG5zjLw/1mfOt31lKj4HhilJhQISVN9mgv9nrIZ637lQ1TZ/FM5klub9aSnw7uxcXlxr28nOXhfjO+9YtBlrcuEWSEChnUvg6JmXlEJ+lpGeLOoPZ17Ja2PR7yOX/sJW7CHOblRHN34xZ8H7UPrVZbrnOd5eEu3vqFm4kIMkKFJGcZaRToSYsQbwxmC8lZRvzcyvcQv5GqPuTzDhzn6pjpKCxW1o94liafvY26AqufVtfD/UbVguKtX7iZiCAjVEhcRi6n4tIxWSRUChmBHhq7PRCr8pA3nrnI4gGP8FnyWTYMeJyWn/8PmVpVoetX18PdWdp+BKE6iCAjVMjJ2Ey2n07CYgW5HPzdtNzdzD5Ty1T2IW++msCCHoOZnnCYYaFNafTVOxUOMFVVkU4LztL2IwjVQQQZoUJOxWaSqjeikMmxWK2cis2s0fxYUjOY3W0Ab16N4tHAxnwRtRuVp3ul0qpK77aKlE6cpe1HEKqDCDJChSRk52GySCiUYLJKJGTn1VherDm57Lz/SWZEH2CcXyM+3r8Dlb9PpdOrSjVWRUonomFfuJWIICNUiL+rBjn5b/3ya9s1wWo0ET/mdeqeiuX7Ol3p8/vXqOsGVSnNqlRjVaR0Ihr2hVuJCDJChfh7aFCrFCABsvzt6ma1WJjcpSfWk9FM9GtC7/XvoWneoMrpVqUaS5ROBKFkIsgIFeLvoaVFsDsalQKDyYK/h326L5eXJEk8d0dvlvy7g6luDQj8fCbajq3tknZVAoUonQhCyUSQESqknreOf5UKkGRolArqeVffVP9Wq5Un7+rDF3u38oZrA1787CNce3e1W/qlBYqbccJKQaguIsgIFVLH24UgTy2ZeWbcNUrqVOOiXO+OeYZlO//gf26NeHrODDyG962W64pxLYJQeSLI1FI19Xada7YSGeZraxzPNVfPVP85v++i3y/H8fNowcCJz+D93IhquS6IcS2CUBUiyNRSNfV2Xd1jPEwmE08PHcGDey7TxKqi/8iH8X3zaYdesygxrkUQKk9ULNdS179da5SKanu7ruerw9NFSUp2Hp4uSocuv2wwGBh8bz++3vgdV3Kz0PXsTMCiqcjk1ftr2zDAjYi6XgR5aoio6yV6jglCBYiSTC1VU2/XMSl6MvMs+Lppycg1E5Oid0gJSq/XM6jfALbv2M4Sjxb07noHgZ/PRKZy3K9saVWQjug5JjoTCLcKEWRqqZoal1Fd7RMPPzSEnTt38Ll7C+5s1ZbglXORuzq2k0F1VkGKzgTCrUIEmVqqpsZlVEcJypqt57FkBQ+7t6BzWGNC1i1A4eNp9+sUVVoAdUSpQ3QmEG4VIsgIFeLIElRaWhrvzpvHU2fzaBudhjygLiHfLkAZap9Znm+ktADqiFKH6Ewg3CpEkBEqxFElqKSkJHr37k3MqTP0cGlOuLs3wavmoW5S3+7XKk1pAdQRpQ4xDY1wqxBBRqgQR1QdxcfHc88995B06TLfaJtSX+NG4Bdvo+3QssTj9UYz208ncjUtl1BvF7o3DUCnrvqvcmkB1BGlDjENjXCrEEFGqBB7Vx2lpqbSvXt3shKSWKlqTAOljoAPpuHas3Op52w/ncgfJxPRKhWciMsCoG/rkErn4UZEqUMQKk8EGaFC4jL1nIxLx2wBpQICPdVVCjLe3t4MahlJ76QzhClc8J35LO5D+pR5ztW0XLRKBQ38XbmQlMPVtNxKX788RKlDECpPBBmhQs4lZHPkSiZapZxck5VQLxfublrxdM6ePcvZs2fpLvfgiV1XQeGC1/Mj8Hp6+A3PDfV24URcFheScsg1WQitxvnTBEGoGBFkhArJM1uQJCt5JgmQyDNbKpzGiRMn6NmzJ35u7nynD0VuteI+vC8+r48v1/ndmwYAFGqTEQTBOYkgI1RIpt5ESrYJtVKOwWwlU2+q0PmHDx+mV69eBHr78IUhBLnBhK53V/zfexmZTFauNHRqpUPbYARBsB8RZIQKyTNYuJKWi5X8ie/yDOUvyfz777/07NmT+qF1+MJUB4+cLLQdWxP42QxkSvGrKAg3IzFZklAhPx25QsHk/tZr2+Xl6+vLPd3u5CtNEzxSslA3b0DQyrnIddW7uqYgCNVHBBmhQnKMZW+XZO/evaSmplLX15/39P7oLiWgrBtE8Np3UXi5OyajgiA4BVFHIVSIWSp7u6gtW7Zw//3388TYcbwSp8Bw6BRyX0+C1y1AGezvuIwKguAURJARKsRDA+mGwtul+fnnn3nwwQe5p0cPJuW4k7t9BzKdC8Gr56NuVM/xma2g0mYzENPyC0LliSAjVIhaIQesRbaL+/777xk2bBj9+/fno0a3k/flJlApCfpyFtp2zasptxVT2mwGYlp+Qag88TomVEigZ/7AR1mR7aISEhJ48MEH+aTrffkBBgj4aDq6uztWRzYrpbTVRmtqFVJBuBmIICNUiEqRH16kItsFjh49CsD48eNZ2nc4WfOWA+D7vxdwH9yz2vJZGd46NQazpdhEmKXtLy+TxcqpuEz2nE/mVFwmJov1xicJwk1CVJfZ0a1Qd3/5WnVRSdtLly7l6aef5vfff6eLUUPySwsB8Jr4KF5PPlSt+ayM0ibCrOoEmaK6TbiViSBjR9X1MKnJYKY35b+FK2Rgkf7bXrRoEZMmTWLChAnc7uJL/OiXwGrFfeQAfF59olryVlWlTYRZ1QkyxSqYwq1MBBk7qq6HSU2+GbtrleRkm7FI/22/8847TJs2jZdffpkZI8cRd//zSAYjrv264T//xXJPF3OzclHKOXgxhSyDGXeNkoci69R0lgSh2jhtXc7ixYsJDw9Hq9USGRnJzp07Sz12w4YN9OrVC39/fzw8POjSpQu//fZbNeY2X1Xr7surJhuiH+5UHw+NDK0CPDQyhrQP4aeffuLNN9/k7fETiB8+BWtWDtrObQlY+qaYLga4kpZLfJaBbIOFuMw8rjh4aQJBcCZOGWTWrl3LxIkTmT59OlFRUXTr1o2+ffsSExNT4vE7duygV69ebN68mYMHD3L33Xdz3333ERUVVa35bhjgRkRdL4I8NUTU9XLY4lbVFcxKElnfi4YB7gS4awnVmujUOICtW7fy2rMvED9sCpbEVNQtGxL0zRzkLmUMormFxGfmEeiupWtDX4I8XIjPzKvpLAlCtZFJknSDMdvVr1OnTrRv354lS5bY9jVv3pwHHniAOXPmlCuNli1bMmzYMN54440SPzcYDBgM/40qzMzMpG7dumRkZODh4dyNsjXZJjP6811sPZtG2p+foj+9i2GzV/P147dz9YEJGI+cQVkvmNCfF6MM8quW/NQGvxyNZcuJRFxUCnJNFnq1CBCzSAu1XmZmJp6enjd8ZjpdXYbRaOTgwYNMnTq10P7evXuze/fucqVhtVrJysrCx8en1GPmzJnDjBkzqpTXmlKTKzVuPZtK2m+LyTr8K53bDSLvn9PEbf0J45EzyP288qeLEQGmELH+jXArc7ogk5ycjMViITAwsND+wMBA4uPjy5XGggULyMnJYejQoaUeM23aNCZPnmzbLijJCKWzWCwo171N9sX9vOPWmIcuJ8HldeQBaNSErHkXdUPxHRYl1r8RbmVOF2QKFO2RJElSuXoprV69mrfeeotNmzYREFD6G6NGo0Gjqb1tBjVRZXZgyQoSLv3Lu25NuF9b5Ls1GDFdjkfTthJrMQuCcNNyuiDj5+eHQqEoVmpJTEwsVropau3atYwdO5Zvv/2Wnj2de3R5VRV0Y1bIZOy9kEKIpwttr3U2sHewMRqNyIHAz35kq3cHfOWq4gfJIPm1D3DtewcyhcKu1xcEofZyut5larWayMhItmzZUmj/li1b6Nq1a6nnrV69mscff5xVq1bRv39/R2ezxhV0Y7ZIVuIzDMSk6jl0OZ3zidl2vU5eXh6DBg3iyYeGY4lNKjnAAEhguZpI3t4jdr2+IAi1m9MFGYDJkyfz+eefs2zZMk6ePMmkSZOIiYlh/PjxQH57yqhRo2zHr169mlGjRrFgwQI6d+5MfHw88fHxZGRk1NQtOFxBN+ZLKXokJOr56uw+ZiYnJ4f77ruPv/76i4ER5ZvY0pyQYrfrC4JQ+zlddRnAsGHDSElJYebMmcTFxdGqVSs2b95MWFgYAHFxcYXGzHzyySeYzWaeffZZnn32Wdv+xx57jC+//LK6s18tCsbgyGQSaoUMhVxm1zEzWVlZDBgwgIMHD/LLL79wG67ELfn5hucpA33tcn1BEG4OTjlOpiaUt8+3s3FUB4C5c+cye/ZsfvnlFzpHtCd+7Ovk/rEXif+m+S9EBoqQAMIOrhNtMoJwCyjvM1MEmWtqW5BxVHAp6MVnsVi4cOECDQKCiH9kKnn/HAG1EsloRqJIPeu1qBO4bBZuA7pXOQ+CIDi/8j4znbJNRrixgt5l8RkGuzX4JyYm0rlzZ/78808UCgXhnr7EPvACef8cQe7hRsh377Ph0UdJdnUvdJ4iJEAEGEEQSuSUbTK3ksqWSOw943NsbCz33HMPGRkZBAcHY4qJI+6hyZiir6Dw9yF43QI0rRqxeONllg59mtYJl/HR55Cqc2XjsqdEFZkgCCUSQaaGVXbafneNkr0pKZxNzEKlkNE4wLXSeYiJiaFHjx4YDAa2b99OmFXJ1f7PYIlPRlkvmJBvF6Jq8N/09Fa5nMPBYbZtEWAEQSiNqC6rYVWZtl9OfnNIVVZrkSSJESNGYLVa2bFjB3WzTFy97zks8cmomtYn9KePCwUYQRCEihAlmRrmrVNzKUVf4Wn703KNKBUKXDVyDGYrabmVqy6TyWQsW7YMFxcXfC8kEDvqVSR9LprIFgSvmofCx7NS6QqCIIAoydQok8WK2WJFJpMwmi20CvEo9xo0qVlGdpxJYPPROHacSSA1q2JB5tixYwwaNIisrCyaNGmC9+ELxI14GUmfi0v3DoSsf6/EAOOjlZW5LQiCcD1RkqlB5xOzOR6XhVqhxGC2oFTIy9Xob7JY2RudTHymARe1ArkJziVllfu6UVFR9OrVizp16mAwGJA2bSfpxflgteI6oDuBS99Apim5ROXvqSU1L7fQtiAIQmlESaYGVbY95nxiNlfSclHK5GiUcowWiQtJOZyKy8RksZZ57r59++jRowcNGjRg69atKNb8TtKkuWC14v5IfwI/n1FqgAFILLJ0cNFtQRCE64mSjB1VtDtypdtj9EbCfV3JMVpIzzFhtkj4uKk5dDkdKL13Wnx8PL169aJVq1b8/PPPWD5cQ9oHKwHwen4EPq+Pv+FyCmnGsrcFQRCuJ4KMHZ2Oy+SPU4kYzVZUChm9mgfSqo5XiccWtMdYrVYupeoJcHPBbLFislhvWGXmrVPTKMgDuULGybgsAj003NM8kLQcc5mloaCgIJYuXcqAfv3Im/kpmSt+AMDnjfF4P/9Ipe9bEAShNCLI2NHxuAzi0vPwc1MTn2HgeFxGqUHmfGI2R65mcDYxi3OJOTQMMCHJQFmOpZULOgc0DnCjU7gvqTkm0nLMpZaGfvvtNy5cuMDTTz/N8AcfIuHpt8n54S+Qy/F/dwoej95X7nvUAIYi24IgCKURbTL2JMmQyJ8KTkICqfSqpzS9kTS9kaQsExarRGKWkfRr+25EpZDTMMANb50arVqO3mgiISsXTxcl9Xx1hY794YcfGDhwIL/++ivmzGziHnklP8ColAR+9laFAgxA0Rhmp0mfBUG4SYkgY0ctQzwIudbbKsRTS8uQ0ksk3jo16XoTVklCIZeBBOl6U7nbZQpmCjhxNZtTcdlYrBIZuWZirs0aAPDtt9/y4IMPct9997Hm0y+IHzaF3G37kelcCF41D7eBd1f4HkWbjCAIFSGqy+yoabAHSoW8UMN/aRoGuNE21JPkrDysagUapYy2oZ7lHidT0DNNo5SjVSnQKhWFeqj9/PPPDB8+nIcffpjP33mXpCEvYjx5AbmXO8Gr56Pt0NIu9ywIglAWEWTsSFWO9pTrj63v50qYn5uto0B9P9dyj5NJyMjlp8NxZOQakMnk+LupcdUobSWhrl27MnPmTF4c9igJD0zAfCkORaAvwd8uRNO8QZXuUxAEobxEkKlBSTl5pGQZSM0xYrJY8XVT0zTY44aB5nxiNnsvpHAxJQejxYpaLsNotRJR14u/N3+LpkcPwsPDmfLAMOLufx5LYirK+qGErF+IKiykmu5OEARBtMnYlcli5VRcJnvOJ5drYOS5hGz2XUzl8JV0Dl1J5/fj8ZyOy7zhddL0RuIy8rACWpUCqyTDYLby86rPeGLcONasWUPuP0eIvRZg1C0b5k90aYcAE6Are1sQBOF6oiRjRxWdtt9ksWKxWrBKEiqFjAy9ieOxmaV2ey7grVOTZ7KQmWtCrZAjIfHnqiVsX/0xr776KhNuu5u4IZORcg1oO7YmaNVcFJ7uZaZZXtnGsrcFQRCuJ4KMHVV4ITEJ0nJM5BgtyGVytAo5Zqns0g/kdxpoF+ZNUlb+iJXYv77h9C/Lefvtt3mhZWfiH50KZgu6ezoTuOxt5Dr7zS9msZS9LQiCcD0RZOygYDqZmBQ9sRm5WKwSZqv1ht2RkzIMZOWZMVoBLBhy8wg+dYasq+dQBvqi7dymxAXBVAo5zQLdOXwxldALF6gj09J96BM8F9KCxKdmgiThNrgnAR++ikytsuu9KhVgMBfeFgRBKI0IMnZQUE2mUMiQARbJQkRd7xt2R953KRmLBAoZ3B59mmf/+QO/7CwSr32uCPHH738v4Dage6HzTBYrOT9vp9VHCxhqcUV+bb6xlK0nAfAYPQi/dyYik9u/yS3ES83ZZGOhbUEQhNKIhn87KKgmC/XSEebrRj0fV5rdoJeYyWIlI9eCVcoPMG/8+T2+2YWn67fEJZEw5jWyf9peaP+pFT/zw/+m8HriIaLMxaf4d+nW3iEBBkBxregiK7ItCIJQElGSsYPKzKZ8PjEbtVKGUrLyzN4/gBKWUc6foYakSXOxJKchU8gxmcxMn/AcPxsSmOfWhEiVR7FTkl//ENd+3UqsaquqLL3h+qzZtgVBEEoigowdFFSLlWekf4E0vZGWdType/Yc/jllLzhmTc8i+aUFmCQrk7POsMWYwnvuTemv8S92rAywXE0kb+8RXG5vV6n7KUuq3lrmtiAIwvVEkLGDioz0L+CtU6OWyfHOzSnX8eq2TVEH+qL9M4EP9f700viWebw5IaVC+Skvo7nsbUEQhOuJNpka0jDAjWyjmViVyw2PzZMsXB7Rg7or57Lyu29vGGAAlIE3PqYyVLKytwVBEK4ngkwNUSnk6A0WToWEkezqTmmVTjlYeMJwjsGvTkav16Pt3AZrgG+pxyMDRWgA2s5tHJJvq1T2tiAIwvVEkHGQ8kwxE+qjA4WcT7r2QgbFAkeWZGZM+nGOWXL49ttv0el0yBQKDFOeKPH4gp4DfrMmOKTRnxKuKVpkBEEoiwgyDlIwdiY+w8Chy+mcT8wudsxDkXUI99exr1FT5t37IBnu/7XrZFjNPJ57mnMaC1v+2kq3bt1snx1u0ZLZPQeR7Fp4qhhLgB/mBa+i6dsNQRAEZyAa/h2kPFPMtAjxpHmQBxk5Zi5FtOXdLh0YrsjkHn812YYsDDNfZevKlbRv377QeVdSctka1pRtdRvTOuEyPvoczN4eDH6iHwZJQkrMrnBHhPJy08pIz5MKbQuCIJRGBBkHKc/YmZgUPYlZRoxWKxazDHOOicN1PejVpx1tXV05NnQgihKqvS6m5Hd5tsrlHA4OA8BDDQOQUMrl5VrCubK8XJSk55kKbQuCIJRGVJc5SMMANyLqehHkqSGirlexsTMmi5XDV9LI0BtAkpAhQ5+WyMcvPcr48eMBSgwwAMiKlx7yTLDp0FUuJGeVewnnylAp5WVuC4IgXE+8hjrIjcbOnE/MJj49DwkZIEPKTODoZ1PQKmW89dZbZabdOsSLbScTyb1uBmQPnZLkLCMZemO5l3CuDLM5f342pRzM1vxtQRCE0ojX0BqSpjcS5OVCsyB35FnxHFwyEa1KwarvN9OwYcMyz+3YwIcADy2Ka893OeB2bellV7WyXEs4V5aHTolcBpIEcln+tiAIQmlEkKkh3jo1sem55JmtZJz+B4Vay9g5y+nUttkNz801WfB2U+OmUaBV5M8jlpJjIMdooqmDGvwLNAv2wlUjR6uU46qR0yzYy6HXEwShdhOvoTWknq+OvJxMLqbk4dv5AQI79eNstoadp5MYEBFa5rnZeWYy9CYsVpDJZICESi5Hp1KglDu2+spNo0CnViInfyCmm0bMwiwIQulEkKkmBQubFUyieTjqX/73+GAaD34RRYOO+Op0GM1WjlzJKBZkip6rVcsJ99ORZ7SQkWvBRQkd6vvg7+5iWy3TUTx1aoI9tMhkMiRJwtOBnQwEQaj9RHVZNbl+cOa6zX8yZtj9hNStT/PIjlitVvRGExZJQq0uXhIpOrAzJ9eCVqlErVLgplWhVSnJMVjINVkI9b7xXGhV4aKSI5PJ0aoUIJPjohK/QoIglE6UZKpJweDMuNP/MmfCo9Rv0pKXFy7nfLqFuJwk1Eo57holjf3dSz23YGBnaq4Rg9mKi1qBl06Fh1ZJHW8ddzbxp3vTAIfeR6NAN9qk52K25C+93CjQcT3ZBEGo/USQqSbeOjUXk3N4f85bNGoZwRffrMXD3R3DmUS6N/MnxFOHyWrFx6149VPRgZ1JmXnkmixolQrMVit+bloe6xrusFH+1wv20NEsyBONUoHBbCHYQ+fwawqCUHuJIOMgRdtRQjzVUM+bRV+sItjXi5Zh/rauxocup9se2v5u2mJpFV0U7Up6DgazFY1Sjl5vwVUjd+jYmLLyUl3XFQShdhJBxkEK2lHUCjnrv9vAps8XsvvvHTS7rWmh4+r56ohOzuZqWi6h3i7U8y1eMig6sPPfS6n4uanRqpR46VS0CPZ06NiYsvIiCIJQFtFq6yAF7ShHd25m0avPUL9Jc7y9vYsdF5OiJzPPgq+bloxcMzEp+hum3aaOF82DPajj7ULzYA/a1PFywB0IgiBUnSjJOIi3Ts2KFSv4dNZL3H7vIN5f+hkqlarYceWZrbmopsEeKBVyUWUlCILTE0HGQVR5qXwxZyr3DR3J7Hffp3FQyVVM5ZmtuVjaospKEIRaQgQZB5AkiYbh9dmzZw/t27fHbJUKdQJoGOBma0MRDemCINzMRJCxs7lz55KUlMT8+fOJjIwE4Hxilq0TwKVrbS4FJRFRKhEE4WbmtA3/ixcvJjw8HK1WS2RkJDt37izz+O3btxMZGYlWq6VBgwYsXbq0mnKaT5IkZsyYwdSpU3F1dS302fXtLhqlwqGLigmCIDgTpwwya9euZeLEiUyfPp2oqCi6detG3759iYmJKfH46Oho+vXrR7du3YiKiuLVV19lwoQJfPfdd9WSX0mSePXVV3nrrbeYPXs2M2bMuDZxZT5vnRqD2VKhdhdBEISbgUySJOnGh1WvTp060b59e5YsWWLb17x5cx544AHmzJlT7PhXXnmFH374gZMnT9r2jR8/nsOHD7Nnz54Sr2EwGDAY/ptMMjMzk7p165KRkYGHR8Wqr7788ktGjx7Ne++9x8SJE4t9XnRg5vVtMoIgCLVRZmYmnp6eN3xmOl2bjNFo5ODBg0ydOrXQ/t69e7N79+4Sz9mzZw+9e/cutK9Pnz588cUXmEymErsOz5kzhxkzZtglzw8//DB+fn4MGDCgxM9Fu4sgCLcqp3udTk5OxmKxEBgYWGh/YGAg8fHxJZ4THx9f4vFms5nk5OQSz5k2bRoZGRm2f5cvX650njUaTakBRhAE4VbmdCWZAte3aUB+u0fRfTc6vqT9BTQaDRqNpoq5FARBEMridCUZPz8/FApFsVJLYmJisdJKgaCgoBKPVyqV+Pr6OiyvgiAIQtmcLsio1WoiIyPZsmVLof1btmyha9euJZ7TpUuXYsf//vvvdOjQocT2GEEQBKF6OF2QAZg8eTKff/45y5Yt4+TJk0yaNImYmBjGjx8P5LenjBo1ynb8+PHjuXTpEpMnT+bkyZMsW7aML774gilTptTULQiCIAg4aZvMsGHDSElJYebMmcTFxdGqVSs2b95MWFgYAHFxcYXGzISHh7N582YmTZrExx9/TEhICB988AEPPvhgTd2CIAiCgJOOk6kJ5e3zLQiCIJT/memU1WWCIAjCzUEEGUEQBMFhRJARBEEQHEYEGUEQBMFhRJARBEEQHEYEGUEQBMFhnHKcTE0o6MmdmZlZwzkRBEFwfgXPyhuNghFB5pqsrCwA6tatW8M5EQRBqD2ysrLw9PQs9XMxGPMaq9VKbGws7u7uZc72XJKCBc8uX758SwzkvNXuF8Q93wr3fKvdL1TtniVJIisri5CQEOTy0lteREnmGrlcTp06daqUhoeHxy3zywm33v2CuOdbwa12v1D5ey6rBFNANPwLgiAIDiOCjCAIguAwIsjYgUaj4c0337xlVtq81e4XxD3fCm61+4XquWfR8C8IgiA4jCjJCIIgCA4jgowgCILgMCLICIIgCA4jgowgCILgMCLIlNPixYsJDw9Hq9USGRnJzp07yzx++/btREZGotVqadCgAUuXLq2mnNpHRe53w4YN9OrVC39/fzw8POjSpQu//fZbNebWPir6My6wa9culEolERERjs2gnVX0fg0GA9OnTycsLAyNRkPDhg1ZtmxZNeXWPip6zytXrqRt27bodDqCg4MZPXo0KSkp1ZTbqtmxYwf33XcfISEhyGQyNm7ceMNzHPLckoQbWrNmjaRSqaTPPvtMOnHihPTCCy9Irq6u0qVLl0o8/sKFC5JOp5NeeOEF6cSJE9Jnn30mqVQqaf369dWc88qp6P2+8MIL0ty5c6V9+/ZJZ86ckaZNmyapVCrp33//reacV15F77lAenq61KBBA6l3795S27ZtqyezdlCZ+x04cKDUqVMnacuWLVJ0dLT0zz//SLt27arGXFdNRe95586dklwul95//33pwoUL0s6dO6WWLVtKDzzwQDXnvHI2b94sTZ8+Xfruu+8kQPr+++/LPN5Rzy0RZMqhY8eO0vjx4wvta9asmTR16tQSj3/55ZelZs2aFdr31FNPSZ07d3ZYHu2povdbkhYtWkgzZsywd9YcprL3PGzYMOm1116T3nzzzVoVZCp6v7/88ovk6ekppaSkVEf2HKKi9zx//nypQYMGhfZ98MEHUp06dRyWR0cpT5Bx1HNLVJfdgNFo5ODBg/Tu3bvQ/t69e7N79+4Sz9mzZ0+x4/v06cOBAwcwmUwOy6s9VOZ+i7JarWRlZeHj4+OILNpdZe95+fLlnD9/njfffNPRWbSrytzvDz/8QIcOHZg3bx6hoaE0adKEKVOmkJubWx1ZrrLK3HPXrl25cuUKmzdvRpIkEhISWL9+Pf3796+OLFc7Rz23xASZN5CcnIzFYiEwMLDQ/sDAQOLj40s8Jz4+vsTjzWYzycnJBAcHOyy/VVWZ+y1qwYIF5OTkMHToUEdk0e4qc89nz55l6tSp7Ny5E6Wydv0ZVeZ+L1y4wN9//41Wq+X777/n/+3de1BU5RsH8O+yy9rGTVgEVnAkrrMMd4jiopAsMRGal6gJUjBkxhphgKDxgnipGUyDUkRGGVphBtTJctSamGW4BVpYsGsOkBZCpFFmUiBgCJzfHw47Lgt7wwPi7/nM7B+e877vPs/L8Tx73j1wbt++jXfeeQd37tyZF9/LGJJzSEgIKioq8Prrr+PevXsYHR3FqlWrUFhYOBshzzq2zlt0JaOjyX/+n2EYjY8EmKr9VNsfV/rmO+HEiRPYvXs3Tp06BRsbG7bCY4WuOY+NjSE+Ph579uyBm5vbbIX3yOnzMx4fHweHw0FFRQWCgoIQExODgoICHD9+fN5czQD65dze3o60tDTk5uaipaUFVVVV6OrqwubNm2cj1DnBxnlrfn0EmwPW1tbgcrlqn3Zu3bqlVvUn2NnZTdmex+NBKBSyFuujYEi+E06dOoXk5GR89tlnkEgkbIb5SOmb88DAAH744QfI5XJs2bIFwIOTMMMw4PF4kMlkWLFixazEbghDfsYikQj29vYqf9pdLBaDYRjcuHEDrq6urMY8U4bknJeXh9DQUGRnZwMAvL29YWJigmXLluGDDz54rFckDMHWeYuuZLTg8/kICAhAdXW1yvbq6mqEhIRM2Sc4OFitvUwmQ2BgIIyNjVmL9VEwJF/gwRVMUlISKisr592atb45m5ub48qVK1AoFMrX5s2b4e7uDoVCgeeee262QjeIIT/j0NBQ/P7777h7965y27Vr1x7Jc5hmgyE5Dw0NqT2Mi8vlAtD+yOH5iLXz1oxuG/g/MXHrY2lpKdPe3s6kp6czJiYmTHd3N8MwDLN161Zm/fr1yvYTtwJmZGQw7e3tTGlp6by8hVnXfCsrKxkej8cUFRUxvb29ytc///wzVynoTd+cJ5tvd5fpm+/AwADj4ODAvPrqq0xbWxvT0NDAuLq6Mps2bZqrFPSmb85SqZTh8XjMkSNHmM7OTqapqYkJDAxkgoKC5ioFvQwMDDByuZyRy+UMAKagoICRy+XKW7Zn67xFRUZHRUVFzNKlSxk+n8/4+/szDQ0Nyn2JiYlMeHi4Svv6+nrGz8+P4fP5jKOjI1NcXDzLEc+MPvmGh4czANReiYmJsx/4DOj7M37YfCsyDKN/vh0dHYxEImEEAgHj4ODAZGZmMkNDQ7Mc9czom/OhQ4cYDw8PRiAQMCKRiElISGBu3Lgxy1Ebpq6uTuP/y9k6b9Gf+ieEEMIa+k6GEEIIa6jIEEIIYQ0VGUIIIayhIkMIIYQ1VGQIIYSwhooMIYQQ1lCRIYQQwhoqMoQQQlhDRYYQQghrqMgQQghhDRUZQmZRREQE0tPTDd4/30zO50nLj2hHz5MheouIiICvry8++eSTuQ7lifPFF1889o+DmInJ+dGx9OSjIkNYMTIyAj6fP9dhzDtWVlZzHQKrnvT8iDpaLiN6SUpKQkNDAw4ePAgOhwMOh4Pu7m5ERERgy5YtyMzMhLW1NaKiogAAjo6Oap9SfX19sXv3bgAPHv60f/9+ODk5QSAQwMfHB6dPn9YYg7YxgQefkNPS0vDee+/BysoKdnZ2Kvsn2qSmpiI9PR2WlpawtbXFsWPHMDg4iI0bN8LMzAzOzs74+uuvVfpVVVUhLCwMCxcuhFAoRGxsLDo7O5X7T58+DS8vLwgEAgiFQkgkEgwODk6ZS1VVFSwsLFBeXq6MafLykrY8BgYGkJCQABMTE4hEInz88cdal6UGBwexYcMGmJqaQiQSIT8/X62PLvOsbS4me/g9pjuWysvLIRQK8d9//6n0XbduHTZs2DDt2JM1NTXB2NhYZZyuri5wOBz8+uuvOo9DZoaKDNHLwYMHERwcjJSUFPT29qK3txdLliwBAJSVlYHH4+HChQs4evSoTuPl5ORAKpWiuLgYbW1tyMjIwJtvvomGhoYZx1pWVgYTExM0Nzdj//792Lt3r9qT/8rKymBtbY1Lly4hNTUVb7/9NuLi4hASEoLW1lZER0dj/fr1GBoaUvYZHBxEZmYmvv/+e9TU1MDIyAhr1qzB+Pg4ent78cYbb+Ctt95CR0cH6uvrsXbt2imfpHjy5Em89tprKC8v13jy1JZHZmYmLly4gHPnzqG6uhqNjY1obW3VODfZ2dmoq6vDmTNnIJPJUF9fj5aWFl2nVqe50Ga6YykuLg5jY2M4d+6csu3t27fx5ZdfYuPGjTrHplAoIBaLsWDBApVtCxcuxNKlS/VLlBiMlsuIXiwsLMDn8/H000/Dzs5OZZ+Liwv279+v81iDg4MoKChAbW0tgoODAQBOTk5oamrC0aNHER4ePqNYvb29sWvXLgCAq6srDh8+jJqaGuVVFgD4+PggJycHALBt2zbs27cP1tbWSElJAQDk5uaiuLgYP/74I55//nkADz5RP6y0tBQ2NjZob2/HyMgIRkdHsXbtWuWJzMvLSy22I0eOYPv27Th79ixeeOEFg/MYGBhAWVkZKisrERkZCQCQSqVYvHjxtOPdvXsXpaWlKC8vV85FWVmZQY9R1jQXnp6eGvtOdywJBALEx8dDKpUiLi4OAFBRUQEHBwdEREToHNvly5fh5+ensk2hUMDHx0fnMcjM0ZUMeWQCAwP1at/e3o579+4hKioKpqamyld5ebnGJRddeXt7q/xbJBLh1q1b07bhcrkQCoUqRcHW1hYAVPp1dnYiPj4eTk5OMDc3xzPPPAMA6OnpgY+PDyIjI+Hl5YW4uDiUlJSgr69P5T0///xzpKenQyaTaS0w2vK4fv067t+/j6CgIOV+CwsLuLu7TzteZ2cnRkZGlIUdePBdiaY+msaabi5mIiUlBTKZDDdv3gTwoHAmJSWBw+HoPIZCoYCvr6/KNrlcTkVmltGVDHlkTExM1LYZGRmpLRXdv38fAJRLKl999RXs7e1V2jy8xKHPmA+bfJcWh8NRW8aZqs3D2yZOag/3W7lyJZYsWYKSkhIsXrwY4+Pj8PT0xMjICLhcLqqrq3Hx4kXIZDIUFhZix44daG5uVp6AfX190draCqlUimeffVbriVNTHhPzMHkMTQ+81fVhuLrMs6a5mAk/Pz/4+PigvLwc0dHRuHLlCs6fP69z/7GxMbS1taldybS2tmLNmjUzio3oh65kiN74fD7GxsZ0arto0SL09vYq/93f34+uri4AgIeHBxYsWICenh64uLiovCa+59F3TLb9/fff6OjoQE5ODiIjIyEWi9WuVDgcDkJDQ7Fnzx7I5XLw+XycOXNGud/Z2Rl1dXU4e/YsUlNTZxSPs7MzjI2NcenSJeW2/v5+/Pzzz9P2cXFxgbGxMb777jvltr6+Ply7dk2lnbZ51mUutNF0LG3atAlSqRSffvopJBKJxmNisqtXr2J4eFhl2fDbb7/FzZs3p7yS+e2337R+j0UMQ1cyRG+Ojo5obm5Gd3c3TE1NNd6WumLFChw/fhwrV66EpaUldu7cCS6XCwAwMzNDVlYWMjIyMD4+jrCwMPT39+PixYswNTVFYmKi3mOyzdLSEkKhEMeOHYNIJEJPTw+2bt2q3N/c3Iyamhq8+OKLsLGxQXNzM/766y+IxWKVcdzc3FBXV4eIiAjweDyDf0/EzMwMiYmJyM7OhpWVFWxsbLBr1y4YGRlNe4VkamqK5ORkZGdnQygUwtbWFjt27ICRkepnTm3zrG0udDHVsTQRR0JCArKyslBSUqK8+05XCoUCAFBYWIi0tDT88ssvSEtLAwC1u9YaGxuRm5uL4eFh7Ny5Ey+//LJe70U0oysZoresrCxwuVx4eHhg0aJFGtfft23bhuXLlyM2NhYxMTFYvXo1nJ2dlfvff/995ObmIi8vD2KxGNHR0Th//rxyacmQMdlkZGSEkydPoqWlBZ6ensjIyMCBAweU+83NzfHNN98gJiYGbm5uyMnJQX5+Pl566SW1sdzd3VFbW4sTJ07g3XffNTimgoICBAcHIzY2FhKJBKGhoRCLxXjqqaem7XPgwAEsX74cq1atgkQiQVhYGAICAlTaaJtnbXOhC03Hkrm5OdatWwdTU1OsXr1ar3EVCgWioqLQ1dUFT09PbN++Hfv27YO5uTmKiopU2i5btgyjo6MYHh6mAsMCDqPrAi0hZF4YHByEvb098vPzkZycrHO/x/G376OioiAWi3Ho0CG9+kVHR8Pf3x95eXla246NjeH69esQCASwtLSc8rtFYjhaLiNknpPL5fjpp58QFBSEf//9F3v37gUAvPLKK3McmeHu3LkDmUyG2tpaHD58WO/+ly9fRlJSkk5tuVwuXF1d9X4PohsqMoQ8AT766CNcvXoVfD4fAQEBaGxshLW19VyHZTB/f3/09fXhww8/1PvW6j/++AN//vmn2q3fZG7QchkhhBDW0Bf/hBBCWENFhhBCCGuoyBBCCGENFRlCCCGsoSJDCCGENVRkCCGEsIaKDCGEENZQkSGEEMIaKjKEEEJYQ0WGEEIIa/4HzuqdTc/4qEkAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# predicted quality vs. true quality mu_star\n", + "plt.figure(figsize=(4.2, 4.2))\n", + "plt.scatter(qs, pred, s=6, alpha=0.25, label=\"held-out states\")\n", + "bins = np.linspace(0, 1, 11)\n", + "idx = np.digitize(qs, bins) - 1\n", + "bx = [qs[idx == b].mean() for b in range(10) if (idx == b).any()]\n", + "by = [pred[idx == b].mean() for b in range(10) if (idx == b).any()]\n", + "plt.plot(bx, by, \"o-\", color=\"crimson\", label=\"binned mean\")\n", + "plt.plot([0, 1], [0, 1], \"k--\", lw=1, label=\"ideal\")\n", + "plt.xlabel(r\"true unmasking quality $\\mu_\\star$\")\n", + "plt.ylabel(r\"predicted $\\mu_\\phi$\")\n", + "plt.title(\"UQL recovers the unmasking quality\")\n", + "plt.legend(fontsize=8); plt.tight_layout(); plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Insertion quality $\\nu_\\star$\n", + "\n", + "When the model inserts a mask into the gap between surviving tokens $\\ell-1$ and $\\ell$,\n", + "its **insertion quality** is the probability that the mask decodes to *some* token that\n", + "genuinely belongs in that gap, $\\mathcal{S}_\\ell=\\{\\boldsymbol{x}_1^i: s_t[\\ell-1] empty gap\n", + "print(f\" insert into an EMPTY gap (nothing belongs): nu* = {nu_empty:.3f} gap={gap_empty}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 Training the insertion-quality predictor $\\nu_\\phi$ (IQL)\n", + "\n", + "Similarly, $\\nu_\\phi$ is trained with the **Insertion Quality Loss**, which is a BCE against the (soft) target $\\nu_\\star$, taking only the post-insertion sequence $\\boldsymbol{y}$ as input:\n", + "\n", + "$$\\mathcal{L}_{\\text{IQL}}(\\phi)=\\mathbb{E}\\Big[\\textstyle\\sum_{i\\in\\mathcal{I}}\\text{BCE}\\big(\\nu_\\star^i(\\boldsymbol{y}),\\,\\nu_\\phi^i(\\boldsymbol{y})\\big)\\Big],$$\n", + "\n", + "whose unique minimizer is $\\nu_\\star$. The predictor learns to assign near-zero quality to inserts whose neighbours are already adjacent in the target." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " step 500 IQL running avg = 0.4728 (irreducible floor = 0.2009)\n", + " step 1000 IQL running avg = 0.4763 (irreducible floor = 0.2009)\n", + " step 1500 IQL running avg = 0.4646 (irreducible floor = 0.2009)\n", + " step 2000 IQL running avg = 0.4717 (irreducible floor = 0.2009)\n", + "\n", + "running loss 0.4717 sits at the floor 0.2009 -> converged\n", + "corr(nu_phi, nu_star) = 0.644\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAGZCAYAAABbpUzOAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAhuNJREFUeJzt3Xd4U2UbwOFfmo50UzqgFChlyJANsveUoTIEZMoUBGU7EGUIioAgyMaPqYAoKioiUFbZKFOgiAJltLR0QfdM3u+P0kjozmjS9r2vi0vPyRnPOUnz5LxTIYQQSJIkSZIJWJk7AEmSJKn4kklGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUlGkiRJMhmZZCRJkiSTkUnGTDZv3oxCoeDcuXNZXtu3bx89evTA09MTOzs7KlasyIgRI7hx40aWbefMmYNCoSAyMrIwwi7REhMTmTNnDkePHs3yWmG9D5mfmzt37pj0PIYIDAxkzpw52cY4fPhwKlWqVOgxmVu7du1o166dzjqFQsGcOXO0y7ndt6JMJhkL8+6779KtWzc0Gg2rV6/G39+fWbNmcfbsWRo0aMCePXvMHWKJlZiYyNy5c7NNMoWlR48enD59Gm9vb7PFkJfAwEDmzp2b7ZflRx99xE8//VT4QVmg06dPM3r0aO1ybvetKLM2dwDSf3bs2MHixYt58803Wb16tXZ9mzZtGDhwIG3btmXQoEFcvXqVihUrGv38arWa9PR07OzsjH5sSyCEIDk5GXt7e3OHojdPT088PT3NHUa20tLSUCgUuW5TpUqVQorG8jVr1szcIRQK+SRjQT755BPc3Nz4/PPPs7zm6OjIihUriIuLY9myZQaf686dOygUChYtWsT8+fPx8/PDzs6OI0eOAHDu3DlefvllSpcujUqlokGDBnz33XdZjhMSEsIbb7xBhQoVsLW1pVy5crz66qs8fPhQu829e/cYMmQIXl5e2NnZUbNmTZYsWYJGowEyvpy8vLwYOnRoluM/fvwYe3t7pk6dql0XGxvL9OnT8fPzw9bWFh8fHyZPnkxCQoLOvgqFgrfeeou1a9dSs2ZN7Ozs2LJlCwBr1qyhXr16ODk54ezsTI0aNfjggw9yvV+ZX+5z585FoVCgUCgYPny4znYPHz5k4MCBuLq6UqZMGUaOHElMTIzONkIIVq9eTf369bG3t8fNzY1XX32V27dv53j+TNkVl7Vr147atWvz559/0rp1axwcHKhcuTKfffaZ9h4DaDQa5s+fT/Xq1bG3t6dUqVLUrVuX5cuX65zj33//ZdCgQTrv16pVq3S2OXr0KAqFgq+//ppp06bh4+ODnZ0d//vf/+jXrx8A7du3196nzZs3A9kXlyUnJzNjxgyd93PChAk8fvxYZ7tKlSrRs2dP9u3bR8OGDbG3t6dGjRps3Lgxz/sG8ODBA/r374+zszOurq4MGDCAM2fO6MSXeT+fLdrKKfa5c+fStGlTSpcujYuLCw0bNmTDhg3kZ9zhp4vLNm/enON9mzdvHtbW1ty/fz/LMUaOHIm7uzvJycn5ugdmISSz2LRpkwDEn3/+KYQQ4sGDBwIQAwYMyHU/Ly8vUatWLe3y7NmzBSAiIiIKdP6goCABCB8fH9G+fXuxa9cuceDAAREUFCQOHz4sbG1tRevWrcXOnTvFvn37xPDhwwUgNm3apD1GcHCw8Pb2Fh4eHmLp0qXi4MGDYufOnWLkyJHi+vXrQgghwsPDhY+Pj/D09BRr164V+/btE2+99ZYAxJtvvqk91pQpU4S9vb2IiYnRiXP16tUCEH/99ZcQQoiEhARRv359nXMuX75cuLq6ig4dOgiNRqPdN/P66tatK7Zv3y4OHz4srl69Knbs2CEA8fbbb4sDBw6IgwcPirVr14qJEyfmeL+Sk5PFvn37BCBGjRolTp8+LU6fPi1u3ryp8z5Ur15dzJo1S/j7+4ulS5cKOzs7MWLECJ1jjRkzRtjY2Ihp06aJffv2ie3bt4saNWqIMmXKiLCwsFzft8zPTVBQkHZd27Zthbu7u6hWrZpYu3at8Pf3F+PHjxeA2LJli3a7BQsWCKVSKWbPni0OHTok9u3bJ5YtWybmzJmj3ebatWvC1dVV1KlTR2zdulUcOHBATJs2TVhZWelsd+TIEe39ffXVV8Uvv/wi9uzZI8LCwsSnn34qALFq1SrtfQoPDxdCCPH6668LX19f7XE0Go3o2rWrsLa2Fh999JE4cOCA+Pzzz4Wjo6No0KCBSE5O1m7r6+srypcvL2rVqiW2bt0q9u/fL/r16ycAERAQkOt9S0xMFDVr1hSurq5ixYoVYv/+/WLixImiYsWKWT7Xbdu2FW3bts1yjGdjF0KI4cOHiw0bNgh/f3/h7+8v5s2bJ+zt7cXcuXN1tsvumICYPXu2ECLj7ySn+/bw4UNhZ2cnZs6cqbN/VFSUsLe3F++8806u125uMsmYybNJ5syZMwIQ77//fq77NW3aVDg6OmqXDU0yVapUEampqTqv1ahRQzRo0ECkpaXprO/Zs6fw9vYWarVaCCHEyJEjhY2NjQgMDMzxPO+//74AxNmzZ3XWv/nmm0KhUIgbN24IIYT466+/BCDWr1+vs12TJk1Eo0aNtMsLFiwQVlZW2vuWadeuXQIQe/fu1a4DhKurq4iOjtbZ9q233hKlSpXKMeacRERE6HwxPC3zfVi0aJHO+vHjxwuVSqVNfqdPnxaAWLJkic529+/fF/b29uLdd9/NNYackkx297hWrVqia9eu2uWePXuK+vXr53r8rl27ivLly2dJ9m+99ZZQqVTae5mZZNq0aZPlGN9//70AxJEjR7K89uwXdWbifva+7dy5M8vnwdfXV6hUKnH37l3tuqSkJFG6dGkxduzYXK9rzZo1AhA///yzzvoxY8YYlGSeplarRVpamvj444+Fu7u7zg+evJKMEHnfNy8vL5GSkqJdt3DhQmFlZaXzWbBEsrisiBFC5FnuXRAvv/wyNjY22uWbN2/y999/M3jwYADS09O1/7p3705oaKi2ldvvv/9O+/btqVmzZo7HP3z4MLVq1aJJkyY664cPH44QgsOHDwNQp04dGjVqxKZNm7TbXL9+nT/++IORI0dq1+3Zs4fatWtTv359ndi6du2KQqHIUinfoUMH3NzcdNY1adKEx48fM3DgQH7++Wejtgh7+eWXdZbr1q1LcnIy4eHh2vgVCgVDhgzRib9s2bLUq1dP70YFZcuWzXKP69aty927d7XLTZo04fLly4wfP579+/cTGxurs31ycjKHDh2id+/eODg4ZHnvk5OTOXPmjM4+ffv21SveTJnv/7PFjv369cPR0ZFDhw7prK9fv75OfaRKpeK5557Tuc7sHDlyBGdn5yzvz6BBgwyIPiP+Tp064erqilKpxMbGhlmzZhEVFaV9z41h0qRJhIeH8/333wMZRZ9r1qyhR48eFt9aTyYZC5H5hxMUFJTrdnfv3qVChQpGO++zrZQy61KmT5+OjY2Nzr/x48cDaL+UIyIiKF++fK7Hj4qKyrYlVLly5bSvZxo5ciSnT5/m77//BmDTpk3Y2dkxcOBAnfj++uuvLLE5OzsjhMiSMLI799ChQ9m4cSN3796lb9++eHl50bRpU/z9/XO9lvxwd3fXWc5sRJGUlKSNXwhBmTJlslzDmTNn9E54z54389yZ5wWYMWMGn3/+OWfOnKFbt264u7vTsWNHbTP6qKgo0tPTWbFiRZbYunfvDpCv+1sQUVFRWFtbZ2nMoFAoKFu2rM7nI7/XmdN5ypQpk2V92bJl9Yg6wx9//EGXLl0A+Oqrrzh58iR//vknM2fOBMgzpoJo0KABrVu31taN7dmzhzt37vDWW28Z7RymIluXWQhvb29q167NgQMHSExMxMHBIcs2p0+f5uHDh7z66qtGO++zT0UeHh5AxhdSnz59st2nevXqQEZLp+Dg4FyP7+7uTmhoaJb1Dx480DkfwMCBA5k6dSqbN2/mk08+4euvv6ZXr146TyIeHh7Y29vnWNn79PGyu75MI0aMYMSIESQkJHDs2DFmz55Nz549+eeff/D19c31mgzh4eGBQqHg+PHj2bbiM2XLPmtra6ZOncrUqVN5/PgxBw8e5IMPPqBr167cv38fNzc3lEolQ4cOZcKECdkew8/PT2fZ0Kdqd3d30tPTiYiI0Ek0QgjCwsJ44YUXDDr+0+f5448/sqwPCwvLsk6lUmVprAFZE+y3336LjY0Ne/bsQaVSadfv3r3b8ICzMXHiRPr168eFCxdYuXIlzz33HJ07dzbJuYxJPslYkJkzZ/Lo0SOmT5+e5bWEhAQmTpyIra2t9onCFKpXr061atW4fPkyjRs3zvafs7MzAN26dePIkSPZdhLN1LFjRwIDA7lw4YLO+q1bt6JQKGjfvr12nZubG7169WLr1q3s2bOHsLAwnaIygJ49e3Lr1i3c3d2zja2gRQeOjo5069aNmTNnkpqayrVr13Lc9tmnEn307NkTIQQhISHZxl+nTh29j10QpUqV4tVXX2XChAlER0dz584dHBwcaN++PRcvXqRu3brZxpfdk8SzCnKfOnbsCMA333yjs/6HH34gISFB+7qh2rdvT1xcHL/88ovO+u3bt2fZtlKlSvzzzz+kpKRo10VFRXHq1Cmd7RQKBdbW1iiVSu26pKQkvv76a71izOu+9e7dm4oVKzJt2jQOHjzI+PHjjVp0biryScaCvPbaa5w/f57PP/+cO3fuMHLkSMqUKcONGzf44osv+Pvvv9mwYQO1atXKsu+vv/6q/fJ/mj5PPevWraNbt2507dqV4cOH4+PjQ3R0NNevX+fChQvacuGPP/6Y33//nTZt2vDBBx9Qp04dHj9+zL59+5g6dSo1atRgypQpbN26lR49evDxxx/j6+vLb7/9xurVq3nzzTd57rnndM49cuRIdu7cyVtvvUX58uXp1KmTzuuTJ0/mhx9+oE2bNkyZMoW6deui0Wi4d+8eBw4cYNq0aTRt2jTX6xszZgz29va0bNkSb29vwsLCWLBgAa6urrn+cnZ2dsbX15eff/6Zjh07Urp0aTw8PAqU2Fq2bMkbb7zBiBEjOHfuHG3atMHR0ZHQ0FBOnDhBnTp1ePPNN/N9vIJ46aWXqF27No0bN8bT05O7d++ybNkyfH19qVatGgDLly+nVatWtG7dmjfffJNKlSoRFxfHzZs3+fXXX7V1KLmpXbs2AOvXr8fZ2RmVSoWfn1+2Capz58507dqV9957j9jYWFq2bMlff/3F7NmzadCgQbbN2vUxbNgwvvjiC4YNG8Ynn3xCtWrV2Lt3L/v378+y7dChQ1m3bh1DhgxhzJgxREVFsWjRIlxcXHS269GjB0uXLmXQoEG88cYbREVF8fnnn+v9NJrXfVMqlUyYMIH33nsPR0fHLPVYFsucrQ5Ksmdblz3tt99+E926dROlS5cWCoVCAMLLy0ucOXMmy7aZrZpy+peTzNZlixcvzvb1y5cvi/79+wsvLy9hY2MjypYtKzp06CDWrl2rs939+/fFyJEjRdmyZYWNjY0oV66c6N+/v3j48KF2m7t374pBgwYJd3d3YWNjI6pXry4WL16sbaX2NLVaLSpUqCCALE02M8XHx4sPP/xQVK9eXdja2mqb3E6ZMkWnCTAgJkyYkGX/LVu2iPbt24syZcoIW1tbbcyZzaRzc/DgQdGgQQNhZ2cnAPH6668LIXJu5ZddazAhhNi4caO2paC9vb2oUqWKGDZsmDh37lyu58+pddnzzz+fZdtnW0MtWbJEtGjRQnh4eAhbW1tRsWJFMWrUKHHnzh2d/YKCgsTIkSOFj4+PsLGxEZ6enqJFixZi/vz52m0yW5d9//332ca5bNky4efnJ5RKpU7rrexaaCUlJYn33ntP+Pr6ChsbG+Ht7S3efPNN8ejRI53tfH19RY8ePbKcK6fWYM8KDg4Wffv2FU5OTsLZ2Vn07dtXnDp1KkvrMiEyPiM1a9YUKpVK1KpVS+zcuTPb2Ddu3CiqV68u7OzsROXKlcWCBQvEhg0bsn2P8mpdJkTO9y3TnTt3BCDGjRuX5/VaCoUQ+eg1JJnVxx9/zOzZs1m1apVJi8okqaS5c+cOfn5+bNq0qUg8GaxYsYKJEydy9epVnn/+eXOHky+yuKwImDVrFqGhobz11ls4Ojry+uuvmzskSZIK0cWLFwkKCuLjjz/mlVdeKTIJBmSSKTLWrFnDmjVrzB2GJElm0Lt3b8LCwmjdujVr1641dzgFIovLJEmSJJORTZglSZIkk5FJRpIkSTIZWSfzhEaj4cGDBzg7OxeJDk6SJEmFSQhBXFwc5cqVw8oq/88nMsk88eDBA6OOCSZJklQc3b9/P88xC58mk8wTmb3l79+/n6VnryRJUkkXGxtLhQoVsh1ZJDcyyTyRWUTm4uIik4wkSVIOClqdICv+JUmSJJOxuCRz7NgxXnrpJcqVK4dCocjXsNkBAQE0atQIlUpF5cqVi1xnJUmSpOLK4orLEhISqFevHiNGjMjXrHtBQUF0796dMWPG8M0333Dy5EnGjx+Pp6enwbP2ZVKr1aSlpRnlWFLhUiqVWFtbyxaDkmQmFpdkunXrRrdu3fK9/dq1a6lYsSLLli0DoGbNmpw7d47PP//cKEkmPj6e4OBg5MAIRZeDgwPe3t7Y2tqaOxRJKnEsLskU1OnTp7VToGbq2rUrGzZsIC0tTWf++qelpKToTEr07HznkPEEExwcjIODA56envLXcBEjhCA1NZWIiAiCgoKoVq1agdr3S5JkuCKfZMLCwrLM3V2mTBnS09OJjIzMcQ7yBQsWMHfu3FyPnZaWhhACT09P7O3tjRazVHjs7e2xsbHh7t27pKam6kyTK0mS6RWLn3XPPmFkFm3l9uQxY8YMYmJitP/u37+f7+Pnh1CrSTp5kbgfD5J08iJCrS7wMSTjkE8vkmQ+Rf5JpmzZsoSFhemsCw8Px9raOtf5yO3s7PSeJjUv8XsCiJy5HPWDCO06ZTlPPD6ZhFPPtiY5pyRJkiUq8j/xmjdvjr+/v866AwcO0Lhx4xzrY0wpfk8AD0d+qJNgANShETwc+SHxewIMPodCoSA+Pj7b1+rXr09SUpLB58ivdu3asWfPnkI7nyRJRYvFPcnEx8dz8+ZN7XJQUBCXLl2idOnSVKxYkRkzZhASEsLWrVsBGDduHCtXrmTq1KmMGTOG06dPs2HDBnbs2GH02IQQiMTknF9Xq4n8YBlk1xBNAAqI/GA59m0aoVAqsz2GwkFlUAODS5cu6b2vJEmSsVlckjl37hzt27fXLk+dOhWA119/nc2bNxMaGsq9e/e0r/v5+bF3716mTJnCqlWrKFeuHF9++aXR+sg8TSQmE1SpS94b5niAjCeaO1VybqLtd+cACse8Gxl8/vnn+Pv7ExERwdy5cxk4cCCQ8ZQTFxeHk5MTlSpVYsSIEezfv5/Q0FBGjRrFhx9+CGQ8gTRt2pRTp07x4MEDOnfurO3EGhcXx9SpU7l8+TLJycm0aNGCFStWYGNjQ2BgICNGjCAtLY2aNWuSnJx90j169CiTJ0+mWbNmnDx5EhsbG7Zu3cq8efO4cuUKPj4+/PTTTzg5OZGWlsZHH33E4cOHSU1NpUaNGqxdu5ZSpUqxfft2li9fTmpqKkIIPv30U7p37w6Q6/VJkmQZLK64rF27dhlPDM/827x5MwCbN2/m6NGjOvu0bduWCxcukJKSQlBQEOPGjSv8wAuZQqHg5MmT7Nu3j7fffjvHhguPHz/m1KlT/PHHHyxevJiQkBDta7du3eLo0aNcvXqV/fv3c/r0aQCmTZtGmzZt+OOPP7h8+TLp6emsXLkSgKFDhzJ+/HguXLjA22+/zZ9//pljjNeuXWPcuHFcuXKF5s2b8+KLL7JkyRICAwOxsbFh+/btACxevBgnJyf++OMPLl26xPPPP8/s2bOBjOboZ86c4eLFi+zevZvRo0frdIzN7fokSTI/i3uSsWQKBxV+dw7k+HrS6cuEDXwnz+OU3bEY++b1cjxHfowePRqAypUr06pVK44fP86gQYOybDd48GAAPD09qVy5MkFBQfj4+ADw2muvoVQqsbe3p379+ty6dYvmzZuze/duzpw5w5IlSzKuKykJW1tbYmNjuXr1KkOHDgWgWbNm1KlTJ8cYq1evTv369QFo2LAhd+/e1Q4R3qhRI27fvg3A7t27iY2NZdeuXQCkpqZSpUoVIKO4dPDgwQQHB2NtbU1kZCR3796latWqeV6fJEn/iYiIwNPTs9DPK5NMASgUilyLshzav4CynCfq0Ijs62UUoCznhUP7F3KskzEktuw83S9EqVSSnp6e52tCCHbv3k3lypV1jhUbG1ug+qJnj//scmYDBSEEq1evpkOHDlmO8dprr/H555/Tq1cvAEqXLq1TRJfb9UmSlGHPnj0MGDCAvXv30rZt4bZwtbjisqJMoVTi8cmkJwvPvpjxH4/5E42SYDZu3AjAnTt3OHHiBK1atTL4mJlefvllPvvsM+0X9qNHj7h58yYuLi7Url2bbdu2AfDHH39w5coVo5xv6dKlJCYmApCYmMi1a9e0565UqRIA33zzDY8ePTL4fJJUkuzevZs+ffrQtWtXmjdvXujnl0nGyJx6tqXMxvkovXUfS5XlvCizcb7R+snY2dnRsmVLunTpwooVK4w6q+eyZcuwtramfv361K1bl06dOnHnzh0Atm7dysqVK2nYsCHr16+nadOmBp/v/fffp379+jRt2pS6devSrFkzbSu55cuX07t3b1q1asXly5epWLGiweeTpJJi165d9OvXj169erFz506zjN+nEHLkRyCjKMjV1ZWYmBjtpGXJyckEBQXh5+dX4OFIhFpN8pm/SH8YhXUZd1TN6hq9iEzKH0PeR0kqqtLS0qhfvz7169dny5YtWFsbVjuS3Xdkfsg6GRNRKJXYt2xg7jAkSSqBMgcHPnr0KKVLl0Zpxh+4srhMkiSpGNm4cSONGjXi8ePHeHp6mjXBgEwykiRJxca6desYNWoULVu2LFCRlinJJCNJklQMrFy5knHjxjFx4kRWr15tMaOPW0YUkiRJkt7+/vtvJk2axLRp01i2bJlFTbAoK/4lSZKKuBo1anD27FkaNWpkUQkG5JOMJElSkTV//nzmzJkDQOPGjS0uwYBMMkVSbnPGVKpUiatXr+p97KNHj9K4cWO995ckyfSEEMyaNYuPPvrI7K3H8iKLy4ogOWeMJJVcQgg++OADPvvsMz777DPee+89c4eUK5lk9BAaGkpoaKjOOjc3N/z8/EhOTiYwMDDLPg0bNgTgxo0bJCQk6LxWqVIlSpcune/zPz1nzPHjxxk/fjz29vY0adKEpwdw+Pfff5k8eTLh4eGkpqYyduxYxo8fD8CQIUP4+++/SU1NpWLFimzcuBEvL698xyBJknn873//47PPPmPp0qVMmTLF3OHkTUhCCCFiYmIEIGJiYrTrkpKSRGBgoEhKStLZdvbs2YKMcZa1/wYPHiyEEOLff//N8trTt7lZs2ZZXvv6668LFCsg4uLiRHJysihXrpw4cuSIEEKInTt3CkBcuXJFpKeni8aNG4vr168LIYRISEgQderUEefPnxdCCBEREaE93oIFC8SECROEEEIcOXJENGrUqEDxWLqc3kdJKooSEhLEDz/8UOjnze47Mj/kk4wexo4dy8svv6yzzs3NDYDy5ctz/vz5HPfdvHlztk8y+rhx4wYODg60a9cOgP79+/PGG29oX7t27Rqvvfaadvu4uDgCAwNp2LAh27Zt4+uvvyYlJYWkpCTKli2rVwySJJmeRqNhxowZDB06lNq1a9OnTx9zh5RvMsnowdvbG29v72xfU6lU2qKx7FSvXt1ocYhcxjYVQuDh4ZFt/c2JEydYuXIlp06dwtPTk19++YWPP/7YaHFJkmQ8arWaN954g02bNlGnTh1q165t7pAKRLYuK8Jq1KhBUlISx44dAzKG9Y6JiQEykpmDgwNbt27Vbn/z5k2io6N59OgRLi4ulC5dmtTUVNatW2eW+CVJyp1arWbEiBFs3ryZrVu3MmTIEHOHVGAyyRRhdnZ27NixgwkTJtCkSRP++OMP7Xwr1tbW/Prrr3z33XfUrVuX559/ntGjR5OUlES3bt2oWrUqNWrUoGvXrtopkiVJsixvvvkm27dvZ9u2bUUywYCcT0bL2PPJSJZDvo9SUXXixAkePnxI3759zR2K3vPJyCcZSZIkC5KSksKSJUtIT0+nVatWFpFgDCGTjCRJkoVITk6mb9++zJw5s9h0upaty/JBligWbRqNxtwhSFKekpKS6N27NwEBAfzyyy/FZngnmWRyYWNjg0KhICIiAk9PT4scfE7KmRCC1NRUIiIisLKywtbW1twhSVK2kpOT6dmzJ2fOnOG3336jQ4cO5g7JaGSSyYVSqaR8+fIEBwdz584dc4cj6cnBwYGKFStazCROkvQsOzs76tevz+zZs2nTpo25wzEq2brsidxaTqjVatLS0swUmWQIpVKJtbW1fAqVLFJsbCwXLlzQjtphyfRtXSafZPJBqVRa/HDakiQVLY8fP+bFF1/kzp073Lp1C0dHR3OHZBIyyUiSJBWy6OhounTpwu3btzl48GCxTTAgk4wkSVKhioyMpHPnzgQHB3PkyBHq1atn7pBMSiYZSZKkQpSQkICNjQ1HjhwpcoNd6kMmGUmSpEIQFhaGra0tvr6+nD17tsQ0RpFtOiVJkkwsJCSEtm3bMnLkSIASk2BAJhlJkiSTunfvHm3btiU5OZklS5aYO5xCJ4vLJEmSTOTOnTu0b98egICAAL1nwS3K5JOMJEmSiQQEBGBtbc2xY8dKZIIB2eNfS9/erJIkSc96/PgxpUqVAiAxMREHBwfzBmQEcj4ZSZIkCxAYGEjNmjXZtGkTQLFIMIaQSUaSJMlIrly5Qrt27fD09KRHjx7mDsciyCQjSZJkBJcuXaJ9+/b4+Phw+PBhvLy8zB2SRZBJRpIkyQhmzZqFn58fhw4dwsPDw9zhWAzZhFmSJMkAarUapVLJN998g0aj0Vb4Sxnkk4wkSZKeTpw4wfPPP8/t27dxcXGRCSYbMslIkiTpISAggBdffJGyZcvK+pdcyCQjSZJUQIcOHaJbt240a9aMvXv34uTkZO6QLJZMMpIkSQUQGxtLv379aNOmDb/++muJ7weTF1nxL0mSVAAuLi7s37+fOnXqoFKpzB2OxZNPMpIkSfmwe/du3njjDTQaDS+88IJMMPkkk4wkSVIedu3aRb9+/Xj8+DFqtdrc4RQpFptkVq9ejZ+fHyqVikaNGnH8+PFct9+2bRv16tXDwcEBb29vRowYQVRUVCFFK0lScbVjxw5ee+01+vfvz/bt27GxsTF3SEWKRSaZnTt3MnnyZGbOnMnFixdp3bo13bp14969e9luf+LECYYNG8aoUaO4du0a33//PX/++SejR48u5MglSSpOTp48yZAhQxg8eDBbt27F2lpWYxeURQ7137RpUxo2bMiaNWu062rWrEmvXr1YsGBBlu0///xz1qxZw61bt7TrVqxYwaJFi7h//36+zimH+pck6VlqtZqtW7fy+uuvY2Vlkb/JC02xGeo/NTWV8+fP06VLF531Xbp04dSpU9nu06JFC4KDg9m7dy9CCB4+fMiuXbtyHQU1JSWF2NhYnX+SJEkA69ev59ChQyiVSkaMGFHiE4whLO7ORUZGolarKVOmjM76MmXKEBYWlu0+LVq0YNu2bQwYMABbW1vKli1LqVKlWLFiRY7nWbBgAa6urtp/FSpUMOp1SJJUNK1YsYKxY8eyf/9+c4dSLFhcksmkUCh0loUQWdZlCgwMZOLEicyaNYvz58+zb98+goKCGDduXI7HnzFjBjExMdp/+S1WkySp+Fq6dCkTJ05k2rRpLFy40NzhFAsWV4vl4eGBUqnM8tQSHh6e5ekm04IFC2jZsiXvvPMOAHXr1sXR0ZHWrVszf/58vL29s+xjZ2eHnZ2d8S9AkqQiadWqVUybNo0ZM2bwySef5PijVioYi3uSsbW1pVGjRvj7++us9/f3p0WLFtnuk5iYmKXMVKlUAhlPQJIkSXnp1KkTixcvlgnGyCzuSQZg6tSpDB06lMaNG9O8eXPWr1/PvXv3tMVfM2bMICQkhK1btwLw0ksvMWbMGNasWUPXrl0JDQ1l8uTJNGnShHLlypnzUiRJsmBCCDZs2MCAAQOoXr061atXN3dIxY5FJpkBAwYQFRXFxx9/TGhoKLVr12bv3r34+voCEBoaqtNnZvjw4cTFxbFy5UqmTZtGqVKl6NChgyxTlSQpR0IIPvjgAz777DMcHBwYNGiQuUMqliyyn4w5yH4yklRyCCGYPn06S5cuZenSpUyZMsXcIVk8fb8jLfJJRpIkyVSEEEyePJkvv/ySlStXMmHCBHOHVKzJJCNJUomiUCjw8fFh3bp1vPHGG+YOp9iTxWVP6PsoeOafUEZvuUC8GlRWMK5DZRr5uuPppKKKlxM2SotrwCdJJZJarSYgIIAOHTqYO5RCs/PMbd7bfV27PKldZZpV88TNwbbA30/FZliZombM1owEA5CsgS8P3uav+7Fcuv+YW+Hx5g1OkiQgI8GMGDGCLl26cPPmTXOHU2ieTjAAy4/eJiwmpVC/n2SSMVB8esZ/M1vVawA7ayvsrJU8Skw1V1iSJD2Rnp7OkCFD2L59O9988w1Vq1Y1d0hmVdZVVajfT7JOxkBO1hCXDplljlZASrqGlHQ1bg625gxNkkq8tLQ0Bg0axO7du/n222959dVXzR2S2YXFJBfq95N8kjHQV8Ma4pQxuAAqK5jYqTJ1K7hQv0Ipqng5mTc4SSrhEhMTCQ4OZteuXSUywSzsVVNneVK7ypR1tSvU7ydZ8f+E7CcjScVHcnIykZGRlC9fHo1GI4fqNwJZ8S9JkgQkJSXxyiuv0KVLF9LT02WCMTNZJ2OgxNR0Am6EE/IoCR83e9pW98LBVt5WSTKHhIQEXn75Zc6cOcOvv/5aoqdLPhoYwhtbL5FKxtPE3Fdq8VoT30LvViFTvIECboRz8Ho4QZGJ+AeGE3Aj3NwhSVKJFB8fT/fu3Tl79iy///57ieoPk53MBAMZrV4/+jnQLN0qZJIxUMijJFTWSip7OmJvoyTkUZK5Q5KkEunChQsEBgZy4MAB2rRpY+5wzC67Bsrm6FZRcp8ljcTHzZ7A0DhuRySQlKbGx83e3CFJUokSHx+Po6Mjbdq04fbt2zg7O5s7JItgS9ZEY45uFTLJGKhtdS8AnTqZoiBNreFWeDyPElMLPMSEIftKkjFFR0fTpUsXOnfuzIIFC7QJxpC60r9Dopmw/QLh8Wl4OdmwalBDaviUNuVlmMT6YfWz1MmYo1uFbML8RElrwvx3aCyXg2OwVVqRkq6mfoVS1PDO33Ubsq8kGUtkZCSdO3cmODiYgwcPUq9ePe1rv195wMHr4aislSSlqelcy4tudfI3gWHHxQe5HZWiXa7sbsehdzoZPf6iRjZhlgrkUWIqtkorvYaYMGRfSTKG8PBwOnTowIMHDzhy5IhOggHD6krD49MAcFYpdZYl/cgkU0K5OdiSkq7Wa4gJQ/aVJGNYsmQJERERHD16lNq1a2d53cfNnqQ0tV51pV5ONgDEJat1liX9yOKyJ0pacZmsk5GKosze+2lpaYSGhlKxYsVst5N1Msan73ekTDJPlLQkI0lFzb1793jllVdYs2YNzZo1M3c4JY6cflmSpGIrKChI27mybNmyZo5GKghZxiFJkkW7efMmbdu2RalUcuzYMSpVqmTukKQCkE8yBjrx9wPGbL1IkgZUCnijbSVUdrY8SkjBzcEOP09HOZ6ZJOlJCEG/fv2wt7fn8OHD+Pj4mDukImXp3mt8eeyOdnlim0pM7f58ocYg62Se0Le8seYHv5Gk+W9ZAbSqWpq70UlUdHPAy0VVoDb6kiTpunLlCp6enrKYTA+V3v8ty7o7n/XQ61iyTsZMMhOMFRmD0AnAzkaJrdIKlY2VycYzM7SFl2whJlmyK1eusGDBAjZs2ECdOnUKvL8hn++YpFS++/MedyMT8fVwoP8LFXG1z38zfUP3L25kkjGQvVVGosl8mFEAKWlqUtI1JKdpTDae2a3weG2v+7tRiQAF6nVv6P6SZCqXLl2iU6dOVKxYkaSkJOztC/73Y8jn+7s/7/H71YeorK24FhoHwJg2VfN9bkP3L27kT1cDfTWsAfZP7qJKAW+3q0TzKh68+LwXLaq407mWl0nGMzO0173stS9ZonPnztGhQwf8/Pw4dOgQpUvr1z/FkM/33chEVNZWVPFywt7GiruRiQU6t6H7G9PENpVyXS4M8knGQK1qlOP6p4Vf3+LmYMvdqES9e90bur8kGVtwcDCdOnWiZs2a7Nu3D1dXV72PZcjn29fDgWuhcdwKjycpTYOvh0OBzm3o/sY0tfvzhV7R/yyZZIqozNFUny5zLsz9JcnYfHx8+Pzzz+nfv7/BHaIN+Xz3fyFjFIGn61QKwtD9ixvZuuwJ2eNfkswjICCAe/fuMXToUHOHIuVCjsIsSVKRc+jQIbp168a2bduQv3eLJ5lkJEkyi3379tGzZ0/atm3L7t27USgU5g5JMgGZZCRJKnSHDh3ilVdeoXPnzuzevRuVSmXukCQTkUlGkqRCV6dOHd5++2127dqFnZ2ducORTEgmGUmSCs2ePXt48OABXl5efP7559jayqbzxZ1MMpIkFYodO3bQq1cvVqxYYe5QpEIkk4wkSSa3detWhgwZwpAhQ5g/f765w5EKkeyMaaDspmqtUraUHHxSkp7YuHEjo0ePZtSoUaxbtw4rK/m3UFh2/RHE9B8Dtcs9a5emRz3fQp1+RL7bBpqw/QK3o1KIT9FwOyqFCdsvaAfnC4tJ4dL9x9wKjzd3mJJkNo6OjkyYMEEmGDN4OsEA7LkajX9gOAE3wgstBvkkY6Dw+DQAnFVK4pLVhMen6QzOFxaTbJGDT0bGJ7Pu6E2CIhLx83RgbLuqeDjlrxlpYmo6ATfCCXmUhI+bvZyUTcrWiRMnaNmyJQMGDGDAgAEF3t+Q4foN+XwX92kwTDX9SE6Kz50zEy8nGwDiktXaZTcHW1LS1RY9+OS6ozfZHxjOnehE9l0LZ93Rm/neN+BGOAevhxMUmVjov4qkouGLL76gdevW7Nu3T+9jGFIiYMjnu7iXRJhq+pGcyCRjoFWDGlLZ3Q4nOysqu9tl1Ml4OVG/QinKutpRv0Ipixx8MigiEVulFX4ejthZWxEUkf/hyEMeJaGyVlLZ07HQfxVJlm/hwoVMnTqVGTNm8OKLL+p9HEOG6zfk812cpsH4vE8tneWetUubbPqRnMgyDgPV8CnNoXc6ZV1v4ROA+Xk6cCM8nqDIBFLSNfh55n84ch83ewJD47gdkVDov4okyzZv3jxmzZrF7NmzmT17tkFDxRgyXL8hn+/iNA3Gq038eLWJn1ljkEmmiDK03Hhsu4yZ+p4us86vzF9BT9fJSFJ6ejpnzpxh3rx5fPjhhwYfz5Dh+g35fHs423IrPFa7b6da8vNtCDnU/xNFbaj/v0NjtdPLpqSrqV+hlMU/PUnFkxCCkJAQypcvj1qtRqlUmjskg3x17KZ2+uSkNA3d65Qp0dMnZ5JD/ZcwxancWCq6hBC888471K1bl8jIyCKfYMCypk/OD6FWk3TyInE/HiTp5EWEWm3ukHQYJcn8+uuvNGjQgKpVq9KnTx8OHDhgjMNKuSgKLdik4k0IwaRJk1iyZAnz5s3Dw8PD3CEZha+HA0lpGouYPjkv8XsCuNuwHw96TSR87Fwe9JrI3Yb9iN8TYO7QtAyqk1m2bBkNGzZk+vTp7Nq1i6pVq3L58mU+/fRTbt++zbhx44wVp8U6808oY7ZeID4dHK1hVs+aeLs7EZ+cjpPKGk8nlUna2cvpk4uu7PoZ2SitstSxpak1FtsfSaPRMGHCBNauXcu6det44403zB2S0RSV6ZPj9wTwcOSH8EyFR9qDCMJGfMjHHXtzrFJ17frBjXzwK+Nc6J8lg+pkVq5cyeXLl9m2bRuVKlWiWrVq1K5dmxo1avDpp58SGBhYZCYi0re8sc6HvxGX/t+yrRVM6lSd2xHxVPZyxN3RrsjUlwi1muQzf5H+MArrMu6omtVFUQyKPyzN71cecPB6OCprJUlpajrX8sLPwylLHVtQZHyW7brVKWfu8AG4ffs2jRs3ZsmSJYwYMcLc4ZQ4Qq3mbsN+qB9EZPu6Boh0dGZw/zfRPDXKwuAmFfX+LOn7HWlQKnvrrbcAuHXrFuvXryctLY0rV67w119/ERoaSq1atXBycuLPP/805DQWLf5JglEqQC0gVQN21laobJSorJVFpr4kfk8AkTOX63xoleU88fhkEk4925oxsuLn6X5GtyMSCHmURCkH2yyjRGS3nbmlp6ejVqupXLkyt27dws3NzdwhlUjJZ/7KMcFARj2IV0IcdR7e57K3r3a9OT5LRinD+fLLL+nXrx8rV64kOjqahIQEateuzfXr1zl48KBex1y9ejV+fn6oVCoaNWrE8ePHc90+JSWFmTNn4uvri52dHVWqVGHjxo16nbsgnJ6kafWT50FbK0hJ15CUpiY5XV0k6ksyH7uf/dCqQyN4OPJDiyrfLQ583OxJSlPr9DPKro4tu+3MKT09naFDhzJo0CCEEDLBmIk6Ooa4HXvztW3pxASdZXN8lvR6ktFoNJw8eZKIiAhatGhB7dq1OX36NPv27ePy5ctUrFhR207e1dW1wMffuXMnkydPZvXq1bRs2ZJ169bRrVs3AgMDqVgx+/LR/v378/DhQzZs2EDVqlUJDw8nPT09222N6athDbOtk6ns6aBTJ2OphFpN5MzlWcp1M14EFBD54Zc4dmsli86MJLt+Rpl1dk/XyVR0d8iynbmkpaUxcOBAfv75Z7799tsiUwxeXAi1mqSAc8Ru30vC78chNS1f+0U7OGr/f3AjH/w8HAr9s6RXnYyHhwcpKSmoVCoeP37MkCFDWLlyJY6OjnnvnA9NmzalYcOGrFmzRruuZs2a9OrViwULFmTZft++fbz22mvcvn2b0qVL63XOotZPxliSTl7kQa+JeW5XbveX2LdsUAgRSZYmJSWFAQMGsHfvXr7//nteeeUVc4dUYqQFhRD37e/Efvs76gf/jRFoU7sq6vthaGLjs/+BqABlOS98z39ntB+HhdpP5vvvvycuLo6IiAj+/PNP7ty5wwsvvEBYWJg+h9ORmprK+fPn6dKli876Ll26cOrUqWz3+eWXX2jcuDGLFi3Cx8eH5557junTp5OUlHO5Y0pKCrGxsTr/SqL0h1FG3U4qfrZv387vv//OTz/9JBNMIdAkJhP33T5Cek3kXpPXeLR0C+oH4ViVcsZ1dF/KH9pAxSOb8Fz2fsYOzz5UPln2mD/RIkof9Coua9++vfb/69evz+HDh3n33Xdp1aoVx48fx9vbW++AIiMjUavVlClTRmd9mTJlckxit2/f5sSJE6hUKn766SciIyMZP3480dHROdbLLFiwgLlz5+odZ3FhXcbdqNtJxYcQAoVCwfDhw2natCm1atXKeydJL0IIUi4EErd9L3E/HkTEP+kAqlBg3+4FXAb1wOHFllip7LT7OPVsCxvnZ9NgxwuP+RMtpsGOURpKJyQkMGrUKK5du0aXLl24cuWKwcd8tsw38wOfHY1Gg0KhYNu2bdo6oKVLl/Lqq6+yatUq7O2zVnLNmDGDqVOnapdjY2OpUKGCwXEXNapmdVGW88y5pcqTx25Vs7qFG5hkVgkJCfTt25cxY8bQt29fmWBMJD08mvhdB4jd/htpN+5o11tXKofLa91xfu1FrH3K5Li/U8+2OHZrZdFdD/RKMiNGjCAkJITg4GBCQkKIj8+Yb0EIgUqVv4mBcuLh4YFSqczy1BIeHp7l6SaTt7c3Pj4+Oo0MatasiRCC4OBgqlWrlmUfOzs77OzssqwvaRRKJR6fTOLhiBwGNBTgMe9ti/rQSqYVFxdHz549OX/+PB988IG5wyl2RHo6iYfOErf9NxIOnIL0jGFgFPZ2OL7UDpdBPVA1r4cin7OIKpRKi64v1SvJPHz4EF9fX1q0aIGPj4/OP0OHlrC1taVRo0b4+/vTu3dv7Xp/f/8cy4NbtmzJ999/T3x8PE5OGS25/vnnH6ysrChfvrxB8ZQEds/nPvifSE4ppEgkc4uNjaVbt25cuXKFAwcO0KJFC3OHVGyk/nuXuB17idu5D3V4tHa9XaNauAzqgWOvDihdLLclqr70SjJ79+avjba+pk6dytChQ2ncuDHNmzdn/fr13Lt3TztMzYwZMwgJCWHr1q0ADBo0iHnz5jFixAjmzp1LZGQk77zzDiNHjsy2qEzSFbN5NwCq9k0oPWmI9rE76cxlHn22gcgPV+DQvglKD9kvorh78803uXbtGgcPHqRJkybmDqfI08QnEr/7MHHbfyP5z6va9VYepXDu/yIuA7tjW8O8872YmmUMhPSMAQMGEBUVxccff0xoaCi1a9dm7969+Ppm9FwNDQ3l3r172u2dnJzw9/fn7bffpnHjxri7u9O/f3/mz59vrksoMjSJycRt2wNAqdF9dB67VU3qkPDrUVKv3SJy5peUWTfbTFFKhWXhwoVMmzaNhg0bmjuUIksIQfKZv4jb/hvxvxxBJCZnvKBU4tC5OS6DuuPQqTkKG4v8+jU6OZ/MEyW1n0zsN3uImLIQa19vKp7dkaXuJfnS34R0HQsaDWW3LcSxiyw+KW4iIyOZNGkSy5cvLzYjKZtDemgEcTv3Ebd9L2lBwdr1NlUr4jy4B879uhbpVppmGbtMKtqEEMRs+BEAl+G9sq3cV9Wvgeu4/sSs/pbId5dg36I+Vk6WO/S5VDDh4eF07NiR8PBwwsPDZZIpIJGSSsL+k8Rt30vikT9AowFA4WiPU++OuAzqgV3j50v0CAkyyZRgKX9eJfXqvyhUtrgM6pHjdqXfG0XC3mOk33lA1Lx1eC6cUohRSqYSGhpKx44defToEUePHqVmzZrmDsns8jsSecq1mxl9WnYdQBMdo12val4P50E9cHqpHVaOsj4YZJIp0WI2/gSAU+9OKEvnPMaclYMKzyXvENp3CrGbfsKpT0fsm8p+M0VZSkoKHTp0IC4ujoCAAJ577jlzh2R2eY1Ern4cR/yPB4nb/hspl2/8t01ZD5xf64bza92wrVLy+trlRdbJPFHS6mTSw6O5W78vpKVT/uD/sKtXPc99wicuIG7HXmyq+VLhyEYUdpY9urSUu+3bt9O0aVOqVKli7lDMLqcJwFAAAlRN65Jy6W9EypNpO2yscXyxFc6DeuDQ/oUS0Y+sUMcuk4q+uK9/hbR07Bo/n68EA+D+8VsoPUuT9u9dHn2x1cQRSqYQFBTEl19+CWQ0/ZcJJh8jkQPJZ/9CpKRiW6sy7vMnUunKT5TdOA/HTs1KRIIxhEwyJZBITydmy88AuI7sncfW/1GWcsbjs8kAPFr+DSmBt0wRnmQiN2/epG3btnz55ZcldkDY7OQ1AVgmj8+nU/7oZkqN7YfSvZTpAysmZJIpgRJ+P4E6NAIrj1I4vdw+7x2e4vhSOxy7t4Z0NRFTFiLUahNFKRnTjRs3aNu2LQ4ODgQEBJSIIuH8yu8I41bOjiW6lZi+ZJIpgTIr/F2GvFTgehWFQoHHwqlYOTuScuE6MV/tMkWIkhHdvHmTdu3aUapUKY4ePYqPj4+5Q7IoSo9S+dquKPdxMad8ty57esTivCxdulSvYCTTS70RRPKJC2Blhcvr+s0NYl3WA/c544mYtpjoBf/DsVtrbHzLGTlSyVjKlStHr169mDt3Ll5e5ptd0xKl3X1A9IL/5b6RHIncIPlOMhcvXtRZPn/+PGq1murVMyqN//nnH5RKJY0aNTJuhJJRxWzIeIpxfLElNuVzHkI8L85DehL3gz/Jpy4RMf1zvL9bIosSLMylS5ewtramdu3aOrPMShkdkeO/20/E+18g4hNRqOwyBoJ90ppMy8ImACuK8l1cduTIEe2/l156iXbt2hEcHMyFCxe4cOEC9+/fp3379vTokXOnPsm8NHEJxH23DwCXUX0MOpbCygrPpe+iUNmSdPRP4nbuM0aIkpGcO3eODh068P7775s7FIujfhTLwzFzCH/rE0R8Iqqmdalw8mvKbJqP0ttTZ1tlOS/KbJxvMROAFUV69ZPx8fHhwIEDPP/88zrrr169SpcuXXjw4IHRAiwsJaGfTMyGH4l8/4uMfi4nvzbKk8ejL7cRPW8tVqWcqXDyG6y9ShshUskQZ86coWvXrtSqVYt9+/bpzLNU0iUeP0/4hE9Qh0aAtZLS746k1MTB2qeU/Pb4L4kKtZ9MbGwsDx8+zLI+PDycuLg4fQ4pmZgQgpiNGeOUuY7oZbSirVJvDsC2djU0j+OImrncKMeU9Hfy5Em6dOlC3bp12b9/v0wwT4iUVCLnrCK07xTUoRHYVKmAz941uE0ZppNEMicAc+7TCfuWDWSCMQK9kkzv3r0ZMWIEu3btIjg4mODgYHbt2sWoUaPo08ewYhjJNJJOXCDtn7soHO1xfq2b0Y6rsLHGa9l7oFQSv/swCftPGu3YUsFpNBratGnD77//XmyfyAsq9UYQwS+OI2bVtyAELsNepvyhDagayLHaCoNeSWbt2rX06NGDIUOG4Ovri6+vL4MHD6Zbt26sXr3a2DFKRhD7pMLfuX9XrJwdjXpsu3rVKTV+AAAR7yxBHRtv1ONLebt48SLp6em0bt2aPXv2aGeILcmEEMT87weCO40m9eq/WLm7Unbrp3gueUcOXlmI9EoyDg4OrF69mqioKC5evMiFCxeIjo5m9erVODoa9wtMMlx6yEMSfj8OgOtI0zxpur0zEutKPqhDI4iet9Yk55Cyt2/fPlq0aMGyZcvMHYrFSH8YRdjAd4mcsQyRnIp9h6ZUCNiCY7fW5g6txDGoM6ajoyN169alXr16MrlYsJgtv4BGg6plA5NN9Wplb4fXF+8CELv5Z5JOXzbJeSRde/bs4ZVXXqFz5868/fbb5g7HIiTsO8H9tq+TeOgMCpUtHgsm4/3tYtmZ0kz0TjLHjx9nyJAhNG/enJCQEAC+/vprTpw4YbTgJMOJlFTivvkVKNg4Zfqwb9UQ5yE9AYiYshBNcopJz1fS7d69mz59+tCjRw927dqFnZ2duUMyK01CEhHTFhM2dAaaqBhsn69Kef//4Tq6r+zDZUZ6JZkffviBrl27Ym9vz8WLF0lJyfgyiYuL49NPPzVqgJJh4n89ijriEUpvz0IpKnCfMx5lGXfSbt3n0ZItJj9fSXbw4EF69+7Nzp07sbUt2dMuJF/6m+COo4jd+gsoFLhOeI3y+9eZ7Mldyj+9+sk0aNCAKVOmMGzYMJydnbl8+TKVK1fm0qVLvPjii4SFhZkiVpPStw14cHQ8c36+yp2oJCq52zPnldqUL23cStc0tYZb4fE8SkzFzcGWKl5O2Cjz9/sguPubpPx5Fbf3R1F62nCjxpWT+N+O8XD4TLBWUt7/f9jVrmr0cySmphNwI5yQR0n4uNnTtroXDrYlYw6+sLAwypYti0ajQaPRYG2t33Xr87ky5LNoCkKt5vHybUQv3gjpapTennitmolDa8sdeeRSUDhvfH2ex8kaSqmsWD+0EfX9LH+4n0LtJ3Pjxg3atGmTZb2LiwuPHz/W55BF1pyfr3Iq6DEPY1M4efsxc36+avRz3AqP53JwDGExKVy6/5hb4flrvZVy+QYpf14FG2tchrxk9Lhy4tSjDY492v43UnN6utHPEXAjnIPXwwmKTMQ/MJyAG+FGP4cl2rp1q/YHnZWVld4JBvT7XOn7WTSFtHuhPHhlItELvoJ0NY6vdKBCwGaLTjAAb3x9nvBEDakaCE/U8MbX580dkknplWS8vb25efNmlvUnTpygcuXKBgdVlNyJSkIJeJdSYa3IWDa2R4mp2CqtKOuqws5ayaPE1Hztp51e+aV2hV7p6fHZZKxcnUi59Dcx640/UnPIoyRU1koqezpib6Mk5JHx77ul2bhxI8OHD2fw4MHUrWv4YI36fK70/SwakxCCuO/3E9xuBMln/0Lh5IDXypmU+WoOSjfL7xv0OFkDgJ1SobNcXOmVZMaOHcukSZM4e/YsCoWCBw8esG3bNqZPn8748eONHaNFq+RuT7qA0MfJpIuMZWNzc7AlJV1NWEwyKelq3BzyLn9XP4ol/kd/AFxMXOGfnYyRmicAEP3Z/0gLCjHq8X3c7ElKU3M7IoGkNDU+bsW738O6desYNWoU48aNY926dVhZGV5Epc/nSp99jEkdE0f42LmEj5+PJi4B1Qu1qXB0E84DXiwylfulVBnvXYpa6CwXV3o9a7/77rvExMTQvn17kpOTadOmDXZ2dkyfPp233nrL2DFatDmv1M5SJ2NsVbwy6nieLgfPS9yOvYjkVGyfr4qqSR2jx5QfzoN7EP+jP0nHLxAxbRHePywz2hdB2+oZZdhP18kUVwkJCSxYsIBJkybxxRdfGO0e6vO50mcfY0k6eZHwCfNJDwkHpZLS74yg1KTBKAwoMjSH9UMbZamTKc70qvi/d+8e5cuXJzk5mcDAQDQaDbVq1cLR0ZH79+9TsWJFU8RqUvpWat0Me8zkby/yIDaVci62LHutAVXLlsrXvqaqvBYaDfeaDiT9zgM8l76Ly9DCq495VtrtYO63fR2RnIrn8vdxGSRH6c5vY5E0tYbrwdEkpINIiqVxDV9srUveWFoiNY3ozzbweOV2EAIbv/J4rfkQVaPn8965mDgaGMIbWy+RCtgC64fVp12twp18rlAr/v38/IiMjMTBwYHGjRvTpEkTnJyciI6Oxs+vZDUZnPztRa4/TCQ2KZ3rDxOZ/O3FvHd6wlSV14mHzpJ+5wFWrk449e1slGPqy6ZyedzeGwVA1KyV+Z7qtjjLb2OR92fNo0/3Ttx9+Jh7idbcjkgo5EjNL/WfOwS/OJbHK7aBEDgP6Un5wxtKVIIBtAkGIPXJclGhV5LJ6eEnPj4elUplUEBFzYPYVBRAKQcbFE+W88tUldexGzJGW3Ye2B0rB/O/H6XG9ceuXnU0MfFEzlhm7nDMLj+NRebNm8fST+fQqGV7KnqVMlslu7lkjBr+U8a4Y1f+xaq0K2U2f4LXF+9h5eRg7vAKXeY7r3hmuSgoUNlM5hTMCoWCWbNm4eDw35utVqs5e/Ys9evXN2qAlq6ciy0xSek8TkxDPFnOLx83ewJD44xaeZ0WFELi4bMAuI4o/Ar/7CisrfH84j2CO48h4dejJOw9hmP3rE3gS4pK7vaExKRk21hECMHs2bOZN28eE9/9kOZ9x/AwNsUslezmkh4eTcTkz0j0Pw2AfbsX8FrxAdZlPcwcmfnYkpFYxFPLRUWBkkzmFMxCCK5cuaLTy9jW1pZ69eoxffp040Zo4Za91iBLnUx+maLyOmbzbhAC+w5Nsalc3uDjGYtdnWqUmvAaj7/cRsS7S1G1bIDS1dncYZlFbo1Fzp49y7x581i4cCFTpk3P0vGxuEs4cJLwSZ+hiXyMws6W0rPexHV0HxRGaE1XlK0fVj9LnUxRUaAkc+TIEQBGjBjB8uXL5XwVQLpaQ1KamrQn/01X57/Nu4OtNd3qlDNaLJrEZOK27QHAdZRlPMU8zW36CBL2BJB2O5joj9fiueQdc4dUaCLjk1l39CZBEYn4eTrwWb/6eDj9V5QphEChUFC3YWNW7NyPqmwVDgaG0ba6FzW8i//fmSYxmag5q4jdtBsA2+er4LVmFnY1S1a/O8h+RIB2tXz457PCreg3Fr1+HmzatEkmmCcmbL/A7agU4lM03I5KYcL2C2aLJf7Hg2hi4rH29cahYzOzxZETK3s7PL94D4DYrb+QdDL/jSSKunVHb7I/MJw70YnsuxbOuqP/dWYWQjBp0iSWLVtGwI1wgvAqUSMZpFy+kTHu2JME4/rmAHz2rSuRCQaK34gAeiWZBQsWsHHjxizrN27cyMKFCw0OqigJj08DwFml1FkubEIIYp5U+LsM72Wx08bat6iPy+uvABAxZRGapJIxUnNQRCK2Siv8PByxs7YiKCIRyJjJcvz48axYsQIHB4cSNZKBUKt5tPwbgl8cS9rNeyjLeuC96ws8Pn4LK1XJHVG6uI0IoFeSWbduHTVq1Miy/vnnn2ft2pI1YVVplTUCiE1WI54sQ0Yfh79DYzl9K5K/Q2NJK0Axmj5S/rxK6tV/UahsLb4vSulZ41CW9SAtKJhHizfle7/E1HR+v/KA/x27xe9XHpCYavwx0UzFz9OBlHQNQZEJpKRr8PN0QK1WM2bMGNatW8fGjRt54403LGokA2Pdb6FWk3TyInE/HiTp5EWEWk1a8EMe9JlM9Px1GeOO9WybMe5Y28ZGvoqi4eq9SFp9up+aH/5G6pOviuIyIoBePf/CwsLw9vbOst7T05PQ0FCDgypKqno5cjcmVWcZ/htI0FZpxd2ojF+tpixb145T1rsTytKuJjuPMShdnPBcNJWwYR/wePW3OPXqgF3d5/LcL7NfkcpaSWBoHIBR67RMaWy7jJGoM+tkxraryqJFi9i8eTNbt25lyJAhgGWNZGCM+x2/J4DImctRP4jQrrNyc0EkpyKSklE42uOxYDLOr3UrMsPCmMK4b84THKubxG2tKBYjAuiVZCpUqMDJkyezdLw8efIk5coVjT96Y3mUpMHLyZYKpR24H53Io6SMnyFPDyQYFpNs0j4O6eHRxP+S0SjDdZRpplc2NsdurXF8uT0JvxwhfMpCyu9fl+fwIE8XJd2OSChSRUkeTipm9tQdcmj8+PHUr1+fbt26adcZuzGIIQy93/F7Ang48sP/2t0+oXkUC4B1lQqU27EYG7+iWaFtTFGJGQnGwcaKxDQN9tZwfb5ll0jkl17PYaNHj2by5Mls2rSJu3fvcvfuXTZu3MiUKVMYM2aMsWO0aH6e9qSpBfejE0lTC/w8M4o3CnMgwbivf4W0dOwa1cKuXnWTncfYPBZMxqqUM6l//cPjNTvz3N6SipL0lZaWxoQJE7h58yaurq46CcbSGHK/hVpN5MzlWRKMzjZJyVhXLGuESIs+d4eMH1iJaRqd5eJA7wEyo6OjGT9+PKmpGb/QVSoV7733HjNmzDBqgJbunRdrAtcJikjCz9P+yXLhDSQo0tOJ2foLUHSeYjJZe5XG/eO3iJi4gEeLNuLYvQ22VSrkuL0lFSXpIyUlhQEDBrB371569OhB1arGn8zNmAy538ln/tIpIsuO+kEEyWf+wr5l/vuWFVdrhzRi3DfniUpMx93BmrVDinYR2dP0GiAzU3x8PNevX8fe3p5q1aoV6TnG9R38zdzi9wTwcMSHWHmUotKlH1DYFaW+wBmt4kL7TSUp4Byqlg0o99PyYlk2n5yczKuvvsrBgwf58ccf6d69u7lDMhlNcgqRM5YR982ePLf1Wjcb5z6dCiEqyVCFOkBmJicnJ1544QVq165dpBNMUaZttjzkpSKXYCBjiCLPz99B4aAi+eTFfH0xFUUDBw7k0KFD/PLLL8U2waSHRxO9cAN3G7ya7/exsCfTkwpfvp9kpk6dyrx583B0dNSOYZaTpUuXGiW4wlQUn2RSbwRxv9UwsLKi4vnvsClfxtwh6e3xmm+JmrUKKxcnKpz8utiNU7V//35sbGzo0KGDuUMxupTAW8Ss/Y64H/whNaOfmLJ8GURsPJq4hOzrZRSgLOeF7/nvLLZPl6RL3+/IfNfJXLx4kbS0NO3/S+YXs3E3AI4vtizSCQbA9Y1+xP90mJSL14l8/wvKbv7E3CEZLD4+nrVr1zJ16lS6du1q7nCMSmg0JB46S8y670gKOKddb/dCbUqN7Y9jj9Yk7DuZ0bpMgW6ieVIa6jF/okwwJYBBdTLFSVF7ktHEJXCnTm9EQhLeP3yBQ5ui34kt5dpNgjuNhnQ1ZTbNx6lnW3OHpLfY2Fi6devGlStXOHfuHM89l3c/oKJAk5hM3Hf7iFn3PWk372WsVCpx7NmWUuP6o2qsO89Ldv1klD5eeMyfWKTf35LI5E8yeRWRZVIoFCxZsiTfAUj6iftuPyIhCZtqvti3Lh4tUeyer0qptwfz+IutRL63FPtWDVGWKnojNT9+/JgXX3yRv//+m4MHDxaLBJMeFknMhh+J3fKztp+LlbMjzkNfwnV0X2wqZN8U2alnWxy7tSL5zF+kP4zCuow7qmZ15RNMCVKg4rKnnT9/HrVaTfXqGf0y/vnnH5RKJY0aFY8vPEuWMaFTRoW/64hexao1ltvUYST8epS0m/eImrMKr2XvmzukAomJiaFTp04EBQVx+PBhGjZsaO6QDJLy1z88Xvcd8T8dgrSMDoPWvt64vtEfl0Hd8zWBmEKplM2US7B8J5nMYf4ho2Lf2dmZLVu24ObmBsCjR48YMWIErVu3Nn6Uko6kExdI++cuCkd7nF+z3M58+rBSZYzU/OClCcRt+w2nvp1xKEJPak5OTjRv3pwNGzZQr149c4ejF6HRkHjgFI/X7CT51CXtelWzeriO64/jiy3lk4iUb3rVyfj4+HDgwAGef163/PXq1at06dKFBw8eGC3AwlKU6mTChn9Iwm8BuIzoheeiaeYOxyQi3l1C7KbdWFcqR4WALRYxjXRuwsPD+eeff2jVqpW5Q9GbJj6RuG9/J2b9LtKCgjNWWitxeqUDruP6o6qfdVBcqeQweZ3Msyd7+PBhliQTHh5OXFycPoeU8ik95CEJvx8HwHVk0erhXxDuH40jYf8p0u884NHijbjPHm/ukHIUGhpKx44dSUtLIzAwEBsbG3OHVCDpIQ8z6lu2/oImJh4AK1cnXIa9guvoPliXK1ojK0iWRa8k07t3b0aMGMGSJUto1ixjcqwzZ87wzjvv0KdP8f3iswQxW34BjQZVywbY1vDLe4ciysrZEc/FUwkb/D6PV+/EqVdHixyXLSQkhA4dOpCQkMCRI0csKsEItTrXCvfki9eJWfsd8T8fAbUaAJvK5XEd2x/nAS9i5Vj0xoaTLI9eSWbt2rVMnz6dIUOGaPvOWFtbM2rUKBYvXmzUAKX/iJRU4r75FQDXkZY3vbKxOXZpiVPvjsT/dIjwSZ9R3v8rFDaWM3DgvXv36NChA2lpaQQEBFClShVzh6SVbdPhcp54zHsbrKyIWbOT5D+uaF9TtWpIqXH9cejcHIVV0Z6/RLIsen2aHBwcWL16NVFRUVy8eJELFy4QHR3N6tWrcXR0NEpgq1evxs/PD5VKRaNGjTh+/Hi+9jt58iTW1tbUr1/fKHFYkvhfj6KOeISyrAeO3UpGAwv3TyZh5eZC6rWbPF61w9zh6EhKSsLLy8siE8zDkR9mGaBS/SCCh6Nm8XDEhxkJxsYap/4vUv7wRnx+Wo5j15YywUhGZ9AnytHRkbp161KvXj2jJReAnTt3MnnyZGbOnMnFixdp3bo13bp14969e7nuFxMTw7Bhw+jYsaPRYrEkmROTuQx/xaJ+0ZuStadbxq9v4NHnm0m9lftnoDAEBQURFxdH9erVOXnyJJUqVTJ3SFr5GWIfhQLXSUPwvfA9ZVbNxK5OtUKLTyp59E4yx48fZ8iQITRv3pyQkBAAvv76a06cOGFwUEuXLmXUqFGMHj2amjVrsmzZMipUqMCaNWty3W/s2LEMGjSI5s2bGxyDpUm5fIOUP6+CjTUuQ14ydziFyql/V+zbvYBISSViyiKExnxznt+4cYNWrVoxefJkAIvro5SfIfYRAsf2TYrd+HCSZdIryfzwww907doVe3t7Ll68SEpKCgBxcXF8+umnBgWUmprK+fPn6dKli876Ll26cOrUqRz327RpE7du3WL27Nn5Ok9KSgqxsbE6/yyZdnrlnm1L3Mi1CoUCzyVPRmo+fZnYr381SxyBgYG0a9eOUqVK8cknljm2WvrDKKNuJ0mG0ivJzJ8/n7Vr1/LVV1/ptKZp0aIFFy5cMCigyMhI1Go1ZcroDvhYpkwZwsLCst3n33//5f3332fbtm1Y5zGFb6YFCxbg6uqq/VehQs6TZZmb+lEs8T/6A+BSxCYmMxabit6UnpEx62r03DWkh+bxa93Irly5Qrt27fD09OTIkSOULWuZMzrm9wdISfuhIpmPXgX7N27coE2bNlnWu7i48PjxY0NjArIWQwghsi2aUKvVDBo0iLlz5xZojKgZM2bojMcWGxurV6IJjo5nzs9XuROVRCV3e+a8UpvypY07C2bcjr2I5FRsn6+Kqkkdox67KEhMTSfgRjgh1RvQrGZV7K/fJOLdpZTZNI+Us1eMNiaW9jxPzQTpYJvxJ3L48GHKly+Pv78/7u6W+wVtU7UCWCshXZ39Bk+G2Fc1q5vlpdyu31gK4xyWLDI+mXVHbxIUkYifpwNj21XFw8myOxobSq9319vbm5s3b2ap8Dxx4gSVK1c2KCAPDw+USmWWp5bw8PAsTzeQUUR37tw5Ll68yFtvvQWARqNBCIG1tTUHDhzIdg4POzs7o0y0NuGbc1x+kADAzchEIr45x88T2xl83ExCoyFmU0ZRmeuo3hZXB1AYAm6Ec/B6OCprJWG9+/Hav4tI3HeCu7Ve0Q7WCE+a6H4yqcCj+6apNdwKj2fftQdcC47Fw0VFYGhGp+Im5exwd3dn0qRJjB07FpUq6xdC5v5PT7Vto8y7kODvkGgmbL/Aw/g0HJTQs355XvBzz/LFGxaTyOJ9ulN8l3XNOmaYOuoxof2n5ZpgIOch9p++z5nX361OuTyvoyAK4xzZ0fc9MrZ1R2+yPzAcW6UVN8IzOr7O7Fm7wMe5ExHLe7suc/9RChXc7Fj4aj0qeVrmSCV63eWxY8cyadIkzp49i0Kh4MGDB2zbto3p06czfrxhPbNtbW1p1KgR/v7+Ouv9/f1p0aJFlu1dXFy4cuUKly5d0v4bN24c1atX59KlSzRt2tSgePKSmWByWjZU4qGzpN95gJWrE059Ohv12EVFyKMkVNZKKns6kuRbgfAXMgadfDrBAKhDI3g48kPi9wQU6Pi3wuO5HBzD7YhEYpLTcVZZY2+j5Njxk1StWpWffspI8tklmKf3D4tJ4dL9x9x68uWRlwnbL3A7KoX4FA3hiRq+/eMe/oHhBNwI19lu8b7rHLkRRfDjJA7/HcXifdezHEsd9ZgHfSeTGngbZRl33BdMRlnOU2cbZTkvymzMeQqFp++zvY2SkEdJ+bqOgiiMc2RH3/fI2IIiErFVWuHn4YidtRVBEYl6Hee9XZe5GBxHVEIqF4LjeG/XZSNHajx6Pcm8++67xMTE0L59e5KTk2nTpg12dnZMnz5d+zRhiKlTpzJ06FAaN25M8+bNWb9+Pffu3WPcuHFARlFXSEgIW7duxcrKitq1dX8JeHl5oVKpsqwvimKfjLbsPLB7ie2B7eNmT2BoHLcjEkhOSaP0v7ey31AACoj88Escu7XKd9HZo8RUbJVWVPZ05MHjZG6HJxB39yq/LnqbRg0b5NkkPnP/sq4qwmKSeZSYmq/zhsdndGS2UypIUQvSNWT7xRsUkYS1lYIKpR24H51IUITu6+roGB70nULqtVsovUpTbveX2FatiOuIXgUaYv/p+5yUpsbHzfift8I4R3b0fY+Mzc/TgRvh8QRFJpCSrsHPM+9RrLNz/1EKCsDT2Y6IuBTuP0oxbqBGpHdh6CeffMLMmTMJDAxEo9FQq1YtnJyMUxcxYMAAoqKi+PjjjwkNDaV27drs3bsXX19fIGOsqLz6zBQWGyDtmWVjSQsKIfHQWQBcRxT/Hv45aVs9Y+yskEdJVL17G5vI6Jw3FqAOCSf5zF/5Hl7ezcGWu1GJeLs4UL1MCg9vXOCXhW/RpMkL/LZnT56f68z9w2KSSUlX4+Zgm6/zejnZEJ+SQoo6o1OLtRXZfvH6edoTFJXI/ehE0tQCP8//Xlc/in2SYG6i9PwvwUDBh9h/+j5n1pcYW2GcIzv6vkfGNrZdVQCdOhl9VHCzIyIhlYi4FDRPli1VgUdhTktLo0uXLqxbt65YTMaUSd8RRnf9EcT0HwO1y5/3qcWrTYwzpljk7FXErP4W+w5NKbfzc6Mcs6iL+/Eg4WPn5rmd17rZOPfplK9jPl1eX8rehtH9euDgYM/PP/+Mg0PevzSzK++Pik/Osx7l75Boxn1zjnuP0lAANcraM7xFVXrUL5evOpmMBDOZ1Cv/ovR0w+uHZdwr5WH2egdLZCl1MvmVV7zmqJMptFGYbWxsuHr1aomsgM7O9+eDsVUqUACaJ8vGSDKaxGTitv8GZFT4SxlM0UTXRmlFDW8X0tPTsba25tdff8He3h57+/wV5WTu/7TMehRrKwVBUYnAdZYM0J0Xp4ZPaTo/X05bERybouGf8Fj62VbU2a6sq0OWfdWP4wjtN5XUK/9i5VGKcj8t57aLO5eDY7BVWnE3KqOs/9m4Sqrs3iNLllmHlNN7WcnThZ1vFo2hpfRK5cOGDWPDhg3GjqVIerps1OrJsjHE/3gQzeM4rH29cejYzCjHLA5UzepmVGjn8Rsn4fBZNPH5r1Tds2cPderUITQ0lNKlS+c7weTk6XoUG6UiSz3Kf9sVvCJYHZORYFIu38hIMD8ux7a6n069g5210mz1DpLhitN7qVedTGpqKv/73//w9/encePGWcYtW7p0qVGCKwpMUTb69PTKLsN7yVkIn6JQKvH4ZBIPR36YkWhyKOyN+XIb8d/+TukZY3Ae2C3Xe/jTTz8xYMAAevbsabQ+MLnVo+huV7CK4IwEM42US39j5e5KuR+XY1czo9uApdQ7SIYrTu+lXjNjtm/fPucDKhQcPnzYoKDMQd/yRlOUjSb/eZWQ7m+iUNnie/lHlKVdDTpecZTtUPY+XnjMexuF0oqoOWu0szvaPl8V93lvZTuN8/fff8+gQYPo3bs327ZtM9p8MPnt21KQznnq2PiMJ5gL17Eq7Uq5n5ZjV+u/0Z+LWr2DlDNLfC/1/Y7UK8k8LXP3ol5HY0nTLz8c9zHxP/jjPLA7Xl/OMGssliy3SblEahoxG3/k0eebtbM9OrzYCvc5b2JbJaPOIywsjMqVK9O7d2+2bNmS7yGJzEETl8CDflNJOR+YkWB+WIZdbf1aJkmSPgo9yWzYsIEvvviCf//9F4Bq1aoxefJkRo8erc/hzE7fGxiTlMp3f97jbmQivh4O9H+hIq72+j/apodHc7d+X0hLp/zB/+U4G6Ql/tIxJmPdV3XUY6IXbyJ2888Zsz9aK3Ed2QflW4PY9e9jTp/5g6YvNOK1Zn4GvW+mpIlL4MGA6aT8eRUrN5eMBFNIw/OXxGFQCkNR/PsttNZlAB999BFffPEFb7/9tnZY/dOnTzNlyhTu3LnD/Pnz9TlskfTdn/f4/epDVNZWXHsyTMaYNvr/woz7Zg+kpWPXqFau0w3n1fqkqDPWfVW6l8Lzsym4juxN1OxVJB48w1dfruD6iiXUfWU8qS1asi8wEiul0qD3zVQ08YmEZiaYUs6U2/VFoc7/YqxhUCRdxf3v92l6pc41a9bw1VdfsWDBAl5++WVefvllFixYwPr161m7dq2xY7RodyMTUVlbUcXLCXsbK+5G6jdMBIBITydmy88AuOYx2nJxan2SHWPeVwDb5yrhvWMxvw1pyQfxN1GnpfLK778yddUX1L/5N3cjjDsckDFkJpjkpxNM3cLtm2asYVAkXcX97/dpeiUZtVpN48aNs6xv1KgR6enpBgdVlPh6OJCUlvHom5SmwddDv2EiABL2nUT9IBwrj1I4vZxz4wrIaH2Skq4uFq1PsmPM+5ppxYoVTFq2kIlvv83AqXNJdHTC9WE4Q77ZzCurVpNy7aYRIjcOTXwioa+9Q/IfV7BydcpIMLk82ZqKn6cDKekag4dBkXQV97/fp+lVXDZkyBDWrFmTpany+vXrGTx4sFECKyr6v5BRifx03YG+YjY8abY8uCcKu9w/dFW8MoY6ebpMtzgx5n0F+PXXX5k4cSLTp09n0aJFxCan8UPr5rh98xO1Dx2h9JVAgjuMwnlQd0q/P9qs861o4hMJHfguyWf/wsrFCW8zJRgw3jAokq7i/vf7NL0q/t9++222bt1KhQoVaNYso6PgmTNnuH//PsOGDdNpBlpU+szoW6mV0/wYBa3YS70RxP1Ww8DKiornv8OmfJlcj19cmaqiOS0tje+++44evV/l+3P3dZKXQ3gkUfPWkfBzRtN7haM9bpOH4jq2P1b2hTsmlCYhidCB75B8+jJWzo54//AFqgY19T6eqSuYi2IFtqFK2t9kpkKt+L969SoNG2YMt37rVsaIuJ6ennh6enL16lXtdkW9WXN+5DQ/RkEr9mI27gbA8cWW2gST2/GLK2NXNC9cuJA2bdrQvHlzBg8ezFfHbmbboKDs/+aSNKYvUR+tJOXidaI/WU/s1l8o/dFYnHp1LJTPsiYhidDB7/2XYL5fYlCCAdNXMJekCuxMJe1v0lB6JZkjR44YO44i6+n5MW5HJGiHaS/I0OKauATidv4OgMtI3Qr/nI5fXD1d0RwUmaB3RbMQglmzZjF//nwWL16sbQX5dIOCW+HxOg0K7JvWxWffWuJ/PEjUvHWk3w8j/I25xKzfhce8t1E1ft4o15gdTWIyoUPeJ/nkRRRODnh/twRVI8PPZ+oh7i1lCP3CVNL+Jg1VvJ9rC4GPmz1Jaeos82MUpGIv7rv9iIQkbKpWxL6Nbq/0nI5fXBmjolkIwYwZM5g/fz4LFy5k+vTp2tfyalCgsLLC+dUuVDy9Dbf3R6FwUJFy7hoh3cbxcOxc0oIfGnyNz9IkJhM25D2ST1xA4eRAue+WGC2hmbqCuSRVYGcqaX+Thir+BYkmltP8GPmt2Ht6nDLXkVmnVzbX/BvmYoyK5rlz57Jw4UK++OILJk+erPNafhsUWDmoKD1tOC6DexL96VfEffs78T8eJGHvMVzHDcBt0hCsnAxvaaVJSiFs2AySjl9A4WhPuZ2fo3rBeP1QTF3BXJIqsDOVtL9JQxk8rExxYa5hZRKPnye0z2QUjvZUuvITVs6Oee8k5erSpUv88ccfvPHGG0Y7ZsrlG0TOWknyqUsAKL1K52vwzdxoE8zRP1E42OP93efYN61rtJglyZj0/Y6UxWVmFrshY/545/5dZYIxgEajYeXKlSQnJ1O/fn2jJhgAu3rVKbf7S8pu+QQbv/Kow6OJmLKQ4I6jSTx+vuDxJj+TYL5dLBOMVCzJJGNG6SEPSdh3AgDXkbn38JdyplarGTNmDBMnTuTo0aMmO49CocCxexsqnNiK+7y3sHJ1IvXaTUL7TCZ06AxSb+VvSnBNcgphr8/8L8HsWIR983omi1uSzEkmGTOK2fILqNWoWjbAtoZxpmwuadRqNSNGjGDz5s1s3bqVF1980eTnVNjaUGrcACqe3YHLqD6gVJK47wT3Ww0j8sMvUT+K1W4r1GqSTl4k7seDJJ28iCYxiYfDPyTp8FkUDiq8ty/EvkV9k8csSeYi62SeKOw6GZGSyt0Gr6KOeESZDR/nOYyMlJVGo2Hw4MF8//33fPPNN7z22mtmiSP1nzvawTcBrNxccJs+AusypYmctVJnzhuFnS0iJRWFvR3eOxZj37KBWWKWpIIq1M6YkuHi9wSgjniEsqwHjt2KxlzdlkahUPDcc8+xc+dO+vbta7Y4MgffTDzyB1GzV5F6/TZRM5dnu61IyehH4vr2IJlgpBJBFpeZiXacsuGvoLCRub4gUlJSOHToEAqFgrlz55o1wTzNoX0Tyh/egPuiqWCV+wgBcdt+Q6jVhRSZJJmP/HYzkD5jbaX89Q8pf14FG2tchrxUSJEWD8nJybz66qscOXKE27dvU6ZMmbx3MpK8xunKHNMqIcmGBprcS6HVIeEkn/nLZE8zxp5MT9JVUscv04e8KwbSZ6ytzM6XTj3bmnW036ImKSmJXr16cezYMX7++edCTTCQ9zhdmWNa1boXnq/jpT+MMkmcYPzJ9CRdcvyy/JPFZQYq6KRO6kexxP/gD5DRMknKl4SEBHr27MmJEyf47bff6NKlS6HHkNdEU5ljWpXyzV/yM+UPDGNP+ibpenr8MnsbpRy/LBcyyRiooGNtxe3Yi0hOxfb5qqia1CmkKIu+xMRE4uPj2bdvHx06dDBLDHmN05U5ptVFz/LElSpFjgVmClD6eKFqZrrOl6aY9E36jxy/LP9kcZmBCjLWltBoiNmU0cPfdVTWccqkrGJjY0lISMDb25szZ86Y9Z7lNU7X02NaxU0eifPcJ3MpPZ1tnoTvMX+i3sPR5IexJ32TdMnxy/JPJhkD2Sit8HJRkZSqwctFleuETYmHzpJ+5wFWrk449elciFEWTY8fP6Zr165YWVlx6tQpsydlG6VVrnOlONha/1cu36YK8b6liZy5XKefTHwpN4JGD8KjY/NsjxEcHc+cn69yJyqJSu72zHmlNuVLF3zQSVd7W1kHY0I677WUK5lkDFSQCtbYJxX+zgO7Y+UoH69zEx0dTZcuXQgKCsLf39/sCUYfTj3b4titFcln/sL/6DWOPdYQWqUKiWqI+PNetp+TOT9f5VTQY5RASEwKc36+yv9GNCv84CXJSGSdjIHyW8GaFhRC4qGzALgM71WIERY9kZGRdOjQgbt373L48GHtLKxFkUKpxL5lA/6oWZcH1apRuaxLrp+TO1FJKAHvUiqsFRnLklSUySRjoPxWsMZs3g1CYN+hKbZVKhRukEXMmTNniIiI4MiRI9SrVzwGjszv56SSuz3pAkIfJ5MuMpYlqSiTxWUGyk8FqyYxmbjtvwEZFf5S9jLHROrZsyf//PMPjo7FZ+qD/FbEz3mldpY6GUkqyuQAmU+YcoDM2G/2EDFlIdYVvan4xw6TtioqqkJCQujQoQOjRo3i3XffNXc4kiQ9Q05aZqGenl7ZZUQvmWCyce/ePdq2bUtycrLFjEMmSZJxyOIyE0s5d43UK/+iUNniMqiHucOxOHfu3KF9+/YoFAoCAgKoVKmSuUOSJMmI5JOMiWWOtuzUuxPK0q5mjsbyfPLJJ1hbW8sEI0nFlHySMaH08GjifzkCgKscp0yHRqPBysqKFStW8PjxY8qWLWvukCRJMgGZZAwUFpPI4n3XCYpIws/TnnderElZ14zmqXHf7IG0dOwa1cKuXnWTxVDUhh0PDAykX79+7Nixg7p161pUgrGEIfLzmlJAyj95L83Pcr+JiojF+65z5EYU1lYKgqISgessGdAIkZ5OzJafAdM/xRSlYcevXLlCx44dKVu2rEUll0yWMER+XlMKSPkn76X5yZRuoKCIJKytFFQo7YCNUkFQREYP7YR9J1E/CMfKoxROL7c3aQxFZdjxixcv0r59e8qXL8+RI0fw8rK8QQUtYYj8vKYUkPJP3kvzk0nGQH6e9qSpBfejE0lTC/w8M3poZ45T5jK4Jwo70xa3FIVhx9PS0ujTpw+VK1fm0KFDuLtb5mRtljBEfl5TCkj5J++l+cniMgO982JNQLdOJvVGEEnHL4CVVaGMU1YUhh23sbHhhx9+oEqVKri6Wm4rO0sYIj+vKQWk/JP30vxkj/8njNnjP+K9L4jd+COO3VtTdsunRoqwaDpx4gTr1q1j48aN2NjYmDscSZL0JHv8WwhNXAJxO38HwGVkyW62fPToUV588UWCg4NJTZVl4ZJUEskkY2Rx3+1HJCRhU7Ui9m0amTscszl48CDdu3enefPm/Pbbb8VqsEtJkvJPJhkjenqcMteRJXd65evXr9OzZ0/atm3LL7/8goODnF9ekkoqWfFvBEKtJvnMXyQeO0faP3fBXoXTgBfNHZbZ1KhRg5UrVzJ06FDs7OzMHY4kSWYkk4yB4vcEZJnHXaGApGPncerZ1oyRFb7du3ej0Wjo06cPo0ePNnc4kiRZAFlcZoD4PQE8HPmhToIBEInJPBz5IfF7AswUWeH7/vvv6devHz/99JO5Q5EkyYJYbJJZvXo1fn5+qFQqGjVqxPHjx3Pc9scff6Rz5854enri4uJC8+bN2b9/v0njE2o1kTOXQy4NwCM//BKhVps0Dkuwfft2XnvtNQYMGMCmTZvMHY4kSRbEIpPMzp07mTx5MjNnzuTixYu0bt2abt26ce/evWy3P3bsGJ07d2bv3r2cP3+e9u3b89JLL3Hx4kWTxZh85q8sTzA6BKhDwkk+85fJYrAEu3btYujQoQwbNowtW7ZgbS1LYCVJ+o9FdsZs2rQpDRs2ZM2aNdp1NWvWpFevXixYsCBfx3j++ecZMGAAs2bNytf2Be1oFPfjQcLHzs1zO691s3Hu0ylfMRRF9+/fZ/369cydOxcrK4v8zSJJkhEUm86YqampnD9/ni5duuis79KlC6dOncrXMTQaDXFxcZQuXTrHbVJSUoiNjdX5VxDWZfI39lZ+tytqvv32WyIjI6lQoQLz5s2TCUaSpGxZ3DdDZGQkarWaMmXK6KwvU6YMYWFh+TrGkiVLSEhIoH///jlus2DBAlxdXbX/KlSoUKA4Vc3qoiznCTl1hVGA0scLVbO6BTpuUbBixQoGDhzIli1bzB2KJEkWzuKSTKZnOzIKIfLVuXHHjh3MmTOHnTt35jqU/IwZM4iJidH+u3//fsHiUyrx+GTSk4VnX8z4j8f8iSiUygId19ItXbqUiRMnMn36dKZOnWrucCRJsnAWl2Q8PDxQKpVZnlrCw8OzPN08a+fOnYwaNYrvvvuOTp1yrwexs7PDxcVF519BOfVsS5mN81F6e+qsV5bzoszG+cWun8zChQuZNm0aH3zwAYsWLSqxIxpIkpR/FtcUyNbWlkaNGuHv70/v3r216/39/XnllVdy3G/Hjh2MHDmSHTt20KNHj8IIFchINI7dWpF85i/SH0ZhXcYdVbO6xe4JBsDd3Z3Zs2cze/ZsmWAkScoXi0syAFOnTmXo0KE0btyY5s2bs379eu7du8e4ceOAjKKukJAQtm7dCmQkmGHDhrF8+XKaNWumfQqyt7cvlLlLFEol9i0bmPw85iCE4NixY7Rt21b24pckqcAsrrgMYMCAASxbtoyPP/6Y+vXrc+zYMfbu3Yuvry8AoaGhOn1m1q1bR3p6OhMmTMDb21v7b9KkSea6hGJBCMEHH3xAu3btOHv2rLnDkSSpCLLIfjLmYMxJy4oDIQTTp09n6dKlfPHFF0yePNncIUmSZEb6fkdaZHGZZF5CCCZOnMjKlStZuXIlEyZMMHdIkiQVUTLJSFmkpKRw9epV1q1bxxtvvGHucCRJKsJkkpG01Go1ISEhVKxYkYMHD6Ishi3kJEkqXBZZ8S8VPrVazYgRI2jevDkJCQkywUiSZBTySUYiPT2doUOH8v333/PNN9/g6Oho7pAkSSomZJIp4dLS0hg4cCA///wzO3fupG/fvuYOSZKkYkQmmRIuMDCQQ4cOsWvXrlxHVJAkSdKHTDIlVHJyMtbW1tSrV487d+4UysgIkiSVPLLivwRKSkrilVde0Q7TIxOMJEmmIp9kSpiEhARefvllzpw5w549e8wdjiRJxZxMMiVIXFwcPXv25MKFC+zbt4/WrVubOyRJkoo5mWRKkPXr13Px4kX2799PixYtzB2OJEklgBwg84niPEBm5qyiGo2GW7duUa1aNXOHJElSEaPvd6Ss+C/moqOjadWqFQcOHMDKykomGEmSCpUsLivGIiMj6dSpEyEhIXlOXS1JkmQKMskUU+Hh4XTs2JHw8HCOHDlC7dq1zR2SJEklkEwyxdSwYcOIiooiICCAGjVqmDscSZJKKJlkiqlVq1ahVqt57rnnzB2KJEklmKz4L0bu3bvHwIEDiYmJoUqVKjLBSJJkdvJJppgICgqiffv2WFlZERMTI4eKkSTJIsgnmWLg5s2btG3bFhsbGwICAqhYsaK5Q5IkSQJkkinyYmNjadeuHQ4ODhw9epQKFSqYOyRJkiQtmWSKOBcXFz755BOOHj2Kj4+PucORJEnSIZNMEXXlyhXWrVsHwOuvv07ZsmXNHJEkSVJWMskUQZcuXaJ9+/asW7eO1NRUc4cjSZKUI5lkiphz587RoUMHKleuzKFDh7C1tTV3SJIkSTmSSaYIuXTpEp06daJGjRr4+/vj5uZm7pAkSZJyJZNMEeLn58eQIUPYv3+/7AcjSVKRIDtjFgHHjh2jXLlyVK1alZUrV5o7HEmSpHyTTzIW7uDBg7z44ovMnz/f3KFIkiQVmEwyFmzfvn289NJLtGvXjrVr15o7HEmSpAKTxWUG6rV8L5dC/5vBur63gt2Tuhf4OGlqDbfC43mUmIqbgy1//3GUAf370bVrV77//nvs7OyMGbZkoMTUdAJuhBPyKAkfN3vaVvfCwbbgf07GOo5UeJ79W63i5YSNUv5ez4n8NBvo6QST3XJ+3QqP53JwDLZKK+5GJRKfmE7fvn3ZvHmzbKZsgQJuhHPwejgqayWBoXEAdKtTzmzHkQrPs3+rADW88z/nfUkjk4yZZf4qOvFvBClqDTZRt3GvVIvnm7Rh9MA+ee4fk5TKd3/e425kIr4eDvR/oSKu9jIpFYQ+TxMhj5JQWSup7OnI7YgEQh4l6XVuYx1HKjyPElOxVVpR1lVFWEwyjxJlh+jcyGc8M8v8VZSSLvjtx12MG9Ad/5934uaQv0Tx3Z/3+P3qQ25HJrD3ykO++/OeiSMufjKfJoIiE/EPDCfgRnie+/i42ZOUpuZ2RAJJaWp83Oz1OrexjiMVHjcHW1LS1YTFJJOSrs7332pJJZ9kzCzzV9GDc/vYt/pDmnXpxdtjR1HFyylf+9+NTERlbUUVLyduhcdzNzLRxBEXP/o8TbSt7qXdN/PpRx/GOo5UeDL/Np+uk5FyJpOMmbk52LJ500a++vR92r08gJWr1lDLp1S+9/f1cOBaaBy3wuNJStPg6+FgumCLKR83ewJD4wr0NOFga22UuhNjHUcqPDZKK1kHUwAyyZhZZU9HQq+fp/eg4cxbuJRqZQv24e3/QkXS0gWnb0Xh7mhNaXsbElPTZQulAnj6acLDyZbUNA3/O3YrX/UzxqgTi4xPZt3RmwRFJOLn6cCo1pV5nJguWy8VAbKlWd7kN5EZPXjwgHLlyvHjzm0olUoUCkWBj+Fqb4ufpyO3IhNQWSs5efsRDiobvX8dl8Q/mqefJn6/8iDH1l7ZNRDIrBNTWVtx9UEs4bHJdKhZtkD3bt3Rm+wPDMdWacWN8HiiE1JoUtlLtl4yk4I0BJEtzfJWvL89LNjSpUt57rnnuHXrFtbW1nolmExP1ynY2ygNaqGU+UcTFpPCpfuPuRUer/exiqLc7mV2DQSerhMTQnAlJLbA9y4oIhFbpRV+Ho7YWVsRFJGkbb1kZ62UrZcKWUEagjzd0ky+V9mTScYMPvvsM6ZNm8bEiROpXLmywcczZgulkv5Hk9u9zC4B+Xo4kJSW8fQXn6qmjItdge+dn6cDKekagiITSEnX4OdpL1svmVFBfrTJlmZ5k8VlhWzevHnMmjWL2bNnM3v2bIOeYDIZs4WSm4Mtd6MSS+wfTW73MrsGAi2qegAZrfyeL+dMOTfHAt+7se2qAuRaJyMVnoI0BJEtzfKmEELo10W9mImNjcXV1ZWYmBhcXPJfpjp03WGOB/33S6e1nz1fj+2Q7baRkZHUrl2bt99+m5kzZxocsymUxDqZ/MqrrF7eu+JBDvWTPX2/I2WSeULfG5gfQghSUlJQqVRERUXh7u5u1ONLkiSZmr7fkfJnlokJIZg+fTrdunVDrVbLBCNJUokik4wJCSGYOHEiS5cu5dVXX0WpVJo7JEmSpEIlCxpNRKPRMH78eNatW8e6det44403zB2SJElSoZNJxkT27dvH+vXr2bhxIyNGjDB3OJIkSWYhk4yRCSFQKBR0796dy5cvU6dOHXOHJEmSZDayTsaI0tPTGTx4MBs2bACQCUaSpBLPYpPM6tWr8fPzQ6VS0ahRI44fP57r9gEBATRq1AiVSkXlypVZu3ZtIUWaIS0tjddee43vv/+eUqVKFeq5JUmSLJVFJpmdO3cyefJkZs6cycWLF2ndujXdunXj3r3sJ+QKCgqie/futG7dmosXL/LBBx8wceJEfvjhh0KJNyUlhX79+vHLL7/www8/0Ldv30I5ryRJkqWzyM6YTZs2pWHDhqxZs0a7rmbNmvTq1YsFCxZk2f69997jl19+4fr169p148aN4/Lly5w+fTrbc6SkpJCSkqJdjo2NpUKFCnp1xpw2bRqrVq3ixx9/pHv37gXaV5IkqSgoNp0xU1NTOX/+PF26dNFZ36VLF06dOpXtPqdPn86yfdeuXTl37hxpaWnZ7rNgwQJcXV21/ypUqKB3zDNmzGD//v0ywUiSJD3D4pJMZGQkarWaMmXK6KwvU6YMYWFh2e4TFhaW7fbp6elERkZmu8+MGTOIiYnR/rt//77eMXt4eNC2bVu995ckSSquLLYJ87OjE2c2DS7I9tmtz2RnZ4ednZ2BUUqSJEm5sbgnGQ8PD5RKZZanlvDw8CxPK5nKli2b7fbW1tZyrDBJkiQzsrgkY2trS6NGjfD399dZ7+/vT4sWLbLdp3nz5lm2P3DgAI0bN8bGxsZksUqSJEm5s7gkAzB16lT+97//sXHjRq5fv86UKVO4d+8e48aNAzLqU4YNG6bdfty4cdy9e5epU6dy/fp1Nm7cyIYNG5g+fbq5LkGSJEnCQutkBgwYQFRUFB9//DGhoaHUrl2bvXv34uvrC0BoaKhOnxk/Pz/27t3LlClTWLVqFeXKlePLL7+U/VUkSZLMzCL7yZiDKSctkyRJKuqKTT8ZSZIkqfiQSUaSJEkyGZlkJEmSJJORSUaSJEkyGZlkJEmSJJORSUaSJEkyGYvsJ2MOmS25Y2NjzRyJJEmS5cn8bixorxeZZJ6Ii4sDMGjIf0mSpOIuLi4OV1fXfG8vO2M+odFoePDgAc7OzrmO9pydzAnP7t+/X2I6cpbEa4aSed3ymkvGNUPu1y2EIC4ujnLlymFllf+aFvkk84SVlRXly5c36BguLi4l6gMJJfOaoWRet7zmkiOn6y7IE0wmWfEvSZIkmYxMMpIkSZLJyCRjBHZ2dsyePbtEzbRZEq8ZSuZ1y2suOUxx3bLiX5IkSTIZ+SQjSZIkmYxMMpIkSZLJyCQjSZIkmYxMMpIkSZLJyCSTD6tXr8bPzw+VSkWjRo04fvx4rtsHBATQqFEjVCoVlStXZu3atYUUqXEV5Lp//PFHOnfujKenJy4uLjRv3pz9+/cXYrTGUdD3OtPJkyextramfv36pg3QRAp63SkpKcycORNfX1/s7OyoUqUKGzduLKRojaOg17xt2zbq1auHg4MD3t7ejBgxgqioqEKK1nDHjh3jpZdeoly5cigUCnbv3p3nPkb5LhNSrr799lthY2MjvvrqKxEYGCgmTZokHB0dxd27d7Pd/vbt28LBwUFMmjRJBAYGiq+++krY2NiIXbt2FXLkhinodU+aNEksXLhQ/PHHH+Kff/4RM2bMEDY2NuLChQuFHLn+CnrNmR4/fiwqV64sunTpIurVq1c4wRqRPtf98ssvi6ZNmwp/f38RFBQkzp49K06ePFmIURumoNd8/PhxYWVlJZYvXy5u374tjh8/Lp5//nnRq1evQo5cf3v37hUzZ84UP/zwgwDETz/9lOv2xvouk0kmD02aNBHjxo3TWVejRg3x/vvvZ7v9u+++K2rUqKGzbuzYsaJZs2Ymi9EUCnrd2alVq5aYO3eusUMzGX2vecCAAeLDDz8Us2fPLpJJpqDX/fvvvwtXV1cRFRVVGOGZREGvefHixaJy5co667788ktRvnx5k8VoSvlJMsb6LpPFZblITU3l/PnzdOnSRWd9ly5dOHXqVLb7nD59Osv2Xbt25dy5c6SlpZksVmPS57qfpdFoiIuLo3Tp0qYI0ej0veZNmzZx69YtZs+ebeoQTUKf6/7ll19o3LgxixYtwsfHh+eee47p06eTlJRUGCEbTJ9rbtGiBcHBwezduxchBA8fPmTXrl306NGjMEI2C2N9l8kBMnMRGRmJWq2mTJkyOuvLlClDWFhYtvuEhYVlu316ejqRkZF4e3ubLF5j0ee6n7VkyRISEhLo37+/KUI0On2u+d9//+X999/n+PHjWFsXzT8lfa779u3bnDhxApVKxU8//URkZCTjx48nOjq6SNTL6HPNLVq0YNu2bQwYMIDk5GTS09N5+eWXWbFiRWGEbBbG+i6TTzL58OzQ/0KIXKcDyG777NZbuoJed6YdO3YwZ84cdu7ciZeXl6nCM4n8XrNarWbQoEHMnTuX5557rrDCM5mCvNcajQaFQsG2bdto0qQJ3bt3Z+nSpWzevLnIPM1Awa45MDCQiRMnMmvWLM6fP8++ffsICgpi3LhxhRGq2Rjju6xo/vwqJB4eHiiVyiy/bsLDw7Nk+Exly5bNdntra2vc3d1NFqsx6XPdmXbu3MmoUaP4/vvv6dSpkynDNKqCXnNcXBznzp3j4sWLvPXWW0DGl68QAmtraw4cOECHDh0KJXZD6PNee3t74+PjozPse82aNRFCEBwcTLVq1Uwas6H0ueYFCxbQsmVL3nnnHQDq1q2Lo6MjrVu3Zv78+UWihKKgjPVdJp9kcmFra0ujRo3w9/fXWe/v70+LFi2y3ad58+ZZtj9w4ACNGzfGxsbGZLEakz7XDRlPMMOHD2f79u1Frqy6oNfs4uLClStXuHTpkvbfuHHjqF69OpcuXaJp06aFFbpB9HmvW7ZsyYMHD4iPj9eu++eff4wyJ1Nh0OeaExMTs0zUpVQqgYJPR1xUGO27rEDNBEqgzKaOGzZsEIGBgWLy5MnC0dFR3LlzRwghxPvvvy+GDh2q3T6z2d+UKVNEYGCg2LBhQ5Fuwpzf696+fbuwtrYWq1atEqGhodp/jx8/NtclFFhBr/lZRbV1WUGvOy4uTpQvX168+uqr4tq1ayIgIEBUq1ZNjB492lyXUGAFveZNmzYJa2trsXr1anHr1i1x4sQJ0bhxY9GkSRNzXUKBxcXFiYsXL4qLFy8KQCxdulRcvHhR22zbVN9lMsnkw6pVq4Svr6+wtbUVDRs2FAEBAdrXXn/9ddG2bVud7Y8ePSoaNGggbG1tRaVKlcSaNWsKOWLjKMh1t23bVgBZ/r3++uuFH7gBCvpeP62oJhkhCn7d169fF506dRL29vaifPnyYurUqSIxMbGQozZMQa/5yy+/FLVq1RL29vbC29tbDB48WAQHBxdy1Po7cuRIrn+jpvouk0P9S5IkSSYj62QkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWQkSZIkk5FJRpIkSTIZmWSkYqVdu3ZMnjzZ3GHkyNLjK6hnr6e4XZ9kODmfjGRU7dq1o379+ixbtsws5//xxx8tZkqF7O6FJcVnCk9fn7k/C5JlkElGKnSpqanY2tqa5NilS5c2yXELIrfrs4T4TKm4X59UcLK4TDKa4cOHExAQwPLly1EoFCgUCu7cuUO7du146623mDp1Kh4eHnTu3BmASpUqZfmVW79+febMmQNkTAa1aNEiKleujL29PfXq1WPXrl25xpBd8c3EiRN59913KV26NGXLltUeP9OuXbuoU6cO9vb2uLu706lTJxISEvIdQ3bXl9u9eDq+lJQUJk6ciJeXFyqVilatWvHnn38WKP7sJCQkMGzYMJycnPD29mbJkiU6587r3gPs27ePVq1aUapUKdzd3enZsye3bt3K9byZ58jp+rdu3Yq7uzspKSk6+/Xt25dhw4bleV2ZunXrxuuvv65dPnz4MO7u7qSnp+f7GFLhkElGMprly5fTvHlzxowZQ2hoKKGhoVSoUAGALVu2YG1tzcmTJ1m3bl2+jvfhhx+yadMm1qxZw7Vr15gyZQpDhgwhICCgQHFt2bIFR0dHzp49y6JFi/j444+1M/6FhoYycOBARo4cyfXr1zl69Ch9+vTRznaY3xievb7c7sXT3n33XX744Qe2bNnChQsXqFq1Kl27diU6Ojpf8efknXfe4ciRI/z0008cOHCAo0ePcv78+QLdt4SEBKZOncqff/7JoUOHsLKyonfv3mg0mjz3zen6+/Xrh1qt5pdfftFuGxkZyZ49exgxYkS+Y/Px8SEkJES73L59e1JSUjh58mSBrlEyPVlcJhmNq6srtra2ODg4ULZsWZ3XqlatyqJFi/J9rISEBJYuXcrhw4dp3rw5AJUrV+bEiROsW7eOtm3b5vtYdevWZfbs2QBUq1aNlStXcujQITp37kxoaCjp6en06dMHX19fAOrUqVPgGLK7vpzuxdPXuGbNGjZv3ky3bt0A+Oqrr/D392fDhg0688nnFH924uPj2bBhA1u3btVus2XLlgJPjdy3b1+d5Q0bNuDl5UVgYCC1a9fOdd+cPgv29vYMGjSITZs20a9fPwC2bdtG+fLladeuXb5j8/Hx4cSJE9plhUKBSqUiIiIi38eQCodMMlKhaNy4cYG2DwwMJDk5OcsXaWpqKg0aNCjQserWrauz7O3tTXh4OAD16tWjY8eO1KlTh65du9KlSxdeffVV3NzcChRDQa8P4NatW6SlpdGyZUvtOhsbG5o0acL169fzFX9Ox01NTdUmRsioK6levXqB4/voo484c+YMkZGR2ieYe/fu5ZlkcjNmzBheeOEFQkJC8PHxYdOmTQwfPhyFQpHvYzz7JHPp0iUePXpEixYt9I5LMg2ZZKRC4ejomGWdlZUVz07MmpaWBqD9Qvvtt9/w8fHR2cbOzq5A5362NZdCodAeX6lU4u/vz6lTpzhw4AArVqxg5syZnD17tkAxZHd9ecm89me/XIUQOutyiz+34+Ymt3uf6aWXXqJChQp89dVXlCtXDo1GQ+3atUlNTc3z+Llp0KAB9erVY+vWrXTt2pUrV67w66+/FugYPj4+xMfHExsbi5OTE1OmTGHw4MGUK1fOoNgk45N1MpJR2draolar87Wtp6cnoaGh2uXY2FiCgoIAqFWrFnZ2dty7d4+qVavq/MuubsMQCoWCli1bMnfuXC5evIitrS0//fSTwTHkdS+qVq2Kra2tTrFPWloa586do2bNmnpfT9WqVbGxseHMmTPadY8ePeKff/7RLud27wGioqK4fv06H374IR07dqRmzZo8evSoQHHkdv2jR49m06ZNbNy4kU6dOhX4Pc1M+sHBwbz33nuEhoaycuXKbLe9f/8+Fy5cKNDxJeORTzKSUVWqVImzZ89y584dnJyccm3S2qFDBzZv3sxLL72Em5sbH330EUqlEgBnZ2emT5/OlClT0Gg0tGrVitjYWE6dOoWTk5NOyyJDnD17lkOHDtGlSxe8vLw4e/YsERER1KxZ0+AY8roXjo6OvPnmm7zzzjuULl2aihUrsmjRIhITExk1apTe1+Tk5MSoUaN45513cHd3p0yZMsycORMrq/9+U+Z27wHc3Nxwd3dn/fr1eHt7c+/ePd5///0CxZHd9WfGMHjwYKZPn85XX33F1q1bC3yNmUlm2rRp3Lhxg+PHj+Pi4pJlu+PHjzNr1iySkpL46KOP6NGjR4HPJRlGJhnJqKZPn87rr79OrVq1SEpK0vl1/KwZM2Zw+/ZtevbsiaurK/PmzdPZft68eXh5ebFgwQJu375NqVKlaNiwIR988IHR4nVxceHYsWMsW7aM2NhYfH19WbJkibYi3pAY8nMvPvvsMzQaDUOHDiUuLo7GjRuzf/9+3NzcDLquxYsXEx8fz8svv4yzszPTpk0jJiZG+3pe997Kyopvv/2WiRMnUrt2bapXr86XX35ZoMr57K6/UqVKQMZ979u3L7/99hu9evUq8PV5eHhgZ2fH3bt3OXbsWJbizEytW7cmPT2dpKQkmWDMRCHyU4ArSVKRZ2k98Dt37kzNmjX58ssvTXYOtVrN7du3sbe3x83NTa+6M8kw8klGkqRCFR0dzYEDBzh8+HCO9SjGolQqqVatmknPIeVOJhlJkgpVw4YNefToEQsXLixws2qp6JHFZZIkSZLJyCbMkiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZjEwykiRJksnIJCNJkiSZzP8Bin6Tz/o0ev4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def make_insert_example():\n", + " n = np.random.randint(6, 11)\n", + " x1 = sample_seq(n)\n", + " i = np.random.randint(1, n - 1)\n", + " gapsize = int(np.random.choice([0, 0, 1, 2])) # often empty -> nu* = 0\n", + " left = x1[i - 1]\n", + " rj = i + gapsize\n", + " right = x1[rj] if rj < n else PAD\n", + " gap = [x1[j] for j in range(i, rj)]\n", + " y = [left, MASK, right]\n", + " p = posterior_at(y, 1)\n", + " nu = float(sum(p[v] for v in set(gap)))\n", + " return window_feat(y, 1), torch.tensor([nu]), nu\n", + "\n", + "# IQL's target nu_star is a *soft* value in [0,1], so BCE(nu_star, nu_phi) bottoms out at E[H(nu_star)] rather than 0\n", + "# same idea as UQL, again reported with the floor + EMA\n", + "\n", + "_, _, nu_pool = zip(*[make_insert_example() for _ in range(4000)])\n", + "iql_floor = binary_entropy(nu_pool)\n", + "\n", + "nu_phi = QualityHead(din)\n", + "opt = torch.optim.Adam(nu_phi.parameters(), lr=2e-3)\n", + "run = None\n", + "for step in range(2000):\n", + " feats, soft, _ = zip(*[make_insert_example() for _ in range(64)])\n", + " loss = F.binary_cross_entropy(nu_phi(torch.stack(feats)), torch.stack(soft))\n", + " opt.zero_grad(); loss.backward(); opt.step()\n", + " run = loss.item() if run is None else 0.98 * run + 0.02 * loss.item()\n", + " if (step + 1) % 500 == 0:\n", + " print(f\" step {step+1:4d} IQL running avg = {run:.4f}\")\n", + "\n", + "with torch.no_grad():\n", + " fs, _, qs = zip(*[make_insert_example() for _ in range(800)])\n", + " pred = nu_phi(torch.stack(fs)).squeeze(-1).numpy()\n", + "qs = np.array(qs)\n", + "print(f\"corr(nu_phi, nu_star) = {np.corrcoef(pred, qs)[0,1]:.3f}\")\n", + "\n", + "plt.figure(figsize=(4.2, 4.2))\n", + "plt.scatter(qs, pred, s=6, alpha=0.25)\n", + "bins = np.linspace(0, qs.max() + 1e-6, 9)\n", + "idx = np.digitize(qs, bins) - 1\n", + "bx = [qs[idx == b].mean() for b in range(len(bins)) if (idx == b).any()]\n", + "by = [pred[idx == b].mean() for b in range(len(bins)) if (idx == b).any()]\n", + "plt.plot(bx, by, \"o-\", color=\"crimson\", label=\"binned mean\")\n", + "lim = max(qs.max(), pred.max())\n", + "plt.plot([0, lim], [0, lim], \"k--\", lw=1, label=\"ideal\")\n", + "plt.xlabel(r\"true insertion quality $\\nu_\\star$\")\n", + "plt.ylabel(r\"predicted $\\nu_\\phi$\")\n", + "plt.title(\"IQL recovers the insertion quality\")\n", + "plt.legend(fontsize=8); plt.tight_layout(); plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Quality-guided inference with A2D2\n", + "\n", + "The A2D2 sampler (`_diffusion_loop`) alternates *unmask* and *insert* Euler steps, and after each one calls the two schedule-aware routines we vendored in §0 to **re-mask low-quality tokens** and **drop low-quality insertions**, keeping the masked/clean counts on the interpolant schedule. We exercise the actual functions with a minimal toy model.\n", + "\n", + "The toy example only needs `model.planner(seq, t)` returning per-position confidences, and `model.interpolant.{unmask,insertion}_schedule.at(t)` (linear schedule here)." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "class LinearSchedule:\n", + " def at(self, t):\n", + " return t.clone() if torch.is_tensor(t) else torch.tensor(float(t))\n", + "\n", + "class ToyInterpolant:\n", + " unmask_schedule = LinearSchedule()\n", + " insertion_schedule = LinearSchedule()\n", + "\n", + "class ToyPlanner:\n", + " \"\"\"Stand-in for the trained planner. We pass confidences in directly so we can script\n", + " exactly which token / insertion is low-quality; in A2D2 these come from mu_phi / nu_phi.\"\"\"\n", + " def __init__(self, remask_conf=None, insert_conf=None):\n", + " self.remask_conf, self.insert_conf = remask_conf, insert_conf\n", + " def __call__(self, seq, t):\n", + " B, L = seq.shape\n", + " out = {}\n", + " if self.remask_conf is not None:\n", + " out[\"remasking_conf\"] = self.remask_conf.view(B, L, 1)\n", + " if self.insert_conf is not None:\n", + " out[\"insertion_conf\"] = self.insert_conf.view(B, L, 1)\n", + " return out\n", + "\n", + "class ToyModel:\n", + " def __init__(self, planner):\n", + " self.planner = planner\n", + " self.interpolant = ToyInterpolant()\n", + "\n", + "neg_inf = torch.tensor(-np.inf)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.1 Schedule-aware **re-masking** (low unmasking quality → re-masked)\n", + "\n", + "Four clean tokens `ABCD` and two masks. Position 1 (`B`) is given a low unmasking-quality confidence (0.10) while the rest are confident. With the schedule asking for one more mask at this step, the routine re-masks exactly the lowest-quality token." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "before: ABCD__\n", + "after : A_CD__ <- the low-quality 'B' was re-masked\n" + ] + } + ], + "source": [ + "new_xt = torch.tensor([[0, 1, 2, 3, MASK, MASK]]) # ABCD__\n", + "conf = torch.tensor([[0.9, 0.1, 0.95, 0.92, 0.0, 0.0]]) # pos 1 = low quality\n", + "clean_index = (new_xt != MASK) & (new_xt != PAD)\n", + "model = ToyModel(ToyPlanner()) # remasking_conf is passed in directly\n", + "t, dt = torch.tensor([0.4]), 0.1\n", + "\n", + "out = apply_schedule_aware_remasking(model, new_xt.clone(), t, dt, conf, clean_index,\n", + " MASK, neg_inf, batch_size=1)\n", + "print(\"before:\", decode(new_xt[0].tolist()))\n", + "print(\"after :\", decode(out[0].tolist()), \" <- the low-quality 'B' was re-masked\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.2 Schedule-aware **insertion deletion** (low insertion quality → dropped)\n", + "\n", + "State `A_BC_` after inserting two masks (between `A` & `B`, and after `C`). The first inserted mask gets a low insertion-quality confidence (0.05); with `quality_threshold=0.5` the routine drops it and compacts the sequence, keeping the confident insertion." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "before: A_BC_.\n", + "after : ABC_.. <- low-quality insertion dropped, kept the good one\n" + ] + } + ], + "source": [ + "new_xt = torch.tensor([[0, 1, 2, PAD, PAD, PAD]]) # ABC (3 clean)\n", + "xt_tmp = torch.tensor([[0, MASK, 1, 2, MASK, PAD]]) # A_BC_ after inserts\n", + "ext = torch.tensor([[0, 1, 0, 1, 0, 0]]) # 1 insert in gaps 1 and 3\n", + "orig_mask = torch.tensor([[True, False, True, True, False, False]])\n", + "new_pos_orig = torch.tensor([[0, 2, 3, 3, 4, 5]]) # originals shifted right\n", + "insert_conf = torch.tensor([[0.9, 0.05, 0.9, 0.9, 0.95, 0.0]]) # inserted mask @pos1 = low\n", + "model = ToyModel(ToyPlanner(insert_conf=insert_conf))\n", + "t, dt = torch.tensor([0.0]), 0.05\n", + "\n", + "out = apply_schedule_aware_insertion(model, xt_tmp.clone(), new_xt, t, dt, ext,\n", + " MASK, PAD, max_length=6, orig_mask=orig_mask,\n", + " new_pos_orig=new_pos_orig, quality_threshold=0.5)\n", + "\n", + "print(\"before:\", decode(xt_tmp[0].tolist()))\n", + "print(\"after :\", decode(out[0].tolist()), \" <- low-quality insertion dropped, kept the good one\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Putting it together: quality inference catches decoding errors\n", + "\n", + "The point of both predictors is to flag mistakes *without* access to $\\boldsymbol{x}_1$. Here we corrupt one position of a clean target (decode it to the wrong letter) and let the trained $\\mu_\\phi$ score every clean token. The corrupted position gets the **lowest** predicted unmasking quality — so A2D2's schedule-aware re-masking would target it for a re-do, exactly as intended." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "target x1: BCDECDEAB\n", + "decoded y: BCDEEDEAB (position 4 corrupted: C -> E)\n", + "\n", + " predicted unmasking quality mu_phi per position:\n", + " pos 0 (B): 0.666\n", + " pos 1 (C): 0.948\n", + " pos 2 (D): 0.920\n", + " pos 3 (E): 0.339\n", + " pos 4 (E): 0.041 *corrupted* <-- lowest, would be re-masked\n", + " pos 5 (D): 0.595\n", + " pos 6 (E): 0.918\n", + " pos 7 (A): 0.957\n", + " pos 8 (B): 0.745\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAD5CAYAAADRCboRAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAMq9JREFUeJzt3XlcVGX/P/7XwMAMDIssAkqsaoohYOCG4pqoqHd2p9ltBSpW3OCK9jHjdslMUssPZoKVIi1kpLeWCy6oSZr0SU2tlBZXSEEUVJAUHbi+f/ib+TnMIEdZZoDX8/E4f8w11znnfZ1t3nOdTSaEECAiIiKiWpkZOwAiIiKipoKJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIlTM7FgwQLIZDJcvXrV4Pf+/v7o379/4wbVAGprp8b48ePh7e3dOEGZkLS0NMhkMpw/f15bZmhZLF68GF9//XW9z//YsWPo168f7O3tIZPJkJSUhP3790Mmk2H//v31Pr/6kJycjLS0NL1yTdwbN25s/KCagVOnTmHBggU622J90RwHavPFF18gKSmpTvOSyWSYPHlynaZRn+7evYtOnTrhnXfeaZT59e/fX+e3448//oClpSV++umnRpm/KWLiRM3S3LlzsXnzZmOHYRIMLYuGSpwmTpyIgoICfPnll8jJycHzzz9f7/OobzUlTlQ3p06dwptvvtkgiZNU9ZE4mZrk5GRcu3YNU6ZMMcr8H3/8cbzwwguYMWOGUeZvCuTGDoCoIbRr187YIZiMxlwWv/76K15++WUMGzZMW/bbb7812vxJnxACt2/fhpWVld53t27dglKplNR7Q8anVquxbNkyTJw4ESqV6oF1//77b1hbWzdIHJMnT0ZISAgOHTqE0NDQBpmHKWOPUwulOQ2xfv16JCQkoG3btrCzs8NTTz2F33//Xadu//794e/vj5ycHISGhsLKygre3t5Yt24dAGD79u148sknYW1tjS5dumDnzp06458+fRoTJkxAhw4dYG1tDXd3d4wcORK//PKLTr2qqiosWrQIHTt2hJWVFVq1aoWAgACsWLHigW357bff4Ovrix49eqCoqAiA4dNTmi73zz77DH5+frC2tkZgYCC2bdumN81vvvkGAQEBUCgU8PX1xYoVKySfHhBCYOnSpfDy8oJSqcSTTz6JHTt26HV5GzqtBsDgqa2srCw8/fTTeOyxx6BUKtG+fXu8+uqrtZ6yNLQsZDIZysvL8cknn0Amk0Emk6F///44f/485HI5EhMT9abx3XffQSaTYcOGDQbnoWmLWq1GSkqKdro1OXLkCJ5//nl4e3trt6d//etfuHDhgl7dgwcPolevXlAqlXB3d8fcuXOxZs0avWW3b98+9O/fH05OTrCysoKnpyeeffZZ/P333zXG4e3tjZMnTyI7O1sbc/Xt5u7du7XuIwCwZ88eDBo0CHZ2drC2tkbv3r2xd+/eGud9v+vXr2PmzJnw9fWFQqGAi4sLIiIidJLOkpISxMbGwt3dHZaWlvD19UVCQgIqKip0pqXZzlevXg0/Pz8oFAp88skn2nW0e/duTJw4Ea1bt4a1tTUqKipqPLVtaJvXTP/DDz/E448/DoVCgc6dO+PLL7/U1klLS8OYMWMAAAMGDNAu2/t79qQur+3btyMoKAgKhQI+Pj549913JS3T/v37Y/v27bhw4YJ2/ve3ReryrE4IgTfeeAMWFhb4+OOPteUZGRno1asXVCoVbGxsMGTIEBw7dkxn3PHjx8PGxganT59GREQEbGxs4OHhgZkzZ9Y6XwDYsmULLl68iJdeekmnXLOefvrpJ4wePRoODg7aP0xCCCQnJyMoKAhWVlZwcHDA6NGjcfbsWb12GTpuGRIcHAw/Pz+sXr261pibJUHNwvz58wUAceXKFYPfP/HEE6Jfv37az99++60AILy9vcULL7wgtm/fLtavXy88PT1Fhw4dhFqt1tbt16+fcHJyEh07dhRr164Vu3btEiNGjBAAxJtvvim6dOki1q9fLzIzM0XPnj2FQqEQFy9e1I6fnZ0tZs6cKTZu3Ciys7PF5s2bxahRo4SVlZX47bfftPUSExOFubm5mD9/vti7d6/YuXOnSEpKEgsWLKixnfv37xcODg7i6aefFuXl5dp6UVFRwsvLS2cZaNrbvXt38dVXX4nMzEzRv39/IZfLxZkzZ7T1duzYIczMzET//v3F5s2bxYYNG0SPHj2Et7e3kLLLaGKMjo4WO3bsEB999JFwd3cXbm5uOutg3bp1AoA4d+6czviadfPtt99qy1JSUkRiYqLYsmWLyM7OFp988okIDAwUHTt2FHfu3HngNKsvi5ycHGFlZSUiIiJETk6OyMnJESdPnhRCCPHMM88IT09PnfUvhBBjxowRbdu2FXfv3jXY5qKiIpGTkyMAiNGjR2unW1N7NmzYIObNmyc2b94ssrOzxZdffin69esnWrdurbMNnzhxQiiVShEQECC+/PJLsWXLFhEREaFdF5p2njt3TiiVSjF48GDx9ddfi/3794v09HTx0ksviWvXrtWwpoT46aefhK+vr+jatas25p9++kknbin7yGeffSZkMpkYNWqU2LRpk9i6dasYMWKEMDc3F3v27Klx/kIIUVpaKp544gmhUqnEwoULxa5du8R///tfMW3aNLFv3z4hhBC3bt0SAQEBQqVSiXfffVfs3r1bzJ07V8jlchEREaEzPQDC3d1dBAQEiC+++ELs27dP/Prrr9ptw93dXbzyyitix44dYuPGjUKtVhvcX4T4/7fl6tP38PAQnTt3FuvXrxdbtmwRQ4cOFQDEhg0btNvD4sWLBQCxatUq7bItKip6qOW1Z88eYW5uLvr06SM2bdokNmzYILp16yY8PT1r3RdPnjwpevfuLdzc3LTz12yTD7s84+LihBBC3L59Wzz//PPC1tZW7NixQ1vn7bffFjKZTEycOFFs27ZNbNq0SfTq1UuoVCrtviXEvX3R0tJS+Pn5iXfffVfs2bNHzJs3T8hkMvHmm28+sD1CCDFx4kTh4uJS43ry8vISs2fPFllZWeLrr78WQgjx8ssvCwsLCzFz5kyxc+dO8cUXX4hOnToJV1dXUVhYqDeN2o5bGv/+97+Fs7OzqKqqqjXu5oaJUzPxqIlT9YPEV199JQBoDzBC3EucAIgjR45oy4qLi4W5ubmwsrLSSZKOHz8uAIj333+/xljVarW4c+eO6NChg5gxY4a2fMSIESIoKEhyOz/77DNhaWkppk6dKiorK3Xq1ZQ4ubq6itLSUm1ZYWGhMDMzE4mJidqybt26CQ8PD1FRUaEtKysrE05OTrUerK9duyaUSqV45plndMq///57AeCRE6f7VVVVibt374oLFy4IAOKbb7554DQNLQuVSiWioqL0pq2Z9+bNm7VlFy9eFHK5XNKB/f4fGantEeLeNnHz5k2hUqnEihUrtOVjxowRKpVKZ7uurKwUnTt31mnnxo0bBQBx/PjxWmOsrvq+UT3u2vaR8vJy4ejoKEaOHKlTr7KyUgQGBoru3bs/cP4LFy4UAERWVlaNdVavXi0AiK+++kqnfMmSJQKA2L17t7YMgLC3txclJSU6dTXbRmRkpN70HzZxsrKy0vnRVavVolOnTqJ9+/basg0bNhhc7w+zvHr06CHatm0rbt26pS0rLS0Vjo6Okv7EDB8+3GC7HnZ5xsXFieLiYtGnTx/h7u6us53l5eUJuVwupkyZojOtsrIy4ebmJp577jltWVRUlMH5RkREiI4dO9baHj8/PzF06FC9cs16mjdvnk655s/Me++9p1Oen58vrKysxP/8z/8IIR7uuKXx8ccfCwAiNze31ribG56qa+H+8Y9/6HwOCAgAAL1TJm3atEFwcLD2s6OjI1xcXBAUFIS2bdtqy/38/PTGV6vVWLx4MTp37gxLS0vI5XJYWlrizz//RG5urrZe9+7dceLECcTGxmLXrl0oLS2tMe63334b48ePxzvvvIMVK1bAzEzapjxgwADY2tpqP7u6usLFxUUbb3l5OY4cOYJRo0bB0tJSW8/GxgYjR46sdfo5OTm4ffs2XnjhBZ3y0NBQeHl5SYrRkKKiIsTExMDDwwNyuRwWFhba6d2/DOuqf//+CAwMxKpVq7Rlq1evhkwmwyuvvFJv87l58yZmz56N9u3bQy6XQy6Xw8bGBuXl5Trtyc7OxsCBA+Hs7KwtMzMzw3PPPaczvaCgIFhaWuKVV17BJ598oncaoi5q20cOHTqEkpISREVFQa1Wa4eqqioMHToUhw8fRnl5eY3T37FjBx5//HE89dRTNdbZt28fVCoVRo8erVM+fvx4ANA7xTVw4EA4ODgYnNazzz5b43ykGjRoEFxdXbWfzc3NMXbsWJw+fRp//fXXA8eVurzKy8tx+PBh/POf/4RSqdSOb2trK2lffJCHXZ7nzp1Dr169UFpaih9++AGBgYHa73bt2gW1Wo3IyEid9iiVSvTr10/vblKZTKYXf0BAgMHT1NVdunQJLi4uNX5ffd1u27YNMpkML774ok5sbm5uCAwM1Mb2KMctTRwXL16sNe7mhheHNxNy+b1VWVlZafB7tVoNCwsLvXInJyedzwqFAsC9i0bv5+joqDeupaWlXrkm2bh9+7a2LD4+HqtWrcLs2bPRr18/ODg4wMzMDJMmTdKZz5w5c6BSqfD5559j9erVMDc3R9++fbFkyRKEhITozOfzzz+Hu7v7Q9+1Vb29mjZr4rh27RqEEDo/ChqGyqorLi4GALi5uel9Z6hMiqqqKoSHh+PSpUuYO3cuunTpApVKhaqqKvTs2VNvXdXV1KlTMWnSJPz+++/w9fXFxx9/jNGjRz9y/IaMGzcOe/fuxdy5c9GtWzfY2dlBJpMhIiJCpz3FxcWS1kW7du2wZ88eLF26FHFxcSgvL4evry+mTp2KadOm1SnW2vaRy5cvA4Dej/D9SkpKaryY98qVK/D09HxgDMXFxXBzc9O73sjFxQVyuVy73Wm0adOmxmk96DupHrR9FxcX47HHHqtxXKnLSyaToaqqql73JY2HXZ4//vgjrl69irfffluvbZr2dOvWzeC8qv+ps7a21kkEgXvb1P3HzJpoLuavSfV1e/ny5RqPZwDg6+sL4NGOW5o46vv40xQwcWomNDvGxYsX9XYSIQQKCgr0ko/G8vnnnyMyMhKLFy/WKb969SpatWql/SyXyxEfH4/4+Hhcv34de/bswRtvvIEhQ4YgPz9f5w6RnTt3YuzYsQgLC8PevXvr1JtzPwcHB8hkMu3B8H6FhYW1jq/5kTVUt7CwUOcCXM2Bp/pFodUv+P71119x4sQJpKWlISoqSlt++vTpWuN5FOPGjcPs2bOxatUq9OzZE4WFhYiLi6u36d+4cQPbtm3D/Pnz8frrr2vLKyoqUFJSolPXyclJ8roICwtDWFgYKisrceTIEaxcuRLTp0+Hq6trgz4WQdMbtnLlSvTs2dNgnQcl3a1bt661l8bJyQn/93//ByGEzo99UVER1Gq1To8cgAdemG/oO6VSafDi5JpuPqhp+9bE+iBSl9fdu3chk8keOK9H9bDLc+zYsXBzc0NCQgKqqqrwn//8R689GzdurLfjUE2cnZ319pH7VV+3zs7OkMlkOHDggDbhv5+m7GGOWxqaOKovq5aAp+qaiYEDB0ImkyEjI0Pvu507d6K0tPSBpwIakkwm09tpt2/f/sAu3latWmH06NGIi4tDSUmJ3p1nXl5e2oNBWFgY/vzzz3qJVaVSISQkBF9//TXu3LmjLb9586bBu++q69mzJ5RKJdLT03XKDx06pNcVrzkY/fzzzzrlW7Zs0fmsORhWX4YffvhhrfHU5P5etuqUSqX2lNfy5csRFBSE3r17P/K8qpPJZBBC6LVnzZo1ej2m/fr1w759+3R+wKuqqmq8uw+4d9qoR48e2tONtT2o70HLQorevXujVatWOHXqFEJCQgwO95/2rW7YsGH4448/sG/fvhrrDBo0CDdv3tR79tann36q/b4uvL29UVRUpJOk3rlzB7t27TJYf+/evTp1KysrkZGRgXbt2ml7ZGrqvZa6vFQqFbp3745Nmzbp9MaUlZVh69atktpV07p9lOX5n//8B0lJSZg3bx7mzJmjLR8yZAjkcjnOnDlTY3vqS6dOnXDmzBnJ9UeMGAEhBC5evGgwri5dugB4uOOWxtmzZ2FmZoaOHTs+eoOaKPY4NRPt2rXD5MmTsWzZMly/fh0RERGwsrLC4cOH8c477yAkJATjxo0zSmwjRoxAWloaOnXqhICAABw9ehTLli3T6/IeOXIk/P39ERISgtatW+PChQtISkqCl5cXOnTooDfdNm3aIDs7G0OGDEHfvn2RlZUFf3//Ose7cOFCDB8+HEOGDMG0adNQWVmJZcuWwcbG5oH/9oB7PVazZs3CokWLMGnSJIwZMwb5+flYsGCBXpd3t27d0LFjR8yaNQtqtRoODg7YvHkzDh48qFOvU6dOaNeuHV5//XUIIeDo6IitW7ciKyvrkdvYpUsX7N+/H1u3bkWbNm1ga2urcwCMjY3F0qVLcfToUaxZs+aR52OInZ0d+vbti2XLlsHZ2Rne3t7Izs7G2rVrdXogASAhIQFbt27FoEGDkJCQACsrK6xevVp7zZDmNMjq1auxb98+DB8+HJ6enrh9+zZSU1MBoNY/DF26dMGXX36JjIwM+Pr6QqlUan9QpLCxscHKlSsRFRWFkpISjB49Gi4uLrhy5QpOnDiBK1euICUlpcbxp0+fjoyMDDz99NN4/fXX0b17d9y6dQvZ2dkYMWIEBgwYgMjISKxatQpRUVE4f/48unTpgoMHD2Lx4sWIiIio85+isWPHYt68eXj++efx2muv4fbt23j//fdrPPXv7OyMgQMHYu7cuVCpVEhOTsZvv/2m80gCzb740UcfwdbWFkqlEj4+PnBycpK8vN566y0MHToUgwcPxsyZM1FZWYklS5ZApVLVui8C99btpk2bkJKSguDgYJiZmSEkJOSRl+e0adNgY2ODV155BTdv3sT7778Pb29vLFy4EAkJCTh79iyGDh0KBwcHXL58GT/++CNUKhXefPPNh10lBvXv3x8LFy6U/Iym3r1745VXXsGECRNw5MgR9O3bFyqVCgUFBTh48CC6dOmCf//73w913NL44YcfEBQUVOO1dM2a8a5Lp/pWVVUlUlJSREhIiLC2thaWlpaiQ4cOYvbs2aKsrEynruaOIc3twxrnzp0TAMS6deu0Zf369RNPPPGE3vy8vLzE8OHD9cpR7c6qa9euiejoaOHi4iKsra1Fnz59xIEDB0S/fv107tZ47733RGhoqHB2dhaWlpbC09NTREdHi/Pnz2vrGLp78Pr166J3797C0dFRHD58WAhR81111e/40rSj+h1mmzdvFl26dNHG8c4774ipU6cKBwcHvfGrq6qqEomJicLDw0NYWlqKgIAAsXXrVr32CiHEH3/8IcLDw4WdnZ1o3bq1mDJliti+fbve3UinTp0SgwcPFra2tsLBwUGMGTNG5OXlCQBi/vz52npS76o7fvy46N27t7C2tq7xrpn+/fsLR0dH8ffff9faZg1Dy9jQXXV//fWXePbZZ4WDg4OwtbUVQ4cOFb/++qvBdXHgwAHRo0cPoVAohJubm3jttde0dz9dv35dCHHv7qFnnnlGeHl5CYVCIZycnES/fv3Eli1bao35/PnzIjw8XNja2mpv6b4/bin7iBD3HrsxfPhw4ejoKCwsLIS7u7sYPny43viGXLt2TUybNk14enoKCwsL4eLiIoYPH67zuI7i4mIRExMj2rRpI+RyufDy8hJz5swRt2/f1plWTdu5ZtvQ7CPVZWZmiqCgIGFlZSV8fX3FBx98UONddXFxcSI5OVm0a9dOWFhYiE6dOon09HS9aSYlJQkfHx9hbm6ut8ykLq8tW7aIgIAAnX3RUFyGlJSUiNGjR4tWrVoJmUymM05dluf69euFXC4XEyZM0N7R+/XXX4sBAwYIOzs7oVAohJeXlxg9erTO4xWioqKESqXSi1Nqe06fPi1kMpneXXm13VWdmpoqevToIVQqlbCyshLt2rUTkZGROndKP8xxq6ysTFhbW+vdrddSyIQQotGyNKIm6u7duwgKCoK7uzt27979SNPQPPzSVN/Zdr+ioiJ4eXlhypQpWLp0qbHD0RMeHo7z58/jjz/+MHYoLY5MJkNcXBw++OADY4fSIo0cORJqtbrGh1M2hrVr12LatGnIz89vkT1OPFVHZEB0dDQGDx6MNm3aoLCwEKtXr0Zubm6tTzFv6v766y+cPXsWy5Ytg5mZWZ3vSKsP8fHx6Nq1Kzw8PFBSUoL09HRkZWVh7dq1xg6NqNElJiaia9euOHz4cI138jUktVqNJUuWYM6cOS0yaQKYOBEZVFZWhlmzZuHKlSuwsLDAk08+iczMTKNdYN9Y1qxZg4ULF8Lb2xvp6elwd3c3dkiorKzEvHnzUFhYCJlMhs6dO+Ozzz7Diy++aOzQiBqdv78/1q1bV+c7Cx9Vfn4+XnzxRcycOdMo8zcFPFVHREREJBEfR0BEREQkkUkkTt999x1GjhyJtm3bQiaT6T1bw5Ds7GwEBwdDqVTC19e35b6lmYiIiBqNSVzjVF5ejsDAQEyYMEHSe5TOnTuHiIgIvPzyy/j888/x/fffIzY2Fq1bt36o9zBVVVXh0qVLsLW1feCTdomIiKj5EkKgrKwMbdu2rfXdpyZ3jZNMJsPmzZsxatSoGuvMnj0bW7Zs0XkZaExMDE6cOIGcnJwax6uoqNB5rcDFixfRuXPneombiIiImrb8/PwHvmsRMJEep4eVk5OD8PBwnbIhQ4Zg7dq1uHv3rsGX2QL3buM09ATX/Px82NnZNUisREREZNpKS0vh4eEBW1vbWus2ycSpsLBQ76WZrq6uUKvVuHr1ao1v/54zZw7i4+O1nzULys7OjokTERFRCyflsp0mmTgB+o3TnHF8UKMVCoXBN0QTERERSWESd9U9LDc3N72HfxUVFUEul8PJyclIUREREVFz1yQTp169eum9GX737t0ICQmp8fomIiIioroyicTp5s2bOH78OI4fPw7g3uMGjh8/jry8PAD3rk2KjIzU1o+JicGFCxcQHx+P3NxcpKamYu3atZg1a5YxwiciIqIWwiSucTpy5AgGDBig/ay5gDsqKgppaWkoKCjQJlEA4OPjg8zMTMyYMQOrVq1C27Zt8f777z/UM5yIiIiIHpbJPcepMZWWlsLe3h43btzgXXVEREQt1MPkAybR40RERNQUDHlru7FDkGTX3OHGDqHZMolrnIiIiIiaAvY40UPjPy4iImqp2ONEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiXhxOhKZzwTvAi96paeA+Rc0Ve5yIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkkUklTsnJyfDx8YFSqURwcDAOHDjwwPrp6ekIDAyEtbU12rRpgwkTJqC4uLiRoiUiIqKWxmQSp4yMDEyfPh0JCQk4duwYwsLCMGzYMOTl5Rmsf/DgQURGRiI6OhonT57Ehg0bcPjwYUyaNKmRIyciIqKWwmQSp+XLlyM6OhqTJk2Cn58fkpKS4OHhgZSUFIP1f/jhB3h7e2Pq1Knw8fFBnz598Oqrr+LIkSM1zqOiogKlpaU6AxEREZFUJpE43blzB0ePHkV4eLhOeXh4OA4dOmRwnNDQUPz111/IzMyEEAKXL1/Gxo0bMXz48Brnk5iYCHt7e+3g4eFRr+0gIiKi5s0kEqerV6+isrISrq6uOuWurq4oLCw0OE5oaCjS09MxduxYWFpaws3NDa1atcLKlStrnM+cOXNw48YN7ZCfn1+v7SAiIqLmzSQSJw2ZTKbzWQihV6Zx6tQpTJ06FfPmzcPRo0exc+dOnDt3DjExMTVOX6FQwM7OTmcgIiIikkpu7AAAwNnZGebm5nq9S0VFRXq9UBqJiYno3bs3XnvtNQBAQEAAVCoVwsLCsGjRIrRp06bB4yYiIqKWxSR6nCwtLREcHIysrCyd8qysLISGhhoc5++//4aZmW745ubmAO71VBERERHVN5NInAAgPj4ea9asQWpqKnJzczFjxgzk5eVpT73NmTMHkZGR2vojR47Epk2bkJKSgrNnz+L777/H1KlT0b17d7Rt29ZYzSAiIqJmzCRO1QHA2LFjUVxcjIULF6KgoAD+/v7IzMyEl5cXAKCgoEDnmU7jx49HWVkZPvjgA8ycOROtWrXCwIEDsWTJEmM1gYiIiJo5k0mcACA2NhaxsbEGv0tLS9MrmzJlCqZMmdLAURERERHdYzKn6oiIiIhMHRMnIiIiIolM6lQdERERNa4hb203dgiS7Zpb89tBGgt7nIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkERMnIiIiIomYOBERERFJxMSJiIiISCImTkREREQSyY0dQHM25K3txg5Bsl1zhxs7BCIiIpPHHiciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJJFJJU7Jycnw8fGBUqlEcHAwDhw48MD6FRUVSEhIgJeXFxQKBdq1a4fU1NRGipaIiIhaGpN5cnhGRgamT5+O5ORk9O7dGx9++CGGDRuGU6dOwdPT0+A4zz33HC5fvoy1a9eiffv2KCoqglqtbuTIiYiIqKUwmcRp+fLliI6OxqRJkwAASUlJ2LVrF1JSUpCYmKhXf+fOncjOzsbZs2fh6OgIAPD29n7gPCoqKlBRUaH9XFpaWn8NICIiombPJE7V3blzB0ePHkV4eLhOeXh4OA4dOmRwnC1btiAkJARLly6Fu7s7Hn/8ccyaNQu3bt2qcT6JiYmwt7fXDh4eHvXaDiIiImreTKLH6erVq6isrISrq6tOuaurKwoLCw2Oc/bsWRw8eBBKpRKbN2/G1atXERsbi5KSkhqvc5ozZw7i4+O1n0tLS5k8EZHR8YXgRE2HSSROGjKZTOezEEKvTKOqqgoymQzp6emwt7cHcO903+jRo7Fq1SpYWVnpjaNQKKBQKOo/cCIiImoRTOJUnbOzM8zNzfV6l4qKivR6oTTatGkDd3d3bdIEAH5+fhBC4K+//mrQeImIiKhlMonEydLSEsHBwcjKytIpz8rKQmhoqMFxevfujUuXLuHmzZvasj/++ANmZmZ47LHHGjReIiIiaplMInECgPj4eKxZswapqanIzc3FjBkzkJeXh5iYGAD3rk+KjIzU1h83bhycnJwwYcIEnDp1Ct999x1ee+01TJw40eBpOiIiIqK6MplrnMaOHYvi4mIsXLgQBQUF8Pf3R2ZmJry8vAAABQUFyMvL09a3sbFBVlYWpkyZgpCQEDg5OeG5557DokWLjNUEIiIiauZMJnECgNjYWMTGxhr8Li0tTa+sU6dOeqf3iIiIiBqKyZyqIyIiIjJ1TJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUn0yInT1q1b0bVrV7Rv3x7//Oc/sXv37vqMi4iIiMjkPPTjCJKSkvDkk09i1qxZ2LhxI9q3b48TJ05g8eLFOHv2rPaBlURERETNzUP3OMnlcnz22WfIz8/H2LFj8fzzz2Pr1q0YM2YMVqxYASFEQ8RJREREZHQP3eM0efJkAMCZM2fw0Ucf4e7du/jll1/w888/o6CgAJ07d4aNjQ0OHz5c78ESERERGdMjPzl85cqVGDNmDEJDQ9GlSxeUl5fD398fBw8exI0bN+ozRiIiIiKT8MgXhz/xxBPIycnB4MGDcfnyZXh6euKbb74BANjb29dbgERERESmQnKPU35+Pjw8PHTKlEolRo0ahVGjRtV3XEREREQmR3Li5OXlBQcHBwQGBiIwMBBBQUEIDAxERUUFVq1ahU8//bQh4yQiIiIyOsmJ09mzZ3H8+HEcP34cx44dw8aNG3Hp0iUAgJ2dXYMFSERERGQqJCdO3t7e8Pb21jktl5OTg6ioKCxZsqQhYiMiIiIyKXV65UqvXr2wYsUKLFq0qL7iISIiIjJZkhOnu3fvGizv0KEDTp48WW8BEREREZkqyafqVCoVOnfujK5duyIoKAhdu3ZF27ZtsXLlSoSHhzdkjEREREQmQXLitG/fPpw4cQInTpxAeno63njjDdy6dQsAEB4ejoSEBAQEBCAgIAB+fn4NFjARERGRsUhOnPr06YM+ffpoP1dVVeH333/X3ml39OhRpKamoqioCJWVlQ0SLBEREZExPfIrV8zMzODn5wc/Pz/861//0pZfvny5XgIjIiIiMjV1uqvOEFdX1/qeJBEREZFJqPfEiYiIiKi5YuJEREREJBETJyIiIiKJmDgRERERSWRSiVNycjJ8fHygVCoRHByMAwcOSBrv+++/h1wuR1BQUMMGSERERC2aySROGRkZmD59OhISEnDs2DGEhYVh2LBhyMvLe+B4N27cQGRkJAYNGtRIkRIREVFLZTKJ0/LlyxEdHY1JkybBz88PSUlJ8PDwQEpKygPHe/XVVzFu3Dj06tWr1nlUVFSgtLRUZyAiIiKSyiQSpzt37uDo0aN677wLDw/HoUOHahxv3bp1OHPmDObPny9pPomJibC3t9cOHh4edYqbiIiIWhaTSJyuXr2KyspKvYdnurq6orCw0OA4f/75J15//XWkp6dDLpf2APQ5c+bgxo0b2iE/P7/OsRMREVHL8civXGkIMplM57MQQq8MACorKzFu3Di8+eabePzxxyVPX6FQQKFQ1DlOIiIiaplMInFydnaGubm5Xu9SUVGRwVe4lJWV4ciRIzh27BgmT54M4N5Lh4UQkMvl2L17NwYOHNgosRMREVHLYRKn6iwtLREcHIysrCyd8qysLISGhurVt7Ozwy+//ILjx49rh5iYGHTs2BHHjx9Hjx49Git0IiIiakFMoscJAOLj4/HSSy8hJCQEvXr1wkcffYS8vDzExMQAuHd90sWLF/Hpp5/CzMwM/v7+OuO7uLhAqVTqlRMRERHVF5NJnMaOHYvi4mIsXLgQBQUF8Pf3R2ZmJry8vAAABQUFtT7TiYiIiKghmUziBACxsbGIjY01+F1aWtoDx12wYAEWLFhQ/0ERERER/X9M4honIiIioqaAiRMRERGRREyciIiIiCRi4kREREQkkUldHE5E9WfIW9uNHYJku+YON3YIRESSsMeJiIiISCImTkREREQSMXEiIiIikoiJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCQyqcQpOTkZPj4+UCqVCA4OxoEDB2qsu2nTJgwePBitW7eGnZ0devXqhV27djVitERERNTSmEzilJGRgenTpyMhIQHHjh1DWFgYhg0bhry8PIP1v/vuOwwePBiZmZk4evQoBgwYgJEjR+LYsWONHDkRERG1FHJjB6CxfPlyREdHY9KkSQCApKQk7Nq1CykpKUhMTNSrn5SUpPN58eLF+Oabb7B161Z07drV4DwqKipQUVGh/VxaWlp/DSAiIqJmzyR6nO7cuYOjR48iPDxcpzw8PByHDh2SNI2qqiqUlZXB0dGxxjqJiYmwt7fXDh4eHnWKm4iIiFoWk0icrl69isrKSri6uuqUu7q6orCwUNI03nvvPZSXl+O5556rsc6cOXNw48YN7ZCfn1+nuImIiKhlMZlTdQAgk8l0Pgsh9MoMWb9+PRYsWIBvvvkGLi4uNdZTKBRQKBR1jpOIiIhaJpNInJydnWFubq7Xu1RUVKTXC1VdRkYGoqOjsWHDBjz11FMNGSYRERG1cCZxqs7S0hLBwcHIysrSKc/KykJoaGiN461fvx7jx4/HF198geHDhzd0mERERNTCmUSPEwDEx8fjpZdeQkhICHr16oWPPvoIeXl5iImJAXDv+qSLFy/i008/BXAvaYqMjMSKFSvQs2dPbW+VlZUV7O3tjdYOIiIiar5MJnEaO3YsiouLsXDhQhQUFMDf3x+ZmZnw8vICABQUFOg80+nDDz+EWq1GXFwc4uLitOVRUVFIS0tr7PCJiIioBTCZxAkAYmNjERsba/C76snQ/v37Gz4gIiIiovuYxDVORERERE0BEyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgkYuJEREREJBETJyIiIiKJmDgRERERScTEiYiIiEgiJk5EREREEjFxIiIiIpKIiRMRERGRREyciIiIiCRi4kREREQkERMnIiIiIomYOBERERFJxMSJiIiISCImTkREREQSMXEiIiIikkhu7ACIiB7GmdZhxg5BknZXDhg7BCJqAOxxIiIiIpKIiRMRERGRREyciIiIiCQyqcQpOTkZPj4+UCqVCA4OxoEDD75GIDs7G8HBwVAqlfD19cXq1asbKVIiIiJqiUwmccrIyMD06dORkJCAY8eOISwsDMOGDUNeXp7B+ufOnUNERATCwsJw7NgxvPHGG5g6dSr++9//NnLkRERE1FKYzF11y5cvR3R0NCZNmgQASEpKwq5du5CSkoLExES9+qtXr4anpyeSkpIAAH5+fjhy5AjeffddPPvsswbnUVFRgYqKCu3nGzduAABKS0vruTX3qG//3SDTbQgPswyaSruaY5sA6e1qjm0CgLIqdQNGUn+4/TW/NgFNp13NsU1Aw/1ea6YrhKi9sjABFRUVwtzcXGzatEmnfOrUqaJv374GxwkLCxNTp07VKdu0aZOQy+Xizp07BseZP3++AMCBAwcOHDhw4KA35Ofn15qzmESP09WrV1FZWQlXV1edcldXVxQWFhocp7Cw0GB9tVqNq1evok2bNnrjzJkzB/Hx8drPVVVVKCkpgZOTE2QyWT20pOGVlpbCw8MD+fn5sLOzM3Y49YJtajqaY7vYpqajObaLbTINQgiUlZWhbdu2tdY1icRJo3ryIoR4YEJjqL6hcg2FQgGFQqFT1qpVq0eI1Pjs7OyazAYpFdvUdDTHdrFNTUdzbBfbZHz29vaS6pnExeHOzs4wNzfX610qKirS61XScHNzM1hfLpfDycmpwWIlIiKilsskEidLS0sEBwcjKytLpzwrKwuhoaEGx+nVq5de/d27dyMkJAQWFhYNFisRERG1XCaROAFAfHw81qxZg9TUVOTm5mLGjBnIy8tDTEwMgHvXJ0VGRmrrx8TE4MKFC4iPj0dubi5SU1Oxdu1azJo1y1hNaBQKhQLz58/XO+XYlLFNTUdzbBfb1HQ0x3axTU2PTAgp9941juTkZCxduhQFBQXw9/fH//7v/6Jv374AgPHjx+P8+fPYv3+/tn52djZmzJiBkydPom3btpg9e7Y20SIiIiKqbyaVOBERERGZMpM5VUdERERk6pg4EREREUnExImIiIhIIiZORERERBIxcWoCxo8fD5lMph2cnJwwdOhQ/Pzzz8YOrc4KCwsxZcoU+Pr6QqFQwMPDAyNHjsTevXuNHdpDu389WVhYwNXVFYMHD0ZqaiqqqqqMHV6dVN8GNcPQoUONHdoja25t4vbX9Bw6dAjm5uZNvh0azfm36n5MnJqIoUOHoqCgAAUFBdi7dy/kcjlGjBhh7LDq5Pz58wgODsa+ffuwdOlS/PLLL9i5cycGDBiAuLg4Y4f3SDTr6fz589ixYwcGDBiAadOmYcSIEVCr1cYOr07u3wY1w/r1640dVp00tzZx+2taUlNTMWXKFBw8eBB5eXnGDqdeNMffqupM6l11VDOFQgE3NzcA9143M3v2bPTt2xdXrlxB69atjRzdo4mNjYVMJsOPP/4IlUqlLX/iiScwceJEI0b26O5fT+7u7njyySfRs2dPDBo0CGlpaZg0aZKRI3x097etuWhubeL213SUl5fjq6++wuHDh1FYWIi0tDTMmzfP2GHVWXP8raqOPU5N0M2bN5Geno727ds32ffylZSUYOfOnYiLi9NJmjSa6suXDRk4cCACAwOxadMmY4dCLRC3P9OUkZGBjh07omPHjnjxxRexbt06NLfHKjaH3ypDmDg1Edu2bYONjQ1sbGxga2uLLVu2ICMjA2ZmTXMVnj59GkIIdOrUydihNIpOnTrh/Pnzxg6jTu7fBjXDW2+9Zeyw6qQ5tskQbn+mZ+3atXjxxRcB3Du9dfPmzSZ5bWd1ze23yhCeqmsiBgwYgJSUFAD3emuSk5MxbNgw/Pjjj/Dy8jJydA9P889KJpMZOZLGIYRo8m29fxvUcHR0NFI09aM5tskQbn+m5ffff8ePP/6o7QWUy+UYO3YsUlNT8dRTTxk5urppbr9VhjBxaiJUKhXat2+v/RwcHAx7e3t8/PHHWLRokREjezQdOnSATCZDbm4uRo0aZexwGlxubi58fHyMHUadVN8Gm4Pm2CZDuP2ZlrVr10KtVsPd3V1bJoSAhYUFrl27BgcHByNGVzfN7bfKkObTd9bCyGQymJmZ4datW8YO5ZE4OjpiyJAhWLVqFcrLy/W+v379euMH1UD27duHX375Bc8++6yxQ6EWiNufaVGr1fj000/x3nvv4fjx49rhxIkT8PLyQnp6urFDrFdN/bfKEPY4NREVFRUoLCwEAFy7dg0ffPABbt68iZEjRxo5skeXnJyM0NBQdO/eHQsXLkRAQADUajWysrKQkpKC3NxcY4f40DTrqbKyEpcvX8bOnTuRmJiIESNGIDIy0tjh1cn926CGXC6Hs7OzkSKqu+bWJm5/pm/btm24du0aoqOjYW9vr/Pd6NGjsXbtWkyePNlI0dVdc/yt0iPI5EVFRQkA2sHW1lZ069ZNbNy40dih1dmlS5dEXFyc8PLyEpaWlsLd3V384x//EN9++62xQ3to968nuVwuWrduLZ566imRmpoqKisrjR1enVTfBjVDx44djR3aI2tubeL21zSMGDFCREREGPzu6NGjAoA4evRoI0dVP5rzb9X9ZEI0s/sfiYiIiBoIr3EiIiIikoiJExEREZFETJyIiIiIJGLiRERERCQREyciIiIiiZg4EREREUnExImIiIhIIiZORERERBIxcSIiIiKSiIkTERERkURMnIiIiIgk+n9hN13plaDdRwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x1 = sample_seq(9)\n", + "y = x1.copy()\n", + "bad_pos = 4\n", + "y[bad_pos] = (x1[bad_pos] + 2) % V # decode this position to a WRONG letter\n", + "\n", + "with torch.no_grad():\n", + " scores = []\n", + " for ell in range(len(y)):\n", + " scores.append(mu_phi(window_feat(y, ell).unsqueeze(0)).item())\n", + "\n", + "print(\"target x1:\", decode(x1))\n", + "print(\"decoded y:\", decode(y), f\" (position {bad_pos} corrupted: {LETTERS[x1[bad_pos]]} -> {LETTERS[y[bad_pos]]})\")\n", + "print(\"\\n predicted unmasking quality mu_phi per position:\")\n", + "for ell, s in enumerate(scores):\n", + " flag = \" <-- lowest, would be re-masked\" if ell == int(np.argmin(scores)) else \"\"\n", + " star = \" *corrupted*\" if ell == bad_pos else \"\"\n", + " print(f\" pos {ell} ({LETTERS[y[ell]]}): {s:.3f}{star}{flag}\")\n", + "\n", + "plt.figure(figsize=(6, 2.6))\n", + "colors = [\"crimson\" if ell == bad_pos else \"steelblue\" for ell in range(len(y))]\n", + "plt.bar(range(len(y)), scores, color=colors)\n", + "plt.xticks(range(len(y)), [LETTERS[t] for t in y])\n", + "plt.ylabel(r\"$\\mu_\\phi$\"); plt.title(\"Unmasking quality flags the corrupted token (red)\")\n", + "plt.tight_layout(); plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "- **Unmasking quality** $\\mu_\\star$ and **insertion quality** $\\nu_\\star$ are defined in closed form here from a known posterior; the predictors $\\mu_\\phi,\\nu_\\phi$ trained with the paper's **UQL / IQL** BCE losses recover them.\n", + "- At inference, A2D2 turns these scores that are used in `apply_schedule_aware_remasking` (re-mask low-quality tokens) and `apply_schedule_aware_insertion` (drop low-quality insertions), both shown here on toy states and tied back to a corrupted-decode example.\n", + "- Swap the toy `posterior_at` / `ToyPlanner` for a trained any-length MDM + its planner heads (`model/model_wrapper py::RemaskingAnyOrder`) and this becomes the quality-guided decode used in `a2d2_mol` / `a2d2_pep` / `a2d2_language`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "peptune", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..9a7021af20cd6805394dc00c9ca30b27a12b4f2a --- /dev/null +++ b/environment.yml @@ -0,0 +1,57 @@ +# Conda environment shared across the molecule, peptide, and language experiments. +# Create with: +# conda env create -f environment.yml +# conda activate a2d2 +# +# NOTE: flash-attn is hardware-specific and must be built against your installed torch +# and CUDA, so it is not listed below. It is imported by the shared transformer backbone +# (model/casual_transformer.py, model/rotary.py) and is required for all experiments. +# After creating the env, install it with: +# pip install flash-attn==2.8.3 --no-build-isolation +# Adjust pytorch-cuda below to match your CUDA toolkit / GPU. +name: a2d2 +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - python=3.11 + - pip + - pytorch + - pytorch-cuda=12.1 + - rdkit=2023.9.6 + - jupyterlab # for demo/quality_inference_demo.ipynb + - pip: + # --- core scientific / DL stack --- + - numpy==1.26.4 + - scipy==1.17.1 + - pandas==2.1.4 + - scikit-learn==1.8.0 + - pytorch-lightning==2.6.0 + - lightning==2.6.1 + - transformers==4.55.4 + - tokenizers==0.21.4 + - safetensors==0.7.0 + - accelerate==0.33.0 + - peft==0.15.1 # LoRA adapters (language experiment) + - datasets==2.19.2 + - huggingface-hub==0.36.2 + - einops==0.8.2 + - timm==1.0.26 + - omegaconf==2.3.0 + - wandb==0.26.1 + # --- molecule experiment --- + - safe-mol==0.1.14 + - datamol==0.12.5 + - PyTDC==1.1.15 + # --- peptide experiment --- + - SmilesPE==0.0.3 + - fair-esm==2.0.0 + - xgboost==3.2.0 + # --- plotting / utilities --- + - matplotlib==3.10.6 + - seaborn==0.13.2 + - tqdm==4.67.1 + - joblib==1.5.3 + - loguru==0.7.3 + - fsspec==2024.3.1 \ No newline at end of file diff --git a/lightning_modules/__init__.py b/lightning_modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..67b96f756961ab511def199434620a0cec87082c --- /dev/null +++ b/lightning_modules/__init__.py @@ -0,0 +1,16 @@ +from .mdm import MaskedDiffusionModule +from .any_order import AnyOrderInsertionFlowModule + + +__all__ = [ + "MaskedDiffusionModule", + "AutoregressiveModule", + "AnyOrderInsertionFlowModule", +] + + +def __getattr__(name): + if name == "AutoregressiveModule": + from .autoregressive import AutoregressiveModule + return AutoregressiveModule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/lightning_modules/any_length_remask.py b/lightning_modules/any_length_remask.py new file mode 100644 index 0000000000000000000000000000000000000000..475818f01b3b3675cf800691d767f926f46d02b7 --- /dev/null +++ b/lightning_modules/any_length_remask.py @@ -0,0 +1,801 @@ +import os +import torch +import torch.nn as nn +import pytorch_lightning as pl +from omegaconf import DictConfig +import torch.nn.functional as F +from model.transformer import AnyOrderMaskInsertionFlow +from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction +from .bregman import jump_kernel_elbo, mse +from .schedule import get_schedule_from_config +from lightning_modules.any_order import AnyOrderInsertionFlowModule +from model.model_wrapper import RemaskingAnyOrder +from sampling import _sample_tokens + +import re +from typing import Dict, Any +from dataclasses import dataclass + +def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Returns a new state_dict where any key containing '._orig_mod.' is replaced + by removing the '_orig_mod' segment, e.g. + 'model._orig_mod.vocab_embed.embedding' + becomes + 'model.vocab_embed.embedding' + """ + new_state_dict: Dict[str, Any] = {} + for key, value in state_dict.items(): + # remove all occurrences of '._orig_mod.' + clean_key = re.sub(r"\._orig_mod\.", ".", key) + new_state_dict[clean_key] = value + return new_state_dict + + +@torch.no_grad() +def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float: + """Rank-based AUROC (Mann-Whitney U statistic). + + AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN + when only one class is present (AUC undefined). Ties are not averaged, which + is fine for continuous logits used here. + """ + scores = scores.float().reshape(-1) + labels = labels.float().reshape(-1) + n_pos = labels.sum() + n_neg = labels.numel() - n_pos + if n_pos == 0 or n_neg == 0: + return float("nan") + order = torch.argsort(scores) + ranks = torch.empty_like(scores) + ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype) + auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) + return auc.item() + + +class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule): + """ + Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model + for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint + and add the schedule model on top. + """ + def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False): + # Initialize parent class first + super().__init__(config) + + self.args = args + self.insertion_planner = insertion_planner + + # Save hyperparameters for this class (overrides parent's save) + self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args']) + + # Load pretrained model weights BEFORE initializing planner to avoid circular reference + if pretrained_checkpoint is not None: + self.load_pretrained_model(pretrained_checkpoint) + + # Initialize adaptive schedule model AFTER loading pretrained weights + self.planner = RemaskingAnyOrder( + backbone=self, + d_model=self.config.model.hidden_size, + insertion_planner=insertion_planner) + + def load_pretrained_model(self, checkpoint_path: str): + """ + Load pretrained AnyOrderInsertionFlowModule weights. + Only loads the base model and interpolant, not the schedule model. + """ + print(f"Loading pretrained model from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + + # Extract state dict - handle different checkpoint formats + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # Strip _orig_mod keys if present + state_dict = strip_orig_mod_keys(state_dict) + + # Filter out planner keys (if any exist from a previous FT checkpoint) + base_state_dict = {k: v for k, v in state_dict.items() + if not k.startswith('planner.')} + + # Load the base model weights + # Use strict=False to ignore missing schedule_model keys + incompatible_keys = self.load_state_dict(base_state_dict, strict=False) + + # Filter out expected missing planner keys for cleaner output + unexpected_missing = [k for k in incompatible_keys.missing_keys + if not k.startswith('planner.')] + planner_missing = [k for k in incompatible_keys.missing_keys + if k.startswith('planner.')] + + if unexpected_missing: + print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}") + if planner_missing: + print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)") + if incompatible_keys.unexpected_keys: + print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}") + + # Freeze base model if specified + if self.config.training.get('freeze_base_model', False): + print("Freezing base model parameters") + for name, param in self.named_parameters(): + if not name.startswith('planner.'): + param.requires_grad = False + + def forward(self, x, t, return_features=False): + # Use parent class forward method + return super().forward(x, t, return_features=return_features) + + def training_loss(self, x1, t): + # Use parent class training_loss for base model loss + # Planner is trained separately via loss_planner_flexible with reward gradients + unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t) + return unmask_loss, insertion_loss, total_loss + + + def training_step(self, batch, batch_idx): + # Extract input data + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + t = self.sample_time(x1.shape[0], x1.device) + + # Calculate the base model loss (planner trained separately, not here) + unmask_loss, len_loss, loss = self.training_loss(x1, t) + + # Log component losses + self.log("train/unmask_loss", unmask_loss, prog_bar=True) + self.log("train/len_loss", len_loss, prog_bar=True) + self.log("train/total_loss", loss, prog_bar=True) + + return loss + + def validation_step(self, batch, batch_idx): + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + t = self.sample_time(x1.shape[0], x1.device) + unmask_loss, len_loss, loss = self.training_loss(x1, t) + + self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True) + self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + + return loss + + @classmethod + def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs): + """ + Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner. + Extracts config from original pretrained checkpoint and loads finetuned weights. + """ + print(f"Loading finetuned checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False) + + # Check if this is a wrapped checkpoint (from PeptideFinetuner) + hparams = checkpoint.get('hyper_parameters', {}) + state_dict = checkpoint.get('state_dict', {}) + + # Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper) + has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys()) + + if has_policy_prefix: + # Detect model type (molecule vs peptide) based on vocab size in checkpoint + # Molecule models have vocab size ~1882, peptide models have ~587 + vocab_size = None + for k, v in state_dict.items(): + if 'vocab_embed.embedding' in k: + vocab_size = v.shape[0] + break + + is_molecule_model = vocab_size is not None and vocab_size > 1000 + model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner" + print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})") + + # Extract args from hyperparameters + if 'args' not in hparams: + raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.") + + args = hparams['args'] + print(f"Found args in hyperparameters, type: {type(args)}") + + # Get original checkpoint path from args + # Handle both Namespace (hasattr) and dict (get) access patterns + original_ckpt_path = None + if hasattr(args, 'checkpoint_path'): + original_ckpt_path = args.checkpoint_path + elif isinstance(args, dict) and 'checkpoint_path' in args: + original_ckpt_path = args['checkpoint_path'] + + # If checkpoint_path is not set or is None, use default pretrained checkpoint + # Select appropriate default based on detected model type + if original_ckpt_path is None: + _repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + if is_molecule_model: + original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt') + print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint") + else: + original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt') + print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint") + + # Try to load config directly from checkpoint first (new checkpoints) + # Fall back to loading from original checkpoint (old checkpoints) + if 'config' in checkpoint: + print("Found config directly in checkpoint") + config = checkpoint['config'] + else: + print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}") + + # Load config from original pretrained checkpoint + orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False) + if 'config' not in orig_ckpt: + raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config") + + config = orig_ckpt['config'] + + # Ensure adaptive schedule is enabled + # Need to disable struct mode to add new keys to OmegaConf config + from omegaconf import OmegaConf + if hasattr(config, 'training'): + OmegaConf.set_struct(config, False) + config.training.use_adaptive_schedule = True + OmegaConf.set_struct(config, True) + + # Create args object if needed + if not hasattr(args, '__dict__'): + # Convert dict to object with attributes + class Args: + pass + args_obj = Args() + for k, v in args.items(): + setattr(args_obj, k, v) + args = args_obj + + # Initialize model with config and args + model = cls( + config=config, + args=args, + pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint + insertion_planner=getattr(args, 'insertion_planner', False) + ) + + # Extract policy_model weights from state_dict + policy_state = {} + for k, v in state_dict.items(): + if k.startswith('policy_model.'): + # Strip 'policy_model.' prefix + new_key = k[len('policy_model.'):] + policy_state[new_key] = v + + # Load the finetuned weights + incompatible = model.load_state_dict(policy_state, strict=False) + if incompatible.missing_keys or incompatible.unexpected_keys: + print(f"Warning: Incompatible keys when loading finetuned weights:") + if incompatible.missing_keys: + print(f" Missing: {incompatible.missing_keys[:5]}...") + if incompatible.unexpected_keys: + print(f" Unexpected: {incompatible.unexpected_keys[:5]}...") + + # Initialize or load EMA params + if model.use_ema: + if "ema_params" in checkpoint: + # Load EMA params from checkpoint + model.ema_params = checkpoint["ema_params"] + print("Loaded EMA params from checkpoint") + else: + # Initialize empty EMA params (will be populated if needed) + model.ema_params = { + name: param.clone().detach() + for name, param in model.named_parameters() + } + print("Initialized EMA params from current model state") + else: + model.ema_params = {} + + # Load planner state if it exists + if "planner_state" in checkpoint and hasattr(model, 'planner'): + model.planner.load_state_dict(checkpoint["planner_state"], strict=False) + print("Loaded planner state from checkpoint") + + return model + else: + # Not a wrapped checkpoint, use default Lightning loading + # But we still need to provide required __init__ arguments + raise NotImplementedError( + "Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. " + "Please provide config and args as kwargs." + ) + + def on_save_checkpoint(self, checkpoint): + """Save config and EMA params, including planner state.""" + # Call parent to save config and base model EMA + super().on_save_checkpoint(checkpoint) + + # Explicitly save planner state + if hasattr(self, 'planner'): + checkpoint["planner_state"] = self.planner.state_dict() + + def on_load_checkpoint(self, checkpoint): + """Load config and reinitialize interpolant, including planner.""" + # For finetuned checkpoints loaded via custom load_from_checkpoint, + # config may not be in checkpoint (it's loaded from original checkpoint) + if "config" in checkpoint: + # Call parent to restore config and interpolant + super().on_load_checkpoint(checkpoint) + else: + # Config already set during __init__ via load_from_checkpoint + # Just restore EMA params if they exist + if self.use_ema and "ema_params" in checkpoint: + self.ema_params = checkpoint["ema_params"] + + # Restore planner state if it exists in checkpoint + if hasattr(self, 'planner') and "planner_state" in checkpoint: + self.planner.load_state_dict(checkpoint["planner_state"]) + print("Loaded planner from checkpoint") + + def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): + r""" + Weighted denoising cross entropy loss + X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) + + log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) + x: [B, L] (no mask) + num_replicates: R, number of replicates of each row in x + weight_func: w(lambda) for each sample, 1/lambda by default + centering_strength: float, controls how much of the mean is subtracted (DMPO-style) + softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) + """ + + batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] + + batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] + if centering: + batch_weights = batch_weights - centering_strength * batch_weights.mean() + + batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) + + lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] + lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] + + t = lamda + + # compute unmasking and insertion loss + interpolant_sample = self.interpolant.sample_interpolant(t, batch) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + scale_factor = self.config.interpolant.max_length + + match self.unmask_loss_fn: + case "elbo": + mask_indices = interpolant_sample.mask_indices + unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L] + unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy( + prediction.token_logits[mask_indices], + interpolant_sample.unmasked[mask_indices], + reduction="none", + ) + unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R] + case _: + raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") + + match self.insert_loss_fn: + case "expectation": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo( + gaps[gaps_mask], prediction.expected_gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] + + case "distribution": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy( + prediction.length_posterior[gaps_mask], gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] + + total_loss = unmask_loss + insertion_loss # [B*R] + # end compute unmasking and insertion loss + + weighted_loss = total_loss * batch_weights # [B*R] + return weighted_loss.mean() + + def one_step_sampler(self, xt, t, pred_rate=None): + """ + Sample one step of unmasking using model predictions. + + Args: + xt: Current state [B, L] + t: Time [B] + pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model. + + Returns: + new_xt: Next state [B, L] + update_ids: Boolean mask of updated positions [B, L] + """ + mask = self.interpolant.mask_token + pad = self.interpolant.pad_token + batch_size, L = xt.shape + device = xt.device + steps = self.args.total_num_steps + dt = 1.0 / steps + max_length = self.interpolant.max_length + # Use actual tensor dimension L instead of max_length to handle replicated batches + batch_idx_L = ( + torch.arange(batch_size, device=device) + .view(batch_size, 1) + .expand(batch_size, L) + ) + pos_idx_L = ( + torch.arange(L, device=device) + .view(1, L) + .expand(batch_size, L) + ) + + # ——— predict and convert rates ——— + if pred_rate is None: + pred_rate = self(xt, t) + pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t) + unmask_rate = pred_rate.unmask_rate # (B, L, V) + len_rate = pred_rate.length_rate # (B, L+1) + + # ——— unmask step (Euler) ——— + mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True) + unmask_rate[xt != mask] = 0 + unmask_rate[mask_pos + (mask,)] = 0 + unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) + trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) + + # add "stay" probability + _xt = xt.clone() + _xt[xt == pad] = mask + trans_prob.scatter_add_( + 2, + _xt.unsqueeze(-1), + torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), + ) + + trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step + + # Renormalize probabilities to ensure they sum to 1 + prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) + # Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) + mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) + if mask_has_zero_prob.any(): + # Create uniform distribution over valid tokens (excluding mask and pad) + num_zero_prob = mask_has_zero_prob.sum().item() + uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) + uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 + trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob + else: + # Normalize to sum to 1 + trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum + + new_xt = _sample_tokens(trans_prob) + new_xt[xt == pad] = pad + new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) + + # update indices--boolean tensor of shape (B, max_length) + # A position is updated if: + # 1. The token changed (xt != new_xt) + # 2. It's not a pad position + # 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask) + + # Debug before fix + old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask) + + # Correct logic: updated positions are where mask tokens were changed + update_ids = (xt != new_xt) & (xt != pad) + + if self.insertion_planner is False: + return new_xt, update_ids + + # ——— Poisson insertion (tau-leaping) — can insert multiple masks per gap ——— + ext = torch.poisson(len_rate * dt).long() # (B, L+1) + xt_len = xt.ne(pad).sum(dim=1) # (B,) + # Use ext.shape[1] to get the actual max_length dimension from the data + actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1 + gaps = torch.arange(ext.shape[1], device=device).view(1, -1) + ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() + total_ext = ext.sum(dim=1) + valid = xt_len + total_ext <= actual_max_length + ext = ext * valid.view(batch_size, 1).long() + + ext_ex = ext.int().cumsum(dim=1) # (B, L+1) + new_len = xt_len + total_ext # (B,) + + xt_tmp = torch.full_like(xt, pad) + # Create position indices that match xt_tmp's shape + pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1) + mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1) + xt_tmp[mask_fill] = mask + + new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L) + orig_mask = pos_idx_L < xt_len.view(batch_size, 1) + flat_b = batch_idx_L[orig_mask] + flat_p = new_pos_orig[orig_mask] + xt_tmp[flat_b, flat_p] = new_xt[orig_mask] + + new_ins_xt = xt_tmp + + # Newly inserted masks: positions that are mask now but weren't before. + newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad) + + update_ins_ids = newly_inserted_masks + + return new_xt, update_ids, new_ins_xt, update_ins_ids + + def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): + r""" + Weighted denoising cross entropy loss + X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) + + log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) + x: [B, L] (no mask) + num_replicates: R, number of replicates of each row in x + weight_func: w(lambda) for each sample, 1/lambda by default + centering_strength: float, controls how much of the mean is subtracted (DMPO-style) + softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) + """ + + batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] + batch_size = batch.shape[0] + + batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] + if centering: + batch_weights = batch_weights - centering_strength * batch_weights.mean() + + batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) + + lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] + lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] + + t = lamda + scale_factor = self.config.interpolant.max_length + + # compute unmasking and insertion loss + interpolant_sample = self.interpolant.sample_interpolant(t, batch) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + with torch.no_grad(): # no need to compute gradient in this step + sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction) + # one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids) + xs, update_ids = sampler_out[0], sampler_out[1] + + # The remasking head scores the freshly-decoded tokens to decide which to + # remask, so it reads the POST-unmask state xs (matching inference, which + # calls the planner on the decoded new_xt). + planner = self.planner(xs, t) + remasking_conf = planner["remasking_conf"] # [B*R, L, 1] + + # Compute per-sample loss + # IMPORTANT: interpolant_sample.xt has been reordered via st permutation + # We need to map back to the original positions to compare with batch + st = interpolant_sample.st # [B*R, L] permutation indices + batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth + + binary_label = (xs == batch_reordered).float() + + # Only compute loss on positions that were updated + per_token_loss = F.binary_cross_entropy_with_logits( + remasking_conf.squeeze(-1), # [B*R, L] + binary_label, # [B*R, L] + reduction="none" # [B*R, L] + ) + + per_token_loss = per_token_loss * update_ids.float() # [B*R, L] + + # Mask out non-updated positions and average per sample + per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R] + + # Weight by importance sampling weights + weighted_loss = per_sample_loss * batch_weights # [B*R] + + # ——— AUC / label-balance diagnostics (see loss_insert_planner_flexible) ——— + with torch.no_grad(): + metrics = {} + sel_u = update_ids.bool() + if sel_u.any(): + u_scores = remasking_conf.squeeze(-1)[sel_u] + u_labels = binary_label[sel_u] + metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) + metrics["unmask_label_mean"] = u_labels.mean().item() + metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() + metrics["unmask_n"] = float(sel_u.sum().item()) + self._last_planner_metrics = metrics + + return weighted_loss.mean() + + def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): + r""" + Weighted denoising cross entropy loss + X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) + + log_rnd: [B] — pre-computed importance weights + x: [B, L] (no mask) + num_replicates: R, number of replicates of each row in x + weight_func: w(lambda) for each sample, 1/lambda by default + centering_strength: float, controls how much of the mean is subtracted (DMPO-style) + softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) + """ + + batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] + batch_size = batch.shape[0] + + batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] + if centering: + batch_weights = batch_weights - centering_strength * batch_weights.mean() + + batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) + + lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] + lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] + + t = lamda + scale_factor = self.config.interpolant.max_length + + # compute unmasking and insertion loss + # deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted + # gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices + interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + with torch.no_grad(): # no need to compute gradient in this step + xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction) + + # The remasking head scores the freshly-decoded tokens to decide which to + # remask, so it must see the POST-unmask state xs_unmask (matching + # inference in inference_quality.py, which calls the planner on the + # decoded new_xt). Grad stays on here since this head is what we train. + planner = self.planner(xs_unmask, t) + remasking_conf = planner["remasking_conf"] # [B*R, L, 1] + + # The insertion-quality head scores the freshly-inserted mask tokens, so + # it must see the POST-insertion state xs_insert (aligned with + # update_ins_ids / insertion_quality below, and matching inference in + # remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on + # here since this head is what we are training. + if self.planner.insertion_planner: + insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1] + else: + insertion_conf = None + + # Compute per-sample loss + # IMPORTANT: interpolant_sample.xt has been reordered via st permutation + # We need to map back to the original positions to compare with batch + # Use the st (permutation) to get the ground truth in the reordered space + st = interpolant_sample.st # [B*R, L] permutation indices + batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth + + # Now compare in the reordered space + binary_label = (xs_unmask == batch_reordered).float() + + # Only compute loss on positions that were updated + per_token_loss = F.binary_cross_entropy_with_logits( + remasking_conf.squeeze(-1), # [B*R, L] + binary_label, # [B*R, L] + reduction="none" # [B*R, L] + ) + + per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L] + + # Mask out non-updated positions and average per sample + unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R] + + # compute insertion planner loss + # For positions where masks were inserted, we evaluate the quality of insertion + # by computing the probability that the ground truth token would be predicted at that position + + # IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted + # The original prediction was computed from xt (before insertion) + with torch.no_grad(): + prediction_after_insert: ModelPrediction = self(xs_insert, t) + + # Get the token prediction probabilities at inserted mask positions + # prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert + token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V] + + # For each gap where masks were inserted, compute the sum of probabilities + # of the ground truth tokens that were deleted in that specific gap + # gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices + # batch: [B*R, L] - ground truth tokens in original space (before permutation) + + vocab_size = token_probs.shape[-1] + L = token_probs.shape[1] + max_gaps = gap_assignment.shape[1] + + # For each gap, create a vocabulary mask of tokens that belong to that gap + # gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx + gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float) + + # Vectorized: gather tokens from batch for all gaps at once + # tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions + tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L] + + # valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad + valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L] + + # Scatter tokens into vocabulary dimension: mark which tokens appear in each gap + gap_vocab_mask.scatter_add_( + 2, # scatter along vocabulary dimension + tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L] + valid_mask.float() # values to add [B*R, max_gaps, L] + ) + + # Binarize: a token either appears in the gap or not + gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V] + + # For each insertion position in xs_insert, determine which gap it corresponds to + # Position p in xs_insert corresponds to gap p (insertions occur between existing tokens) + # Vectorized: compute for all positions at once + # token_probs: [B*R, L, V] + # gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1 + insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L] + + # Only consider quality at positions where masks were actually inserted + insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L] + + # Compute insertion planner loss only if insertion_planner is enabled + if insertion_conf is not None: + # The planner predicts insertion confidence with insertion_conf + # We want to train it to predict high confidence when insertion_quality is high + # Use Bernoulli cross-entropy: treat insertion_quality as the "success probability" + + # Binary cross-entropy with insertion_quality as continuous labels in [0,1] + ins_per_token_loss = F.binary_cross_entropy_with_logits( + insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits + insertion_quality, # [B*R, L] - ground truth token probability as quality metric + reduction="none" + ) + + # Only compute loss where masks were actually inserted + ins_per_token_loss = ins_per_token_loss * update_ins_ids.float() + + # Average per sample + ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8) + else: + # No insertion planner - set loss to zero + ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss) + + # Add to total loss + per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss + + # Weight by importance sampling weights + weighted_loss = per_sample_loss * batch_weights # [B*R] + + # ——— AUC / label-balance diagnostics (the loss alone hides degenerate + # targets; near-0 BCE can mean "all labels one class", not "learned") ——— + with torch.no_grad(): + metrics = {} + sel_u = update_unmask_ids.bool() + if sel_u.any(): + u_scores = remasking_conf.squeeze(-1)[sel_u] + u_labels = binary_label[sel_u] + metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) + metrics["unmask_label_mean"] = u_labels.mean().item() + metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() + metrics["unmask_n"] = float(sel_u.sum().item()) + if insertion_conf is not None: + sel_i = update_ins_ids.bool() + if sel_i.any(): + i_scores = insertion_conf.squeeze(-1)[sel_i] + i_targets = insertion_quality[sel_i] + i_labels = (i_targets > 0.5).float() + metrics["insert_auc"] = _binary_auc(i_scores, i_labels) + metrics["insert_target_mean"] = i_targets.mean().item() + metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item() + metrics["insert_n"] = float(sel_i.sum().item()) + self._last_planner_metrics = metrics + + return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean() diff --git a/lightning_modules/any_order.py b/lightning_modules/any_order.py new file mode 100755 index 0000000000000000000000000000000000000000..9c3c8df74f67033e161bc3afd3ca4ad960d6790a --- /dev/null +++ b/lightning_modules/any_order.py @@ -0,0 +1,417 @@ +import torch +import pytorch_lightning as pl +from omegaconf import DictConfig +import torch.nn.functional as F +from model.transformer import AnyOrderMaskInsertionFlow +from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction +from .bregman import jump_kernel_elbo, mse +from .schedule import get_schedule_from_config + + +import re +from typing import Dict, Any + + +def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Returns a new state_dict where any key containing '._orig_mod.' is replaced + by removing the '_orig_mod' segment, e.g. + 'model._orig_mod.vocab_embed.embedding' + becomes + 'model.vocab_embed.embedding' + """ + new_state_dict: Dict[str, Any] = {} + for key, value in state_dict.items(): + # remove all occurrences of '._orig_mod.' + clean_key = re.sub(r"\._orig_mod\.", ".", key) + new_state_dict[clean_key] = value + return new_state_dict + + +class AnyOrderInsertionFlowModule(pl.LightningModule): + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.model_type = config.interpolant.type + self.learning_rate = config.training.learning_rate + self.unmask_loss_fn = config.training.loss_fn.unmask + self.insert_loss_fn = config.training.loss_fn.insert + + # Initialize model based on type + self.model = AnyOrderMaskInsertionFlow(config) + # self.model = torch.compile(self.model) # Disabled: incompatible with flex_attention nested functions + + insert_schedule = get_schedule_from_config(config.interpolant.insert_schedule) + unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule) + + # Initialize interpolant + self.interpolant = AnyOrderMaskInsertionInterpolant( + insertion_schedule=insert_schedule, + unmask_schedule=unmask_schedule, + vocab_size=config.interpolant.tokens, + mask_token=config.interpolant.mask_token, + pad_token=config.interpolant.pad_token, + max_length=config.interpolant.max_length, + ) + + # Save hyperparameters + self.save_hyperparameters() + + self.ema_decay = config.training.ema_decay or 0.0 + self.use_ema = self.ema_decay > 0 + self._orig_params = {} + + def forward(self, x, t, return_features: bool = False): + if self.config.training.only_embed_insert: + result = self.model(x, self.interpolant.insertion_schedule.at(t), return_features=return_features) + else: + result = self.model(x, t, return_features=return_features) + return result + + def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor): + """Delegate to backbone transformer for RemaskingAnyOrder compatibility.""" + return self.model.get_hidden_states(indices, t) + + def training_loss(self, x1, t): + interpolant_sample = self.interpolant.sample_interpolant(t, x1) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x1) + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + scale_factor = x1.shape[0] * self.config.interpolant.max_length + + match self.unmask_loss_fn: + case "elbo": + mask_indices = interpolant_sample.mask_indices + unmask_loss = unmask_weight[mask_indices] * F.cross_entropy( + prediction.token_logits[mask_indices], + interpolant_sample.unmasked[mask_indices], + reduction="none", + ) + unmask_loss = unmask_loss.sum() / scale_factor + case _: + raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") + + match self.insert_loss_fn: + case "expectation": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss = insert_weight[gaps_mask] * jump_kernel_elbo( + gaps[gaps_mask], prediction.expected_gaps[gaps_mask] + ) + insertion_loss = insertion_loss.sum() / scale_factor + + case "distribution": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss = insert_weight[gaps_mask] * F.cross_entropy( + prediction.length_posterior[gaps_mask], gaps[gaps_mask] + ) + insertion_loss = insertion_loss.sum() / scale_factor + + total_loss = unmask_loss + insertion_loss + return unmask_loss, insertion_loss, total_loss + + def prepare_noised_sample(self, x, num_samples=1, t=None): + """ + Run the forward noising process on clean sequences x. + Replicates each sequence num_samples times with independent random times + so that both policy and pretrained can evaluate the same noised data. + + Args: + x: [B, L] clean token sequences (no mask tokens) + num_samples: K, number of noisy time samples per sequence + t: [B*K] optional time values. If None, sampled uniformly. + + Returns: + dict with all artifacts needed by compute_loss_from_noised. + """ + B = x.shape[0] + x_rep = x.repeat_interleave(num_samples, dim=0) # [B*K, L] + if t is None: + t = torch.rand(B * num_samples, device=x.device) + + interpolant_sample = self.interpolant.sample_interpolant(t, x_rep) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x_rep) + scale_factor = self.config.interpolant.max_length + + return { + "interpolant_sample": interpolant_sample, + "unmask_weight": unmask_weight, + "insert_weight": insert_weight, + "t": t, + "scale_factor": scale_factor, + "num_samples": num_samples, + "batch_size": B, + } + + def compute_loss_from_noised(self, noised): + """ + Compute per-sample denoising loss given pre-noised data. + Each model runs its own forward pass on the shared noised xt. + + Args: + noised: dict from prepare_noised_sample() + + Returns: + total_loss: [B] per-sample loss averaged over K noisy samples + """ + interpolant_sample = noised["interpolant_sample"] + unmask_weight = noised["unmask_weight"] + insert_weight = noised["insert_weight"] + t = noised["t"] + scale_factor = noised["scale_factor"] + num_samples = noised["num_samples"] + B = noised["batch_size"] + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + match self.unmask_loss_fn: + case "elbo": + mask_indices = interpolant_sample.mask_indices + unmask_loss_all = torch.zeros_like(unmask_weight) # [B*K, L] + unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy( + prediction.token_logits[mask_indices], + interpolant_sample.unmasked[mask_indices], + reduction="none", + ) + unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*K] + case _: + raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") + + match self.insert_loss_fn: + case "expectation": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo( + gaps[gaps_mask], prediction.expected_gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K] + case "distribution": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*K, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy( + prediction.length_posterior[gaps_mask], gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*K] + + per_replicate_loss = unmask_loss + insertion_loss # [B*K] + per_sample_loss = per_replicate_loss.view(B, num_samples).mean(dim=1) # [B] + return per_sample_loss + + def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False): + r""" + Weighted denoising cross entropy loss + X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) + + log_rnd: [B]; x: [B, L] (no mask) + num_replicates: R, number of replicates of each row in x + weight_func: w(lambda) for each sample, 1/lambda by default + """ + + print("logrnd shape:", log_rnd.shape) + print("x shape:", x.shape) + + batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] + + batch_weights = log_rnd.detach().softmax(dim=-1) # [B*R] + if centering: + batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True) + + batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) + + lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] + lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] + + t = lamda + + # compute unmasking and insertion loss + interpolant_sample = self.interpolant.sample_interpolant(t, batch) + unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) + + prediction: ModelPrediction = self(interpolant_sample.xt, t) + + scale_factor = self.config.interpolant.max_length + + match self.unmask_loss_fn: + case "elbo": + mask_indices = interpolant_sample.mask_indices + unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L] + unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy( + prediction.token_logits[mask_indices], + interpolant_sample.unmasked[mask_indices], + reduction="none", + ) + unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R] + case _: + raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") + + match self.insert_loss_fn: + case "expectation": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo( + gaps[gaps_mask], prediction.expected_gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] + + case "distribution": + gaps, gaps_mask = interpolant_sample.gaps_and_mask + insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] + insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy( + prediction.length_posterior[gaps_mask], gaps[gaps_mask] + ) + insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] + + total_loss = unmask_loss + insertion_loss # [B*R] + # end compute unmasking and insertion loss + + weighted_loss = total_loss * batch_weights # [B*R] + return weighted_loss.mean() + + def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor: + eps = 1e-6 + interval = 1.0 - eps + interval_size = interval / batch_size + u = torch.rand(batch_size, device=device) + return (torch.arange(batch_size, device=device, dtype=u.dtype) + u) * interval_size + + def training_step(self, batch, batch_idx): + # Extract input data + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + t = self.sample_time(x1.shape[0], x1.device) + + # Calculate the combined loss normally + unmask_loss, len_loss, loss = self.training_loss(x1, t) + + # Log component losses + self.log("train/unmask_loss", unmask_loss, prog_bar=True) + self.log("train/len_loss", len_loss, prog_bar=True) + self.log("train/total_loss", loss, prog_bar=True) + + + return loss + + def validation_step(self, batch, batch_idx): + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + t = self.sample_time(x1.shape[0], x1.device) + unmask_loss, len_loss, loss = self.training_loss(x1, t) + + self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True) + self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + + return loss + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.learning_rate, + weight_decay=self.config.training.weight_decay, + ) + + warmup_steps = self.config.training.warmup_steps + max_steps = self.config.training.max_steps + + # Always create a fresh schedule starting from step 0 + # This allows extending training beyond original max_steps + linear_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1e-6, + end_factor=1.0, + total_iters=warmup_steps, + last_epoch=-1, + ) + post_warmup = max_steps - warmup_steps + cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=post_warmup, + eta_min=0.0, + last_epoch=-1, + ) + + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[linear_scheduler, cosine_scheduler], + milestones=[warmup_steps], + last_epoch=-1, + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer, + optimizer_closure=None, + ): + super().optimizer_step( + epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure + ) + # log learning rate and gradient norm + lr = optimizer.param_groups[0]["lr"] + self.log("train/lr", lr, on_step=True, prog_bar=True) + grad_norm = torch.sqrt( + sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None) + ) + self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True) + + # update EMA + if self.use_ema: + for n, p in self.named_parameters(): + self.ema_params[n].mul_(self.ema_decay).add_( + p.data.clone().detach(), alpha=1 - self.ema_decay + ) + + def on_save_checkpoint(self, checkpoint): + checkpoint["config"] = self.config + # save EMA state + if self.use_ema: + checkpoint["ema_params"] = { + n: v.clone() for n, v in self.ema_params.items() + } + + def on_load_checkpoint(self, checkpoint): + self.config = checkpoint["config"] + + insert_schedule = get_schedule_from_config( + self.config.interpolant.insert_schedule + ) + unmask_schedule = get_schedule_from_config( + self.config.interpolant.unmask_schedule + ) + + self.interpolant = AnyOrderMaskInsertionInterpolant( + insertion_schedule=insert_schedule, + unmask_schedule=unmask_schedule, + vocab_size=self.config.interpolant.tokens, + mask_token=self.config.interpolant.mask_token, + pad_token=self.config.interpolant.pad_token, + max_length=self.config.interpolant.max_length, + ) + + self.ema_params = checkpoint["ema_params"] if self.use_ema else {} + + def swap_to_ema(self): + for name, p in self.named_parameters(): + self._orig_params[name] = p.data.clone() + p.data.copy_(self.ema_params[name].to(p.device)) + + def restore_original(self): + for name, p in self.named_parameters(): + p.data.copy_(self._orig_params[name]) + self._orig_params.clear() + + def on_train_start(self): + # initialize and move EMA buffers once model is on correct device + if self.use_ema: + self.ema_params = { + name: param.clone().detach().to(self.device) + for name, param in self.named_parameters() + } + for buf in self.ema_params.values(): + buf.requires_grad = False \ No newline at end of file diff --git a/lightning_modules/bregman.py b/lightning_modules/bregman.py new file mode 100755 index 0000000000000000000000000000000000000000..825f26eacb1b1eaab7b52f3e0406923138681ab7 --- /dev/null +++ b/lightning_modules/bregman.py @@ -0,0 +1,19 @@ +# A file of bregman divergences +import torch + + +def mse(x, y): + sq_diff = (x - y) ** 2 + if x.shape != y.shape: + assert False, "x and y must have the same shape" + return sq_diff.reshape(sq_diff.size(0), -1).sum(dim=-1) + + +# TODO: check if this formulation is correct +def jump_kernel_elbo(x, y, eps=1e-6): + # x_safe: true length + # y_safe: predicted length + x_safe = torch.clamp(x, min=eps) + y_safe = torch.clamp(y, min=eps) + + return y_safe - x_safe + x_safe * (torch.log(x_safe) - torch.log(y_safe)) diff --git a/lightning_modules/mdm.py b/lightning_modules/mdm.py new file mode 100755 index 0000000000000000000000000000000000000000..3360925d99a3452898d45b3855f7f25d807bc6d0 --- /dev/null +++ b/lightning_modules/mdm.py @@ -0,0 +1,184 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from model.MDM_transformer import DDiTNoLengthModel +from model.interpolant import MDMInterpolant # replaced relative import +from .schedule import get_schedule_from_config + + +class MaskedDiffusionModule(pl.LightningModule): + def __init__(self, config): + super().__init__() + self.config = config + self.learning_rate = config.training.learning_rate + + # Initialize model (no length head) + self.model = DDiTNoLengthModel(config) + self.model = torch.compile(self.model) + + unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule) + + # Initialize interpolant + self.interpolant = MDMInterpolant( + unmask_schedule=unmask_schedule, + vocab_size=config.interpolant.tokens, + mask_token=config.interpolant.mask_token, + pad_token=config.interpolant.pad_token, + max_length=config.interpolant.max_length, + ) + + # Save hyperparameters + self.save_hyperparameters() + + self.ema_decay = config.training.ema_decay or 0.0 + self.use_ema = self.ema_decay > 0 + self._orig_params = {} + + def forward(self, x, t) -> torch.Tensor: + return self.model(x, t) + + def training_loss(self, x1, t): + # sample interpolant and elbo weight + + interpolant_result = self.interpolant.sample_interpolant(t, x1) + unmask_weight = self.interpolant.elbo_weight(t, x1) + + # model prediction + predicted_logits = self(interpolant_result.xt, t) + mask_indices = interpolant_result.mask_indices + + # compute unmask loss + loss = unmask_weight[mask_indices] * F.cross_entropy( + predicted_logits[mask_indices], + interpolant_result.unmasked[mask_indices], + reduction="none", + ) + + loss = loss.sum() / (x1.shape[0] * self.config.interpolant.max_length) + return loss + + def training_step(self, batch, batch_idx): + # Extract input data + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + batch_size = x1.shape[0] + t = torch.rand(batch_size, device=x1.device) + loss = self.training_loss(x1, t) + + self.log("train/total_loss", loss, prog_bar=True) + + return loss + + def validation_step(self, batch, batch_idx): + if isinstance(batch, dict): + batch = batch["input_ids"] + + x1 = batch + batch_size = x1.shape[0] + + t = torch.rand(batch_size, device=x1.device) + loss = self.training_loss(x1, t) + + self.log("val_loss", loss, prog_bar=True) + + return loss + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.learning_rate, + weight_decay=self.config.training.weight_decay, + ) + warmup_steps = self.config.training.warmup_steps + max_steps = self.config.training.max_steps + + linear_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1e-6, + end_factor=1.0, + total_iters=warmup_steps, + ) + post_warmup = max_steps - warmup_steps + cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=post_warmup // 10, + T_mult=1, + eta_min=0.0, + ) + + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[linear_scheduler, cosine_scheduler], + milestones=[warmup_steps], + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer, + optimizer_closure=None, + ): + super().optimizer_step( + epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure + ) + # log learning rate and gradient norm + lr = optimizer.param_groups[0]["lr"] + self.log("train/lr", lr, on_step=True, prog_bar=True) + grad_norm = torch.sqrt( + sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None) + ) + self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True) + + # update EMA + if self.use_ema: + for n, p in self.named_parameters(): + self.ema_params[n].mul_(self.ema_decay).add_( + p.data.clone().detach(), alpha=1 - self.ema_decay + ) + + def on_save_checkpoint(self, checkpoint): + checkpoint["config"] = self.config + # save EMA state + if self.use_ema: + checkpoint["ema_params"] = {n: v.cpu() for n, v in self.ema_params.items()} + + def on_load_checkpoint(self, checkpoint): + self.config = checkpoint["config"] + + unmask_schedule = get_schedule_from_config( + self.config.interpolant.unmask_schedule + ) + + self.interpolant = MDMInterpolant( + unmask_schedule=unmask_schedule, + vocab_size=self.config.interpolant.tokens, + mask_token=self.config.interpolant.mask_token, + pad_token=self.config.interpolant.pad_token, + max_length=self.config.interpolant.max_length, + ) + + self.ema_params = checkpoint["ema_params"] if self.use_ema else {} + + def swap_to_ema(self): + for name, p in self.named_parameters(): + self._orig_params[name] = p.data.clone() + p.data.copy_(self.ema_params[name].to(p.device)) + + def restore_original(self): + for name, p in self.named_parameters(): + p.data.copy_(self._orig_params[name]) + self._orig_params.clear() + + def on_train_start(self): + # initialize and move EMA buffers once model is on correct device + if self.use_ema: + self.ema_params = { + name: param.clone().detach().to(self.device) + for name, param in self.named_parameters() + } + for buf in self.ema_params.values(): + buf.requires_grad = False diff --git a/lightning_modules/schedule.py b/lightning_modules/schedule.py new file mode 100755 index 0000000000000000000000000000000000000000..38e27fd72d24cbba5b95c1d1f63fa0a2b7e2c48b --- /dev/null +++ b/lightning_modules/schedule.py @@ -0,0 +1,156 @@ +import abc +from omegaconf import DictConfig +import torch +import torch.nn as nn +from torch import Tensor + + +def get_schedule_from_config(config: DictConfig): + match config.type: + case "geometric": + return GeometricSchedule(min_val=config.min, max_val=config.max) + case "linear": + return LinearSchedule() + case "sin": + return SinSchedule() + case "cosine": + return CosineSchedule() + case "polynomial": + return PolynomialSchedule(exp=config.exp) + case _: + raise ValueError(f"Invalid schedule type: {config.type}") + + +class Schedule(abc.ABC): + """ + Generic schedule class for masking or noising + This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately + """ + + @abc.abstractmethod + def at(self, t: Tensor): + """ + Return value a(t) + """ + raise NotImplementedError + + @abc.abstractmethod + def derivative_at(self, t: Tensor): + """ + Return d/dt a(t) + """ + raise NotImplementedError + + def rate_scale_factor(self, t: Tensor) -> Tensor: + """ + Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation + """ + return self.derivative_at(t) / (1 - self.at(t)) + + def sample(self, shape, device) -> Tensor: + """ + Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1] + """ + uniform = torch.rand(shape, device=device) + return self.inv(uniform) + + def sample_truncated(self, threshold, shape, device) -> Tensor: + """ + Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1] + """ + uniform = torch.rand(shape, device=device) + threshold = self.at(threshold) + return self.inv(uniform * (1 - threshold) + threshold) + + @abc.abstractmethod + def inv(self, alpha: Tensor): + """ + Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t. + """ + raise NotImplementedError + + +class LinearSchedule(Schedule): + def __init__(self): + pass + + def at(self, t: Tensor): + return t + + def derivative_at(self, t: Tensor): + return torch.ones_like(t, device=t.device) + + def inv(self, alpha: Tensor): + return alpha + + +class GeometricSchedule(Schedule, nn.Module): + def __init__(self, min_val: float, max_val: float): + super().__init__() + self.register_buffer("min", Tensor([min_val])) + self.register_buffer("max", Tensor([max_val])) + + def at(self, t: Tensor): + min_val = self.min.to(t.device) + max_val = self.max.to(t.device) + return torch.exp(-(min_val ** (1 - t)) * max_val**t) + + def derivative_at(self, t): + min_val = self.min.to(t.device) + max_val = self.max.to(t.device) + return ( + self.at(t) + * min_val ** (1 - t) + * max_val**t + * (min_val.log() - max_val.log()) + ) + + def inv(self, alpha: Tensor): + log_min = self.min.to(alpha.device).log() + log_max = self.max.to(alpha.device).log() + return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min) + + +class SinSchedule(Schedule, nn.Module): + def __init__(self): + super().__init__() + + def at(self, t: Tensor): + return torch.sin(torch.pi / 2 * t) + + def derivative_at(self, t: Tensor): + return (torch.pi / 2) * torch.cos(torch.pi / 2 * t) + + def inv(self, alpha: Tensor): + return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.)) + + +class CosineSchedule(Schedule, nn.Module): + def __init__(self): + super().__init__() + + def at(self, t: Tensor): + return 1 - torch.cos(torch.pi / 2 * t) + + def derivative_at(self, t: Tensor): + return (torch.pi / 2) * torch.sin(torch.pi / 2 * t) + + def rate_scale_factor(self, t): + return (torch.pi/2) * torch.tan(torch.pi / 2 * t) + + def inv(self, alpha): + return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.)) + +class PolynomialSchedule(Schedule, nn.Module): + def __init__(self, exp): + super().__init__() + self.exp = exp + + def at(self, t: Tensor): + return t ** self.exp + + def derivative_at(self, t: Tensor): + return self.exp * t ** (self.exp - 1) + + def inv(self, alpha: Tensor): + return alpha ** (1 / self.exp) \ No newline at end of file diff --git a/model/MDM_transformer.py b/model/MDM_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..1728fb66bea910df6c3e00b85329249271e3d4ba --- /dev/null +++ b/model/MDM_transformer.py @@ -0,0 +1,75 @@ +import torch +import torch.nn.functional as F +from . import rotary +from .transformer import EmbeddingLayer, TimestepEmbedder, DDiTBlock, DDitFinalLayer +from omegaconf import OmegaConf +from torch.nn.attention.flex_attention import create_block_mask + + +def _dense_mask(b, h, q_idx, kv_idx): + return torch.full_like(q_idx, True, dtype=torch.bool) + + +class DDiTNoLengthModel(torch.nn.Module): + """ + A DDiT‐style model that predicts only per‐token posteriors, + without any sequence‐length head, opt for the vanilla MDM + """ + + def __init__(self, config): + super().__init__() + # allowing dict configs too + if isinstance(config, dict): + config = OmegaConf.create(config) + + self.config = config + self.vocab_size = config.interpolant.tokens + self.pad_token = config.interpolant.pad_token + self.mask_token = config.interpolant.mask_token + + self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) + self.sigma_map = TimestepEmbedder(config.model.cond_dim) + self.rotary_emb = rotary.Rotary( + config.model.hidden_size // config.model.n_heads + ) + + self.blocks = torch.nn.ModuleList( + [ + DDiTBlock( + config.model.hidden_size, + config.model.n_heads, + config.model.cond_dim, + dropout=config.model.dropout, + ) + for _ in range(config.model.n_blocks) + ] + ) + # final per‐token head only / no length head + self.output_layer = DDitFinalLayer( + config.model.hidden_size, self.vocab_size, config.model.cond_dim + ) + + def forward(self, indices: torch.Tensor, t: torch.Tensor): + """ + indices: (B, L) token indices + t: (B,) timestep scalars + returns: ReparametrizedRate with only per_token_posterior set + """ + B, L = indices.shape + + block_mask = create_block_mask( + _dense_mask, B=B, H=None, Q_LEN=indices.shape[1], KV_LEN=indices.shape[1] + ) + print(block_mask) + + x = self.vocab_embed(indices) # (B, L, hidden) + c = F.silu(self.sigma_map(t)) # (B, cond_dim) + rotary_cos_sin = self.rotary_emb(x) # precompute rotary embeddings + + # run the stack + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, block_mask) + + token_logits = self.output_layer(x, c) + return token_logits diff --git a/model/__init__.py b/model/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/casual_transformer.py b/model/casual_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..fd6aaba43ad4eb11ee0bd3a3d7c525d99de9bd60 --- /dev/null +++ b/model/casual_transformer.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from .fused_add_dropout_scale import modulate_fused, bias_dropout_add_scale_fused_train, bias_dropout_add_scale_fused_inference +from .transformer import LayerNorm, EmbeddingLayer +from . import rotary + + +class CausalDiTBlock(nn.Module): + def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): + super().__init__() + self.n_heads = n_heads + self.norm1 = LayerNorm(dim) + self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + self.dropout1 = nn.Dropout(dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_ratio * dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_ratio * dim, dim, bias=True) + ) + self.dropout2 = nn.Dropout(dropout) + self.dropout = dropout + # No time or label conditioning, so no adaLN_modulation + + def _get_bias_dropout_scale(self): + return ( + bias_dropout_add_scale_fused_train + if self.training + else bias_dropout_add_scale_fused_inference + ) + + def forward(self, x, rotary_cos_sin, seqlens=None): + batch_size, seq_len = x.shape[0], x.shape[1] + bias_dropout_scale_fn = self._get_bias_dropout_scale() + + # attention operation + x_skip = x + x = self.norm1(x) + # dtype0 = x.dtype + + qkv = self.attn_qkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads) + with torch.cuda.amp.autocast(enabled=False): + cos, sin = rotary_cos_sin + qkv = rotary.apply_rotary_pos_emb( + qkv, cos.to(qkv.dtype), sin.to(qkv.dtype) + ) + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + if seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, step=seq_len, + dtype=torch.int32, device=qkv.device + ) + else: + cu_seqlens = seqlens.cumsum(-1) + x = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, seq_len, 0., causal=True) + x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size) + + scale = torch.ones(1, device=x.device, dtype=x.dtype) + x = bias_dropout_scale_fn(self.attn_out(x), None, scale, x_skip, self.dropout) + + # mlp operation + x = bias_dropout_scale_fn( + self.mlp(self.norm2(x)), None, scale, x, self.dropout + ) + return x + +class CausalDiT(nn.Module): + def __init__(self, config): + super().__init__() + if isinstance(config, dict): + config = OmegaConf.create(config) + + self.config = config + self.vocab_size = config.interpolant.tokens + self.pad_token = config.interpolant.pad_token + + self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) + self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads) + self.blocks = nn.ModuleList([ + CausalDiTBlock(config.model.hidden_size, config.model.n_heads, config.model.cond_dim, dropout=config.model.dropout) + for _ in range(config.model.n_blocks) + ]) + self.output_layer = nn.Linear(config.model.hidden_size, self.vocab_size) + + def forward(self, indices): + x = self.vocab_embed(indices) + rotary_cos_sin = self.rotary_emb(x) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + for block in self.blocks: + x = block(x, rotary_cos_sin, seqlens=None) + logits = self.output_layer(x) + return logits \ No newline at end of file diff --git a/model/fused_add_dropout_scale.py b/model/fused_add_dropout_scale.py new file mode 100755 index 0000000000000000000000000000000000000000..56266ada1b469d84225e8db3eb617c05efc06cc8 --- /dev/null +++ b/model/fused_add_dropout_scale.py @@ -0,0 +1,66 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from torch import Tensor + +# flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + + +def bias_dropout_add_scale( + x: Tensor, + bias: Optional[Tensor], + scale: Tensor, + residual: Optional[Tensor], + prob: float, + training: bool, +) -> Tensor: + if bias is not None: + out = scale * F.dropout(x + bias, p=prob, training=training) + else: + out = scale * F.dropout(x, p=prob, training=training) + + if residual is not None: + out = residual + out + return out + + +def get_bias_dropout_add_scale(training): + def _bias_dropout_add(x, bias, scale, residual, prob): + return bias_dropout_add_scale(x, bias, scale, residual, prob, training) + + return _bias_dropout_add + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + return x * (1 + scale) + shift + + +@torch.jit.script +def bias_dropout_add_scale_fused_train( + x: Tensor, + bias: Optional[Tensor], + scale: Tensor, + residual: Optional[Tensor], + prob: float, +) -> Tensor: + return bias_dropout_add_scale(x, bias, scale, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_scale_fused_inference( + x: Tensor, + bias: Optional[Tensor], + scale: Tensor, + residual: Optional[Tensor], + prob: float, +) -> Tensor: + return bias_dropout_add_scale(x, bias, scale, residual, prob, False) + + +@torch.jit.script +def modulate_fused(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + return modulate(x, shift, scale) diff --git a/model/interpolant.py b/model/interpolant.py new file mode 100755 index 0000000000000000000000000000000000000000..a461bfdecd4eba425f17368e0893a092b8075826 --- /dev/null +++ b/model/interpolant.py @@ -0,0 +1,411 @@ +import abc +from typing import Optional +import torch +from torch import Tensor +from dataclasses import dataclass +from .schedule import Schedule +import torch.nn.functional as F + + +@dataclass +class ModelPrediction: + token_logits: Tensor + length_posterior: Optional[Tensor] + expected_gaps: Tensor + + def __init__( + self, + token_logits: Tensor, + length_posterior: Optional[Tensor] = None, + expected_gaps: Optional[Tensor] = None, + ): + assert length_posterior is not None or expected_gaps is not None + self.token_logits = token_logits + self.length_posterior = length_posterior + self.expected_gaps = expected_gaps + if self.expected_gaps is None: + _, _, L = self.length_posterior.shape + index = torch.arange(0, L, device=token_logits.device).view(1, 1, -1) + self.expected_gaps = (F.softmax(self.length_posterior, dim=-1) * index).sum(dim=-1) + + +@dataclass +class Rate: + unmask_rate: Tensor # Shape [Batch, Length, Vocab] + length_rate: Tensor # Shape [Batch] + + +@dataclass +class HittingTime: + insertion_time: Tensor # Shape [Batch, Length] + unmasking_time: Tensor # Shape [Batch, Length] + + def __iter__(self): + yield from [self.insertion_time, self.unmasking_time] + + +@dataclass +class JointInterpolantResult: + # Joint Interpolant + xt: Tensor # Shape [Batch, Length] + st: Tensor # Shape [Batch, Length] + _x1: Tensor + _pad_token: int + _mask_token: int + + @property + def mask_indices(self) -> Tensor: + return self.xt == self._mask_token + + @property + def unmasked(self) -> Tensor: + return torch.gather(self._x1, 1, self.st) + + @property + def xt_length(self) -> Tensor: + # Calculate length of xt + return (self.xt != self._pad_token).sum(dim=1) + + @property + def x1_length(self) -> Tensor: + # Calculate length of x1 + return (self._x1 != self._pad_token).sum(dim=1) + + @property + def gaps_and_mask(self) -> tuple[Tensor, Tensor]: + x1_len = self.x1_length + gaps = self.st.clone() + + pad_front = gaps.new_zeros((gaps.shape[0], 1)) - 1 # -1 for the front padding + pad_back = gaps.new_zeros((gaps.shape[0], 1)) + gaps = torch.cat([pad_front, gaps, pad_back], dim=1) # Add a leading zero + + gaps.scatter_( + 1, self.xt_length.unsqueeze(1) + 1, x1_len.unsqueeze(1) + ) # Fill the last position with x1_len + + gaps = gaps[:, 1:] - gaps[:, :-1] - 1 + gaps = torch.clamp(gaps, min=0) + + idx = torch.arange(gaps.size(1), device=self.xt.device).unsqueeze( + 0 + ) # shape [1, max_gap] + mask = idx <= self.xt_length.unsqueeze(1) + gaps[~mask] = 0 + + return gaps, mask + + +class JointInterpolant(abc.ABC): + def __init__( + self, + vocab_size: int, + mask_token: int, + pad_token: int, + max_length: int, + ): + """ + TODO: Add knobs + """ + self.mask_token = mask_token + self.pad_token = pad_token + self.max_length = max_length + self.vocab_size = vocab_size + + @abc.abstractmethod + def elbo_weight(self, t: Tensor, x1: Tensor): + """ + Return the ELBO weight for the training, can be changed depends on the empirical results + Shape: + t: [B] + Returns: + weight_unmask: [B, L] + weight_delete: [B, L+1] + """ + raise NotImplementedError + + @abc.abstractmethod + def to_actual_rate(self, prediction: ModelPrediction, t: Tensor) -> Rate: + raise NotImplementedError + + @abc.abstractmethod + def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult: + """ + Sample the interpolant xt from x1 at time t + Shapes: + x1: [B, L] + t: [B] + Returns: + xt: [B, L] + st: [B, L] boolean mask of positions that corresponds to xt + xt_mask_indices: [B, L] boolean mask of positions that are masked at xt + x1_remained: [B, L] tokens that are not deleted, used for the training target + gap_counts: [B, L+1] the number of deleted tokens between xt slots + """ + raise NotImplementedError + + +class AnyOrderMaskInsertionInterpolant(JointInterpolant): + def __init__( + self, + insertion_schedule: Schedule, + unmask_schedule: Schedule, + vocab_size: int, + mask_token: int, + pad_token: int, + max_length: int, + ): + super().__init__(vocab_size, mask_token, pad_token, max_length) + self.insertion_schedule = insertion_schedule + self.unmask_schedule = unmask_schedule + #self.max_length = 500 + + def expected_mask_fraction(self, t: Tensor, xt: Tensor) -> Tensor: + """ + Compute the expected fraction of tokens that should be masked at time t. + For AnyOrderMaskInsertionInterpolant, tokens are: + - Deleted (pad) if t < insertion_time + - Masked if insertion_time <= t < unmasking_time + - Unmasked if t >= unmasking_time + + We approximate: E[fraction masked] ≈ max(0, insertion_schedule.at(t) - unmask_schedule.at(t)) + + Args: + t: [B] current time + xt: [B, L] current sequence (to get current length) + Returns: + [B] expected number of masked tokens per sequence + """ + # Get schedule values at time t + insertion_progress = self.insertion_schedule.at(t) # [B] + unmask_progress = self.unmask_schedule.at(t) # [B] + + # Expected fraction of tokens that are inserted but not yet unmasked + # Clamp to ensure non-negative + expected_mask_frac = torch.clamp(insertion_progress - unmask_progress, min=0.0, max=1.0) + + # Get current sequence length (non-pad tokens) + current_length = (xt != self.pad_token).sum(dim=1).float() # [B] + + # Expected number of masked tokens + expected_num_masked = expected_mask_frac * current_length # [B] + + return expected_num_masked + + def hitting_time(self, t: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]: + """ + t1 is sampled from a uniform distribution over [0, 1]. when t1 < self.mask_schedule.at(t) + t2 is sampled from a uniform distribution over [t1, 1] + """ + B, L = x1.shape + eps = 1e-6 + + insert_time = self.insertion_schedule.sample((B, L), device=x1.device) + insert_time = eps + (1 - eps) * insert_time # ensure t1 is not 0 + unmask_time = self.unmask_schedule.sample_truncated( + insert_time, (B, L), device=x1.device + ) + + return insert_time, unmask_time + + def elbo_weight(self, t: Tensor, x1: Tensor): + """ + Return the ELBO weight for the training, can be changed depends on the empirical results + """ + insert_weight = self.insertion_schedule.rate_scale_factor(t) + insert_weight = insert_weight[:, None].expand(-1, x1.shape[1] + 1) + + unmask_weight = self.unmask_schedule.rate_scale_factor(t) + unmask_weight = unmask_weight.unsqueeze(1).expand(-1, x1.shape[1]) + + return unmask_weight, insert_weight + + def to_actual_rate( + self, xt: Tensor, prediction: ModelPrediction, t: Tensor + ) -> Rate: + """ + Return the actual rate for the sampling + Args: + xt: [B, L] the sampled tokens + prediction: ModelPrediction object containing token_posterior and expected_gaps + t: [B] the time parameter + """ + token_posterior = F.softmax(prediction.token_logits, dim=-1) # (B, L, V) + unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view( + -1, 1, 1 + ) + + length_rate = ( + prediction.expected_gaps + * self.insertion_schedule.rate_scale_factor(t).view(-1, 1) + ) + #print("expected_gaps:", prediction.expected_gaps, "length_rate:", length_rate) + + return Rate( + unmask_rate=unmask_rate, # (B, L, V) + length_rate=length_rate, # (B, L+1) + ) + + def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult: + """ + Shapes: + x1: [B, L] + t: [B] + Returns: + xt: [B, L] + st: [B, L] boolean mask of positions that corresponds to xt + xt_mask_indices: [B, L] boolean mask of positions that are masked at xt + x1_remained: [B, L] tokens that are not deleted, used for the training target + gap_counts: [B, L+1] the number of deleted tokens between xt slots + """ + # sample the stopping time (B, L, 2) + insertion_time, unmasking_time = self.hitting_time(t, x1) + + clean_tokens = x1.ne(self.pad_token) + deleted_tokens = clean_tokens & (t[:, None] < insertion_time) + masked_tokens = ( + clean_tokens + & (t[:, None] >= insertion_time) + & (t[:, None] < unmasking_time) + ) + + xt = torch.where( + deleted_tokens, + self.pad_token, # for deletion, change to pad token + torch.where( + masked_tokens, + self.mask_token, # for masking, change to mask token + x1, + ), + ) + + st = xt.ne(self.pad_token).to(torch.int32).argsort(dim=1, descending=True, stable=True) # edited to sort integers + xt = torch.gather(xt, 1, st) + st[xt == self.pad_token] = 0 + + return JointInterpolantResult( + xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token + ) + + def sample_interpolant_plan(self, t: Tensor, x1: Tensor) -> JointInterpolantResult: + """ + Shapes: + x1: [B, L] + t: [B] + Returns: + xt: [B, L] + st: [B, L] boolean mask of positions that corresponds to xt + xt_mask_indices: [B, L] boolean mask of positions that are masked at xt + x1_remained: [B, L] tokens that are not deleted, used for the training target + gap_counts: [B, L+1] the number of deleted tokens between xt slots + """ + # sample the stopping time (B, L, 2) + insertion_time, unmasking_time = self.hitting_time(t, x1) + + clean_tokens = x1.ne(self.pad_token) + deleted_tokens = clean_tokens & (t[:, None] < insertion_time) + masked_tokens = ( + clean_tokens + & (t[:, None] >= insertion_time) + & (t[:, None] < unmasking_time) + ) + + xt = torch.where( + deleted_tokens, + self.pad_token, # for deletion, change to pad token + torch.where( + masked_tokens, + self.mask_token, # for masking, change to mask token + x1, + ), + ) + st = xt.ne(self.pad_token).to(torch.int32).argsort(dim=1, descending=True, stable=True) # edited to sort integers + xt = torch.gather(xt, 1, st) + st[xt == self.pad_token] = 0 + num_gaps = (st != 0).sum(dim=1) + 1 # [B] + + deleted_mask = deleted_tokens # [B, L] + + # Create gap assignment tensor: gap_assignment[b, gap_idx, x1_pos] = 1 if x1_pos is in gap gap_idx + B, L = x1.shape + max_gaps = L + 1 + gap_assignment = torch.zeros(B, max_gaps, L, device=x1.device, dtype=torch.float) + + # For each deleted position in x1, determine which gap it belongs to + # Gap index = number of non-deleted positions (st values) that come before it + pos_indices = torch.arange(L, device=x1.device).view(1, L, 1) # [1, L, 1] + st_expanded = st.unsqueeze(1) # [B, 1, L] + st_valid_mask = (st != 0).unsqueeze(1) # [B, 1, L] + + # Count how many valid st entries are less than each position + # gap_indices[b, pos] = number of st values < pos for deleted positions + gap_indices = ((st_expanded < pos_indices) & st_valid_mask).sum(dim=2) # [B, L] + + # Set gap_assignment[b, gap_idx, pos] = 1 where pos is deleted and belongs to gap_idx + batch_idx = torch.arange(B, device=x1.device).view(B, 1).expand(B, L) + pos_idx = torch.arange(L, device=x1.device).view(1, L).expand(B, L) + + gap_assignment[batch_idx[deleted_mask], gap_indices[deleted_mask], pos_idx[deleted_mask]] = 1.0 + + return JointInterpolantResult( + xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token + ), deleted_mask, gap_assignment + + +class MDMInterpolant(JointInterpolant): + def __init__( + self, + unmask_schedule: Schedule, + vocab_size: int, + mask_token: int, + pad_token: int, + max_length: int, + ): + super().__init__(vocab_size, mask_token, pad_token, max_length) + self.unmask_schedule = unmask_schedule + + def elbo_weight(self, t: Tensor, x1: Tensor): + """ + Return the ELBO weight for the training, can be changed depends on the empirical results + there's no weight_delete for the vanilla MDM + """ + weight_unmask = self.unmask_schedule.rate_scale_factor(t) + weight_unmask_expanded = weight_unmask.unsqueeze(1).expand( + -1, x1.shape[1] + ) # (B,L) + return weight_unmask_expanded + + def to_actual_rate(self, xt: Tensor, prediction: Tensor, t: Tensor) -> Rate: + """ + Return the actual rate for the sampling + """ + token_posterior = F.softmax(prediction, dim=-1) # (B, L, V) + unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view( + -1, 1, 1 + ) + + return Rate( + unmask_rate=unmask_rate, # (B, L, V) + length_rate=None, # (B, L+1) + ) + + def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult: + # sample the stopping time (B, L, 2) + eps = 1e-6 + unmask_time = self.unmask_schedule.sample( + (x1.shape[0], x1.shape[1]), device=x1.device + ) + unmask_time = unmask_time * (1 - eps) + eps + + xt = torch.where( + t[:, None] < unmask_time, + self.mask_token, # for masking, change to mask token + x1, + ) + st = torch.arange(xt.shape[1], device=xt.device, dtype=torch.long).repeat( + xt.shape[0], 1 + ) + + return JointInterpolantResult( + xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token + ) diff --git a/model/model_wrapper.py b/model/model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0993746aa08fede9e06596c763e5bb1e5b426c46 --- /dev/null +++ b/model/model_wrapper.py @@ -0,0 +1,100 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ------------------------------------------------------------ +# additional sigmoid head +# ------------------------------------------------------------ +class RemaskingHead(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.norm = nn.LayerNorm(hidden_size) + self.proj1 = nn.Linear(hidden_size, hidden_size) + self.act = nn.GELU() + self.proj2 = nn.Linear(hidden_size, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm(x) + h = self.proj1(h) + h = self.act(h) + h = self.proj2(h) + return h + + +class InsertionQualityHead(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.norm = nn.LayerNorm(hidden_size) + self.proj1 = nn.Linear(hidden_size, hidden_size) + self.act = nn.GELU() + self.proj2 = nn.Linear(hidden_size, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm(x) + h = self.proj1(h) + h = self.act(h) + h = self.proj2(h) + return h + + +class RemaskingAnyOrder(nn.Module): + """Remasking adapter for AnyOrderMaskInsertionFlow models.""" + def __init__(self, backbone: nn.Module, d_model: int, insertion_planner: bool = False): + super().__init__() + + # Store backbone as non-module attribute to avoid circular reference in module tree + # Use object.__setattr__ to bypass nn.Module's __setattr__ which registers modules + object.__setattr__(self, 'backbone', backbone) + self.d_model = d_model + self.insertion_planner = insertion_planner + self.remasking_head = RemaskingHead(d_model) + + if insertion_planner: + self.insertion_head = InsertionQualityHead(d_model) + + def forward(self, indices: torch.Tensor, t: torch.Tensor, **kwargs): + """ + Forward pass for remasking training. + + Args: + indices: Token indices [batch_size, seq_len] + t: Timesteps [batch_size] + **kwargs: Additional arguments (ignored for compatibility) + + Returns: + Dict with 'logits', 'remasking_conf', and optionally 'insertion_conf' keys + """ + # Single backbone pass returning both prediction and post-block features. + # features has shape [B, L+1, hidden]; the remasking/insertion heads use + # the same [B, L, hidden] slice that get_hidden_states would have returned. + prediction, features = self.backbone(indices, t, return_features=True) + hidden_states = features[:, :-1] + + remasking_conf = self.remasking_head(hidden_states) + token_logits = prediction.token_logits + + result = {"logits": token_logits, "remasking_conf": remasking_conf} + + if self.insertion_planner: + insertion_conf = self.insertion_head(hidden_states) + result["insertion_conf"] = insertion_conf + + return result + + def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor): + """ + Get hidden states and logits for adapter training. + + Args: + indices: Token indices [batch_size, seq_len] + t: Timesteps [batch_size] + + Returns: + Tuple of (token_logits, hidden_states, conditioning) + """ + return self.backbone.get_hidden_states(indices, t) + + @property + def device(self): + return next(self.backbone.parameters()).device diff --git a/model/rotary.py b/model/rotary.py new file mode 100755 index 0000000000000000000000000000000000000000..6bda73203ff2aaaf75e8db376fc2d340b601035c --- /dev/null +++ b/model/rotary.py @@ -0,0 +1,54 @@ +import torch + + +class Rotary(torch.nn.Module): + def __init__(self, dim, base=10_000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + # dims are: batch, seq_len, qkv, head, dim + self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) + self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) + # This makes the transformation on v an identity. + self.cos_cached[:, :, 2, :, :].fill_(1.0) + self.sin_cached[:, :, 2, :, :].fill_(0.0) + + return self.cos_cached, self.sin_cached + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb_native(qkv, cos, sin): + """Native PyTorch implementation without JIT compilation""" + return (qkv * cos) + (rotate_half(qkv) * sin) + + +@torch.jit.script +def _apply_rotary_pos_emb_torchscript(qkv, cos, sin): + return (qkv * cos) + (rotate_half(qkv) * sin) + + +def apply_rotary_pos_emb(qkv, cos, sin): + try: + import flash_attn.layers.rotary + + cos_flash = cos[0, :, 0, 0, : cos.shape[-1] // 2] + sin_flash = sin[0, :, 0, 0, : sin.shape[-1] // 2] + return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos_flash, sin_flash) + except (ImportError, AttributeError, RuntimeError): + # Use native implementation without TorchScript due to compatibility issues + return _apply_rotary_pos_emb_native(qkv, cos, sin) diff --git a/model/schedule.py b/model/schedule.py new file mode 100755 index 0000000000000000000000000000000000000000..38e27fd72d24cbba5b95c1d1f63fa0a2b7e2c48b --- /dev/null +++ b/model/schedule.py @@ -0,0 +1,156 @@ +import abc +from omegaconf import DictConfig +import torch +import torch.nn as nn +from torch import Tensor + + +def get_schedule_from_config(config: DictConfig): + match config.type: + case "geometric": + return GeometricSchedule(min_val=config.min, max_val=config.max) + case "linear": + return LinearSchedule() + case "sin": + return SinSchedule() + case "cosine": + return CosineSchedule() + case "polynomial": + return PolynomialSchedule(exp=config.exp) + case _: + raise ValueError(f"Invalid schedule type: {config.type}") + + +class Schedule(abc.ABC): + """ + Generic schedule class for masking or noising + This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately + """ + + @abc.abstractmethod + def at(self, t: Tensor): + """ + Return value a(t) + """ + raise NotImplementedError + + @abc.abstractmethod + def derivative_at(self, t: Tensor): + """ + Return d/dt a(t) + """ + raise NotImplementedError + + def rate_scale_factor(self, t: Tensor) -> Tensor: + """ + Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation + """ + return self.derivative_at(t) / (1 - self.at(t)) + + def sample(self, shape, device) -> Tensor: + """ + Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1] + """ + uniform = torch.rand(shape, device=device) + return self.inv(uniform) + + def sample_truncated(self, threshold, shape, device) -> Tensor: + """ + Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1] + """ + uniform = torch.rand(shape, device=device) + threshold = self.at(threshold) + return self.inv(uniform * (1 - threshold) + threshold) + + @abc.abstractmethod + def inv(self, alpha: Tensor): + """ + Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t. + """ + raise NotImplementedError + + +class LinearSchedule(Schedule): + def __init__(self): + pass + + def at(self, t: Tensor): + return t + + def derivative_at(self, t: Tensor): + return torch.ones_like(t, device=t.device) + + def inv(self, alpha: Tensor): + return alpha + + +class GeometricSchedule(Schedule, nn.Module): + def __init__(self, min_val: float, max_val: float): + super().__init__() + self.register_buffer("min", Tensor([min_val])) + self.register_buffer("max", Tensor([max_val])) + + def at(self, t: Tensor): + min_val = self.min.to(t.device) + max_val = self.max.to(t.device) + return torch.exp(-(min_val ** (1 - t)) * max_val**t) + + def derivative_at(self, t): + min_val = self.min.to(t.device) + max_val = self.max.to(t.device) + return ( + self.at(t) + * min_val ** (1 - t) + * max_val**t + * (min_val.log() - max_val.log()) + ) + + def inv(self, alpha: Tensor): + log_min = self.min.to(alpha.device).log() + log_max = self.max.to(alpha.device).log() + return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min) + + +class SinSchedule(Schedule, nn.Module): + def __init__(self): + super().__init__() + + def at(self, t: Tensor): + return torch.sin(torch.pi / 2 * t) + + def derivative_at(self, t: Tensor): + return (torch.pi / 2) * torch.cos(torch.pi / 2 * t) + + def inv(self, alpha: Tensor): + return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.)) + + +class CosineSchedule(Schedule, nn.Module): + def __init__(self): + super().__init__() + + def at(self, t: Tensor): + return 1 - torch.cos(torch.pi / 2 * t) + + def derivative_at(self, t: Tensor): + return (torch.pi / 2) * torch.sin(torch.pi / 2 * t) + + def rate_scale_factor(self, t): + return (torch.pi/2) * torch.tan(torch.pi / 2 * t) + + def inv(self, alpha): + return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.)) + +class PolynomialSchedule(Schedule, nn.Module): + def __init__(self, exp): + super().__init__() + self.exp = exp + + def at(self, t: Tensor): + return t ** self.exp + + def derivative_at(self, t: Tensor): + return self.exp * t ** (self.exp - 1) + + def inv(self, alpha: Tensor): + return alpha ** (1 / self.exp) \ No newline at end of file diff --git a/model/transformer.py b/model/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..f488aa4420c9d0aa0815d87d430f4986dea0bbd3 --- /dev/null +++ b/model/transformer.py @@ -0,0 +1,404 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from einops import rearrange +from omegaconf import OmegaConf +from .interpolant import ModelPrediction +from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from . import rotary +from .fused_add_dropout_scale import ( + bias_dropout_add_scale_fused_train, + bias_dropout_add_scale_fused_inference, + modulate_fused, +) + + +# Disable torch.compile for flex_attention due to batch dimension mismatch issues +# flex_attention = torch.compile(flex_attention, mode="max-autotune") + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Layers # +################################################################################# +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.weight = nn.Parameter(torch.ones([dim])) + self.dim = dim + + def forward(self, x): + with torch.amp.autocast("cuda", enabled=False): + x = F.layer_norm(x.float(), [self.dim]) + return x * self.weight[None, None, :] + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, silu=True): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, cond_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, cond_size) + self.num_classes = num_classes + + # TODO think of initializing with 0.02 std deviation like in original DiT paper + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +# length scalar head +class ScalarLengthHead(nn.Module): + def __init__(self, d_model: int, normalized_len: int, cond_dim: int | None = None): + super().__init__() + self.has_cond = cond_dim is not None + if self.has_cond: + self.adaLN = nn.Linear(cond_dim, 2 * d_model, bias=True) + self.adaLN.weight.data.zero_() + self.adaLN.bias.data.zero_() + + self.norm = LayerNorm(d_model) + self.proj1 = nn.Linear(d_model, d_model) + self.act = nn.GELU() + self.proj2 = nn.Linear(d_model, 1) + self.softplus = nn.Softplus() + self.normalized_len = normalized_len + + def forward(self, x: torch.Tensor, c: torch.Tensor | None = None): + x_fp32 = x.float() + c_fp32 = c.float() if (self.has_cond and c is not None) else None + if self.has_cond and c_fp32 is not None: + shift, scale = self.adaLN(c_fp32)[:, None].chunk(2, dim=2) + x_fp32 = modulate_fused(self.norm(x_fp32), shift, scale) + else: + x_fp32 = self.norm(x_fp32) + s = self.proj2(self.act(self.proj1(x_fp32))) + out = self.softplus(s).squeeze(-1) * self.normalized_len + return out.to(x.dtype) + + +################################################################################# +# Core Model # +################################################################################# + + +def get_mask_mod(seq_len: torch.Tensor): + def mask_mod(b, h, q_idx, kv_idx): + return (q_idx <= seq_len[b]) & (kv_idx <= seq_len[b]) + + return mask_mod + + +class DDiTBlock(nn.Module): + def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): + super().__init__() + self.n_heads = n_heads + + self.norm1 = LayerNorm(dim) + self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + self.dropout1 = nn.Dropout(dropout) + + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_ratio * dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_ratio * dim, dim, bias=True), + ) + self.dropout2 = nn.Dropout(dropout) + + self.dropout = dropout + + self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + def _get_bias_dropout_scale(self): + return ( + bias_dropout_add_scale_fused_train + if self.training + else bias_dropout_add_scale_fused_inference + ) + + def forward(self, x, rotary_cos_sin, c, block_mask): + batch_size = x.shape[0] + + bias_dropout_scale_fn = self._get_bias_dropout_scale() + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(c)[:, None].chunk(6, dim=2) + ) + + # attention operation + x_skip = x + x = modulate_fused(self.norm1(x), shift_msa, scale_msa) + # dtype0 = x.dtype + + qkv = self.attn_qkv(x) + qkv = rearrange( + qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads + ) + with torch.amp.autocast("cuda", enabled=False): + cos, sin = rotary_cos_sin + qkv = rotary.apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) + + q, k, v = rearrange(qkv, "b s three h d -> three b h s d", three=3) + + x = flex_attention(q, k, v, block_mask=block_mask) + + x = rearrange(x, "b h s d -> b s (h d)", b=batch_size) + + x = bias_dropout_scale_fn( + self.attn_out(x), None, gate_msa, x_skip, self.dropout + ) + + # mlp operation + x = bias_dropout_scale_fn( + self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)), + None, + gate_mlp, + x, + self.dropout, + ) + + return x + + +class EmbeddingLayer(nn.Module): + def __init__(self, dim, vocab_dim): + super().__init__() + self.embedding = nn.Parameter(torch.empty((vocab_dim, dim))) + torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) + + def forward(self, x): + return self.embedding[x] + + +class DDitFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, cond_dim): + super().__init__() + self.norm_final = LayerNorm(hidden_size) + self.linear = nn.Linear(hidden_size, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) + x = modulate_fused(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AnyOrderMaskInsertionFlow(nn.Module): + def __init__(self, config): + super().__init__() + + # hack to make loading in configs easier + if isinstance(config, dict): + config = OmegaConf.create(config) + + self.config = config + self.vocab_size = config.interpolant.tokens + self.pad_token = config.interpolant.pad_token + self.mask_token = config.interpolant.mask_token + + # Get dtype from config, default to bfloat16 + dtype_str = config.model.get('torch_dtype', 'bfloat16') + self.dtype = getattr(torch, dtype_str) + + self.vocab_embed = EmbeddingLayer(config.model.hidden_size, self.vocab_size) + self.sigma_map = TimestepEmbedder(config.model.cond_dim) + self.rotary_emb = rotary.Rotary( + config.model.hidden_size // config.model.n_heads + ) + + self.blocks = nn.ModuleList( + [ + DDiTBlock( + config.model.hidden_size, + config.model.n_heads, + config.model.cond_dim, + dropout=config.model.dropout, + ) + for _ in range(config.model.n_blocks) + ] + ) + + self.output_layer = DDitFinalLayer( + config.model.hidden_size, self.vocab_size, config.model.cond_dim + ) + + self.len_predict_type = config.training.loss_fn.insert + if self.len_predict_type == "distribution": + self.len_pred = DDitFinalLayer( + config.model.hidden_size, + config.interpolant.max_length + 1, + config.model.cond_dim, + ) + elif self.len_predict_type == "expectation": + normalized_len = config.interpolant.max_length + self.len_pred = ScalarLengthHead( + config.model.hidden_size, normalized_len, config.model.cond_dim + ) + else: + raise ValueError(f"Invalid length prediction type: {self.len_predict_type}") + + def _get_bias_dropout_scale(self): + return ( + bias_dropout_add_scale_fused_train + if self.training + else bias_dropout_add_scale_fused_inference + ) + + def forward(self, indices: torch.Tensor, t: torch.Tensor, return_features: bool = False): + B, L = indices.shape + indices = torch.cat( + [ + indices, + self.pad_token + * torch.ones((B, 1), device=indices.device, dtype=torch.int64), + ], + dim=-1, + ) + seq_lens = (indices != self.pad_token).sum(dim=-1).to(indices.device) + block_mask = create_block_mask( + get_mask_mod(seq_lens), + B=B, + H=None, + Q_LEN=indices.shape[1], + KV_LEN=indices.shape[1], + device=indices.device, + ) + + x = self.vocab_embed(indices) + c = F.silu(self.sigma_map(t)) + + rotary_cos_sin = self.rotary_emb(x) + + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, block_mask) + + # Store features after transformer blocks for optional sharing + features = x.clone() if return_features else None + + # --- unmasking --- + token_logits = self.output_layer(x[:, :-1], c) + + # --- length prediction --- + match self.len_predict_type: + case "distribution": + length_posterior = self.len_pred(x, c) + prediction = ModelPrediction( + token_logits=token_logits, + length_posterior=length_posterior, + ) + case "expectation": + prediction = ModelPrediction( + token_logits=token_logits, + expected_gaps=self.len_pred(x, c), + ) + + if return_features: + return prediction, features # [B, L+1, hidden_dim] + else: + return prediction + + def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor): + """Returns token logits, hidden states, and conditioning for adapter training.""" + B, L = indices.shape + indices = torch.cat( + [ + indices, + self.pad_token + * torch.ones((B, 1), device=indices.device, dtype=torch.int64), + ], + dim=-1, + ) + seq_lens = (indices != self.pad_token).sum(dim=-1).to(indices.device) + block_mask = create_block_mask( + get_mask_mod(seq_lens), + B=B, + H=None, + Q_LEN=indices.shape[1], + KV_LEN=indices.shape[1], + device=indices.device, + ) + + x = self.vocab_embed(indices) + c = F.silu(self.sigma_map(t)) + + rotary_cos_sin = self.rotary_emb(x) + + with torch.amp.autocast("cuda", dtype=self.dtype): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, block_mask) + + # Hidden states after transformer blocks + hidden_states = x[:, :-1] # [B, L, hidden_dim] - exclude padding position + + # Token logits + token_logits = self.output_layer(hidden_states, c) + + return token_logits, hidden_states, c