File size: 9,747 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import entrypoint_setup

import random
import tempfile
import os
import torch

from esm2.modeling_fastesm import FastEsmForMaskedLM
from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM
from e1_fastplms.modeling_e1 import E1ForMaskedLM
from dplm_fastplms.modeling_dplm import DPLMForMaskedLM
from dplm2_fastplms.modeling_dplm2 import (
    DPLM2ForMaskedLM,
    _has_packed_multimodal_layout,
    _normalize_dplm2_input_ids,
)
from embedding_mixin import parse_fasta


CANONICAL_AAS = "ACDEFGHIKLMNPQRSTVWY"
SEED = 42
DEFAULT_BATCH_SIZE = 4
MAX_EMBED_LEN = 128  # fixed pad length used to keep max_seqlen identical across runs


# (display_name, model_class, hf_path, use_model_tokenizer)
MODEL_CONFIGS = [
    ("ESM2",  FastEsmForMaskedLM,       "Synthyra/ESM2-8M",           True),
    ("ESM++", ESMplusplusForMaskedLM,   "Synthyra/ESMplusplus_small",  True),
    ("E1",    E1ForMaskedLM,            "Synthyra/Profluent-E1-150M",  False),
    ("DPLM",  DPLMForMaskedLM,          "Synthyra/DPLM-150M",          True),
    ("DPLM2", DPLM2ForMaskedLM,         "Synthyra/DPLM2-150M",         True),
]


def test_parse_fasta() -> None:
    """Test parse_fasta with single-line and multi-line sequences."""
    fasta_content = (
        ">seq1 a simple protein\n"
        "MKTLLLTLVVVTIVCLDLGYT\n"
        ">seq2 multi-line sequence\n"
        "ACDEFGHIKL\n"
        "MNPQRSTVWY\n"
        ">seq3 another entry\n"
        "MALWMRLLPLLALL\n"
    )
    expected = [
        "MKTLLLTLVVVTIVCLDLGYT",
        "ACDEFGHIKLMNPQRSTVWY",
        "MALWMRLLPLLALL",
    ]
    with tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) as f:
        f.write(fasta_content)
        tmp_path = f.name
    parsed = parse_fasta(tmp_path)
    os.unlink(tmp_path)
    assert parsed == expected, f"parse_fasta mismatch:\n  got:      {parsed}\n  expected: {expected}"
    print("test_parse_fasta: OK")


class FixedLengthTokenizer:
    """Wraps a tokenizer so every call pads to exactly MAX_EMBED_LEN tokens.

    Both batch=1 and batch=N therefore receive tensors of the same shape,
    keeping max_seqlen_in_batch identical and eliminating floating-point
    variability from different softmax vector lengths / flash-attention tile sizes.
    """
    def __init__(self, tokenizer, max_length: int = MAX_EMBED_LEN):
        self._tok = tokenizer
        self.max_length = max_length

    def __call__(self, sequences, **kwargs):
        return self._tok(
            sequences,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )


def random_sequences(n: int, min_len: int = 8, max_len: int = 64) -> list[str]:
    """Variable-length sequences; used for the NaN test."""
    return [
        "M" + "".join(random.choices(CANONICAL_AAS, k=random.randint(min_len, max_len)))
        for _ in range(n)
    ]


def random_sequences_fixed_len(n: int, length: int = 64) -> list[str]:
    """Fixed-length sequences; used for the match test with E1 (sequence mode)."""
    return [
        "M" + "".join(random.choices(CANONICAL_AAS, k=length - 1))
        for _ in range(n)
    ]


def assert_no_nan(embeddings: dict[str, torch.Tensor], label: str) -> None:
    for seq, emb in embeddings.items():
        assert not torch.isnan(emb).any(), (
            f"[{label}] NaN found in embedding for sequence '{seq[:20]}...'"
        )


def assert_embeddings_match(
    a: dict[str, torch.Tensor],
    b: dict[str, torch.Tensor],
    label: str,
    atol: float = 5e-3,
) -> None:
    """Compare real-token embeddings from two runs.

    full_embeddings=True already strips padding via emb[mask.bool()], so both
    dicts contain only non-pad token rows and the comparison is over those rows.
    """
    assert set(a) == set(b), f"[{label}] Key sets differ between batch and single runs"
    for seq in a:
        ea, eb = a[seq].float(), b[seq].float()
        assert ea.shape == eb.shape, (
            f"[{label}] Shape mismatch for '{seq[:20]}': {ea.shape} vs {eb.shape}"
        )
        max_diff = (ea - eb).abs().max().item()
        assert max_diff <= atol, (
            f"[{label}] Max abs diff {max_diff:.5f} > {atol} for '{seq[:20]}'"
        )


