import torch from models.world_model import LightweightRevealStateTransitionModel def test_lightweight_transition_contract(tiny_policy_config, tiny_state): config = tiny_policy_config(num_candidates=4, chunk_size=2) model = LightweightRevealStateTransitionModel(config.world_model) state = tiny_state(batch_size=2, field_size=config.reveal_head.field_size) action_chunk = torch.rand(2, 4, config.decoder.chunk_size, config.decoder.action_dim) proposal_mode_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long).expand(2, -1) rollout = model( interaction_state=state, action_chunk=action_chunk, proposal_mode_ids=proposal_mode_ids, ) assert rollout["visibility_summary"].shape == (2, 4, config.decoder.chunk_size) assert rollout["access_field"].shape[:4] == (2, 4, config.decoder.chunk_size, config.world_model.num_support_modes) assert rollout["clearance_field"].shape == (2, 4, config.decoder.chunk_size, 2, 1, 1)