Spaces:
Running on Zero
Running on Zero
| import os | |
| import glob | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Any, Iterator, Optional, Tuple | |
| import torch | |
| from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval | |
| dataset_metadata: Dict[str, Dict[str, Any]] = { | |
| "davis": { | |
| "img_path": "data/davis/DAVIS/JPEGImages/480p", | |
| "mask_path": "data/davis/DAVIS/masked_images/480p", | |
| "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), | |
| "gt_traj_func": lambda img_path, anno_path, seq: None, | |
| "traj_format": None, | |
| "seq_list": None, | |
| "full_seq": True, | |
| "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq), | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "kitti": { | |
| "img_path": "data/kitti/sequences", | |
| "anno_path": "data/kitti/poses", | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "image_2"), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( | |
| anno_path, f"{seq}.txt" | |
| ) | |
| if os.path.exists(os.path.join(anno_path, f"{seq}.txt")) | |
| else None, | |
| "traj_format": "kitti", | |
| "seq_list": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"], | |
| "full_seq": True, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "bonn": { | |
| "img_path": "data/bonn/rgbd_bonn_dataset", | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join( | |
| img_path, f"rgbd_bonn_{seq}", "rgb_110" | |
| ), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( | |
| img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt" | |
| ), | |
| "traj_format": "tum", | |
| "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], | |
| "full_seq": False, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "nyu": { | |
| "img_path": "data/nyu-v2/val/nyu_images", | |
| "mask_path": None, | |
| "process_func": None, | |
| }, | |
| "scannet": { | |
| "img_path": "data/scannetv2", | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( | |
| img_path, seq, "pose_90.txt" | |
| ), | |
| "traj_format": "replica", | |
| "seq_list": None, | |
| "full_seq": True, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "tum": { | |
| "img_path": "data/tum", | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( | |
| img_path, seq, "groundtruth_90.txt" | |
| ), | |
| "traj_format": "tum", | |
| "seq_list": None, | |
| "full_seq": True, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "sintel": { | |
| "img_path": "data/sintel/training/final", | |
| "anno_path": "data/sintel/training/camdata_left", | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq), | |
| "traj_format": None, | |
| "seq_list": [ | |
| "alley_2", | |
| "ambush_4", | |
| "ambush_5", | |
| "ambush_6", | |
| "cave_2", | |
| "cave_4", | |
| "market_2", | |
| "market_5", | |
| "market_6", | |
| "shaman_3", | |
| "sleeping_1", | |
| "sleeping_2", | |
| "temple_2", | |
| "temple_3", | |
| ], | |
| "full_seq": False, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| "waymo": { | |
| "img_path": "/horizon-bucket/saturn_v_4dlabel/004_vision/01_users/tao02.xie/datasets/scatt3r_evaluation/waymo_open_dataset_v1_4_3", | |
| "anno_path": None, | |
| "mask_path": None, | |
| "dir_path_func": lambda img_path, seq: os.path.join( | |
| img_path, | |
| seq.split("_cam")[0] if "_cam" in seq else seq, | |
| "images", | |
| seq.split("_cam")[1] if "_cam" in seq else "00", | |
| ), | |
| "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( | |
| img_path, | |
| seq.split("_cam")[0] if "_cam" in seq else seq, | |
| "cameras", | |
| seq.split("_cam")[1] if "_cam" in seq else "00", | |
| "extri.yml", | |
| ), | |
| "traj_format": "waymo", | |
| "seq_list": None, | |
| "full_seq": True, | |
| "mask_path_seq_func": lambda mask_path, seq: None, | |
| "skip_condition": None, | |
| "process_func": None, | |
| }, | |
| } | |
| class LongStreamSequenceInfo: | |
| name: str | |
| scene_root: str | |
| image_dir: str | |
| image_paths: List[str] | |
| camera: Optional[str] | |
| class LongStreamSequence: | |
| def __init__( | |
| self, | |
| name: str, | |
| images: torch.Tensor, | |
| image_paths: List[str], | |
| scene_root: Optional[str] = None, | |
| image_dir: Optional[str] = None, | |
| camera: Optional[str] = None, | |
| ): | |
| self.name = name | |
| self.images = images | |
| self.image_paths = image_paths | |
| self.scene_root = scene_root | |
| self.image_dir = image_dir | |
| self.camera = camera | |
| def _read_list_file(path: str) -> List[str]: | |
| with open(path, "r") as f: | |
| lines = [] | |
| for line in f.readlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith("#"): | |
| continue | |
| lines.append(line) | |
| return lines | |
| def _is_generalizable_scene_root(path: str) -> bool: | |
| return os.path.isdir(os.path.join(path, "images")) | |
| def _direct_image_files(dir_path: str) -> List[str]: | |
| filelist = sorted(glob.glob(os.path.join(dir_path, "*.png"))) | |
| if not filelist: | |
| filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpg"))) | |
| if not filelist: | |
| filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpeg"))) | |
| return filelist | |
| class LongStreamDataLoader: | |
| def __init__(self, cfg: Dict[str, Any]): | |
| self.cfg = cfg | |
| self.dataset = cfg.get("dataset", None) | |
| meta = dataset_metadata.get(self.dataset, {}) | |
| self.img_path = cfg.get("img_path", meta.get("img_path")) | |
| self.mask_path = cfg.get("mask_path", meta.get("mask_path")) | |
| self.dir_path_func = meta.get("dir_path_func", lambda p, s: os.path.join(p, s)) | |
| self.mask_path_seq_func = meta.get("mask_path_seq_func", lambda p, s: None) | |
| self.full_seq = bool(cfg.get("full_seq", meta.get("full_seq", True))) | |
| self.seq_list = cfg.get("seq_list", None) | |
| self.stride = int(cfg.get("stride", 1)) | |
| self.max_frames = cfg.get("max_frames", None) | |
| self.size = int(cfg.get("size", 518)) | |
| self.crop = bool(cfg.get("crop", False)) | |
| self.patch_size = int(cfg.get("patch_size", 14)) | |
| self.format = cfg.get("format", "auto") | |
| self.data_roots_file = cfg.get("data_roots_file", None) | |
| self.split = cfg.get("split", None) | |
| self.camera = cfg.get("camera", None) | |
| def _infer_format(self) -> str: | |
| if self.format in ["relpose", "generalizable"]: | |
| return self.format | |
| if self.img_path is None: | |
| return "relpose" | |
| if _is_generalizable_scene_root(self.img_path): | |
| return "generalizable" | |
| default_list = self.data_roots_file or "data_roots.txt" | |
| if os.path.exists(os.path.join(self.img_path, default_list)): | |
| return "generalizable" | |
| return "relpose" | |
| def _resolve_seq_list_generalizable(self) -> List[str]: | |
| if self.seq_list is not None: | |
| return list(self.seq_list) | |
| if self.img_path is None or not os.path.isdir(self.img_path): | |
| return [] | |
| if _is_generalizable_scene_root(self.img_path): | |
| return [self.img_path] | |
| candidates = [] | |
| if isinstance(self.data_roots_file, str) and self.data_roots_file: | |
| candidates.append(self.data_roots_file) | |
| if isinstance(self.split, str) and self.split: | |
| split_name = self.split.lower() | |
| if split_name in ["val", "valid", "validate"]: | |
| split_name = "validate" | |
| candidates.append(f"{split_name}_data_roots.txt") | |
| candidates.append("data_roots.txt") | |
| candidates.append("train_data_roots.txt") | |
| candidates.append("validate_data_roots.txt") | |
| for fname in candidates: | |
| path = os.path.join(self.img_path, fname) | |
| if os.path.exists(path): | |
| return _read_list_file(path) | |
| img_dirs = sorted( | |
| glob.glob(os.path.join(self.img_path, "**", "images"), recursive=True) | |
| ) | |
| scene_roots = [os.path.dirname(p) for p in img_dirs] | |
| rels = [] | |
| for p in scene_roots: | |
| try: | |
| rels.append(os.path.relpath(p, self.img_path)) | |
| except ValueError: | |
| rels.append(p) | |
| return sorted(set(rels)) | |
| def _resolve_seq_list_relpose(self) -> List[str]: | |
| if self.seq_list is not None: | |
| return list(self.seq_list) | |
| meta = dataset_metadata.get(self.dataset, {}) | |
| if self.full_seq: | |
| if self.img_path is None or not os.path.isdir(self.img_path): | |
| return [] | |
| seqs = [ | |
| s | |
| for s in os.listdir(self.img_path) | |
| if os.path.isdir(os.path.join(self.img_path, s)) | |
| ] | |
| return sorted(seqs) | |
| seqs = meta.get("seq_list", []) or [] | |
| return list(seqs) | |
| def _resolve_seq_list(self) -> List[str]: | |
| fmt = self._infer_format() | |
| if fmt == "generalizable": | |
| return self._resolve_seq_list_generalizable() | |
| return self._resolve_seq_list_relpose() | |
| def _resolve_scene_root(self, seq_entry: str) -> Tuple[str, str]: | |
| if os.path.isabs(seq_entry) or os.path.sep in seq_entry: | |
| scene_root = seq_entry | |
| name = os.path.basename(os.path.normpath(seq_entry)) | |
| else: | |
| scene_root = os.path.join(self.img_path, seq_entry) | |
| name = seq_entry | |
| return name, scene_root | |
| def _resolve_image_dir_generalizable(self, scene_root: str) -> Optional[str]: | |
| images_root = os.path.join(scene_root, "images") | |
| if not os.path.isdir(images_root): | |
| return None | |
| if isinstance(self.camera, str) and self.camera: | |
| cam_dir = os.path.join(images_root, self.camera) | |
| if os.path.isdir(cam_dir): | |
| return cam_dir | |
| if _direct_image_files(images_root): | |
| return images_root | |
| cams = [ | |
| d | |
| for d in os.listdir(images_root) | |
| if os.path.isdir(os.path.join(images_root, d)) | |
| ] | |
| if not cams: | |
| return None | |
| cams = sorted(cams) | |
| frame_dirs = [] | |
| for name in cams: | |
| child_dir = os.path.join(images_root, name) | |
| child_images = _direct_image_files(child_dir) | |
| if child_images: | |
| frame_dirs.append((name, len(child_images))) | |
| if ( | |
| len(cams) > 10 | |
| and len(frame_dirs) == len(cams) | |
| and max(count for _, count in frame_dirs) == 1 | |
| ): | |
| return images_root | |
| return os.path.join(images_root, cams[0]) | |
| def _camera_from_image_dir(self, image_dir: str) -> Optional[str]: | |
| parent = os.path.basename(os.path.dirname(image_dir)) | |
| if parent != "images": | |
| return None | |
| return os.path.basename(image_dir) | |
| def _collect_filelist(self, dir_path: str) -> List[str]: | |
| filelist = _direct_image_files(dir_path) | |
| if not filelist: | |
| nested = [] | |
| child_dirs = sorted( | |
| d for d in glob.glob(os.path.join(dir_path, "*")) if os.path.isdir(d) | |
| ) | |
| for child_dir in child_dirs: | |
| child_images = _direct_image_files(child_dir) | |
| if child_images: | |
| nested.append(child_images[0]) | |
| filelist = nested | |
| if self.stride > 1: | |
| filelist = filelist[:: self.stride] | |
| if self.max_frames is not None: | |
| filelist = filelist[: self.max_frames] | |
| return filelist | |
| def _load_images(self, filelist: List[str]) -> torch.Tensor: | |
| views = load_images_for_eval( | |
| filelist, | |
| size=self.size, | |
| verbose=False, | |
| crop=self.crop, | |
| patch_size=self.patch_size, | |
| ) | |
| imgs = torch.cat([view["img"] for view in views], dim=0) | |
| images = imgs.unsqueeze(0) | |
| images = (images + 1.0) / 2.0 | |
| return images | |
| def iter_sequence_infos(self) -> Iterator[LongStreamSequenceInfo]: | |
| fmt = self._infer_format() | |
| seqs = self._resolve_seq_list() | |
| for seq_entry in seqs: | |
| if fmt == "generalizable": | |
| seq, scene_root = self._resolve_scene_root(seq_entry) | |
| dir_path = self._resolve_image_dir_generalizable(scene_root) | |
| if dir_path is None or not os.path.isdir(dir_path): | |
| continue | |
| camera = self._camera_from_image_dir(dir_path) | |
| else: | |
| seq = seq_entry | |
| scene_root = os.path.join(self.img_path, seq) | |
| dir_path = self.dir_path_func(self.img_path, seq) | |
| if not os.path.isdir(dir_path): | |
| continue | |
| camera = None | |
| filelist = self._collect_filelist(dir_path) | |
| if not filelist: | |
| continue | |
| yield LongStreamSequenceInfo( | |
| name=seq, | |
| scene_root=scene_root, | |
| image_dir=dir_path, | |
| image_paths=filelist, | |
| camera=camera, | |
| ) | |
| def __iter__(self) -> Iterator[LongStreamSequence]: | |
| for info in self.iter_sequence_infos(): | |
| print( | |
| f"[longstream] loading sequence {info.name}: {len(info.image_paths)} frames", | |
| flush=True, | |
| ) | |
| images = self._load_images(info.image_paths) | |
| print( | |
| f"[longstream] loaded sequence {info.name}: {tuple(images.shape)}", | |
| flush=True, | |
| ) | |
| yield LongStreamSequence( | |
| info.name, | |
| images, | |
| info.image_paths, | |
| scene_root=info.scene_root, | |
| image_dir=info.image_dir, | |
| camera=info.camera, | |
| ) | |