File size: 44,884 Bytes
4dfb808
e80f0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7db5fc
e80f0b4
4dfb808
 
f7db5fc
4dfb808
 
 
 
 
 
e80f0b4
 
 
 
 
 
f7db5fc
 
 
 
 
4dfb808
 
d85b559
ccaa72d
d85b559
 
 
 
 
fdf3717
d85b559
77e7303
d85b559
77e7303
 
 
 
 
 
 
 
9c8b386
77e7303
 
fdf3717
d85b559
 
 
ccaa72d
f7db5fc
4dfb808
a54b7c3
2fb0eb7
 
f7db5fc
 
 
 
 
7927dd6
 
f7db5fc
7927dd6
 
4dfb808
d85b559
 
 
 
 
 
7927dd6
d85b559
 
e80f0b4
d85b559
 
 
fdf3717
 
 
 
 
 
 
 
 
 
 
 
 
 
d85b559
 
e80f0b4
4dfb808
f7db5fc
 
7927dd6
 
 
 
f7db5fc
 
d85b559
 
 
f7db5fc
 
e80f0b4
7927dd6
f7db5fc
 
7927dd6
 
 
 
e80f0b4
7927dd6
 
e80f0b4
7927dd6
f7db5fc
 
d85b559
 
7927dd6
f7db5fc
 
7927dd6
d85b559
 
7927dd6
f7db5fc
 
 
 
 
 
e80f0b4
f7db5fc
7927dd6
 
f7db5fc
 
 
 
e80f0b4
f7db5fc
 
7927dd6
 
f7db5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7927dd6
 
f7db5fc
 
 
 
 
7927dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7db5fc
 
 
 
 
 
d85b559
 
 
f7db5fc
 
 
 
 
7927dd6
f7db5fc
 
e80f0b4
7927dd6
f7db5fc
7927dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7db5fc
 
 
e80f0b4
f7db5fc
 
 
7927dd6
 
e80f0b4
 
 
 
 
 
7927dd6
 
 
 
 
 
e80f0b4
 
 
 
 
 
 
 
 
 
 
f7db5fc
 
4dfb808
 
 
3cd8b50
 
 
 
 
 
 
4dfb808
3cd8b50
 
 
 
 
 
 
 
 
 
4dfb808
3cd8b50
4dfb808
 
 
7927dd6
3cd8b50
7927dd6
 
4dfb808
3cd8b50
 
 
 
 
7927dd6
 
 
 
 
3cd8b50
 
7927dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7db5fc
4dfb808
 
f7db5fc
4dfb808
 
7927dd6
 
 
f7db5fc
 
4dfb808
7927dd6
4dfb808
e80f0b4
4dfb808
7927dd6
 
 
4dfb808
 
7927dd6
 
 
4dfb808
 
 
7927dd6
f7db5fc
 
7927dd6
f7db5fc
 
7927dd6
4dfb808
 
7927dd6
4dfb808
 
 
7927dd6
4dfb808
 
 
 
f7db5fc
4dfb808
7927dd6
4dfb808
 
 
 
 
7927dd6
4dfb808
 
7927dd6
4dfb808
 
f7db5fc
7927dd6
4dfb808
 
 
 
 
7927dd6
 
f7db5fc
7927dd6
 
 
 
4dfb808
 
7927dd6
4dfb808
 
 
 
 
 
 
 
7927dd6
4dfb808
f7db5fc
7927dd6
 
4dfb808
 
 
 
 
7927dd6
 
 
 
 
 
 
 
4dfb808
 
 
 
 
 
 
 
 
7927dd6
4dfb808
7927dd6
4dfb808
 
ca76c1e
e80f0b4
ca76c1e
 
 
 
 
 
 
 
 
 
 
e80f0b4
ca76c1e
 
 
 
 
 
 
 
cdbfe0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b9e9d
cdbfe0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b9e9d
 
cdbfe0c
 
 
 
 
 
 
b8b9e9d
 
 
cdbfe0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca76c1e
cdbfe0c
ca76c1e
 
cdbfe0c
ca76c1e
 
e80f0b4
ca76c1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e80f0b4
cdbfe0c
ca76c1e
 
e80f0b4
 
ca76c1e
 
 
80f274e
 
 
 
 
b8b9e9d
 
 
 
80f274e
b8b9e9d
 
e80f0b4
80f274e
 
b8b9e9d
80f274e
 
 
 
b8b9e9d
 
 
80f274e
 
 
b8b9e9d
80f274e
 
 
b8b9e9d
80f274e
 
 
 
b8b9e9d
 
 
80f274e
 
 
 
 
 
 
 
 
b8b9e9d
 
 
 
 
 
 
 
 
 
 
80f274e
 
 
 
b8b9e9d
80f274e
 
 
 
 
 
e80f0b4
b8b9e9d
e80f0b4
 
 
 
 
 
 
 
 
 
16b8b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0eb7
16b8b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0eb7
 
16b8b5e
 
 
2fb0eb7
 
 
16b8b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54b7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0eb7
e80f0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54b7c3
 
 
 
 
e80f0b4
 
 
 
a54b7c3
 
 
 
 
 
 
 
 
 
 
e80f0b4
 
a54b7c3
 
e80f0b4
a54b7c3
 
 
e80f0b4
 
 
 
 
 
 
 
 
 
a54b7c3
80f274e
e80f0b4
a54b7c3
e80f0b4
 
 
a54b7c3
 
e80f0b4
a54b7c3
 
80f274e
a54b7c3
e80f0b4
 
d85b559
 
7927dd6
 
 
 
d85b559
4dfb808
 
f7db5fc
4dfb808
7927dd6
f7db5fc
4dfb808
7927dd6
 
4dfb808
 
f7db5fc
7927dd6
 
