Zhaoting123 commited on
Commit
f7db5fc
·
verified ·
1 Parent(s): f084ff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -165
app.py CHANGED
@@ -1,81 +1,293 @@
1
  """
2
- Hugging Face Space viewer for your TrajectoryBuffer HDF5 files.
3
 
4
- Expected Space files:
5
- app.py # this file
 
 
 
6
  requirements.txt
7
- tools/buffer_trajectory.py # copied from your repo, plus any minimal dependencies it needs
8
-
9
- This first Space version visualizes:
10
- - selected trajectory
11
- - selected timestep
12
- - image observations, e.g. image1/image2/agentview_image
13
- - teacher action, robot action
14
- - no_teacher/no_robot flags
15
- - action chunk source mask
16
- - simple action-chunk plot
17
-
18
- It intentionally removes:
19
- - matplotlib keyboard controls
20
- - manual region labeling
21
- - robosuite/MuJoCo environment construction
22
- - 3D camera projection overlay
23
-
24
- You can add the heavy robosuite projection overlay later, but the Space will be much easier
25
- to deploy if the first viewer does not need MuJoCo/robosuite/Hydra.
26
  """
27
 
28
  import os
29
- import sys
30
  from functools import lru_cache
31
 
32
- import cv2
33
  import h5py
34
  import matplotlib
35
  matplotlib.use("Agg")
36
  import matplotlib.pyplot as plt
37
  import numpy as np
38
- import gradio as gr
39
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
40
 
41
 
42
  # -----------------------------------------------------------------------------
43
  # EDIT THESE FOR YOUR DATASET
44
  # -----------------------------------------------------------------------------
45
  REPO_ID = "Zhaoting123/Robosuite_Square_image_abs_with_state"
 
46
  HDF5_FILENAME = (
47
  "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
48
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
49
  )
50
- REPO_TYPE = "dataset"
51
 
52
- # The keys you used locally were ["image1", "image2"].
53
- # For the robosuite upload, you may also have "agentview_image" or
54
- # "robot0_eye_in_hand_image" inside step["obs"].
55
- DEFAULT_IMAGE_KEYS = ["image1", "image2", "agentview_image", "robot0_eye_in_hand_image"]
56
  DEFAULT_CHUNK_LEN = 16
 
 
 
 
 
 
57
 
58
 
59
  # -----------------------------------------------------------------------------
60
- # Import your project loader
61
  # -----------------------------------------------------------------------------
62
- HERE = os.path.dirname(os.path.abspath(__file__))
63
- sys.path.insert(0, HERE)
 
 
 
 
 
64
 
65
- try:
66
- from tools.buffer_trajectory import TrajectoryBuffer
67
- except Exception as exc:
68
- TrajectoryBuffer = None
69
- TRAJECTORY_BUFFER_IMPORT_ERROR = repr(exc)
70
- else:
71
- TRAJECTORY_BUFFER_IMPORT_ERROR = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  # -----------------------------------------------------------------------------
75
- # Small utility functions extracted from your local script
76
  # -----------------------------------------------------------------------------
77
  def _extract_latest_obs_value(value):
78
  arr = np.asarray(value)
 
79
  if arr.ndim >= 1 and arr.shape[0] in (1, 2):
80
  return arr[-1]
81
  return arr
@@ -85,11 +297,19 @@ def _extract_display_image(value, reverse_channels=False, output_uint8=True):
85
  img = _extract_latest_obs_value(value)
86
  img = np.asarray(img)
87
 
 
 
 
 
 
88
  if img.ndim == 2:
89
  img = np.repeat(img[..., None], 3, axis=-1)
90
- elif img.ndim == 3 and img.shape[0] in (1, 3):
91
  img = np.transpose(img, (1, 2, 0))
92
 
 
 
 
93
  if img.ndim != 3:
94
  raise ValueError(f"Unsupported image shape: {img.shape}")
95
 
@@ -111,19 +331,20 @@ def _extract_display_image(value, reverse_channels=False, output_uint8=True):
111
  return img_rgb
112
 
113
 
114
- def _resize_image_for_display(img, display_scale=4, interpolation=None):
115
  if display_scale is None or float(display_scale) == 1.0:
116
  return img
117
- display_scale = float(display_scale)
118
- if display_scale <= 0:
119
- raise ValueError(f"display_scale must be positive, got {display_scale}")
120
 
 
121
  h, w = img.shape[:2]
122
  new_w = max(1, int(round(w * display_scale)))
123
  new_h = max(1, int(round(h * display_scale)))
