Zhaoting123 commited on
Commit
7927dd6
·
verified ·
1 Parent(s): 9baa8f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +455 -310
app.py CHANGED
@@ -1,32 +1,33 @@
1
  """
2
- Standalone Hugging Face Space viewer for HDF5 trajectory datasets.
3
 
4
- This version does NOT require your local TrajectoryBuffer class.
5
- It reads the HDF5 file directly with h5py.
6
-
7
- Files needed in the Space:
8
- app.py
9
- requirements.txt
 
10
 
11
  requirements.txt:
12
- gradio
13
- huggingface_hub
14
- h5py
15
- numpy
16
- pillow
17
- matplotlib
18
 
19
  Optional:
20
- opencv-python-headless
21
  """
22
 
23
- import os
24
  import re
25
  from functools import lru_cache
26
 
27
  import gradio as gr
28
  import h5py
29
  import matplotlib
 
30
  matplotlib.use("Agg")
31
  import matplotlib.pyplot as plt
32
  import numpy as np
@@ -40,8 +41,7 @@ except Exception:
40
 
41
 
42
  # -----------------------------------------------------------------------------
43
- # Dataset presets.
44
- # The same Space can visualize multiple HDF5 files by changing repo_id + filename.
45
  # -----------------------------------------------------------------------------
46
  DATASET_PRESETS = {
47
  "Robosuite Square 20260409": {
@@ -67,36 +67,35 @@ DATASET_PRESETS = {
67
  DEFAULT_PRESET = "Robosuite Square 20260409"
68
  REPO_TYPE = "dataset"
69
  DEFAULT_CHUNK_LEN = 16
 
70
  PREFERRED_IMAGE_KEYS = [
71
  "image1",
72
  "image2",
73
  "agentview_image",
74
  "robot0_eye_in_hand_image",
 
 
75
  ]
76
 
 
 
77
 
78
  # -----------------------------------------------------------------------------
79
- # HDF5 helpers
80
  # -----------------------------------------------------------------------------
81
- def _clear_dataset_caches():
82
- get_local_hdf5_path.cache_clear()
83
- get_trajectory_keys.cache_clear()
84
- get_num_trajectories.cache_clear()
85
- load_traj.cache_clear()
86
-
87
-
88
  def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
89
- """Return (repo_id, filename) from a preset or custom fields."""
90
  preset_name = preset_name or DEFAULT_PRESET
 
91
  if preset_name == "Custom":
92
  repo_id = str(custom_repo_id or "").strip()
93
  filename = str(custom_filename or "").strip()
94
  if not repo_id or not filename:
95
- raise ValueError("For Custom, provide both repo_id and HDF5 filename/path.")
96
  return repo_id, filename
97
 
98
  if preset_name not in DATASET_PRESETS:
99
  preset_name = DEFAULT_PRESET
 
100
  item = DATASET_PRESETS[preset_name]
101
  return item["repo_id"], item["filename"]
102
 
@@ -111,75 +110,92 @@ def get_local_hdf5_path(repo_id, filename):
111
 
112
 
113
  def _natural_sort_key(name):
114
- m = re.search(r"([0-9]+)$", str(name))
115
- return (0, int(m.group(1))) if m else (1, str(name))
 
 
116
 
117
 
118
  @lru_cache(maxsize=8)
119
  def get_trajectory_keys(repo_id, filename):
120
- """Detect trajectory groups in common HDF5 layouts."""
121
  path = get_local_hdf5_path(repo_id, filename)
 
122
  with h5py.File(path, "r") as f:
123
- # Your TrajectoryBuffer saves root-level groups:
124
  # /episode_0000
125
  # /episode_0001
126
- # ...
127
- # Some other robotics datasets use /data/demo_0, so keep that fallback.
128
  root_episode_keys = [
129
- k for k in f.keys()
130
- if isinstance(f[k], h5py.Group) and str(k).startswith("episode_")
 
131
  ]
132
  if root_episode_keys:
133
- group = f
134
- prefix = ""
135
- group_keys = root_episode_keys
136
- elif "data" in f and isinstance(f["data"], h5py.Group):
137
- group = f["data"]
138
- prefix = "data"
139
- group_keys = [k for k in group.keys() if isinstance(group[k], h5py.Group)]
140
- else:
141
- group = f
142
- prefix = ""
143
- group_keys = [k for k in group.keys() if isinstance(group[k], h5py.Group)]
144
-
145
- group_keys = sorted(group_keys, key=_natural_sort_key)
146
- return tuple(f"{prefix}/{k}" if prefix else k for k in group_keys)
 
 
 
 
 
 
 
147
 
148
 
149
  @lru_cache(maxsize=8)
150
  def get_num_trajectories(repo_id, filename):
151
- return max(len(get_trajectory_keys(repo_id, filename)), 1)
152
 
153
 
154
- def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=160):
155
- """Show the HDF5 tree for debugging inside the Space."""
156
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
157
  path = get_local_hdf5_path(repo_id, filename)
 
158
  lines = []
159
  with h5py.File(path, "r") as f:
160
  def visitor(name, obj):
161
  if len(lines) >= max_lines:
162
  return
163
  if isinstance(obj, h5py.Dataset):
164
- lines.append(f"DATASET {name} shape={obj.shape} dtype={obj.dtype}")
 
 
165
  elif isinstance(obj, h5py.Group):
166
- lines.append(f"GROUP {name}")
 
167
  f.visititems(visitor)
168
 
169
  if len(lines) >= max_lines:
170
  lines.append("...")
171
- return "\n".join(lines) if lines else "No HDF5 contents found."
 
 
 
172
 
173
 
174
- def _read_dataset_value(ds):
175
- value = ds[()]
 
 
 
176
  if isinstance(value, bytes):
177
  return value.decode("utf-8")
178
  return value
179
 
180
 
181
  def _read_group_recursive(group):
182
- """Read a group into nested dictionaries of numpy arrays."""
183
  out = {}
184
  for key, obj in group.items():
185
  if isinstance(obj, h5py.Dataset):
@@ -189,239 +205,266 @@ def _read_group_recursive(group):
189
  return out
190
 
191
 
192
- def _find_first_existing_key(mapping, candidates):
193
- for key in candidates:
194
  if key in mapping:
195
  return key
196
  return None
197
 
198
 
199
- def _maybe_time_slice(value, t, T):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  arr = np.asarray(value)
201
  if arr.ndim >= 1 and arr.shape[0] == T:
202
  return arr[t]
203
  return arr
204
 
205
 
206
- def _infer_time_length(data):
207
- """Infer T from datasets whose first dimension is time."""
208
- candidate_lengths = []
209
-
210
- def collect(obj):
211
- if isinstance(obj, dict):
212
- for v in obj.values():
213
- collect(v)
214
- else:
215
- arr = np.asarray(obj)
216
- if arr.ndim >= 1 and arr.shape[0] > 1:
217
- candidate_lengths.append(int(arr.shape[0]))
218
-
219
- collect(data)
220
- if not candidate_lengths:
221
- return 1
222
-
223
- # The trajectory length should usually be the most common large first dim.
224
- values, counts = np.unique(candidate_lengths, return_counts=True)
225
- return int(values[np.argmax(counts)])
226
-
227
-
228
  @lru_cache(maxsize=64)
229
  def load_traj(repo_id, filename, traj_id):
230
- """Load one trajectory as a list of step dictionaries.
231
-
232
- Output step format:
233
- {
234
- "timestep": int,
235
- "obs": dict,
236
- "teacher_action": np.ndarray,
237
- "robot_action": np.ndarray,
238
- "no_teacher_action": bool,
239
- "no_robot_action": bool,
240
- }
241
- """
242
- path = get_local_hdf5_path(repo_id, filename)
243
  traj_keys = get_trajectory_keys(repo_id, filename)
244
  if not traj_keys:
245
  return []
246
 
247
  traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1))
