| | """Load official ESMC model from the esm package for comparison.""" |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class _ESMCComplianceOutput: |
| | """Mimics HuggingFace model output so the test suite can access .logits and .hidden_states.""" |
| | def __init__(self, logits: torch.Tensor, last_hidden_state: torch.Tensor, hidden_states: tuple): |
| | self.logits = logits |
| | self.last_hidden_state = last_hidden_state |
| | self.hidden_states = hidden_states |
| |
|
| |
|
| | class _OfficialESMCForwardWrapper(nn.Module): |
| | """Wraps official ESMC model to produce outputs compatible with our test suite.""" |
| | def __init__(self, model: nn.Module, tokenizer): |
| | super().__init__() |
| | self.model = model |
| | self.tokenizer = tokenizer |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor | None = None, |
| | sequence_id: torch.Tensor | None = None, |
| | **kwargs, |
| | ): |
| | esmc_output = self.model(sequence_tokens=input_ids) |
| | |
| | logits = esmc_output.sequence_logits |
| | embeddings = esmc_output.embeddings |
| | raw_hiddens = esmc_output.hidden_states |
| | |
| | if raw_hiddens is not None: |
| | hidden_states = tuple(raw_hiddens[i] for i in range(raw_hiddens.shape[0])) |
| | hidden_states = hidden_states + (embeddings,) |
| | else: |
| | hidden_states = (embeddings,) |
| | return _ESMCComplianceOutput( |
| | logits=logits, |
| | last_hidden_state=embeddings, |
| | hidden_states=hidden_states, |
| | ) |
| |
|
| |
|
| | def load_official_model( |
| | reference_repo_id: str, |
| | device: torch.device, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> tuple[nn.Module, object]: |
| | """Load the official ESMC model from the esm submodule. |
| | |
| | Args: |
| | reference_repo_id: e.g. "EvolutionaryScale/esmc-300m-2024-12" |
| | device: target device |
| | dtype: target dtype (should be float32 for comparison) |
| | |
| | Returns (wrapped_model, tokenizer). |
| | """ |
| | from esm.pretrained import ESMC_300M_202412, ESMC_600M_202412 |
| |
|
| | if "300" in reference_repo_id: |
| | official_model = ESMC_300M_202412(use_flash_attn=False) |
| | elif "600" in reference_repo_id: |
| | official_model = ESMC_600M_202412(use_flash_attn=False) |
| | else: |
| | raise ValueError(f"Unsupported ESMC reference repo id: {reference_repo_id}") |
| |
|
| | official_model = official_model.to(device=device, dtype=dtype).eval() |
| | tokenizer = official_model.tokenizer |
| | wrapped = _OfficialESMCForwardWrapper(official_model, tokenizer).to(device=device, dtype=dtype).eval() |
| | return wrapped, tokenizer |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model, tokenizer = load_official_model(reference_repo_id="EvolutionaryScale/esmc-300m-2024-12", device=torch.device("cuda"), dtype=torch.float32) |
| | print(model) |
| | print(tokenizer) |