File size: 2,560 Bytes
31ade1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

import torch

from train.checkpoint_compat import filter_compatible_state_dict


def test_filter_compatible_state_dict_remaps_legacy_wrapper_keys() -> None:
    model_state = {
        "trunk.backbone.foo": torch.zeros(2, 2),
        "trunk.decoder.bar": torch.zeros(3),
        "adapter.state_head.decoder.phase_head.0.weight": torch.zeros(4, 4),
        "adapter.transition_model.transition.weight_ih_l0": torch.zeros(5, 5),
        "adapter.planner.reranker.network.0.weight": torch.zeros(6, 6),
        "adapter.planner.reranker.score_head.weight": torch.zeros(7, 7),
    }
    checkpoint_state = {
        "backbone.foo": torch.ones(2, 2),
        "decoder.bar": torch.ones(3),
        "elastic_state_head.decoder.phase_head.0.weight": torch.ones(4, 4),
        "world_model.transition.weight_ih": torch.ones(5, 5),
        "planner.residual.trunk.0.weight": torch.ones(6, 6),
        "planner.residual.residual_head.weight": torch.ones(7, 7),
    }

    compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)

    assert skipped == []
    assert compatible["trunk.backbone.foo"].shape == (2, 2)
    assert compatible["trunk.decoder.bar"].shape == (3,)
    assert compatible["adapter.state_head.decoder.phase_head.0.weight"].shape == (4, 4)
    assert compatible["adapter.transition_model.transition.weight_ih_l0"].shape == (5, 5)
    assert compatible["adapter.planner.reranker.network.0.weight"].shape == (6, 6)
    assert compatible["adapter.planner.reranker.score_head.weight"].shape == (7, 7)
    assert remapped["backbone.foo"] == "trunk.backbone.foo"
    assert remapped["decoder.bar"] == "trunk.decoder.bar"
    assert remapped["elastic_state_head.decoder.phase_head.0.weight"] == "adapter.state_head.decoder.phase_head.0.weight"
    assert remapped["world_model.transition.weight_ih"] == "adapter.transition_model.transition.weight_ih_l0"
    assert remapped["planner.residual.trunk.0.weight"] == "adapter.planner.reranker.network.0.weight"
    assert remapped["planner.residual.residual_head.weight"] == "adapter.planner.reranker.score_head.weight"


def test_filter_compatible_state_dict_skips_shape_mismatch_after_remap() -> None:
    model_state = {
        "trunk.backbone.foo": torch.zeros(2, 2),
    }
    checkpoint_state = {
        "backbone.foo": torch.ones(3, 3),
    }

    compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)

    assert compatible == {}
    assert skipped == ["backbone.foo"]
    assert remapped == {}