248
  traj_key = traj_keys[traj_id]
 
249
 
250
  with h5py.File(path, "r") as f:
251
- g = f[traj_key]
252
- data = _read_group_recursive(g)
253
- attrs = dict(g.attrs)
254
-
255
- # Case A: trajectory group contains step groups: step_0, step_1, ...
256
- step_group_keys = [
257
- k for k, v in data.items()
258
- if isinstance(v, dict) and (str(k).startswith("step") or str(k).isdigit())
259
- ]
260
- if step_group_keys:
261
- traj = []
262
- for step_key in sorted(step_group_keys, key=_natural_sort_key):
263
- step = data[step_key]
264
- obs = step.get("obs", {}) if isinstance(step.get("obs", {}), dict) else {}
265
- teacher_action = step.get("teacher_action", step.get("teacher_actions", step.get("action", step.get("actions", np.zeros(1, dtype=np.float32)))))
266
- robot_action = step.get("robot_action", step.get("robot_actions", step.get("action", step.get("actions", teacher_action))))
267
- traj.append({
268
- "timestep": int(step.get("timestep", len(traj))),
269
- "obs": obs,
270
- "teacher_action": np.asarray(teacher_action),
271
- "robot_action": np.asarray(robot_action),
272
- "no_teacher_action": bool(np.asarray(step.get("no_teacher_action", step.get("no_teacher_actions", False))).reshape(-1)[0]),
273
- "no_robot_action": bool(np.asarray(step.get("no_robot_action", step.get("no_robot_actions", False))).reshape(-1)[0]),
274
- })
275
- return traj
276
-
277
- # Case B: trajectory group contains array datasets with first dimension T.
278
- # Your TrajectoryBuffer layout is:
279
- # /episode_0000/observation/<image_or_state_key>[T,...]
280
- # /episode_0000/robot_actions[T,D]
281
- # /episode_0000/teacher_actions[T,D]
282
- # /episode_0000/no_teacher_actions[T]
283
- # /episode_0000/no_robot_actions[T]
284
- #
285
- # Keep obs/action aliases for compatibility with other layouts.
286
  T = _infer_time_length(data)
287
- obs_all = {}
288
- if isinstance(data.get("observation", {}), dict):
289
- obs_all = data.get("observation", {})
290
- elif isinstance(data.get("obs", {}), dict):
291
- obs_all = data.get("obs", {})
292
-
293
- action_key = _find_first_existing_key(data, ["actions", "action"])
294
- teacher_key = _find_first_existing_key(data, ["teacher_actions", "teacher_action"])
295
- robot_key = _find_first_existing_key(data, ["robot_actions", "robot_action"])
296
- no_teacher_key = _find_first_existing_key(data, ["no_teacher_actions", "no_teacher_action"])
297
- no_robot_key = _find_first_existing_key(data, ["no_robot_actions", "no_robot_action"])
 
 
 
 
 
298
 
299
  traj = []
300
  for t in range(T):
301
  obs_t = {}
302
  for key, value in obs_all.items():
303
- obs_t[key] = _maybe_time_slice(value, t, T)
304
 
305
  default_action = np.zeros(1, dtype=np.float32)
306
  if action_key is not None:
