"""Visualization utilities for satellite change detection.""" import numpy as np import cv2 from typing import Tuple, Optional def create_overlay( image: np.ndarray, mask: np.ndarray, alpha: float = 0.5, color: Tuple[int, int, int] = (255, 0, 0) ) -> np.ndarray: """ Overlay a binary mask on an image with transparency. Args: image: Base image (H, W, 3), uint8 or float [0,1] mask: Binary mask (H, W), values 0 or 1 alpha: Overlay transparency (0 = invisible, 1 = opaque) color: RGB color for the overlay Returns: Blended image as uint8 (H, W, 3) """ # Convert image to uint8 if image.dtype != np.uint8: base = (np.clip(image, 0, 1) * 255).astype(np.uint8) else: base = image.copy() # Ensure 3 channels if base.ndim == 2: base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB) overlay = base.copy() mask_bool = mask.astype(bool) # Apply colour to masked region overlay[mask_bool] = [color[0], color[1], color[2]] # Blend result = cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0) return result.astype(np.uint8) def visualize_predictions( image: np.ndarray, pred_mask: np.ndarray, gt_mask: Optional[np.ndarray] = None, confidence: Optional[np.ndarray] = None ) -> np.ndarray: """ Create a side-by-side visualization of image, prediction, and (optionally) ground truth. Args: image: Original image (H, W, 3) pred_mask: Predicted binary mask (H, W) gt_mask: Optional ground truth mask (H, W) confidence: Optional confidence map (H, W) Returns: Combined visualization as uint8 (H, W*N, 3) """ if image.dtype != np.uint8: img_u8 = (np.clip(image, 0, 1) * 255).astype(np.uint8) else: img_u8 = image.copy() h, w = img_u8.shape[:2] panels = [img_u8] # Prediction overlay (red) pred_overlay = create_overlay(img_u8, pred_mask, alpha=0.5, color=(255, 0, 0)) panels.append(pred_overlay) # Ground truth overlay (green) if gt_mask is not None: gt_overlay = create_overlay(img_u8, gt_mask, alpha=0.5, color=(0, 255, 0)) panels.append(gt_overlay) # Confidence heatmap if confidence is not None: heatmap = (np.clip(confidence, 0, 1) * 255).astype(np.uint8) heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) panels.append(heatmap_color) # Resize all panels to same height panels_resized = [ cv2.resize(p, (w, h), interpolation=cv2.INTER_LINEAR) if p.shape[:2] != (h, w) else p for p in panels ] return np.concatenate(panels_resized, axis=1)