import torch from torch.nn.functional import mse_loss from typing import Mapping def assert_model_parameters_fp32(model: torch.nn.Module, model_name: str) -> None: non_fp32: list[dict[str, str]] = [] parameter_count = 0 for name, parameter in model.named_parameters(): parameter_count += 1 if parameter.dtype != torch.float32: non_fp32.append({"name": name, "dtype": str(parameter.dtype)}) assert parameter_count > 0, f"{model_name} has no parameters." assert len(non_fp32) == 0, ( f"{model_name} parameters must all be torch.float32. " f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}" ) def assert_state_dict_floating_tensors_fp32( state_dict: Mapping[str, torch.Tensor], state_dict_name: str, ) -> None: non_fp32: list[dict[str, str]] = [] for tensor_name in sorted(state_dict.keys()): tensor = state_dict[tensor_name] assert torch.is_tensor(tensor), ( f"{state_dict_name} state_dict entry must be a tensor. " f"name={tensor_name} type={type(tensor)}" ) if tensor.is_floating_point() and tensor.dtype != torch.float32: non_fp32.append({"name": tensor_name, "dtype": str(tensor.dtype)}) assert len(non_fp32) == 0, ( f"{state_dict_name} floating tensors must be torch.float32. " f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}" ) def assert_state_dict_equal( reference_state_dict: Mapping[str, torch.Tensor], candidate_state_dict: Mapping[str, torch.Tensor], context: str, max_report: int = 10, ) -> None: error_msgs = [] for (ref_name, ref_tensor), (cand_name, cand_tensor) in zip(reference_state_dict.items(), candidate_state_dict.items()): if ref_name != cand_name: msg = f"Name mismatch: {ref_name} != {cand_name}" print(msg) error_msgs.append(msg) else: diff = mse_loss(ref_tensor, cand_tensor).item() if diff > 0.0: msg = f"{ref_name}: {diff}" print(msg) error_msgs.append(msg) assert not error_msgs, ( f"{context} state_dict parity failed:{' | '.join(error_msgs[:max_report])}" ) def assert_models_fp32_and_equal( reference_model: torch.nn.Module, candidate_model: torch.nn.Module, context: str, max_report: int = 5, ) -> None: assert_model_parameters_fp32(model=reference_model, model_name=f"{context} reference model") assert_model_parameters_fp32(model=candidate_model, model_name=f"{context} candidate model") assert_state_dict_equal( reference_state_dict=reference_model.state_dict(), candidate_state_dict=candidate_model.state_dict(), context=context, max_report=max_report, )