| | import entrypoint_setup |
| |
|
| | import random |
| | import tempfile |
| | import os |
| | import torch |
| |
|
| | from esm2.modeling_fastesm import FastEsmForMaskedLM |
| | from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM |
| | from e1_fastplms.modeling_e1 import E1ForMaskedLM |
| | from dplm_fastplms.modeling_dplm import DPLMForMaskedLM |
| | from dplm2_fastplms.modeling_dplm2 import ( |
| | DPLM2ForMaskedLM, |
| | _has_packed_multimodal_layout, |
| | _normalize_dplm2_input_ids, |
| | ) |
| | from embedding_mixin import parse_fasta |
| |
|
| |
|
| | CANONICAL_AAS = "ACDEFGHIKLMNPQRSTVWY" |
| | SEED = 42 |
| | DEFAULT_BATCH_SIZE = 4 |
| | MAX_EMBED_LEN = 128 |
| |
|
| |
|
| | |
| | MODEL_CONFIGS = [ |
| | ("ESM2", FastEsmForMaskedLM, "Synthyra/ESM2-8M", True), |
| | ("ESM++", ESMplusplusForMaskedLM, "Synthyra/ESMplusplus_small", True), |
| | ("E1", E1ForMaskedLM, "Synthyra/Profluent-E1-150M", False), |
| | ("DPLM", DPLMForMaskedLM, "Synthyra/DPLM-150M", True), |
| | ("DPLM2", DPLM2ForMaskedLM, "Synthyra/DPLM2-150M", True), |
| | ] |
| |
|
| |
|
| | def test_parse_fasta() -> None: |
| | """Test parse_fasta with single-line and multi-line sequences.""" |
| | fasta_content = ( |
| | ">seq1 a simple protein\n" |
| | "MKTLLLTLVVVTIVCLDLGYT\n" |
| | ">seq2 multi-line sequence\n" |
| | "ACDEFGHIKL\n" |
| | "MNPQRSTVWY\n" |
| | ">seq3 another entry\n" |
| | "MALWMRLLPLLALL\n" |
| | ) |
| | expected = [ |
| | "MKTLLLTLVVVTIVCLDLGYT", |
| | "ACDEFGHIKLMNPQRSTVWY", |
| | "MALWMRLLPLLALL", |
| | ] |
| | with tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) as f: |
| | f.write(fasta_content) |
| | tmp_path = f.name |
| | parsed = parse_fasta(tmp_path) |
| | os.unlink(tmp_path) |
| | assert parsed == expected, f"parse_fasta mismatch:\n got: {parsed}\n expected: {expected}" |
| | print("test_parse_fasta: OK") |
| |
|
| |
|
| | class FixedLengthTokenizer: |
| | """Wraps a tokenizer so every call pads to exactly MAX_EMBED_LEN tokens. |
| | |
| | Both batch=1 and batch=N therefore receive tensors of the same shape, |
| | keeping max_seqlen_in_batch identical and eliminating floating-point |
| | variability from different softmax vector lengths / flash-attention tile sizes. |
| | """ |
| | def __init__(self, tokenizer, max_length: int = MAX_EMBED_LEN): |
| | self._tok = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __call__(self, sequences, **kwargs): |
| | return self._tok( |
| | sequences, |
| | return_tensors="pt", |
| | padding="max_length", |
| | max_length=self.max_length, |
| | truncation=True, |
| | ) |
| |
|
| |
|
| | def random_sequences(n: int, min_len: int = 8, max_len: int = 64) -> list[str]: |
| | """Variable-length sequences; used for the NaN test.""" |
| | return [ |
| | "M" + "".join(random.choices(CANONICAL_AAS, k=random.randint(min_len, max_len))) |
| | for _ in range(n) |
| | ] |
| |
|
| |
|
| | def random_sequences_fixed_len(n: int, length: int = 64) -> list[str]: |
| | """Fixed-length sequences; used for the match test with E1 (sequence mode).""" |
| | return [ |
| | "M" + "".join(random.choices(CANONICAL_AAS, k=length - 1)) |
| | for _ in range(n) |
| | ] |
| |
|
| |
|
| | def assert_no_nan(embeddings: dict[str, torch.Tensor], label: str) -> None: |
| | for seq, emb in embeddings.items(): |
| | assert not torch.isnan(emb).any(), ( |
| | f"[{label}] NaN found in embedding for sequence '{seq[:20]}...'" |
| | ) |
| |
|
| |
|
| | def assert_embeddings_match( |
| | a: dict[str, torch.Tensor], |
| | b: dict[str, torch.Tensor], |
| | label: str, |
| | atol: float = 5e-3, |
| | ) -> None: |
| | """Compare real-token embeddings from two runs. |
| | |
| | full_embeddings=True already strips padding via emb[mask.bool()], so both |
| | dicts contain only non-pad token rows and the comparison is over those rows. |
| | """ |
| | assert set(a) == set(b), f"[{label}] Key sets differ between batch and single runs" |
| | for seq in a: |
| | ea, eb = a[seq].float(), b[seq].float() |
| | assert ea.shape == eb.shape, ( |
| | f"[{label}] Shape mismatch for '{seq[:20]}': {ea.shape} vs {eb.shape}" |
| | ) |
| | max_diff = (ea - eb).abs().max().item() |
| | assert max_diff <= atol, ( |
| | f"[{label}] Max abs diff {max_diff:.5f} > {atol} for '{seq[:20]}'" |
| | ) |
| |
|
| |
|
| | def test_dplm2_multimodal_layout_guard() -> None: |
| | plain_sequence_type_ids = torch.tensor([ |
| | [1, 1, 1, 1, 1, 1, 0, 2], |
| | [1, 1, 1, 1, 1, 0, 2, 2], |
| | ]) |
| | packed_multimodal_type_ids = torch.tensor([ |
| | [1, 1, 1, 2, 0, 0, 0, 2], |
| | [1, 1, 2, 2, 0, 0, 2, 2], |
| | ]) |
| | mismatched_multimodal_type_ids = torch.tensor([ |
| | [1, 1, 1, 2, 0, 0, 2, 2], |
| | ]) |
| |
|
| | assert not _has_packed_multimodal_layout(plain_sequence_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| | assert _has_packed_multimodal_layout(packed_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| | assert not _has_packed_multimodal_layout(mismatched_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| | print("test_dplm2_multimodal_layout_guard: OK") |
| |
|
| |
|
| | def test_dplm2_special_token_normalization() -> None: |
| | input_ids = torch.tensor([[8231, 5, 23, 13, 8229, 1, 8232, -100]]) |
| | normalized_input_ids = _normalize_dplm2_input_ids(input_ids, vocab_size=8229) |
| | expected = torch.tensor([[0, 5, 23, 13, 2, 1, 32, -100]]) |
| | assert torch.equal(normalized_input_ids, expected), ( |
| | f"DPLM2 special-token normalization mismatch:\n" |
| | f" got: {normalized_input_ids.tolist()}\n" |
| | f" expected: {expected.tolist()}" |
| | ) |
| | print("test_dplm2_special_token_normalization: OK") |
| |
|
| |
|
| | def test_model(name: str, model_cls, model_path: str, use_model_tokenizer: bool, batch_size: int) -> None: |
| | print(f"\n--- {name} ({model_path}) ---") |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | model = model_cls.from_pretrained( |
| | model_path, |
| | dtype=torch.bfloat16, |
| | device_map=device, |
| | trust_remote_code=True, |
| | ).eval() |
| |
|
| | if use_model_tokenizer: |
| | |
| | |
| | |
| | tokenizer = FixedLengthTokenizer(model.tokenizer) |
| | sequences = random_sequences(n=8) |
| | else: |
| | |
| | |
| | tokenizer = None |
| | sequences = random_sequences_fixed_len(n=8) |
| |
|
| | nan_kwargs = dict( |
| | tokenizer=tokenizer, |
| | full_embeddings=True, |
| | embed_dtype=torch.bfloat16, |
| | save=False, |
| | ) |
| |
|
| | |
| | |
| | |
| | nan_embs = model.embed_dataset(sequences=sequences, batch_size=batch_size, **nan_kwargs) |
| | assert_no_nan(nan_embs, f"{name} NaN check batch_size={batch_size}") |
| | shapes = [tuple(e.shape) for e in list(nan_embs.values())[:3]] |
| | print(f" NaN check batch_size={batch_size}: OK sample shapes={shapes}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if not use_model_tokenizer: |
| | return |
| |
|
| | model.to(torch.float32) |
| | batch_embs = model.embed_dataset( |
| | sequences=sequences, batch_size=batch_size, |
| | tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, |
| | ) |
| | single_embs = model.embed_dataset( |
| | sequences=sequences, batch_size=1, |
| | tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, |
| | ) |
| | assert_no_nan(batch_embs, f"{name} match test batch_size={batch_size}") |
| | assert_no_nan(single_embs, f"{name} match test batch_size=1") |
| | assert_embeddings_match(batch_embs, single_embs, name) |
| | print(f" Match test batch_size={batch_size} vs 1: OK (non-pad tokens only)") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser(description="Test embed_dataset produces no NaN with batch_size > 1.") |
| | parser.add_argument("--models", nargs="+", default=["ESM2", "ESM++", "E1", "DPLM", "DPLM2"]) |
| | parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE) |
| | args = parser.parse_args() |
| |
|
| | random.seed(SEED) |
| | test_parse_fasta() |
| | test_dplm2_multimodal_layout_guard() |
| | test_dplm2_special_token_normalization() |
| |
|
| | valid_names = {cfg[0] for cfg in MODEL_CONFIGS} |
| | for name in args.models: |
| | assert name in valid_names, f"Unknown model '{name}'. Choose from {sorted(valid_names)}" |
| |
|
| | configs_by_name = {cfg[0]: cfg for cfg in MODEL_CONFIGS} |
| | for model_name in args.models: |
| | name, model_cls, model_path, use_model_tokenizer = configs_by_name[model_name] |
| | test_model(name, model_cls, model_path, use_model_tokenizer, args.batch_size) |
| |
|
| | print("\nAll tests passed!") |
| |
|