124
- if interpolation is None:
125
- interpolation = cv2.INTER_NEAREST if display_scale >= 1 else cv2.INTER_AREA
126
- return cv2.resize(img, (new_w, new_h), interpolation=interpolation)
 
 
 
127
 
128
 
129
  def _extract_mixed_action_chunk(traj, start_idx, chunk_length=16):
@@ -134,7 +355,7 @@ def _extract_mixed_action_chunk(traj, start_idx, chunk_length=16):
134
  step = traj[idx]
135
  use_teacher = not bool(step.get("no_teacher_action", False))
136
  action = step["teacher_action"] if use_teacher else step["robot_action"]
137
- chunk.append(np.asarray(action, dtype=np.float32))
138
  sources.append("T" if use_teacher else "R")
139
  if not chunk:
140
  return None, ""
@@ -146,17 +367,16 @@ def _extract_robot_action_chunk(traj, start_idx, chunk_length=16):
146
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
147
  for idx in range(int(start_idx), end_idx):
148
  step = traj[idx]
149
- chunk.append(np.asarray(step["robot_action"], dtype=np.float32))
150
  if not chunk:
151
  return None
152
  return np.stack(chunk, axis=0)
153
 
154
 
155
- def _safe_array_str(x, precision=3, max_items=20):
156
- arr = np.asarray(x)
157
- flat = arr.reshape(-1)
158
- suffix = "" if flat.size <= max_items else f" ... +{flat.size - max_items} more"
159
- shown = flat[:max_items]
160
  return np.array2string(shown, precision=precision, separator=", ") + suffix
161
 
162
 
@@ -170,9 +390,8 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
170
 
171
  fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
172
  t = np.arange(mixed_chunk.shape[0])
173
-
174
- # Plot up to 10 dims to avoid clutter.
175
  max_dims = min(mixed_chunk.shape[1], 10)
 
176
  for d in range(max_dims):
177
  ax.plot(t, mixed_chunk[:, d], label=f"mixed[{d}]")
178
 
@@ -180,8 +399,7 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
180
  robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
181
  if robot_chunk.ndim == 1:
182
  robot_chunk = robot_chunk[:, None]
183
- max_robot_dims = min(robot_chunk.shape[1], max_dims)
184
- for d in range(max_robot_dims):
185
  ax.plot(t, robot_chunk[:, d], linestyle="--", alpha=0.55, label=f"robot[{d}]")
186
 
187
  ax.set_title("Action chunk")
@@ -190,7 +408,6 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
190
  ax.grid(True, alpha=0.3)
191
  ax.legend(loc="upper right", fontsize=7, ncol=2)
192
  fig.tight_layout()
193
-
194
  fig.canvas.draw()
195
  rgba = np.asarray(fig.canvas.buffer_rgba())
196
  img = rgba[..., :3].copy()
@@ -198,104 +415,62 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
198
  return img
199
 
200
 
201
- # -----------------------------------------------------------------------------
202
- # HDF5 download/load/cache
203
- # -----------------------------------------------------------------------------
204
- @lru_cache(maxsize=1)
205
- def get_local_hdf5_path():
206
- return hf_hub_download(
207
- repo_id=REPO_ID,
208
- filename=HDF5_FILENAME,
209
- repo_type=REPO_TYPE,
210
- )
211
-
212
-
213
- @lru_cache(maxsize=1)
214
- def get_buffer():
215
- if TrajectoryBuffer is None:
216
- raise RuntimeError(
217
- "Could not import tools.buffer_trajectory.TrajectoryBuffer. "
218
- "Copy tools/buffer_trajectory.py into the Space. "
219
- f"Original import error: {TRAJECTORY_BUFFER_IMPORT_ERROR}"
220
- )
221
- return TrajectoryBuffer()
222
-
223
-
224
- @lru_cache(maxsize=1)
225
- def get_num_trajectories():
226
- path = get_local_hdf5_path()
227
- return int(get_buffer().count_trajectories_in_hdf5(path))
228
-
229
-
230
- @lru_cache(maxsize=32)
231
- def load_traj(traj_id):
232
- path = get_local_hdf5_path()
233
- traj = get_buffer().load_from_file(path, traj_id=int(traj_id))
234
- return traj
235
-
236
-
237
- def inspect_hdf5_tree(max_lines=120):
238
- """Useful debug panel for Space deployment."""
239
- path = get_local_hdf5_path()
240
- lines = []
241
- with h5py.File(path, "r") as f:
242
- def visitor(name, obj):
243
- if len(lines) >= max_lines:
244
- return
245
- if isinstance(obj, h5py.Dataset):
246
- lines.append(f"DATASET {name} shape={obj.shape} dtype={obj.dtype}")
247
- elif isinstance(obj, h5py.Group):
248
- lines.append(f"GROUP {name}")
249
- f.visititems(visitor)
250
- if len(lines) >= max_lines:
251
- lines.append("...")
252
- return "\n".join(lines) if lines else "Empty or unsupported HDF5 tree."
253
-
254
-
255
- # -----------------------------------------------------------------------------
256
- # Gradio rendering functions
257
- # -----------------------------------------------------------------------------
258
  def get_available_image_keys(traj_id):
