# Standalone Hugging Face Space viewer for TrajectoryBuffer-style HDF5 files. # # requirements.txt: # gradio # huggingface_hub # h5py # numpy # pillow # matplotlib # imageio # imageio-ffmpeg # # Optional: # opencv-python-headless import os import re import tempfile from functools import lru_cache import gradio as gr import h5py import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw try: import imageio.v2 as imageio except Exception: imageio = None try: import cv2 except Exception: cv2 = None DATASET_PRESETS = { "Robosuite Square Correction": { "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state", "filename": ( "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_" "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5" ), "default_reverse_channels": False, }, "InsertT Demonstration": { "repo_id": "Zhaoting123/InsertT", "filename": "trajectory_buffer_Nov10_demo.hdf5", "default_reverse_channels": True, }, "InsertT Correction": { "repo_id": "Zhaoting123/InsertT", "filename": "trajectory_buffer_Nov11_intervention.hdf5", "default_reverse_channels": True, }, "RoundTable Correction": { "repo_id": "Zhaoting123/Furniture_Bench_Round_Table_Assembly", "filename": "trajectory_buffer_0_Nov24_intervention_relabeled.hdf5", "default_reverse_channels": True, }, } DEFAULT_PRESET = "Robosuite Square Correction" REPO_TYPE = "dataset" DEFAULT_CHUNK_LEN = 16 DEFAULT_DISPLAY_SCALE = 1 VIDEO_STATUS_FIGSIZE = (6.0, 1.8) VIDEO_STATUS_DPI = 120 PREFERRED_IMAGE_KEYS = [ "image1", "image2", "agentview_image", "robot0_eye_in_hand_image", "front_image", "wrist_image", ] IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"] def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None): preset_name = preset_name or DEFAULT_PRESET if preset_name == "Custom": repo_id = str(custom_repo_id or "").strip() filename = str(custom_filename or "").strip() if not repo_id or not filename: raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.") return repo_id, filename item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) return item["repo_id"], item["filename"] def get_default_reverse_channels(preset_name): """Dataset-specific default for BGR<->RGB reversal. Robosuite Square presets use normal RGB ordering. InsertT / PushT-style preset requires reversal. Custom datasets default to False so users can still override manually. """ preset_name = preset_name or DEFAULT_PRESET if preset_name == "Custom": return False item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) return bool(item.get("default_reverse_channels", False)) @lru_cache(maxsize=8) def get_local_hdf5_path(repo_id, filename): return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE) def _natural_sort_key(name): match = re.search(r"([0-9]+)$", str(name)) if match: return 0, int(match.group(1)) return 1, str(name) @lru_cache(maxsize=8) def get_trajectory_keys(repo_id, filename): path = get_local_hdf5_path(repo_id, filename) with h5py.File(path, "r") as f: root_episode_keys = [ key for key in f.keys() if isinstance(f[key], h5py.Group) and str(key).startswith("episode_") ] if root_episode_keys: return tuple(sorted(root_episode_keys, key=_natural_sort_key)) if "data" in f and isinstance(f["data"], h5py.Group): data_group = f["data"] keys = [key for key in data_group.keys() if isinstance(data_group[key], h5py.Group)] return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key)) keys = [key for key in f.keys() if isinstance(f[key], h5py.Group)] return tuple(sorted(keys, key=_natural_sort_key)) @lru_cache(maxsize=8) def get_num_trajectories(repo_id, filename): return len(get_trajectory_keys(repo_id, filename)) def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180): repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) path = get_local_hdf5_path(repo_id, filename) lines = [] with h5py.File(path, "r") as f: def visitor(name, obj): if len(lines) >= max_lines: return if isinstance(obj, h5py.Dataset): lines.append("DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype)) elif isinstance(obj, h5py.Group): lines.append("GROUP {}".format(name)) f.visititems(visitor) if len(lines) >= max_lines: lines.append("...") return "\n".join(lines) if lines else "No HDF5 contents found." def _read_dataset_value(dataset): value = dataset[()] if isinstance(value, bytes): return value.decode("utf-8") return value def _read_group_recursive(group): out = {} for key, obj in group.items(): if isinstance(obj, h5py.Dataset): out[key] = _read_dataset_value(obj) elif isinstance(obj, h5py.Group): out[key] = _read_group_recursive(obj) return out def _find_first_key(mapping, candidate_keys): for key in candidate_keys: if key in mapping: return key return None def _infer_time_length(data): for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]: if key in data: arr = np.asarray(data[key]) if arr.ndim >= 1: return int(arr.shape[0]) obs_group = None if isinstance(data.get("observation"), dict): obs_group = data["observation"] elif isinstance(data.get("obs"), dict): obs_group = data["obs"] if obs_group: lengths = [] for value in obs_group.values(): arr = np.asarray(value) if arr.ndim >= 1: lengths.append(int(arr.shape[0])) if lengths: values, counts = np.unique(lengths, return_counts=True) return int(values[np.argmax(counts)]) return 1 def _slice_time(value, t, T): arr = np.asarray(value) if arr.ndim >= 1 and arr.shape[0] == T: return arr[t] return arr @lru_cache(maxsize=64) def load_traj(repo_id, filename, traj_id): traj_keys = get_trajectory_keys(repo_id, filename) if not traj_keys: return [] traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1)) traj_key = traj_keys[traj_id] path = get_local_hdf5_path(repo_id, filename) with h5py.File(path, "r") as f: data = _read_group_recursive(f[traj_key]) T = _infer_time_length(data) if isinstance(data.get("observation"), dict): obs_all = data["observation"] elif isinstance(data.get("obs"), dict): obs_all = data["obs"] else: obs_all = {} action_key = _find_first_key(data, ["actions", "action"]) teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"]) robot_key = _find_first_key(data, ["robot_actions", "robot_action"]) no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"]) no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"]) done_key = _find_first_key(data, ["dones", "done"]) timestep_key = _find_first_key(data, ["timesteps", "timestep"]) success_key = _find_first_key(data, ["if_success", "success", "successes"]) traj = [] for t in range(T): obs_t = {key: _slice_time(value, t, T) for key, value in obs_all.items()} default_action = np.zeros(1, dtype=np.float32) if action_key is not None: default_action = _slice_time(data[action_key], t, T) teacher_action = _slice_time(data[teacher_key], t, T) if teacher_key else default_action robot_action = _slice_time(data[robot_key], t, T) if robot_key else default_action no_teacher = _slice_time(data[no_teacher_key], t, T) if no_teacher_key else False no_robot = _slice_time(data[no_robot_key], t, T) if no_robot_key else False done = _slice_time(data[done_key], t, T) if done_key else False if_success = _slice_time(data[success_key], t, T) if success_key else False timestep = t if timestep_key is not None: timestep_arr = _slice_time(data[timestep_key], t, T) timestep = int(np.asarray(timestep_arr).reshape(-1)[0]) traj.append({ "obs": obs_t, "robot_action": np.asarray(robot_action), "teacher_action": np.asarray(teacher_action), "done": bool(np.asarray(done).reshape(-1)[0]), "timestep": timestep, "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]), "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]), "episode_id": traj_key, "if_success": bool(np.asarray(if_success).reshape(-1)[0]), }) return traj def _extract_latest_obs_value(value): """Return the latest stacked observation only when there is a clear stack axis. Important: - [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame. - [C, H, W] must NOT be sliced, otherwise an RGB image becomes one grayscale channel. """ arr = np.asarray(value) # Stacked image observations, e.g. [obs_T, C, H, W] or [obs_T, H, W, C]. if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): channel_first = arr.shape[1] in (1, 3, 4) channel_last = arr.shape[-1] in (1, 3, 4) if channel_first or channel_last: return arr[-1] # Stacked vector observations, e.g. [obs_T, D]. Keep this for non-image obs. if arr.ndim == 2 and arr.shape[0] in (1, 2): return arr[-1] return arr def _looks_like_image_array(key, value): arr = np.asarray(value) key_l = str(key).lower() key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS) # Remove only a clear stacked-image axis for shape detection. if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4): arr = arr[-1] shape_hint = False if arr.ndim == 2: shape_hint = True elif arr.ndim == 3: shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4) elif arr.ndim == 4: shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4) return key_hint or shape_hint def _float_img_to_uint8(img): arr = img.astype(np.float32) arr_min = float(np.nanmin(arr)) arr_max = float(np.nanmax(arr)) if arr_min >= -1.01 and arr_max <= 1.01: if arr_min < 0.0: arr = (arr + 1.0) * 0.5 arr = np.clip(arr, 0.0, 1.0) * 255.0 elif arr_max <= 255.0: arr = np.clip(arr, 0.0, 255.0) else: arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8) return np.round(arr).astype(np.uint8) def _extract_display_image(value, reverse_channels=False): img = np.asarray(_extract_latest_obs_value(value)) if img.ndim == 2: img = np.repeat(img[..., None], 3, axis=-1) elif img.ndim == 3 and img.shape[0] in (1, 3, 4): img = np.transpose(img, (1, 2, 0)) if img.ndim == 3 and img.shape[-1] == 1: img = np.repeat(img, 3, axis=-1) elif img.ndim == 3 and img.shape[-1] == 4: img = img[..., :3] if img.ndim != 3: raise ValueError("Unsupported image shape: {}".format(img.shape)) out = img.copy() if img.dtype == np.uint8 else _float_img_to_uint8(img) if reverse_channels and out.shape[-1] == 3: out = out[..., ::-1] return out def _resize_image_for_display(img, display_scale): scale = float(display_scale) if scale == 1.0: return img h, w = img.shape[:2] new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale)))) if cv2 is not None: return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST) pil_img = Image.fromarray(img) return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST)) def _extract_mixed_action_chunk(traj, start_idx, chunk_length): chunk = [] sources = [] end_idx = min(len(traj), int(start_idx) + int(chunk_length)) for idx in range(int(start_idx), end_idx): step = traj[idx] use_teacher = not bool(step.get("no_teacher_action", False)) action = step["teacher_action"] if use_teacher else step["robot_action"] chunk.append(np.asarray(action, dtype=np.float32).reshape(-1)) sources.append("T" if use_teacher else "R") if not chunk: return None, "" return np.stack(chunk, axis=0), "".join(sources) def _extract_robot_action_chunk(traj, start_idx, chunk_length): chunk = [] end_idx = min(len(traj), int(start_idx) + int(chunk_length)) for idx in range(int(start_idx), end_idx): step = traj[idx] chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1)) if not chunk: return None return np.stack(chunk, axis=0) def _safe_array_str(value, precision=3, max_items=24): arr = np.asarray(value).reshape(-1) shown = arr[:max_items] text = np.array2string(shown, precision=precision, separator=", ") if arr.size > max_items: text += " ... +{} more".format(arr.size - max_items) return text def _make_action_chunk_plot(mixed_chunk, robot_chunk): if mixed_chunk is None: return None mixed_chunk = np.asarray(mixed_chunk, dtype=np.float32) if mixed_chunk.ndim == 1: mixed_chunk = mixed_chunk[:, None] fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140) x = np.arange(mixed_chunk.shape[0]) max_dims = min(mixed_chunk.shape[1], 10) for dim in range(max_dims): ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim)) if robot_chunk is not None: robot_chunk = np.asarray(robot_chunk, dtype=np.float32) if robot_chunk.ndim == 1: robot_chunk = robot_chunk[:, None] for dim in range(min(robot_chunk.shape[1], max_dims)): ax.plot( x, robot_chunk[:, dim], linestyle="--", alpha=0.55, label="robot[{}]".format(dim), ) ax.set_title("Action chunk") ax.set_xlabel("chunk step") ax.set_ylabel("action value") ax.grid(True, alpha=0.3) ax.legend(loc="upper right", fontsize=7, ncol=2) fig.tight_layout() fig.canvas.draw() rgba = np.asarray(fig.canvas.buffer_rgba()) image = rgba[..., :3].copy() plt.close(fig) return image @lru_cache(maxsize=8192) def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, reverse_channels): traj = load_traj(repo_id, filename, int(traj_id)) timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) obs = traj[timestep].get("obs", {}) gallery_items = [] warnings = [] for key in image_keys_tuple: if key not in obs: warnings.append("Missing image key: {}".format(key)) continue try: img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels)) img = _resize_image_for_display(img, float(display_scale)) gallery_items.append((img, key)) except Exception as exc: warnings.append("{}: {}".format(key, exc)) return gallery_items, tuple(warnings) def _compute_valid_start_indices(traj, min_seq_len): """Match the original local script's valid-start heuristic. A timestep is valid when the following min_seq_len steps all have no_teacher_action == False. """ total_steps = len(traj) min_seq_len = int(max(1, min_seq_len)) no_teacher = np.asarray( [int(bool(step.get("no_teacher_action", False))) for step in traj], dtype=np.int32, ) valid_indices = [] max_start = total_steps - min_seq_len + 1 for t in range(max(0, max_start)): if int(np.sum(no_teacher[t:t + min_seq_len])) == 0: valid_indices.append(t) return no_teacher, valid_indices def _make_trajectory_status_plot(traj, timestep, min_seq_len): """Render the same high-level status figure as the local matplotlib tool. Shows: - orange no_teacher_action step plot - green triangles for algorithmic valid start points - black vertical cursor at current timestep """ total_steps = len(traj) if total_steps == 0: return None, False, 0 timestep = int(np.clip(int(timestep), 0, total_steps - 1)) timesteps = np.asarray( [int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)], dtype=np.int32, ) no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len) is_valid_start = timestep in set(valid_indices) fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170) ax.step( np.arange(total_steps), no_teacher, where="post", label="no_teacher_action", color="orange", ) if valid_indices: ax.scatter( valid_indices, [-0.15] * len(valid_indices), color="green", marker="^", s=18, label="Valid Start (len >= {})".format(int(min_seq_len)), ) ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5) ax.set_xlim(0, max(total_steps - 1, 1)) ax.set_ylim(-0.38, 1.1) ax.set_ylabel("Flag", fontsize=10) ax.set_xlabel("Timestep index", fontsize=10) ax.set_yticks([0, 1]) ax.set_yticklabels(["False", "True"]) ax.grid(True, axis="x", alpha=0.2) title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1) if is_valid_start: title += " | VALID START" ax.set_title(title, fontsize=11) ax.tick_params(axis="both", labelsize=9) ax.legend(loc="upper right", fontsize=9) # Add saved timestep annotation if the stored timestep is not the same as index. saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep if saved_timestep != timestep: ax.text( 0.01, 0.04, "saved timestep: {}".format(saved_timestep), transform=ax.transAxes, fontsize=8, va="bottom", ha="left", ) fig.tight_layout() fig.canvas.draw() rgba = np.asarray(fig.canvas.buffer_rgba()) image = rgba[..., :3].copy() plt.close(fig) return image, bool(is_valid_start), len(valid_indices) @lru_cache(maxsize=8192) def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len): traj = load_traj(repo_id, filename, int(traj_id)) timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) return _make_trajectory_status_plot(traj, timestep, int(min_seq_len)) def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels): repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) n_traj = get_num_trajectories(repo_id, filename) if n_traj == 0: return "No trajectories found." traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) traj = load_traj(repo_id, filename, traj_id) if not traj: return "Trajectory could not be loaded." if image_keys is None: image_keys = [] if isinstance(image_keys, str): image_keys = [image_keys] image_keys_tuple = tuple(image_keys) total = len(traj) for t in range(total): get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels)) get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len)) status = "Preloaded trajectory {}".format(traj_id) status += "\nFrames cached: {}".format(total) status += "\nImage keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none") return status def _compose_video_frame(gallery_items, frame_label, status_plot=None): """Compose one video frame. Top: selected observation images. Bottom: trajectory-status plot with the moving timestep cursor. Important: do NOT downscale the status plot to the image width. The plot contains tick labels and a legend, so preserving its native width makes the generated MP4 much more readable. """ small_text_y = 3 if not gallery_items: obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20)) draw = ImageDraw.Draw(obs_canvas) draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255)) else: pil_images = [] for img, label in gallery_items: pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB") # Keep the image-key caption compact; large captions waste video space. label_h = 16 panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0)) panel.paste(pil_img, (0, label_h)) draw = ImageDraw.Draw(panel) draw.text((4, small_text_y), str(label), fill=(220, 220, 220)) pil_images.append(panel) gap = 8 top_h = 18 width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0) height = max(im.height for im in pil_images) + top_h obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0)) draw = ImageDraw.Draw(obs_canvas) # Compact frame label above the image panels. draw.text((6, small_text_y), frame_label, fill=(220, 220, 220)) x = 0 for im in pil_images: obs_canvas.paste(im, (x, top_h)) x += im.width + gap if status_plot is not None: status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB") # Preserve the status plot resolution. If needed, pad the observation # canvas to the same width and center it above the plot. final_w = max(obs_canvas.width, status_img.width) if obs_canvas.width < final_w: padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0)) padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0)) obs_canvas = padded_obs elif status_img.width < final_w: padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255)) padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0)) status_img = padded_status gap_h = 8 canvas = Image.new( "RGB", (final_w, obs_canvas.height + gap_h + status_img.height), color=(0, 0, 0), ) canvas.paste(obs_canvas, (0, 0)) canvas.paste(status_img, (0, obs_canvas.height + gap_h)) else: canvas = obs_canvas # Many MP4 encoders prefer dimensions divisible by 16. pad_w = int(np.ceil(canvas.width / 16.0) * 16) pad_h = int(np.ceil(canvas.height / 16.0) * 16) if pad_w != canvas.width or pad_h != canvas.height: padded = Image.new("RGB", (pad_w, pad_h), color=(0, 0, 0)) padded.paste(canvas, (0, 0)) canvas = padded return np.asarray(canvas) @lru_cache(maxsize=128) def get_video_status_plot_base(repo_id, filename, traj_id, valid_window_len): """Render the static part of the status plot once for video export. Matplotlib per frame is slow. This function draws no_teacher_action and valid-start markers once, records the axes pixel bounds, and returns a base image. The moving cursor is later drawn with PIL, which is much faster. """ traj = load_traj(repo_id, filename, int(traj_id)) total_steps = len(traj) if total_steps == 0: return None, (0, 0, 1, 1), 0 no_teacher, valid_indices = _compute_valid_start_indices(traj, int(valid_window_len)) fig, ax = plt.subplots(figsize=VIDEO_STATUS_FIGSIZE, dpi=VIDEO_STATUS_DPI) ax.step( np.arange(total_steps), no_teacher, where="post", label="no_teacher_action", color="orange", ) if valid_indices: ax.scatter( valid_indices, [-0.15] * len(valid_indices), color="green", marker="^", s=18, label="Valid Start (len >= {})".format(int(valid_window_len)), ) ax.set_xlim(0, max(total_steps - 1, 1)) ax.set_ylim(-0.38, 1.1) ax.set_ylabel("Flag", fontsize=8) ax.set_xlabel("Timestep index", fontsize=8) ax.set_yticks([0, 1]) ax.set_yticklabels(["False", "True"]) ax.grid(True, axis="x", alpha=0.2) ax.set_title("no_teacher_action and valid starts", fontsize=9) ax.tick_params(axis="both", labelsize=7) ax.legend(loc="upper right", fontsize=7) fig.tight_layout() fig.canvas.draw() rgba = np.asarray(fig.canvas.buffer_rgba()) base = rgba[..., :3].copy() bbox = ax.get_window_extent() height = base.shape[0] # Matplotlib bbox origin is bottom-left, image origin is top-left. x0 = int(round(bbox.x0)) x1 = int(round(bbox.x1)) y0 = int(round(height - bbox.y1)) y1 = int(round(height - bbox.y0)) plt.close(fig) return base, (x0, y0, x1, y1), total_steps @lru_cache(maxsize=8192) def get_cached_video_status_frame(repo_id, filename, traj_id, timestep, valid_window_len): """Draw the moving cursor on a cached static status plot.""" base, bounds, total_steps = get_video_status_plot_base( repo_id, filename, int(traj_id), int(valid_window_len), ) if base is None: return None timestep = int(np.clip(int(timestep), 0, max(total_steps - 1, 0))) x0, y0, x1, y1 = bounds denom = max(total_steps - 1, 1) x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) img = Image.fromarray(np.asarray(base, dtype=np.uint8)).convert("RGB") draw = ImageDraw.Draw(img) # Moving cursor. draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=4) # Compact step label, top-left of the plot area. label = "step {}/{}".format(timestep, total_steps - 1) draw.rectangle((x0 + 4, y0 + 4, x0 + 118, y0 + 24), fill=(255, 255, 255)) draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) return np.asarray(img) def _draw_status_cursor_on_base(base, bounds, total_steps, timestep): """Fast video status frame: copy one static Matplotlib image and draw cursor. This avoids calling the lru-cached per-timestep status frame function during video export. For long trajectories, caching thousands of status images can consume a lot of memory and still requires PIL conversion for every frame. """ if base is None: return None total_steps = int(max(total_steps, 1)) timestep = int(np.clip(int(timestep), 0, total_steps - 1)) x0, y0, x1, y1 = [int(v) for v in bounds] denom = max(total_steps - 1, 1) x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) img = np.asarray(base, dtype=np.uint8).copy() # Draw the vertical cursor directly with NumPy. This is much cheaper than # creating a Matplotlib plot for every frame. x_left = max(0, x - 2) x_right = min(img.shape[1], x + 2) y_top = max(0, y0) y_bottom = min(img.shape[0], y1) img[y_top:y_bottom, x_left:x_right, :] = 0 # Small text label. PIL is used only for the label, not for the whole plot. pil_img = Image.fromarray(img).convert("RGB") draw = ImageDraw.Draw(pil_img) label = "step {}/{}".format(timestep, total_steps - 1) draw.rectangle((x0 + 4, y0 + 4, x0 + 126, y0 + 24), fill=(255, 255, 255)) draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) return np.asarray(pil_img) def _get_fast_video_writer(out_path, fps): """Use ffmpeg's ultrafast x264 preset for interactive Spaces exports.""" return imageio.get_writer( out_path, fps=float(fps), codec="libx264", macro_block_size=16, ffmpeg_params=[ "-preset", "ultrafast", "-crf", "28", "-pix_fmt", "yuv420p", "-movflags", "+faststart", ], ) def build_current_trajectory_video(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, display_scale, reverse_channels, fps, valid_window_len, video_stride=4): if imageio is None: return None, "Video export requires imageio and imageio-ffmpeg in requirements.txt." repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) n_traj = get_num_trajectories(repo_id, filename) if n_traj == 0: return None, "No trajectories found." traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) traj = load_traj(repo_id, filename, traj_id) if not traj: return None, "Trajectory could not be loaded." if image_keys is None: image_keys = [] if isinstance(image_keys, str): image_keys = [image_keys] image_keys_tuple = tuple(image_keys) video_stride = int(max(1, int(video_stride))) frame_indices = list(range(0, len(traj), video_stride)) if frame_indices and frame_indices[-1] != len(traj) - 1: frame_indices.append(len(traj) - 1) safe_repo = re.sub(r"[^A-Za-z0-9_.-]+", "_", repo_id) safe_file = re.sub(r"[^A-Za-z0-9_.-]+", "_", filename)[-80:] out_path = os.path.join( tempfile.gettempdir(), "trajectory_{}_{}_traj{:04d}_fps{}_stride{}.mp4".format( safe_repo, safe_file, traj_id, int(fps), video_stride ), ) # Build the static status plot once. During export, only draw the cursor. status_base, status_bounds, total_steps = get_video_status_plot_base( repo_id, filename, traj_id, int(valid_window_len), ) writer = _get_fast_video_writer(out_path, fps) written = 0 try: for t in frame_indices: # Use the existing cached image extraction for correctness, but avoid # cached per-timestep status images to reduce memory pressure. gallery_items, _warnings = get_cached_gallery_items( repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels), ) label = "trajectory {} | frame {}/{}".format(traj_id, t, len(traj) - 1) status_plot = _draw_status_cursor_on_base(status_base, status_bounds, total_steps, t) frame = _compose_video_frame(gallery_items, label, status_plot=status_plot) writer.append_data(frame) written += 1 finally: writer.close() approx_seconds = float(written) / float(max(float(fps), 1.0)) status = "Built trajectory video with optimized encoder and status rendering" status += "\nTrajectory: {}".format(traj_id) status += "\nOriginal timesteps: {} | Written frames: {} | Stride: {}".format(len(traj), written, video_stride) status += "\nFPS: {} | Approx video duration: {:.1f}s".format(fps, approx_seconds) status += "\nValid-window length: {}".format(int(valid_window_len)) status += "\nSpeedups: x264 ultrafast preset; static status plot rendered once; cursor drawn with NumPy/PIL" return out_path, status def get_available_image_keys(repo_id, filename, traj_id): n_traj = get_num_trajectories(repo_id, filename) if n_traj == 0: return [] traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) traj = load_traj(repo_id, filename, traj_id) if not traj: return [] obs = traj[0].get("obs", {}) keys = [] for key, value in obs.items(): try: if _looks_like_image_array(key, value): keys.append(key) except Exception: pass ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys] ordered.extend([key for key in keys if key not in ordered]) return ordered def update_custom_visibility(preset_name): visible = preset_name == "Custom" return gr.update(visible=visible), gr.update(visible=visible) def update_after_dataset_change(preset_name, custom_repo_id, custom_filename): repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) n_traj = get_num_trajectories(repo_id, filename) reverse_default = get_default_reverse_channels(preset_name) if n_traj == 0: status = "Loaded `{}` / `{}`".format(repo_id, filename) status += "\nDetected trajectories: 0" status += "\nreverse_channels default: {}".format(int(reverse_default)) return ( gr.update(maximum=1, value=0), gr.update(maximum=1, value=0), gr.update(choices=[], value=[]), status, gr.update(value=reverse_default), ) keys = get_available_image_keys(repo_id, filename, 0) traj = load_traj(repo_id, filename, 0) status = "Loaded `{}` / `{}`".format(repo_id, filename) status += "\nDetected trajectories: {}".format(n_traj) status += "\nreverse_channels default: {}".format(int(reverse_default)) return ( gr.update(maximum=max(n_traj - 1, 1), value=0), gr.update(maximum=max(len(traj) - 1, 1), value=0), gr.update(choices=keys, value=keys[:2]), status, gr.update(value=reverse_default), ) def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id): repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) n_traj = get_num_trajectories(repo_id, filename) if n_traj == 0: return gr.update(maximum=1, value=0), gr.update(choices=[], value=[]) traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) traj = load_traj(repo_id, filename, traj_id) keys = get_available_image_keys(repo_id, filename, traj_id) return ( gr.update(maximum=max(len(traj) - 1, 1), value=0), gr.update(choices=keys, value=keys[:2]), ) def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels): repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) n_traj = get_num_trajectories(repo_id, filename) if n_traj == 0: return [], None, "No trajectory groups found. Open Debug: HDF5 tree." traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) traj = load_traj(repo_id, filename, traj_id) if not traj: return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree." timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) chunk_len = int(chunk_len) display_scale = float(display_scale) if image_keys is None: image_keys = [] if isinstance(image_keys, str): image_keys = [image_keys] step = traj[timestep] image_keys_tuple = tuple(image_keys) gallery_items, warnings_tuple = get_cached_gallery_items( repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, bool(reverse_channels) ) warnings = list(warnings_tuple) status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len) image_debug_lines = [] for _key in image_keys: if _key in step.get("obs", {}): _arr = np.asarray(step["obs"][_key]) image_debug_lines.append( "{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype) ) info_lines = [ "dataset: {} / {}".format(repo_id, filename), "detected trajectories: {}".format(n_traj), "trajectory: {}".format(traj_id), "episode_id: {}".format(step.get("episode_id", "")), "timestep: {} / {}".format(timestep, len(traj) - 1), "saved timestep: {}".format(step.get("timestep", timestep)), "done: {}".format(int(bool(step.get("done", False)))), "if_success: {}".format(int(bool(step.get("if_success", False)))), "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))), "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))), "valid-window length: {}".format(chunk_len), "valid_start: {}".format(int(bool(is_valid_start))), "num_valid_starts: {}".format(num_valid_starts), "", "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))), "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))), "", "selected image tensors:", *image_debug_lines, ] if warnings: info_lines.append("") info_lines.append("Image warnings:") info_lines.extend(warnings) return gallery_items, status_plot, "\n".join(info_lines) def build_app(): repo_id, filename = resolve_dataset(DEFAULT_PRESET) try: n_traj = get_num_trajectories(repo_id, filename) first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else [] startup_warning = "" except Exception as exc: n_traj = 0 first_keys = [] startup_warning = repr(exc) default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET))) with gr.Blocks(title="HDF5 Trajectory Viewer") as demo: gr.Markdown( "# HDF5 Trajectory Viewer\n\n" "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n" "The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor." ) if startup_warning: gr.Markdown("Startup warning: `{}`".format(startup_warning)) with gr.Row(): preset = gr.Dropdown( choices=list(DATASET_PRESETS.keys()) + ["Custom"], value=DEFAULT_PRESET, label="Dataset preset", ) custom_repo_id = gr.Textbox(value="", label="Custom repo_id, e.g. Zhaoting123/InsertT", visible=False) custom_filename = gr.Textbox(value="", label="Custom HDF5 path in repo", visible=False) dataset_status = gr.Textbox(label="Dataset status", lines=2, value=default_status, interactive=False) with gr.Row(): traj_slider = gr.Slider(minimum=0, maximum=max(n_traj - 1, 1), value=0, step=1, label="Trajectory index") timestep_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Timestep") with gr.Row(): image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys") chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length") display_scale = gr.State(value=DEFAULT_DISPLAY_SCALE) reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB") with gr.Row(): render_btn = gr.Button("Render frame", variant="primary") preload_btn = gr.Button("Preload current trajectory") video_btn = gr.Button("Build trajectory video") video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS") video_stride = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Video frame stride") preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False) with gr.Row(): with gr.Column(scale=3): gallery = gr.Gallery( label="Camera images", columns=2, height=360, object_fit="contain", ) with gr.Column(scale=2): status_plot = gr.Image( label="no_teacher_action + valid starts", type="numpy", height=360, ) trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback") info = gr.Textbox(label="Frame info", lines=16) with gr.Accordion("Debug: HDF5 tree", open=False): inspect_btn = gr.Button("Inspect HDF5 structure") hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree") preset.change( fn=update_custom_visibility, inputs=preset, outputs=[custom_repo_id, custom_filename], ).then( fn=update_after_dataset_change, inputs=[preset, custom_repo_id, custom_filename], outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], ).then( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) custom_repo_id.submit( fn=update_after_dataset_change, inputs=[preset, custom_repo_id, custom_filename], outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], ) custom_filename.submit( fn=update_after_dataset_change, inputs=[preset, custom_repo_id, custom_filename], outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], ) traj_slider.change( fn=update_after_traj_change, inputs=[preset, custom_repo_id, custom_filename, traj_slider], outputs=[timestep_slider, image_keys], ).then( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) timestep_slider.release( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) for widget in [image_keys, chunk_len, reverse_channels]: widget.change( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) render_btn.click( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) preload_btn.click( fn=preload_current_trajectory, inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=preload_status, ) video_btn.click( fn=build_current_trajectory_video, inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, display_scale, reverse_channels, video_fps, chunk_len, video_stride], outputs=[trajectory_video, preload_status], ) inspect_btn.click( fn=inspect_hdf5_tree, inputs=[preset, custom_repo_id, custom_filename], outputs=hdf5_tree, ) demo.load( fn=update_after_dataset_change, inputs=[preset, custom_repo_id, custom_filename], outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], ).then( fn=render_frame, inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], outputs=[gallery, status_plot, info], ) return demo if __name__ == "__main__": demo = build_app() demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), share=False, ssr_mode=False, )