directionality_probe / protify /FastPLMs /weight_parity_utils.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
from torch.nn.functional import mse_loss
from typing import Mapping
def assert_model_parameters_fp32(model: torch.nn.Module, model_name: str) -> None:
non_fp32: list[dict[str, str]] = []
parameter_count = 0
for name, parameter in model.named_parameters():
parameter_count += 1
if parameter.dtype != torch.float32:
non_fp32.append({"name": name, "dtype": str(parameter.dtype)})
assert parameter_count > 0, f"{model_name} has no parameters."
assert len(non_fp32) == 0, (
f"{model_name} parameters must all be torch.float32. "
f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}"
)
def assert_state_dict_floating_tensors_fp32(
state_dict: Mapping[str, torch.Tensor],
state_dict_name: str,
) -> None:
non_fp32: list[dict[str, str]] = []
for tensor_name in sorted(state_dict.keys()):
tensor = state_dict[tensor_name]
assert torch.is_tensor(tensor), (
f"{state_dict_name} state_dict entry must be a tensor. "
f"name={tensor_name} type={type(tensor)}"
)
if tensor.is_floating_point() and tensor.dtype != torch.float32:
non_fp32.append({"name": tensor_name, "dtype": str(tensor.dtype)})
assert len(non_fp32) == 0, (
f"{state_dict_name} floating tensors must be torch.float32. "
f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}"
)
def assert_state_dict_equal(
reference_state_dict: Mapping[str, torch.Tensor],
candidate_state_dict: Mapping[str, torch.Tensor],
context: str,
max_report: int = 10,
) -> None:
error_msgs = []
for (ref_name, ref_tensor), (cand_name, cand_tensor) in zip(reference_state_dict.items(), candidate_state_dict.items()):
if ref_name != cand_name:
msg = f"Name mismatch: {ref_name} != {cand_name}"
print(msg)
error_msgs.append(msg)
else:
diff = mse_loss(ref_tensor, cand_tensor).item()
if diff > 0.0:
msg = f"{ref_name}: {diff}"
print(msg)
error_msgs.append(msg)
assert not error_msgs, (
f"{context} state_dict parity failed:{' | '.join(error_msgs[:max_report])}"
)
def assert_models_fp32_and_equal(
reference_model: torch.nn.Module,
candidate_model: torch.nn.Module,
context: str,
max_report: int = 5,
) -> None:
assert_model_parameters_fp32(model=reference_model, model_name=f"{context} reference model")
assert_model_parameters_fp32(model=candidate_model, model_name=f"{context} candidate model")
assert_state_dict_equal(
reference_state_dict=reference_model.state_dict(),
candidate_state_dict=candidate_model.state_dict(),
context=context,
max_report=max_report,
)