259
- traj = load_traj(int(traj_id))
 
 
260
  if not traj:
261
  return []
 
262
  obs = traj[0].get("obs", {})
263
- keys = []
264
- for k, v in obs.items():
265
  try:
266
- arr = np.asarray(_extract_latest_obs_value(v))
267
- if arr.ndim in (2, 3):
268
- keys.append(k)
 
 
 
 
 
 
 
269
  except Exception:
270
  pass
271
- # Prefer your known keys first, then any detected keys.
272
- ordered = [k for k in DEFAULT_IMAGE_KEYS if k in keys]
273
- ordered += [k for k in keys if k not in ordered]
274
  return ordered
275
 
276
 
 
 
 
277
  def update_after_traj_change(traj_id):
278
- traj = load_traj(int(traj_id))
279
- image_keys = get_available_image_keys(int(traj_id))
 
 
280
  max_step = max(len(traj) - 1, 0)
281
- default_keys = image_keys[:2] if image_keys else []
282
  return (
283
- gr.update(maximum=max_step, value=0),
284
- gr.update(choices=image_keys, value=default_keys),
285
  )
286
 
287
 
288
  def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
289
- traj_id = int(traj_id)
290
- timestep = int(timestep)
291
- chunk_len = int(chunk_len)
292
- display_scale = float(display_scale)
293
-
294
  traj = load_traj(traj_id)
 
295
  if not traj:
296
- return [], None, "No trajectory loaded."
297
 
298
- timestep = int(np.clip(timestep, 0, len(traj) - 1))
 
 
299
  step = traj[timestep]
300
  obs = step.get("obs", {})
301
 
@@ -311,30 +486,23 @@ def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, revers
311
  errors.append(f"Missing image key: {key}")
312
  continue
313
  try:
314
- img = _extract_display_image(
315
- obs[key],
316
- reverse_channels=bool(reverse_channels),
317
- output_uint8=True,
318
- )
319
  img = _resize_image_for_display(img, display_scale=display_scale)
320
  gallery.append((img, key))
321
  except Exception as exc:
322
  errors.append(f"{key}: {exc}")
323
 
324
- mixed_chunk, chunk_sources = _extract_mixed_action_chunk(
325
- traj, timestep, chunk_length=chunk_len
326
- )
327
- robot_chunk = _extract_robot_action_chunk(
328
- traj, timestep, chunk_length=chunk_len
329
- )
330
  action_plot = _make_action_chunk_plot(mixed_chunk, robot_chunk)
331
 
 
 
332
  no_teacher = bool(step.get("no_teacher_action", False))
333
  no_robot = bool(step.get("no_robot_action", False))
334
- teacher_action = step.get("teacher_action", None)
335
- robot_action = step.get("robot_action", None)
336
 
337
  info_lines = [
 
338
  f"trajectory: {traj_id}",
339
  f"timestep: {timestep} / {len(traj) - 1}",
340
  f"no_teacher_action: {int(no_teacher)}",
@@ -342,8 +510,8 @@ def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, revers
342
  f"chunk_len: {chunk_len}",
343
  f"chunk source mask: {chunk_sources} (T=teacher, R=robot fallback)",
344
  "",
345
- f"teacher_action: {_safe_array_str(teacher_action) if teacher_action is not None else 'None'}",
346
- f"robot_action: {_safe_array_str(robot_action) if robot_action is not None else 'None'}",
347
  ]
348
 
349
  if errors:
@@ -352,34 +520,36 @@ def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, revers
352
  return gallery, action_plot, "\n".join(info_lines)
353
 
354
 
 
 
 
355
  def build_app():
356
  try:
357
  n_traj = get_num_trajectories()
358
- first_keys = get_available_image_keys(0) if n_traj > 0 else []
359
  startup_error = None
360
  except Exception as exc:
361
  n_traj = 1
362
  first_keys = []
363
  startup_error = repr(exc)
364
 
365
- with gr.Blocks(title="Robosuite HDF5 Trajectory Viewer") as demo:
366
  gr.Markdown(
367
- "# Robosuite HDF5 Trajectory Viewer\n"
368
- "Browse trajectories directly from the Hugging Face dataset repository."
 
369
  )
