Zhaoting123 commited on
Commit
d85b559
·
verified ·
1 Parent(s): f7db5fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -45
app.py CHANGED
@@ -40,15 +40,32 @@ except Exception:
40
 
41
 
42
  # -----------------------------------------------------------------------------
43
- # EDIT THESE FOR YOUR DATASET
 
44
  # -----------------------------------------------------------------------------
45
- REPO_ID = "Zhaoting123/Robosuite_Square_image_abs_with_state"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  REPO_TYPE = "dataset"
47
- HDF5_FILENAME = (
48
- "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
49
- "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
50
- )
51
-
52
  DEFAULT_CHUNK_LEN = 16
53
  PREFERRED_IMAGE_KEYS = [
54
  "image1",
@@ -61,11 +78,34 @@ PREFERRED_IMAGE_KEYS = [
61
  # -----------------------------------------------------------------------------
62
  # HDF5 helpers
63
  # -----------------------------------------------------------------------------
64
- @lru_cache(maxsize=1)
65
- def get_local_hdf5_path():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return hf_hub_download(
67
- repo_id=REPO_ID,
68
- filename=HDF5_FILENAME,
69
  repo_type=REPO_TYPE,
70
  )
71
 
@@ -75,10 +115,10 @@ def _natural_sort_key(name):
75
  return (0, int(m.group(1))) if m else (1, str(name))
76
 
77
 
78
- @lru_cache(maxsize=1)
79
- def get_trajectory_keys():
80
  """Detect trajectory groups in common HDF5 layouts."""
81
- path = get_local_hdf5_path()
82
  with h5py.File(path, "r") as f:
83
  # Your TrajectoryBuffer saves root-level groups:
84
  # /episode_0000
@@ -106,14 +146,15 @@ def get_trajectory_keys():
106
  return tuple(f"{prefix}/{k}" if prefix else k for k in group_keys)
107
 
108
 
109
- @lru_cache(maxsize=1)
110
- def get_num_trajectories():
111
- return max(len(get_trajectory_keys()), 1)
112
 
113
 
114
- def inspect_hdf5_tree(max_lines=160):
115
  """Show the HDF5 tree for debugging inside the Space."""
116
- path = get_local_hdf5_path()
 
117
  lines = []
118
  with h5py.File(path, "r") as f:
119
  def visitor(name, obj):
@@ -184,8 +225,8 @@ def _infer_time_length(data):
184
  return int(values[np.argmax(counts)])
185
 
186
 
187
- @lru_cache(maxsize=32)
188
- def load_traj(traj_id):
189
  """Load one trajectory as a list of step dictionaries.
190
 
191
  Output step format:
@@ -198,8 +239,8 @@ def load_traj(traj_id):
198
  "no_robot_action": bool,
199
  }
200
  """
201
- path = get_local_hdf5_path()
202
- traj_keys = get_trajectory_keys()
203
  if not traj_keys:
204
  return []
205
 
@@ -415,10 +456,10 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk=None):
415
  return img
416
 
417
 
418
- def get_available_image_keys(traj_id):
419
- n_traj = get_num_trajectories()
420
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
421
- traj = load_traj(traj_id)
422
  if not traj:
423
  return []
424
 
@@ -447,11 +488,12 @@ def get_available_image_keys(traj_id):
447
  # -----------------------------------------------------------------------------
448
  # Gradio callbacks
449
  # -----------------------------------------------------------------------------
450
- def update_after_traj_change(traj_id):
451
- n_traj = get_num_trajectories()
 
452
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
453
- traj = load_traj(traj_id)
454
- image_keys = get_available_image_keys(traj_id)
455
  max_step = max(len(traj) - 1, 0)
456
  slider_max = max(max_step, 1) # Gradio requires min < max.
457
  return (
@@ -460,10 +502,11 @@ def update_after_traj_change(traj_id):
460
  )
461
 
462
 
463
- def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
464
- n_traj = get_num_trajectories()
 
465
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
466
- traj = load_traj(traj_id)
467
 
468
  if not traj:
469
  return [], None, "No trajectory could be loaded. Open the HDF5 debug panel to inspect the file layout."
@@ -525,8 +568,9 @@ def render_frame(traj_id, timestep, image_keys, chunk_len, display_scale, revers
525
  # -----------------------------------------------------------------------------
526
  def build_app():
527
  try:
528
- n_traj = get_num_trajectories()
529
- first_keys = get_available_image_keys(0)
 
530
  startup_error = None
531
  except Exception as exc:
532
  n_traj = 1
@@ -537,7 +581,7 @@ def build_app():
537
  gr.Markdown(
538
  "# HDF5 Trajectory Viewer\n"
539
  "Standalone viewer: no local `TrajectoryBuffer` dependency.\n\n"
540
- f"Detected trajectories: **{n_traj}**"
541
  )
542
 
543
  if startup_error is not None:
@@ -546,6 +590,23 @@ def build_app():
546
  f"```text\n{startup_error}\n```"
547
  )
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  with gr.Row():
550
  traj_slider = gr.Slider(
551
  minimum=0,
@@ -583,7 +644,7 @@ def build_app():
583
  label="Image display scale",
584
  )
585
  reverse_channels = gr.Checkbox(
586
- value=True,
587
  label="Reverse channels BGR↔RGB",
588
  )
589
 
@@ -601,38 +662,96 @@ def build_app():
601
  with gr.Accordion("Debug: HDF5 tree", open=False):
602
  inspect_btn = gr.Button("Inspect HDF5 structure")
603
  hdf5_tree = gr.Textbox(lines=22, label="HDF5 tree")
604
- inspect_btn.click(fn=inspect_hdf5_tree, outputs=hdf5_tree)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  traj_slider.change(
607
  fn=update_after_traj_change,
608
- inputs=traj_slider,
609
  outputs=[timestep_slider, image_keys],
610
  ).then(
611
  fn=render_frame,
612
- inputs=[traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
 
 
 
 
 
 
 
613
  outputs=[gallery, action_plot, info],
614
  )
615
 
616
- for widget in [timestep_slider, image_keys, chunk_len, display_scale, reverse_channels]:
 
617
  widget.change(
618
  fn=render_frame,
619
- inputs=[traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
620
  outputs=[gallery, action_plot, info],
621
  )
622
 
623
  render_btn.click(
624
  fn=render_frame,
625
- inputs=[traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
626
  outputs=[gallery, action_plot, info],
627
  )
628
 
629
  demo.load(
630
  fn=update_after_traj_change,
631
- inputs=traj_slider,
632
  outputs=[timestep_slider, image_keys],
633
  ).then(
634
  fn=render_frame,
635
- inputs=[traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
636
  outputs=[gallery, action_plot, info],
637
  )
638
 
 
40
 
41
 
42
  # -----------------------------------------------------------------------------
43
+ # Dataset presets.
44
+ # The same Space can visualize multiple HDF5 files by changing repo_id + filename.
45
  # -----------------------------------------------------------------------------
46
+ DATASET_PRESETS = {
47
+ "Robosuite Square 20260409": {
48
+ "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
49
+ "filename": (
50
+ "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
51
+ "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
52
+ ),
53
+ },
54
+ "Robosuite Square 20260410": {
55
+ "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
56
+ "filename": (
57
+ "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
58
+ "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
59
+ ),
60
+ },
61
+ "InsertT Nov10 noisy": {
62
+ "repo_id": "Zhaoting123/InsertT",
63
+ "filename": "trajectory_buffer_Nov10_demo_noisy.hdf5",
64
+ },
65
+ }
66
+
67
+ DEFAULT_PRESET = "Robosuite Square 20260409"
68
  REPO_TYPE = "dataset"
 
 
 
 
 
69
  DEFAULT_CHUNK_LEN = 16
70
  PREFERRED_IMAGE_KEYS = [
71
  "image1",
 
78
  # -----------------------------------------------------------------------------
79
  # HDF5 helpers
80
  # -----------------------------------------------------------------------------
81
+ def _clear_dataset_caches():
82
+ get_local_hdf5_path.cache_clear()
83
+ get_trajectory_keys.cache_clear()
84
+ get_num_trajectories.cache_clear()
85
+ load_traj.cache_clear()
86
+
87
+
88
+ def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
89
+ """Return (repo_id, filename) from a preset or custom fields."""
90
+ preset_name = preset_name or DEFAULT_PRESET
91
+ if preset_name == "Custom":
92
+ repo_id = str(custom_repo_id or "").strip()
93
+ filename = str(custom_filename or "").strip()
94
+ if not repo_id or not filename:
95
+ raise ValueError("For Custom, provide both repo_id and HDF5 filename/path.")
96
+ return repo_id, filename
97
+
98
+ if preset_name not in DATASET_PRESETS:
99
+ preset_name = DEFAULT_PRESET
100
+ item = DATASET_PRESETS[preset_name]
101
+ return item["repo_id"], item["filename"]
102
+
103
+
104
+ @lru_cache(maxsize=8)
105
+ def get_local_hdf5_path(repo_id, filename):
106
  return hf_hub_download(
107
+ repo_id=repo_id,
108
+ filename=filename,
109
  repo_type=REPO_TYPE,
110
  )
111
 
 
115
  return (0, int(m.group(1))) if m else (1, str(name))
116
 
117
 
118
+ @lru_cache(maxsize=8)
119
+ def get_trajectory_keys(repo_id, filename):
120
  """Detect trajectory groups in common HDF5 layouts."""
121
+ path = get_local_hdf5_path(repo_id, filename)
122
  with h5py.File(path, "r") as f:
123
  # Your TrajectoryBuffer saves root-level groups:
124
  # /episode_0000
 
146
  return tuple(f"{prefix}/{k}" if prefix else k for k in group_keys)
147
 
148
 
149
+ @lru_cache(maxsize=8)
150
+ def get_num_trajectories(repo_id, filename):
151
+ return max(len(get_trajectory_keys(repo_id, filename)), 1)
152
 
153
 
154
+ def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=160):
155
  """Show the HDF5 tree for debugging inside the Space."""
156
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
157
+ path = get_local_hdf5_path(repo_id, filename)
158
  lines = []
159
  with h5py.File(path, "r") as f:
160
  def visitor(name, obj):
 
225
  return int(values[np.argmax(counts)])
226
 
227
 
228
+ @lru_cache(maxsize=64)
229
+ def load_traj(repo_id, filename, traj_id):
230
  """Load one trajectory as a list of step dictionaries.
231
 
232
  Output step format:
 
239
  "no_robot_action": bool,
240
  }
241
  """
242
+ path = get_local_hdf5_path(repo_id, filename)
243
+ traj_keys = get_trajectory_keys(repo_id, filename)
244
  if not traj_keys:
245
  return []
246
 
 
456
  return img
457
 
458
 
459
+ def get_available_image_keys(repo_id, filename, traj_id):
460
+ n_traj = get_num_trajectories(repo_id, filename)
461
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
462
+ traj = load_traj(repo_id, filename, traj_id)
463
  if not traj:
464
  return []
465
 
 
488
  # -----------------------------------------------------------------------------
489
  # Gradio callbacks
490
  # -----------------------------------------------------------------------------
491
+ def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id):
492
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
493
+ n_traj = get_num_trajectories(repo_id, filename)
494
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
495
+ traj = load_traj(repo_id, filename, traj_id)
496
+ image_keys = get_available_image_keys(repo_id, filename, traj_id)
497
  max_step = max(len(traj) - 1, 0)
498
  slider_max = max(max_step, 1) # Gradio requires min < max.
499
  return (
 
502
  )
503
 
504
 
505
+ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
506
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
507
+ n_traj = get_num_trajectories(repo_id, filename)
508
  traj_id = int(np.clip(int(traj_id), 0, max(n_traj - 1, 0)))
509
+ traj = load_traj(repo_id, filename, traj_id)
510
 
511
  if not traj:
512
  return [], None, "No trajectory could be loaded. Open the HDF5 debug panel to inspect the file layout."
 
568
  # -----------------------------------------------------------------------------
569
  def build_app():
570
  try:
571
+ repo_id, filename = resolve_dataset(DEFAULT_PRESET)
572
+ n_traj = get_num_trajectories(repo_id, filename)
573
+ first_keys = get_available_image_keys(repo_id, filename, 0)
574
  startup_error = None
575
  except Exception as exc:
576
  n_traj = 1
 
581
  gr.Markdown(
582
  "# HDF5 Trajectory Viewer\n"
583
  "Standalone viewer: no local `TrajectoryBuffer` dependency.\n\n"
584
+ f"Default dataset detected trajectories: **{n_traj}**"
585
  )
586
 
587
  if startup_error is not None:
 
590
  f"```text\n{startup_error}\n```"
591
  )
592
 
593
+ with gr.Row():
594
+ preset = gr.Dropdown(
595
+ choices=list(DATASET_PRESETS.keys()) + ["Custom"],
596
+ value=DEFAULT_PRESET,
597
+ label="Dataset preset",
598
+ )
599
+ custom_repo_id = gr.Textbox(
600
+ value="",
601
+ label="Custom repo_id, e.g. Zhaoting123/InsertT",
602
+ visible=False,
603
+ )
604
+ custom_filename = gr.Textbox(
605
+ value="",
606
+ label="Custom HDF5 path in repo",
607
+ visible=False,
608
+ )
609
+
610
  with gr.Row():
611
  traj_slider = gr.Slider(
612
  minimum=0,
 
644
  label="Image display scale",
645
  )
646
  reverse_channels = gr.Checkbox(
647
+ value=False,
648
  label="Reverse channels BGR↔RGB",
649
  )
650
 
 
662
  with gr.Accordion("Debug: HDF5 tree", open=False):
663
  inspect_btn = gr.Button("Inspect HDF5 structure")
664
  hdf5_tree = gr.Textbox(lines=22, label="HDF5 tree")
665
+ inspect_btn.click(
666
+ fn=inspect_hdf5_tree,
667
+ inputs=[preset, custom_repo_id, custom_filename],
668
+ outputs=hdf5_tree,
669
+ )
670
+
671
+ def update_custom_visibility(preset_name):
672
+ visible = preset_name == "Custom"
673
+ return gr.update(visible=visible), gr.update(visible=visible)
674
+
675
+ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
676
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
677
+ n = get_num_trajectories(repo_id, filename)
678
+ keys = get_available_image_keys(repo_id, filename, 0)
679
+ traj = load_traj(repo_id, filename, 0)
680
+ return (
681
+ gr.update(maximum=max(n - 1, 1), value=0),
682
+ gr.update(maximum=max(len(traj) - 1, 1), value=0),
683
+ gr.update(choices=keys, value=keys[:2]),
684
+ f"Loaded `{repo_id}` / `{filename}`
685
+ Detected trajectories: {n}",
686
+ )
687
+
688
+ dataset_status = gr.Textbox(label="Dataset status", lines=2, value=f"Loaded default dataset
689
+ Detected trajectories: {n_traj}")
690
+
691
+ preset.change(
692
+ fn=update_custom_visibility,
693
+ inputs=preset,
694
+ outputs=[custom_repo_id, custom_filename],
695
+ ).then(
696
+ fn=update_after_dataset_change,
697
+ inputs=[preset, custom_repo_id, custom_filename],
698
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
699
+ ).then(
700
+ fn=render_frame,
701
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
702
+ outputs=[gallery, action_plot, info],
703
+ )
704
+
705
+ custom_repo_id.submit(
706
+ fn=update_after_dataset_change,
707
+ inputs=[preset, custom_repo_id, custom_filename],
708
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
709
+ )
710
+ custom_filename.submit(
711
+ fn=update_after_dataset_change,
712
+ inputs=[preset, custom_repo_id, custom_filename],
713
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
714
+ )
715
 
716
  traj_slider.change(
717
  fn=update_after_traj_change,
718
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider],
719
  outputs=[timestep_slider, image_keys],
720
  ).then(
721
  fn=render_frame,
722
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
723
+ outputs=[gallery, action_plot, info],
724
+ )
725
+
726
+ # Use release for the timestep slider so the gallery does not clear/re-render
727
+ # continuously while the user drags through a trajectory.
728
+ timestep_slider.release(
729
+ fn=render_frame,
730
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
731
  outputs=[gallery, action_plot, info],
732
  )
733
 
734
+ # These controls can re-render immediately because they are changed less often.
735
+ for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
736
  widget.change(
737
  fn=render_frame,
738
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
739
  outputs=[gallery, action_plot, info],
740
  )
741
 
742
  render_btn.click(
743
  fn=render_frame,
744
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
745
  outputs=[gallery, action_plot, info],
746
  )
747
 
748
  demo.load(
749
  fn=update_after_traj_change,
750
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider],
751
  outputs=[timestep_slider, image_keys],
752
  ).then(
753
  fn=render_frame,
754
+ inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
755
  outputs=[gallery, action_plot, info],
756
  )
757