def test_dplm2_multimodal_layout_guard() -> None:
    plain_sequence_type_ids = torch.tensor([
        [1, 1, 1, 1, 1, 1, 0, 2],
        [1, 1, 1, 1, 1, 0, 2, 2],
    ])
    packed_multimodal_type_ids = torch.tensor([
        [1, 1, 1, 2, 0, 0, 0, 2],
        [1, 1, 2, 2, 0, 0, 2, 2],
    ])
    mismatched_multimodal_type_ids = torch.tensor([
        [1, 1, 1, 2, 0, 0, 2, 2],
    ])

    assert not _has_packed_multimodal_layout(plain_sequence_type_ids, aa_type=1, struct_type=0, pad_type=2)
    assert _has_packed_multimodal_layout(packed_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2)
    assert not _has_packed_multimodal_layout(mismatched_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2)
    print("test_dplm2_multimodal_layout_guard: OK")


def test_dplm2_special_token_normalization() -> None:
    input_ids = torch.tensor([[8231, 5, 23, 13, 8229, 1, 8232, -100]])
    normalized_input_ids = _normalize_dplm2_input_ids(input_ids, vocab_size=8229)
    expected = torch.tensor([[0, 5, 23, 13, 2, 1, 32, -100]])
    assert torch.equal(normalized_input_ids, expected), (
        f"DPLM2 special-token normalization mismatch:\n"
        f"  got:      {normalized_input_ids.tolist()}\n"
        f"  expected: {expected.tolist()}"
    )
    print("test_dplm2_special_token_normalization: OK")


def test_model(name: str, model_cls, model_path: str, use_model_tokenizer: bool, batch_size: int) -> None:
    print(f"\n--- {name} ({model_path}) ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model_cls.from_pretrained(
        model_path,
        dtype=torch.bfloat16,
        device_map=device,
        trust_remote_code=True,
    ).eval()

    if use_model_tokenizer:
        # FixedLengthTokenizer pads every batch to MAX_EMBED_LEN regardless of
        # actual sequence lengths, so batch=1 and batch=N see the same tensor
        # shape and produce numerically identical real-token outputs.
        tokenizer = FixedLengthTokenizer(model.tokenizer)
        sequences = random_sequences(n=8)          # variable lengths, all padded to MAX_EMBED_LEN
    else:
        # E1 (sequence mode): control padding length via fixed-length sequences
        # so max_seqlen_in_batch is the same in every forward call.
        tokenizer = None
        sequences = random_sequences_fixed_len(n=8)  # fixed length, no padding variability

    nan_kwargs = dict(
        tokenizer=tokenizer,
        full_embeddings=True,  # extracts only real (non-pad) token rows via emb[mask.bool()]
        embed_dtype=torch.bfloat16,
        save=False,
    )

    # NaN test ----------------------------------------------------------------
    # Run in bfloat16 to match the real-world user scenario.
    # batch_size > 1 with padding present must produce no NaN in real-token rows.
    nan_embs = model.embed_dataset(sequences=sequences, batch_size=batch_size, **nan_kwargs)
    assert_no_nan(nan_embs, f"{name} NaN check batch_size={batch_size}")
    shapes = [tuple(e.shape) for e in list(nan_embs.values())[:3]]
    print(f"  NaN check batch_size={batch_size}: OK  sample shapes={shapes}")

    # Match test (tokenizer / SDPA models only) --------------------------------
    # The NaN fix only touches SDPA backends; E1 uses flash varlen which
    # inherently unpads and is unaffected.  Flash varlen is also NOT
    # bit-deterministic across different batch sizes (different numbers of
    # packed query blocks → different online-softmax accumulation order), so
    # a tight match test for E1 is not meaningful.
    #
    # For SDPA models we cast to float32: bfloat16 CUBLAS selects different
    # mat-mul algorithms for batch=1 vs batch=N (simple vs batched GEMM),
    # producing 1-ULP differences.  Float32 differences are < 1e-3.
    if not use_model_tokenizer:
        return

    model.to(torch.float32)
    batch_embs = model.embed_dataset(
        sequences=sequences, batch_size=batch_size,
        tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False,
    )
    single_embs = model.embed_dataset(
        sequences=sequences, batch_size=1,
        tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False,
    )
    assert_no_nan(batch_embs, f"{name} match test batch_size={batch_size}")
    assert_no_nan(single_embs, f"{name} match test batch_size=1")
    assert_embeddings_match(batch_embs, single_embs, name)
    print(f"  Match test batch_size={batch_size} vs 1: OK  (non-pad tokens only)")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Test embed_dataset produces no NaN with batch_size > 1.")
    parser.add_argument("--models", nargs="+", default=["ESM2", "ESM++", "E1", "DPLM", "DPLM2"])
    parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
    args = parser.parse_args()

    random.seed(SEED)
    test_parse_fasta()
    test_dplm2_multimodal_layout_guard()
    test_dplm2_special_token_normalization()

    valid_names = {cfg[0] for cfg in MODEL_CONFIGS}
    for name in args.models:
        assert name in valid_names, f"Unknown model '{name}'. Choose from {sorted(valid_names)}"

    configs_by_name = {cfg[0]: cfg for cfg in MODEL_CONFIGS}
    for model_name in args.models:
        name, model_cls, model_path, use_model_tokenizer = configs_by_name[model_name]
        test_model(name, model_cls, model_path, use_model_tokenizer, args.batch_size)

    print("\nAll tests passed!")