SentinelWatch / utils /visualization.py
VishaliniS456's picture
Upload 8 files
9875bf8 verified
"""Visualization utilities for satellite change detection."""
import numpy as np
import cv2
from typing import Tuple, Optional
def create_overlay(
image: np.ndarray,
mask: np.ndarray,
alpha: float = 0.5,
color: Tuple[int, int, int] = (255, 0, 0)
) -> np.ndarray:
"""
Overlay a binary mask on an image with transparency.
Args:
image: Base image (H, W, 3), uint8 or float [0,1]
mask: Binary mask (H, W), values 0 or 1
alpha: Overlay transparency (0 = invisible, 1 = opaque)
color: RGB color for the overlay
Returns:
Blended image as uint8 (H, W, 3)
"""
# Convert image to uint8
if image.dtype != np.uint8:
base = (np.clip(image, 0, 1) * 255).astype(np.uint8)
else:
base = image.copy()
# Ensure 3 channels
if base.ndim == 2:
base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB)
overlay = base.copy()
mask_bool = mask.astype(bool)
# Apply colour to masked region
overlay[mask_bool] = [color[0], color[1], color[2]]
# Blend
result = cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0)
return result.astype(np.uint8)
def visualize_predictions(
image: np.ndarray,
pred_mask: np.ndarray,
gt_mask: Optional[np.ndarray] = None,
confidence: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Create a side-by-side visualization of image, prediction, and (optionally) ground truth.
Args:
image: Original image (H, W, 3)
pred_mask: Predicted binary mask (H, W)
gt_mask: Optional ground truth mask (H, W)
confidence: Optional confidence map (H, W)
Returns:
Combined visualization as uint8 (H, W*N, 3)
"""
if image.dtype != np.uint8:
img_u8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
else:
img_u8 = image.copy()
h, w = img_u8.shape[:2]
panels = [img_u8]
# Prediction overlay (red)
pred_overlay = create_overlay(img_u8, pred_mask, alpha=0.5, color=(255, 0, 0))
panels.append(pred_overlay)
# Ground truth overlay (green)
if gt_mask is not None:
gt_overlay = create_overlay(img_u8, gt_mask, alpha=0.5, color=(0, 255, 0))
panels.append(gt_overlay)
# Confidence heatmap
if confidence is not None:
heatmap = (np.clip(confidence, 0, 1) * 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
panels.append(heatmap_color)
# Resize all panels to same height
panels_resized = [
cv2.resize(p, (w, h), interpolation=cv2.INTER_LINEAR)
if p.shape[:2] != (h, w) else p
for p in panels
]
return np.concatenate(panels_resized, axis=1)