Zhaoting123 commited on
Commit
cdbfe0c
·
verified ·
1 Parent(s): f2f511d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -18
app.py CHANGED
@@ -445,13 +445,110 @@ def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tu
445
  return gallery_items, tuple(warnings)
446
 
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  @lru_cache(maxsize=8192)
449
- def get_cached_action_plot(repo_id, filename, traj_id, timestep, chunk_len):
450
  traj = load_traj(repo_id, filename, int(traj_id))
451
  timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
452
- mixed_chunk, source_mask = _extract_mixed_action_chunk(traj, timestep, int(chunk_len))
453
- robot_chunk = _extract_robot_action_chunk(traj, timestep, int(chunk_len))
454
- return _make_action_chunk_plot(mixed_chunk, robot_chunk), source_mask
455
 
456
 
457
  def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
@@ -474,7 +571,7 @@ def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, tra
474
  total = len(traj)
475
  for t in range(total):
476
  get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
477
- get_cached_action_plot(repo_id, filename, traj_id, t, int(chunk_len))
478
 
479
  status = "Preloaded trajectory {}".format(traj_id)
480
  status += "\nFrames cached: {}".format(total)
@@ -675,7 +772,7 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
675
  )
676
  warnings = list(warnings_tuple)
677
 
678
- action_plot, source_mask = get_cached_action_plot(repo_id, filename, traj_id, timestep, chunk_len)
679
 
680
  info_lines = [
681
  "dataset: {} / {}".format(repo_id, filename),
@@ -688,8 +785,9 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
688
  "if_success: {}".format(int(bool(step.get("if_success", False)))),
689
  "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
690
  "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
691
- "chunk_len: {}".format(chunk_len),
692
- "chunk source mask: {} (T=teacher, R=robot fallback)".format(source_mask),
 
693
  "",
694
  "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
695
  "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
@@ -700,7 +798,7 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
700
  info_lines.append("Image warnings:")
701
  info_lines.extend(warnings)
702
 
703
- return gallery_items, action_plot, "\n".join(info_lines)
704
 
705
 
706
  def build_app():
@@ -720,7 +818,8 @@ def build_app():
720
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
721
  gr.Markdown(
722
  "# HDF5 Trajectory Viewer\n\n"
723
- "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face."
 
724
  )
725
 
726
  if startup_warning:
@@ -743,7 +842,7 @@ def build_app():
743
 
744
  with gr.Row():
745
  image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
746
- chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Action chunk length")
747
  display_scale = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Image display scale")
748
  reverse_channels = gr.Checkbox(value=False, label="Reverse channels BGR↔RGB")
749
 
@@ -757,7 +856,7 @@ def build_app():
757
  trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
758
 
759
  gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
760
- action_plot = gr.Image(label="Action chunk plot", type="numpy")
761
  info = gr.Textbox(label="Frame info", lines=16)
762
 
763
  with gr.Accordion("Debug: HDF5 tree", open=False):
@@ -775,7 +874,7 @@ def build_app():
775
  ).then(
776
  fn=render_frame,
777
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
778
- outputs=[gallery, action_plot, info],
779
  )
780
 
781
  custom_repo_id.submit(
@@ -796,26 +895,26 @@ def build_app():
796
  ).then(
797
  fn=render_frame,
798
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
799
- outputs=[gallery, action_plot, info],
800
  )
801
 
802
  timestep_slider.release(
803
  fn=render_frame,
804
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
805
- outputs=[gallery, action_plot, info],
806
  )
807
 
808
  for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
809
  widget.change(
810
  fn=render_frame,
811
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
812
- outputs=[gallery, action_plot, info],
813
  )
814
 
815
  render_btn.click(
816
  fn=render_frame,
817
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
818
- outputs=[gallery, action_plot, info],
819
  )
820
 
821
  preload_btn.click(
@@ -843,7 +942,7 @@ def build_app():
843
  ).then(
844
  fn=render_frame,
845
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
846
- outputs=[gallery, action_plot, info],
847
  )
848
 
849
  return demo
 
445
  return gallery_items, tuple(warnings)
446
 
447
 
