STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7)

Pretrained AO-GPT (any-order GPT) over STAMP molecular token sequences with a hybrid motif + character vocabulary. Trained on 30M unique filtered molecules. Achieves 79.00% GenMol quality β€” matching the 112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of the vocabulary size.

Highlights

  • Small vocab (2481): 2387 high-frequency atomic motifs (freq β‰₯ 5000)
    • 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91% of motif occurrences as atomic tokens; rare motifs expand to chars.
  • Training-time char fallback with log-interpolated probability (~2% at the most frequent motif, ~15% at the cutoff). The model sees each atomic motif in both atomic and char form, closing the train/inference gap for OOV motifs.
  • STAMP structural tokens ([J_*], [B_*], [S_*], [END]) act as natural motif boundaries β€” no extra [MS]/[ME] markers needed.
  • Drug-like outputs: 100% validity, 100% uniqueness (at N=1024), 79.00% pass the GenMol filter (QED β‰₯ 0.6 AND SA ≀ 4.0).

Files

file what it is
model.pt torch checkpoint: {model_state, cfg, epoch, representation, model_type}
hybrid_vocab.json full vocab with atomic motif map, frequencies, and char expansions
motif_vocab.txt source motif-freq file (v3_cm_union format: smiles\tn_heavy\tfreq)
hybrid_vocab.py self-contained HybridVocab class for decoding
config.json architecture summary + default sampling + eval numbers

Evaluation (N=1024 at T=0.95, top_p=0.85)

metric value
validity 100.00%
uniqueness (raw SMILES) 100.00%
quality over valid (QED β‰₯ 0.6 ∧ SA ≀ 4) 79.16%
GenMol score 79.00%
QED mean 0.727
SA mean 2.92
diversity (1 βˆ’ pairwise Tanimoto, 1024-bit Morgan r=2) 0.860

Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%. The hybrid model matches within noise at 28% of the parameter count.

Usage

1. Load vocab

from hybrid_vocab import HybridVocab
vocab = HybridVocab.load("hybrid_vocab.json")
# vocab.itos          -> list of 2481 token strings
# vocab.atomic_motifs -> {smiles: id} for the 2387 motifs
# vocab.motif_freq    -> {smiles: freq}
# vocab.motif_expansion -> {smiles: [char_id, ...]}

2. Load model

import torch
from dataclasses import dataclass, field
from typing import Optional

# Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm`
from stamp.benchmark.lm import LMConfig, TinyDecoderLM

ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
cfg = LMConfig(**ckpt["cfg"])
cfg.use_adaln = True  # AO-GPT arch
model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False)

state = ckpt["model_state"]
# Strip torch.compile prefix if present.
if any(k.startswith("_orig_mod.") for k in state):
    state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
model.load_state_dict(state)
model.eval().cuda()

3. Sample (AR, top-p)

import torch
from hybrid_vocab import is_stamp_structural

BOS, EOS = vocab.bos_id, vocab.eos_id
PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id
struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)}
suppress = {PAD, BOS, MASK, UNK}
T, P = 0.95, 0.85
n, max_new = 64, 64

@torch.no_grad()
def sample(n_samples=64):
    x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda")
    finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda")
    for step in range(max_new):
        orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float()
        for sid in suppress:
            logits[:, sid] = float("-inf")
        # top-p
        sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = torch.softmax(sorted_logits / T, dim=-1)
        cum = torch.cumsum(sorted_probs, dim=-1)
        remove = cum > P
        remove[..., 1:] = remove[..., :-1].clone()
        remove[..., 0] = False
        sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
        logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits)
        probs = torch.softmax(logits / T, dim=-1)
        nxt = torch.multinomial(probs, 1).squeeze(-1)
        nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt)
        x = torch.cat([x, nxt.unsqueeze(1)], dim=1)
        finished = finished | (nxt == EOS)
        if finished.all():
            break
    return x

4. Decode token stream β†’ SMILES

def decode_to_stamp_tokens(ids):
    """Flush character runs to motif SMILES at structural token boundaries."""
    special = {PAD, BOS, EOS, MASK, UNK}
    out, buf = [], []
    for i in ids:
        if i in special: continue
        tok = vocab.itos[i]
        if i in struct_ids:
            if buf: out.append("".join(buf)); buf = []
            out.append(tok)
        else:
            buf.append(tok)
    if buf: out.append("".join(buf))
    return out

# Then run through the STAMP codec in the stamp repo:
#     from stamp.benchmark.representations import build_representation
#     rep = build_representation("stamp")
#     text = rep.detokenize(stamp_tokens)
#     mol  = rep.codec.decode_stamp_to_mol(text)

Sample outputs

Ten representative draws from this checkpoint (all drug-like, QED β‰₯ 0.6 ∧ SA ≀ 4):

Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O                   QED=0.935 SA=2.46 MW=300
CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O            QED=0.923 SA=3.31 MW=327
Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1          QED=0.921 SA=2.86 MW=332
CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1            QED=0.908 SA=1.94 MW=349
COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1                    QED=0.907 SA=2.16 MW=263
CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353
CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O           QED=0.905 SA=2.93 MW=314
Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C      QED=0.897 SA=2.45 MW=375
CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2      QED=0.896 SA=3.37 MW=340
NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1      QED=0.884 SA=2.96 MW=345

Architecture notes

  • AO-GPT: decoder-only transformer with causal attention over a shuffled token order per batch (random permutation of middle tokens, BOS/EOS pinned at ends). Target position is conditioned via AdaLN so the model learns "any-order" decoding.
  • Hybrid vocab: structural tokens + SMILES char tokens + atomic motif tokens share a single id space. At training time, atomic motif tokens may be expanded to their SMILES char form with a log-frequency-weighted probability (HybridVocab.fallback_prob) so the model is not brittle at char-level decoding.
  • Decoder: the STAMP structural tokens delimit motifs; consecutive character tokens between structural tokens concatenate to a single motif SMILES, which the STAMP codec parses to a molecule via a stack machine with safety fallbacks.

License

Apache-2.0.

Citation

Cite the STAMP representation paper and this repository. (Placeholder β€” fill in with your actual citation info.)

Downloads last month
17
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support