| | import entrypoint_setup |
| |
|
| | import torch |
| | import random |
| | from torch.nn.functional import mse_loss |
| | from tqdm import tqdm |
| | from collections import defaultdict |
| | from transformers import AutoModelForMaskedLM |
| |
|
| | from esm2.modeling_fastesm import FastEsmForMaskedLM |
| | from esm2.load_official import load_official_model as load_official_esm2_model |
| |
|
| | from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM |
| | from esm_plusplus.load_official import load_official_model as load_official_esmc_model |
| |
|
| | from e1_fastplms.load_official import load_official_model as load_official_e1_model |
| | from e1_fastplms.modeling_e1 import E1ForMaskedLM |
| |
|
| | from weight_parity_utils import assert_state_dict_equal |
| |
|
| |
|
| | class ComplianceChecker: |
| | def __init__( |
| | self, |
| | test_number_batches: int = 25, |
| | batch_size: int = 8, |
| | min_sequence_length: int = 16, |
| | max_sequence_length: int = 128, |
| | ): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.test_number_batches = test_number_batches |
| | self.batch_size = batch_size |
| | self.min_sequence_length = min_sequence_length |
| | self.max_sequence_length = max_sequence_length |
| | self.canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" |
| |
|
| | def _load_esmc(self, from_auto_model: bool = False, force_download: bool = False): |
| | official_model_path = "esmc-300" |
| | fast_model_path = "Synthyra/ESMplusplus_small" |
| | official_model, tokenizer = load_official_esmc_model( |
| | reference_repo_id=official_model_path, |
| | device=self.device, |
| | dtype=torch.bfloat16, |
| | ) |
| | load_class = AutoModelForMaskedLM if from_auto_model else ESMplusplusForMaskedLM |
| | fast_model = load_class.from_pretrained( |
| | fast_model_path, |
| | dtype=torch.bfloat16, |
| | device_map=self.device, |
| | force_download=force_download, |
| | trust_remote_code=True, |
| | ).eval() |
| | return official_model, fast_model, tokenizer |
| |
|
| | def _load_esm2(self, from_auto_model: bool = False, force_download: bool = False): |
| | official_model_path = "facebook/esm2_t6_8M_UR50D" |
| | fast_model_path = "Synthyra/ESM2-8M" |
| | official_model, tokenizer = load_official_esm2_model( |
| | reference_repo_id=official_model_path, |
| | device=self.device, |
| | dtype=torch.bfloat16, |
| | ) |
| | load_class = AutoModelForMaskedLM if from_auto_model else FastEsmForMaskedLM |
| | fast_model = load_class.from_pretrained( |
| | fast_model_path, |
| | dtype=torch.bfloat16, |
| | device_map=self.device, |
| | force_download=force_download, |
| | trust_remote_code=True, |
| | ).eval() |
| | return official_model, fast_model, tokenizer |
| |
|
| | def _load_e1(self, from_auto_model: bool = False, force_download: bool = False): |
| | official_model_path = "Profluent-Bio/E1-150m" |
| | fast_model_path = "Synthyra/Profluent-E1-150M" |
| | official_model, tokenizer = load_official_e1_model( |
| | reference_repo_id=official_model_path, |
| | device=self.device, |
| | dtype=torch.bfloat16, |
| | ) |
| | load_class = AutoModelForMaskedLM if from_auto_model else E1ForMaskedLM |
| | fast_model = load_class.from_pretrained( |
| | fast_model_path, |
| | dtype=torch.bfloat16, |
| | device_map=self.device, |
| | force_download=force_download, |
| | trust_remote_code=True, |
| | ).eval() |
| | return official_model, fast_model, tokenizer |
| |
|
| | def _generate_random_sequence(self, length: int) -> str: |
| | return 'M' + "".join(random.choices(self.canonical_amino_acids, k=length)) |
| | |
| | def _generate_random_batch(self, batch_size: int, min_length: int, max_length: int) -> list[str]: |
| | return [self._generate_random_sequence(random.randint(min_length, max_length)) for _ in range(batch_size)] |
| |
|
| | def _weight_compliance(self, official_model, fast_model): |
| | for (official_name, official_param), (fast_name, fast_param) in zip(official_model.model.state_dict().items(), fast_model.state_dict().items()): |
| | if official_name == fast_name: |
| | diff = mse_loss(official_param, fast_param).item() |
| | if diff > 0.0: |
| | print(f"{official_name}: {diff}") |
| | assert diff < 1e-3, f"Parameter {official_name} has a large difference: {diff}" |
| | else: |
| | print(f"Name mismatch: {official_name} != {fast_name}") |
| |
|
| | @torch.inference_mode() |
| | def _foward_compliance(self, model_type: str, official_model, fast_model, tokenizer, only_non_pad_tokens: bool = False): |
| | cumulative_logits_mse = 0 |
| | cumulative_preds_accuracy = 0 |
| | hidden_state_diff_dict = defaultdict(int) |
| |
|
| | for _ in tqdm(range(self.test_number_batches)): |
| | batch = self._generate_random_batch(self.batch_size, self.min_sequence_length, self.max_sequence_length) |
| | if model_type == "E1": |
| | tokenized = tokenizer.get_batch_kwargs(batch, device=self.device) |
| | tokenized = { |
| | "input_ids": tokenized["input_ids"], |
| | "within_seq_position_ids": tokenized["within_seq_position_ids"], |
| | "global_position_ids": tokenized["global_position_ids"], |
| | "sequence_ids": tokenized["sequence_ids"], |
| | "attention_mask": (tokenized["sequence_ids"] != -1).long(), |
| | } |
| | else: |
| | tokenized = tokenizer(batch, return_tensors="pt", padding=True) |
| | tokenized = {k: v.to(self.device) for k, v in tokenized.items()} |
| | |
| | attention_mask = tokenized['attention_mask'].cpu().bool() |
| | model_inputs = tokenized.copy() |
| | if model_type == "ESMC": |
| | model_inputs["sequence_id"] = model_inputs["attention_mask"].to(dtype=torch.bool) |
| |
|
| | official_output = official_model(**model_inputs, output_hidden_states=True) |
| | official_hidden_states = official_output.hidden_states |
| | official_logits = official_output.logits.cpu() |
| | if only_non_pad_tokens: |
| | official_logits = official_logits[attention_mask] |
| | official_preds = official_logits.argmax(dim=-1) |
| | |
| | fast_output = fast_model(**model_inputs, output_hidden_states=True) |
| | fast_hidden_states = fast_output.hidden_states |
| | fast_logits = fast_output.logits.cpu() |
| | if only_non_pad_tokens: |
| | fast_logits = fast_logits[attention_mask] |
| | fast_preds = fast_logits.argmax(dim=-1) |
| |
|
| | cumulative_logits_mse += mse_loss(official_logits, fast_logits) |
| | cumulative_preds_accuracy += (official_preds == fast_preds).float().mean() |
| |
|
| | for i in range(len(official_hidden_states)): |
| | official_state, fast_state = official_hidden_states[i], fast_hidden_states[i] |
| | if only_non_pad_tokens: |
| | official_state, fast_state = official_state[attention_mask], fast_state[attention_mask] |
| | hidden_state_diff_dict[i] += mse_loss(official_state, fast_state).item() |
| |
|
| | avg_logits_mse = cumulative_logits_mse / self.test_number_batches |
| | avg_preds_accuracy = cumulative_preds_accuracy / self.test_number_batches |
| | print(f"Average logits MSE: {avg_logits_mse}") |
| | print(f"Average preds accuracy: {avg_preds_accuracy}") |
| |
|
| | if avg_logits_mse > 1e-3 or avg_preds_accuracy < 0.95: |
| | print("Differences were too large, printing hidden state differences for debugging...") |
| | for k, v in hidden_state_diff_dict.items(): |
| | print(f"Hidden state {k} Avg MSE: {v / self.test_number_batches}") |
| |
|
| |
|
| | def __call__( |
| | self, |
| | model_type: str = "ESMC", |
| | force_download: bool = False, |
| | from_auto_model: bool = False, |
| | only_non_pad_tokens: bool = False, |
| | ): |
| | if model_type == "ESMC": |
| | official_model, fast_model, tokenizer = self._load_esmc(from_auto_model, force_download) |
| | elif model_type == "ESM2": |
| | official_model, fast_model, tokenizer = self._load_esm2(from_auto_model, force_download) |
| | elif model_type == "E1": |
| | official_model, fast_model, tokenizer = self._load_e1(from_auto_model, force_download) |
| | else: |
| | raise ValueError(f"Unsupported model type: {model_type}. Supported: ESMC, ESM2, E1") |
| | assert_state_dict_equal( |
| | reference_state_dict=official_model.model.state_dict(), |
| | candidate_state_dict=fast_model.state_dict(), |
| | context=f"{model_type} weight parity", |
| | ) |
| | self._weight_compliance(official_model, fast_model) |
| | self._foward_compliance(model_type, official_model, fast_model, tokenizer, only_non_pad_tokens) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--hf_token", type=str, default=None) |
| | parser.add_argument("--only_non_pad_tokens", action="store_true") |
| | parser.add_argument("--force_download", action="store_true") |
| | parser.add_argument("--from_auto_model", action="store_true") |
| | parser.add_argument("--model_types", nargs="+", default=["ESMC", "ESM2", "E1"]) |
| | args = parser.parse_args() |
| |
|
| | if args.hf_token is not None: |
| | from huggingface_hub import login |
| | login(token=args.hf_token) |
| |
|
| | checker = ComplianceChecker() |
| | for model_type in args.model_types: |
| | print(f"Checking {model_type}...") |
| | checker( |
| | model_type=model_type, |
| | from_auto_model=args.from_auto_model, |
| | only_non_pad_tokens=args.only_non_pad_tokens, |
| | force_download=args.force_download |
| | ) |
| |
|