307
- default_action = _maybe_time_slice(data[action_key], t, T)
308
-
309
- teacher_action = _maybe_time_slice(data[teacher_key], t, T) if teacher_key is not None else default_action
310
- robot_action = _maybe_time_slice(data[robot_key], t, T) if robot_key is not None else default_action
311
- no_teacher = _maybe_time_slice(data[no_teacher_key], t, T) if no_teacher_key is not None else False
312
- no_robot = _maybe_time_slice(data[no_robot_key], t, T) if no_robot_key is not None else False
313
-
314
- traj.append({
315
- "timestep": t,
316
- "obs": obs_t,
317
- "teacher_action": np.asarray(teacher_action),
318
- "robot_action": np.asarray(robot_action),
319
- "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]),
320
- "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]),
321
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  return traj
324
 
325
 
326
  # -----------------------------------------------------------------------------
327
- # Visualization helpers
328
  # -----------------------------------------------------------------------------
329
  def _extract_latest_obs_value(value):
330
  arr = np.asarray(value)
331
- # Your local script handled stacked recent observations by taking the latest.
332
- if arr.ndim >= 1 and arr.shape[0] in (1, 2):
333
  return arr[-1]
334
  return arr
335
 
336
 
337
- def _extract_display_image(value, reverse_channels=False, output_uint8=True):
338
- img = _extract_latest_obs_value(value)
339
- img = np.asarray(img)
 
340
 
341
- # Your saved image shape can be [obs_T, C, H, W] per timestep.
342
- # Take the latest stacked observation, then convert CHW -> HWC.
343
- if img.ndim == 4 and img.shape[0] in (1, 2, 3, 4):
344
- img = img[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  if img.ndim == 2:
347
  img = np.repeat(img[..., None], 3, axis=-1)
348
  elif img.ndim == 3 and img.shape[0] in (1, 3, 4):
349
  img = np.transpose(img, (1, 2, 0))
350
 
351
- if img.ndim == 3 and img.shape[-1] == 4:
 
 
352
  img = img[..., :3]
353
 
354
  if img.ndim != 3:
355
- raise ValueError(f"Unsupported image shape: {img.shape}")
356
 
357
  if img.dtype == np.uint8:
358
- img_rgb = img.copy()
359
  else:
360
- img_rgb = img.astype(np.float32)
361
- if np.nanmin(img_rgb) < 0:
362
- img_rgb = (img_rgb + 1.0) / 2.0
363
- if np.nanmax(img_rgb) > 1.5:
364
- img_rgb = img_rgb / 255.0
365
- img_rgb = np.clip(img_rgb, 0.0, 1.0)
366
- if output_uint8:
367
- img_rgb = np.round(img_rgb * 255.0).astype(np.uint8)
368
 
369
- if reverse_channels and img_rgb.shape[-1] == 3:
370
- img_rgb = img_rgb[..., ::-1]
 
 
371
 
372
- return img_rgb
373
 
374
 
375
- def _resize_image_for_display(img, display_scale=4):
376
- if display_scale is None or float(display_scale) == 1.0:
 
377
  return img
378
 
379
- display_scale = float(display_scale)
380
  h, w = img.shape[:2]
381
- new_w = max(1, int(round(w * display_scale)))
382
- new_h = max(1, int(round(h * display_scale)))
383
 
384
  if cv2 is not None:
385
- return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
386
 
387
  pil_img = Image.fromarray(img)
388
- return np.asarray(pil_img.resize((new_w, new_h), resample=Image.Resampling.NEAREST))
389
 
390
 
391
- def _extract_mixed_action_chunk(traj, start_idx, chunk_length=16):
392
  chunk = []
393
  sources = []
394
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
 
395
  for idx in range(int(start_idx), end_idx):
396
  step = traj[idx]
397
  use_teacher = not bool(step.get("no_teacher_action", False))
398
  action = step["teacher_action"] if use_teacher else step["robot_action"]
399
  chunk.append(np.asarray(action, dtype=np.float32).reshape(-1))
400
  sources.append("T" if use_teacher else "R")
 
401
  if not chunk:
402
  return None, ""
 
403
  return np.stack(chunk, axis=0), "".join(sources)
404
 
405
 
406
- def _extract_robot_action_chunk(traj, start_idx, chunk_length=16):
407
  chunk = []
408
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
 
409
  for idx in range(int(start_idx), end_idx):
410
  step = traj[idx]
411
  chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1))
 
412
  if not chunk:
413
  return None
 
414
  return np.stack(chunk, axis=0)
415
 
416
 
417
- def _safe_array_str(x, precision=3, max_items=24):
418
- arr = np.asarray(x).reshape(-1)
419
  shown = arr[:max_items]
420
- suffix = "" if arr.size <= max_items else f" ... +{arr.size - max_items} more"
421
- return np.array2string(shown, precision=precision, separator=", ") + suffix
 
 
422
 
423
 
424
- def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
425
  if mixed_chunk is None:
426
  return None
427
 
@@ -430,18 +473,24 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
430
  mixed_chunk = mixed_chunk[:, None]
431
 
432
  fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
433
- t = np.arange(mixed_chunk.shape[0])
434
  max_dims = min(mixed_chunk.shape[1], 10)
435
 
436
- for d in range(max_dims):
437
- ax.plot(t, mixed_chunk[:, d], label=f"mixed[{d}]")
438
 
439
  if robot_chunk is not None:
440
  robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
441
  if robot_chunk.ndim == 1:
442
  robot_chunk = robot_chunk[:, None]
443
- for d in range(min(robot_chunk.shape[1], max_dims)):
444
- ax.plot(t, robot_chunk[:, d], linestyle="--", alpha=0.55, label=f"robot[{d}]")
 
 
 
 
 
 
445
 
446
  ax.set_title("Action chunk")
447
  ax.set_xlabel("chunk step")
@@ -451,144 +500,193 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
451
  fig.tight_layout()
452
  fig.canvas.draw()
453
  rgba = np.asarray(fig.canvas.buffer_rgba())
454
- img = rgba[..., :3].copy()
455
  plt.close(fig)
456
- return img
457
 
458
 
 
 
 
459
  def get_available_image_keys(repo_id, filename, traj_id):
460
  n_traj = get_num_trajectories(repo_id, filename)
461
- traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
 
 
 
462
  traj = load_traj(repo_id, filename, traj_id)
463
  if not traj:
464
  return []
465
 
466
  obs = traj[0].get("obs", {})
467
- image_keys = []
468
  for key, value in obs.items():
469
  try:
470
- arr = np.asarray(_extract_latest_obs_value(value))
471
- key_l = str(key).lower()
472
- key_hint = any(s in key_l for s in ["rgb", "image", "img", "camera", "cam"])
473
- looks_like_shape = (
474
- arr.ndim == 2
475
- or (arr.ndim == 3 and (arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)))
476
- or (arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4) and arr.shape[1] in (1, 3, 4))
477
- )
478
- if key_hint or looks_like_shape:
479
- image_keys.append(key)
480
  except Exception:
481
  pass
482
 
483
- ordered = [k for k in PREFERRED_IMAGE_KEYS if k in image_keys]
484
- ordered += [k for k in image_keys if k not in ordered]
485
  return ordered
486
 
487
 
488
- # -----------------------------------------------------------------------------
489
- # Gradio callbacks
490
- # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id):
492
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
493
  n_traj = get_num_trajectories(repo_id, filename)
494
- traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
 
 
 
495
  traj = load_traj(repo_id, filename, traj_id)
496
- image_keys = get_available_image_keys(repo_id, filename, traj_id)
497
- max_step = max(len(traj) - 1, 0)
498
- slider_max = max(max_step, 1) # Gradio requires min < max.
499
  return (
500
- gr.update(maximum=slider_max, value=0),
501
- gr.update(choices=image_keys, value=image_keys[:2]),
502
  )
503
 
504
 
505
- def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
 
 
 
 
 
 
 
 
 
 
506
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
507
  n_traj = get_num_trajectories(repo_id, filename)
508
- traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
509
- traj = load_traj(repo_id, filename, traj_id)
510
 
 
 
 
 
 
511
  if not traj:
512
- return [], None, "No trajectory could be loaded. Open the HDF5 debug panel to inspect the file layout."
513
 
514
  timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
515
  chunk_len = int(chunk_len)
516
  display_scale = float(display_scale)
517
- step = traj[timestep]
518
- obs = step.get("obs", {})
519
 
520
  if image_keys is None:
521
  image_keys = []
522
  if isinstance(image_keys, str):
523
  image_keys = [image_keys]
524
 
525
- gallery = []
526
- errors = []
 
 
 
527
  for key in image_keys:
528
  if key not in obs:
529
- errors.append(f"Missing image key: {key}")
530
  continue
