Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -445,13 +445,110 @@ def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tu
|
|
| 445 |
return gallery_items, tuple(warnings)
|
| 446 |
|
| 447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
@lru_cache(maxsize=8192)
|
| 449 |
-
def
|
| 450 |
traj = load_traj(repo_id, filename, int(traj_id))
|
| 451 |
timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
|
| 452 |
-
|
| 453 |
-
robot_chunk = _extract_robot_action_chunk(traj, timestep, int(chunk_len))
|
| 454 |
-
return _make_action_chunk_plot(mixed_chunk, robot_chunk), source_mask
|
| 455 |
|
| 456 |
|
| 457 |
def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
|
|
@@ -474,7 +571,7 @@ def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, tra
|
|
| 474 |
total = len(traj)
|
| 475 |
for t in range(total):
|
| 476 |
get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
|
| 477 |
-
|
| 478 |
|
| 479 |
status = "Preloaded trajectory {}".format(traj_id)
|
| 480 |
status += "\nFrames cached: {}".format(total)
|
|
@@ -675,7 +772,7 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
|
|
| 675 |
)
|
| 676 |
warnings = list(warnings_tuple)
|
| 677 |
|
| 678 |
-
|
| 679 |
|
| 680 |
info_lines = [
|
| 681 |
"dataset: {} / {}".format(repo_id, filename),
|
|
@@ -688,8 +785,9 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
|
|
| 688 |
"if_success: {}".format(int(bool(step.get("if_success", False)))),
|
| 689 |
"no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
|
| 690 |
"no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
|
| 691 |
-
"
|
| 692 |
-
"
|
|
|
|
| 693 |
"",
|
| 694 |
"teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
|
| 695 |
"robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
|
|
@@ -700,7 +798,7 @@ def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep
|
|
| 700 |
info_lines.append("Image warnings:")
|
| 701 |
info_lines.extend(warnings)
|
| 702 |
|
| 703 |
-
return gallery_items,
|
| 704 |
|
| 705 |
|
| 706 |
def build_app():
|
|
@@ -720,7 +818,8 @@ def build_app():
|
|
| 720 |
with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
|
| 721 |
gr.Markdown(
|
| 722 |
"# HDF5 Trajectory Viewer\n\n"
|
| 723 |
-
"Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face."
|
|
|
|
| 724 |
)
|
| 725 |
|
| 726 |
if startup_warning:
|
|
@@ -743,7 +842,7 @@ def build_app():
|
|
| 743 |
|
| 744 |
with gr.Row():
|
| 745 |
image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
|
| 746 |
-
chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="
|
| 747 |
display_scale = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Image display scale")
|
| 748 |
reverse_channels = gr.Checkbox(value=False, label="Reverse channels BGR↔RGB")
|
| 749 |
|
|
@@ -757,7 +856,7 @@ def build_app():
|
|
| 757 |
trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
|
| 758 |
|
| 759 |
gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
|
| 760 |
-
|
| 761 |
info = gr.Textbox(label="Frame info", lines=16)
|
| 762 |
|
| 763 |
with gr.Accordion("Debug: HDF5 tree", open=False):
|
|
@@ -775,7 +874,7 @@ def build_app():
|
|
| 775 |
).then(
|
| 776 |
fn=render_frame,
|
| 777 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 778 |
-
outputs=[gallery,
|
| 779 |
)
|
| 780 |
|
| 781 |
custom_repo_id.submit(
|
|
@@ -796,26 +895,26 @@ def build_app():
|
|
| 796 |
).then(
|
| 797 |
fn=render_frame,
|
| 798 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 799 |
-
outputs=[gallery,
|
| 800 |
)
|
| 801 |
|
| 802 |
timestep_slider.release(
|
| 803 |
fn=render_frame,
|
| 804 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 805 |
-
outputs=[gallery,
|
| 806 |
)
|
| 807 |
|
| 808 |
for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
|
| 809 |
widget.change(
|
| 810 |
fn=render_frame,
|
| 811 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 812 |
-
outputs=[gallery,
|
| 813 |
)
|
| 814 |
|
| 815 |
render_btn.click(
|
| 816 |
fn=render_frame,
|
| 817 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 818 |
-
outputs=[gallery,
|
| 819 |
)
|
| 820 |
|
| 821 |
preload_btn.click(
|
|
@@ -843,7 +942,7 @@ def build_app():
|
|
| 843 |
).then(
|
| 844 |
fn=render_frame,
|
| 845 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 846 |
-
outputs=[gallery,
|
| 847 |
)
|
| 848 |
|
| 849 |
return demo
|
|
|
|
| 445 |
return gallery_items, tuple(warnings)
|
| 446 |
|
| 447 |
|
| 448 |
+
def _compute_valid_start_indices(traj, min_seq_len):
|
| 449 |
+
"""Match the original local script's valid-start heuristic.
|
| 450 |
+
|
| 451 |
+
A timestep is valid when the following min_seq_len steps all have
|
| 452 |
+
no_teacher_action == False.
|
| 453 |
+
"""
|
| 454 |
+
total_steps = len(traj)
|
| 455 |
+
min_seq_len = int(max(1, min_seq_len))
|
| 456 |
+
no_teacher = np.asarray(
|
| 457 |
+
[int(bool(step.get("no_teacher_action", False))) for step in traj],
|
| 458 |
+
dtype=np.int32,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
valid_indices = []
|
| 462 |
+
max_start = total_steps - min_seq_len + 1
|
| 463 |
+
for t in range(max(0, max_start)):
|
| 464 |
+
if int(np.sum(no_teacher[t:t + min_seq_len])) == 0:
|
| 465 |
+
valid_indices.append(t)
|
| 466 |
+
|
| 467 |
+
return no_teacher, valid_indices
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _make_trajectory_status_plot(traj, timestep, min_seq_len):
|
| 471 |
+
"""Render the same high-level status figure as the local matplotlib tool.
|
| 472 |
+
|
| 473 |
+
Shows:
|
| 474 |
+
- orange no_teacher_action step plot
|
| 475 |
+
- green triangles for algorithmic valid start points
|
| 476 |
+
- black vertical cursor at current timestep
|
| 477 |
+
"""
|
| 478 |
+
total_steps = len(traj)
|
| 479 |
+
if total_steps == 0:
|
| 480 |
+
return None, False, 0
|
| 481 |
+
|
| 482 |
+
timestep = int(np.clip(int(timestep), 0, total_steps - 1))
|
| 483 |
+
timesteps = np.asarray(
|
| 484 |
+
[int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)],
|
| 485 |
+
dtype=np.int32,
|
| 486 |
+
)
|
| 487 |
+
no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
|
| 488 |
+
is_valid_start = timestep in set(valid_indices)
|
| 489 |
+
|
| 490 |
+
fig, ax = plt.subplots(figsize=(9, 2.4), dpi=140)
|
| 491 |
+
|
| 492 |
+
ax.step(
|
| 493 |
+
np.arange(total_steps),
|
| 494 |
+
no_teacher,
|
| 495 |
+
where="post",
|
| 496 |
+
label="no_teacher_action",
|
| 497 |
+
color="orange",
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if valid_indices:
|
| 501 |
+
ax.scatter(
|
| 502 |
+
valid_indices,
|
| 503 |
+
[-0.15] * len(valid_indices),
|
| 504 |
+
color="green",
|
| 505 |
+
marker="^",
|
| 506 |
+
s=18,
|
| 507 |
+
label="Valid Start (len >= {})".format(int(min_seq_len)),
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
|
| 511 |
+
ax.set_xlim(0, max(total_steps - 1, 1))
|
| 512 |
+
ax.set_ylim(-0.38, 1.1)
|
| 513 |
+
ax.set_ylabel("Flag")
|
| 514 |
+
ax.set_xlabel("Timestep index")
|
| 515 |
+
ax.set_yticks([0, 1])
|
| 516 |
+
ax.set_yticklabels(["False", "True"])
|
| 517 |
+
ax.grid(True, axis="x", alpha=0.2)
|
| 518 |
+
|
| 519 |
+
title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
|
| 520 |
+
if is_valid_start:
|
| 521 |
+
title += " | VALID START"
|
| 522 |
+
ax.set_title(title)
|
| 523 |
+
ax.legend(loc="upper right", fontsize=8)
|
| 524 |
+
|
| 525 |
+
# Add saved timestep annotation if the stored timestep is not the same as index.
|
| 526 |
+
saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
|
| 527 |
+
if saved_timestep != timestep:
|
| 528 |
+
ax.text(
|
| 529 |
+
0.01,
|
| 530 |
+
0.04,
|
| 531 |
+
"saved timestep: {}".format(saved_timestep),
|
| 532 |
+
transform=ax.transAxes,
|
| 533 |
+
fontsize=8,
|
| 534 |
+
va="bottom",
|
| 535 |
+
ha="left",
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
fig.tight_layout()
|
| 539 |
+
fig.canvas.draw()
|
| 540 |
+
rgba = np.asarray(fig.canvas.buffer_rgba())
|
| 541 |
+
image = rgba[..., :3].copy()
|
| 542 |
+
plt.close(fig)
|
| 543 |
+
|
| 544 |
+
return image, bool(is_valid_start), len(valid_indices)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
@lru_cache(maxsize=8192)
|
| 548 |
+
def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len):
|
| 549 |
traj = load_traj(repo_id, filename, int(traj_id))
|
| 550 |
timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
|
| 551 |
+
return _make_trajectory_status_plot(traj, timestep, int(min_seq_len))
|
|
|
|
|
|
|
| 552 |
|
| 553 |
|
| 554 |
def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
|
|
|
|
| 571 |
total = len(traj)
|
| 572 |
for t in range(total):
|
| 573 |
get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
|
| 574 |
+
get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len))
|
| 575 |
|
| 576 |
status = "Preloaded trajectory {}".format(traj_id)
|
| 577 |
status += "\nFrames cached: {}".format(total)
|
|
|
|
| 772 |
)
|
| 773 |
warnings = list(warnings_tuple)
|
| 774 |
|
| 775 |
+
status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)
|
| 776 |
|
| 777 |
info_lines = [
|
| 778 |
"dataset: {} / {}".format(repo_id, filename),
|
|
|
|
| 785 |
"if_success: {}".format(int(bool(step.get("if_success", False)))),
|
| 786 |
"no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
|
| 787 |
"no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
|
| 788 |
+
"valid-window length: {}".format(chunk_len),
|
| 789 |
+
"valid_start: {}".format(int(bool(is_valid_start))),
|
| 790 |
+
"num_valid_starts: {}".format(num_valid_starts),
|
| 791 |
"",
|
| 792 |
"teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
|
| 793 |
"robot_action: {}".format(_safe_array_str(step.get("robot_action", []))),
|
|
|
|
| 798 |
info_lines.append("Image warnings:")
|
| 799 |
info_lines.extend(warnings)
|
| 800 |
|
| 801 |
+
return gallery_items, status_plot, "\n".join(info_lines)
|
| 802 |
|
| 803 |
|
| 804 |
def build_app():
|
|
|
|
| 818 |
with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
|
| 819 |
gr.Markdown(
|
| 820 |
"# HDF5 Trajectory Viewer\n\n"
|
| 821 |
+
"Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n"
|
| 822 |
+
"The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor."
|
| 823 |
)
|
| 824 |
|
| 825 |
if startup_warning:
|
|
|
|
| 842 |
|
| 843 |
with gr.Row():
|
| 844 |
image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
|
| 845 |
+
chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
|
| 846 |
display_scale = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Image display scale")
|
| 847 |
reverse_channels = gr.Checkbox(value=False, label="Reverse channels BGR↔RGB")
|
| 848 |
|
|
|
|
| 856 |
trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
|
| 857 |
|
| 858 |
gallery = gr.Gallery(label="Observation images", columns=2, height="auto", object_fit="contain")
|
| 859 |
+
status_plot = gr.Image(label="Trajectory status: no_teacher_action and valid starts", type="numpy")
|
| 860 |
info = gr.Textbox(label="Frame info", lines=16)
|
| 861 |
|
| 862 |
with gr.Accordion("Debug: HDF5 tree", open=False):
|
|
|
|
| 874 |
).then(
|
| 875 |
fn=render_frame,
|
| 876 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 877 |
+
outputs=[gallery, status_plot, info],
|
| 878 |
)
|
| 879 |
|
| 880 |
custom_repo_id.submit(
|
|
|
|
| 895 |
).then(
|
| 896 |
fn=render_frame,
|
| 897 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 898 |
+
outputs=[gallery, status_plot, info],
|
| 899 |
)
|
| 900 |
|
| 901 |
timestep_slider.release(
|
| 902 |
fn=render_frame,
|
| 903 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 904 |
+
outputs=[gallery, status_plot, info],
|
| 905 |
)
|
| 906 |
|
| 907 |
for widget in [image_keys, chunk_len, display_scale, reverse_channels]:
|
| 908 |
widget.change(
|
| 909 |
fn=render_frame,
|
| 910 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 911 |
+
outputs=[gallery, status_plot, info],
|
| 912 |
)
|
| 913 |
|
| 914 |
render_btn.click(
|
| 915 |
fn=render_frame,
|
| 916 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 917 |
+
outputs=[gallery, status_plot, info],
|
| 918 |
)
|
| 919 |
|
| 920 |
preload_btn.click(
|
|
|
|
| 942 |
).then(
|
| 943 |
fn=render_frame,
|
| 944 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
| 945 |
+
outputs=[gallery, status_plot, info],
|
| 946 |
)
|
| 947 |
|
| 948 |
return demo
|