| | import utils |
| | import torch |
| | import numpy as np |
| | from pathlib import Path |
| | from datasets.core import TrajectoryDataset |
| |
|
| |
|
| | class SimKitchenTrajectoryDataset(TrajectoryDataset): |
| | def __init__(self, data_directory, prefetch=True, onehot_goals=False): |
| | self.data_directory = Path(data_directory) |
| | states = torch.from_numpy(np.load(self.data_directory / "observations_seq.npy")) |
| | actions = torch.from_numpy(np.load(self.data_directory / "actions_seq.npy")) |
| | goals = torch.load(self.data_directory / "onehot_goals.pth") |
| | |
| | self.states, self.actions, self.goals = utils.transpose_batch_timestep( |
| | states, actions, goals |
| | ) |
| | self.Ts = np.load(self.data_directory / "existence_mask.npy").sum(axis=0).astype(int).tolist() |
| | |
| | self.prefetch = prefetch |
| | if self.prefetch: |
| | self.obses = [] |
| | for i in range(len(self.Ts)): |
| | self.obses.append(torch.load(self.data_directory / "obses" / f"{i:03d}.pth")) |
| | self.onehot_goals = onehot_goals |
| |
|
| | def get_seq_length(self, idx): |
| | return self.Ts[idx] |
| |
|
| | def get_all_actions(self): |
| | result = [] |
| | |
| | for i in range(len(self.Ts)): |
| | T = self.Ts[i] |
| | result.append(self.actions[i, :T, :]) |
| | return torch.cat(result, dim=0) |
| |
|
| | def get_frames(self, idx, frames): |
| | |
| | if self.prefetch: |
| | obs = self.obses[idx][frames] |
| | else: |
| | obs = torch.load(self.data_directory / "obses" / f"{idx:03d}.pth")[frames] |
| | obs = obs / 255.0 |
| | act = self.actions[idx, frames] |
| | mask = torch.ones((len(frames))) |
| | if self.onehot_goals: |
| | goal = self.goals[idx, frames] |
| | return obs, act, mask, goal |
| | else: |
| | return obs, act, mask |
| |
|
| | def __getitem__(self, idx): |
| | T = self.Ts[idx] |
| | return self.get_frames(idx, range(T)) |
| | |
| | def __len__(self): |
| | return len(self.Ts) |