531
  try:
532
- img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels), output_uint8=True)
533
- img = _resize_image_for_display(img, display_scale=display_scale)
534
- gallery.append((img, key))
 
 
 
535
  except Exception as exc:
536
- errors.append(f"{key}: {exc}")
537
 
538
- mixed_chunk, chunk_sources = _extract_mixed_action_chunk(traj, timestep, chunk_length=chunk_len)
539
- robot_chunk = _extract_robot_action_chunk(traj, timestep, chunk_length=chunk_len)
540
  action_plot = _make_action_chunk_plot(mixed_chunk, robot_chunk)
541
 
542
- teacher_action = step.get("teacher_action", np.zeros(1, dtype=np.float32))
543
- robot_action = step.get("robot_action", np.zeros(1, dtype=np.float32))
544
- no_teacher = bool(step.get("no_teacher_action", False))
545
- no_robot = bool(step.get("no_robot_action", False))
546
-
547
  info_lines = [
548
- f"detected trajectories: {n_traj}",
549
- f"trajectory: {traj_id}",
550
- f"timestep: {timestep} / {len(traj) - 1}",
551
- f"no_teacher_action: {int(no_teacher)}",
552
- f"no_robot_action: {int(no_robot)}",
553
- f"chunk_len: {chunk_len}",
554
- f"chunk source mask: {chunk_sources} (T=teacher, R=robot fallback)",
 
 
 
 
 
555
  "",
556
- f"teacher_action: {_safe_array_str(teacher_action)}",
557
- f"robot_action: {_safe_array_str(robot_action)}",
558
  ]
559
 
560
- if errors:
561
- info_lines += ["", "Image warnings:"] + errors
 
 
562
 
563
- return gallery, action_plot, "\n".join(info_lines)
564
 
565
 
566
  # -----------------------------------------------------------------------------
567
  # App
568
  # -----------------------------------------------------------------------------
569
  def build_app():
 
 
570
  try:
571
- repo_id, filename = resolve_dataset(DEFAULT_PRESET)
572
  n_traj = get_num_trajectories(repo_id, filename)
573
- first_keys = get_available_image_keys(repo_id, filename, 0)
574
- startup_error = None
575
  except Exception as exc:
576
- n_traj = 1
577
  first_keys = []
578
- startup_error = repr(exc)
 
 
 
579
 
580
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
581
  gr.Markdown(
582
- "# HDF5 Trajectory Viewer\n"
583
- "Standalone viewer: no local `TrajectoryBuffer` dependency.\n\n"
584
- f"Default dataset detected trajectories: **{n_traj}**"
585
  )
586
 
587
- if startup_error is not None:
588
- gr.Markdown(
589
- "⚠️ **Startup warning**\n\n"
590
- f"```text\n{startup_error}\n```"
591
- )
592
 
593
  with gr.Row():
594
  preset = gr.Dropdown(
@@ -607,6 +705,13 @@ def build_app():
607
  visible=False,
608
  )
609
 
 
 
 
 
 
 
 
610
  with gr.Row():
611
  traj_slider = gr.Slider(
612
  minimum=0,
@@ -657,37 +762,13 @@ def build_app():
657
  object_fit="contain",
658
  )
659
  action_plot = gr.Image(label="Action chunk plot", type="numpy")
660
- info = gr.Textbox(label="Frame info", lines=13)
661
 
662
  with gr.Accordion("Debug: HDF5 tree", open=False):
663
  inspect_btn = gr.Button("Inspect HDF5 structure")
664
- hdf5_tree = gr.Textbox(lines=22, label="HDF5 tree")
665
- inspect_btn.click(
666
- fn=inspect_hdf5_tree,
667
- inputs=[preset, custom_repo_id, custom_filename],
668
- outputs=hdf5_tree,
669
- )
670
-
671
- def update_custom_visibility(preset_name):
672
- visible = preset_name == "Custom"
673
- return gr.update(visible=visible), gr.update(visible=visible)
674
-
675
- def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
676
- repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
677
- n = get_num_trajectories(repo_id, filename)
678
- keys = get_available_image_keys(repo_id, filename, 0)
679
- traj = load_traj(repo_id, filename, 0)
680
- status_text = "Loaded `{}` / `{}`".format(repo_id, filename)
681
- status_text = status_text + chr(10) + "Detected trajectories: {}".format(n)
682
- return (
683
- gr.update(maximum=max(n - 1, 1), value=0),
684
- gr.update(maximum=max(len(traj) - 1, 1), value=0),
685
- gr.update(choices=keys, value=keys[:2]),
686
- status_text,
687
- )
688
- dataset_status = gr.Textbox(label="Dataset status", lines=2, value=f"Loaded default dataset
689
- Detected trajectories: {n_traj}")
690
 
 
691
  preset.change(
692
  fn=update_custom_visibility,
693
  inputs=preset,
@@ -698,7 +779,17 @@ Detected trajectories: {n_traj}")
698
  outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
699
  ).then(
700
  fn=render_frame,
701
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
702
  outputs=[gallery, action_plot, info],
703
  )
704
 
@@ -719,39 +810,93 @@ Detected trajectories: {n_traj}")
719
  outputs=[timestep_slider, image_keys],
720
  ).then(
721
  fn=render_frame,
722
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
723
  outputs=[gallery, action_plot, info],
724
  )
725
 
726
- # Use release for the timestep slider so the gallery does not clear/re-render
727
- # continuously while the user drags through a trajectory.
728
  timestep_slider.release(
729
  fn=render_frame,
730
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
731
  outputs=[gallery, action_plot, info],
732
  )
733
 
734
- # These controls can re-render immediately because they are changed less often.
735
  for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
736
  widget.change(
737
  fn=render_frame,
738
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
739
  outputs=[gallery, action_plot, info],
740
  )
741
 
742
  render_btn.click(
743
  fn=render_frame,
744
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
745
  outputs=[gallery, action_plot, info],
746
  )
747
 
 
 
 
 
 
 
748
  demo.load(
749
- fn=update_after_traj_change,
750
- inputs=[preset, custom_repo_id, custom_filename, traj_slider],
751
- outputs=[timestep_slider, image_keys],
752
  ).then(
753
  fn=render_frame,
754
- inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
 
 
755
  outputs=[gallery, action_plot, info],
756
  )