4dfb808
 
 
7927dd6
 
 
 
 
 
 
 
 
fdf3717
 
7927dd6
 
e80f0b4
fdf3717
7927dd6
 
 
 
 
fdf3717
7927dd6
 
 
 
 
 
e80f0b4
fdf3717
7927dd6
 
 
 
 
 
fdf3717
7927dd6
 
 
d85b559
 
 
7927dd6
 
 
 
d85b559
7927dd6
 
4dfb808
7927dd6
 
4dfb808
 
 
e80f0b4
d85b559
 
f7db5fc
7927dd6
 
 
 
 
4dfb808
7927dd6
4dfb808
f7db5fc
 
 
4dfb808
 
 
 
 
 
7927dd6
ca76c1e
e80f0b4
ca76c1e
e80f0b4
ca76c1e
 
 
cdbfe0c
4dfb808
3cd8b50
 
 
 
 
 
 
 
4dfb808
7927dd6
 
 
 
 
 
 
 
 
 
cdbfe0c
 
 
4dfb808
7927dd6
 
3cd8b50
 
 
4dfb808
 
7927dd6
 
 
 
4dfb808
cdbfe0c
4dfb808
 
 
7927dd6
 
4dfb808
d85b559
7927dd6
 
4dfb808
7927dd6
4dfb808
7927dd6
 
fdf3717
4dfb808
f7db5fc
4dfb808
7927dd6
cdbfe0c
 
4dfb808
 
7927dd6
 
4dfb808
d85b559
 
 
 
 
 
e80f0b4
 
d85b559
e80f0b4
7927dd6
4dfb808
e80f0b4
 
4dfb808
 
e80f0b4
cdbfe0c
a54b7c3
fdf3717
4dfb808
ca76c1e
 
 
e80f0b4
 
2fb0eb7
ca76c1e
e80f0b4
4dfb808
16b8b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b9e9d
3c79c68
 
 
 
 
d85b559
 
 
 
 
 
 
 
fdf3717
d85b559
 
e80f0b4
cdbfe0c
d85b559
 
 
 
 
fdf3717
d85b559
 
 
 
fdf3717
d85b559
4dfb808
 
 
d85b559
4dfb808
 
 
e80f0b4
cdbfe0c
d85b559
 
 
 
e80f0b4
cdbfe0c
4dfb808
 
a54b7c3
4dfb808
 
e80f0b4
cdbfe0c
4dfb808
 
 
 
e80f0b4
cdbfe0c
4dfb808
 
ca76c1e
 
e80f0b4
ca76c1e
 
 
e80f0b4
 
a54b7c3
e80f0b4
 
 
7927dd6
 
 
 
 
 
4dfb808
7927dd6
 
fdf3717
4dfb808
 
e80f0b4
cdbfe0c
4dfb808
 
 
 
 
 
cabf23d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244

# Standalone Hugging Face Space viewer for TrajectoryBuffer-style HDF5 files.
#
# requirements.txt:
# gradio
# huggingface_hub
# h5py
# numpy
# pillow
# matplotlib
# imageio
# imageio-ffmpeg
#
# Optional:
# opencv-python-headless

import os
import re
import tempfile
from functools import lru_cache

import gradio as gr
import h5py
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw

try:
    import imageio.v2 as imageio
except Exception:
    imageio = None

try:
    import cv2
except Exception:
    cv2 = None


DATASET_PRESETS = {
    "Robosuite Square Correction": {
        "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state",
        "filename": (
            "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_"
            "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5"
        ),
        "default_reverse_channels": False,
    },
    "InsertT Demonstration": {
        "repo_id": "Zhaoting123/InsertT",
        "filename": "trajectory_buffer_Nov10_demo.hdf5",
        "default_reverse_channels": True,
    },
    "InsertT Correction": {
        "repo_id": "Zhaoting123/InsertT",
        "filename": "trajectory_buffer_Nov11_intervention.hdf5",
        "default_reverse_channels": True,
    },
    "RoundTable Correction": {
        "repo_id": "Zhaoting123/Furniture_Bench_Round_Table_Assembly",
        "filename": "trajectory_buffer_0_Nov24_intervention_relabeled.hdf5",
        "default_reverse_channels": True,
    },
}

DEFAULT_PRESET = "Robosuite Square Correction"
REPO_TYPE = "dataset"
DEFAULT_CHUNK_LEN = 16
DEFAULT_DISPLAY_SCALE = 1
VIDEO_STATUS_FIGSIZE = (6.0, 1.8)
VIDEO_STATUS_DPI = 120
PREFERRED_IMAGE_KEYS = [
    "image1",
    "image2",
    "agentview_image",
    "robot0_eye_in_hand_image",
    "front_image",
    "wrist_image",
]
IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"]


def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None):
    preset_name = preset_name or DEFAULT_PRESET
    if preset_name == "Custom":
        repo_id = str(custom_repo_id or "").strip()
        filename = str(custom_filename or "").strip()
        if not repo_id or not filename:
            raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.")
        return repo_id, filename

    item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
    return item["repo_id"], item["filename"]


def get_default_reverse_channels(preset_name):
    """Dataset-specific default for BGR<->RGB reversal.

    Robosuite Square presets use normal RGB ordering.
    InsertT / PushT-style preset requires reversal.
    Custom datasets default to False so users can still override manually.
    """
    preset_name = preset_name or DEFAULT_PRESET
    if preset_name == "Custom":
        return False
    item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET])
    return bool(item.get("default_reverse_channels", False))


@lru_cache(maxsize=8)
def get_local_hdf5_path(repo_id, filename):
    return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE)


def _natural_sort_key(name):
    match = re.search(r"([0-9]+)$", str(name))
    if match:
        return 0, int(match.group(1))
    return 1, str(name)


