File size: 2,818 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.nn.functional import mse_loss
from typing import Mapping


def assert_model_parameters_fp32(model: torch.nn.Module, model_name: str) -> None:
    non_fp32: list[dict[str, str]] = []
    parameter_count = 0
    for name, parameter in model.named_parameters():
        parameter_count += 1
        if parameter.dtype != torch.float32:
            non_fp32.append({"name": name, "dtype": str(parameter.dtype)})

    assert parameter_count > 0, f"{model_name} has no parameters."
    assert len(non_fp32) == 0, (
        f"{model_name} parameters must all be torch.float32. "
        f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}"
    )


def assert_state_dict_floating_tensors_fp32(
    state_dict: Mapping[str, torch.Tensor],
    state_dict_name: str,
) -> None:
    non_fp32: list[dict[str, str]] = []
    for tensor_name in sorted(state_dict.keys()):
        tensor = state_dict[tensor_name]
        assert torch.is_tensor(tensor), (
            f"{state_dict_name} state_dict entry must be a tensor. "
            f"name={tensor_name} type={type(tensor)}"
        )
        if tensor.is_floating_point() and tensor.dtype != torch.float32:
            non_fp32.append({"name": tensor_name, "dtype": str(tensor.dtype)})

    assert len(non_fp32) == 0, (
        f"{state_dict_name} floating tensors must be torch.float32. "
        f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}"
    )


def assert_state_dict_equal(
    reference_state_dict: Mapping[str, torch.Tensor],
    candidate_state_dict: Mapping[str, torch.Tensor],
    context: str,
    max_report: int = 10,
) -> None:
    error_msgs = []
    for (ref_name, ref_tensor), (cand_name, cand_tensor) in zip(reference_state_dict.items(), candidate_state_dict.items()):
        if ref_name != cand_name:
            msg = f"Name mismatch: {ref_name} != {cand_name}"
            print(msg)
            error_msgs.append(msg)
        else:
            diff = mse_loss(ref_tensor, cand_tensor).item()
            if diff > 0.0:
                msg = f"{ref_name}: {diff}"
                print(msg)
                error_msgs.append(msg)
    assert not error_msgs, (
        f"{context} state_dict parity failed:{' | '.join(error_msgs[:max_report])}"
    )


def assert_models_fp32_and_equal(
    reference_model: torch.nn.Module,
    candidate_model: torch.nn.Module,
    context: str,
    max_report: int = 5,
) -> None:
    assert_model_parameters_fp32(model=reference_model, model_name=f"{context} reference model")
    assert_model_parameters_fp32(model=candidate_model, model_name=f"{context} candidate model")
    assert_state_dict_equal(
        reference_state_dict=reference_model.state_dict(),
        candidate_state_dict=candidate_model.state_dict(),
        context=context,
        max_report=max_report,
    )