directionality_probe / protify /FastPLMs /weight_comparison.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import argparse
import os
import torch
from safetensors.torch import load_file
from rich.console import Console
from rich.table import Table
from transformers import AutoModelForMaskedLM, AutoConfig, AutoModel
from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config, E1Model
def load_weights(path, cast_fp32=True):
assert os.path.exists(path), f"File {path} not found."
if path.endswith(".safetensors"):
sd = load_file(path)
elif path.endswith(".pth") or path.endswith(".pt"):
sd = torch.load(path, map_location="cpu", weights_only=True)
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
elif isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
else:
try:
sd = load_file(path)
except Exception:
sd = torch.load(path, map_location="cpu", weights_only=True)
if cast_fp32:
return {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
return sd
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--file1", type=str, default=None)
parser.add_argument("--files", type=str, nargs="+", default=None)
parser.add_argument("--strict", action="store_true")
parser.add_argument("--assert_exact", action="store_true")
args = parser.parse_args()
model = E1ForMaskedLM.from_pretrained('Profluent-Bio/E1-150m', dtype=torch.float32).eval()
torch.save(model.state_dict(), 'official.pth')
config = AutoConfig.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True)
model1 = AutoModel.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
torch.save(model1.state_dict(), 'load_from_pretrained_1.pth')
model2 = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
torch.save(model2.state_dict(), 'load_from_pretrained_2.pth')
if args.file1 is None:
args.file1 = 'official.pth'
if args.files is None:
args.files = ['load_from_pretrained_1.pth', 'load_from_pretrained_2.pth', 'old.safetensors']
paths = [args.file1] + args.files
sds = [load_weights(p, cast_fp32=not args.strict) for p in paths]
all_keys = sorted(set().union(*(sd.keys() for sd in sds)))
strict_mismatches = []
console = Console()
table = Table(title=f"Weights Comparison (Reference: {os.path.basename(paths[0])})")
table.add_column("Tensor Name", style="cyan", no_wrap=True)
for p in paths[1:]:
table.add_column(f"{os.path.basename(p)} == Ref", justify="center")
sd1 = sds[0]
for k in all_keys:
row = [k]
has_ref = k in sd1
ref_w = sd1[k] if has_ref else None
for sd in sds[1:]:
has_other = k in sd
other_w = sd[k] if has_other else None
if not has_ref or not has_other:
if not has_ref and not has_other:
row.append("[dim]βœ”[/dim]")
else:
row.append("[red]✘[/red]")
else:
# Both present, compare shapes and MSE
assert isinstance(ref_w, torch.Tensor), f"Weight {k} in reference is not a tensor."
assert isinstance(other_w, torch.Tensor), f"Weight {k} in comparison file is not a tensor."
if ref_w.shape != other_w.shape:
row.append("[red]✘ (Shape)[/red]")
else:
if args.strict:
if torch.equal(ref_w, other_w):
row.append("[green]βœ”[/green]")
else:
mse = torch.mean((ref_w.float() - other_w.float())**2).item()
row.append(f"[red]✘ (Strict, MSE: {mse:.2e})[/red]")
strict_mismatches.append(k)
else:
mse = torch.mean((ref_w - other_w)**2).item()
if mse == 0:
row.append("[green]βœ”[/green]")
else:
row.append(f"[red]✘ (MSE: {mse:.2e})[/red]")
table.add_row(*row)
console.print(table)
if args.strict and args.assert_exact:
assert len(strict_mismatches) == 0, (
f"Found {len(strict_mismatches)} strict mismatches. "
f"First mismatches: {strict_mismatches[:10]}"
)
if __name__ == "__main__":
main()