Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,6 +46,7 @@ DATASET_PRESETS = {
|
|
| 46 |
"20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
|
| 47 |
"Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
|
| 48 |
),
|
|
|
|
| 49 |
},
|
| 50 |
"Robosuite Square 20260410": {
|
| 51 |
"repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
|
|
@@ -53,10 +54,12 @@ DATASET_PRESETS = {
|
|
| 53 |
"20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
|
| 54 |
"Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
|
| 55 |
),
|
|
|
|
| 56 |
},
|
| 57 |
"InsertT Nov10 noisy": {
|
| 58 |
"repo_id": "Zhaoting123/InsertT",
|
| 59 |
"filename": "trajectory_buffer_Nov10_demo_noisy.hdf5",
|
|
|
|
| 60 |
},
|
| 61 |
}
|
| 62 |
|
|
@@ -87,6 +90,20 @@ def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
|
|
| 87 |
return item["repo_id"], item["filename"]
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
@lru_cache(maxsize=8)
|
| 91 |
def get_local_hdf5_path(repo_id, filename):
|
| 92 |
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)
|
|
@@ -871,14 +888,18 @@ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
|
|
| 871 |
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
|
| 872 |
n_traj = get_num_trajectories(repo_id, filename)
|
| 873 |
|
|
|
|
|
|
|
| 874 |
if n_traj == 0:
|
| 875 |
status = "Loaded `{}` / `{}`".format(repo_id, filename)
|
| 876 |
status += "\nDetected trajectories: 0"
|
|
|
|
| 877 |
return (
|
| 878 |
gr.update(maximum=1, value=0),
|
| 879 |
gr.update(maximum=1, value=0),
|
| 880 |
gr.update(choices=[], value=[]),
|
| 881 |
status,
|
|
|
|
| 882 |
)
|
| 883 |
|
| 884 |
keys = get_available_image_keys(repo_id, filename, 0)
|
|
@@ -886,12 +907,14 @@ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
|
|
| 886 |
|
| 887 |
status = "Loaded `{}` / `{}`".format(repo_id, filename)
|
| 888 |
status += "\nDetected trajectories: {}".format(n_traj)
|
|
|
|
| 889 |
|
| 890 |
return (
|
| 891 |
gr.update(maximum=max(n_traj - 1, 1), value=0),
|
| 892 |
gr.update(maximum=max(len(traj) - 1, 1), value=0),
|
| 893 |
gr.update(choices=keys, value=keys[:2]),
|
| 894 |
status,
|
|
|
|
| 895 |
)
|
| 896 |
|
| 897 |
|
|
@@ -992,7 +1015,7 @@ def build_app():
|
|
| 992 |
first_keys = []
|
| 993 |
startup_warning = repr(exc)
|
| 994 |
|
| 995 |
-
default_status = "Loaded default dataset\nDetected trajectories: {}".format(n_traj)
|
| 996 |
|
| 997 |
with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
|
| 998 |
gr.Markdown(
|
|
@@ -1023,7 +1046,7 @@ def build_app():
|
|
| 1023 |
image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
|
| 1024 |
chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
|
| 1025 |
display_scale = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Image display scale")
|
| 1026 |
-
reverse_channels = gr.Checkbox(value=
|
| 1027 |
|
| 1028 |
with gr.Row():
|
| 1029 |
render_btn = gr.Button("Render frame", variant="primary")
|
|
@@ -1062,7 +1085,7 @@ def build_app():
|
|
| 1062 |
).then(
|
| 1063 |
fn=update_after_dataset_change,
|
| 1064 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1065 |
-
outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
|
| 1066 |
).then(
|
| 1067 |
fn=render_frame,
|
| 1068 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
|
@@ -1072,12 +1095,12 @@ def build_app():
|
|
| 1072 |
custom_repo_id.submit(
|
| 1073 |
fn=update_after_dataset_change,
|
| 1074 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1075 |
-
outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
|
| 1076 |
)
|
| 1077 |
custom_filename.submit(
|
| 1078 |
fn=update_after_dataset_change,
|
| 1079 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1080 |
-
outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
|
| 1081 |
)
|
| 1082 |
|
| 1083 |
traj_slider.change(
|
|
@@ -1130,7 +1153,7 @@ def build_app():
|
|
| 1130 |
demo.load(
|
| 1131 |
fn=update_after_dataset_change,
|
| 1132 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1133 |
-
outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
|
| 1134 |
).then(
|
| 1135 |
fn=render_frame,
|
| 1136 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
|
|
|
| 46 |
"20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
|
| 47 |
"Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
|
| 48 |
),
|
| 49 |
+
"default_reverse_channels": False,
|
| 50 |
},
|
| 51 |
"Robosuite Square 20260410": {
|
| 52 |
"repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
|
|
|
|
| 54 |
"20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
|
| 55 |
"Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
|
| 56 |
),
|
| 57 |
+
"default_reverse_channels": False,
|
| 58 |
},
|
| 59 |
"InsertT Nov10 noisy": {
|
| 60 |
"repo_id": "Zhaoting123/InsertT",
|
| 61 |
"filename": "trajectory_buffer_Nov10_demo_noisy.hdf5",
|
| 62 |
+
"default_reverse_channels": True,
|
| 63 |
},
|
| 64 |
}
|
| 65 |
|
|
|
|
| 90 |
return item["repo_id"], item["filename"]
|
| 91 |
|
| 92 |
|
| 93 |
+
def get_default_reverse_channels(preset_name):
|
| 94 |
+
"""Dataset-specific default for BGR<->RGB reversal.
|
| 95 |
+
|
| 96 |
+
Robosuite Square presets use normal RGB ordering.
|
| 97 |
+
InsertT / PushT-style preset requires reversal.
|
| 98 |
+
Custom datasets default to False so users can still override manually.
|
| 99 |
+
"""
|
| 100 |
+
preset_name = preset_name or DEFAULT_PRESET
|
| 101 |
+
if preset_name == "Custom":
|
| 102 |
+
return False
|
| 103 |
+
item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
|
| 104 |
+
return bool(item.get("default_reverse_channels", False))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
@lru_cache(maxsize=8)
|
| 108 |
def get_local_hdf5_path(repo_id, filename):
|
| 109 |
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)
|
|
|
|
| 888 |
repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
|
| 889 |
n_traj = get_num_trajectories(repo_id, filename)
|
| 890 |
|
| 891 |
+
reverse_default = get_default_reverse_channels(preset_name)
|
| 892 |
+
|
| 893 |
if n_traj == 0:
|
| 894 |
status = "Loaded `{}` / `{}`".format(repo_id, filename)
|
| 895 |
status += "\nDetected trajectories: 0"
|
| 896 |
+
status += "\nreverse_channels default: {}".format(int(reverse_default))
|
| 897 |
return (
|
| 898 |
gr.update(maximum=1, value=0),
|
| 899 |
gr.update(maximum=1, value=0),
|
| 900 |
gr.update(choices=[], value=[]),
|
| 901 |
status,
|
| 902 |
+
gr.update(value=reverse_default),
|
| 903 |
)
|
| 904 |
|
| 905 |
keys = get_available_image_keys(repo_id, filename, 0)
|
|
|
|
| 907 |
|
| 908 |
status = "Loaded `{}` / `{}`".format(repo_id, filename)
|
| 909 |
status += "\nDetected trajectories: {}".format(n_traj)
|
| 910 |
+
status += "\nreverse_channels default: {}".format(int(reverse_default))
|
| 911 |
|
| 912 |
return (
|
| 913 |
gr.update(maximum=max(n_traj - 1, 1), value=0),
|
| 914 |
gr.update(maximum=max(len(traj) - 1, 1), value=0),
|
| 915 |
gr.update(choices=keys, value=keys[:2]),
|
| 916 |
status,
|
| 917 |
+
gr.update(value=reverse_default),
|
| 918 |
)
|
| 919 |
|
| 920 |
|
|
|
|
| 1015 |
first_keys = []
|
| 1016 |
startup_warning = repr(exc)
|
| 1017 |
|
| 1018 |
+
default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET)))
|
| 1019 |
|
| 1020 |
with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
|
| 1021 |
gr.Markdown(
|
|
|
|
| 1046 |
image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
|
| 1047 |
chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
|
| 1048 |
display_scale = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Image display scale")
|
| 1049 |
+
reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB")
|
| 1050 |
|
| 1051 |
with gr.Row():
|
| 1052 |
render_btn = gr.Button("Render frame", variant="primary")
|
|
|
|
| 1085 |
).then(
|
| 1086 |
fn=update_after_dataset_change,
|
| 1087 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1088 |
+
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
|
| 1089 |
).then(
|
| 1090 |
fn=render_frame,
|
| 1091 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|
|
|
|
| 1095 |
custom_repo_id.submit(
|
| 1096 |
fn=update_after_dataset_change,
|
| 1097 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1098 |
+
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
|
| 1099 |
)
|
| 1100 |
custom_filename.submit(
|
| 1101 |
fn=update_after_dataset_change,
|
| 1102 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1103 |
+
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
|
| 1104 |
)
|
| 1105 |
|
| 1106 |
traj_slider.change(
|
|
|
|
| 1153 |
demo.load(
|
| 1154 |
fn=update_after_dataset_change,
|
| 1155 |
inputs=[preset, custom_repo_id, custom_filename],
|
| 1156 |
+
outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
|
| 1157 |
).then(
|
| 1158 |
fn=render_frame,
|
| 1159 |
inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
|