from __future__ import annotations import torch from train.checkpoint_compat import filter_compatible_state_dict def test_filter_compatible_state_dict_remaps_legacy_wrapper_keys() -> None: model_state = { "trunk.backbone.foo": torch.zeros(2, 2), "trunk.decoder.bar": torch.zeros(3), "adapter.state_head.decoder.phase_head.0.weight": torch.zeros(4, 4), "adapter.transition_model.transition.weight_ih_l0": torch.zeros(5, 5), "adapter.planner.reranker.network.0.weight": torch.zeros(6, 6), "adapter.planner.reranker.score_head.weight": torch.zeros(7, 7), } checkpoint_state = { "backbone.foo": torch.ones(2, 2), "decoder.bar": torch.ones(3), "elastic_state_head.decoder.phase_head.0.weight": torch.ones(4, 4), "world_model.transition.weight_ih": torch.ones(5, 5), "planner.residual.trunk.0.weight": torch.ones(6, 6), "planner.residual.residual_head.weight": torch.ones(7, 7), } compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state) assert skipped == [] assert compatible["trunk.backbone.foo"].shape == (2, 2) assert compatible["trunk.decoder.bar"].shape == (3,) assert compatible["adapter.state_head.decoder.phase_head.0.weight"].shape == (4, 4) assert compatible["adapter.transition_model.transition.weight_ih_l0"].shape == (5, 5) assert compatible["adapter.planner.reranker.network.0.weight"].shape == (6, 6) assert compatible["adapter.planner.reranker.score_head.weight"].shape == (7, 7) assert remapped["backbone.foo"] == "trunk.backbone.foo" assert remapped["decoder.bar"] == "trunk.decoder.bar" assert remapped["elastic_state_head.decoder.phase_head.0.weight"] == "adapter.state_head.decoder.phase_head.0.weight" assert remapped["world_model.transition.weight_ih"] == "adapter.transition_model.transition.weight_ih_l0" assert remapped["planner.residual.trunk.0.weight"] == "adapter.planner.reranker.network.0.weight" assert remapped["planner.residual.residual_head.weight"] == "adapter.planner.reranker.score_head.weight" def test_filter_compatible_state_dict_skips_shape_mismatch_after_remap() -> None: model_state = { "trunk.backbone.foo": torch.zeros(2, 2), } checkpoint_state = { "backbone.foo": torch.ones(3, 3), } compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state) assert compatible == {} assert skipped == ["backbone.foo"] assert remapped == {}