757
 
 
1
  """
2
+ Standalone Hugging Face Space viewer for TrajectoryBuffer-style HDF5 files.
3
 
4
+ Best-practice version:
5
+ - No dependency on your local TrajectoryBuffer Python class.
6
+ - Dataset preset + custom dataset support.
7
+ - Robust HDF5 schema detection for root-level episode_XXXX groups.
8
+ - Auto-detect image keys from each trajectory's observation group.
9
+ - Avoids fragile multiline f-strings in UI status text.
10
+ - Uses slider.release() for timestep rendering to reduce image flicker.
11
 
12
  requirements.txt:
13
+ gradio
14
+ huggingface_hub
15
+ h5py
16
+ numpy
17
+ pillow
18
+ matplotlib
19
 
20
  Optional:
21
+ opencv-python-headless
22
  """
23
 
 
24
  import re
25
  from functools import lru_cache
26
 
27
  import gradio as gr
28
  import h5py
29
  import matplotlib
30
+
31
  matplotlib.use("Agg")
32
  import matplotlib.pyplot as plt
33
  import numpy as np
 
41
 
42
 
43
  # -----------------------------------------------------------------------------
44
+ # Dataset presets
 
45
  # -----------------------------------------------------------------------------
46
  DATASET_PRESETS = {
47
  "Robosuite Square 20260409": {
 
67
  DEFAULT_PRESET = "Robosuite Square 20260409"
68
  REPO_TYPE = "dataset"
69
  DEFAULT_CHUNK_LEN = 16
70
+
71
  PREFERRED_IMAGE_KEYS = [
72
  "image1",
73
  "image2",
74
  "agentview_image",
75
  "robot0_eye_in_hand_image",
76
+ "front_image",
77
+ "wrist_image",
78
  ]
79
 
80
+ IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"]
81
+
82
 
83
  # -----------------------------------------------------------------------------
84
+ # Dataset resolution and cache helpers
85
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
86
  def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
 
87
  preset_name = preset_name or DEFAULT_PRESET
88
+
89
  if preset_name == "Custom":
90
  repo_id = str(custom_repo_id or "").strip()
91
  filename = str(custom_filename or "").strip()
92
  if not repo_id or not filename:
93
+ raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.")
94
  return repo_id, filename
95
 
96
  if preset_name not in DATASET_PRESETS:
97
  preset_name = DEFAULT_PRESET
98
+
99
  item = DATASET_PRESETS[preset_name]
100
  return item["repo_id"], item["filename"]
101
 
 
110
 
111
 
112
  def _natural_sort_key(name):
113
+ match = re.search(r"([0-9]+)$", str(name))
114
+ if match:
115
+ return 0, int(match.group(1))
116
+ return 1, str(name)
117
 
118
 
119
  @lru_cache(maxsize=8)
120
  def get_trajectory_keys(repo_id, filename):
121
+ """Return ordered trajectory group paths."""
122
  path = get_local_hdf5_path(repo_id, filename)
123
+
124
  with h5py.File(path, "r") as f:
125
+ # Your TrajectoryBuffer format:
126
  # /episode_0000
127
  # /episode_0001
 
 
128
  root_episode_keys = [
129
+ key
130
+ for key in f.keys()
131
+ if isinstance(f[key], h5py.Group) and str(key).startswith("episode_")
132
  ]
133
  if root_episode_keys:
134
+ return tuple(sorted(root_episode_keys, key=_natural_sort_key))
135
+
136
+ # Robomimic-style fallback:
137
+ # /data/demo_0
138
+ # /data/demo_1
139
+ if "data" in f and isinstance(f["data"], h5py.Group):
140
+ data_group = f["data"]
141
+ keys = [
142
+ key
143
+ for key in data_group.keys()
144
+ if isinstance(data_group[key], h5py.Group)
145
+ ]
146
+ return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key))
147
+
148
+ # Generic root-level group fallback.
149
+ keys = [
150
+ key
151
+ for key in f.keys()
152
+ if isinstance(f[key], h5py.Group)
153
+ ]
154
+ return tuple(sorted(keys, key=_natural_sort_key))
155
 
156
 
157
  @lru_cache(maxsize=8)
158
  def get_num_trajectories(repo_id, filename):
159
+ return len(get_trajectory_keys(repo_id, filename))
160
 
161
 
162
+ def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180):
 
163
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
164
  path = get_local_hdf5_path(repo_id, filename)
165
+
166
  lines = []
167
  with h5py.File(path, "r") as f:
168
  def visitor(name, obj):
169
  if len(lines) >= max_lines:
170
  return
171
  if isinstance(obj, h5py.Dataset):
172
+ lines.append(
173
+ "DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype)
174
+ )
175
  elif isinstance(obj, h5py.Group):
176
+ lines.append("GROUP {}".format(name))
177
+
178
  f.visititems(visitor)
179
 
180
  if len(lines) >= max_lines:
181
  lines.append("...")
182
+
183
+ if not lines:
184
+ return "No HDF5 contents found."
185
+ return chr(10).join(lines)
186
 
187
 
188
+ # -----------------------------------------------------------------------------
189
+ # HDF5 loading helpers
190
+ # -----------------------------------------------------------------------------
191
+ def _read_dataset_value(dataset):
192
+ value = dataset[()]
193
  if isinstance(value, bytes):
194
  return value.decode("utf-8")
195
  return value
196
 
197
 
198
  def _read_group_recursive(group):
 
199
  out = {}
200
  for key, obj in group.items():
201
  if isinstance(obj, h5py.Dataset):
 
205
  return out
206
 
207
 
208
+ def _find_first_key(mapping, candidate_keys):
209
+ for key in candidate_keys:
210
  if key in mapping:
211
  return key
212
  return None
213
 
214
 
215
+ def _infer_time_length(data):
216
+ """Infer trajectory length from common TrajectoryBuffer fields."""
217
+ for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]:
218
+ if key in data:
219
+ arr = np.asarray(data[key])
220
+ if arr.ndim >= 1:
221
+ return int(arr.shape[0])
222
+
223
+ obs_group = None
224
+ if isinstance(data.get("observation"), dict):
225
+ obs_group = data["observation"]
226
+ elif isinstance(data.get("obs"), dict):
227
+ obs_group = data["obs"]
228
+
229
+ if obs_group:
230
+ lengths = []
231
+ for value in obs_group.values():
232
+ arr = np.asarray(value)
233
+ if arr.ndim >= 1:
234
+ lengths.append(int(arr.shape[0]))
235
+ if lengths:
236
+ values, counts = np.unique(lengths, return_counts=True)
237
+ return int(values[np.argmax(counts)])
238
+
239
+ return 1
240
+
241
+
242
+ def _slice_time(value, t, T):
243
  arr = np.asarray(value)
244
  if arr.ndim >= 1 and arr.shape[0] == T:
245
  return arr[t]
246
  return arr
247
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  @lru_cache(maxsize=64)
250
  def load_traj(repo_id, filename, traj_id):