370
 
371
  if startup_error is not None:
372
  gr.Markdown(
373
  "⚠️ **Startup warning**\n\n"
374
- f"```text\n{startup_error}\n```\n"
375
- "Most likely fix: copy `tools/buffer_trajectory.py` and its minimal dependencies "
376
- "into this Space."
377
  )
378
 
379
  with gr.Row():
380
  traj_slider = gr.Slider(
381
  minimum=0,
382
- maximum=max(n_traj - 1, 0),
383
  value=0,
384
  step=1,
385
  label="Trajectory index",
@@ -414,7 +584,7 @@ def build_app():
414
  )
415
  reverse_channels = gr.Checkbox(
416
  value=True,
417
- label="Reverse image channels BGR↔RGB",
418
  )
419
 
420
  render_btn = gr.Button("Render frame", variant="primary")
@@ -426,11 +596,11 @@ def build_app():
426
  object_fit="contain",
427
  )
428
  action_plot = gr.Image(label="Action chunk plot", type="numpy")
429
- info = gr.Textbox(label="Frame info", lines=12)
430
 
431
  with gr.Accordion("Debug: HDF5 tree", open=False):
432
  inspect_btn = gr.Button("Inspect HDF5 structure")
433
- hdf5_tree = gr.Textbox(lines=20, label="HDF5 tree")
434
  inspect_btn.click(fn=inspect_hdf5_tree, outputs=hdf5_tree)
435
 