448
+ def _compute_valid_start_indices(traj, min_seq_len):
449
+ """Match the original local script's valid-start heuristic.
450
+
451
+ A timestep is valid when the following min_seq_len steps all have
452
+ no_teacher_action == False.
453
+ """
454
+ total_steps = len(traj)
455
+ min_seq_len = int(max(1, min_seq_len))
456
+ no_teacher = np.asarray(
457
+ [int(bool(step.get("no_teacher_action", False))) for step in traj],
458
+ dtype=np.int32,
459
+ )
460
+
461
+ valid_indices = []
462
+ max_start = total_steps - min_seq_len + 1
463
+ for t in range(max(0, max_start)):
464
+ if int(np.sum(no_teacher[t:t + min_seq_len])) == 0:
465
+ valid_indices.append(t)
466
+
467
+ return no_teacher, valid_indices
468
+
469
+
470
+ def _make_trajectory_status_plot(traj, timestep, min_seq_len):
471
+ """Render the same high-level status figure as the local matplotlib tool.
472
+
473
+ Shows:
474
+ - orange no_teacher_action step plot
475
+ - green triangles for algorithmic valid start points
476
+ - black vertical cursor at current timestep
477
+ """
478
+ total_steps = len(traj)
479
+ if total_steps == 0:
480
+ return None, False, 0
481
+
482
+ timestep = int(np.clip(int(timestep), 0, total_steps - 1))
483
+ timesteps = np.asarray(
484
+ [int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)],
485
+ dtype=np.int32,
486
+ )
487
+ no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
488
+ is_valid_start = timestep in set(valid_indices)
489
+
490
+ fig, ax = plt.subplots(figsize=(9, 2.4), dpi=140)
491
+
492
+ ax.step(
493
+ np.arange(total_steps),
494
+ no_teacher,
495
+ where="post",
496
+ label="no_teacher_action",
497
+ color="orange",
498
+ )
499
+
500
+ if valid_indices:
501
+ ax.scatter(
502
+ valid_indices,
503
+ [-0.15] * len(valid_indices),
504
+ color="green",
505
+ marker="^",
506
+ s=18,
507
+ label="Valid Start (len >= {})".format(int(min_seq_len)),
508
+ )
509
+
510
+ ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
511
+ ax.set_xlim(0, max(total_steps - 1, 1))
512
+ ax.set_ylim(-0.38, 1.1)
513
+ ax.set_ylabel("Flag")
514
+ ax.set_xlabel("Timestep index")
515
+ ax.set_yticks([0, 1])
516
+ ax.set_yticklabels(["False", "True"])
517
+ ax.grid(True, axis="x", alpha=0.2)
518
+
519
+ title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
520
+ if is_valid_start:
521
+ title += " | VALID START"
522
+ ax.set_title(title)
523
+ ax.legend(loc="upper right", fontsize=8)
524
+
525
+ # Add saved timestep annotation if the stored timestep is not the same as index.
526
+ saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
527
+ if saved_timestep != timestep:
528
+ ax.text(
529
+ 0.01,
530
+ 0.04,
531
+ "saved timestep: {}".format(saved_timestep),
532
+ transform=ax.transAxes,
533
+ fontsize=8,
534
+ va="bottom",
535
+ ha="left",
536
+ )
537
+
538
+ fig.tight_layout()
539
+ fig.canvas.draw()
540
+ rgba = np.asarray(fig.canvas.buffer_rgba())
541
+ image = rgba[..., :3].copy()
542
+ plt.close(fig)
543
+
544
+ return image, bool(is_valid_start), len(valid_indices)
545
+
546
+
547
  @lru_cache(maxsize=8192)
548
+ def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len):
549
  traj = load_traj(repo_id, filename, int(traj_id))
550
  timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
551
+ return _make_trajectory_status_plot(traj, timestep, int(min_seq_len))
 
 
552
 
553
 
554
  def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
 
571
  total = len(traj)
572
  for t in range(total):
573
  get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
574
+ get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len))
575
 
576
  status = "Preloaded trajectory {}".format(traj_id)
577
  status += "\nFrames cached: {}".format(total)
 
772
  )
773
  warnings = list(warnings_tuple)
774
 
775
+ status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)
776
 
777
  info_lines = [
778
  "dataset: {} / {}".format(repo_id, filename),
 
785
  "if_success: {}".format(int(bool(step.get("if_success", False)))),
786
  "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
787
  "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
788
+ "valid-window length: {}".format(chunk_len),
789
+ "valid_start: {}".format(int(bool(is_valid_start))),
790
+ "num_valid_starts: {}".format(num_valid_starts),
791
  "",
792
  "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
793
  "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
 
798
  info_lines.append("Image warnings:")
799
  info_lines.extend(warnings)
800
 
801
+ return gallery_items, status_plot, "\n".join(info_lines)
802
 
803
 
804
  def build_app():
 
818
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
819
  gr.Markdown(
820
  "# HDF5 Trajectory Viewer\n\n"
821
+ "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n"
822
+ "The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor."
823
  )
824
 
825
  if startup_warning:
 
842
 
843
  with gr.Row():
844
  image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
845
+ chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
846
  display_scale = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Image display scale")
847
  reverse_channels = gr.Checkbox(value=False, label="Reverse channels BGR↔RGB")
848
 
 
856
  trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
857
 
858
  gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
859
+ status_plot = gr.Image(label="Trajectory status: no_teacher_action and valid starts", type="numpy")
860
  info = gr.Textbox(label="Frame info", lines=16)
861
 
862
  with gr.Accordion("Debug: HDF5 tree", open=False):
 
874
  ).then(
875
  fn=render_frame,
876
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
877
+ outputs=[gallery, status_plot, info],
878
  )
879
 
880
  custom_repo_id.submit(
 
895
  ).then(
896
  fn=render_frame,
897
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
898
+ outputs=[gallery, status_plot, info],
899
  )
900
 
901
  timestep_slider.release(
902
  fn=render_frame,
903
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
904
+ outputs=[gallery, status_plot, info],
905
  )
906
 
907
  for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
908
  widget.change(
909
  fn=render_frame,
910
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
911
+ outputs=[gallery, status_plot, info],
912
  )
913
 
914
  render_btn.click(
915
  fn=render_frame,
916
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
917
+ outputs=[gallery, status_plot, info],
918
  )
919
 
920
  preload_btn.click(
 
942
  ).then(
943
  fn=render_frame,
944
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
945
+ outputs=[gallery, status_plot, info],
946
  )
947
 
948
  return demo