File size: 9,747 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | import entrypoint_setup
import random
import tempfile
import os
import torch
from esm2.modeling_fastesm import FastEsmForMaskedLM
from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM
from e1_fastplms.modeling_e1 import E1ForMaskedLM
from dplm_fastplms.modeling_dplm import DPLMForMaskedLM
from dplm2_fastplms.modeling_dplm2 import (
DPLM2ForMaskedLM,
_has_packed_multimodal_layout,
_normalize_dplm2_input_ids,
)
from embedding_mixin import parse_fasta
CANONICAL_AAS = "ACDEFGHIKLMNPQRSTVWY"
SEED = 42
DEFAULT_BATCH_SIZE = 4
MAX_EMBED_LEN = 128 # fixed pad length used to keep max_seqlen identical across runs
# (display_name, model_class, hf_path, use_model_tokenizer)
MODEL_CONFIGS = [
("ESM2", FastEsmForMaskedLM, "Synthyra/ESM2-8M", True),
("ESM++", ESMplusplusForMaskedLM, "Synthyra/ESMplusplus_small", True),
("E1", E1ForMaskedLM, "Synthyra/Profluent-E1-150M", False),
("DPLM", DPLMForMaskedLM, "Synthyra/DPLM-150M", True),
("DPLM2", DPLM2ForMaskedLM, "Synthyra/DPLM2-150M", True),
]
def test_parse_fasta() -> None:
"""Test parse_fasta with single-line and multi-line sequences."""
fasta_content = (
">seq1 a simple protein\n"
"MKTLLLTLVVVTIVCLDLGYT\n"
">seq2 multi-line sequence\n"
"ACDEFGHIKL\n"
"MNPQRSTVWY\n"
">seq3 another entry\n"
"MALWMRLLPLLALL\n"
)
expected = [
"MKTLLLTLVVVTIVCLDLGYT",
"ACDEFGHIKLMNPQRSTVWY",
"MALWMRLLPLLALL",
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) as f:
f.write(fasta_content)
tmp_path = f.name
parsed = parse_fasta(tmp_path)
os.unlink(tmp_path)
assert parsed == expected, f"parse_fasta mismatch:\n got: {parsed}\n expected: {expected}"
print("test_parse_fasta: OK")
class FixedLengthTokenizer:
"""Wraps a tokenizer so every call pads to exactly MAX_EMBED_LEN tokens.
Both batch=1 and batch=N therefore receive tensors of the same shape,
keeping max_seqlen_in_batch identical and eliminating floating-point
variability from different softmax vector lengths / flash-attention tile sizes.
"""
def __init__(self, tokenizer, max_length: int = MAX_EMBED_LEN):
self._tok = tokenizer
self.max_length = max_length
def __call__(self, sequences, **kwargs):
return self._tok(
sequences,
return_tensors="pt",
padding="max_length",
max_length=self.max_length,
truncation=True,
)
def random_sequences(n: int, min_len: int = 8, max_len: int = 64) -> list[str]:
"""Variable-length sequences; used for the NaN test."""
return [
"M" + "".join(random.choices(CANONICAL_AAS, k=random.randint(min_len, max_len)))
for _ in range(n)
]
def random_sequences_fixed_len(n: int, length: int = 64) -> list[str]:
"""Fixed-length sequences; used for the match test with E1 (sequence mode)."""
return [
"M" + "".join(random.choices(CANONICAL_AAS, k=length - 1))
for _ in range(n)
]
def assert_no_nan(embeddings: dict[str, torch.Tensor], label: str) -> None:
for seq, emb in embeddings.items():
assert not torch.isnan(emb).any(), (
f"[{label}] NaN found in embedding for sequence '{seq[:20]}...'"
)
def assert_embeddings_match(
a: dict[str, torch.Tensor],
b: dict[str, torch.Tensor],
label: str,
atol: float = 5e-3,
) -> None:
"""Compare real-token embeddings from two runs.
full_embeddings=True already strips padding via emb[mask.bool()], so both
dicts contain only non-pad token rows and the comparison is over those rows.
"""
assert set(a) == set(b), f"[{label}] Key sets differ between batch and single runs"
for seq in a:
ea, eb = a[seq].float(), b[seq].float()
assert ea.shape == eb.shape, (
f"[{label}] Shape mismatch for '{seq[:20]}': {ea.shape} vs {eb.shape}"
)
max_diff = (ea - eb).abs().max().item()
assert max_diff <= atol, (
f"[{label}] Max abs diff {max_diff:.5f} > {atol} for '{seq[:20]}'"
)
def test_dplm2_multimodal_layout_guard() -> None:
plain_sequence_type_ids = torch.tensor([
[1, 1, 1, 1, 1, 1, 0, 2],
[1, 1, 1, 1, 1, 0, 2, 2],
])
packed_multimodal_type_ids = torch.tensor([
[1, 1, 1, 2, 0, 0, 0, 2],
[1, 1, 2, 2, 0, 0, 2, 2],
])
mismatched_multimodal_type_ids = torch.tensor([
[1, 1, 1, 2, 0, 0, 2, 2],
])
assert not _has_packed_multimodal_layout(plain_sequence_type_ids, aa_type=1, struct_type=0, pad_type=2)
assert _has_packed_multimodal_layout(packed_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2)
assert not _has_packed_multimodal_layout(mismatched_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2)
print("test_dplm2_multimodal_layout_guard: OK")
def test_dplm2_special_token_normalization() -> None:
input_ids = torch.tensor([[8231, 5, 23, 13, 8229, 1, 8232, -100]])
normalized_input_ids = _normalize_dplm2_input_ids(input_ids, vocab_size=8229)
expected = torch.tensor([[0, 5, 23, 13, 2, 1, 32, -100]])
assert torch.equal(normalized_input_ids, expected), (
f"DPLM2 special-token normalization mismatch:\n"
f" got: {normalized_input_ids.tolist()}\n"
f" expected: {expected.tolist()}"
)
print("test_dplm2_special_token_normalization: OK")
def test_model(name: str, model_cls, model_path: str, use_model_tokenizer: bool, batch_size: int) -> None:
print(f"\n--- {name} ({model_path}) ---")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model_cls.from_pretrained(
model_path,
dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True,
).eval()
if use_model_tokenizer:
# FixedLengthTokenizer pads every batch to MAX_EMBED_LEN regardless of
# actual sequence lengths, so batch=1 and batch=N see the same tensor
# shape and produce numerically identical real-token outputs.
tokenizer = FixedLengthTokenizer(model.tokenizer)
sequences = random_sequences(n=8) # variable lengths, all padded to MAX_EMBED_LEN
else:
# E1 (sequence mode): control padding length via fixed-length sequences
# so max_seqlen_in_batch is the same in every forward call.
tokenizer = None
sequences = random_sequences_fixed_len(n=8) # fixed length, no padding variability
nan_kwargs = dict(
tokenizer=tokenizer,
full_embeddings=True, # extracts only real (non-pad) token rows via emb[mask.bool()]
embed_dtype=torch.bfloat16,
save=False,
)
# NaN test ----------------------------------------------------------------
# Run in bfloat16 to match the real-world user scenario.
# batch_size > 1 with padding present must produce no NaN in real-token rows.
nan_embs = model.embed_dataset(sequences=sequences, batch_size=batch_size, **nan_kwargs)
assert_no_nan(nan_embs, f"{name} NaN check batch_size={batch_size}")
shapes = [tuple(e.shape) for e in list(nan_embs.values())[:3]]
print(f" NaN check batch_size={batch_size}: OK sample shapes={shapes}")
# Match test (tokenizer / SDPA models only) --------------------------------
# The NaN fix only touches SDPA backends; E1 uses flash varlen which
# inherently unpads and is unaffected. Flash varlen is also NOT
# bit-deterministic across different batch sizes (different numbers of
# packed query blocks → different online-softmax accumulation order), so
# a tight match test for E1 is not meaningful.
#
# For SDPA models we cast to float32: bfloat16 CUBLAS selects different
# mat-mul algorithms for batch=1 vs batch=N (simple vs batched GEMM),
# producing 1-ULP differences. Float32 differences are < 1e-3.
if not use_model_tokenizer:
return
model.to(torch.float32)
batch_embs = model.embed_dataset(
sequences=sequences, batch_size=batch_size,
tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False,
)
single_embs = model.embed_dataset(
sequences=sequences, batch_size=1,
tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False,
)
assert_no_nan(batch_embs, f"{name} match test batch_size={batch_size}")
assert_no_nan(single_embs, f"{name} match test batch_size=1")
assert_embeddings_match(batch_embs, single_embs, name)
print(f" Match test batch_size={batch_size} vs 1: OK (non-pad tokens only)")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test embed_dataset produces no NaN with batch_size > 1.")
parser.add_argument("--models", nargs="+", default=["ESM2", "ESM++", "E1", "DPLM", "DPLM2"])
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
args = parser.parse_args()
random.seed(SEED)
test_parse_fasta()
test_dplm2_multimodal_layout_guard()
test_dplm2_special_token_normalization()
valid_names = {cfg[0] for cfg in MODEL_CONFIGS}
for name in args.models:
assert name in valid_names, f"Unknown model '{name}'. Choose from {sorted(valid_names)}"
configs_by_name = {cfg[0]: cfg for cfg in MODEL_CONFIGS}
for model_name in args.models:
name, model_cls, model_path, use_model_tokenizer = configs_by_name[model_name]
test_model(name, model_cls, model_path, use_model_tokenizer, args.batch_size)
print("\nAll tests passed!")
|