File size: 2,782 Bytes
9875bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Visualization utilities for satellite change detection."""

import numpy as np
import cv2
from typing import Tuple, Optional


def create_overlay(
    image: np.ndarray,
    mask: np.ndarray,
    alpha: float = 0.5,
    color: Tuple[int, int, int] = (255, 0, 0)
) -> np.ndarray:
    """
    Overlay a binary mask on an image with transparency.

    Args:
        image: Base image (H, W, 3), uint8 or float [0,1]
        mask: Binary mask (H, W), values 0 or 1
        alpha: Overlay transparency (0 = invisible, 1 = opaque)
        color: RGB color for the overlay

    Returns:
        Blended image as uint8 (H, W, 3)
    """
    # Convert image to uint8
    if image.dtype != np.uint8:
        base = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    else:
        base = image.copy()

    # Ensure 3 channels
    if base.ndim == 2:
        base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB)

    overlay = base.copy()
    mask_bool = mask.astype(bool)

    # Apply colour to masked region
    overlay[mask_bool] = [color[0], color[1], color[2]]

    # Blend
    result = cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0)
    return result.astype(np.uint8)


def visualize_predictions(
    image: np.ndarray,
    pred_mask: np.ndarray,
    gt_mask: Optional[np.ndarray] = None,
    confidence: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Create a side-by-side visualization of image, prediction, and (optionally) ground truth.

    Args:
        image: Original image (H, W, 3)
        pred_mask: Predicted binary mask (H, W)
        gt_mask: Optional ground truth mask (H, W)
        confidence: Optional confidence map (H, W)

    Returns:
        Combined visualization as uint8 (H, W*N, 3)
    """
    if image.dtype != np.uint8:
        img_u8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    else:
        img_u8 = image.copy()

    h, w = img_u8.shape[:2]
    panels = [img_u8]

    # Prediction overlay (red)
    pred_overlay = create_overlay(img_u8, pred_mask, alpha=0.5, color=(255, 0, 0))
    panels.append(pred_overlay)

    # Ground truth overlay (green)
    if gt_mask is not None:
        gt_overlay = create_overlay(img_u8, gt_mask, alpha=0.5, color=(0, 255, 0))
        panels.append(gt_overlay)

    # Confidence heatmap
    if confidence is not None:
        heatmap = (np.clip(confidence, 0, 1) * 255).astype(np.uint8)
        heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
        panels.append(heatmap_color)

    # Resize all panels to same height
    panels_resized = [
        cv2.resize(p, (w, h), interpolation=cv2.INTER_LINEAR)
        if p.shape[:2] != (h, w) else p
        for p in panels
    ]

    return np.concatenate(panels_resized, axis=1)