251
+ """Load one trajectory as list[dict]."""
 
 
 
 
 
 
 
 
 
 
 
 
252
  traj_keys = get_trajectory_keys(repo_id, filename)
253
  if not traj_keys:
254
  return []
255
 
256
  traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1))
257
  traj_key = traj_keys[traj_id]
258
+ path = get_local_hdf5_path(repo_id, filename)
259
 
260
  with h5py.File(path, "r") as f:
261
+ group = f[traj_key]
262
+ data = _read_group_recursive(group)
263
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  T = _infer_time_length(data)
265
+
266
+ if isinstance(data.get("observation"), dict):
267
+ obs_all = data["observation"]
268
+ elif isinstance(data.get("obs"), dict):
269
+ obs_all = data["obs"]
270
+ else:
271
+ obs_all = {}
272
+
273
+ action_key = _find_first_key(data, ["actions", "action"])
274
+ teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"])
275
+ robot_key = _find_first_key(data, ["robot_actions", "robot_action"])
276
+ no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"])
277
+ no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"])
278
+ done_key = _find_first_key(data, ["dones", "done"])
279
+ timestep_key = _find_first_key(data, ["timesteps", "timestep"])
280
+ success_key = _find_first_key(data, ["if_success", "success", "successes"])
281
 
282
  traj = []
283
  for t in range(T):
284
  obs_t = {}
285
  for key, value in obs_all.items():
286
+ obs_t[key] = _slice_time(value, t, T)
287
 
288
  default_action = np.zeros(1, dtype=np.float32)
289
  if action_key is not None:
290
+ default_action = _slice_time(data[action_key], t, T)
291
+
292
+ teacher_action = default_action
293
+ if teacher_key is not None:
294
+ teacher_action = _slice_time(data[teacher_key], t, T)
295
+
296
+ robot_action = default_action
297
+ if robot_key is not None:
298
+ robot_action = _slice_time(data[robot_key], t, T)
299
+
300
+ no_teacher = False
301
+ if no_teacher_key is not None:
302
+ no_teacher = _slice_time(data[no_teacher_key], t, T)
303
+
304
+ no_robot = False
305
+ if no_robot_key is not None:
306
+ no_robot = _slice_time(data[no_robot_key], t, T)
307
+
308
+ done = False
309
+ if done_key is not None:
310
+ done = _slice_time(data[done_key], t, T)
311
+
312
+ timestep = t
313
+ if timestep_key is not None:
314
+ timestep_arr = _slice_time(data[timestep_key], t, T)
315
+ timestep = int(np.asarray(timestep_arr).reshape(-1)[0])
316
+
317
+ if_success = False
318
+ if success_key is not None:
319
+ if_success = _slice_time(data[success_key], t, T)
320
+
321
+ traj.append(
322
+ {
323
+ "obs": obs_t,
324
+ "robot_action": np.asarray(robot_action),
325
+ "teacher_action": np.asarray(teacher_action),
326
+ "done": bool(np.asarray(done).reshape(-1)[0]),
327
+ "timestep": timestep,
328
+ "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]),
329
+ "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]),
330
+ "episode_id": traj_key,
331
+ "if_success": bool(np.asarray(if_success).reshape(-1)[0]),
332
+ }
333
+ )
334
 
335
  return traj
336
 
337
 
338
  # -----------------------------------------------------------------------------
339
+ # Image and plotting helpers
340
  # -----------------------------------------------------------------------------
341
  def _extract_latest_obs_value(value):
342
  arr = np.asarray(value)
343
+ # Per-timestep stacked observation commonly has shape [obs_T, C, H, W].
344
+ if arr.ndim >= 1 and arr.shape[0] in (1, 2, 3, 4):
345
  return arr[-1]
346
  return arr
347
 
348
 
349
+ def _looks_like_image_array(key, value):
350
+ arr = np.asarray(_extract_latest_obs_value(value))
351
+ key_l = str(key).lower()
352
+ key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS)
353
 
354
+ shape_hint = False
355
+ if arr.ndim == 2:
356
+ shape_hint = True
357
+ elif arr.ndim == 3:
358
+ shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)
359
+
360
+ return key_hint or shape_hint
361
+
362
+
363
+ def _float_img_to_uint8(img):
364
+ arr = img.astype(np.float32)
365
+ arr_min = float(np.nanmin(arr))
366
+ arr_max = float(np.nanmax(arr))
367
+
368
+ # TrajectoryBuffer saves float images originally in [-1, 1] as uint8.
369
+ # But for compatibility, handle float [-1, 1], [0, 1], and [0, 255].
370
+ if arr_min >= -1.01 and arr_max <= 1.01:
371
+ if arr_min < 0.0:
372
+ arr = (arr + 1.0) * 0.5
373
+ arr = np.clip(arr, 0.0, 1.0) * 255.0
374
+ elif arr_max <= 255.0:
375
+ arr = np.clip(arr, 0.0, 255.0)
376
+ else:
377
+ arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8)
378
+
379
+ return np.round(arr).astype(np.uint8)
380
+
381
+
382
+ def _extract_display_image(value, reverse_channels=False):
383
+ img = np.asarray(_extract_latest_obs_value(value))
384
 
385
  if img.ndim == 2:
386
  img = np.repeat(img[..., None], 3, axis=-1)
387
  elif img.ndim == 3 and img.shape[0] in (1, 3, 4):
388
  img = np.transpose(img, (1, 2, 0))
389
 
390
+ if img.ndim == 3 and img.shape[-1] == 1:
391
+ img = np.repeat(img, 3, axis=-1)
392
+ elif img.ndim == 3 and img.shape[-1] == 4:
393
  img = img[..., :3]
394
 
395
  if img.ndim != 3:
396
+ raise ValueError("Unsupported image shape: {}".format(img.shape))
397
 
398
  if img.dtype == np.uint8:
399
+ out = img.copy()
400
  else:
401
+ out = _float_img_to_uint8(img)
 
 
 
 
 
 
 
402
 
403
+ # Browser display expects RGB. Your current data appears RGB already,
404
+ # so default reverse_channels=False.
405
+ if reverse_channels and out.shape[-1] == 3:
406
+ out = out[..., ::-1]
407
 
408
+ return out
409
 
410
 
411
+ def _resize_image_for_display(img, display_scale):
412
+ scale = float(display_scale)
413
+ if scale == 1.0:
414
  return img
415
 
 
416
  h, w = img.shape[:2]
417
+ new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale))))
 
418
 
419
  if cv2 is not None:
420
+ return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST)
421
 
422
  pil_img = Image.fromarray(img)
