| import hydra |
| from omegaconf import OmegaConf |
| from pfp import DATA_DIRS, set_seeds |
| from pfp.data.dataset_pcd import RobotDatasetPcd |
| from pfp.data.dataset_images import RobotDatasetImages |
|
|
| import rerun as rr |
| from pfp.common.visualization import RerunViewer as RV |
| from pfp.common.visualization import RerunTraj |
|
|
| TASK_NAME = "sponge_on_plate" |
| MODE = "valid" |
|
|
|
|
| @hydra.main(version_base=None, config_path="../conf", config_name="train") |
| def main(cfg: OmegaConf): |
| if not OmegaConf.has_resolver("eval"): |
| OmegaConf.register_new_resolver("eval", eval) |
| OmegaConf.resolve(cfg) |
| print(OmegaConf.to_yaml(cfg)) |
| set_seeds(cfg.seed) |
|
|
| data_path_train = DATA_DIRS.PFP_REAL / TASK_NAME / MODE |
| |
| if cfg.obs_mode == "pcd": |
| dataset_train = RobotDatasetPcd(data_path_train, **cfg.dataset) |
| |
| elif cfg.obs_mode == "rgb": |
| dataset_train = RobotDatasetImages(data_path_train, **cfg.dataset) |
| |
| else: |
| raise ValueError(f"Unknown observation mode: {cfg.obs_mode}") |
|
|
| |
| RV("Dataset visualization") |
| obs_traj = RerunTraj() |
| pred_traj = RerunTraj() |
| for i in range(len(dataset_train)): |
| |
| |
| |
| pcd, robot_state_obs, robot_state_pred = dataset_train[i] |
| rr.set_time_sequence("timestep", i) |
| RV.add_np_pointcloud("vis/pointcloud", pcd[-1]) |
| obs_traj.add_traj("vis/robot_state_obs", robot_state_obs, size=0.008) |
| pred_traj.add_traj("vis/robot_state_pred", robot_state_pred, size=0.004) |
| rr.log("plot/gripper_pred", rr.Scalar(robot_state_pred[0, -1])) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|