| | """ |
| | This file will compute the min, max, mean, and standard deviation of each datasets |
| | in `pretrain_datasets.json` or `pretrain_datasets.json`. |
| | """ |
| |
|
| | import json |
| | import argparse |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | |
| | from data.hdf5_vla_dataset import AnubisHDF5VLADataset as HDF5VLADataset |
| |
|
| |
|
| | def process_hdf5_dataset(vla_dataset): |
| | EPS = 1e-8 |
| | episode_cnt = 0 |
| | state_sum = 0 |
| | state_sum_sq = 0 |
| | z_state_sum = 0 |
| | z_state_sum_sq = 0 |
| | state_cnt = 0 |
| | nz_state_cnt = None |
| | state_max = None |
| | state_min = None |
| | for i in tqdm(range(len(vla_dataset))): |
| | |
| | episode = vla_dataset.get_item(i, state_only=True) |
| | episode_cnt += 1 |
| | |
| | states = episode['state'] |
| | |
| | |
| | z_states = states.copy() |
| | z_states[np.abs(states) <= EPS] = 0 |
| | |
| | if nz_state_cnt is None: |
| | nz_state_cnt = np.zeros(states.shape[1]) |
| | nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) |
| | |
| | |
| | state_sum += np.sum(states, axis=0) |
| | state_sum_sq += np.sum(states**2, axis=0) |
| | z_state_sum += np.sum(z_states, axis=0) |
| | z_state_sum_sq += np.sum(z_states**2, axis=0) |
| | state_cnt += states.shape[0] |
| | if state_max is None: |
| | state_max = np.max(states, axis=0) |
| | state_min = np.min(states, axis=0) |
| | else: |
| | state_max = np.maximum(state_max, np.max(states, axis=0)) |
| | state_min = np.minimum(state_min, np.min(states, axis=0)) |
| | |
| | |
| | nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) |
| | |
| | result = { |
| | "dataset_name": vla_dataset.get_dataset_name(), |
| | "state_mean": (state_sum / state_cnt).tolist(), |
| | "state_std": np.sqrt( |
| | np.maximum( |
| | (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), |
| | np.zeros_like(state_sum_sq) |
| | ) |
| | ).tolist(), |
| | "state_min": state_min.tolist(), |
| | "state_max": state_max.tolist(), |
| | } |
| |
|
| | return result |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--save_path', type=str, |
| | default="configs/dataset_stat.json", |
| | help="JSON file path to save the dataset statistics.") |
| | parser.add_argument('--skip_exist', action='store_true', |
| | help="Whether to skip the existing dataset statistics.") |
| | parser.add_argument('--dataset', type=str) |
| | args = parser.parse_args() |
| | |
| | vla_dataset = HDF5VLADataset(args.dataset) |
| | dataset_name = vla_dataset.get_dataset_name() |
| | |
| | try: |
| | with open(args.save_path, 'r') as f: |
| | results = json.load(f) |
| | except FileNotFoundError: |
| | results = {} |
| | if args.skip_exist and dataset_name in results: |
| | print(f"Skipping existed {dataset_name} dataset statistics") |
| | else: |
| | print(f"Processing {dataset_name} dataset") |
| | result = process_hdf5_dataset(vla_dataset) |
| | results[result["dataset_name"]] = result |
| | with open(args.save_path, 'w') as f: |
| | json.dump(results, f, indent=4) |
| | print("All datasets have been processed.") |
| |
|