nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
"""Load official ESMC model from the esm package for comparison."""
import torch
import torch.nn as nn
class _ESMCComplianceOutput:
"""Mimics HuggingFace model output so the test suite can access .logits and .hidden_states."""
def __init__(self, logits: torch.Tensor, last_hidden_state: torch.Tensor, hidden_states: tuple):
self.logits = logits
self.last_hidden_state = last_hidden_state
self.hidden_states = hidden_states
class _OfficialESMCForwardWrapper(nn.Module):
"""Wraps official ESMC model to produce outputs compatible with our test suite."""
def __init__(self, model: nn.Module, tokenizer):
super().__init__()
self.model = model
self.tokenizer = tokenizer
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
**kwargs,
):
esmc_output = self.model(sequence_tokens=input_ids)
# ESMC returns: sequence_logits, embeddings, hidden_states (stacked [n_layers, B, L, D])
logits = esmc_output.sequence_logits
embeddings = esmc_output.embeddings
raw_hiddens = esmc_output.hidden_states
# Convert stacked tensor to tuple for compatibility with hidden_states[-1]
if raw_hiddens is not None:
hidden_states = tuple(raw_hiddens[i] for i in range(raw_hiddens.shape[0]))
hidden_states = hidden_states + (embeddings,)
else:
hidden_states = (embeddings,)
return _ESMCComplianceOutput(
logits=logits,
last_hidden_state=embeddings,
hidden_states=hidden_states,
)
def load_official_model(
reference_repo_id: str,
device: torch.device,
dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, object]:
"""Load the official ESMC model from the esm submodule.
Args:
reference_repo_id: e.g. "EvolutionaryScale/esmc-300m-2024-12"
device: target device
dtype: target dtype (should be float32 for comparison)
Returns (wrapped_model, tokenizer).
"""
from esm.pretrained import ESMC_300M_202412, ESMC_600M_202412
if "300" in reference_repo_id:
official_model = ESMC_300M_202412(use_flash_attn=False)
elif "600" in reference_repo_id:
official_model = ESMC_600M_202412(use_flash_attn=False)
else:
raise ValueError(f"Unsupported ESMC reference repo id: {reference_repo_id}")
official_model = official_model.to(device=device, dtype=dtype).eval()
tokenizer = official_model.tokenizer
wrapped = _OfficialESMCForwardWrapper(official_model, tokenizer).to(device=device, dtype=dtype).eval()
return wrapped, tokenizer
if __name__ == "__main__":
model, tokenizer = load_official_model(reference_repo_id="EvolutionaryScale/esmc-300m-2024-12", device=torch.device("cuda"), dtype=torch.float32)
print(model)
print(tokenizer)