# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer # -------------------------------------------------------- import numpy as np import torch import os from PIL import Image from typing import Tuple import yaml import pickle import tqdm from torch.utils.data import Dataset from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords import torchaudio class BaseDataset(Dataset): def __init__( self, data_folder: str, data_split_folder: str, dataset_name: str, image_size: Tuple[int, int], min_dist_cat: int, max_dist_cat: int, len_traj_pred: int, traj_stride: int, context_size: int, transform: object, traj_names: str, normalize: bool = True, predefined_index: list = None, goals_per_obs: int = 1, ): self.data_folder = data_folder self.data_split_folder = data_split_folder self.dataset_name = dataset_name self.goals_per_obs = goals_per_obs traj_names_file = os.path.join(data_split_folder, traj_names) with open(traj_names_file, "r") as f: file_lines = f.read() self.traj_names = file_lines.split("\n") if "" in self.traj_names: self.traj_names.remove("") self.image_size = image_size self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1)) self.min_dist_cat = self.distance_categories[0] self.max_dist_cat = self.distance_categories[-1] self.len_traj_pred = len_traj_pred self.traj_stride = traj_stride self.context_size = context_size self.normalize = normalize # load data/data_config.yaml with open("config/data_config.yaml", "r") as f: all_data_config = yaml.safe_load(f) dataset_names = list(all_data_config.keys()) dataset_names.sort() # use this index to retrieve the dataset name from the data_config.yaml self.data_config = all_data_config[self.dataset_name] self.transform = transform self._load_index(predefined_index) self.ACTION_STATS = {} for key in all_data_config['action_stats']: self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0) self.DISTANCE_DIFF_STATS = {} # [NEW] for key in all_data_config['distance_diff_stats']: # [NEW] self.DISTANCE_DIFF_STATS[key] = np.expand_dims(all_data_config['distance_diff_stats'][key], axis=0) # [NEW] def _load_index(self, predefined_index) -> None: """ Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset """ if predefined_index: print(f"****** Using a predefined evaluation index... {predefined_index}******") with open(predefined_index, "rb") as f: self.index_to_data = pickle.load(f) return else: print("****** Evaluating from NON PREDEFINED index... ******") index_to_data_path = os.path.join( self.data_split_folder, f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl", ) self.index_to_data, self.goals_index = self._build_index() with open(index_to_data_path, "wb") as f: pickle.dump((self.index_to_data, self.goals_index), f) def _build_index(self, use_tqdm: bool = False): """ Build an index consisting of tuples (trajectory name, time, max goal distance) """ samples_index = [] goals_index = [] for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True): traj_data = self._get_trajectory(traj_name) traj_len = len(traj_data["position"]) for goal_time in range(0, traj_len): goals_index.append((traj_name, goal_time)) begin_time = self.context_size - 1 end_time = traj_len - self.len_traj_pred for curr_time in range(begin_time, end_time, self.traj_stride): max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1) min_goal_distance = max(self.min_dist_cat, -curr_time) samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance)) return samples_index, goals_index def _get_trajectory(self, trajectory_name): with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f: traj_data = pickle.load(f) for k,v in traj_data.items(): traj_data[k] = v.astype('float') return traj_data def __len__(self) -> int: return len(self.index_to_data) def _compute_actions(self, traj_data, curr_time, goal_time): start_index = curr_time end_index = curr_time + self.len_traj_pred + 1 yaw = traj_data["yaw"][start_index:end_index] positions = traj_data["position"][start_index:end_index] goal_pos = traj_data["position"][goal_time] goal_yaw = traj_data["yaw"][goal_time] dist_window = traj_data["distance_to_target"][start_index:end_index] # shape (len_traj_pred+1,) # [NEW] goal_dist = traj_data["distance_to_target"][goal_time] # shape (N,) or scalar # [NEW] if len(yaw.shape) == 2: yaw = yaw.squeeze(1) if yaw.shape != (self.len_traj_pred + 1,): raise ValueError("is used?") waypoints_pos = to_local_coords(positions, positions[0], yaw[0]) waypoints_yaw = angle_difference(yaw[0], yaw) actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1) actions = actions[1:] goal_pos = to_local_coords(goal_pos, positions[0], yaw[0]) goal_yaw = angle_difference(yaw[0], goal_yaw) diffs_seq = (dist_window[0] - dist_window).reshape(-1, 1)[1:] # [NEW] goal_diff = (dist_window[0] - goal_dist).reshape(-1, 1) # [NEW] if self.normalize: actions[:, :2] /= self.data_config["metric_waypoint_spacing"] goal_pos[:, :2] /= self.data_config["metric_waypoint_spacing"] diffs_seq /= self.data_config["metric_waypoint_spacing"] # [NEW] goal_diff /= self.data_config["metric_waypoint_spacing"] # [NEW] goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1) return actions, goal_pos, diffs_seq, goal_diff class TrainingDataset(BaseDataset): def __init__( self, data_folder: str, data_split_folder: str, dataset_name: str, image_size: Tuple[int, int], min_dist_cat: int, max_dist_cat: int, len_traj_pred: int, traj_stride: int, context_size: int, transform: object, traj_names: str = 'traj_names.txt', normalize: bool = True, predefined_index: list = None, goals_per_obs: int = 1, # sample_rate: int = 16000, # target_len: int = 7840 sample_rate: int = 16000, input_sr: int = 48000, evaluate: bool = False ): super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat, len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs) self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64) self.evaluate = evaluate def __getitem__(self, i: int) -> Tuple[torch.Tensor]: try: f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i] goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs)) goal_time = (curr_time + goal_offset).astype('int') rel_time = (goal_offset).astype('float')/(128.) # TODO: refactor, currently a fixed const context_times = list(range(curr_time - self.context_size + 1, curr_time + 1)) context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time] obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context]) obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context]) if self.evaluate: orig_obs_audio = obs_audio obs_audio = self.resampler(obs_audio) # Load other trajectory data curr_traj_data = self._get_trajectory(f_curr) # Compute actions _, goal_pos, _, goal_diff = self._compute_actions(curr_traj_data, curr_time, goal_time) goal_pos[:, :2] = normalize_data(goal_pos[:, :2], self.ACTION_STATS) goal_diff = normalize_data(goal_diff, self.DISTANCE_DIFF_STATS) if self.evaluate: return ( torch.as_tensor(obs_image, dtype=torch.float32), torch.as_tensor(obs_audio, dtype=torch.float32), torch.as_tensor(goal_pos, dtype=torch.float32), torch.as_tensor(goal_diff, dtype=torch.float32), torch.as_tensor(rel_time, dtype=torch.float32), torch.as_tensor(orig_obs_audio, dtype=torch.float32), ) else: return ( torch.as_tensor(obs_image, dtype=torch.float32), torch.as_tensor(obs_audio, dtype=torch.float32), torch.as_tensor(goal_pos, dtype=torch.float32), torch.as_tensor(goal_diff, dtype=torch.float32), torch.as_tensor(rel_time, dtype=torch.float32), ) except Exception as e: print(f"Exception in {self.dataset_name}", e) raise Exception(e) class EvalDataset(BaseDataset): def __init__( self, data_folder: str, data_split_folder: str, dataset_name: str, image_size: Tuple[int, int], min_dist_cat: int, max_dist_cat: int, len_traj_pred: int, traj_stride: int, context_size: int, transform: object, traj_names: str, normalize: bool = True, predefined_index: list = None, goals_per_obs: int = 1, sample_rate: int = 16000, input_sr: int = 48000 ): super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat, len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs) self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64) def __getitem__(self, i: int) -> Tuple[torch.Tensor]: try: f_curr, curr_time, _, _ = self.index_to_data[i] context_times = list(range(curr_time - self.context_size + 1, curr_time + 1)) pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1)) context = [(f_curr, t) for t in context_times] pred = [(f_curr, t) for t in pred_times] obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context]) pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred]) orig_obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context]) orig_pred_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in pred]) obs_audio = self.resampler(orig_obs_audio) pred_audio = self.resampler(orig_pred_audio) curr_traj_data = self._get_trajectory(f_curr) # Compute actions actions, _, diffs_seq, _ = self._compute_actions(curr_traj_data, curr_time, np.array([curr_time+1])) # last argument is dummy goal actions[:, :2] = normalize_data(actions[:, :2], self.ACTION_STATS) diffs_seq = normalize_data(diffs_seq, self.DISTANCE_DIFF_STATS) delta = get_delta_np(actions) diffs_seq = get_delta_np(diffs_seq) return ( torch.tensor([i], dtype=torch.float32), # for logging purposes torch.as_tensor(obs_image, dtype=torch.float32), torch.as_tensor(pred_image, dtype=torch.float32), torch.as_tensor(obs_audio, dtype=torch.float32), torch.as_tensor(pred_audio, dtype=torch.float32), torch.as_tensor(diffs_seq, dtype=torch.float32), torch.as_tensor(delta, dtype=torch.float32), torch.as_tensor(orig_obs_audio, dtype=torch.float32), torch.as_tensor(orig_pred_audio, dtype=torch.float32), ) except Exception as e: print(f"Exception in {self.dataset_name}", e) raise Exception(e)