from sim_reveal.dataset import RGBD_PROXY_DATASET_VERSION, collect_teacher_dataset, dataset_from_bundle def test_dataset_v6_keys(): bundle = collect_teacher_dataset( episodes_per_proxy=1, resolution=16, history_steps=2, planner_candidates=3, dataset_version=RGBD_PROXY_DATASET_VERSION, ) dataset = dataset_from_bundle(bundle, resolution=16) item = dataset[0] for key in ( "images", "depths", "depth_valid", "belief_map", "visibility_map", "clearance_map", "support_stability", "reocclusion_target", "candidate_rollout_belief_map", ): assert key in item def test_phase_dataset_version_keeps_rgbd_path(): bundle = collect_teacher_dataset( proxy_names=["bag_proxy"], episodes_per_proxy=1, resolution=16, history_steps=2, planner_candidates=3, dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", ) dataset = dataset_from_bundle(bundle, resolution=16) item = dataset[0] assert float(item["depth_valid"].sum()) > 0.0 assert "phase" in item assert "rollout_phase" in item assert "candidate_rollout_phase" in item def test_dataset_proposal_target_keys_roundtrip(): bundle = collect_teacher_dataset( proxy_names=["cloth_proxy"], episodes_per_proxy=1, resolution=16, history_steps=2, planner_candidates=3, dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", proposal_target_builder=lambda env, observation, sample: { "proposal_target_action_chunks": sample["candidate_action_chunks"].copy(), "proposal_target_retrieval_success": sample["candidate_retrieval_success"].copy(), "proposal_target_risk": sample["candidate_risk"].copy(), "proposal_target_utility": sample["candidate_utility"].copy(), }, ) dataset = dataset_from_bundle(bundle, resolution=16) item = dataset[0] assert "proposal_target_action_chunks" in item assert "proposal_target_retrieval_success" in item assert "proposal_target_risk" in item assert "proposal_target_utility" in item