Spaces:
Sleeping
Sleeping
| # Standalone Hugging Face Space viewer for TrajectoryBuffer-style HDF5 files. | |
| # | |
| # requirements.txt: | |
| # gradio | |
| # huggingface_hub | |
| # h5py | |
| # numpy | |
| # pillow | |
| # matplotlib | |
| # imageio | |
| # imageio-ffmpeg | |
| # | |
| # Optional: | |
| # opencv-python-headless | |
| import os | |
| import re | |
| import tempfile | |
| from functools import lru_cache | |
| import gradio as gr | |
| import h5py | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image, ImageDraw | |
| try: | |
| import imageio.v2 as imageio | |
| except Exception: | |
| imageio = None | |
| try: | |
| import cv2 | |
| except Exception: | |
| cv2 = None | |
| DATASET_PRESETS = { | |
| "Robosuite Square Correction": { | |
| "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state", | |
| "filename": ( | |
| "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_" | |
| "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5" | |
| ), | |
| "default_reverse_channels": False, | |
| }, | |
| "InsertT Demonstration": { | |
| "repo_id": "Zhaoting123/InsertT", | |
| "filename": "trajectory_buffer_Nov10_demo.hdf5", | |
| "default_reverse_channels": True, | |
| }, | |
| "InsertT Correction": { | |
| "repo_id": "Zhaoting123/InsertT", | |
| "filename": "trajectory_buffer_Nov11_intervention.hdf5", | |
| "default_reverse_channels": True, | |
| }, | |
| "RoundTable Correction": { | |
| "repo_id": "Zhaoting123/Furniture_Bench_Round_Table_Assembly", | |
| "filename": "trajectory_buffer_0_Nov24_intervention_relabeled.hdf5", | |
| "default_reverse_channels": True, | |
| }, | |
| } | |
| DEFAULT_PRESET = "Robosuite Square Correction" | |
| REPO_TYPE = "dataset" | |
| DEFAULT_CHUNK_LEN = 16 | |
| DEFAULT_DISPLAY_SCALE = 1 | |
| VIDEO_STATUS_FIGSIZE = (6.0, 1.8) | |
| VIDEO_STATUS_DPI = 120 | |
| PREFERRED_IMAGE_KEYS = [ | |
| "image1", | |
| "image2", | |
| "agentview_image", | |
| "robot0_eye_in_hand_image", | |
| "front_image", | |
| "wrist_image", | |
| ] | |
| IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"] | |
| def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None): | |
| preset_name = preset_name or DEFAULT_PRESET | |
| if preset_name == "Custom": | |
| repo_id = str(custom_repo_id or "").strip() | |
| filename = str(custom_filename or "").strip() | |
| if not repo_id or not filename: | |
| raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.") | |
| return repo_id, filename | |
| item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) | |
| return item["repo_id"], item["filename"] | |
| def get_default_reverse_channels(preset_name): | |
| """Dataset-specific default for BGR<->RGB reversal. | |
| Robosuite Square presets use normal RGB ordering. | |
| InsertT / PushT-style preset requires reversal. | |
| Custom datasets default to False so users can still override manually. | |
| """ | |
| preset_name = preset_name or DEFAULT_PRESET | |
| if preset_name == "Custom": | |
| return False | |
| item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) | |
| return bool(item.get("default_reverse_channels", False)) | |
| def get_local_hdf5_path(repo_id, filename): | |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE) | |
| def _natural_sort_key(name): | |
| match = re.search(r"([0-9]+)$", str(name)) | |
| if match: | |
| return 0, int(match.group(1)) | |
| return 1, str(name) | |
| def get_trajectory_keys(repo_id, filename): | |
| path = get_local_hdf5_path(repo_id, filename) | |
| with h5py.File(path, "r") as f: | |
| root_episode_keys = [ | |
| key for key in f.keys() | |
| if isinstance(f[key], h5py.Group) and str(key).startswith("episode_") | |
| ] | |
| if root_episode_keys: | |
| return tuple(sorted(root_episode_keys, key=_natural_sort_key)) | |
| if "data" in f and isinstance(f["data"], h5py.Group): | |
| data_group = f["data"] | |
| keys = [key for key in data_group.keys() if isinstance(data_group[key], h5py.Group)] | |
| return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key)) | |
| keys = [key for key in f.keys() if isinstance(f[key], h5py.Group)] | |
| return tuple(sorted(keys, key=_natural_sort_key)) | |
| def get_num_trajectories(repo_id, filename): | |
| return len(get_trajectory_keys(repo_id, filename)) | |
| def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180): | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| path = get_local_hdf5_path(repo_id, filename) | |
| lines = [] | |
| with h5py.File(path, "r") as f: | |
| def visitor(name, obj): | |
| if len(lines) >= max_lines: | |
| return | |
| if isinstance(obj, h5py.Dataset): | |
| lines.append("DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype)) | |
| elif isinstance(obj, h5py.Group): | |
| lines.append("GROUP {}".format(name)) | |
| f.visititems(visitor) | |
| if len(lines) >= max_lines: | |
| lines.append("...") | |
| return "\n".join(lines) if lines else "No HDF5 contents found." | |
| def _read_dataset_value(dataset): | |
| value = dataset[()] | |
| if isinstance(value, bytes): | |
| return value.decode("utf-8") | |
| return value | |
| def _read_group_recursive(group): | |
| out = {} | |
| for key, obj in group.items(): | |
| if isinstance(obj, h5py.Dataset): | |
| out[key] = _read_dataset_value(obj) | |
| elif isinstance(obj, h5py.Group): | |
| out[key] = _read_group_recursive(obj) | |
| return out | |
| def _find_first_key(mapping, candidate_keys): | |
| for key in candidate_keys: | |
| if key in mapping: | |
| return key | |
| return None | |
| def _infer_time_length(data): | |
| for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]: | |
| if key in data: | |
| arr = np.asarray(data[key]) | |
| if arr.ndim >= 1: | |
| return int(arr.shape[0]) | |
| obs_group = None | |
| if isinstance(data.get("observation"), dict): | |
| obs_group = data["observation"] | |
| elif isinstance(data.get("obs"), dict): | |
| obs_group = data["obs"] | |
| if obs_group: | |
| lengths = [] | |
| for value in obs_group.values(): | |
| arr = np.asarray(value) | |
| if arr.ndim >= 1: | |
| lengths.append(int(arr.shape[0])) | |
| if lengths: | |
| values, counts = np.unique(lengths, return_counts=True) | |
| return int(values[np.argmax(counts)]) | |
| return 1 | |
| def _slice_time(value, t, T): | |
| arr = np.asarray(value) | |
| if arr.ndim >= 1 and arr.shape[0] == T: | |
| return arr[t] | |
| return arr | |
| def load_traj(repo_id, filename, traj_id): | |
| traj_keys = get_trajectory_keys(repo_id, filename) | |
| if not traj_keys: | |
| return [] | |
| traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1)) | |
| traj_key = traj_keys[traj_id] | |
| path = get_local_hdf5_path(repo_id, filename) | |
| with h5py.File(path, "r") as f: | |
| data = _read_group_recursive(f[traj_key]) | |
| T = _infer_time_length(data) | |
| if isinstance(data.get("observation"), dict): | |
| obs_all = data["observation"] | |
| elif isinstance(data.get("obs"), dict): | |
| obs_all = data["obs"] | |
| else: | |
| obs_all = {} | |
| action_key = _find_first_key(data, ["actions", "action"]) | |
| teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"]) | |
| robot_key = _find_first_key(data, ["robot_actions", "robot_action"]) | |
| no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"]) | |
| no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"]) | |
| done_key = _find_first_key(data, ["dones", "done"]) | |
| timestep_key = _find_first_key(data, ["timesteps", "timestep"]) | |
| success_key = _find_first_key(data, ["if_success", "success", "successes"]) | |
| traj = [] | |
| for t in range(T): | |
| obs_t = {key: _slice_time(value, t, T) for key, value in obs_all.items()} | |
| default_action = np.zeros(1, dtype=np.float32) | |
| if action_key is not None: | |
| default_action = _slice_time(data[action_key], t, T) | |
| teacher_action = _slice_time(data[teacher_key], t, T) if teacher_key else default_action | |
| robot_action = _slice_time(data[robot_key], t, T) if robot_key else default_action | |
| no_teacher = _slice_time(data[no_teacher_key], t, T) if no_teacher_key else False | |
| no_robot = _slice_time(data[no_robot_key], t, T) if no_robot_key else False | |
| done = _slice_time(data[done_key], t, T) if done_key else False | |
| if_success = _slice_time(data[success_key], t, T) if success_key else False | |
| timestep = t | |
| if timestep_key is not None: | |
| timestep_arr = _slice_time(data[timestep_key], t, T) | |
| timestep = int(np.asarray(timestep_arr).reshape(-1)[0]) | |
| traj.append({ | |
| "obs": obs_t, | |
| "robot_action": np.asarray(robot_action), | |
| "teacher_action": np.asarray(teacher_action), | |
| "done": bool(np.asarray(done).reshape(-1)[0]), | |
| "timestep": timestep, | |
| "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]), | |
| "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]), | |
| "episode_id": traj_key, | |
| "if_success": bool(np.asarray(if_success).reshape(-1)[0]), | |
| }) | |
| return traj | |
| def _extract_latest_obs_value(value): | |
| """Return the latest stacked observation only when there is a clear stack axis. | |
| Important: | |
| - [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame. | |
| - [C, H, W] must NOT be sliced, otherwise an RGB image becomes one | |
| grayscale channel. | |
| """ | |
| arr = np.asarray(value) | |
| # Stacked image observations, e.g. [obs_T, C, H, W] or [obs_T, H, W, C]. | |
| if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): | |
| channel_first = arr.shape[1] in (1, 3, 4) | |
| channel_last = arr.shape[-1] in (1, 3, 4) | |
| if channel_first or channel_last: | |
| return arr[-1] | |
| # Stacked vector observations, e.g. [obs_T, D]. Keep this for non-image obs. | |
| if arr.ndim == 2 and arr.shape[0] in (1, 2): | |
| return arr[-1] | |
| return arr | |
| def _looks_like_image_array(key, value): | |
| arr = np.asarray(value) | |
| key_l = str(key).lower() | |
| key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS) | |
| # Remove only a clear stacked-image axis for shape detection. | |
| if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): | |
| if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4): | |
| arr = arr[-1] | |
| shape_hint = False | |
| if arr.ndim == 2: | |
| shape_hint = True | |
| elif arr.ndim == 3: | |
| shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4) | |
| elif arr.ndim == 4: | |
| shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4) | |
| return key_hint or shape_hint | |
| def _float_img_to_uint8(img): | |
| arr = img.astype(np.float32) | |
| arr_min = float(np.nanmin(arr)) | |
| arr_max = float(np.nanmax(arr)) | |
| if arr_min >= -1.01 and arr_max <= 1.01: | |
| if arr_min < 0.0: | |
| arr = (arr + 1.0) * 0.5 | |
| arr = np.clip(arr, 0.0, 1.0) * 255.0 | |
| elif arr_max <= 255.0: | |
| arr = np.clip(arr, 0.0, 255.0) | |
| else: | |
| arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8) | |
| return np.round(arr).astype(np.uint8) | |
| def _extract_display_image(value, reverse_channels=False): | |
| img = np.asarray(_extract_latest_obs_value(value)) | |
| if img.ndim == 2: | |
| img = np.repeat(img[..., None], 3, axis=-1) | |
| elif img.ndim == 3 and img.shape[0] in (1, 3, 4): | |
| img = np.transpose(img, (1, 2, 0)) | |
| if img.ndim == 3 and img.shape[-1] == 1: | |
| img = np.repeat(img, 3, axis=-1) | |
| elif img.ndim == 3 and img.shape[-1] == 4: | |
| img = img[..., :3] | |
| if img.ndim != 3: | |
| raise ValueError("Unsupported image shape: {}".format(img.shape)) | |
| out = img.copy() if img.dtype == np.uint8 else _float_img_to_uint8(img) | |
| if reverse_channels and out.shape[-1] == 3: | |
| out = out[..., ::-1] | |
| return out | |
| def _resize_image_for_display(img, display_scale): | |
| scale = float(display_scale) | |
| if scale == 1.0: | |
| return img | |
| h, w = img.shape[:2] | |
| new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale)))) | |
| if cv2 is not None: | |
| return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST) | |
| pil_img = Image.fromarray(img) | |
| return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST)) | |
| def _extract_mixed_action_chunk(traj, start_idx, chunk_length): | |
| chunk = [] | |
| sources = [] | |
| end_idx = min(len(traj), int(start_idx) + int(chunk_length)) | |
| for idx in range(int(start_idx), end_idx): | |
| step = traj[idx] | |
| use_teacher = not bool(step.get("no_teacher_action", False)) | |
| action = step["teacher_action"] if use_teacher else step["robot_action"] | |
| chunk.append(np.asarray(action, dtype=np.float32).reshape(-1)) | |
| sources.append("T" if use_teacher else "R") | |
| if not chunk: | |
| return None, "" | |
| return np.stack(chunk, axis=0), "".join(sources) | |
| def _extract_robot_action_chunk(traj, start_idx, chunk_length): | |
| chunk = [] | |
| end_idx = min(len(traj), int(start_idx) + int(chunk_length)) | |
| for idx in range(int(start_idx), end_idx): | |
| step = traj[idx] | |
| chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1)) | |
| if not chunk: | |
| return None | |
| return np.stack(chunk, axis=0) | |
| def _safe_array_str(value, precision=3, max_items=24): | |
| arr = np.asarray(value).reshape(-1) | |
| shown = arr[:max_items] | |
| text = np.array2string(shown, precision=precision, separator=", ") | |
| if arr.size > max_items: | |
| text += " ... +{} more".format(arr.size - max_items) | |
| return text | |
| def _make_action_chunk_plot(mixed_chunk, robot_chunk): | |
| if mixed_chunk is None: | |
| return None | |
| mixed_chunk = np.asarray(mixed_chunk, dtype=np.float32) | |
| if mixed_chunk.ndim == 1: | |
| mixed_chunk = mixed_chunk[:, None] | |
| fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140) | |
| x = np.arange(mixed_chunk.shape[0]) | |
| max_dims = min(mixed_chunk.shape[1], 10) | |
| for dim in range(max_dims): | |
| ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim)) | |
| if robot_chunk is not None: | |
| robot_chunk = np.asarray(robot_chunk, dtype=np.float32) | |
| if robot_chunk.ndim == 1: | |
| robot_chunk = robot_chunk[:, None] | |
| for dim in range(min(robot_chunk.shape[1], max_dims)): | |
| ax.plot( | |
| x, | |
| robot_chunk[:, dim], | |
| linestyle="--", | |
| alpha=0.55, | |
| label="robot[{}]".format(dim), | |
| ) | |
| ax.set_title("Action chunk") | |
| ax.set_xlabel("chunk step") | |
| ax.set_ylabel("action value") | |
| ax.grid(True, alpha=0.3) | |
| ax.legend(loc="upper right", fontsize=7, ncol=2) | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| rgba = np.asarray(fig.canvas.buffer_rgba()) | |
| image = rgba[..., :3].copy() | |
| plt.close(fig) | |
| return image | |
| def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, reverse_channels): | |
| traj = load_traj(repo_id, filename, int(traj_id)) | |
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) | |
| obs = traj[timestep].get("obs", {}) | |
| gallery_items = [] | |
| warnings = [] | |
| for key in image_keys_tuple: | |
| if key not in obs: | |
| warnings.append("Missing image key: {}".format(key)) | |
| continue | |
| try: | |
| img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels)) | |
| img = _resize_image_for_display(img, float(display_scale)) | |
| gallery_items.append((img, key)) | |
| except Exception as exc: | |
| warnings.append("{}: {}".format(key, exc)) | |
| return gallery_items, tuple(warnings) | |
| def _compute_valid_start_indices(traj, min_seq_len): | |
| """Match the original local script's valid-start heuristic. | |
| A timestep is valid when the following min_seq_len steps all have | |
| no_teacher_action == False. | |
| """ | |
| total_steps = len(traj) | |
| min_seq_len = int(max(1, min_seq_len)) | |
| no_teacher = np.asarray( | |
| [int(bool(step.get("no_teacher_action", False))) for step in traj], | |
| dtype=np.int32, | |
| ) | |
| valid_indices = [] | |
| max_start = total_steps - min_seq_len + 1 | |
| for t in range(max(0, max_start)): | |
| if int(np.sum(no_teacher[t:t + min_seq_len])) == 0: | |
| valid_indices.append(t) | |
| return no_teacher, valid_indices | |
| def _make_trajectory_status_plot(traj, timestep, min_seq_len): | |
| """Render the same high-level status figure as the local matplotlib tool. | |
| Shows: | |
| - orange no_teacher_action step plot | |
| - green triangles for algorithmic valid start points | |
| - black vertical cursor at current timestep | |
| """ | |
| total_steps = len(traj) | |
| if total_steps == 0: | |
| return None, False, 0 | |
| timestep = int(np.clip(int(timestep), 0, total_steps - 1)) | |
| timesteps = np.asarray( | |
| [int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)], | |
| dtype=np.int32, | |
| ) | |
| no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len) | |
| is_valid_start = timestep in set(valid_indices) | |
| fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170) | |
| ax.step( | |
| np.arange(total_steps), | |
| no_teacher, | |
| where="post", | |
| label="no_teacher_action", | |
| color="orange", | |
| ) | |
| if valid_indices: | |
| ax.scatter( | |
| valid_indices, | |
| [-0.15] * len(valid_indices), | |
| color="green", | |
| marker="^", | |
| s=18, | |
| label="Valid Start (len >= {})".format(int(min_seq_len)), | |
| ) | |
| ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5) | |
| ax.set_xlim(0, max(total_steps - 1, 1)) | |
| ax.set_ylim(-0.38, 1.1) | |
| ax.set_ylabel("Flag", fontsize=10) | |
| ax.set_xlabel("Timestep index", fontsize=10) | |
| ax.set_yticks([0, 1]) | |
| ax.set_yticklabels(["False", "True"]) | |
| ax.grid(True, axis="x", alpha=0.2) | |
| title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1) | |
| if is_valid_start: | |
| title += " | VALID START" | |
| ax.set_title(title, fontsize=11) | |
| ax.tick_params(axis="both", labelsize=9) | |
| ax.legend(loc="upper right", fontsize=9) | |
| # Add saved timestep annotation if the stored timestep is not the same as index. | |
| saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep | |
| if saved_timestep != timestep: | |
| ax.text( | |
| 0.01, | |
| 0.04, | |
| "saved timestep: {}".format(saved_timestep), | |
| transform=ax.transAxes, | |
| fontsize=8, | |
| va="bottom", | |
| ha="left", | |
| ) | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| rgba = np.asarray(fig.canvas.buffer_rgba()) | |
| image = rgba[..., :3].copy() | |
| plt.close(fig) | |
| return image, bool(is_valid_start), len(valid_indices) | |
| def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len): | |
| traj = load_traj(repo_id, filename, int(traj_id)) | |
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) | |
| return _make_trajectory_status_plot(traj, timestep, int(min_seq_len)) | |
| def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels): | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| if n_traj == 0: | |
| return "No trajectories found." | |
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) | |
| traj = load_traj(repo_id, filename, traj_id) | |
| if not traj: | |
| return "Trajectory could not be loaded." | |
| if image_keys is None: | |
| image_keys = [] | |
| if isinstance(image_keys, str): | |
| image_keys = [image_keys] | |
| image_keys_tuple = tuple(image_keys) | |
| total = len(traj) | |
| for t in range(total): | |
| get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels)) | |
| get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len)) | |
| status = "Preloaded trajectory {}".format(traj_id) | |
| status += "\nFrames cached: {}".format(total) | |
| status += "\nImage keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none") | |
| return status | |
| def _compose_video_frame(gallery_items, frame_label, status_plot=None): | |
| """Compose one video frame. | |
| Top: selected observation images. | |
| Bottom: trajectory-status plot with the moving timestep cursor. | |
| Important: do NOT downscale the status plot to the image width. The plot | |
| contains tick labels and a legend, so preserving its native width makes the | |
| generated MP4 much more readable. | |
| """ | |
| small_text_y = 3 | |
| if not gallery_items: | |
| obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20)) | |
| draw = ImageDraw.Draw(obs_canvas) | |
| draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255)) | |
| else: | |
| pil_images = [] | |
| for img, label in gallery_items: | |
| pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB") | |
| # Keep the image-key caption compact; large captions waste video space. | |
| label_h = 16 | |
| panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0)) | |
| panel.paste(pil_img, (0, label_h)) | |
| draw = ImageDraw.Draw(panel) | |
| draw.text((4, small_text_y), str(label), fill=(220, 220, 220)) | |
| pil_images.append(panel) | |
| gap = 8 | |
| top_h = 18 | |
| width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0) | |
| height = max(im.height for im in pil_images) + top_h | |
| obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0)) | |
| draw = ImageDraw.Draw(obs_canvas) | |
| # Compact frame label above the image panels. | |
| draw.text((6, small_text_y), frame_label, fill=(220, 220, 220)) | |
| x = 0 | |
| for im in pil_images: | |
| obs_canvas.paste(im, (x, top_h)) | |
| x += im.width + gap | |
| if status_plot is not None: | |
| status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB") | |
| # Preserve the status plot resolution. If needed, pad the observation | |
| # canvas to the same width and center it above the plot. | |
| final_w = max(obs_canvas.width, status_img.width) | |
| if obs_canvas.width < final_w: | |
| padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0)) | |
| padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0)) | |
| obs_canvas = padded_obs | |
| elif status_img.width < final_w: | |
| padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255)) | |
| padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0)) | |
| status_img = padded_status | |
| gap_h = 8 | |
| canvas = Image.new( | |
| "RGB", | |
| (final_w, obs_canvas.height + gap_h + status_img.height), | |
| color=(0, 0, 0), | |
| ) | |
| canvas.paste(obs_canvas, (0, 0)) | |
| canvas.paste(status_img, (0, obs_canvas.height + gap_h)) | |
| else: | |
| canvas = obs_canvas | |
| # Many MP4 encoders prefer dimensions divisible by 16. | |
| pad_w = int(np.ceil(canvas.width / 16.0) * 16) | |
| pad_h = int(np.ceil(canvas.height / 16.0) * 16) | |
| if pad_w != canvas.width or pad_h != canvas.height: | |
| padded = Image.new("RGB", (pad_w, pad_h), color=(0, 0, 0)) | |
| padded.paste(canvas, (0, 0)) | |
| canvas = padded | |
| return np.asarray(canvas) | |
| def get_video_status_plot_base(repo_id, filename, traj_id, valid_window_len): | |
| """Render the static part of the status plot once for video export. | |
| Matplotlib per frame is slow. This function draws no_teacher_action and | |
| valid-start markers once, records the axes pixel bounds, and returns a base | |
| image. The moving cursor is later drawn with PIL, which is much faster. | |
| """ | |
| traj = load_traj(repo_id, filename, int(traj_id)) | |
| total_steps = len(traj) | |
| if total_steps == 0: | |
| return None, (0, 0, 1, 1), 0 | |
| no_teacher, valid_indices = _compute_valid_start_indices(traj, int(valid_window_len)) | |
| fig, ax = plt.subplots(figsize=VIDEO_STATUS_FIGSIZE, dpi=VIDEO_STATUS_DPI) | |
| ax.step( | |
| np.arange(total_steps), | |
| no_teacher, | |
| where="post", | |
| label="no_teacher_action", | |
| color="orange", | |
| ) | |
| if valid_indices: | |
| ax.scatter( | |
| valid_indices, | |
| [-0.15] * len(valid_indices), | |
| color="green", | |
| marker="^", | |
| s=18, | |
| label="Valid Start (len >= {})".format(int(valid_window_len)), | |
| ) | |
| ax.set_xlim(0, max(total_steps - 1, 1)) | |
| ax.set_ylim(-0.38, 1.1) | |
| ax.set_ylabel("Flag", fontsize=8) | |
| ax.set_xlabel("Timestep index", fontsize=8) | |
| ax.set_yticks([0, 1]) | |
| ax.set_yticklabels(["False", "True"]) | |
| ax.grid(True, axis="x", alpha=0.2) | |
| ax.set_title("no_teacher_action and valid starts", fontsize=9) | |
| ax.tick_params(axis="both", labelsize=7) | |
| ax.legend(loc="upper right", fontsize=7) | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| rgba = np.asarray(fig.canvas.buffer_rgba()) | |
| base = rgba[..., :3].copy() | |
| bbox = ax.get_window_extent() | |
| height = base.shape[0] | |
| # Matplotlib bbox origin is bottom-left, image origin is top-left. | |
| x0 = int(round(bbox.x0)) | |
| x1 = int(round(bbox.x1)) | |
| y0 = int(round(height - bbox.y1)) | |
| y1 = int(round(height - bbox.y0)) | |
| plt.close(fig) | |
| return base, (x0, y0, x1, y1), total_steps | |
| def get_cached_video_status_frame(repo_id, filename, traj_id, timestep, valid_window_len): | |
| """Draw the moving cursor on a cached static status plot.""" | |
| base, bounds, total_steps = get_video_status_plot_base( | |
| repo_id, | |
| filename, | |
| int(traj_id), | |
| int(valid_window_len), | |
| ) | |
| if base is None: | |
| return None | |
| timestep = int(np.clip(int(timestep), 0, max(total_steps - 1, 0))) | |
| x0, y0, x1, y1 = bounds | |
| denom = max(total_steps - 1, 1) | |
| x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) | |
| img = Image.fromarray(np.asarray(base, dtype=np.uint8)).convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| # Moving cursor. | |
| draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=4) | |
| # Compact step label, top-left of the plot area. | |
| label = "step {}/{}".format(timestep, total_steps - 1) | |
| draw.rectangle((x0 + 4, y0 + 4, x0 + 118, y0 + 24), fill=(255, 255, 255)) | |
| draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) | |
| return np.asarray(img) | |
| def _draw_status_cursor_on_base(base, bounds, total_steps, timestep): | |
| """Fast video status frame: copy one static Matplotlib image and draw cursor. | |
| This avoids calling the lru-cached per-timestep status frame function during | |
| video export. For long trajectories, caching thousands of status images can | |
| consume a lot of memory and still requires PIL conversion for every frame. | |
| """ | |
| if base is None: | |
| return None | |
| total_steps = int(max(total_steps, 1)) | |
| timestep = int(np.clip(int(timestep), 0, total_steps - 1)) | |
| x0, y0, x1, y1 = [int(v) for v in bounds] | |
| denom = max(total_steps - 1, 1) | |
| x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) | |
| img = np.asarray(base, dtype=np.uint8).copy() | |
| # Draw the vertical cursor directly with NumPy. This is much cheaper than | |
| # creating a Matplotlib plot for every frame. | |
| x_left = max(0, x - 2) | |
| x_right = min(img.shape[1], x + 2) | |
| y_top = max(0, y0) | |
| y_bottom = min(img.shape[0], y1) | |
| img[y_top:y_bottom, x_left:x_right, :] = 0 | |
| # Small text label. PIL is used only for the label, not for the whole plot. | |
| pil_img = Image.fromarray(img).convert("RGB") | |
| draw = ImageDraw.Draw(pil_img) | |
| label = "step {}/{}".format(timestep, total_steps - 1) | |
| draw.rectangle((x0 + 4, y0 + 4, x0 + 126, y0 + 24), fill=(255, 255, 255)) | |
| draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) | |
| return np.asarray(pil_img) | |
| def _get_fast_video_writer(out_path, fps): | |
| """Use ffmpeg's ultrafast x264 preset for interactive Spaces exports.""" | |
| return imageio.get_writer( | |
| out_path, | |
| fps=float(fps), | |
| codec="libx264", | |
| macro_block_size=16, | |
| ffmpeg_params=[ | |
| "-preset", "ultrafast", | |
| "-crf", "28", | |
| "-pix_fmt", "yuv420p", | |
| "-movflags", "+faststart", | |
| ], | |
| ) | |
| def build_current_trajectory_video(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, display_scale, reverse_channels, fps, valid_window_len, video_stride=4): | |
| if imageio is None: | |
| return None, "Video export requires imageio and imageio-ffmpeg in requirements.txt." | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| if n_traj == 0: | |
| return None, "No trajectories found." | |
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) | |
| traj = load_traj(repo_id, filename, traj_id) | |
| if not traj: | |
| return None, "Trajectory could not be loaded." | |
| if image_keys is None: | |
| image_keys = [] | |
| if isinstance(image_keys, str): | |
| image_keys = [image_keys] | |
| image_keys_tuple = tuple(image_keys) | |
| video_stride = int(max(1, int(video_stride))) | |
| frame_indices = list(range(0, len(traj), video_stride)) | |
| if frame_indices and frame_indices[-1] != len(traj) - 1: | |
| frame_indices.append(len(traj) - 1) | |
| safe_repo = re.sub(r"[^A-Za-z0-9_.-]+", "_", repo_id) | |
| safe_file = re.sub(r"[^A-Za-z0-9_.-]+", "_", filename)[-80:] | |
| out_path = os.path.join( | |
| tempfile.gettempdir(), | |
| "trajectory_{}_{}_traj{:04d}_fps{}_stride{}.mp4".format( | |
| safe_repo, safe_file, traj_id, int(fps), video_stride | |
| ), | |
| ) | |
| # Build the static status plot once. During export, only draw the cursor. | |
| status_base, status_bounds, total_steps = get_video_status_plot_base( | |
| repo_id, | |
| filename, | |
| traj_id, | |
| int(valid_window_len), | |
| ) | |
| writer = _get_fast_video_writer(out_path, fps) | |
| written = 0 | |
| try: | |
| for t in frame_indices: | |
| # Use the existing cached image extraction for correctness, but avoid | |
| # cached per-timestep status images to reduce memory pressure. | |
| gallery_items, _warnings = get_cached_gallery_items( | |
| repo_id, | |
| filename, | |
| traj_id, | |
| t, | |
| image_keys_tuple, | |
| float(display_scale), | |
| bool(reverse_channels), | |
| ) | |
| label = "trajectory {} | frame {}/{}".format(traj_id, t, len(traj) - 1) | |
| status_plot = _draw_status_cursor_on_base(status_base, status_bounds, total_steps, t) | |
| frame = _compose_video_frame(gallery_items, label, status_plot=status_plot) | |
| writer.append_data(frame) | |
| written += 1 | |
| finally: | |
| writer.close() | |
| approx_seconds = float(written) / float(max(float(fps), 1.0)) | |
| status = "Built trajectory video with optimized encoder and status rendering" | |
| status += "\nTrajectory: {}".format(traj_id) | |
| status += "\nOriginal timesteps: {} | Written frames: {} | Stride: {}".format(len(traj), written, video_stride) | |
| status += "\nFPS: {} | Approx video duration: {:.1f}s".format(fps, approx_seconds) | |
| status += "\nValid-window length: {}".format(int(valid_window_len)) | |
| status += "\nSpeedups: x264 ultrafast preset; static status plot rendered once; cursor drawn with NumPy/PIL" | |
| return out_path, status | |
| def get_available_image_keys(repo_id, filename, traj_id): | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| if n_traj == 0: | |
| return [] | |
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) | |
| traj = load_traj(repo_id, filename, traj_id) | |
| if not traj: | |
| return [] | |
| obs = traj[0].get("obs", {}) | |
| keys = [] | |
| for key, value in obs.items(): | |
| try: | |
| if _looks_like_image_array(key, value): | |
| keys.append(key) | |
| except Exception: | |
| pass | |
| ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys] | |
| ordered.extend([key for key in keys if key not in ordered]) | |
| return ordered | |
| def update_custom_visibility(preset_name): | |
| visible = preset_name == "Custom" | |
| return gr.update(visible=visible), gr.update(visible=visible) | |
| def update_after_dataset_change(preset_name, custom_repo_id, custom_filename): | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| reverse_default = get_default_reverse_channels(preset_name) | |
| if n_traj == 0: | |
| status = "Loaded `{}` / `{}`".format(repo_id, filename) | |
| status += "\nDetected trajectories: 0" | |
| status += "\nreverse_channels default: {}".format(int(reverse_default)) | |
| return ( | |
| gr.update(maximum=1, value=0), | |
| gr.update(maximum=1, value=0), | |
| gr.update(choices=[], value=[]), | |
| status, | |
| gr.update(value=reverse_default), | |
| ) | |
| keys = get_available_image_keys(repo_id, filename, 0) | |
| traj = load_traj(repo_id, filename, 0) | |
| status = "Loaded `{}` / `{}`".format(repo_id, filename) | |
| status += "\nDetected trajectories: {}".format(n_traj) | |
| status += "\nreverse_channels default: {}".format(int(reverse_default)) | |
| return ( | |
| gr.update(maximum=max(n_traj - 1, 1), value=0), | |
| gr.update(maximum=max(len(traj) - 1, 1), value=0), | |
| gr.update(choices=keys, value=keys[:2]), | |
| status, | |
| gr.update(value=reverse_default), | |
| ) | |
| def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id): | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| if n_traj == 0: | |
| return gr.update(maximum=1, value=0), gr.update(choices=[], value=[]) | |
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) | |
| traj = load_traj(repo_id, filename, traj_id) | |
| keys = get_available_image_keys(repo_id, filename, traj_id) | |
| return ( | |
| gr.update(maximum=max(len(traj) - 1, 1), value=0), | |
| gr.update(choices=keys, value=keys[:2]), | |
| ) | |
| def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels): | |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| if n_traj == 0: | |
| return [], None, "No trajectory groups found. Open Debug: HDF5 tree." | |
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) | |
| traj = load_traj(repo_id, filename, traj_id) | |
| if not traj: | |
| return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree." | |
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) | |
| chunk_len = int(chunk_len) | |
| display_scale = float(display_scale) | |
| if image_keys is None: | |
| image_keys = [] | |
| if isinstance(image_keys, str): | |
| image_keys = [image_keys] | |
| step = traj[timestep] | |
| image_keys_tuple = tuple(image_keys) | |
| gallery_items, warnings_tuple = get_cached_gallery_items( | |
| repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, bool(reverse_channels) | |
| ) | |
| warnings = list(warnings_tuple) | |
| status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len) | |
| image_debug_lines = [] | |
| for _key in image_keys: | |
| if _key in step.get("obs", {}): | |
| _arr = np.asarray(step["obs"][_key]) | |
| image_debug_lines.append( | |
| "{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype) | |
| ) | |
| info_lines = [ | |
| "dataset: {} / {}".format(repo_id, filename), | |
| "detected trajectories: {}".format(n_traj), | |
| "trajectory: {}".format(traj_id), | |
| "episode_id: {}".format(step.get("episode_id", "")), | |
| "timestep: {} / {}".format(timestep, len(traj) - 1), | |
| "saved timestep: {}".format(step.get("timestep", timestep)), | |
| "done: {}".format(int(bool(step.get("done", False)))), | |
| "if_success: {}".format(int(bool(step.get("if_success", False)))), | |
| "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))), | |
| "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))), | |
| "valid-window length: {}".format(chunk_len), | |
| "valid_start: {}".format(int(bool(is_valid_start))), | |
| "num_valid_starts: {}".format(num_valid_starts), | |
| "", | |
| "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))), | |
| "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))), | |
| "", | |
| "selected image tensors:", | |
| *image_debug_lines, | |
| ] | |
| if warnings: | |
| info_lines.append("") | |
| info_lines.append("Image warnings:") | |
| info_lines.extend(warnings) | |
| return gallery_items, status_plot, "\n".join(info_lines) | |
| def build_app(): | |
| repo_id, filename = resolve_dataset(DEFAULT_PRESET) | |
| try: | |
| n_traj = get_num_trajectories(repo_id, filename) | |
| first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else [] | |
| startup_warning = "" | |
| except Exception as exc: | |
| n_traj = 0 | |
| first_keys = [] | |
| startup_warning = repr(exc) | |
| default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET))) | |
| with gr.Blocks(title="HDF5 Trajectory Viewer") as demo: | |
| gr.Markdown( | |
| "# HDF5 Trajectory Viewer\n\n" | |
| "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n" | |
| "The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor." | |
| ) | |
| if startup_warning: | |
| gr.Markdown("Startup warning: `{}`".format(startup_warning)) | |
| with gr.Row(): | |
| preset = gr.Dropdown( | |
| choices=list(DATASET_PRESETS.keys()) + ["Custom"], | |
| value=DEFAULT_PRESET, | |
| label="Dataset preset", | |
| ) | |
| custom_repo_id = gr.Textbox(value="", label="Custom repo_id, e.g. Zhaoting123/InsertT", visible=False) | |
| custom_filename = gr.Textbox(value="", label="Custom HDF5 path in repo", visible=False) | |
| dataset_status = gr.Textbox(label="Dataset status", lines=2, value=default_status, interactive=False) | |
| with gr.Row(): | |
| traj_slider = gr.Slider(minimum=0, maximum=max(n_traj - 1, 1), value=0, step=1, label="Trajectory index") | |
| timestep_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Timestep") | |
| with gr.Row(): | |
| image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys") | |
| chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length") | |
| display_scale = gr.State(value=DEFAULT_DISPLAY_SCALE) | |
| reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB") | |
| with gr.Row(): | |
| render_btn = gr.Button("Render frame", variant="primary") | |
| preload_btn = gr.Button("Preload current trajectory") | |
| video_btn = gr.Button("Build trajectory video") | |
| video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS") | |
| video_stride = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Video frame stride") | |
| preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery( | |
| label="Camera images", | |
| columns=2, | |
| height=360, | |
| object_fit="contain", | |
| ) | |
| with gr.Column(scale=2): | |
| status_plot = gr.Image( | |
| label="no_teacher_action + valid starts", | |
| type="numpy", | |
| height=360, | |
| ) | |
| trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback") | |
| info = gr.Textbox(label="Frame info", lines=16) | |
| with gr.Accordion("Debug: HDF5 tree", open=False): | |
| inspect_btn = gr.Button("Inspect HDF5 structure") | |
| hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree") | |
| preset.change( | |
| fn=update_custom_visibility, | |
| inputs=preset, | |
| outputs=[custom_repo_id, custom_filename], | |
| ).then( | |
| fn=update_after_dataset_change, | |
| inputs=[preset, custom_repo_id, custom_filename], | |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], | |
| ).then( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| custom_repo_id.submit( | |
| fn=update_after_dataset_change, | |
| inputs=[preset, custom_repo_id, custom_filename], | |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], | |
| ) | |
| custom_filename.submit( | |
| fn=update_after_dataset_change, | |
| inputs=[preset, custom_repo_id, custom_filename], | |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], | |
| ) | |
| traj_slider.change( | |
| fn=update_after_traj_change, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider], | |
| outputs=[timestep_slider, image_keys], | |
| ).then( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| timestep_slider.release( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| for widget in [image_keys, chunk_len, reverse_channels]: | |
| widget.change( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| render_btn.click( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| preload_btn.click( | |
| fn=preload_current_trajectory, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=preload_status, | |
| ) | |
| video_btn.click( | |
| fn=build_current_trajectory_video, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, display_scale, reverse_channels, video_fps, chunk_len, video_stride], | |
| outputs=[trajectory_video, preload_status], | |
| ) | |
| inspect_btn.click( | |
| fn=inspect_hdf5_tree, | |
| inputs=[preset, custom_repo_id, custom_filename], | |
| outputs=hdf5_tree, | |
| ) | |
| demo.load( | |
| fn=update_after_dataset_change, | |
| inputs=[preset, custom_repo_id, custom_filename], | |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], | |
| ).then( | |
| fn=render_frame, | |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], | |
| outputs=[gallery, status_plot, info], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_app() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| share=False, | |
| ssr_mode=False, | |
| ) |