@lru_cache(maxsize=8)
def get_trajectory_keys(repo_id, filename):
    path = get_local_hdf5_path(repo_id, filename)
    with h5py.File(path, "r") as f:
        root_episode_keys = [
            key for key in f.keys()
            if isinstance(f[key], h5py.Group) and str(key).startswith("episode_")
        ]
        if root_episode_keys:
            return tuple(sorted(root_episode_keys, key=_natural_sort_key))

        if "data" in f and isinstance(f["data"], h5py.Group):
            data_group = f["data"]
            keys = [key for key in data_group.keys() if isinstance(data_group[key], h5py.Group)]
            return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key))

        keys = [key for key in f.keys() if isinstance(f[key], h5py.Group)]
        return tuple(sorted(keys, key=_natural_sort_key))


@lru_cache(maxsize=8)
def get_num_trajectories(repo_id, filename):
    return len(get_trajectory_keys(repo_id, filename))


def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180):
    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    path = get_local_hdf5_path(repo_id, filename)

    lines = []
    with h5py.File(path, "r") as f:
        def visitor(name, obj):
            if len(lines) >= max_lines:
                return
            if isinstance(obj, h5py.Dataset):
                lines.append("DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype))
            elif isinstance(obj, h5py.Group):
                lines.append("GROUP   {}".format(name))

        f.visititems(visitor)

    if len(lines) >= max_lines:
        lines.append("...")
    return "\n".join(lines) if lines else "No HDF5 contents found."


def _read_dataset_value(dataset):
    value = dataset[()]
    if isinstance(value, bytes):
        return value.decode("utf-8")
    return value


def _read_group_recursive(group):
    out = {}
    for key, obj in group.items():
        if isinstance(obj, h5py.Dataset):
            out[key] = _read_dataset_value(obj)
        elif isinstance(obj, h5py.Group):
            out[key] = _read_group_recursive(obj)
    return out


def _find_first_key(mapping, candidate_keys):
    for key in candidate_keys:
        if key in mapping:
            return key
    return None


def _infer_time_length(data):
    for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]:
        if key in data:
            arr = np.asarray(data[key])
            if arr.ndim >= 1:
                return int(arr.shape[0])

    obs_group = None
    if isinstance(data.get("observation"), dict):
        obs_group = data["observation"]
    elif isinstance(data.get("obs"), dict):
        obs_group = data["obs"]

    if obs_group:
        lengths = []
        for value in obs_group.values():
            arr = np.asarray(value)
            if arr.ndim >= 1:
                lengths.append(int(arr.shape[0]))
        if lengths:
            values, counts = np.unique(lengths, return_counts=True)
            return int(values[np.argmax(counts)])
    return 1


def _slice_time(value, t, T):
    arr = np.asarray(value)
    if arr.ndim >= 1 and arr.shape[0] == T:
        return arr[t]
    return arr


@lru_cache(maxsize=64)
def load_traj(repo_id, filename, traj_id):
    traj_keys = get_trajectory_keys(repo_id, filename)
    if not traj_keys:
        return []

    traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1))
    traj_key = traj_keys[traj_id]
    path = get_local_hdf5_path(repo_id, filename)

    with h5py.File(path, "r") as f:
        data = _read_group_recursive(f[traj_key])

    T = _infer_time_length(data)

    if isinstance(data.get("observation"), dict):
        obs_all = data["observation"]
    elif isinstance(data.get("obs"), dict):
        obs_all = data["obs"]
    else:
        obs_all = {}

    action_key = _find_first_key(data, ["actions", "action"])
    teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"])
    robot_key = _find_first_key(data, ["robot_actions", "robot_action"])
    no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"])
    no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"])
    done_key = _find_first_key(data, ["dones", "done"])
    timestep_key = _find_first_key(data, ["timesteps", "timestep"])
    success_key = _find_first_key(data, ["if_success", "success", "successes"])

    traj = []
    for t in range(T):
        obs_t = {key: _slice_time(value, t, T) for key, value in obs_all.items()}

        default_action = np.zeros(1, dtype=np.float32)
        if action_key is not None:
            default_action = _slice_time(data[action_key], t, T)

        teacher_action = _slice_time(data[teacher_key], t, T) if teacher_key else default_action
        robot_action = _slice_time(data[robot_key], t, T) if robot_key else default_action
        no_teacher = _slice_time(data[no_teacher_key], t, T) if no_teacher_key else False
        no_robot = _slice_time(data[no_robot_key], t, T) if no_robot_key else False
        done = _slice_time(data[done_key], t, T) if done_key else False
        if_success = _slice_time(data[success_key], t, T) if success_key else False

        timestep = t
        if timestep_key is not None:
            timestep_arr = _slice_time(data[timestep_key], t, T)
            timestep = int(np.asarray(timestep_arr).reshape(-1)[0])

        traj.append({
            "obs": obs_t,
            "robot_action": np.asarray(robot_action),
            "teacher_action": np.asarray(teacher_action),
            "done": bool(np.asarray(done).reshape(-1)[0]),
            "timestep": timestep,
            "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]),
            "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]),
            "episode_id": traj_key,
            "if_success": bool(np.asarray(if_success).reshape(-1)[0]),
        })

    return traj


def _extract_latest_obs_value(value):
    """Return the latest stacked observation only when there is a clear stack axis.

    Important:
    - [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame.
    - [C, H, W] must NOT be sliced, otherwise an RGB image becomes one
      grayscale channel.
    """
    arr = np.asarray(value)

    # Stacked image observations, e.g. [obs_T, C, H, W] or [obs_T, H, W, C].
    if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
        channel_first = arr.shape[1] in (1, 3, 4)
        channel_last = arr.shape[-1] in (1, 3, 4)
        if channel_first or channel_last:
            return arr[-1]

    # Stacked vector observations, e.g. [obs_T, D]. Keep this for non-image obs.
    if arr.ndim == 2 and arr.shape[0] in (1, 2):
        return arr[-1]

    return arr


