File size: 4,614 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import torch
from safetensors.torch import load_file
from rich.console import Console
from rich.table import Table

from transformers import AutoModelForMaskedLM, AutoConfig, AutoModel

from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config, E1Model


def load_weights(path, cast_fp32=True):
    assert os.path.exists(path), f"File {path} not found."
    if path.endswith(".safetensors"):
        sd = load_file(path)
    elif path.endswith(".pth") or path.endswith(".pt"):
        sd = torch.load(path, map_location="cpu", weights_only=True)
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]
        elif isinstance(sd, dict) and "model" in sd:
            sd = sd["model"]
    else:
        try:
            sd = load_file(path)
        except Exception:
            sd = torch.load(path, map_location="cpu", weights_only=True)
    
    if cast_fp32:
        return {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
    return sd


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file1", type=str, default=None)
    parser.add_argument("--files", type=str, nargs="+", default=None)
    parser.add_argument("--strict", action="store_true")
    parser.add_argument("--assert_exact", action="store_true")
    args = parser.parse_args()

    model = E1ForMaskedLM.from_pretrained('Profluent-Bio/E1-150m', dtype=torch.float32).eval()
    torch.save(model.state_dict(), 'official.pth')

    config = AutoConfig.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True)
    model1 = AutoModel.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
    torch.save(model1.state_dict(), 'load_from_pretrained_1.pth')
    model2 = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
    torch.save(model2.state_dict(), 'load_from_pretrained_2.pth')

    if args.file1 is None:
        args.file1 = 'official.pth'
    if args.files is None:
        args.files = ['load_from_pretrained_1.pth', 'load_from_pretrained_2.pth', 'old.safetensors']

    paths = [args.file1] + args.files
    sds = [load_weights(p, cast_fp32=not args.strict) for p in paths]
    all_keys = sorted(set().union(*(sd.keys() for sd in sds)))
    strict_mismatches = []

    console = Console()
    table = Table(title=f"Weights Comparison (Reference: {os.path.basename(paths[0])})")
    table.add_column("Tensor Name", style="cyan", no_wrap=True)
    
    for p in paths[1:]:
        table.add_column(f"{os.path.basename(p)} == Ref", justify="center")
    
    sd1 = sds[0]
    for k in all_keys:
        row = [k]
        
        has_ref = k in sd1
        ref_w = sd1[k] if has_ref else None
        
        for sd in sds[1:]:
            has_other = k in sd
            other_w = sd[k] if has_other else None
            
            if not has_ref or not has_other:
                if not has_ref and not has_other:
                    row.append("[dim]✔[/dim]")
                else:
                    row.append("[red]✘[/red]")
            else:
                # Both present, compare shapes and MSE
                assert isinstance(ref_w, torch.Tensor), f"Weight {k} in reference is not a tensor."
                assert isinstance(other_w, torch.Tensor), f"Weight {k} in comparison file is not a tensor."
                
                if ref_w.shape != other_w.shape:
                    row.append("[red]✘ (Shape)[/red]")
                else:
                    if args.strict:
                        if torch.equal(ref_w, other_w):
                            row.append("[green]✔[/green]")
                        else:
                            mse = torch.mean((ref_w.float() - other_w.float())**2).item()
                            row.append(f"[red]✘ (Strict, MSE: {mse:.2e})[/red]")
                            strict_mismatches.append(k)
                    else:
                        mse = torch.mean((ref_w - other_w)**2).item()
                        if mse == 0:
                            row.append("[green]✔[/green]")
                        else:
                            row.append(f"[red]✘ (MSE: {mse:.2e})[/red]")
        
        table.add_row(*row)

    console.print(table)
    if args.strict and args.assert_exact:
        assert len(strict_mismatches) == 0, (
            f"Found {len(strict_mismatches)} strict mismatches. "
            f"First mismatches: {strict_mismatches[:10]}"
        )


if __name__ == "__main__":
    main()