423
+ return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST))
424
 
425
 
426
+ def _extract_mixed_action_chunk(traj, start_idx, chunk_length):
427
  chunk = []
428
  sources = []
429
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
430
+
431
  for idx in range(int(start_idx), end_idx):
432
  step = traj[idx]
433
  use_teacher = not bool(step.get("no_teacher_action", False))
434
  action = step["teacher_action"] if use_teacher else step["robot_action"]
435
  chunk.append(np.asarray(action, dtype=np.float32).reshape(-1))
436
  sources.append("T" if use_teacher else "R")
437
+
438
  if not chunk:
439
  return None, ""
440
+
441
  return np.stack(chunk, axis=0), "".join(sources)
442
 
443
 
444
+ def _extract_robot_action_chunk(traj, start_idx, chunk_length):
445
  chunk = []
446
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
447
+
448
  for idx in range(int(start_idx), end_idx):
449
  step = traj[idx]
450
  chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1))
451
+
452
  if not chunk:
453
  return None
454
+
455
  return np.stack(chunk, axis=0)
456
 
457
 
458
+ def _safe_array_str(value, precision=3, max_items=24):
459
+ arr = np.asarray(value).reshape(-1)
460
  shown = arr[:max_items]
461
+ text = np.array2string(shown, precision=precision, separator=", ")
462
+ if arr.size > max_items:
463
+ text += " ... +{} more".format(arr.size - max_items)
464
+ return text
465
 
466
 
467
+ def _make_action_chunk_plot(mixed_chunk, robot_chunk):
468
  if mixed_chunk is None:
469
  return None
470
 
 
473
  mixed_chunk = mixed_chunk[:, None]
474
 
475
  fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
476
+ x = np.arange(mixed_chunk.shape[0])
477
  max_dims = min(mixed_chunk.shape[1], 10)
478
 
479
+ for dim in range(max_dims):
480
+ ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim))
481
 
482
  if robot_chunk is not None:
483
  robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
484
  if robot_chunk.ndim == 1:
485
  robot_chunk = robot_chunk[:, None]
486
+ for dim in range(min(robot_chunk.shape[1], max_dims)):
487
+ ax.plot(
488
+ x,
489
+ robot_chunk[:, dim],
490
+ linestyle="--",
491
+ alpha=0.55,
492
+ label="robot[{}]".format(dim),
493
+ )
494
 
495
  ax.set_title("Action chunk")
496
  ax.set_xlabel("chunk step")
 
500
  fig.tight_layout()
501
  fig.canvas.draw()
502
  rgba = np.asarray(fig.canvas.buffer_rgba())
503
+ image = rgba[..., :3].copy()
504
  plt.close(fig)
505
+ return image
506
 
507
 
508
+ # -----------------------------------------------------------------------------
509
+ # Gradio callbacks
510
+ # -----------------------------------------------------------------------------
511
  def get_available_image_keys(repo_id, filename, traj_id):
512
  n_traj = get_num_trajectories(repo_id, filename)
513
+ if n_traj == 0:
514
+ return []
515
+
516
+ traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
517
  traj = load_traj(repo_id, filename, traj_id)
518
  if not traj:
519
  return []
520
 
521
  obs = traj[0].get("obs", {})
522
+ keys = []
523
  for key, value in obs.items():
524
  try:
525
+ if _looks_like_image_array(key, value):
526
+ keys.append(key)
 
 
 
 
 
 
 
 
527
  except Exception:
528
  pass
529
 
530
+ ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys]
531
+ ordered.extend([key for key in keys if key not in ordered])
532
  return ordered
533
 
534
 
535
+ def update_custom_visibility(preset_name):
536
+ visible = preset_name == "Custom"
537
+ return gr.update(visible=visible), gr.update(visible=visible)
538
+
539
+
540
+ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
541
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
542
+ n_traj = get_num_trajectories(repo_id, filename)
543
+
544
+ if n_traj == 0:
545
+ status = "Loaded `{}` / `{}`".format(repo_id, filename)
546
+ status = status + chr(10) + "Detected trajectories: 0"
547
+ return (
548
+ gr.update(maximum=1, value=0),
549
+ gr.update(maximum=1, value=0),
550
+ gr.update(choices=[], value=[]),
551
+ status,
552
+ )
553
+
554
+ keys = get_available_image_keys(repo_id, filename, 0)
555
+ traj = load_traj(repo_id, filename, 0)
556
+
557
+ status = "Loaded `{}` / `{}`".format(repo_id, filename)
558
+ status = status + chr(10) + "Detected trajectories: {}".format(n_traj)
559
+
560
+ return (
561
+ gr.update(maximum=max(n_traj - 1, 1), value=0),
562
+ gr.update(maximum=max(len(traj) - 1, 1), value=0),
563
+ gr.update(choices=keys, value=keys[:2]),
564
+ status,
565
+ )
566
+
567
+
568
  def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id):
569
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
570
  n_traj = get_num_trajectories(repo_id, filename)
571
+ if n_traj == 0:
572
+ return gr.update(maximum=1, value=0), gr.update(choices=[], value=[])
573
+
574
+ traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
575
  traj = load_traj(repo_id, filename, traj_id)
576
+ keys = get_available_image_keys(repo_id, filename, traj_id)
577
+
 
578
  return (
579
+ gr.update(maximum=max(len(traj) - 1, 1), value=0),
580
+ gr.update(choices=keys, value=keys[:2]),
581
  )
582
 
583
 
584
+ def render_frame(
585
+ preset_name,
586
+ custom_repo_id,
587
+ custom_filename,
588
+ traj_id,
589
+ timestep,
590
+ image_keys,
591
+ chunk_len,
592
+ display_scale,
593
+ reverse_channels,
594
+ ):
595
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
596
  n_traj = get_num_trajectories(repo_id, filename)
 
 
597
 
598
+ if n_traj == 0:
599
+ return [], None, "No trajectory groups found. Open Debug: HDF5 tree."
600
+
601
+ traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
602
+ traj = load_traj(repo_id, filename, traj_id)
603
  if not traj:
604
+ return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree."
605
 
606
  timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
607
  chunk_len = int(chunk_len)
608
  display_scale = float(display_scale)
 
 
609
 
610
  if image_keys is None:
611
  image_keys = []
612
  if isinstance(image_keys, str):
613
  image_keys = [image_keys]
614
 
615
+ step = traj[timestep]
616
+ obs = step.get("obs", {})
617
+
618
+ gallery_items = []
619
+ warnings = []
620
  for key in image_keys:
621
  if key not in obs:
622
+ warnings.append("Missing image key: {}".format(key))
623
  continue