436
  traj_slider.change(
 
1
  """
2
+ Standalone Hugging Face Space viewer for HDF5 trajectory datasets.
3
 
4
+ This version does NOT require your local TrajectoryBuffer class.
5
+ It reads the HDF5 file directly with h5py.
6
+
7
+ Files needed in the Space:
8
+ app.py
9
  requirements.txt
10
+
11
+ requirements.txt:
12
+ gradio
13
+ huggingface_hub
14
+ h5py
15
+ numpy
16
+ pillow
17
+ matplotlib
18
+
19
+ Optional:
20
+ opencv-python-headless
 
 
 
 
 
 
 
 
21
  """
22
 
23
  import os
24
+ import re
25
  from functools import lru_cache
26
 
27
+ import gradio as gr
28
  import h5py
29
  import matplotlib
30
  matplotlib.use("Agg")
31
  import matplotlib.pyplot as plt
32
  import numpy as np
 
33
  from huggingface_hub import hf_hub_download
34
+ from PIL import Image
35
+
36
+ try:
37
+ import cv2
38
+ except Exception:
39
+ cv2 = None
40
 
41
 
42
  # -----------------------------------------------------------------------------
43
  # EDIT THESE FOR YOUR DATASET
44
  # -----------------------------------------------------------------------------
45
  REPO_ID = "Zhaoting123/Robosuite_Square_image_abs_with_state"
46
+ REPO_TYPE = "dataset"
47
  HDF5_FILENAME = (
48
  "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
49
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
50
  )
 
51
 
 
 
 
 
52
  DEFAULT_CHUNK_LEN = 16
53
+ PREFERRED_IMAGE_KEYS = [
54
+ "image1",
55
+ "image2",
56
+ "agentview_image",
57
+ "robot0_eye_in_hand_image",
58
+ ]
59
 
60
 
61
  # -----------------------------------------------------------------------------
62
+ # HDF5 helpers
63
  # -----------------------------------------------------------------------------
64
+ @lru_cache(maxsize=1)
65
+ def get_local_hdf5_path():
66
+ return hf_hub_download(
67
+ repo_id=REPO_ID,
68
+ filename=HDF5_FILENAME,
69
+ repo_type=REPO_TYPE,
70
+ )
71
 
72
+
73
+ def _natural_sort_key(name):
74
+ m = re.search(r"([0-9]+)$", str(name))
75
+ return (0, int(m.group(1))) if m else (1, str(name))
76
+
77
+
78
+ @lru_cache(maxsize=1)
79
+ def get_trajectory_keys():
80
+ """Detect trajectory groups in common HDF5 layouts."""
81
+ path = get_local_hdf5_path()
82
+ with h5py.File(path, "r") as f:
83
+ # Your TrajectoryBuffer saves root-level groups:
84
+ # /episode_0000
85
+ # /episode_0001
86
+ # ...
87
+ # Some other robotics datasets use /data/demo_0, so keep that fallback.
88
+ root_episode_keys = [
89
+ k for k in f.keys()
90
+ if isinstance(f[k], h5py.Group) and str(k).startswith("episode_")
91
+ ]
92
+ if root_episode_keys:
93
+ group = f
94
+ prefix = ""
95
+ group_keys = root_episode_keys
96
+ elif "data" in f and isinstance(f["data"], h5py.Group):
97
+ group = f["data"]
98
+ prefix = "data"
99
+ group_keys = [k for k in group.keys() if isinstance(group[k], h5py.Group)]
100
+ else:
101
+ group = f
102
+ prefix = ""
103
+ group_keys = [k for k in group.keys() if isinstance(group[k], h5py.Group)]
104
+
105
+ group_keys = sorted(group_keys, key=_natural_sort_key)
106
+ return tuple(f"{prefix}/{k}" if prefix else k for k in group_keys)
107
+
108
+
109
+ @lru_cache(maxsize=1)
110
+ def get_num_trajectories():
111
+ return max(len(get_trajectory_keys()), 1)
112
+
113
+
114
+ def inspect_hdf5_tree(max_lines=160):
115
+ """Show the HDF5 tree for debugging inside the Space."""
116
+ path = get_local_hdf5_path()
117
+ lines = []
118
+ with h5py.File(path, "r") as f:
119
+ def visitor(name, obj):
120
+ if len(lines) >= max_lines:
121
+ return
122
+ if isinstance(obj, h5py.Dataset):
123
+ lines.append(f"DATASET {name} shape={obj.shape} dtype={obj.dtype}")
124
+ elif isinstance(obj, h5py.Group):
125
+ lines.append(f"GROUP {name}")
126
+ f.visititems(visitor)
127
+
128
+ if len(lines) >= max_lines:
129
+ lines.append("...")
130
+ return "\n".join(lines) if lines else "No HDF5 contents found."
131
+
132
+
133
+ def _read_dataset_value(ds):
134
+ value = ds[()]
135
+ if isinstance(value, bytes):
136
+ return value.decode("utf-8")
137
+ return value
138
+
139
+
140
+ def _read_group_recursive(group):
141
+ """Read a group into nested dictionaries of numpy arrays."""
142
+ out = {}
143
+ for key, obj in group.items():
144
+ if isinstance(obj, h5py.Dataset):
145
+ out[key] = _read_dataset_value(obj)
146
+ elif isinstance(obj, h5py.Group):
147
+ out[key] = _read_group_recursive(obj)
148
+ return out
149
+
150
+
151
+ def _find_first_existing_key(mapping, candidates):
152
+ for key in candidates:
153
+ if key in mapping:
154
+ return key
155
+ return None
156
+
157
+
158
+ def _maybe_time_slice(value, t, T):
159
+ arr = np.asarray(value)
160
+ if arr.ndim >= 1 and arr.shape[0] == T:
161
+ return arr[t]
162
+ return arr
163
+
164
+
165
+ def _infer_time_length(data):
166
+ """Infer T from datasets whose first dimension is time."""
167
+ candidate_lengths = []
168
+
169
+ def collect(obj):
170
+ if isinstance(obj, dict):
171
+ for v in obj.values():
172
+ collect(v)
173
+ else:
174
+ arr = np.asarray(obj)
175
+ if arr.ndim >= 1 and arr.shape[0] > 1:
176
+ candidate_lengths.append(int(arr.shape[0]))
177
+
178
+ collect(data)
179
+ if not candidate_lengths:
180
+ return 1
181
+
182
+ # The trajectory length should usually be the most common large first dim.
183
+ values, counts = np.unique(candidate_lengths, return_counts=True)
184
+ return int(values[np.argmax(counts)])
185
+
186
+
187
+ @lru_cache(maxsize=32)
188
+ def load_traj(traj_id):
189
+ """Load one trajectory as a list of step dictionaries.
190
+
191
+ Output step format:
192
+ {
193
+ "timestep": int,
194
+ "obs": dict,
195
+ "teacher_action": np.ndarray,
196
+ "robot_action": np.ndarray,
197
+ "no_teacher_action": bool,
198
+ "no_robot_action": bool,
199
+ }
200
+ """
201
+ path = get_local_hdf5_path()
202
+ traj_keys = get_trajectory_keys()
203
+ if not traj_keys:
204
+ return []
205
+
206
+ traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1))
207
+ traj_key = traj_keys[traj_id]
208
+
209
+ with h5py.File(path, "r") as f:
210
+ g = f[traj_key]
211
+ data = _read_group_recursive(g)
212
+ attrs = dict(g.attrs)
213
+
214
+ # Case A: trajectory group contains step groups: step_0, step_1, ...
215
+ step_group_keys = [
216
+ k for k, v in data.items()
217
+ if isinstance(v, dict) and (str(k).startswith("step") or str(k).isdigit())
218
+ ]
219
+ if step_group_keys:
220
+ traj = []
221
+ for step_key in sorted(step_group_keys, key=_natural_sort_key):
222
+ step = data[step_key]
223
+ obs = step.get("obs", {}) if isinstance(step.get("obs", {}), dict) else {}
224
+ teacher_action = step.get("teacher_action", step.get("teacher_actions", step.get("action", step.get("actions", np.zeros(1, dtype=np.float32)))))
225
+ robot_action = step.get("robot_action", step.get("robot_actions", step.get("action", step.get("actions", teacher_action))))
226
+ traj.append({
227
+ "timestep": int(step.get("timestep", len(traj))),
228
+ "obs": obs,
229
+ "teacher_action": np.asarray(teacher_action),
230
+ "robot_action": np.asarray(robot_action),
231
+ "no_teacher_action": bool(np.asarray(step.get("no_teacher_action", step.get("no_teacher_actions", False))).reshape(-1)[0]),
232
+ "no_robot_action": bool(np.asarray(step.get("no_robot_action", step.get("no_robot_actions", False))).reshape(-1)[0]),
233
+ })
234
+ return traj
235
+
236
+ # Case B: trajectory group contains array datasets with first dimension T.
237
+ # Your TrajectoryBuffer layout is:
238
+ # /episode_0000/observation/<image_or_state_key>[T,...]
239
+ # /episode_0000/robot_actions[T,D]
240
+ # /episode_0000/teacher_actions[T,D]
241
+ # /episode_0000/no_teacher_actions[T]
242
+ # /episode_0000/no_robot_actions[T]
243
+ #
244
+ # Keep obs/action aliases for compatibility with other layouts.
245
+ T = _infer_time_length(data)
246
+ obs_all = {}
247
+ if isinstance(data.get("observation", {}), dict):
248
+ obs_all = data.get("observation", {})
249
+ elif isinstance(data.get("obs", {}), dict):
250
+ obs_all = data.get("obs", {})
251
+
252
+ action_key = _find_first_existing_key(data, ["actions", "action"])
253
+ teacher_key = _find_first_existing_key(data, ["teacher_actions", "teacher_action"])
254
+ robot_key = _find_first_existing_key(data, ["robot_actions", "robot_action"])
255
+ no_teacher_key = _find_first_existing_key(data, ["no_teacher_actions", "no_teacher_action"])
256
+ no_robot_key = _find_first_existing_key(data, ["no_robot_actions", "no_robot_action"])
257
+
258
+ traj = []
259
+ for t in range(T):
260
+ obs_t = {}
261
+ for key, value in obs_all.items():
262
+ obs_t[key] = _maybe_time_slice(value, t, T)
263
+
264
+ default_action = np.zeros(1, dtype=np.float32)
265
+ if action_key is not None:
266
+ default_action = _maybe_time_slice(data[action_key], t, T)
267
+
268
+ teacher_action = _maybe_time_slice(data[teacher_key], t, T) if teacher_key is not None else default_action
269
+ robot_action = _maybe_time_slice(data[robot_key], t, T) if robot_key is not None else default_action
270
+ no_teacher = _maybe_time_slice(data[no_teacher_key], t, T) if no_teacher_key is not None else False
271
+ no_robot = _maybe_time_slice(data[no_robot_key], t, T) if no_robot_key is not None else False
272
+
273
+ traj.append({
274
+ "timestep": t,
275
+ "obs": obs_t,
276
+ "teacher_action": np.asarray(teacher_action),
277
+ "robot_action": np.asarray(robot_action),
278
+ "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]),
279
+ "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]),
280
+ })
281
+
282
+ return traj
283
 