def _looks_like_image_array(key, value):
    arr = np.asarray(value)
    key_l = str(key).lower()
    key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS)

    # Remove only a clear stacked-image axis for shape detection.
    if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4):
        if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4):
            arr = arr[-1]

    shape_hint = False
    if arr.ndim == 2:
        shape_hint = True
    elif arr.ndim == 3:
        shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4)
    elif arr.ndim == 4:
        shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4)

    return key_hint or shape_hint


def _float_img_to_uint8(img):
    arr = img.astype(np.float32)
    arr_min = float(np.nanmin(arr))
    arr_max = float(np.nanmax(arr))

    if arr_min >= -1.01 and arr_max <= 1.01:
        if arr_min < 0.0:
            arr = (arr + 1.0) * 0.5
        arr = np.clip(arr, 0.0, 1.0) * 255.0
    elif arr_max <= 255.0:
        arr = np.clip(arr, 0.0, 255.0)
    else:
        arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8)

    return np.round(arr).astype(np.uint8)


def _extract_display_image(value, reverse_channels=False):
    img = np.asarray(_extract_latest_obs_value(value))

    if img.ndim == 2:
        img = np.repeat(img[..., None], 3, axis=-1)
    elif img.ndim == 3 and img.shape[0] in (1, 3, 4):
        img = np.transpose(img, (1, 2, 0))

    if img.ndim == 3 and img.shape[-1] == 1:
        img = np.repeat(img, 3, axis=-1)
    elif img.ndim == 3 and img.shape[-1] == 4:
        img = img[..., :3]

    if img.ndim != 3:
        raise ValueError("Unsupported image shape: {}".format(img.shape))

    out = img.copy() if img.dtype == np.uint8 else _float_img_to_uint8(img)

    if reverse_channels and out.shape[-1] == 3:
        out = out[..., ::-1]
    return out


def _resize_image_for_display(img, display_scale):
    scale = float(display_scale)
    if scale == 1.0:
        return img

    h, w = img.shape[:2]
    new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale))))

    if cv2 is not None:
        return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST)

    pil_img = Image.fromarray(img)
    return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST))


def _extract_mixed_action_chunk(traj, start_idx, chunk_length):
    chunk = []
    sources = []
    end_idx = min(len(traj), int(start_idx) + int(chunk_length))

    for idx in range(int(start_idx), end_idx):
        step = traj[idx]
        use_teacher = not bool(step.get("no_teacher_action", False))
        action = step["teacher_action"] if use_teacher else step["robot_action"]
        chunk.append(np.asarray(action, dtype=np.float32).reshape(-1))
        sources.append("T" if use_teacher else "R")

    if not chunk:
        return None, ""
    return np.stack(chunk, axis=0), "".join(sources)


def _extract_robot_action_chunk(traj, start_idx, chunk_length):
    chunk = []
    end_idx = min(len(traj), int(start_idx) + int(chunk_length))

    for idx in range(int(start_idx), end_idx):
        step = traj[idx]
        chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1))

    if not chunk:
        return None
    return np.stack(chunk, axis=0)


def _safe_array_str(value, precision=3, max_items=24):
    arr = np.asarray(value).reshape(-1)
    shown = arr[:max_items]
    text = np.array2string(shown, precision=precision, separator=", ")
    if arr.size > max_items:
        text += " ... +{} more".format(arr.size - max_items)
    return text


def _make_action_chunk_plot(mixed_chunk, robot_chunk):
    if mixed_chunk is None:
        return None

    mixed_chunk = np.asarray(mixed_chunk, dtype=np.float32)
    if mixed_chunk.ndim == 1:
        mixed_chunk = mixed_chunk[:, None]

    fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140)
    x = np.arange(mixed_chunk.shape[0])
    max_dims = min(mixed_chunk.shape[1], 10)

    for dim in range(max_dims):
        ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim))

    if robot_chunk is not None:
        robot_chunk = np.asarray(robot_chunk, dtype=np.float32)
        if robot_chunk.ndim == 1:
            robot_chunk = robot_chunk[:, None]
        for dim in range(min(robot_chunk.shape[1], max_dims)):
            ax.plot(
                x,
                robot_chunk[:, dim],
                linestyle="--",
                alpha=0.55,
                label="robot[{}]".format(dim),
            )

    ax.set_title("Action chunk")
    ax.set_xlabel("chunk step")
    ax.set_ylabel("action value")
    ax.grid(True, alpha=0.3)
    ax.legend(loc="upper right", fontsize=7, ncol=2)
    fig.tight_layout()
    fig.canvas.draw()
    rgba = np.asarray(fig.canvas.buffer_rgba())
    image = rgba[..., :3].copy()
    plt.close(fig)
    return image


@lru_cache(maxsize=8192)
def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, reverse_channels):
    traj = load_traj(repo_id, filename, int(traj_id))
    timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
    obs = traj[timestep].get("obs", {})

    gallery_items = []
    warnings = []
    for key in image_keys_tuple:
        if key not in obs:
            warnings.append("Missing image key: {}".format(key))
            continue
        try:
            img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels))
            img = _resize_image_for_display(img, float(display_scale))
            gallery_items.append((img, key))
        except Exception as exc:
            warnings.append("{}: {}".format(key, exc))

    return gallery_items, tuple(warnings)


