Zhaoting123 commited on
Commit
3cd8b50
·
verified ·
1 Parent(s): 16b8b5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -268,22 +268,46 @@ def load_traj(repo_id, filename, traj_id):
268
 
269
 
270
  def _extract_latest_obs_value(value):
 
 
 
 
 
 
 
271
  arr = np.asarray(value)
272
- if arr.ndim >= 1 and arr.shape[0] in (1, 2, 3, 4):
 
 
 
 
 
 
 
 
 
273
  return arr[-1]
 
274
  return arr
275
 
276
 
277
  def _looks_like_image_array(key, value):
278
- arr = np.asarray(_extract_latest_obs_value(value))
279
  key_l = str(key).lower()
280
  key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS)
281
 
 
 
 
 
 
282
  shape_hint = False
283
  if arr.ndim == 2:
284
  shape_hint = True
285
  elif arr.ndim == 3:
286
  shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)
 
 
287
 
288
  return key_hint or shape_hint
289
 
@@ -918,6 +942,14 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
918
 
919
  status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)
920
 
 
 
 
 
 
 
 
 
921
  info_lines = [
922
  "dataset: {} / {}".format(repo_id, filename),
923
  "detected trajectories: {}".format(n_traj),
@@ -935,6 +967,9 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
935
  "",
936
  "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
937
  "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
 
 
 
938
  ]
939
 
940
  if warnings:
 
268
 
269
 
270
  def _extract_latest_obs_value(value):
271
+ """Return the latest stacked observation only when there is a clear stack axis.
272
+
273
+ Important:
274
+ - [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame.
275
+ - [C, H, W] must NOT be sliced, otherwise an RGB image becomes one
276
+ grayscale channel.
277
+ """
278
  arr = np.asarray(value)
279
+
280
+ # Stacked image observations, e.g. [obs_T, C, H, W] or [obs_T, H, W, C].
281
+ if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
282
+ channel_first = arr.shape[1] in (1, 3, 4)
283
+ channel_last = arr.shape[-1] in (1, 3, 4)
284
+ if channel_first or channel_last:
285
+ return arr[-1]
286
+
287
+ # Stacked vector observations, e.g. [obs_T, D]. Keep this for non-image obs.
288
+ if arr.ndim == 2 and arr.shape[0] in (1, 2):
289
  return arr[-1]
290
+
291
  return arr
292
 
293
 
294
  def _looks_like_image_array(key, value):
295
+ arr = np.asarray(value)
296
  key_l = str(key).lower()
297
  key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS)
298
 
299
+ # Remove only a clear stacked-image axis for shape detection.
300
+ if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
301
+ if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4):
302
+ arr = arr[-1]
303
+
304
  shape_hint = False
305
  if arr.ndim == 2:
306
  shape_hint = True
307
  elif arr.ndim == 3:
308
  shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)
309
+ elif arr.ndim == 4:
310
+ shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4)
311
 
312
  return key_hint or shape_hint
313
 
 
942
 
943
  status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)
944
 
945
+ image_debug_lines = []
946
+ for _key in image_keys:
947
+ if _key in step.get("obs", {}):
948
+ _arr = np.asarray(step["obs"][_key])
949
+ image_debug_lines.append(
950
+ "{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype)
951
+ )
952
+
953
  info_lines = [
954
  "dataset: {} / {}".format(repo_id, filename),
955
  "detected trajectories: {}".format(n_traj),
 
967
  "",
968
  "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
969
  "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
970
+ "",
971
+ "selected image tensors:",
972
+ *image_debug_lines,
973
  ]
974
 
975
  if warnings: