VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_checkpoint_compat.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
from __future__ import annotations
import torch
from train.checkpoint_compat import filter_compatible_state_dict
def test_filter_compatible_state_dict_remaps_legacy_wrapper_keys() -> None:
model_state = {
"trunk.backbone.foo": torch.zeros(2, 2),
"trunk.decoder.bar": torch.zeros(3),
"adapter.state_head.decoder.phase_head.0.weight": torch.zeros(4, 4),
"adapter.transition_model.transition.weight_ih_l0": torch.zeros(5, 5),
"adapter.planner.reranker.network.0.weight": torch.zeros(6, 6),
"adapter.planner.reranker.score_head.weight": torch.zeros(7, 7),
}
checkpoint_state = {
"backbone.foo": torch.ones(2, 2),
"decoder.bar": torch.ones(3),
"elastic_state_head.decoder.phase_head.0.weight": torch.ones(4, 4),
"world_model.transition.weight_ih": torch.ones(5, 5),
"planner.residual.trunk.0.weight": torch.ones(6, 6),
"planner.residual.residual_head.weight": torch.ones(7, 7),
}
compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)
assert skipped == []
assert compatible["trunk.backbone.foo"].shape == (2, 2)
assert compatible["trunk.decoder.bar"].shape == (3,)
assert compatible["adapter.state_head.decoder.phase_head.0.weight"].shape == (4, 4)
assert compatible["adapter.transition_model.transition.weight_ih_l0"].shape == (5, 5)
assert compatible["adapter.planner.reranker.network.0.weight"].shape == (6, 6)
assert compatible["adapter.planner.reranker.score_head.weight"].shape == (7, 7)
assert remapped["backbone.foo"] == "trunk.backbone.foo"
assert remapped["decoder.bar"] == "trunk.decoder.bar"
assert remapped["elastic_state_head.decoder.phase_head.0.weight"] == "adapter.state_head.decoder.phase_head.0.weight"
assert remapped["world_model.transition.weight_ih"] == "adapter.transition_model.transition.weight_ih_l0"
assert remapped["planner.residual.trunk.0.weight"] == "adapter.planner.reranker.network.0.weight"
assert remapped["planner.residual.residual_head.weight"] == "adapter.planner.reranker.score_head.weight"
def test_filter_compatible_state_dict_skips_shape_mismatch_after_remap() -> None:
model_state = {
"trunk.backbone.foo": torch.zeros(2, 2),
}
checkpoint_state = {
"backbone.foo": torch.ones(3, 3),
}
compatible, skipped, remapped = filter_compatible_state_dict(model_state, checkpoint_state)
assert compatible == {}
assert skipped == ["backbone.foo"]
assert remapped == {}