import entrypoint_setup import torch import random from torch.nn.functional import mse_loss from tqdm import tqdm from collections import defaultdict from transformers import AutoModelForMaskedLM from esm2.modeling_fastesm import FastEsmForMaskedLM from esm2.load_official import load_official_model as load_official_esm2_model from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM from esm_plusplus.load_official import load_official_model as load_official_esmc_model from e1_fastplms.load_official import load_official_model as load_official_e1_model from e1_fastplms.modeling_e1 import E1ForMaskedLM from weight_parity_utils import assert_state_dict_equal class ComplianceChecker: def __init__( self, test_number_batches: int = 25, batch_size: int = 8, min_sequence_length: int = 16, max_sequence_length: int = 128, ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.test_number_batches = test_number_batches self.batch_size = batch_size self.min_sequence_length = min_sequence_length self.max_sequence_length = max_sequence_length self.canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" def _load_esmc(self, from_auto_model: bool = False, force_download: bool = False): official_model_path = "esmc-300" fast_model_path = "Synthyra/ESMplusplus_small" official_model, tokenizer = load_official_esmc_model( reference_repo_id=official_model_path, device=self.device, dtype=torch.bfloat16, ) load_class = AutoModelForMaskedLM if from_auto_model else ESMplusplusForMaskedLM fast_model = load_class.from_pretrained( fast_model_path, dtype=torch.bfloat16, device_map=self.device, force_download=force_download, trust_remote_code=True, ).eval() return official_model, fast_model, tokenizer def _load_esm2(self, from_auto_model: bool = False, force_download: bool = False): official_model_path = "facebook/esm2_t6_8M_UR50D" fast_model_path = "Synthyra/ESM2-8M" official_model, tokenizer = load_official_esm2_model( reference_repo_id=official_model_path, device=self.device, dtype=torch.bfloat16, ) load_class = AutoModelForMaskedLM if from_auto_model else FastEsmForMaskedLM fast_model = load_class.from_pretrained( fast_model_path, dtype=torch.bfloat16, device_map=self.device, force_download=force_download, trust_remote_code=True, ).eval() return official_model, fast_model, tokenizer def _load_e1(self, from_auto_model: bool = False, force_download: bool = False): official_model_path = "Profluent-Bio/E1-150m" fast_model_path = "Synthyra/Profluent-E1-150M" official_model, tokenizer = load_official_e1_model( reference_repo_id=official_model_path, device=self.device, dtype=torch.bfloat16, ) load_class = AutoModelForMaskedLM if from_auto_model else E1ForMaskedLM fast_model = load_class.from_pretrained( fast_model_path, dtype=torch.bfloat16, device_map=self.device, force_download=force_download, trust_remote_code=True, ).eval() return official_model, fast_model, tokenizer def _generate_random_sequence(self, length: int) -> str: return 'M' + "".join(random.choices(self.canonical_amino_acids, k=length)) def _generate_random_batch(self, batch_size: int, min_length: int, max_length: int) -> list[str]: return [self._generate_random_sequence(random.randint(min_length, max_length)) for _ in range(batch_size)] def _weight_compliance(self, official_model, fast_model): for (official_name, official_param), (fast_name, fast_param) in zip(official_model.model.state_dict().items(), fast_model.state_dict().items()): if official_name == fast_name: diff = mse_loss(official_param, fast_param).item() if diff > 0.0: print(f"{official_name}: {diff}") assert diff < 1e-3, f"Parameter {official_name} has a large difference: {diff}" else: print(f"Name mismatch: {official_name} != {fast_name}") @torch.inference_mode() def _foward_compliance(self, model_type: str, official_model, fast_model, tokenizer, only_non_pad_tokens: bool = False): cumulative_logits_mse = 0 cumulative_preds_accuracy = 0 hidden_state_diff_dict = defaultdict(int) for _ in tqdm(range(self.test_number_batches)): batch = self._generate_random_batch(self.batch_size, self.min_sequence_length, self.max_sequence_length) if model_type == "E1": tokenized = tokenizer.get_batch_kwargs(batch, device=self.device) tokenized = { "input_ids": tokenized["input_ids"], "within_seq_position_ids": tokenized["within_seq_position_ids"], "global_position_ids": tokenized["global_position_ids"], "sequence_ids": tokenized["sequence_ids"], "attention_mask": (tokenized["sequence_ids"] != -1).long(), } else: tokenized = tokenizer(batch, return_tensors="pt", padding=True) tokenized = {k: v.to(self.device) for k, v in tokenized.items()} attention_mask = tokenized['attention_mask'].cpu().bool() model_inputs = tokenized.copy() if model_type == "ESMC": model_inputs["sequence_id"] = model_inputs["attention_mask"].to(dtype=torch.bool) official_output = official_model(**model_inputs, output_hidden_states=True) official_hidden_states = official_output.hidden_states official_logits = official_output.logits.cpu() if only_non_pad_tokens: official_logits = official_logits[attention_mask] official_preds = official_logits.argmax(dim=-1) fast_output = fast_model(**model_inputs, output_hidden_states=True) fast_hidden_states = fast_output.hidden_states fast_logits = fast_output.logits.cpu() if only_non_pad_tokens: fast_logits = fast_logits[attention_mask] fast_preds = fast_logits.argmax(dim=-1) cumulative_logits_mse += mse_loss(official_logits, fast_logits) cumulative_preds_accuracy += (official_preds == fast_preds).float().mean() for i in range(len(official_hidden_states)): official_state, fast_state = official_hidden_states[i], fast_hidden_states[i] if only_non_pad_tokens: official_state, fast_state = official_state[attention_mask], fast_state[attention_mask] hidden_state_diff_dict[i] += mse_loss(official_state, fast_state).item() avg_logits_mse = cumulative_logits_mse / self.test_number_batches avg_preds_accuracy = cumulative_preds_accuracy / self.test_number_batches print(f"Average logits MSE: {avg_logits_mse}") print(f"Average preds accuracy: {avg_preds_accuracy}") if avg_logits_mse > 1e-3 or avg_preds_accuracy < 0.95: print("Differences were too large, printing hidden state differences for debugging...") for k, v in hidden_state_diff_dict.items(): print(f"Hidden state {k} Avg MSE: {v / self.test_number_batches}") def __call__( self, model_type: str = "ESMC", force_download: bool = False, from_auto_model: bool = False, only_non_pad_tokens: bool = False, ): if model_type == "ESMC": official_model, fast_model, tokenizer = self._load_esmc(from_auto_model, force_download) elif model_type == "ESM2": official_model, fast_model, tokenizer = self._load_esm2(from_auto_model, force_download) elif model_type == "E1": official_model, fast_model, tokenizer = self._load_e1(from_auto_model, force_download) else: raise ValueError(f"Unsupported model type: {model_type}. Supported: ESMC, ESM2, E1") assert_state_dict_equal( reference_state_dict=official_model.model.state_dict(), candidate_state_dict=fast_model.state_dict(), context=f"{model_type} weight parity", ) self._weight_compliance(official_model, fast_model) self._foward_compliance(model_type, official_model, fast_model, tokenizer, only_non_pad_tokens) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--hf_token", type=str, default=None) parser.add_argument("--only_non_pad_tokens", action="store_true") parser.add_argument("--force_download", action="store_true") parser.add_argument("--from_auto_model", action="store_true") parser.add_argument("--model_types", nargs="+", default=["ESMC", "ESM2", "E1"]) args = parser.parse_args() if args.hf_token is not None: from huggingface_hub import login login(token=args.hf_token) checker = ComplianceChecker() for model_type in args.model_types: print(f"Checking {model_type}...") checker( model_type=model_type, from_auto_model=args.from_auto_model, only_non_pad_tokens=args.only_non_pad_tokens, force_download=args.force_download )