File size: 2,963 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Load official ESMC model from the esm package for comparison."""
import torch
import torch.nn as nn


class _ESMCComplianceOutput:
    """Mimics HuggingFace model output so the test suite can access .logits and .hidden_states."""
    def __init__(self, logits: torch.Tensor, last_hidden_state: torch.Tensor, hidden_states: tuple):
        self.logits = logits
        self.last_hidden_state = last_hidden_state
        self.hidden_states = hidden_states


class _OfficialESMCForwardWrapper(nn.Module):
    """Wraps official ESMC model to produce outputs compatible with our test suite."""
    def __init__(self, model: nn.Module, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        sequence_id: torch.Tensor | None = None,
        **kwargs,
    ):
        esmc_output = self.model(sequence_tokens=input_ids)
        # ESMC returns: sequence_logits, embeddings, hidden_states (stacked [n_layers, B, L, D])
        logits = esmc_output.sequence_logits
        embeddings = esmc_output.embeddings
        raw_hiddens = esmc_output.hidden_states
        # Convert stacked tensor to tuple for compatibility with hidden_states[-1]
        if raw_hiddens is not None:
            hidden_states = tuple(raw_hiddens[i] for i in range(raw_hiddens.shape[0]))
            hidden_states = hidden_states + (embeddings,)
        else:
            hidden_states = (embeddings,)
        return _ESMCComplianceOutput(
            logits=logits,
            last_hidden_state=embeddings,
            hidden_states=hidden_states,
        )


def load_official_model(
    reference_repo_id: str,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, object]:
    """Load the official ESMC model from the esm submodule.

    Args:
        reference_repo_id: e.g. "EvolutionaryScale/esmc-300m-2024-12"
        device: target device
        dtype: target dtype (should be float32 for comparison)

    Returns (wrapped_model, tokenizer).
    """
    from esm.pretrained import ESMC_300M_202412, ESMC_600M_202412

    if "300" in reference_repo_id:
        official_model = ESMC_300M_202412(use_flash_attn=False)
    elif "600" in reference_repo_id:
        official_model = ESMC_600M_202412(use_flash_attn=False)
    else:
        raise ValueError(f"Unsupported ESMC reference repo id: {reference_repo_id}")

    official_model = official_model.to(device=device, dtype=dtype).eval()
    tokenizer = official_model.tokenizer
    wrapped = _OfficialESMCForwardWrapper(official_model, tokenizer).to(device=device, dtype=dtype).eval()
    return wrapped, tokenizer


if __name__ == "__main__":
    model, tokenizer = load_official_model(reference_repo_id="EvolutionaryScale/esmc-300m-2024-12", device=torch.device("cuda"), dtype=torch.float32)
    print(model)
    print(tokenizer)