Zhaoting123 commited on
Commit
b8b9e9d
·
verified ·
1 Parent(s): 3c79c68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -487,7 +487,7 @@ def _make_trajectory_status_plot(traj, timestep, min_seq_len):
487
  no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
488
  is_valid_start = timestep in set(valid_indices)
489
 
490
- fig, ax = plt.subplots(figsize=(9, 2.4), dpi=140)
491
 
492
  ax.step(
493
  np.arange(total_steps),
@@ -510,8 +510,8 @@ def _make_trajectory_status_plot(traj, timestep, min_seq_len):
510
  ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
511
  ax.set_xlim(0, max(total_steps - 1, 1))
512
  ax.set_ylim(-0.38, 1.1)
513
- ax.set_ylabel("Flag")
514
- ax.set_xlabel("Timestep index")
515
  ax.set_yticks([0, 1])
516
  ax.set_yticklabels(["False", "True"])
517
  ax.grid(True, axis="x", alpha=0.2)
@@ -519,8 +519,9 @@ def _make_trajectory_status_plot(traj, timestep, min_seq_len):
519
  title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
520
  if is_valid_start:
521
  title += " | VALID START"
522
- ax.set_title(title)
523
- ax.legend(loc="upper right", fontsize=8)
 
524
 
525
  # Add saved timestep annotation if the stored timestep is not the same as index.
526
  saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
@@ -584,29 +585,39 @@ def _compose_video_frame(gallery_items, frame_label, status_plot=None):
584
 
585
  Top: selected observation images.
586
  Bottom: trajectory-status plot with the moving timestep cursor.
 
 
 
 
587
  """
 
 
588
  if not gallery_items:
589
  obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20))
590
  draw = ImageDraw.Draw(obs_canvas)
591
- draw.text((16, 16), "No selected image keys", fill=(255, 255, 255))
592
  else:
593
  pil_images = []
594
  for img, label in gallery_items:
595
  pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB")
596
- label_h = 24
 
 
597
  panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0))
598
  panel.paste(pil_img, (0, label_h))
599
  draw = ImageDraw.Draw(panel)
600
- draw.text((6, 4), str(label), fill=(255, 255, 255))
601
  pil_images.append(panel)
602
 
603
  gap = 8
604
- top_h = 28
605
  width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0)
606
  height = max(im.height for im in pil_images) + top_h
607
  obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0))
608
  draw = ImageDraw.Draw(obs_canvas)
609
- draw.text((8, 6), frame_label, fill=(255, 255, 255))
 
 
610
 
611
  x = 0
612
  for im in pil_images:
@@ -616,15 +627,22 @@ def _compose_video_frame(gallery_items, frame_label, status_plot=None):
616
  if status_plot is not None:
617
  status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB")
618
 
619
- target_w = obs_canvas.width
620
- if status_img.width != target_w:
621
- target_h = max(1, int(round(status_img.height * float(target_w) / float(status_img.width))))
622
- status_img = status_img.resize((target_w, target_h), resample=Image.Resampling.BILINEAR)
 
 
 
 
 
 
 
623
 
624
  gap_h = 8
625
  canvas = Image.new(
626
  "RGB",
627
- (max(obs_canvas.width, status_img.width), obs_canvas.height + gap_h + status_img.height),
628
  color=(0, 0, 0),
629
  )
630
  canvas.paste(obs_canvas, (0, 0))
@@ -632,6 +650,7 @@ def _compose_video_frame(gallery_items, frame_label, status_plot=None):
632
  else:
633
  canvas = obs_canvas
634
 
 
635
  pad_w = int(np.ceil(canvas.width / 16.0) * 16)
636
  pad_h = int(np.ceil(canvas.height / 16.0) * 16)
637
  if pad_w != canvas.width or pad_h != canvas.height:
@@ -884,9 +903,9 @@ def build_app():
884
  video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS")
885
 
886
  preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False)
887
- trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
888
 
889
  gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
 
890
  status_plot = gr.Image(label="Trajectory status: no_teacher_action and valid starts", type="numpy")
891
  info = gr.Textbox(label="Frame info", lines=16)
892
 
 
487
  no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
488
  is_valid_start = timestep in set(valid_indices)
489
 
490
+ fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170)
491
 
492
  ax.step(
493
  np.arange(total_steps),
 
510
  ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
511
  ax.set_xlim(0, max(total_steps - 1, 1))
512
  ax.set_ylim(-0.38, 1.1)
513
+ ax.set_ylabel("Flag", fontsize=10)
514
+ ax.set_xlabel("Timestep index", fontsize=10)
515
  ax.set_yticks([0, 1])
516
  ax.set_yticklabels(["False", "True"])
517
  ax.grid(True, axis="x", alpha=0.2)
 
519
  title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
520
  if is_valid_start:
521
  title += " | VALID START"
522
+ ax.set_title(title, fontsize=11)
523
+ ax.tick_params(axis="both", labelsize=9)
524
+ ax.legend(loc="upper right", fontsize=9)
525
 
526
  # Add saved timestep annotation if the stored timestep is not the same as index.
527
  saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
 
585
 
586
  Top: selected observation images.
587
  Bottom: trajectory-status plot with the moving timestep cursor.
588
+
589
+ Important: do NOT downscale the status plot to the image width. The plot
590
+ contains tick labels and a legend, so preserving its native width makes the
591
+ generated MP4 much more readable.
592
  """
593
+ small_text_y = 3
594
+
595
  if not gallery_items:
596
  obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20))