def _compute_valid_start_indices(traj, min_seq_len):
    """Match the original local script's valid-start heuristic.

    A timestep is valid when the following min_seq_len steps all have
    no_teacher_action == False.
    """
    total_steps = len(traj)
    min_seq_len = int(max(1, min_seq_len))
    no_teacher = np.asarray(
        [int(bool(step.get("no_teacher_action", False))) for step in traj],
        dtype=np.int32,
    )

    valid_indices = []
    max_start = total_steps - min_seq_len + 1
    for t in range(max(0, max_start)):
        if int(np.sum(no_teacher[t:t + min_seq_len])) == 0:
            valid_indices.append(t)

    return no_teacher, valid_indices


def _make_trajectory_status_plot(traj, timestep, min_seq_len):
    """Render the same high-level status figure as the local matplotlib tool.

    Shows:
      - orange no_teacher_action step plot
      - green triangles for algorithmic valid start points
      - black vertical cursor at current timestep
    """
    total_steps = len(traj)
    if total_steps == 0:
        return None, False, 0

    timestep = int(np.clip(int(timestep), 0, total_steps - 1))
    timesteps = np.asarray(
        [int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)],
        dtype=np.int32,
    )
    no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len)
    is_valid_start = timestep in set(valid_indices)

    fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170)

    ax.step(
        np.arange(total_steps),
        no_teacher,
        where="post",
        label="no_teacher_action",
        color="orange",
    )

    if valid_indices:
        ax.scatter(
            valid_indices,
            [-0.15] * len(valid_indices),
            color="green",
            marker="^",
            s=18,
            label="Valid Start (len >= {})".format(int(min_seq_len)),
        )

    ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5)
    ax.set_xlim(0, max(total_steps - 1, 1))
    ax.set_ylim(-0.38, 1.1)
    ax.set_ylabel("Flag", fontsize=10)
    ax.set_xlabel("Timestep index", fontsize=10)
    ax.set_yticks([0, 1])
    ax.set_yticklabels(["False", "True"])
    ax.grid(True, axis="x", alpha=0.2)

    title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1)
    if is_valid_start:
        title += " | VALID START"
    ax.set_title(title, fontsize=11)
    ax.tick_params(axis="both", labelsize=9)
    ax.legend(loc="upper right", fontsize=9)

    # Add saved timestep annotation if the stored timestep is not the same as index.
    saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep
    if saved_timestep != timestep:
        ax.text(
            0.01,
            0.04,
            "saved timestep: {}".format(saved_timestep),
            transform=ax.transAxes,
            fontsize=8,
            va="bottom",
            ha="left",
        )

    fig.tight_layout()
    fig.canvas.draw()
    rgba = np.asarray(fig.canvas.buffer_rgba())
    image = rgba[..., :3].copy()
    plt.close(fig)

    return image, bool(is_valid_start), len(valid_indices)


@lru_cache(maxsize=8192)
def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len):
    traj = load_traj(repo_id, filename, int(traj_id))
    timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
    return _make_trajectory_status_plot(traj, timestep, int(min_seq_len))


def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels):
    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    n_traj = get_num_trajectories(repo_id, filename)
    if n_traj == 0:
        return "No trajectories found."

    traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
    traj = load_traj(repo_id, filename, traj_id)
    if not traj:
        return "Trajectory could not be loaded."

    if image_keys is None:
        image_keys = []
    if isinstance(image_keys, str):
        image_keys = [image_keys]
    image_keys_tuple = tuple(image_keys)

    total = len(traj)
    for t in range(total):
        get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels))
        get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len))

    status = "Preloaded trajectory {}".format(traj_id)
    status += "\nFrames cached: {}".format(total)
    status += "\nImage keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none")
    return status


def _compose_video_frame(gallery_items, frame_label, status_plot=None):
    """Compose one video frame.

    Top: selected observation images.
    Bottom: trajectory-status plot with the moving timestep cursor.

    Important: do NOT downscale the status plot to the image width. The plot
    contains tick labels and a legend, so preserving its native width makes the
    generated MP4 much more readable.
    """
    small_text_y = 3

    if not gallery_items:
        obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20))
        draw = ImageDraw.Draw(obs_canvas)
        draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255))
    else:
        pil_images = []
        for img, label in gallery_items:
            pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB")

            # Keep the image-key caption compact; large captions waste video space.
            label_h = 16
            panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0))
            panel.paste(pil_img, (0, label_h))
            draw = ImageDraw.Draw(panel)
            draw.text((4, small_text_y), str(label), fill=(220, 220, 220))
            pil_images.append(panel)

        gap = 8
        top_h = 18
        width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0)
        height = max(im.height for im in pil_images) + top_h
        obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0))
        draw = ImageDraw.Draw(obs_canvas)

        # Compact frame label above the image panels.
        draw.text((6, small_text_y), frame_label, fill=(220, 220, 220))

        x = 0
        for im in pil_images:
            obs_canvas.paste(im, (x, top_h))
            x += im.width + gap

    if status_plot is not None:
        status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB")

        # Preserve the status plot resolution. If needed, pad the observation
        # canvas to the same width and center it above the plot.
        final_w = max(obs_canvas.width, status_img.width)
        if obs_canvas.width < final_w:
            padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0))
            padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0))
            obs_canvas = padded_obs
        elif status_img.width < final_w:
            padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255))
            padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0))
            status_img = padded_status

        gap_h = 8
        canvas = Image.new(
            "RGB",
            (final_w, obs_canvas.height + gap_h + status_img.height),
            color=(0, 0, 0),
        )
        canvas.paste(obs_canvas, (0, 0))
        canvas.paste(status_img, (0, obs_canvas.height + gap_h))
    else:
        canvas = obs_canvas

    # Many MP4 encoders prefer dimensions divisible by 16.
    pad_w = int(np.ceil(canvas.width / 16.0) * 16)
    pad_h = int(np.ceil(canvas.height / 16.0) * 16)
    if pad_w != canvas.width or pad_h != canvas.height:
        padded = Image.new("RGB", (pad_w, pad_h), color=(0, 0, 0))
        padded.paste(canvas, (0, 0))
        canvas = padded

    return np.asarray(canvas)


