File size: 2,560 Bytes
31ade1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from __future__ import annotations
import torch
from train.checkpoint_compat import filter_compatible_state_dict
def test_filter_compatible_state_dict_remaps_legacy_wrapper_keys() -> None:
model_state = {
"trunk.backbone.foo": torch.zeros(2, 2),
"trunk.decoder.bar": torch.zeros(3),
"adapter.state_head.decoder.phase_head.0.weight": torch.zeros(4, 4),
"adapter.transition_model.transition.weight_ih_l0": torch.zeros(5, 5),
"adapter.planner.reranker.network.0.weight": torch.zeros(6, 6),
"adapter.planner.reranker.score_head.weight": torch.zeros(7, 7),
}
checkpoint_state = {
"backbone.foo": torch.ones(2, 2),
"decoder.bar": torch.ones(3),
"elastic_state_head.decoder.phase_head.0.weight": torch.ones(4, 4),
"world_model.transition.weight_ih": torch.ones(5, 5),
"planner.residual.trunk.0.weight": torch.ones(6, 6),
"planner.residual.residual_head.weight": torch.ones(7, 7),
}
compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)
assert skipped == []
assert compatible["trunk.backbone.foo"].shape == (2, 2)
assert compatible["trunk.decoder.bar"].shape == (3,)
assert compatible["adapter.state_head.decoder.phase_head.0.weight"].shape == (4, 4)
assert compatible["adapter.transition_model.transition.weight_ih_l0"].shape == (5, 5)
assert compatible["adapter.planner.reranker.network.0.weight"].shape == (6, 6)
assert compatible["adapter.planner.reranker.score_head.weight"].shape == (7, 7)
assert remapped["backbone.foo"] == "trunk.backbone.foo"
assert remapped["decoder.bar"] == "trunk.decoder.bar"
assert remapped["elastic_state_head.decoder.phase_head.0.weight"] == "adapter.state_head.decoder.phase_head.0.weight"
assert remapped["world_model.transition.weight_ih"] == "adapter.transition_model.transition.weight_ih_l0"
assert remapped["planner.residual.trunk.0.weight"] == "adapter.planner.reranker.network.0.weight"
assert remapped["planner.residual.residual_head.weight"] == "adapter.planner.reranker.score_head.weight"
def test_filter_compatible_state_dict_skips_shape_mismatch_after_remap() -> None:
model_state = {
"trunk.backbone.foo": torch.zeros(2, 2),
}
checkpoint_state = {
"backbone.foo": torch.ones(3, 3),
}
compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)
assert compatible == {}
assert skipped == ["backbone.foo"]
assert remapped == {}
|