STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7)
Pretrained AO-GPT (any-order GPT) over STAMP molecular token sequences with a hybrid motif + character vocabulary. Trained on 30M unique filtered molecules. Achieves 79.00% GenMol quality β matching the 112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of the vocabulary size.
Highlights
- Small vocab (2481): 2387 high-frequency atomic motifs (freq β₯ 5000)
- 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91% of motif occurrences as atomic tokens; rare motifs expand to chars.
- Training-time char fallback with log-interpolated probability (~2% at the most frequent motif, ~15% at the cutoff). The model sees each atomic motif in both atomic and char form, closing the train/inference gap for OOV motifs.
- STAMP structural tokens (
[J_*],[B_*],[S_*],[END]) act as natural motif boundaries β no extra[MS]/[ME]markers needed. - Drug-like outputs: 100% validity, 100% uniqueness (at N=1024), 79.00% pass the GenMol filter (QED β₯ 0.6 AND SA β€ 4.0).
Files
| file | what it is |
|---|---|
model.pt |
torch checkpoint: {model_state, cfg, epoch, representation, model_type} |
hybrid_vocab.json |
full vocab with atomic motif map, frequencies, and char expansions |
motif_vocab.txt |
source motif-freq file (v3_cm_union format: smiles\tn_heavy\tfreq) |
hybrid_vocab.py |
self-contained HybridVocab class for decoding |
config.json |
architecture summary + default sampling + eval numbers |
Evaluation (N=1024 at T=0.95, top_p=0.85)
| metric | value |
|---|---|
| validity | 100.00% |
| uniqueness (raw SMILES) | 100.00% |
| quality over valid (QED β₯ 0.6 β§ SA β€ 4) | 79.16% |
| GenMol score | 79.00% |
| QED mean | 0.727 |
| SA mean | 2.92 |
| diversity (1 β pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 |
Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%. The hybrid model matches within noise at 28% of the parameter count.
Usage
1. Load vocab
from hybrid_vocab import HybridVocab
vocab = HybridVocab.load("hybrid_vocab.json")
# vocab.itos -> list of 2481 token strings
# vocab.atomic_motifs -> {smiles: id} for the 2387 motifs
# vocab.motif_freq -> {smiles: freq}
# vocab.motif_expansion -> {smiles: [char_id, ...]}
2. Load model
import torch
from dataclasses import dataclass, field
from typing import Optional
# Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm`
from stamp.benchmark.lm import LMConfig, TinyDecoderLM
ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
cfg = LMConfig(**ckpt["cfg"])
cfg.use_adaln = True # AO-GPT arch
model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False)
state = ckpt["model_state"]
# Strip torch.compile prefix if present.
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
model.load_state_dict(state)
model.eval().cuda()
3. Sample (AR, top-p)
import torch
from hybrid_vocab import is_stamp_structural
BOS, EOS = vocab.bos_id, vocab.eos_id
PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id
struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)}
suppress = {PAD, BOS, MASK, UNK}
T, P = 0.95, 0.85
n, max_new = 64, 64
@torch.no_grad()
def sample(n_samples=64):
x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda")
finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda")
for step in range(max_new):
orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float()
for sid in suppress:
logits[:, sid] = float("-inf")
# top-p
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits / T, dim=-1)
cum = torch.cumsum(sorted_probs, dim=-1)
remove = cum > P
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits)
probs = torch.softmax(logits / T, dim=-1)
nxt = torch.multinomial(probs, 1).squeeze(-1)
nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt)
x = torch.cat([x, nxt.unsqueeze(1)], dim=1)
finished = finished | (nxt == EOS)
if finished.all():
break
return x
4. Decode token stream β SMILES
def decode_to_stamp_tokens(ids):
"""Flush character runs to motif SMILES at structural token boundaries."""
special = {PAD, BOS, EOS, MASK, UNK}
out, buf = [], []
for i in ids:
if i in special: continue
tok = vocab.itos[i]
if i in struct_ids:
if buf: out.append("".join(buf)); buf = []
out.append(tok)
else:
buf.append(tok)
if buf: out.append("".join(buf))
return out
# Then run through the STAMP codec in the stamp repo:
# from stamp.benchmark.representations import build_representation
# rep = build_representation("stamp")
# text = rep.detokenize(stamp_tokens)
# mol = rep.codec.decode_stamp_to_mol(text)
Sample outputs
Ten representative draws from this checkpoint (all drug-like, QED β₯ 0.6 β§ SA β€ 4):
Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O QED=0.935 SA=2.46 MW=300
CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O QED=0.923 SA=3.31 MW=327
Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1 QED=0.921 SA=2.86 MW=332
CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1 QED=0.908 SA=1.94 MW=349
COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1 QED=0.907 SA=2.16 MW=263
CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353
CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O QED=0.905 SA=2.93 MW=314
Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C QED=0.897 SA=2.45 MW=375
CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2 QED=0.896 SA=3.37 MW=340
NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1 QED=0.884 SA=2.96 MW=345
Architecture notes
- AO-GPT: decoder-only transformer with causal attention over a shuffled token order per batch (random permutation of middle tokens, BOS/EOS pinned at ends). Target position is conditioned via AdaLN so the model learns "any-order" decoding.
- Hybrid vocab: structural tokens + SMILES char tokens + atomic motif
tokens share a single id space. At training time, atomic motif tokens
may be expanded to their SMILES char form with a
log-frequency-weighted probability (
HybridVocab.fallback_prob) so the model is not brittle at char-level decoding. - Decoder: the STAMP structural tokens delimit motifs; consecutive character tokens between structural tokens concatenate to a single motif SMILES, which the STAMP codec parses to a molecule via a stack machine with safety fallbacks.
License
Apache-2.0.
Citation
Cite the STAMP representation paper and this repository. (Placeholder β fill in with your actual citation info.)
- Downloads last month
- 17
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support