@lru_cache(maxsize=128)
def get_video_status_plot_base(repo_id, filename, traj_id, valid_window_len):
    """Render the static part of the status plot once for video export.

    Matplotlib per frame is slow. This function draws no_teacher_action and
    valid-start markers once, records the axes pixel bounds, and returns a base
    image. The moving cursor is later drawn with PIL, which is much faster.
    """
    traj = load_traj(repo_id, filename, int(traj_id))
    total_steps = len(traj)
    if total_steps == 0:
        return None, (0, 0, 1, 1), 0

    no_teacher, valid_indices = _compute_valid_start_indices(traj, int(valid_window_len))

    fig, ax = plt.subplots(figsize=VIDEO_STATUS_FIGSIZE, dpi=VIDEO_STATUS_DPI)
    ax.step(
        np.arange(total_steps),
        no_teacher,
        where="post",
        label="no_teacher_action",
        color="orange",
    )

    if valid_indices:
        ax.scatter(
            valid_indices,
            [-0.15] * len(valid_indices),
            color="green",
            marker="^",
            s=18,
            label="Valid Start (len >= {})".format(int(valid_window_len)),
        )

    ax.set_xlim(0, max(total_steps - 1, 1))
    ax.set_ylim(-0.38, 1.1)
    ax.set_ylabel("Flag", fontsize=8)
    ax.set_xlabel("Timestep index", fontsize=8)
    ax.set_yticks([0, 1])
    ax.set_yticklabels(["False", "True"])
    ax.grid(True, axis="x", alpha=0.2)
    ax.set_title("no_teacher_action and valid starts", fontsize=9)
    ax.tick_params(axis="both", labelsize=7)
    ax.legend(loc="upper right", fontsize=7)
    fig.tight_layout()
    fig.canvas.draw()

    rgba = np.asarray(fig.canvas.buffer_rgba())
    base = rgba[..., :3].copy()

    bbox = ax.get_window_extent()
    height = base.shape[0]

    # Matplotlib bbox origin is bottom-left, image origin is top-left.
    x0 = int(round(bbox.x0))
    x1 = int(round(bbox.x1))
    y0 = int(round(height - bbox.y1))
    y1 = int(round(height - bbox.y0))

    plt.close(fig)
    return base, (x0, y0, x1, y1), total_steps


@lru_cache(maxsize=8192)
def get_cached_video_status_frame(repo_id, filename, traj_id, timestep, valid_window_len):
    """Draw the moving cursor on a cached static status plot."""
    base, bounds, total_steps = get_video_status_plot_base(
        repo_id,
        filename,
        int(traj_id),
        int(valid_window_len),
    )
    if base is None:
        return None

    timestep = int(np.clip(int(timestep), 0, max(total_steps - 1, 0)))
    x0, y0, x1, y1 = bounds
    denom = max(total_steps - 1, 1)
    x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom)))

    img = Image.fromarray(np.asarray(base, dtype=np.uint8)).convert("RGB")
    draw = ImageDraw.Draw(img)

    # Moving cursor.
    draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=4)

    # Compact step label, top-left of the plot area.
    label = "step {}/{}".format(timestep, total_steps - 1)
    draw.rectangle((x0 + 4, y0 + 4, x0 + 118, y0 + 24), fill=(255, 255, 255))
    draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0))

    return np.asarray(img)

def _draw_status_cursor_on_base(base, bounds, total_steps, timestep):
    """Fast video status frame: copy one static Matplotlib image and draw cursor.

    This avoids calling the lru-cached per-timestep status frame function during
    video export. For long trajectories, caching thousands of status images can
    consume a lot of memory and still requires PIL conversion for every frame.
    """
    if base is None:
        return None

    total_steps = int(max(total_steps, 1))
    timestep = int(np.clip(int(timestep), 0, total_steps - 1))
    x0, y0, x1, y1 = [int(v) for v in bounds]
    denom = max(total_steps - 1, 1)
    x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom)))

    img = np.asarray(base, dtype=np.uint8).copy()

    # Draw the vertical cursor directly with NumPy. This is much cheaper than
    # creating a Matplotlib plot for every frame.
    x_left = max(0, x - 2)
    x_right = min(img.shape[1], x + 2)
    y_top = max(0, y0)
    y_bottom = min(img.shape[0], y1)
    img[y_top:y_bottom, x_left:x_right, :] = 0

    # Small text label. PIL is used only for the label, not for the whole plot.
    pil_img = Image.fromarray(img).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    label = "step {}/{}".format(timestep, total_steps - 1)
    draw.rectangle((x0 + 4, y0 + 4, x0 + 126, y0 + 24), fill=(255, 255, 255))
    draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0))
    return np.asarray(pil_img)


def _get_fast_video_writer(out_path, fps):
    """Use ffmpeg's ultrafast x264 preset for interactive Spaces exports."""
    return imageio.get_writer(
        out_path,
        fps=float(fps),
        codec="libx264",
        macro_block_size=16,
        ffmpeg_params=[
            "-preset", "ultrafast",
            "-crf", "28",
            "-pix_fmt", "yuv420p",
            "-movflags", "+faststart",
        ],
    )