624
  try:
625
+ img = _extract_display_image(
626
+ obs[key],
627
+ reverse_channels=bool(reverse_channels),
628
+ )
629
+ img = _resize_image_for_display(img, display_scale)
630
+ gallery_items.append((img, key))
631
  except Exception as exc:
632
+ warnings.append("{}: {}".format(key, exc))
633
 
634
+ mixed_chunk, source_mask = _extract_mixed_action_chunk(traj, timestep, chunk_len)
635
+ robot_chunk = _extract_robot_action_chunk(traj, timestep, chunk_len)
636
  action_plot = _make_action_chunk_plot(mixed_chunk, robot_chunk)
637
 
 
 
 
 
 
638
  info_lines = [
639
+ "dataset: {} / {}".format(repo_id, filename),
640
+ "detected trajectories: {}".format(n_traj),
641
+ "trajectory: {}".format(traj_id),
642
+ "episode_id: {}".format(step.get("episode_id", "")),
643
+ "timestep: {} / {}".format(timestep, len(traj) - 1),
644
+ "saved timestep: {}".format(step.get("timestep", timestep)),
645
+ "done: {}".format(int(bool(step.get("done", False)))),
646
+ "if_success: {}".format(int(bool(step.get("if_success", False)))),
647
+ "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
648
+ "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
649
+ "chunk_len: {}".format(chunk_len),
650
+ "chunk source mask: {} (T=teacher, R=robot fallback)".format(source_mask),
651
  "",
652
+ "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
653
+ "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
654
  ]
655
 
656
+ if warnings:
657
+ info_lines.append("")
658
+ info_lines.append("Image warnings:")
659
+ info_lines.extend(warnings)
660
 
661
+ return gallery_items, action_plot, chr(10).join(info_lines)
662
 
663
 
664
  # -----------------------------------------------------------------------------
665
  # App
666
  # -----------------------------------------------------------------------------
667
  def build_app():
668
+ repo_id, filename = resolve_dataset(DEFAULT_PRESET)
669
+
670
  try:
 
671
  n_traj = get_num_trajectories(repo_id, filename)
672
+ first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else []
673
+ startup_warning = ""
674
  except Exception as exc:
675
+ n_traj = 0
676
  first_keys = []
677
+ startup_warning = repr(exc)
678
+
679
+ default_status = "Loaded default dataset" + chr(10)
680
+ default_status += "Detected trajectories: {}".format(n_traj)
681
 
682
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
683
  gr.Markdown(
684
+ "# HDF5 Trajectory Viewer\n\n"
685
+ "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face."
 
686
  )
687
 
688
+ if startup_warning:
689
+ gr.Markdown("Startup warning: `{}`".format(startup_warning))
 
 
 
690
 
691
  with gr.Row():
692
  preset = gr.Dropdown(
 
705
  visible=False,
706
  )
707
 
708
+ dataset_status = gr.Textbox(
709
+ label="Dataset status",
710
+ lines=2,
711
+ value=default_status,
712
+ interactive=False,
713
+ )
714
+
715
  with gr.Row():
716
  traj_slider = gr.Slider(
717
  minimum=0,
 
762
  object_fit="contain",
763
  )
764
  action_plot = gr.Image(label="Action chunk plot", type="numpy")
765
+ info = gr.Textbox(label="Frame info", lines=16)
766
 
767
  with gr.Accordion("Debug: HDF5 tree", open=False):
768
  inspect_btn = gr.Button("Inspect HDF5 structure")
769
+ hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
+ # Dataset selection and custom-field visibility.
772
  preset.change(
773
  fn=update_custom_visibility,
774
  inputs=preset,
 
779
  outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
780
  ).then(
781
  fn=render_frame,
782
+ inputs=[
783
+ preset,
784
+ custom_repo_id,
785
+ custom_filename,
786
+ traj_slider,
787
+ timestep_slider,
788
+ image_keys,
789
+ chunk_len,
790
+ display_scale,
791
+ reverse_channels,
792
+ ],
793
  outputs=[gallery, action_plot, info],
794
  )
795
 
 
810
  outputs=[timestep_slider, image_keys],
811
  ).then(
812
  fn=render_frame,
813
+ inputs=[
814
+ preset,
815
+ custom_repo_id,
816
+ custom_filename,
817
+ traj_slider,
818
+ timestep_slider,
819
+ image_keys,
820
+ chunk_len,
821
+ display_scale,
822
+ reverse_channels,
823
+ ],
824
  outputs=[gallery, action_plot, info],
825
  )
826
 
827
+ # Render only after releasing the timestep slider, reducing flicker.
 
828
  timestep_slider.release(
829
  fn=render_frame,
830
+ inputs=[
831
+ preset,
832
+ custom_repo_id,
833
+ custom_filename,
834
+ traj_slider,
835
+ timestep_slider,
836
+ image_keys,
837
+ chunk_len,
838
+ display_scale,
839
+ reverse_channels,
840
+ ],
841
  outputs=[gallery, action_plot, info],
842
  )
843
 
 
844
  for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
845
  widget.change(
846
  fn=render_frame,
847
+ inputs=[
848
+ preset,
849
+ custom_repo_id,
850
+ custom_filename,
851
+ traj_slider,
852
+ timestep_slider,
853
+ image_keys,
854
+ chunk_len,
855
+ display_scale,
856
+ reverse_channels,
857
+ ],
858
  outputs=[gallery, action_plot, info],
859
  )
860
 
861
  render_btn.click(
862
  fn=render_frame,
863
+ inputs=[
864
+ preset,
865
+ custom_repo_id,
866
+ custom_filename,
867
+ traj_slider,
868
+ timestep_slider,
869
+ image_keys,
870
+ chunk_len,
871
+ display_scale,
872
+ reverse_channels,
873
+ ],
874
  outputs=[gallery, action_plot, info],
875
  )
876
 
877
+ inspect_btn.click(
878
+ fn=inspect_hdf5_tree,
879
+ inputs=[preset, custom_repo_id, custom_filename],
880
+ outputs=hdf5_tree,
881
+ )
882
+
883
  demo.load(
884
+ fn=update_after_dataset_change,
885
+ inputs=[preset, custom_repo_id, custom_filename],
886
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
887
  ).then(
888
  fn=render_frame,
889
+ inputs=[
890
+ preset,
891
+ custom_repo_id,
892
+ custom_filename,
893
+ traj_slider,
894
+ timestep_slider,
895
+ image_keys,
896
+ chunk_len,
897
+ display_scale,
898
+ reverse_channels,
899
+ ],
900
  outputs=[gallery, action_plot, info],
901
  )
902