Zhaoting123 commited on
Commit
ca76c1e
·
verified ·
1 Parent(s): 7927dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -20
app.py CHANGED
@@ -505,6 +505,104 @@ def _make_action_chunk_plot(mixed_chunk, robot_chunk):
505
  return image
506
 
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  # -----------------------------------------------------------------------------
509
  # Gradio callbacks
510
  # -----------------------------------------------------------------------------
@@ -615,25 +713,25 @@ def render_frame(
615
  step = traj[timestep]
616
  obs = step.get("obs", {})
617
 
618
- gallery_items = []
619
- warnings = []
620
- for key in image_keys:
621
- if key not in obs:
622
- warnings.append("Missing image key: {}".format(key))
623
- continue
624
- try:
625
- img = _extract_display_image(
626
- obs[key],
627
- reverse_channels=bool(reverse_channels),
628
- )
629
- img = _resize_image_for_display(img, display_scale)
630
- gallery_items.append((img, key))
631
- except Exception as exc:
632
- warnings.append("{}: {}".format(key, exc))
633
-
634
- mixed_chunk, source_mask = _extract_mixed_action_chunk(traj, timestep, chunk_len)
635
- robot_chunk = _extract_robot_action_chunk(traj, timestep, chunk_len)
636
- action_plot = _make_action_chunk_plot(mixed_chunk, robot_chunk)
637
 
638
  info_lines = [
639
  "dataset: {} / {}".format(repo_id, filename),
@@ -753,7 +851,16 @@ def build_app():
753
  label="Reverse channels BGR↔RGB",
754
  )
755
 
756
- render_btn = gr.Button("Render frame", variant="primary")
 
 
 
 
 
 
 
 
 
757
 
758
  gallery = gr.Gallery(
759
  label="Observation images",
@@ -874,6 +981,21 @@ def build_app():
874
  outputs=[gallery, action_plot, info],
875
  )
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  inspect_btn.click(
878
  fn=inspect_hdf5_tree,
879
  inputs=[preset, custom_repo_id, custom_filename],
 
505
  return image
506
 
507
 
508
+ # -----------------------------------------------------------------------------
509
+ # Frame-level render cache
510
+ # -----------------------------------------------------------------------------
511
+ @lru_cache(maxsize=8192)
512
+ def get_cached_gallery_items(
513
+ repo_id,
514
+ filename,
515
+ traj_id,
516
+ timestep,
517
+ image_keys_tuple,
518
+ display_scale,
519
+ reverse_channels,
520
+ ):
521
+ """Cache decoded/resized observation images for one frame.
522
+
523
+ Gradio may briefly clear the image component while a callback is running.
524
+ This cache makes the callback fast after preloading, which largely removes
525
+ the black-frame effect when scrubbing.
526
+ """
527
+ traj = load_traj(repo_id, filename, int(traj_id))
528
+ timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
529
+ obs = traj[timestep].get("obs", {})
530
+
531
+ gallery_items = []
532
+ warnings = []
533
+ for key in image_keys_tuple:
534
+ if key not in obs:
535
+ warnings.append("Missing image key: {}".format(key))
536
+ continue
537
+ try:
538
+ img = _extract_display_image(
539
+ obs[key],
540
+ reverse_channels=bool(reverse_channels),
541
+ )
542
+ img = _resize_image_for_display(img, float(display_scale))
543
+ gallery_items.append((img, key))
544
+ except Exception as exc:
545
+ warnings.append("{}: {}".format(key, exc))
546
+
547
+ return gallery_items, tuple(warnings)
548
+
549
+
550
+ @lru_cache(maxsize=8192)
551
+ def get_cached_action_plot(repo_id, filename, traj_id, timestep, chunk_len):
552
+ traj = load_traj(repo_id, filename, int(traj_id))
553
+ timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
554
+ mixed_chunk, source_mask = _extract_mixed_action_chunk(traj, timestep, int(chunk_len))
555
+ robot_chunk = _extract_robot_action_chunk(traj, timestep, int(chunk_len))
556
+ return _make_action_chunk_plot(mixed_chunk, robot_chunk), source_mask
557
+
558
+
559
+ def preload_current_trajectory(
560
+ preset_name,
561
+ custom_repo_id,
562
+ custom_filename,
563
+ traj_id,
564
+ image_keys,
565
+ chunk_len,
566
+ display_scale,
567
+ reverse_channels,
568
+ ):
569
+ """Pre-render all selected observation frames for the current trajectory."""
570
+ repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
571
+ n_traj = get_num_trajectories(repo_id, filename)
572
+ if n_traj == 0:
573
+ return "No trajectories found."
574
+
575
+ traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
576
+ traj = load_traj(repo_id, filename, traj_id)
577
+ if not traj:
578
+ return "Trajectory could not be loaded."
579
+
580
+ if image_keys is None:
581
+ image_keys = []
582
+ if isinstance(image_keys, str):
583
+ image_keys = [image_keys]
584
+ image_keys_tuple = tuple(image_keys)
585
+
586
+ total = len(traj)
587
+ for t in range(total):
588
+ get_cached_gallery_items(
589
+ repo_id,
590
+ filename,
591
+ traj_id,
592
+ t,
593
+ image_keys_tuple,
594
+ float(display_scale),
595
+ bool(reverse_channels),
596
+ )
597
+ # The action plot is usually the slowest part. Preload it too.
598
+ get_cached_action_plot(repo_id, filename, traj_id, t, int(chunk_len))
599
+
600
+ status = "Preloaded trajectory {}".format(traj_id)
601
+ status += chr(10) + "Frames cached: {}".format(total)
602
+ status += chr(10) + "Image keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none")
603
+ return status
604
+
605
+
606
  # -----------------------------------------------------------------------------
607
  # Gradio callbacks
608
  # -----------------------------------------------------------------------------
 
713
  step = traj[timestep]
714
  obs = step.get("obs", {})
715
 
716
+ image_keys_tuple = tuple(image_keys)
717
+ gallery_items, warnings_tuple = get_cached_gallery_items(
718
+ repo_id,
719
+ filename,
720
+ traj_id,
721
+ timestep,
722
+ image_keys_tuple,
723
+ display_scale,
724
+ bool(reverse_channels),
725
+ )
726
+ warnings = list(warnings_tuple)
727
+
728
+ action_plot, source_mask = get_cached_action_plot(
729
+ repo_id,
730
+ filename,
731
+ traj_id,
732
+ timestep,
733
+ chunk_len,
734
+ )
735
 
736
  info_lines = [
737
  "dataset: {} / {}".format(repo_id, filename),
 
851
  label="Reverse channels BGR↔RGB",
852
  )
853
 
854
+ with gr.Row():
855
+ render_btn = gr.Button("Render frame", variant="primary")
856
+ preload_btn = gr.Button("Preload current trajectory")
857
+
858
+ preload_status = gr.Textbox(
859
+ label="Preload status",
860
+ lines=3,
861
+ value="Not preloaded yet.",
862
+ interactive=False,
863
+ )
864
 
865
  gallery = gr.Gallery(
866
  label="Observation images",
 
981
  outputs=[gallery, action_plot, info],
982
  )
983
 
984
+ preload_btn.click(
985
+ fn=preload_current_trajectory,
986
+ inputs=[
987
+ preset,
988
+ custom_repo_id,
989
+ custom_filename,
990
+ traj_slider,
991
+ image_keys,
992
+ chunk_len,
993
+ display_scale,
994
+ reverse_channels,
995
+ ],
996
+ outputs=preload_status,
997
+ )
998
+
999
  inspect_btn.click(
1000
  fn=inspect_hdf5_tree,
1001
  inputs=[preset, custom_repo_id, custom_filename],