284
 
285
  # -----------------------------------------------------------------------------
286
+ # Visualization helpers
287
  # -----------------------------------------------------------------------------
288
  def _extract_latest_obs_value(value):
289
  arr = np.asarray(value)
290
+ # Your local script handled stacked recent observations by taking the latest.
291
  if arr.ndim >= 1 and arr.shape[0] in (1, 2):
292
  return arr[-1]
293
  return arr
 
297
  img = _extract_latest_obs_value(value)
298
  img = np.asarray(img)
299
 
300
+ # Your saved image shape can be [obs_T, C, H, W] per timestep.
301
+ # Take the latest stacked observation, then convert CHW -> HWC.
302
+ if img.ndim == 4 and img.shape[0] in (1, 2, 3, 4):
303
+ img = img[-1]
304
+
305
  if img.ndim == 2:
306
  img = np.repeat(img[..., None], 3, axis=-1)
307
+ elif img.ndim == 3 and img.shape[0] in (1, 3, 4):
308
  img = np.transpose(img, (1, 2, 0))
309
 
310
+ if img.ndim == 3 and img.shape[-1] == 4:
311
+ img = img[..., :3]
312
+
313
  if img.ndim != 3:
314
  raise ValueError(f"Unsupported image shape: {img.shape}")
315
 
 
331
  return img_rgb
