File size: 9,766 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import entrypoint_setup

import torch
import random
from torch.nn.functional import mse_loss
from tqdm import tqdm
from collections import defaultdict
from transformers import AutoModelForMaskedLM

from esm2.modeling_fastesm import FastEsmForMaskedLM
from esm2.load_official import load_official_model as load_official_esm2_model

from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM
from esm_plusplus.load_official import load_official_model as load_official_esmc_model

from e1_fastplms.load_official import load_official_model as load_official_e1_model
from e1_fastplms.modeling_e1 import E1ForMaskedLM

from weight_parity_utils import assert_state_dict_equal


class ComplianceChecker:
    def __init__(
        self,
        test_number_batches: int = 25,
        batch_size: int = 8,
        min_sequence_length: int = 16,
        max_sequence_length: int = 128,
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.test_number_batches = test_number_batches
        self.batch_size = batch_size
        self.min_sequence_length = min_sequence_length
        self.max_sequence_length = max_sequence_length
        self.canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"

    def _load_esmc(self, from_auto_model: bool = False, force_download: bool = False):
        official_model_path = "esmc-300"
        fast_model_path = "Synthyra/ESMplusplus_small"
        official_model, tokenizer = load_official_esmc_model(
            reference_repo_id=official_model_path,
            device=self.device,
            dtype=torch.bfloat16,
        )
        load_class = AutoModelForMaskedLM if from_auto_model else ESMplusplusForMaskedLM
        fast_model = load_class.from_pretrained(
            fast_model_path,
            dtype=torch.bfloat16,
            device_map=self.device,
            force_download=force_download,
            trust_remote_code=True,
        ).eval()
        return official_model, fast_model, tokenizer

    def _load_esm2(self, from_auto_model: bool = False, force_download: bool = False):
        official_model_path = "facebook/esm2_t6_8M_UR50D"
        fast_model_path = "Synthyra/ESM2-8M"
        official_model, tokenizer = load_official_esm2_model(
            reference_repo_id=official_model_path,
            device=self.device,
            dtype=torch.bfloat16,
        )
        load_class = AutoModelForMaskedLM if from_auto_model else FastEsmForMaskedLM
        fast_model = load_class.from_pretrained(
            fast_model_path,
            dtype=torch.bfloat16,
            device_map=self.device,
            force_download=force_download,
            trust_remote_code=True,
        ).eval()
        return official_model, fast_model, tokenizer

    def _load_e1(self, from_auto_model: bool = False, force_download: bool = False):
        official_model_path = "Profluent-Bio/E1-150m"
        fast_model_path = "Synthyra/Profluent-E1-150M"
        official_model, tokenizer = load_official_e1_model(
            reference_repo_id=official_model_path,
            device=self.device,
            dtype=torch.bfloat16,
        )
        load_class = AutoModelForMaskedLM if from_auto_model else E1ForMaskedLM
        fast_model = load_class.from_pretrained(
            fast_model_path,
            dtype=torch.bfloat16,
            device_map=self.device,
            force_download=force_download,
            trust_remote_code=True,
        ).eval()
        return official_model, fast_model, tokenizer

    def _generate_random_sequence(self, length: int) -> str:
        return 'M' + "".join(random.choices(self.canonical_amino_acids, k=length))
    
    def _generate_random_batch(self, batch_size: int, min_length: int, max_length: int) -> list[str]:
        return [self._generate_random_sequence(random.randint(min_length, max_length)) for _ in range(batch_size)]

    def _weight_compliance(self, official_model, fast_model):
        for (official_name, official_param), (fast_name, fast_param) in zip(official_model.model.state_dict().items(), fast_model.state_dict().items()):
            if official_name == fast_name:
                diff = mse_loss(official_param, fast_param).item()
                if diff > 0.0:
                    print(f"{official_name}: {diff}")
                    assert diff < 1e-3, f"Parameter {official_name} has a large difference: {diff}"
            else:
                print(f"Name mismatch: {official_name} != {fast_name}")

    @torch.inference_mode()
    def _foward_compliance(self, model_type: str, official_model, fast_model, tokenizer, only_non_pad_tokens: bool = False):
        cumulative_logits_mse = 0
        cumulative_preds_accuracy = 0
        hidden_state_diff_dict = defaultdict(int)

        for _ in tqdm(range(self.test_number_batches)):
            batch = self._generate_random_batch(self.batch_size, self.min_sequence_length, self.max_sequence_length)
            if model_type == "E1":
                tokenized = tokenizer.get_batch_kwargs(batch, device=self.device)
                tokenized = {
                    "input_ids": tokenized["input_ids"],
                    "within_seq_position_ids": tokenized["within_seq_position_ids"],
                    "global_position_ids": tokenized["global_position_ids"],
                    "sequence_ids": tokenized["sequence_ids"],
                    "attention_mask": (tokenized["sequence_ids"] != -1).long(),
                }
            else:
                tokenized = tokenizer(batch, return_tensors="pt", padding=True)
                tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
            
            attention_mask = tokenized['attention_mask'].cpu().bool()
            model_inputs = tokenized.copy()
            if model_type == "ESMC":
                model_inputs["sequence_id"] = model_inputs["attention_mask"].to(dtype=torch.bool)

            official_output = official_model(**model_inputs, output_hidden_states=True)
            official_hidden_states = official_output.hidden_states
            official_logits = official_output.logits.cpu()
            if only_non_pad_tokens:
                official_logits = official_logits[attention_mask]
            official_preds = official_logits.argmax(dim=-1)
            
            fast_output = fast_model(**model_inputs, output_hidden_states=True)
            fast_hidden_states = fast_output.hidden_states
            fast_logits = fast_output.logits.cpu()
            if only_non_pad_tokens:
                fast_logits = fast_logits[attention_mask]
            fast_preds = fast_logits.argmax(dim=-1)

            cumulative_logits_mse += mse_loss(official_logits, fast_logits)
            cumulative_preds_accuracy += (official_preds == fast_preds).float().mean()

            for i in range(len(official_hidden_states)):
                official_state, fast_state = official_hidden_states[i], fast_hidden_states[i]
                if only_non_pad_tokens:
                    official_state, fast_state = official_state[attention_mask], fast_state[attention_mask]
                hidden_state_diff_dict[i] += mse_loss(official_state, fast_state).item()

        avg_logits_mse = cumulative_logits_mse / self.test_number_batches
        avg_preds_accuracy = cumulative_preds_accuracy / self.test_number_batches
        print(f"Average logits MSE: {avg_logits_mse}")
        print(f"Average preds accuracy: {avg_preds_accuracy}")

        if avg_logits_mse > 1e-3 or avg_preds_accuracy < 0.95:
            print("Differences were too large, printing hidden state differences for debugging...")
            for k, v in hidden_state_diff_dict.items():
                print(f"Hidden state {k} Avg MSE: {v / self.test_number_batches}")


    def __call__(
        self,
        model_type: str = "ESMC",
        force_download: bool = False,
        from_auto_model: bool = False,
        only_non_pad_tokens: bool = False,
    ):
        if model_type == "ESMC":
            official_model, fast_model, tokenizer = self._load_esmc(from_auto_model, force_download)
        elif model_type == "ESM2":
            official_model, fast_model, tokenizer = self._load_esm2(from_auto_model, force_download)
        elif model_type == "E1":
            official_model, fast_model, tokenizer = self._load_e1(from_auto_model, force_download)
        else:
            raise ValueError(f"Unsupported model type: {model_type}. Supported: ESMC, ESM2, E1")
        assert_state_dict_equal(
            reference_state_dict=official_model.model.state_dict(),
            candidate_state_dict=fast_model.state_dict(),
            context=f"{model_type} weight parity",
        )
        self._weight_compliance(official_model, fast_model)
        self._foward_compliance(model_type, official_model, fast_model, tokenizer, only_non_pad_tokens)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--hf_token", type=str, default=None)
    parser.add_argument("--only_non_pad_tokens", action="store_true")
    parser.add_argument("--force_download", action="store_true")
    parser.add_argument("--from_auto_model", action="store_true")
    parser.add_argument("--model_types", nargs="+", default=["ESMC", "ESM2", "E1"])
    args = parser.parse_args()

    if args.hf_token is not None:
        from huggingface_hub import login
        login(token=args.hf_token)

    checker = ComplianceChecker()
    for model_type in args.model_types:
        print(f"Checking {model_type}...")
        checker(
            model_type=model_type,
            from_auto_model=args.from_auto_model,
            only_non_pad_tokens=args.only_non_pad_tokens,
            force_download=args.force_download
        )