Zhaoting123's picture
Update app.py
2fb0eb7 verified
# Standalone Hugging Face Space viewer for TrajectoryBuffer-style HDF5 files.
#
# requirements.txt:
# gradio
# huggingface_hub
# h5py
# numpy
# pillow
# matplotlib
# imageio
# imageio-ffmpeg
#
# Optional:
# opencv-python-headless
import os
import re
import tempfile
from functools import lru_cache
import gradio as gr
import h5py
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
try:
import imageio.v2 as imageio
except Exception:
imageio = None
try:
import cv2
except Exception:
cv2 = None
DATASET_PRESETS = {
"Robosuite Square Correction": {
"repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
"filename": (
"20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
"Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
),
"default_reverse_channels": False,
},
"InsertT Demonstration": {
"repo_id": "Zhaoting123/InsertT",
"filename": "trajectory_buffer_Nov10_demo.hdf5",
"default_reverse_channels": True,
},
"InsertT Correction": {
"repo_id": "Zhaoting123/InsertT",
"filename": "trajectory_buffer_Nov11_intervention.hdf5",
"default_reverse_channels": True,
},
"RoundTable Correction": {
"repo_id": "Zhaoting123/Furniture_Bench_Round_Table_Assembly",
"filename": "trajectory_buffer_0_Nov24_intervention_relabeled.hdf5",
"default_reverse_channels": True,
},
}
DEFAULT_PRESET = "Robosuite Square Correction"
REPO_TYPE = "dataset"
DEFAULT_CHUNK_LEN = 16
DEFAULT_DISPLAY_SCALE = 1
VIDEO_STATUS_FIGSIZE = (6.0, 1.8)
VIDEO_STATUS_DPI = 120
PREFERRED_IMAGE_KEYS = [
"image1",
"image2",
"agentview_image",
"robot0_eye_in_hand_image",
"front_image",
"wrist_image",
]
IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"]
def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
preset_name = preset_name or DEFAULT_PRESET
if preset_name == "Custom":
repo_id = str(custom_repo_id or "").strip()
filename = str(custom_filename or "").strip()
if not repo_id or not filename:
raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.")
return repo_id, filename
item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
return item["repo_id"], item["filename"]
def get_default_reverse_channels(preset_name):
"""Dataset-specific default for BGR<->RGB reversal.
Robosuite Square presets use normal RGB ordering.
InsertT / PushT-style preset requires reversal.
Custom datasets default to False so users can still override manually.
"""
preset_name = preset_name or DEFAULT_PRESET
if preset_name == "Custom":
return False
item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
return bool(item.get("default_reverse_channels", False))
@lru_cache(maxsize=8)
def get_local_hdf5_path(repo_id, filename):
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)
def _natural_sort_key(name):
match = re.search(r"([0-9]+)$", str(name))
if match:
return 0, int(match.group(1))
return 1, str(name)
@lru_cache(maxsize=8)
def get_trajectory_keys(repo_id, filename):
path = get_local_hdf5_path(repo_id, filename)
with h5py.File(path, "r") as f:
root_episode_keys = [
key for key in f.keys()
if isinstance(f[key], h5py.Group) and str(key).startswith("episode_")
]
if root_episode_keys:
return tuple(sorted(root_episode_keys, key=_natural_sort_key))
if "data" in f and isinstance(f["data"], h5py.Group):
data_group = f["data"]
keys = [key for key in data_group.keys() if isinstance(data_group[key], h5py.Group)]
return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key))
keys = [key for key in f.keys() if isinstance(f[key], h5py.Group)]
return tuple(sorted(keys, key=_natural_sort_key))
@lru_cache(maxsize=8)
def get_num_trajectories(repo_id, filename):
return len(get_trajectory_keys(repo_id, filename))
def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180):
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
path = get_local_hdf5_path(repo_id, filename)
lines = []
with h5py.File(path, "r") as f:
def visitor(name, obj):
if len(lines) >= max_lines:
return
if isinstance(obj, h5py.Dataset):
lines.append("DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype))
elif isinstance(obj, h5py.Group):
lines.append("GROUP {}".format(name))
f.visititems(visitor)
if len(lines) >= max_lines:
lines.append("...")
return "\n".join(lines) if lines else "No HDF5 contents found."
def _read_dataset_value(dataset):
value = dataset[()]
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def _read_group_recursive(group):
out = {}
for key, obj in group.items():
if isinstance(obj, h5py.Dataset):
out[key] = _read_dataset_value(obj)
elif isinstance(obj, h5py.Group):
out[key] = _read_group_recursive(obj)
return out
def _find_first_key(mapping, candidate_keys):
for key in candidate_keys:
if key in mapping:
return key
return None
def _infer_time_length(data):
for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]:
if key in data:
arr = np.asarray(data[key])
if arr.ndim >= 1:
return int(arr.shape[0])
obs_group = None
if isinstance(data.get("observation"), dict):
obs_group = data["observation"]
elif isinstance(data.get("obs"), dict):
obs_group = data["obs"]
if obs_group:
lengths = []
for value in obs_group.values():
arr = np.asarray(value)
if arr.ndim >= 1:
lengths.append(int(arr.shape[0]))
if lengths:
values, counts = np.unique(lengths, return_counts=True)
return int(values[np.argmax(counts)])
return 1
def _slice_time(value, t, T):
arr = np.asarray(value)
if arr.ndim >= 1 and arr.shape[0] == T:
return arr[t]
return arr
@lru_cache(maxsize=64)
def load_traj(repo_id, filename, traj_id):
traj_keys = get_trajectory_keys(repo_id, filename)
if not traj_keys:
return []
traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1))
traj_key = traj_keys[traj_id]
path = get_local_hdf5_path(repo_id, filename)
with h5py.File(path, "r") as f:
data = _read_group_recursive(f[traj_key])
T = _infer_time_length(data)
if isinstance(data.get("observation"), dict):
obs_all = data["observation"]
elif isinstance(data.get("obs"), dict):
obs_all = data["obs"]
else:
obs_all = {}
action_key = _find_first_key(data, ["actions", "action"])
teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"])
robot_key = _find_first_key(data, ["robot_actions", "robot_action"])
no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"])
no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"])
done_key = _find_first_key(data, ["dones", "done"])
timestep_key = _find_first_key(data, ["timesteps", "timestep"])
success_key = _find_first_key(data, ["if_success", "success", "successes"])
traj = []
for t in range(T):
obs_t = {key: _slice_time(value, t, T) for key, value in obs_all.items()}
default_action = np.zeros(1, dtype=np.float32)
if action_key is not None:
default_action = _slice_time(data[action_key], t, T)
teacher_action = _slice_time(data[teacher_key], t, T) if teacher_key else default_action
robot_action = _slice_time(data[robot_key], t, T) if robot_key else default_action
no_teacher = _slice_time(data[no_teacher_key], t, T) if no_teacher_key else False
no_robot = _slice_time(data[no_robot_key], t, T) if no_robot_key else False
done = _slice_time(data[done_key], t, T) if done_key else False
if_success = _slice_time(data[success_key], t, T) if success_key else False
timestep = t
if timestep_key is not None:
timestep_arr = _slice_time(data[timestep_key], t, T)
timestep = int(np.asarray(timestep_arr).reshape(-1)[0])
traj.append({
"obs": obs_t,
"robot_action": np.asarray(robot_action),
"teacher_action": np.asarray(teacher_action),
"done": bool(np.asarray(done).reshape(-1)[0]),
"timestep": timestep,
"no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]),
"no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]),
"episode_id": traj_key,
"if_success": bool(np.asarray(if_success).reshape(-1)[0]),
})
return traj
def _extract_latest_obs_value(value):
"""Return the latest stacked observation only when there is a clear stack axis.
Important:
- [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame.
- [C, H, W] must NOT be sliced, otherwise an RGB image becomes one
grayscale channel.
"""
arr = np.asarray(value)
# Stacked image observations, e.g. [obs_T, C, H, W] or [obs_T, H, W, C].
if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
channel_first = arr.shape[1] in (1, 3, 4)
channel_last = arr.shape[-1] in (1, 3, 4)
if channel_first or channel_last:
return arr[-1]
# Stacked vector observations, e.g. [obs_T, D]. Keep this for non-image obs.
if arr.ndim == 2 and arr.shape[0] in (1, 2):
return arr[-1]
return arr
def _looks_like_image_array(key, value):
arr = np.asarray(value)
key_l = str(key).lower()
key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS)
# Remove only a clear stacked-image axis for shape detection.
if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4):
arr = arr[-1]
shape_hint = False
if arr.ndim == 2:
shape_hint = True
elif arr.ndim == 3:
shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)
elif arr.ndim == 4:
shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4)
return key_hint or shape_hint
def _float_img_to_uint8(img):
arr = img.astype(np.float32)
arr_min = float(np.nanmin(arr))
arr_max = float(np.nanmax(arr))
if arr_min >= -1.01 and arr_max <= 1.01:
if arr_min < 0.0:
arr = (arr + 1.0) * 0.5
arr = np.clip(arr, 0.0, 1.0) * 255.0
elif arr_max <= 255.0:
arr = np.clip(arr, 0.0, 255.0)
else:
arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8)
return np.round(arr).astype(np.uint8)
def _extract_display_image(value, reverse_channels=False):
img = np.asarray(_extract_latest_obs_value(value))
if img.ndim == 2:
img = np.repeat(img[..., None], 3, axis=-1)
elif img.ndim == 3 and img.shape[0] in (1, 3, 4):
img = np.transpose(img, (1, 2, 0))
if img.ndim == 3 and img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
elif img.ndim == 3 and img.shape[-1] == 4:
img = img[..., :3]
if img.ndim != 3:
raise ValueError("Unsupported image shape: {}".format(img.shape))
out = img.copy() if img.dtype == np.uint8 else _float_img_to_uint8(img)
if reverse_channels and out.shape[-1] == 3:
out = out[..., ::-1]
return out
def _resize_image_for_display(img, display_scale):
scale = float(display_scale)
if scale == 1.0:
return img
h, w = img.shape[:2]
new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale))))
if cv2 is not None:
return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST)
pil_img = Image.fromarray(img)
return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST))
def _extract_mixed_action_chunk(traj, start_idx, chunk_length):
chunk = []
sources = []
end_idx = min(len(traj), int(start_idx) + int(chunk_length))
for idx in range(int(start_idx), end_idx):
step = traj[idx]
use_teacher = not bool(step.get("no_teacher_action", False))
action = step["teacher_action"] if use_teacher else step["robot_action"]
chunk.append(np.asarray(action, dtype=np.float32).reshape(-1))
sources.append("T" if use_teacher else "R")
if not chunk:
return None, ""
return np.stack(chunk, axis=0), "".join(sources)
def _extract_robot_action_chunk(traj, start_idx, chunk_length):
chunk = []
end_idx = min(len(traj), int(start_idx) + int(chunk_length))
for idx in range(int(start_idx), end_idx):
step = traj[idx]
chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1))
if not chunk:
return None
return np.stack(chunk, axis=0)
def _safe_array_str(value, precision=3, max_items=24):
arr = np.asarray(value).reshape(-1)
shown = arr[:max_items]
text = np.array2string(shown, precision=precision, separator=", ")
if arr.size > max_items:
text += " ... +{} more".format(arr.size - max_items)
return text
def _make_action_chunk_plot(mixed_chunk, robot_chunk):
if mixed_chunk is None:
return None
mixed_chunk = np.asarray(mixed_chunk, dtype=np.float32)
if mixed_chunk.ndim == 1:
mixed_chunk = mixed_chunk[:, None]
fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
x = np.arange(mixed_chunk.shape[0])
max_dims = min(mixed_chunk.shape[1], 10)
for dim in range(max_dims):
ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim))
if robot_chunk is not None:
robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
if robot_chunk.ndim == 1:
robot_chunk = robot_chunk[:, None]
for dim in range(min(robot_chunk.shape[1], max_dims)):
ax.plot(
x,
robot_chunk[:, dim],
linestyle="--",
alpha=0.55,
label="robot[{}]".format(dim),
)
ax.set_title("Action chunk")
ax.set_xlabel("chunk step")
ax.set_ylabel("action value")
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=7, ncol=2)
fig.tight_layout()
fig.canvas.draw()
rgba = np.asarray(fig.canvas.buffer_rgba())
image = rgba[..., :3].copy()
plt.close(fig)
return image
@lru_cache(maxsize=8192)
def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, reverse_channels):
traj = load_traj(repo_id, filename, int(traj_id))
timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
obs = traj[timestep].get("obs", {})
gallery_items = []
warnings = []
for key in image_keys_tuple:
if key not in obs:
warnings.append("Missing image key: {}".format(key))
continue
try:
img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels))
img = _resize_image_for_display(img, float(display_scale))
gallery_items.append((img, key))
except Exception as exc:
warnings.append("{}: {}".format(key, exc))
return gallery_items, tuple(warnings)
def _compute_valid_start_indices(traj, min_seq_len):
"""Match the original local script's valid-start heuristic.
A timestep is valid when the following min_seq_len steps all have
no_teacher_action == False.
"""
total_steps = len(traj)
min_seq_len = int(max(1, min_seq_len))
no_teacher = np.asarray(
[int(bool(step.get("no_teacher_action", False))) for step in traj],
dtype=np.int32,
)
valid_indices = []
max_start = total_steps - min_seq_len + 1
for t in range(max(0, max_start)):
if int(np.sum(no_teacher[t:t + min_seq_len])) == 0:
valid_indices.append(t)
return no_teacher, valid_indices
def _make_trajectory_status_plot(traj, timestep, min_seq_len):
"""Render the same high-level status figure as the local matplotlib tool.
Shows:
- orange no_teacher_action step plot
- green triangles for algorithmic valid start points
- black vertical cursor at current timestep
"""
total_steps = len(traj)
if total_steps == 0:
return None, False, 0
timestep = int(np.clip(int(timestep), 0, total_steps - 1))
timesteps = np.asarray(
[int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)],
dtype=np.int32,
)
no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
is_valid_start = timestep in set(valid_indices)
fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170)
ax.step(
np.arange(total_steps),
no_teacher,
where="post",
label="no_teacher_action",
color="orange",
)
if valid_indices:
ax.scatter(
valid_indices,
[-0.15] * len(valid_indices),
color="green",
marker="^",
s=18,
label="Valid Start (len >= {})".format(int(min_seq_len)),
)
ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
ax.set_xlim(0, max(total_steps - 1, 1))
ax.set_ylim(-0.38, 1.1)
ax.set_ylabel("Flag", fontsize=10)
ax.set_xlabel("Timestep index", fontsize=10)
ax.set_yticks([0, 1])
ax.set_yticklabels(["False", "True"])
ax.grid(True, axis="x", alpha=0.2)
title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
if is_valid_start:
title += " | VALID START"
ax.set_title(title, fontsize=11)
ax.tick_params(axis="both", labelsize=9)
ax.legend(loc="upper right", fontsize=9)
# Add saved timestep annotation if the stored timestep is not the same as index.
saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
if saved_timestep != timestep:
ax.text(
0.01,
0.04,
"saved timestep: {}".format(saved_timestep),
transform=ax.transAxes,
fontsize=8,
va="bottom",
ha="left",
)
fig.tight_layout()
fig.canvas.draw()
rgba = np.asarray(fig.canvas.buffer_rgba())
image = rgba[..., :3].copy()
plt.close(fig)
return image, bool(is_valid_start), len(valid_indices)
@lru_cache(maxsize=8192)
def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len):
traj = load_traj(repo_id, filename, int(traj_id))
timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
return _make_trajectory_status_plot(traj, timestep, int(min_seq_len))
def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
n_traj = get_num_trajectories(repo_id, filename)
if n_traj == 0:
return "No trajectories found."
traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
traj = load_traj(repo_id, filename, traj_id)
if not traj:
return "Trajectory could not be loaded."
if image_keys is None:
image_keys = []
if isinstance(image_keys, str):
image_keys = [image_keys]
image_keys_tuple = tuple(image_keys)
total = len(traj)
for t in range(total):
get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len))
status = "Preloaded trajectory {}".format(traj_id)
status += "\nFrames cached: {}".format(total)
status += "\nImage keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none")
return status
def _compose_video_frame(gallery_items, frame_label, status_plot=None):
"""Compose one video frame.
Top: selected observation images.
Bottom: trajectory-status plot with the moving timestep cursor.
Important: do NOT downscale the status plot to the image width. The plot
contains tick labels and a legend, so preserving its native width makes the
generated MP4 much more readable.
"""
small_text_y = 3
if not gallery_items:
obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20))
draw = ImageDraw.Draw(obs_canvas)
draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255))
else:
pil_images = []
for img, label in gallery_items:
pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB")
# Keep the image-key caption compact; large captions waste video space.
label_h = 16
panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0))
panel.paste(pil_img, (0, label_h))
draw = ImageDraw.Draw(panel)
draw.text((4, small_text_y), str(label), fill=(220, 220, 220))
pil_images.append(panel)
gap = 8
top_h = 18
width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0)
height = max(im.height for im in pil_images) + top_h
obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0))
draw = ImageDraw.Draw(obs_canvas)
# Compact frame label above the image panels.
draw.text((6, small_text_y), frame_label, fill=(220, 220, 220))
x = 0
for im in pil_images:
obs_canvas.paste(im, (x, top_h))
x += im.width + gap
if status_plot is not None:
status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB")
# Preserve the status plot resolution. If needed, pad the observation
# canvas to the same width and center it above the plot.
final_w = max(obs_canvas.width, status_img.width)
if obs_canvas.width < final_w:
padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0))
padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0))
obs_canvas = padded_obs
elif status_img.width < final_w:
padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255))
padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0))
status_img = padded_status
gap_h = 8
canvas = Image.new(
"RGB",
(final_w, obs_canvas.height + gap_h + status_img.height),
color=(0, 0, 0),
)
canvas.paste(obs_canvas, (0, 0))
canvas.paste(status_img, (0, obs_canvas.height + gap_h))
else:
canvas = obs_canvas
# Many MP4 encoders prefer dimensions divisible by 16.
pad_w = int(np.ceil(canvas.width / 16.0) * 16)
pad_h = int(np.ceil(canvas.height / 16.0) * 16)
if pad_w != canvas.width or pad_h != canvas.height:
padded = Image.new("RGB", (pad_w, pad_h), color=(0, 0, 0))
padded.paste(canvas, (0, 0))
canvas = padded
return np.asarray(canvas)
@lru_cache(maxsize=128)
def get_video_status_plot_base(repo_id, filename, traj_id, valid_window_len):
"""Render the static part of the status plot once for video export.
Matplotlib per frame is slow. This function draws no_teacher_action and
valid-start markers once, records the axes pixel bounds, and returns a base
image. The moving cursor is later drawn with PIL, which is much faster.
"""
traj = load_traj(repo_id, filename, int(traj_id))
total_steps = len(traj)
if total_steps == 0:
return None, (0, 0, 1, 1), 0
no_teacher, valid_indices = _compute_valid_start_indices(traj, int(valid_window_len))
fig, ax = plt.subplots(figsize=VIDEO_STATUS_FIGSIZE, dpi=VIDEO_STATUS_DPI)
ax.step(
np.arange(total_steps),
no_teacher,
where="post",
label="no_teacher_action",
color="orange",
)
if valid_indices:
ax.scatter(
valid_indices,
[-0.15] * len(valid_indices),
color="green",
marker="^",
s=18,
label="Valid Start (len >= {})".format(int(valid_window_len)),
)
ax.set_xlim(0, max(total_steps - 1, 1))
ax.set_ylim(-0.38, 1.1)
ax.set_ylabel("Flag", fontsize=8)
ax.set_xlabel("Timestep index", fontsize=8)
ax.set_yticks([0, 1])
ax.set_yticklabels(["False", "True"])
ax.grid(True, axis="x", alpha=0.2)
ax.set_title("no_teacher_action and valid starts", fontsize=9)
ax.tick_params(axis="both", labelsize=7)
ax.legend(loc="upper right", fontsize=7)
fig.tight_layout()
fig.canvas.draw()
rgba = np.asarray(fig.canvas.buffer_rgba())
base = rgba[..., :3].copy()
bbox = ax.get_window_extent()
height = base.shape[0]
# Matplotlib bbox origin is bottom-left, image origin is top-left.
x0 = int(round(bbox.x0))
x1 = int(round(bbox.x1))
y0 = int(round(height - bbox.y1))
y1 = int(round(height - bbox.y0))
plt.close(fig)
return base, (x0, y0, x1, y1), total_steps
@lru_cache(maxsize=8192)
def get_cached_video_status_frame(repo_id, filename, traj_id, timestep, valid_window_len):
"""Draw the moving cursor on a cached static status plot."""
base, bounds, total_steps = get_video_status_plot_base(
repo_id,
filename,
int(traj_id),
int(valid_window_len),
)
if base is None:
return None
timestep = int(np.clip(int(timestep), 0, max(total_steps - 1, 0)))
x0, y0, x1, y1 = bounds
denom = max(total_steps - 1, 1)
x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom)))
img = Image.fromarray(np.asarray(base, dtype=np.uint8)).convert("RGB")
draw = ImageDraw.Draw(img)
# Moving cursor.
draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=4)
# Compact step label, top-left of the plot area.
label = "step {}/{}".format(timestep, total_steps - 1)
draw.rectangle((x0 + 4, y0 + 4, x0 + 118, y0 + 24), fill=(255, 255, 255))
draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0))
return np.asarray(img)
def _draw_status_cursor_on_base(base, bounds, total_steps, timestep):
"""Fast video status frame: copy one static Matplotlib image and draw cursor.
This avoids calling the lru-cached per-timestep status frame function during
video export. For long trajectories, caching thousands of status images can
consume a lot of memory and still requires PIL conversion for every frame.
"""
if base is None:
return None
total_steps = int(max(total_steps, 1))
timestep = int(np.clip(int(timestep), 0, total_steps - 1))
x0, y0, x1, y1 = [int(v) for v in bounds]
denom = max(total_steps - 1, 1)
x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom)))
img = np.asarray(base, dtype=np.uint8).copy()
# Draw the vertical cursor directly with NumPy. This is much cheaper than
# creating a Matplotlib plot for every frame.
x_left = max(0, x - 2)
x_right = min(img.shape[1], x + 2)
y_top = max(0, y0)
y_bottom = min(img.shape[0], y1)
img[y_top:y_bottom, x_left:x_right, :] = 0
# Small text label. PIL is used only for the label, not for the whole plot.
pil_img = Image.fromarray(img).convert("RGB")
draw = ImageDraw.Draw(pil_img)
label = "step {}/{}".format(timestep, total_steps - 1)
draw.rectangle((x0 + 4, y0 + 4, x0 + 126, y0 + 24), fill=(255, 255, 255))
draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0))
return np.asarray(pil_img)
def _get_fast_video_writer(out_path, fps):
"""Use ffmpeg's ultrafast x264 preset for interactive Spaces exports."""
return imageio.get_writer(
out_path,
fps=float(fps),
codec="libx264",
macro_block_size=16,
ffmpeg_params=[
"-preset", "ultrafast",
"-crf", "28",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
],
)
def build_current_trajectory_video(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, display_scale, reverse_channels, fps, valid_window_len, video_stride=4):
if imageio is None:
return None, "Video export requires imageio and imageio-ffmpeg in requirements.txt."
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
n_traj = get_num_trajectories(repo_id, filename)
if n_traj == 0:
return None, "No trajectories found."
traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
traj = load_traj(repo_id, filename, traj_id)
if not traj:
return None, "Trajectory could not be loaded."
if image_keys is None:
image_keys = []
if isinstance(image_keys, str):
image_keys = [image_keys]
image_keys_tuple = tuple(image_keys)
video_stride = int(max(1, int(video_stride)))
frame_indices = list(range(0, len(traj), video_stride))
if frame_indices and frame_indices[-1] != len(traj) - 1:
frame_indices.append(len(traj) - 1)
safe_repo = re.sub(r"[^A-Za-z0-9_.-]+", "_", repo_id)
safe_file = re.sub(r"[^A-Za-z0-9_.-]+", "_", filename)[-80:]
out_path = os.path.join(
tempfile.gettempdir(),
"trajectory_{}_{}_traj{:04d}_fps{}_stride{}.mp4".format(
safe_repo, safe_file, traj_id, int(fps), video_stride
),
)
# Build the static status plot once. During export, only draw the cursor.
status_base, status_bounds, total_steps = get_video_status_plot_base(
repo_id,
filename,
traj_id,
int(valid_window_len),
)
writer = _get_fast_video_writer(out_path, fps)
written = 0
try:
for t in frame_indices:
# Use the existing cached image extraction for correctness, but avoid
# cached per-timestep status images to reduce memory pressure.
gallery_items, _warnings = get_cached_gallery_items(
repo_id,
filename,
traj_id,
t,
image_keys_tuple,
float(display_scale),
bool(reverse_channels),
)
label = "trajectory {} | frame {}/{}".format(traj_id, t, len(traj) - 1)
status_plot = _draw_status_cursor_on_base(status_base, status_bounds, total_steps, t)
frame = _compose_video_frame(gallery_items, label, status_plot=status_plot)
writer.append_data(frame)
written += 1
finally:
writer.close()
approx_seconds = float(written) / float(max(float(fps), 1.0))
status = "Built trajectory video with optimized encoder and status rendering"
status += "\nTrajectory: {}".format(traj_id)
status += "\nOriginal timesteps: {} | Written frames: {} | Stride: {}".format(len(traj), written, video_stride)
status += "\nFPS: {} | Approx video duration: {:.1f}s".format(fps, approx_seconds)
status += "\nValid-window length: {}".format(int(valid_window_len))
status += "\nSpeedups: x264 ultrafast preset; static status plot rendered once; cursor drawn with NumPy/PIL"
return out_path, status
def get_available_image_keys(repo_id, filename, traj_id):
n_traj = get_num_trajectories(repo_id, filename)
if n_traj == 0:
return []
traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
traj = load_traj(repo_id, filename, traj_id)
if not traj:
return []
obs = traj[0].get("obs", {})
keys = []
for key, value in obs.items():
try:
if _looks_like_image_array(key, value):
keys.append(key)
except Exception:
pass
ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys]
ordered.extend([key for key in keys if key not in ordered])
return ordered
def update_custom_visibility(preset_name):
visible = preset_name == "Custom"
return gr.update(visible=visible), gr.update(visible=visible)
def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
n_traj = get_num_trajectories(repo_id, filename)
reverse_default = get_default_reverse_channels(preset_name)
if n_traj == 0:
status = "Loaded `{}` / `{}`".format(repo_id, filename)
status += "\nDetected trajectories: 0"
status += "\nreverse_channels default: {}".format(int(reverse_default))
return (
gr.update(maximum=1, value=0),
gr.update(maximum=1, value=0),
gr.update(choices=[], value=[]),
status,
gr.update(value=reverse_default),
)
keys = get_available_image_keys(repo_id, filename, 0)
traj = load_traj(repo_id, filename, 0)
status = "Loaded `{}` / `{}`".format(repo_id, filename)
status += "\nDetected trajectories: {}".format(n_traj)
status += "\nreverse_channels default: {}".format(int(reverse_default))
return (
gr.update(maximum=max(n_traj - 1, 1), value=0),
gr.update(maximum=max(len(traj) - 1, 1), value=0),
gr.update(choices=keys, value=keys[:2]),
status,
gr.update(value=reverse_default),
)
def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id):
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
n_traj = get_num_trajectories(repo_id, filename)
if n_traj == 0:
return gr.update(maximum=1, value=0), gr.update(choices=[], value=[])
traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
traj = load_traj(repo_id, filename, traj_id)
keys = get_available_image_keys(repo_id, filename, traj_id)
return (
gr.update(maximum=max(len(traj) - 1, 1), value=0),
gr.update(choices=keys, value=keys[:2]),
)
def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
n_traj = get_num_trajectories(repo_id, filename)
if n_traj == 0:
return [], None, "No trajectory groups found. Open Debug: HDF5 tree."
traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
traj = load_traj(repo_id, filename, traj_id)
if not traj:
return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree."
timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
chunk_len = int(chunk_len)
display_scale = float(display_scale)
if image_keys is None:
image_keys = []
if isinstance(image_keys, str):
image_keys = [image_keys]
step = traj[timestep]
image_keys_tuple = tuple(image_keys)
gallery_items, warnings_tuple = get_cached_gallery_items(
repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, bool(reverse_channels)
)
warnings = list(warnings_tuple)
status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)
image_debug_lines = []
for _key in image_keys:
if _key in step.get("obs", {}):
_arr = np.asarray(step["obs"][_key])
image_debug_lines.append(
"{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype)
)
info_lines = [
"dataset: {} / {}".format(repo_id, filename),
"detected trajectories: {}".format(n_traj),
"trajectory: {}".format(traj_id),
"episode_id: {}".format(step.get("episode_id", "")),
"timestep: {} / {}".format(timestep, len(traj) - 1),
"saved timestep: {}".format(step.get("timestep", timestep)),
"done: {}".format(int(bool(step.get("done", False)))),
"if_success: {}".format(int(bool(step.get("if_success", False)))),
"no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
"no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
"valid-window length: {}".format(chunk_len),
"valid_start: {}".format(int(bool(is_valid_start))),
"num_valid_starts: {}".format(num_valid_starts),
"",
"teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
"robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
"",
"selected image tensors:",
*image_debug_lines,
]
if warnings:
info_lines.append("")
info_lines.append("Image warnings:")
info_lines.extend(warnings)
return gallery_items, status_plot, "\n".join(info_lines)
def build_app():
repo_id, filename = resolve_dataset(DEFAULT_PRESET)
try:
n_traj = get_num_trajectories(repo_id, filename)
first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else []
startup_warning = ""
except Exception as exc:
n_traj = 0
first_keys = []
startup_warning = repr(exc)
default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET)))
with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
gr.Markdown(
"# HDF5 Trajectory Viewer\n\n"
"Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n"
"The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor."
)
if startup_warning:
gr.Markdown("Startup warning: `{}`".format(startup_warning))
with gr.Row():
preset = gr.Dropdown(
choices=list(DATASET_PRESETS.keys()) + ["Custom"],
value=DEFAULT_PRESET,
label="Dataset preset",
)
custom_repo_id = gr.Textbox(value="", label="Custom repo_id, e.g. Zhaoting123/InsertT", visible=False)
custom_filename = gr.Textbox(value="", label="Custom HDF5 path in repo", visible=False)
dataset_status = gr.Textbox(label="Dataset status", lines=2, value=default_status, interactive=False)
with gr.Row():
traj_slider = gr.Slider(minimum=0, maximum=max(n_traj - 1, 1), value=0, step=1, label="Trajectory index")
timestep_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Timestep")
with gr.Row():
image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
display_scale = gr.State(value=DEFAULT_DISPLAY_SCALE)
reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB")
with gr.Row():
render_btn = gr.Button("Render frame", variant="primary")
preload_btn = gr.Button("Preload current trajectory")
video_btn = gr.Button("Build trajectory video")
video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS")
video_stride = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Video frame stride")
preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False)
with gr.Row():
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Camera images",
columns=2,
height=360,
object_fit="contain",
)
with gr.Column(scale=2):
status_plot = gr.Image(
label="no_teacher_action + valid starts",
type="numpy",
height=360,
)
trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
info = gr.Textbox(label="Frame info", lines=16)
with gr.Accordion("Debug: HDF5 tree", open=False):
inspect_btn = gr.Button("Inspect HDF5 structure")
hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree")
preset.change(
fn=update_custom_visibility,
inputs=preset,
outputs=[custom_repo_id, custom_filename],
).then(
fn=update_after_dataset_change,
inputs=[preset, custom_repo_id, custom_filename],
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
).then(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
custom_repo_id.submit(
fn=update_after_dataset_change,
inputs=[preset, custom_repo_id, custom_filename],
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
)
custom_filename.submit(
fn=update_after_dataset_change,
inputs=[preset, custom_repo_id, custom_filename],
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
)
traj_slider.change(
fn=update_after_traj_change,
inputs=[preset, custom_repo_id, custom_filename, traj_slider],
outputs=[timestep_slider, image_keys],
).then(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
timestep_slider.release(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
for widget in [image_keys, chunk_len, reverse_channels]:
widget.change(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
render_btn.click(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
preload_btn.click(
fn=preload_current_trajectory,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=preload_status,
)
video_btn.click(
fn=build_current_trajectory_video,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, display_scale, reverse_channels, video_fps, chunk_len, video_stride],
outputs=[trajectory_video, preload_status],
)
inspect_btn.click(
fn=inspect_hdf5_tree,
inputs=[preset, custom_repo_id, custom_filename],
outputs=hdf5_tree,
)
demo.load(
fn=update_after_dataset_change,
inputs=[preset, custom_repo_id, custom_filename],
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
).then(
fn=render_frame,
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
outputs=[gallery, status_plot, info],
)
return demo
if __name__ == "__main__":
demo = build_app()
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False,
ssr_mode=False,
)