332
 
333
 
334
+ def _resize_image_for_display(img, display_scale=4):
335
  if display_scale is None or float(display_scale) == 1.0:
336
  return img
 
 
 
337
 
338
+ display_scale = float(display_scale)
339
  h, w = img.shape[:2]
340
  new_w = max(1, int(round(w * display_scale)))
341
  new_h = max(1, int(round(h * display_scale)))
342
+
343
+ if cv2 is not None:
344
+ return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
345
+
346
+ pil_img = Image.fromarray(img)
347
+ return np.asarray(pil_img.resize((new_w, new_h), resample=Image.Resampling.NEAREST))
348
 
349
 
350
  def _extract_mixed_action_chunk(traj, start_idx, chunk_length=16):
 
355
  step = traj[idx]
356
  use_teacher = not bool(step.get("no_teacher_action", False))
357
  action = step["teacher_action"] if use_teacher else step["robot_action"]
358
+ chunk.append(np.asarray(action, dtype=np.float32).reshape(-1))
359
  sources.append("T" if use_teacher else "R")
360
  if not chunk:
361
  return None, ""
 
367
  end_idx = min(len(traj), int(start_idx) + int(chunk_length))
368
  for idx in range(int(start_idx), end_idx):
369
  step = traj[idx]
370
+ chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1))
371
  if not chunk:
372
  return None
373
  return np.stack(chunk, axis=0)
374
 
375
 
376
+ def _safe_array_str(x, precision=3, max_items=24):
377
+ arr = np.asarray(x).reshape(-1)
378
+ shown = arr[:max_items]
379
+ suffix = "" if arr.size <= max_items else f" ... +{arr.size - max_items} more"
 
380
  return np.array2string(shown, precision=precision, separator=", ") + suffix
381
 
382
 
 
390
 
391
  fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
392
  t = np.arange(mixed_chunk.shape[0])
 
 
393
  max_dims = min(mixed_chunk.shape[1], 10)
394
+
395
  for d in range(max_dims):
396
  ax.plot(t, mixed_chunk[:, d], label=f"mixed[{d}]")
397
 
 
399
  robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
400
  if robot_chunk.ndim == 1:
401
  robot_chunk = robot_chunk[:, None]
402
+ for d in range(min(robot_chunk.shape[1], max_dims)):
 
403
  ax.plot(t, robot_chunk[:, d], linestyle="--", alpha=0.55, label=f"robot[{d}]")
404
 
405
  ax.set_title("Action chunk")
 
408
  ax.grid(True, alpha=0.3)
409
  ax.legend(loc="upper right", fontsize=7, ncol=2)
410
  fig.tight_layout()
 
411
  fig.canvas.draw()
412
  rgba = np.asarray(fig.canvas.buffer_rgba())
413
  img = rgba[..., :3].copy()
 
415
  return img
416
 
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  def get_available_image_keys(traj_id):
419
+ n_traj = get_num_trajectories()
420
+ traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
421
+ traj = load_traj(traj_id)
422
  if not traj:
423
  return []
424
+
425
  obs = traj[0].get("obs", {})
426
+ image_keys = []
427
+ for key, value in obs.items():
428
  try:
429
+ arr = np.asarray(_extract_latest_obs_value(value))
430
+ key_l = str(key).lower()
431
+ key_hint = any(s in key_l for s in ["rgb", "image", "img", "camera", "cam"])
432
+ looks_like_shape = (
433
+ arr.ndim == 2
434
+ or (arr.ndim == 3 and (arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)))
435
+ or (arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4) and arr.shape[1] in (1, 3, 4))
436
+ )
437
+ if key_hint or looks_like_shape:
438
+ image_keys.append(key)
439
  except Exception:
440
  pass
441
+
442
+ ordered = [k for k in PREFERRED_IMAGE_KEYS if k in image_keys]
443
+ ordered += [k for k in image_keys if k not in ordered]
444
  return ordered
445
 
446
 
447
+ # -----------------------------------------------------------------------------
448
+ # Gradio callbacks
449
+ # -----------------------------------------------------------------------------
450
  def update_after_traj_change(traj_id):
