| | |
| | """ |
| | utils.segmentation |
| | βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | All high-quality person-segmentation code for BackgroundFX Pro. |
| | |
| | Exports |
| | ------- |
| | segment_person_hq(image, predictor, fallback_enabled=True) β np.ndarray |
| | segment_person_hq_original(image, predictor, fallback_enabled=True) β np.ndarray |
| | SegmentationError - Custom exception for segmentation errors |
| | |
| | Everything else is prefixed "_" and considered private. |
| | """ |
| |
|
| | from __future__ import annotations |
| | from typing import Any, Tuple, Optional, Dict |
| | import logging, os, math |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | class SegmentationError(Exception): |
| | """Custom exception for segmentation-related errors""" |
| | pass |
| |
|
| | |
| | |
| | |
| | USE_ENHANCED_SEGMENTATION = True |
| | USE_INTELLIGENT_PROMPTING = True |
| | USE_ITERATIVE_REFINEMENT = True |
| |
|
| | MIN_AREA_RATIO = 0.015 |
| | MAX_AREA_RATIO = 0.97 |
| | SALIENCY_THRESH = 0.65 |
| | GRABCUT_ITERS = 3 |
| |
|
| | |
| | |
| | |
| | __all__ = [ |
| | "segment_person_hq", |
| | "segment_person_hq_original", |
| | "SegmentationError", |
| | ] |
| |
|
| | |
| | |
| | |
| | def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray: |
| | """ |
| | Convert SAM2 multi-mask output to single best mask for MatAnyone. |
| | SAM2 returns (N, H, W) where N is typically 3 masks. |
| | We need to return a single (H, W) mask. |
| | """ |
| | if masks is None or len(masks) == 0: |
| | raise SegmentationError("No masks returned from SAM2") |
| | |
| | |
| | if isinstance(masks, torch.Tensor): |
| | masks = masks.cpu().numpy() |
| | if scores is not None and isinstance(scores, torch.Tensor): |
| | scores = scores.cpu().numpy() |
| | |
| | |
| | if masks.ndim == 4: |
| | masks = masks[0] |
| | if masks.ndim != 3: |
| | raise SegmentationError(f"Unexpected mask shape: {masks.shape}") |
| | |
| | |
| | if scores is not None and len(scores) > 0: |
| | best_idx = int(np.argmax(scores)) |
| | else: |
| | |
| | areas = [np.sum(m > 0.5) for m in masks] |
| | best_idx = int(np.argmax(areas)) |
| | |
| | mask = masks[best_idx] |
| | |
| | |
| | if mask.dtype in (np.float32, np.float64): |
| | mask = (mask > 0.5).astype(np.uint8) * 255 |
| | elif mask.dtype != np.uint8: |
| | mask = mask.astype(np.uint8) |
| | |
| | |
| | if mask.ndim == 3: |
| | mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze() |
| | |
| | |
| | _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| | |
| | |
| | assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}" |
| | |
| | return mask |
| |
|
| | |
| | |
| | |
| |
|
| | def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| | """ |
| | High-quality person segmentation. Tries SAM-2 with smart prompts first, |
| | then a classical CV cascade, then a geometric fallback. |
| | Returns uint8 mask (0/255). Never raises if fallback_enabled=True. |
| | """ |
| | if not USE_ENHANCED_SEGMENTATION: |
| | return segment_person_hq_original(image, predictor, fallback_enabled) |
| |
|
| | if image is None or image.size == 0: |
| | raise SegmentationError("Invalid input image") |
| |
|
| | |
| | if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
| | try: |
| | predictor.set_image(image) |
| | mask = ( |
| | _segment_with_intelligent_prompts(image, predictor) |
| | if USE_INTELLIGENT_PROMPTING |
| | else _segment_with_basic_prompts(image, predictor) |
| | ) |
| | if USE_ITERATIVE_REFINEMENT: |
| | mask = _auto_refine_mask_iteratively(image, mask, predictor) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | log.warning("SAM2 mask failed validation β fallback") |
| | except Exception as e: |
| | log.warning(f"SAM2 path failed: {e}") |
| |
|
| | |
| | try: |
| | mask = _classical_segmentation_cascade(image) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | log.warning("Classical cascade weak β geometric fallback") |
| | except Exception as e: |
| | log.debug(f"Classical cascade error: {e}") |
| |
|
| | |
| | return _geometric_person_mask(image) |
| |
|
| |
|
| | def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| | """ |
| | Very first implementation kept for rollback. Fewer smarts, still robust. |
| | """ |
| | if image is None or image.size == 0: |
| | raise SegmentationError("Invalid input image") |
| |
|
| | try: |
| | if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
| | h, w = image.shape[:2] |
| | predictor.set_image(image) |
| |
|
| | points = np.array([ |
| | [w//2, h//4], |
| | [w//2, h//2], |
| | [w//2, 3*h//4], |
| | [w//3, h//2], |
| | [2*w//3, h//2], |
| | ], dtype=np.float32) |
| | labels = np.ones(len(points), np.int32) |
| |
|
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=points, |
| | point_labels=labels, |
| | multimask_output=True, |
| | ) |
| | |
| | |
| | if masks is not None and len(masks): |
| | mask = _sam2_to_matanyone_mask(masks, scores) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | |
| | if fallback_enabled: |
| | return _classical_segmentation_cascade(image) |
| | raise RuntimeError("SAM2 failed and fallback disabled") |
| | except Exception as e: |
| | log.warning(f"segment_person_hq_original error: {e}") |
| | return _classical_segmentation_cascade(image) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| | pos, neg = _generate_smart_prompts(image) |
| | return _sam2_predict(image, predictor, pos, neg) |
| |
|
| |
|
| | def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| | h, w = image.shape[:2] |
| | pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32) |
| | neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32) |
| | return _sam2_predict(image, predictor, pos, neg) |
| |
|
| |
|
| | def _sam2_predict(image: np.ndarray, predictor: Any, |
| | pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray: |
| | if pos_points.size == 0: |
| | pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32) |
| | points = np.vstack([pos_points, neg_points]) |
| | labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32) |
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=points, |
| | point_labels=labels, |
| | multimask_output=True, |
| | ) |
| | |
| | |
| | return _sam2_to_matanyone_mask(masks, scores) |
| |
|
| |
|
| | def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | """ |
| | Simple saliency-based heuristic to auto-place positive / negative points. |
| | """ |
| | h, w = image.shape[:2] |
| | sal = _compute_saliency(image) |
| | pos, neg = [], [] |
| | if sal is not None: |
| | high = sal > (SALIENCY_THRESH - .1) |
| | contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
| | M = cv2.moments(c) |
| | if M["m00"]: |
| | pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])]) |
| | if not pos: |
| | pos = [[w//2, h//2]] |
| | neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]] |
| | return np.asarray(pos, np.float32), np.asarray(neg, np.float32) |
| |
|
| | |
| | |
| | |
| |
|
| | def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray: |
| | """ |
| | Edge-median background subtraction β saliency flood-fill β GrabCut. |
| | """ |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| | edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]]) |
| | diff = np.abs(gray.astype(float) - np.median(edge_px)) |
| | mask = (diff > 30).astype(np.uint8) * 255 |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, |
| | cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | |
| | mask = _refine_with_saliency(image, mask) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | |
| | mask = _refine_with_grabcut(image, mask) |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | return mask |
| | |
| | return _geometric_person_mask(image) |
| |
|
| | |
| |
|
| | def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]: |
| | try: |
| | if hasattr(cv2, "saliency"): |
| | s = cv2.saliency.StaticSaliencySpectralResidual_create() |
| | ok, smap = s.computeSaliency(image) |
| | if ok: |
| | smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min()) |
| | return smap |
| | except Exception: |
| | pass |
| | return None |
| |
|
| | def _auto_person_rect(image): |
| | sal = _compute_saliency(image) |
| | if sal is None: |
| | return None |
| | m = (sal > SALIENCY_THRESH).astype(np.uint8) |
| | cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | if not cnts: |
| | return None |
| | x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea)) |
| | H,W = image.shape[:2] |
| | pad = 0.05 |
| | x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad)) |
| | w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad))) |
| | return x,y,w,h |
| |
|
| | def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
| | h,w = image.shape[:2] |
| | gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8) |
| | gc[seed>200] = cv2.GC_FGD |
| | rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7)) |
| | bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64) |
| | cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK) |
| | return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
| |
|
| | def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
| | sal = _compute_saliency(image) |
| | if sal is None: |
| | return seed |
| | high = (sal > SALIENCY_THRESH).astype(np.uint8)*255 |
| | ys,xs = np.where(seed>127) |
| | cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2 |
| | ff = high.copy() |
| | cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5) |
| | return ff |
| |
|
| | |
| | |
| | |
| |
|
| | def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool: |
| | h,w = shape |
| | ratio = np.sum(mask>127)/(h*w) |
| | return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO |
| |
|
| | def _process_mask(mask: np.ndarray) -> np.ndarray: |
| | """Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask""" |
| | if mask.dtype in (np.float32, np.float64): |
| | if mask.max() <= 1.0: |
| | mask = (mask*255).astype(np.uint8) |
| | if mask.dtype != np.uint8: |
| | mask = mask.astype(np.uint8) |
| | if mask.ndim == 3: |
| | mask = mask.squeeze() |
| | if mask.ndim == 3: |
| | mask = mask[:,:,0] |
| | _,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY) |
| | return mask |
| |
|
| | def _geometric_person_mask(image: np.ndarray) -> np.ndarray: |
| | h,w = image.shape[:2] |
| | mask = np.zeros((h,w), np.uint8) |
| | cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1) |
| | return mask |
| |
|
| | |
| | |
| | |
| |
|
| | def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1): |
| | |
| | return mask |