| from __future__ import annotations | |
| import torch | |
| from train.checkpoint_compat import filter_compatible_state_dict | |
| def test_filter_compatible_state_dict_remaps_legacy_wrapper_keys() -> None: | |
| model_state = { | |
| "trunk.backbone.foo": torch.zeros(2, 2), | |
| "trunk.decoder.bar": torch.zeros(3), | |
| "adapter.state_head.decoder.phase_head.0.weight": torch.zeros(4, 4), | |
| "adapter.transition_model.transition.weight_ih_l0": torch.zeros(5, 5), | |
| "adapter.planner.reranker.network.0.weight": torch.zeros(6, 6), | |
| "adapter.planner.reranker.score_head.weight": torch.zeros(7, 7), | |
| } | |
| checkpoint_state = { | |
| "backbone.foo": torch.ones(2, 2), | |
| "decoder.bar": torch.ones(3), | |
| "elastic_state_head.decoder.phase_head.0.weight": torch.ones(4, 4), | |
| "world_model.transition.weight_ih": torch.ones(5, 5), | |
| "planner.residual.trunk.0.weight": torch.ones(6, 6), | |
| "planner.residual.residual_head.weight": torch.ones(7, 7), | |
| } | |
| compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state) | |
| assert skipped == [] | |
| assert compatible["trunk.backbone.foo"].shape == (2, 2) | |
| assert compatible["trunk.decoder.bar"].shape == (3,) | |
| assert compatible["adapter.state_head.decoder.phase_head.0.weight"].shape == (4, 4) | |
| assert compatible["adapter.transition_model.transition.weight_ih_l0"].shape == (5, 5) | |
| assert compatible["adapter.planner.reranker.network.0.weight"].shape == (6, 6) | |
| assert compatible["adapter.planner.reranker.score_head.weight"].shape == (7, 7) | |
| assert remapped["backbone.foo"] == "trunk.backbone.foo" | |
| assert remapped["decoder.bar"] == "trunk.decoder.bar" | |
| assert remapped["elastic_state_head.decoder.phase_head.0.weight"] == "adapter.state_head.decoder.phase_head.0.weight" | |
| assert remapped["world_model.transition.weight_ih"] == "adapter.transition_model.transition.weight_ih_l0" | |
| assert remapped["planner.residual.trunk.0.weight"] == "adapter.planner.reranker.network.0.weight" | |
| assert remapped["planner.residual.residual_head.weight"] == "adapter.planner.reranker.score_head.weight" | |
| def test_filter_compatible_state_dict_skips_shape_mismatch_after_remap() -> None: | |
| model_state = { | |
| "trunk.backbone.foo": torch.zeros(2, 2), | |
| } | |
| checkpoint_state = { | |
| "backbone.foo": torch.ones(3, 3), | |
| } | |
| compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state) | |
| assert compatible == {} | |
| assert skipped == ["backbone.foo"] | |
| assert remapped == {} | |