oops / batch_select_nearest_touching_sidewalk.py
deansmile123's picture
Upload folder using huggingface_hub
75f0bc0 verified
# batch_select_nearest_touching_sidewalk.py
import argparse
import os
import glob
import csv
import pickle
import numpy as np
import cv2
def load_mask_stack(pkl_path):
with open(pkl_path, "rb") as f:
masks = pickle.load(f)
masks = np.asarray(masks)
if masks.ndim != 3:
raise ValueError(f"Expected (N,H,W), got {masks.shape} from {pkl_path}")
return masks.astype(bool)
def load_sidewalk_mask(path):
"""
Supports:
- .pkl: pickle of (H,W) or (N,H,W)
- .npy: numpy array (H,W) or (N,H,W)
- .png: image mask (nonzero => True)
For (N,H,W), returns union over N.
"""
ext = os.path.splitext(path)[1].lower()
if ext == ".pkl":
with open(path, "rb") as f:
arr = pickle.load(f)
arr = np.asarray(arr)
elif ext == ".npy":
arr = np.load(path)
elif ext in [".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"]:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None:
raise FileNotFoundError(f"Failed to read sidewalk image: {path}")
if img.ndim == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
arr = (img > 0).astype(bool)
return arr
else:
raise ValueError(f"Unsupported sidewalk mask extension: {ext} ({path})")
if arr.ndim == 3:
return np.any(arr.astype(bool), axis=0)
if arr.ndim == 2:
return arr.astype(bool)
raise ValueError(f"Unexpected sidewalk mask shape {arr.shape} from {path}")
def robust_depth(depth, mask, q=10.0):
valid = np.isfinite(depth) & (depth > 0)
pix = depth[mask & valid]
if pix.size == 0:
return np.inf
return float(np.percentile(pix, q))
def touches_sidewalk(obj_mask, sidewalk_mask, margin_px=8, min_contact_px=30, use_boundary=False):
sidewalk_u8 = sidewalk_mask.astype(np.uint8) * 255
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*margin_px+1, 2*margin_px+1))
sidewalk_dil = cv2.dilate(sidewalk_u8, k, iterations=1) > 0
if not use_boundary:
contact = obj_mask & sidewalk_dil
return int(contact.sum()) >= min_contact_px
# boundary-touch version
m = obj_mask.astype(np.uint8) * 255
kb = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
er = cv2.erode(m, kb, iterations=1)
bd = (m > 0) & (er == 0)
return int((bd & sidewalk_dil).sum()) >= max(5, min_contact_px // 3)
def find_matching_file(stem, folder, exts, allow_contains=False):
"""
Tries:
1) exact: folder/stem + ext
2) if allow_contains: glob folder/*stem*ext (first match)
"""
for ext in exts:
p = os.path.join(folder, stem + ext)
if os.path.exists(p):
return p
if allow_contains:
for ext in exts:
hits = sorted(glob.glob(os.path.join(folder, f"*{stem}*{ext}")))
if hits:
return hits[0]
return None
def overlay_mask(rgb_bgr, mask_bool, alpha=0.4):
overlay = rgb_bgr.copy()
red = np.array([0, 0, 255], dtype=np.uint8) # BGR
m = mask_bool
overlay[m] = (0.6 * overlay[m] + 0.4 * red).astype(np.uint8)
return overlay
def process_one(rgb_path, depth_path, masks_path, sidewalk_path, out_dir, args):
rgb = cv2.imread(rgb_path)
if rgb is None:
return {"status": "fail", "reason": "rgb_read_failed"}
depth = np.load(depth_path)
masks = load_mask_stack(masks_path)
sidewalk = load_sidewalk_mask(sidewalk_path)
if depth.shape != masks.shape[1:]:
return {"status": "fail", "reason": f"shape_mismatch_depth_vs_masks depth={depth.shape} masksHW={masks.shape[1:]}"}
if rgb.shape[:2] != depth.shape:
return {"status": "fail", "reason": f"shape_mismatch_rgb_vs_depth rgbHW={rgb.shape[:2]} depth={depth.shape}"}
if sidewalk.shape != depth.shape:
return {"status": "fail", "reason": f"shape_mismatch_sidewalk_vs_depth sidewalk={sidewalk.shape} depth={depth.shape}"}
best_i, best_score = None, np.inf
kept = 0
for i in range(masks.shape[0]):
m = masks[i]
if not touches_sidewalk(
m, sidewalk,
margin_px=args.margin_px,
min_contact_px=args.min_contact_px,
use_boundary=args.use_boundary
):
continue
kept += 1
score = robust_depth(depth, m, q=args.quantile)
if score < best_score:
best_score = score
best_i = i
os.makedirs(out_dir, exist_ok=True)
if best_i is None:
# still save a quick note file for debugging
with open(os.path.join(out_dir, "no_match.txt"), "w") as f:
f.write(f"No object touching sidewalk found. total_masks={masks.shape[0]} kept_after_touch={kept}\n")
return {"status": "no_match", "reason": f"no_touching_object kept={kept}/{masks.shape[0]}"}
nearest_mask = masks[best_i]
mask_png = os.path.join(out_dir, "nearest_mask.png")
overlay_png = os.path.join(out_dir, "nearest_overlay.png")
cv2.imwrite(mask_png, nearest_mask.astype(np.uint8) * 255)
cv2.imwrite(overlay_png, overlay_mask(rgb, nearest_mask, alpha=args.overlay_alpha))
# optional: save index + score
with open(os.path.join(out_dir, "nearest_meta.txt"), "w") as f:
f.write(f"best_i={best_i}\n")
f.write(f"depth_score_p{args.quantile:g}={best_score}\n")
f.write(f"total_masks={masks.shape[0]}\n")
f.write(f"kept_after_touch={kept}\n")
return {
"status": "ok",
"best_i": int(best_i),
"depth_score": float(best_score),
"total_masks": int(masks.shape[0]),
"kept_after_touch": int(kept),
"mask_png": mask_png,
"overlay_png": overlay_png,
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--rgb_dir", default="/scratch/ds5725/OOPS/images_resized",
help="folder with RGB images (png/jpg)")
ap.add_argument("--depth_dir", default="/scratch/ds5725/OOPS/depthpro_out",
help="folder with depth .npy")
ap.add_argument("--masks_dir", default="/scratch/ds5725/sam3/object_union_batch",
help="folder with object masks .pkl")
ap.add_argument("--sidewalk_dir", default="/scratch/ds5725/sam3/batch_surface",
help="folder with sidewalk masks (.pkl/.npy/.png)")
ap.add_argument("--out_dir", default="./nearest_out",
help="output root folder")
ap.add_argument("--rgb_exts", nargs="+", default=[".png", ".jpg", ".jpeg"],
help="RGB extensions to scan")
ap.add_argument("--quantile", type=float, default=10.0)
ap.add_argument("--margin_px", type=int, default=8)
ap.add_argument("--min_contact_px", type=int, default=30)
ap.add_argument("--use_boundary", action="store_true",
help="use boundary-touch instead of mask-touch")
ap.add_argument("--overlay_alpha", type=float, default=0.4)
# matching behavior
ap.add_argument("--allow_contains_match", action="store_true",
help="if exact stem.ext not found, try *stem*.ext glob in depth/masks/sidewalk dirs")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# gather RGBs
rgb_paths = []
for ext in args.rgb_exts:
rgb_paths.extend(sorted(glob.glob(os.path.join(args.rgb_dir, f"*{ext}"))))
rgb_paths = sorted(set(rgb_paths))
if not rgb_paths:
raise FileNotFoundError(f"No RGB images found in {args.rgb_dir} with exts {args.rgb_exts}")
summary_csv = os.path.join(args.out_dir, "summary.csv")
rows = []
for rgb_path in rgb_paths:
fname = os.path.basename(rgb_path)
stem = os.path.splitext(fname)[0]
depth_path = find_matching_file(stem, args.depth_dir, exts=[".npy"], allow_contains=args.allow_contains_match)
masks_path = find_matching_file(stem, args.masks_dir, exts=[".pkl"], allow_contains=args.allow_contains_match)
# sidewalk could be pkl/npy/png; try in that order
sidewalk_path = find_matching_file(stem, args.sidewalk_dir, exts=[".pkl"], allow_contains=args.allow_contains_match)
out_subdir = os.path.join(args.out_dir, stem)
missing = []
if depth_path is None: missing.append("depth")
if masks_path is None: missing.append("masks")
if sidewalk_path is None: missing.append("sidewalk")
if missing:
rows.append({
"image": fname,
"stem": stem,
"status": "skip_missing_inputs",
"reason": "missing_" + ",".join(missing),
"depth_path": depth_path or "",
"masks_path": masks_path or "",
"sidewalk_path": sidewalk_path or "",
"best_i": "",
"depth_score": "",
"total_masks": "",
"kept_after_touch": "",
"overlay_png": "",
})
continue
try:
res = process_one(rgb_path, depth_path, masks_path, sidewalk_path, out_subdir, args)
rows.append({
"image": fname,
"stem": stem,
"status": res.get("status", ""),
"reason": res.get("reason", ""),
"depth_path": depth_path,
"masks_path": masks_path,
"sidewalk_path": sidewalk_path,
"best_i": res.get("best_i", ""),
"depth_score": res.get("depth_score", ""),
"total_masks": res.get("total_masks", ""),
"kept_after_touch": res.get("kept_after_touch", ""),
"overlay_png": res.get("overlay_png", ""),
})
except Exception as e:
rows.append({
"image": fname,
"stem": stem,
"status": "fail_exception",
"reason": repr(e),
"depth_path": depth_path,
"masks_path": masks_path,
"sidewalk_path": sidewalk_path,
"best_i": "",
"depth_score": "",
"total_masks": "",
"kept_after_touch": "",
"overlay_png": "",
})
# write CSV
fieldnames = [
"image","stem","status","reason",
"depth_path","masks_path","sidewalk_path",
"best_i","depth_score","total_masks","kept_after_touch","overlay_png"
]
with open(summary_csv, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
print(f"Done. Wrote summary: {summary_csv}")
print(f"Outputs per image are in: {args.out_dir}/<stem>/nearest_mask.png and nearest_overlay.png")
if __name__ == "__main__":
main()