def build_current_trajectory_video(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, display_scale, reverse_channels, fps, valid_window_len, video_stride=4):
    if imageio is None:
        return None, "Video export requires imageio and imageio-ffmpeg in requirements.txt."

    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    n_traj = get_num_trajectories(repo_id, filename)
    if n_traj == 0:
        return None, "No trajectories found."

    traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
    traj = load_traj(repo_id, filename, traj_id)
    if not traj:
        return None, "Trajectory could not be loaded."

    if image_keys is None:
        image_keys = []
    if isinstance(image_keys, str):
        image_keys = [image_keys]
    image_keys_tuple = tuple(image_keys)

    video_stride = int(max(1, int(video_stride)))
    frame_indices = list(range(0, len(traj), video_stride))
    if frame_indices and frame_indices[-1] != len(traj) - 1:
        frame_indices.append(len(traj) - 1)

    safe_repo = re.sub(r"[^A-Za-z0-9_.-]+", "_", repo_id)
    safe_file = re.sub(r"[^A-Za-z0-9_.-]+", "_", filename)[-80:]
    out_path = os.path.join(
        tempfile.gettempdir(),
        "trajectory_{}_{}_traj{:04d}_fps{}_stride{}.mp4".format(
            safe_repo, safe_file, traj_id, int(fps), video_stride
        ),
    )

    # Build the static status plot once. During export, only draw the cursor.
    status_base, status_bounds, total_steps = get_video_status_plot_base(
        repo_id,
        filename,
        traj_id,
        int(valid_window_len),
    )

    writer = _get_fast_video_writer(out_path, fps)
    written = 0
    try:
        for t in frame_indices:
            # Use the existing cached image extraction for correctness, but avoid
            # cached per-timestep status images to reduce memory pressure.
            gallery_items, _warnings = get_cached_gallery_items(
                repo_id,
                filename,
                traj_id,
                t,
                image_keys_tuple,
                float(display_scale),
                bool(reverse_channels),
            )
            label = "trajectory {} | frame {}/{}".format(traj_id, t, len(traj) - 1)
            status_plot = _draw_status_cursor_on_base(status_base, status_bounds, total_steps, t)
            frame = _compose_video_frame(gallery_items, label, status_plot=status_plot)
            writer.append_data(frame)
            written += 1
    finally:
        writer.close()

    approx_seconds = float(written) / float(max(float(fps), 1.0))
    status = "Built trajectory video with optimized encoder and status rendering"
    status += "\nTrajectory: {}".format(traj_id)
    status += "\nOriginal timesteps: {} | Written frames: {} | Stride: {}".format(len(traj), written, video_stride)
    status += "\nFPS: {} | Approx video duration: {:.1f}s".format(fps, approx_seconds)
    status += "\nValid-window length: {}".format(int(valid_window_len))
    status += "\nSpeedups: x264 ultrafast preset; static status plot rendered once; cursor drawn with NumPy/PIL"
    return out_path, status

def get_available_image_keys(repo_id, filename, traj_id):
    n_traj = get_num_trajectories(repo_id, filename)
    if n_traj == 0:
        return []

    traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
    traj = load_traj(repo_id, filename, traj_id)
    if not traj:
        return []

    obs = traj[0].get("obs", {})
    keys = []
    for key, value in obs.items():
        try:
            if _looks_like_image_array(key, value):
                keys.append(key)
        except Exception:
            pass

    ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys]
    ordered.extend([key for key in keys if key not in ordered])
    return ordered


def update_custom_visibility(preset_name):
    visible = preset_name == "Custom"
    return gr.update(visible=visible), gr.update(visible=visible)


def update_after_dataset_change(preset_name, custom_repo_id, custom_filename):
    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    n_traj = get_num_trajectories(repo_id, filename)

    reverse_default = get_default_reverse_channels(preset_name)

    if n_traj == 0:
        status = "Loaded `{}` / `{}`".format(repo_id, filename)
        status += "\nDetected trajectories: 0"
        status += "\nreverse_channels default: {}".format(int(reverse_default))
        return (
            gr.update(maximum=1, value=0),
            gr.update(maximum=1, value=0),
            gr.update(choices=[], value=[]),
            status,
            gr.update(value=reverse_default),
        )

    keys = get_available_image_keys(repo_id, filename, 0)
    traj = load_traj(repo_id, filename, 0)

    status = "Loaded `{}` / `{}`".format(repo_id, filename)
    status += "\nDetected trajectories: {}".format(n_traj)
    status += "\nreverse_channels default: {}".format(int(reverse_default))

    return (
        gr.update(maximum=max(n_traj - 1, 1), value=0),
        gr.update(maximum=max(len(traj) - 1, 1), value=0),
        gr.update(choices=keys, value=keys[:2]),
        status,
        gr.update(value=reverse_default),
    )


def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id):
    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    n_traj = get_num_trajectories(repo_id, filename)
    if n_traj == 0:
        return gr.update(maximum=1, value=0), gr.update(choices=[], value=[])

    traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
    traj = load_traj(repo_id, filename, traj_id)
    keys = get_available_image_keys(repo_id, filename, traj_id)

    return (
        gr.update(maximum=max(len(traj) - 1, 1), value=0),
        gr.update(choices=keys, value=keys[:2]),
    )


