Zhaoting123 commited on
Commit
fdf3717
·
verified ·
1 Parent(s): 3cd8b50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -46,6 +46,7 @@ DATASET_PRESETS = {
46
  "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
47
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
48
  ),
 
49
  },
50
  "Robosuite Square 20260410": {
51
  "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
@@ -53,10 +54,12 @@ DATASET_PRESETS = {
53
  "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
54
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
55
  ),
 
56
  },
57
  "InsertT Nov10 noisy": {
58
  "repo_id": "Zhaoting123/InsertT",
59
  "filename": "trajectory_buffer_Nov10_demo_noisy.hdf5",
 
60
  },
61
  }
62
 
@@ -87,6 +90,20 @@ def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
87
  return item["repo_id"], item["filename"]
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @lru_cache(maxsize=8)
91
  def get_local_hdf5_path(repo_id, filename):
92
  return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)
@@ -871,14 +888,18 @@ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
871
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
872
  n_traj = get_num_trajectories(repo_id, filename)
873
 
 
 
874
  if n_traj == 0:
875
  status = "Loaded `{}` / `{}`".format(repo_id, filename)
876
  status += "\nDetected trajectories: 0"
 
877
  return (
878
  gr.update(maximum=1, value=0),
879
  gr.update(maximum=1, value=0),
880
  gr.update(choices=[], value=[]),
881
  status,
 
882
  )
883
 
884
  keys = get_available_image_keys(repo_id, filename, 0)
@@ -886,12 +907,14 @@ def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
886
 
887
  status = "Loaded `{}` / `{}`".format(repo_id, filename)
888
  status += "\nDetected trajectories: {}".format(n_traj)
 
889
 
890
  return (
891
  gr.update(maximum=max(n_traj - 1, 1), value=0),
892
  gr.update(maximum=max(len(traj) - 1, 1), value=0),
893
  gr.update(choices=keys, value=keys[:2]),
894
  status,
 
895
  )
896
 
897
 
@@ -992,7 +1015,7 @@ def build_app():
992
  first_keys = []
993
  startup_warning = repr(exc)
994
 
995
- default_status = "Loaded default dataset\nDetected trajectories: {}".format(n_traj)
996
 
997
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
998
  gr.Markdown(
@@ -1023,7 +1046,7 @@ def build_app():
1023
  image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
1024
  chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
1025
  display_scale = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Image display scale")
1026
- reverse_channels = gr.Checkbox(value=False, label="Reverse channels BGR↔RGB")
1027
 
1028
  with gr.Row():
1029
  render_btn = gr.Button("Render frame", variant="primary")
@@ -1062,7 +1085,7 @@ def build_app():
1062
  ).then(
1063
  fn=update_after_dataset_change,
1064
  inputs=[preset, custom_repo_id, custom_filename],
1065
- outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
1066
  ).then(
1067
  fn=render_frame,
1068
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
@@ -1072,12 +1095,12 @@ def build_app():
1072
  custom_repo_id.submit(
1073
  fn=update_after_dataset_change,
1074
  inputs=[preset, custom_repo_id, custom_filename],
1075
- outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
1076
  )
1077
  custom_filename.submit(
1078
  fn=update_after_dataset_change,
1079
  inputs=[preset, custom_repo_id, custom_filename],
1080
- outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
1081
  )
1082
 
1083
  traj_slider.change(
@@ -1130,7 +1153,7 @@ def build_app():
1130
  demo.load(
1131
  fn=update_after_dataset_change,
1132
  inputs=[preset, custom_repo_id, custom_filename],
1133
- outputs=[traj_slider, timestep_slider, image_keys, dataset_status],
1134
  ).then(
1135
  fn=render_frame,
1136
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
46
  "20260409_205051_Diffusion_CLIC_intervention_Circular_square_image_abs_"
47
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
48
  ),
49
+ "default_reverse_channels": False,
50
  },
51
  "Robosuite Square 20260410": {
52
  "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
 
54
  "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
55
  "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
56
  ),
57
+ "default_reverse_channels": False,
58
  },
59
  "InsertT Nov10 noisy": {
60
  "repo_id": "Zhaoting123/InsertT",
61
  "filename": "trajectory_buffer_Nov10_demo_noisy.hdf5",
62
+ "default_reverse_channels": True,
63
  },
64
  }
65
 
 
90
  return item["repo_id"], item["filename"]
91
 
92
 
93
+ def get_default_reverse_channels(preset_name):
94
+ """Dataset-specific default for BGR<->RGB reversal.
95
+
96
+ Robosuite Square presets use normal RGB ordering.
97
+ InsertT / PushT-style preset requires reversal.
98
+ Custom datasets default to False so users can still override manually.
99
+ """
100
+ preset_name = preset_name or DEFAULT_PRESET
101
+ if preset_name == "Custom":
102
+ return False
103
+ item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
104
+ return bool(item.get("default_reverse_channels", False))
105
+
106
+
107
  @lru_cache(maxsize=8)
108
  def get_local_hdf5_path(repo_id, filename):
109
  return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)
 
888
  repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
889
  n_traj = get_num_trajectories(repo_id, filename)
890
 
891
+ reverse_default = get_default_reverse_channels(preset_name)
892
+
893
  if n_traj == 0:
894
  status = "Loaded `{}` / `{}`".format(repo_id, filename)
895
  status += "\nDetected trajectories: 0"
896
+ status += "\nreverse_channels default: {}".format(int(reverse_default))
897
  return (
898
  gr.update(maximum=1, value=0),
899
  gr.update(maximum=1, value=0),
900
  gr.update(choices=[], value=[]),
901
  status,
902
+ gr.update(value=reverse_default),
903
  )
904
 
905
  keys = get_available_image_keys(repo_id, filename, 0)
 
907
 
908
  status = "Loaded `{}` / `{}`".format(repo_id, filename)
909
  status += "\nDetected trajectories: {}".format(n_traj)
910
+ status += "\nreverse_channels default: {}".format(int(reverse_default))
911
 
912
  return (
913
  gr.update(maximum=max(n_traj - 1, 1), value=0),
914
  gr.update(maximum=max(len(traj) - 1, 1), value=0),
915
  gr.update(choices=keys, value=keys[:2]),
916
  status,
917
+ gr.update(value=reverse_default),
918
  )
919
 
920
 
 
1015
  first_keys = []
1016
  startup_warning = repr(exc)
1017
 
1018
+ default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET)))
1019
 
1020
  with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
1021
  gr.Markdown(
 
1046
  image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
1047
  chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
1048
  display_scale = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Image display scale")
1049
+ reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB")
1050
 
1051
  with gr.Row():
1052
  render_btn = gr.Button("Render frame", variant="primary")
 
1085
  ).then(
1086
  fn=update_after_dataset_change,
1087
  inputs=[preset, custom_repo_id, custom_filename],
1088
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
1089
  ).then(
1090
  fn=render_frame,
1091
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
 
1095
  custom_repo_id.submit(
1096
  fn=update_after_dataset_change,
1097
  inputs=[preset, custom_repo_id, custom_filename],
1098
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
1099
  )
1100
  custom_filename.submit(
1101
  fn=update_after_dataset_change,
1102
  inputs=[preset, custom_repo_id, custom_filename],
1103
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
1104
  )
1105
 
1106
  traj_slider.change(
 
1153
  demo.load(
1154
  fn=update_after_dataset_change,
1155
  inputs=[preset, custom_repo_id, custom_filename],
1156
+ outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
1157
  ).then(
1158
  fn=render_frame,
1159
  inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],