451
+ n_traj = get_num_trajectories()
452
+ traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
453
+ traj = load_traj(traj_id)
454
+ image_keys = get_available_image_keys(traj_id)
455
  max_step = max(len(traj) - 1, 0)
456
+ slider_max = max(max_step, 1) # Gradio requires min < max.
457
  return (
458
+ gr.update(maximum=slider_max, value=0),
459
+ gr.update(choices=image_keys, value=image_keys[:2]),
460
  )
461
 
462
 
463
  def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
464
+ n_traj = get_num_trajectories()
465
+ traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
 
 
 
466
  traj = load_traj(traj_id)
467
+
468
  if not traj:
469
+ return [], None, "No trajectory could be loaded. Open the HDF5 debug panel to inspect the file layout."
470
 
471
+ timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
472
+ chunk_len = int(chunk_len)
473
+ display_scale = float(display_scale)
474
  step = traj[timestep]
475
  obs = step.get("obs", {})
476
 
 
486
  errors.append(f"Missing image key: {key}")
487
  continue
488
  try:
489
+ img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels), output_uint8=True)
 
 
 
 
490
  img = _resize_image_for_display(img, display_scale=display_scale)
491
  gallery.append((img, key))
492
  except Exception as exc:
493
  errors.append(f"{key}: {exc}")
494
 
495
+ mixed_chunk, chunk_sources = _extract_mixed_action_chunk(traj, timestep, chunk_length=chunk_len)
496
+ robot_chunk = _extract_robot_action_chunk(traj, timestep, chunk_length=chunk_len)
 
 
 
 
497
  action_plot = _make_action_chunk_plot(mixed_chunk, robot_chunk)
498
 
499
+ teacher_action = step.get("teacher_action", np.zeros(1, dtype=np.float32))
500
+ robot_action = step.get("robot_action", np.zeros(1, dtype=np.float32))
501
  no_teacher = bool(step.get("no_teacher_action", False))
502
  no_robot = bool(step.get("no_robot_action", False))
 
 
503
 
504
  info_lines = [
505
+ f"detected trajectories: {n_traj}",
506
  f"trajectory: {traj_id}",
507
  f"timestep: {timestep} / {len(traj) - 1}",
508
  f"no_teacher_action: {int(no_teacher)}",
 
510
  f"chunk_len: {chunk_len}",
511
  f"chunk source mask: {chunk_sources} (T=teacher, R=robot fallback)",
512
  "",
513
+ f"teacher_action: {_safe_array_str(teacher_action)}",
514
+ f"robot_action: {_safe_array_str(robot_action)}",
515
  ]
516
 
517
  if errors:
 
520
  return gallery, action_plot, "\n".join(info_lines)
521
 
522
 
523
+ # -----------------------------------------------------------------------------
524
+ # App
525
+ # -----------------------------------------------------------------------------
526
  def build_app():
527
  try:
528
  n_traj = get_num_trajectories()
529
+ first_keys = get_available_image_keys(0)
530
  startup_error = None
531
  except Exception as exc:
532
  n_traj = 1
533
  first_keys = []
534
  startup_error = repr(exc)
535
 
536
+ with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
537
  gr.Markdown(
538
+ "# HDF5 Trajectory Viewer\n"
539
+ "Standalone viewer: no local `TrajectoryBuffer` dependency.\n\n"
540
+ f"Detected trajectories: **{n_traj}**"
541
  )
542
 
543
  if startup_error is not None:
544
  gr.Markdown(
545
  "⚠️ **Startup warning**\n\n"
546
+ f"```text\n{startup_error}\n```"
 
 
547
  )
548
 
549
  with gr.Row():
550
  traj_slider = gr.Slider(
551
  minimum=0,
552
+ maximum=max(n_traj - 1, 1),
553
  value=0,
554
  step=1,
555
  label="Trajectory index",
 
584
  )
585
  reverse_channels = gr.Checkbox(
586
  value=True,
587
+ label="Reverse channels BGR↔RGB",
588
  )
589
 
590
  render_btn = gr.Button("Render frame", variant="primary")
 
596
  object_fit="contain",
597
  )
598
  action_plot = gr.Image(label="Action chunk plot", type="numpy")
599
+ info = gr.Textbox(label="Frame info", lines=13)
600
 
601
  with gr.Accordion("Debug: HDF5 tree", open=False):
602
  inspect_btn = gr.Button("Inspect HDF5 structure")
603
+ hdf5_tree = gr.Textbox(lines=22, label="HDF5 tree")
604
  inspect_btn.click(fn=inspect_hdf5_tree, outputs=hdf5_tree)
605
 
606
  traj_slider.change(