def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels):
    repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename)
    n_traj = get_num_trajectories(repo_id, filename)

    if n_traj == 0:
        return [], None, "No trajectory groups found. Open Debug: HDF5 tree."

    traj_id = int(np.clip(int(traj_id), 0, n_traj - 1))
    traj = load_traj(repo_id, filename, traj_id)
    if not traj:
        return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree."

    timestep = int(np.clip(int(timestep), 0, len(traj) - 1))
    chunk_len = int(chunk_len)
    display_scale = float(display_scale)

    if image_keys is None:
        image_keys = []
    if isinstance(image_keys, str):
        image_keys = [image_keys]

    step = traj[timestep]
    image_keys_tuple = tuple(image_keys)

    gallery_items, warnings_tuple = get_cached_gallery_items(
        repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, bool(reverse_channels)
    )
    warnings = list(warnings_tuple)

    status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len)

    image_debug_lines = []
    for _key in image_keys:
        if _key in step.get("obs", {}):
            _arr = np.asarray(step["obs"][_key])
            image_debug_lines.append(
                "{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype)
            )

    info_lines = [
        "dataset: {} / {}".format(repo_id, filename),
        "detected trajectories: {}".format(n_traj),
        "trajectory: {}".format(traj_id),
        "episode_id: {}".format(step.get("episode_id", "")),
        "timestep: {} / {}".format(timestep, len(traj) - 1),
        "saved timestep: {}".format(step.get("timestep", timestep)),
        "done: {}".format(int(bool(step.get("done", False)))),
        "if_success: {}".format(int(bool(step.get("if_success", False)))),
        "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))),
        "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))),
        "valid-window length: {}".format(chunk_len),
        "valid_start: {}".format(int(bool(is_valid_start))),
        "num_valid_starts: {}".format(num_valid_starts),
        "",
        "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))),
        "robot_action:   {}".format(_safe_array_str(step.get("robot_action", []))),
        "",
        "selected image tensors:",
        *image_debug_lines,
    ]

    if warnings:
        info_lines.append("")
        info_lines.append("Image warnings:")
        info_lines.extend(warnings)

    return gallery_items, status_plot, "\n".join(info_lines)


def build_app():
    repo_id, filename = resolve_dataset(DEFAULT_PRESET)

    try:
        n_traj = get_num_trajectories(repo_id, filename)
        first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else []
        startup_warning = ""
    except Exception as exc:
        n_traj = 0
        first_keys = []
        startup_warning = repr(exc)

    default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET)))

    with gr.Blocks(title="HDF5 Trajectory Viewer") as demo:
        gr.Markdown(
            "# HDF5 Trajectory Viewer\n\n"
            "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n"
            "The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor."
        )

        if startup_warning:
            gr.Markdown("Startup warning: `{}`".format(startup_warning))

        with gr.Row():
            preset = gr.Dropdown(
                choices=list(DATASET_PRESETS.keys()) + ["Custom"],
                value=DEFAULT_PRESET,
                label="Dataset preset",
            )
            custom_repo_id = gr.Textbox(value="", label="Custom repo_id, e.g. Zhaoting123/InsertT", visible=False)
            custom_filename = gr.Textbox(value="", label="Custom HDF5 path in repo", visible=False)

        dataset_status = gr.Textbox(label="Dataset status", lines=2, value=default_status, interactive=False)

        with gr.Row():
            traj_slider = gr.Slider(minimum=0, maximum=max(n_traj - 1, 1), value=0, step=1, label="Trajectory index")
            timestep_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Timestep")

        with gr.Row():
            image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys")
            chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length")
            display_scale = gr.State(value=DEFAULT_DISPLAY_SCALE)
            reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB")

        with gr.Row():
            render_btn = gr.Button("Render frame", variant="primary")
            preload_btn = gr.Button("Preload current trajectory")
            video_btn = gr.Button("Build trajectory video")
            video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS")
            video_stride = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Video frame stride")

        preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False)

        with gr.Row():
            with gr.Column(scale=3):
                gallery = gr.Gallery(
                    label="Camera images",
                    columns=2,
                    height=360,
                    object_fit="contain",
                )
            with gr.Column(scale=2):
                status_plot = gr.Image(
                    label="no_teacher_action + valid starts",
                    type="numpy",
                    height=360,
                )

        trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback")
        info = gr.Textbox(label="Frame info", lines=16)

        with gr.Accordion("Debug: HDF5 tree", open=False):
            inspect_btn = gr.Button("Inspect HDF5 structure")
            hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree")

        preset.change(
            fn=update_custom_visibility,
            inputs=preset,
            outputs=[custom_repo_id, custom_filename],
        ).then(
            fn=update_after_dataset_change,
            inputs=[preset, custom_repo_id, custom_filename],
            outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
        ).then(
            fn=render_frame,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=[gallery, status_plot, info],
        )

        custom_repo_id.submit(
            fn=update_after_dataset_change,
            inputs=[preset, custom_repo_id, custom_filename],
            outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
        )
        custom_filename.submit(
            fn=update_after_dataset_change,
            inputs=[preset, custom_repo_id, custom_filename],
            outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
        )

        traj_slider.change(
            fn=update_after_traj_change,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider],
            outputs=[timestep_slider, image_keys],
        ).then(
            fn=render_frame,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=[gallery, status_plot, info],
        )

        timestep_slider.release(
            fn=render_frame,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=[gallery, status_plot, info],
        )

        for widget in [image_keys, chunk_len, reverse_channels]:
            widget.change(
                fn=render_frame,
                inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
                outputs=[gallery, status_plot, info],
            )

        render_btn.click(
            fn=render_frame,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=[gallery, status_plot, info],
        )

        preload_btn.click(
            fn=preload_current_trajectory,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=preload_status,
        )

        video_btn.click(
            fn=build_current_trajectory_video,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, display_scale, reverse_channels, video_fps, chunk_len, video_stride],
            outputs=[trajectory_video, preload_status],
        )

        inspect_btn.click(
            fn=inspect_hdf5_tree,
            inputs=[preset, custom_repo_id, custom_filename],
            outputs=hdf5_tree,
        )

        demo.load(
            fn=update_after_dataset_change,
            inputs=[preset, custom_repo_id, custom_filename],
            outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels],
        ).then(
            fn=render_frame,
            inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels],
            outputs=[gallery, status_plot, info],
        )

    return demo


if __name__ == "__main__":
    demo = build_app()
    demo.launch(
        server_name="0.0.0.0",
        server_port=int(os.environ.get("PORT", 7860)),
        share=False,
        ssr_mode=False,
    )