597
  draw = ImageDraw.Draw(obs_canvas)
598
+ draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255))
599
  else:
600
  pil_images = []
601
  for img, label in gallery_items:
602
  pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB")
603
+
604
+ # Keep the image-key caption compact; large captions waste video space.
605
+ label_h = 16
606
  panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0))
607
  panel.paste(pil_img, (0, label_h))
608
  draw = ImageDraw.Draw(panel)
609
+ draw.text((4, small_text_y), str(label), fill=(220, 220, 220))
610
  pil_images.append(panel)
611
 
612
  gap = 8
613
+ top_h = 18
614
  width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0)
615
  height = max(im.height for im in pil_images) + top_h
616
  obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0))
617
  draw = ImageDraw.Draw(obs_canvas)
618
+
619
+ # Compact frame label above the image panels.
620
+ draw.text((6, small_text_y), frame_label, fill=(220, 220, 220))
621
 
622
  x = 0
623
  for im in pil_images:
 
627
  if status_plot is not None:
628
  status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB")
629
 
630
+ # Preserve the status plot resolution. If needed, pad the observation
631
+ # canvas to the same width and center it above the plot.
632
+ final_w = max(obs_canvas.width, status_img.width)
633
+ if obs_canvas.width < final_w:
634
+ padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0))
635
+ padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0))
636
+ obs_canvas = padded_obs
637
+ elif status_img.width < final_w:
638
+ padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255))
639
+ padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0))
640
+ status_img = padded_status
641
 
642
  gap_h = 8
643
  canvas = Image.new(
644
  "RGB",
645
+ (final_w, obs_canvas.height + gap_h + status_img.height),
646
  color=(0, 0, 0),
647
  )
648
  canvas.paste(obs_canvas, (0, 0))
 
650
  else:
651
  canvas = obs_canvas
652
 
653
+ # Many MP4 encoders prefer dimensions divisible by 16.
654
  pad_w = int(np.ceil(canvas.width / 16.0) * 16)
655
  pad_h = int(np.ceil(canvas.height / 16.0) * 16)
656
  if pad_w != canvas.width or pad_h != canvas.height:
 
903
  video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS")
904
 
905
  preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False)
 
906
 
907
  gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
908
+ trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
909
  status_plot = gr.Image(label="Trajectory status: no_teacher_action and valid starts", type="numpy")
910
  info = gr.Textbox(label="Frame info", lines=16)
911