| | """ |
| | GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images. |
| | |
| | This module implements a fully PyTorch-based image processor that: |
| | 1. Localizes the eye/fundus region using gradient-based radial symmetry |
| | 2. Crops to a border-minimized square centered on the eye |
| | 3. Applies CLAHE for contrast enhancement |
| | 4. Outputs tensors compatible with Hugging Face vision models |
| | |
| | Constraints: |
| | - PyTorch only (no OpenCV, PIL, NumPy in runtime) |
| | - CUDA-compatible, batch-friendly, deterministic |
| | """ |
| |
|
| | from typing import Dict, List, Optional, Union |
| | import math |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.image_processing_base import BatchFeature |
| |
|
| | |
| | try: |
| | from PIL import Image |
| | PIL_AVAILABLE = True |
| | except ImportError: |
| | PIL_AVAILABLE = False |
| |
|
| | try: |
| | import numpy as np |
| | NUMPY_AVAILABLE = True |
| | except ImportError: |
| | NUMPY_AVAILABLE = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _pil_to_tensor(image: "Image.Image") -> torch.Tensor: |
| | """Convert a single PIL Image to a float32 tensor of shape (C, H, W) in [0, 1]. |
| | |
| | Converts to RGB if not already. Uses numpy as intermediate when available, |
| | otherwise falls back to manual pixel extraction. |
| | """ |
| | if not PIL_AVAILABLE: |
| | raise ImportError("PIL is required to process PIL Images") |
| |
|
| | |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| |
|
| | |
| | if NUMPY_AVAILABLE: |
| | arr = np.array(image, dtype=np.float32) / 255.0 |
| | |
| | tensor = torch.from_numpy(arr).permute(2, 0, 1) |
| | else: |
| | |
| | width, height = image.size |
| | pixels = list(image.getdata()) |
| | tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0 |
| | tensor = tensor.permute(2, 0, 1) |
| |
|
| | return tensor |
| |
|
| |
|
| | def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor: |
| | """Convert a single numpy array to a float32 tensor of shape (C, H, W) in [0, 1]. |
| | |
| | Handles grayscale (H, W), HWC (H, W, C) with C in {1, 3, 4}, and uint8/float inputs. |
| | Makes a copy to avoid sharing memory with the source array. |
| | """ |
| | if not NUMPY_AVAILABLE: |
| | raise ImportError("NumPy is required to process numpy arrays") |
| |
|
| | |
| | if arr.ndim == 2: |
| | |
| | arr = arr[..., None] |
| |
|
| | if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]: |
| | |
| | arr = arr.transpose(2, 0, 1) |
| |
|
| | |
| | if arr.dtype == np.uint8: |
| | arr = arr.astype(np.float32) / 255.0 |
| | elif arr.dtype != np.float32: |
| | arr = arr.astype(np.float32) |
| |
|
| | return torch.from_numpy(arr.copy()) |
| |
|
| |
|
| | def standardize_input( |
| | images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]], |
| | device: Optional[torch.device] = None, |
| | ) -> torch.Tensor: |
| | """Convert heterogeneous image inputs to a standardized (B, C, H, W) float32 tensor in [0, 1]. |
| | |
| | Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. Integer-typed |
| | inputs (uint8) are scaled to [0, 1]. The output is clamped to [0, 1]. |
| | |
| | Note: All images in a list must have the same spatial dimensions (required by torch.stack). |
| | A single numpy array with ndim==3 is treated as a single HWC image if the last dimension |
| | is in {1, 3, 4}; otherwise it falls through to the tensor path (assumed CHW). |
| | |
| | Args: |
| | images: Input as: |
| | - torch.Tensor (C,H,W), (B,C,H,W), or list of tensors |
| | - PIL.Image.Image or list of PIL Images |
| | - numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays |
| | device: Target device (defaults to input device or CPU) |
| | |
| | Returns: |
| | Tensor of shape (B, C, H, W) in float32, range [0, 1] |
| | """ |
| | |
| | if PIL_AVAILABLE and isinstance(images, Image.Image): |
| | images = [images] |
| | if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3: |
| | |
| | if images.shape[-1] in [1, 3, 4]: |
| | images = [images] |
| |
|
| | |
| | if isinstance(images, list): |
| | converted = [] |
| | for img in images: |
| | if PIL_AVAILABLE and isinstance(img, Image.Image): |
| | converted.append(_pil_to_tensor(img)) |
| | elif NUMPY_AVAILABLE and isinstance(img, np.ndarray): |
| | converted.append(_numpy_to_tensor(img)) |
| | elif isinstance(img, torch.Tensor): |
| | t = img if img.dim() == 3 else img.squeeze(0) |
| | converted.append(t) |
| | else: |
| | raise TypeError(f"Unsupported image type: {type(img)}") |
| | images = torch.stack(converted) |
| | elif NUMPY_AVAILABLE and isinstance(images, np.ndarray): |
| | |
| | if images.ndim == 4: |
| | images = images.transpose(0, 3, 1, 2) |
| | if images.dtype == np.uint8: |
| | images = images.astype(np.float32) / 255.0 |
| | images = torch.from_numpy(images.copy()) |
| |
|
| | if images.dim() == 3: |
| | |
| | images = images.unsqueeze(0) |
| |
|
| | |
| | if device is not None: |
| | images = images.to(device) |
| |
|
| | |
| | if images.dtype == torch.uint8: |
| | images = images.float() / 255.0 |
| | elif images.dtype != torch.float32: |
| | images = images.float() |
| |
|
| | |
| | images = images.clamp(0.0, 1.0) |
| |
|
| | return images |
| |
|
| | def standardize_mask_input( |
| | masks: Union[ |
| | torch.Tensor, |
| | List[torch.Tensor], |
| | "Image.Image", |
| | List["Image.Image"], |
| | "np.ndarray", |
| | List["np.ndarray"], |
| | ], |
| | device: Optional[torch.device] = None, |
| | ) -> torch.Tensor: |
| | """Convert heterogeneous mask inputs to a standardized (B, 1, H, W) tensor. |
| | |
| | Unlike ``standardize_input``, this preserves the original dtype (typically integer |
| | label values) and does **not** normalize to [0, 1]. |
| | |
| | Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. |
| | A single 2-D input is treated as (H, W) and expanded to (1, 1, H, W). |
| | |
| | Args: |
| | masks: Input masks in any supported format. |
| | device: Target device. |
| | |
| | Returns: |
| | Tensor of shape (B, 1, H, W) with original dtype preserved. |
| | """ |
| |
|
| | |
| | if PIL_AVAILABLE and isinstance(masks, Image.Image): |
| | masks = [masks] |
| |
|
| | if NUMPY_AVAILABLE and isinstance(masks, np.ndarray) and masks.ndim == 2: |
| | masks = [masks] |
| |
|
| | |
| | if isinstance(masks, list): |
| | converted = [] |
| | for m in masks: |
| | if PIL_AVAILABLE and isinstance(m, Image.Image): |
| | |
| | m = np.array(m) |
| | converted.append(torch.from_numpy(m)) |
| | elif NUMPY_AVAILABLE and isinstance(m, np.ndarray): |
| | converted.append(torch.from_numpy(m)) |
| | elif isinstance(m, torch.Tensor): |
| | converted.append(m) |
| | else: |
| | raise TypeError(f"Unsupported mask type: {type(m)}") |
| |
|
| | masks = torch.stack(converted) |
| |
|
| | elif NUMPY_AVAILABLE and isinstance(masks, np.ndarray): |
| | masks = torch.from_numpy(masks) |
| |
|
| | |
| |
|
| | if masks.dim() == 2: |
| | |
| | masks = masks.unsqueeze(0).unsqueeze(0) |
| | elif masks.dim() == 3: |
| | |
| | masks = masks.unsqueeze(1) |
| | elif masks.dim() == 4: |
| | |
| | pass |
| | else: |
| | raise ValueError(f"Invalid mask shape: {masks.shape}") |
| |
|
| | |
| | if device is not None: |
| | masks = masks.to(device) |
| |
|
| | return masks |
| |
|
| |
|
| | def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor: |
| | """Convert RGB images to grayscale via ITU-R BT.601 luminance: Y = 0.299R + 0.587G + 0.114B. |
| | |
| | Args: |
| | images: Tensor of shape (B, 3, H, W) in any value range. |
| | |
| | Returns: |
| | Tensor of shape (B, 1, H, W) in the same value range as input. |
| | """ |
| | |
| | weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype) |
| | weights = weights.view(1, 3, 1, 1) |
| |
|
| | grayscale = (images * weights).sum(dim=1, keepdim=True) |
| | return grayscale |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple: |
| | """Create 3x3 Sobel edge-detection kernels for horizontal and vertical gradients. |
| | |
| | Args: |
| | device: Target device for the kernels. |
| | dtype: Target dtype for the kernels. |
| | |
| | Returns: |
| | Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3), |
| | suitable for use with ``F.conv2d`` on single-channel input. |
| | """ |
| | sobel_x = torch.tensor([ |
| | [-1, 0, 1], |
| | [-2, 0, 2], |
| | [-1, 0, 1] |
| | ], device=device, dtype=dtype).view(1, 1, 3, 3) |
| |
|
| | sobel_y = torch.tensor([ |
| | [-1, -2, -1], |
| | [ 0, 0, 0], |
| | [ 1, 2, 1] |
| | ], device=device, dtype=dtype).view(1, 1, 3, 3) |
| |
|
| | return sobel_x, sobel_y |
| |
|
| |
|
| | def compute_gradients(grayscale: torch.Tensor) -> tuple: |
| | """Compute horizontal and vertical image gradients using 3x3 Sobel filters. |
| | |
| | Uses reflect-free padding=1 (zero-padded convolution) to maintain spatial size. |
| | |
| | Args: |
| | grayscale: Single-channel images of shape (B, 1, H, W). |
| | |
| | Returns: |
| | Tuple of (grad_x, grad_y, grad_magnitude), each (B, 1, H, W). |
| | ``grad_magnitude`` = sqrt(grad_x^2 + grad_y^2 + 1e-8). |
| | """ |
| | sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype) |
| |
|
| | |
| | grad_x = F.conv2d(grayscale, sobel_x, padding=1) |
| | grad_y = F.conv2d(grayscale, sobel_y, padding=1) |
| |
|
| | |
| | grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8) |
| |
|
| | return grad_x, grad_y, grad_magnitude |
| |
|
| |
|
| | def compute_radial_symmetry_response( |
| | grayscale: torch.Tensor, |
| | grad_x: torch.Tensor, |
| | grad_y: torch.Tensor, |
| | grad_magnitude: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Compute a radial-symmetry response map for circular-region detection. |
| | |
| | The algorithm: |
| | 1. Estimates an initial center as the intensity-weighted center of mass of |
| | dark regions (squared inverse intensity). |
| | 2. For each pixel, computes the dot product between the normalized gradient |
| | vector and the unit vector pointing toward the estimated center. |
| | 3. Weights this alignment score by gradient magnitude and darkness. |
| | 4. Smooths the response with a separable Gaussian whose sigma is |
| | proportional to the image size (kernel_size = max(H,W)//8, sigma = kernel_size/6). |
| | |
| | High response indicates pixels whose gradients point radially inward toward |
| | a dark center — characteristic of the fundus disc boundary. |
| | |
| | Args: |
| | grayscale: Grayscale images (B, 1, H, W) in [0, 1]. |
| | grad_x: Horizontal gradient (B, 1, H, W). |
| | grad_y: Vertical gradient (B, 1, H, W). |
| | grad_magnitude: Gradient magnitude (B, 1, H, W). |
| | |
| | Returns: |
| | Smoothed radial symmetry response map (B, 1, H, W). |
| | """ |
| | B, _, H, W = grayscale.shape |
| | device = grayscale.device |
| | dtype = grayscale.dtype |
| |
|
| | |
| | y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
| | x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
| |
|
| | |
| | |
| | dark_weight = 1.0 - grayscale |
| | dark_weight = dark_weight ** 2 |
| |
|
| | |
| | weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8 |
| |
|
| | |
| | cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
| | cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
| |
|
| | |
| | dx_to_center = cx_init - x_coords |
| | dy_to_center = cy_init - y_coords |
| | dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8) |
| |
|
| | |
| | dx_norm = dx_to_center / dist_to_center |
| | dy_norm = dy_to_center / dist_to_center |
| |
|
| | |
| | grad_norm = grad_magnitude + 1e-8 |
| | gx_norm = grad_x / grad_norm |
| | gy_norm = grad_y / grad_norm |
| |
|
| | |
| | |
| | radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm |
| |
|
| | |
| | response = radial_alignment * grad_magnitude * dark_weight |
| |
|
| | |
| | kernel_size = max(H, W) // 8 |
| | if kernel_size % 2 == 0: |
| | kernel_size += 1 |
| | kernel_size = max(kernel_size, 5) |
| |
|
| | sigma = kernel_size / 6.0 |
| |
|
| | |
| | x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2 |
| | gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2)) |
| | gaussian_1d = gaussian_1d / gaussian_1d.sum() |
| |
|
| | |
| | gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size) |
| | gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1) |
| |
|
| | pad_h = kernel_size // 2 |
| | pad_v = kernel_size // 2 |
| |
|
| | response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect') |
| | response = F.conv2d(response, gaussian_1d_h) |
| | response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect') |
| | response = F.conv2d(response, gaussian_1d_v) |
| |
|
| | return response |
| |
|
| |
|
| | def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple: |
| | """Find the sub-pixel peak location in a response map via softmax-weighted coordinates. |
| | |
| | Divides the flattened response by ``temperature`` before applying softmax, then |
| | computes the weighted mean of the (x, y) coordinate grids. Lower temperature yields |
| | a sharper, more argmax-like result; higher temperature yields a broader average. |
| | |
| | Caution: Very low temperatures (< 0.01) combined with large response magnitudes |
| | can cause numerical overflow in the softmax exponential. |
| | |
| | Args: |
| | response: Response map (B, 1, H, W). |
| | temperature: Softmax temperature. Default 0.1. |
| | |
| | Returns: |
| | Tuple of (cx, cy), each of shape (B,), in pixel coordinates. |
| | """ |
| | B, _, H, W = response.shape |
| | device = response.device |
| | dtype = response.dtype |
| |
|
| | |
| | response_flat = response.view(B, -1) |
| |
|
| | |
| | weights = F.softmax(response_flat / temperature, dim=1) |
| | weights = weights.view(B, 1, H, W) |
| |
|
| | |
| | y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
| | x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
| |
|
| | |
| | cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) |
| | cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) |
| |
|
| | return cx, cy |
| |
|
| |
|
| | def estimate_eye_center( |
| | images: torch.Tensor, |
| | softmax_temperature: float = 0.1, |
| | ) -> tuple: |
| | """Estimate the center of the fundus/eye disc in each image. |
| | |
| | Pipeline: RGB → grayscale → Sobel gradients → radial symmetry response → soft argmax. |
| | |
| | Args: |
| | images: RGB images of shape (B, 3, H, W) in [0, 1]. |
| | softmax_temperature: Temperature for the soft-argmax peak finder. |
| | Lower values (0.01-0.1) give sharper localization; higher values |
| | (0.3-0.5) give broader averaging, useful for noisy or low-contrast images. |
| | Default 0.1. |
| | |
| | Returns: |
| | Tuple of (cx, cy), each of shape (B,), in pixel coordinates. |
| | """ |
| | grayscale = rgb_to_grayscale(images) |
| | grad_x, grad_y, grad_magnitude = compute_gradients(grayscale) |
| | response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude) |
| | cx, cy = soft_argmax_2d(response, temperature=softmax_temperature) |
| |
|
| | return cx, cy |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def estimate_radius( |
| | images: torch.Tensor, |
| | cx: torch.Tensor, |
| | cy: torch.Tensor, |
| | num_radii: int = 100, |
| | num_angles: int = 36, |
| | min_radius_frac: float = 0.1, |
| | max_radius_frac: float = 0.5, |
| | ) -> torch.Tensor: |
| | """Estimate the radius of the fundus disc by analyzing radial intensity profiles. |
| | |
| | Samples grayscale intensity along ``num_angles`` rays emanating from ``(cx, cy)`` |
| | at ``num_radii`` radial distances. The per-radius mean intensity across all angles |
| | gives a 1-D radial profile. The discrete derivative of this profile is linearly |
| | weighted by radius (range 0.5–1.5) to bias toward the outer fundus boundary |
| | rather than the smaller pupil boundary. The radius at the strongest weighted |
| | negative gradient is selected as the disc edge. |
| | |
| | Uses ``F.grid_sample`` with bilinear interpolation and border padding for |
| | sub-pixel sampling. |
| | |
| | Args: |
| | images: RGB images (B, 3, H, W) in [0, 1]. |
| | cx, cy: Center coordinates (B,) in pixel units. |
| | num_radii: Number of radial sample points. Default 100. |
| | num_angles: Number of angular sample rays. Default 36. |
| | min_radius_frac: Minimum search radius as fraction of min(H, W). Default 0.1. |
| | max_radius_frac: Maximum search radius as fraction of min(H, W). Default 0.5. |
| | |
| | Returns: |
| | Estimated radius for each image (B,), clamped to [min_radius, max_radius]. |
| | """ |
| | B, _, H, W = images.shape |
| | device = images.device |
| | dtype = images.dtype |
| |
|
| | grayscale = rgb_to_grayscale(images) |
| |
|
| | min_dim = min(H, W) |
| | min_radius = int(min_radius_frac * min_dim) |
| | max_radius = int(max_radius_frac * min_dim) |
| |
|
| | |
| | radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype) |
| | angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1] |
| |
|
| | |
| | cos_angles = torch.cos(angles).view(-1, 1) |
| | sin_angles = torch.sin(angles).view(-1, 1) |
| |
|
| | |
| | dx = cos_angles * radii |
| | dy = sin_angles * radii |
| |
|
| | |
| | |
| | cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii) |
| | cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii) |
| |
|
| | sample_x = cx_expanded + dx.unsqueeze(0) |
| | sample_y = cy_expanded + dy.unsqueeze(0) |
| |
|
| | |
| | sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
| | sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
| |
|
| | |
| | grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1) |
| |
|
| | |
| | sampled = F.grid_sample( |
| | grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True |
| | ) |
| |
|
| | |
| | radial_profile = sampled.mean(dim=2).squeeze(1) |
| |
|
| | |
| | radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] |
| |
|
| | |
| | |
| | radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype) |
| | weighted_gradient = radial_gradient * radius_weights.unsqueeze(0) |
| |
|
| | |
| | min_idx = weighted_gradient.argmin(dim=1) |
| |
|
| | |
| | estimated_radius = radii[min_idx + 1] |
| |
|
| | |
| | estimated_radius = estimated_radius.clamp(min_radius, max_radius) |
| |
|
| | return estimated_radius |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def compute_crop_box( |
| | cx: torch.Tensor, |
| | cy: torch.Tensor, |
| | radius: torch.Tensor, |
| | H: int, |
| | W: int, |
| | scale_factor: float = 1.1, |
| | allow_overflow: bool = False, |
| | ) -> tuple: |
| | """Compute a square bounding box centered on the detected eye. |
| | |
| | The half-side length is ``radius * scale_factor``. When ``allow_overflow`` is |
| | False, the box is clamped to the image bounds and then made square by shrinking |
| | to the shorter side and re-centering. The resulting box is guaranteed to be |
| | square and fully within [0, W-1] x [0, H-1]. |
| | |
| | When ``allow_overflow`` is True the raw (possibly out-of-bounds) box is |
| | returned, which is useful for images where the fundus disc is partially |
| | clipped; out-of-bounds regions will be zero-filled during grid_sample. |
| | |
| | Args: |
| | cx, cy: Detected eye center coordinates (B,). |
| | radius: Estimated disc radius (B,). |
| | H, W: Spatial dimensions of the source images. |
| | scale_factor: Padding multiplier applied to ``radius``. Default 1.1. |
| | allow_overflow: Skip clamping / squareness enforcement. Default False. |
| | |
| | Returns: |
| | Tuple of (x1, y1, x2, y2), each of shape (B,), in pixel coordinates. |
| | """ |
| | |
| | half_side = radius * scale_factor |
| |
|
| | |
| | x1 = cx - half_side |
| | y1 = cy - half_side |
| | x2 = cx + half_side |
| | y2 = cy + half_side |
| |
|
| | if allow_overflow: |
| | |
| | |
| | return x1, y1, x2, y2 |
| |
|
| | |
| | |
| | x1 = x1.clamp(min=0) |
| | y1 = y1.clamp(min=0) |
| | x2 = x2.clamp(max=W - 1) |
| | y2 = y2.clamp(max=H - 1) |
| |
|
| | |
| | side_x = x2 - x1 |
| | side_y = y2 - y1 |
| | side = torch.minimum(side_x, side_y) |
| |
|
| | |
| | cx_new = (x1 + x2) / 2 |
| | cy_new = (y1 + y2) / 2 |
| |
|
| | x1 = (cx_new - side / 2).clamp(min=0) |
| | y1 = (cy_new - side / 2).clamp(min=0) |
| | x2 = x1 + side |
| | y2 = y1 + side |
| |
|
| | |
| | x2 = x2.clamp(max=W - 1) |
| | y2 = y2.clamp(max=H - 1) |
| |
|
| | return x1, y1, x2, y2 |
| |
|
| |
|
| | def batch_crop_and_resize( |
| | images: torch.Tensor, |
| | x1: torch.Tensor, |
| | y1: torch.Tensor, |
| | x2: torch.Tensor, |
| | y2: torch.Tensor, |
| | output_size: int, |
| | padding_mode: str = 'border', |
| | ) -> torch.Tensor: |
| | """Crop and resize images to a square using ``F.grid_sample`` (GPU-friendly). |
| | |
| | Builds a regular output grid in [0, 1]^2, maps it to the source rectangle |
| | [x1, x2] x [y1, y2] via affine scaling, normalizes to [-1, 1] for |
| | ``grid_sample``, and samples with bilinear interpolation (``align_corners=True``). |
| | |
| | Crop coordinates may extend beyond image bounds; the ``padding_mode`` |
| | controls how out-of-bounds pixels are filled. |
| | |
| | Args: |
| | images: Input images (B, C, H, W). |
| | x1, y1, x2, y2: Crop box corners (B,). May exceed [0, W-1] / [0, H-1]. |
| | output_size: Side length of the square output. |
| | padding_mode: ``'border'`` (repeat edge, default) or ``'zeros'`` (black fill). |
| | |
| | Returns: |
| | Cropped and resized images (B, C, output_size, output_size). |
| | """ |
| | B, C, H, W = images.shape |
| | device = images.device |
| | dtype = images.dtype |
| |
|
| | |
| | out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype) |
| | out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij') |
| | out_grid = torch.stack([out_x, out_y], dim=-1) |
| | out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) |
| |
|
| | |
| | |
| | x1 = x1.view(B, 1, 1, 1) |
| | y1 = y1.view(B, 1, 1, 1) |
| | x2 = x2.view(B, 1, 1, 1) |
| | y2 = y2.view(B, 1, 1, 1) |
| |
|
| | |
| | sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) |
| | sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) |
| |
|
| | |
| | sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
| | sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
| |
|
| | grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) |
| |
|
| | |
| | cropped = F.grid_sample( |
| | images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True |
| | ) |
| |
|
| | return cropped |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def batch_crop_and_resize_mask( |
| | masks: torch.Tensor, |
| | x1: torch.Tensor, |
| | y1: torch.Tensor, |
| | x2: torch.Tensor, |
| | y2: torch.Tensor, |
| | output_size: int, |
| | padding_mode: str = "zeros", |
| | ) -> torch.Tensor: |
| | """Crop and resize segmentation masks using nearest-neighbor sampling. |
| | |
| | Same spatial transform as ``batch_crop_and_resize`` but uses ``mode='nearest'`` |
| | to preserve discrete label values. The output is rounded and cast to ``torch.long`` |
| | to guard against floating-point drift in ``grid_sample``. |
| | |
| | Args: |
| | masks: Integer label masks (B, 1, H, W) — any dtype (converted to float internally). |
| | x1, y1, x2, y2: Crop box corners (B,). May exceed image bounds. |
| | output_size: Side length of the square output. |
| | padding_mode: ``'zeros'`` (background = 0, default) or ``'border'`` (repeat edge). |
| | |
| | Returns: |
| | Cropped and resized masks (B, 1, output_size, output_size) as ``torch.long``. |
| | """ |
| |
|
| | B, C, H, W = masks.shape |
| | device = masks.device |
| |
|
| | |
| | masks_f = masks.float() |
| |
|
| | |
| | coords = torch.linspace(0, 1, output_size, device=device) |
| | out_y, out_x = torch.meshgrid(coords, coords, indexing="ij") |
| | out_grid = torch.stack([out_x, out_y], dim=-1) |
| | out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) |
| |
|
| | |
| | x1 = x1.view(B, 1, 1, 1) |
| | y1 = y1.view(B, 1, 1, 1) |
| | x2 = x2.view(B, 1, 1, 1) |
| | y2 = y2.view(B, 1, 1, 1) |
| |
|
| | |
| | sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) |
| | sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) |
| |
|
| | |
| | sample_x = 2.0 * sample_x / (W - 1) - 1.0 |
| | sample_y = 2.0 * sample_y / (H - 1) - 1.0 |
| |
|
| | grid = torch.cat([sample_x, sample_y], dim=-1) |
| |
|
| | |
| | cropped = F.grid_sample( |
| | masks_f, |
| | grid, |
| | mode="nearest", |
| | padding_mode=padding_mode, |
| | align_corners=True, |
| | ) |
| |
|
| | |
| | |
| | |
| | return cropped.round().long() |
| |
|
| | |
| | |
| | |
| |
|
| | def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: |
| | """Apply the sRGB electro-optical transfer function (EOTF) to convert sRGB to linear RGB. |
| | |
| | Uses the IEC 61966-2-1 piecewise formula with threshold 0.04045. |
| | """ |
| | threshold = 0.04045 |
| | linear = torch.where( |
| | rgb <= threshold, |
| | rgb / 12.92, |
| | ((rgb + 0.055) / 1.055) ** 2.4 |
| | ) |
| | return linear |
| |
|
| |
|
| | def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: |
| | """Apply the inverse sRGB EOTF to convert linear RGB to sRGB. |
| | |
| | Uses the IEC 61966-2-1 piecewise formula with threshold 0.0031308. |
| | Input must be non-negative; negative values will produce NaN from the power function. |
| | """ |
| | threshold = 0.0031308 |
| | srgb = torch.where( |
| | linear <= threshold, |
| | linear * 12.92, |
| | 1.055 * (linear ** (1.0 / 2.4)) - 0.055 |
| | ) |
| | return srgb |
| |
|
| |
|
| | def rgb_to_lab(images: torch.Tensor) -> tuple: |
| | """Convert sRGB images to CIE LAB colour space (D65 illuminant). |
| | |
| | Conversion chain: sRGB → linear RGB → CIE XYZ → CIE LAB. |
| | The raw LAB values are rescaled for internal convenience: |
| | |
| | - L ∈ [0, 100] → L / 100 → [0, 1] |
| | - a ∈ ~[-128, 127] → a / 256 + 0.5 → ~[0, 1] |
| | - b ∈ ~[-128, 127] → b / 256 + 0.5 → ~[0, 1] |
| | |
| | These normalised values are **not** standard LAB; use ``lab_to_rgb`` to |
| | invert them back to sRGB. |
| | |
| | Args: |
| | images: RGB images (B, 3, H, W) in [0, 1] sRGB. |
| | |
| | Returns: |
| | Tuple of (L, a, b_ch), each (B, 1, H, W): |
| | - L: Normalised luminance in [0, 1]. |
| | - a: Normalised green–red chrominance, roughly [0, 1]. |
| | - b_ch: Normalised blue–yellow chrominance, roughly [0, 1]. |
| | """ |
| | device = images.device |
| | dtype = images.dtype |
| |
|
| | |
| | linear_rgb = _srgb_to_linear(images) |
| |
|
| | |
| | |
| | r = linear_rgb[:, 0:1, :, :] |
| | g = linear_rgb[:, 1:2, :, :] |
| | b = linear_rgb[:, 2:3, :, :] |
| |
|
| | x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b |
| | y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b |
| | z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b |
| |
|
| | |
| | xn, yn, zn = 0.95047, 1.0, 1.08883 |
| |
|
| | x = x / xn |
| | y = y / yn |
| | z = z / zn |
| |
|
| | |
| | delta = 6.0 / 29.0 |
| | delta_cube = delta ** 3 |
| |
|
| | def f(t): |
| | return torch.where( |
| | t > delta_cube, |
| | t ** (1.0 / 3.0), |
| | t / (3.0 * delta ** 2) + 4.0 / 29.0 |
| | ) |
| |
|
| | fx = f(x) |
| | fy = f(y) |
| | fz = f(z) |
| |
|
| | L = 116.0 * fy - 16.0 |
| | a = 500.0 * (fx - fy) |
| | b_ch = 200.0 * (fy - fz) |
| |
|
| | |
| | L = L / 100.0 |
| | a = a / 256.0 + 0.5 |
| | b_ch = b_ch / 256.0 + 0.5 |
| |
|
| | return L, a, b_ch |
| |
|
| |
|
| | def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor: |
| | """Convert normalised CIE LAB back to sRGB (inverse of ``rgb_to_lab``). |
| | |
| | Denormalisation: L*100, (a-0.5)*256, (b_ch-0.5)*256, then LAB → XYZ → linear RGB → sRGB. |
| | Output is clamped to [0, 1]. |
| | |
| | Args: |
| | L: Normalised luminance (B, 1, H, W) in [0, 1]. |
| | a: Normalised green–red chrominance (B, 1, H, W), roughly [0, 1]. |
| | b_ch: Normalised blue–yellow chrominance (B, 1, H, W), roughly [0, 1]. |
| | |
| | Returns: |
| | sRGB images (B, 3, H, W) clamped to [0, 1]. |
| | """ |
| | |
| | L_lab = L * 100.0 |
| | a_lab = (a - 0.5) * 256.0 |
| | b_lab = (b_ch - 0.5) * 256.0 |
| |
|
| | |
| | fy = (L_lab + 16.0) / 116.0 |
| | fx = a_lab / 500.0 + fy |
| | fz = fy - b_lab / 200.0 |
| |
|
| | delta = 6.0 / 29.0 |
| |
|
| | def f_inv(t): |
| | return torch.where( |
| | t > delta, |
| | t ** 3, |
| | 3.0 * (delta ** 2) * (t - 4.0 / 29.0) |
| | ) |
| |
|
| | |
| | xn, yn, zn = 0.95047, 1.0, 1.08883 |
| |
|
| | x = xn * f_inv(fx) |
| | y = yn * f_inv(fy) |
| | z = zn * f_inv(fz) |
| |
|
| | |
| | r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z |
| | g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z |
| | b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z |
| |
|
| | linear_rgb = torch.cat([r, g, b], dim=1) |
| |
|
| | |
| | linear_rgb = linear_rgb.clamp(0.0, 1.0) |
| |
|
| | |
| | srgb = _linear_to_srgb(linear_rgb) |
| |
|
| | return srgb.clamp(0.0, 1.0) |
| |
|
| |
|
| | def compute_histogram( |
| | tensor: torch.Tensor, |
| | num_bins: int = 256, |
| | ) -> torch.Tensor: |
| | """Compute per-image histograms for a batch of single-channel images. |
| | |
| | Bins are uniformly spaced over [0, 1]. Each pixel is assigned to a bin via |
| | ``floor(value * (num_bins - 1))``, accumulated with ``scatter_add`` in a |
| | per-sample loop. |
| | |
| | Note: This function is used only by ``clahe_single_tile``. |
| | The vectorized CLAHE path (``apply_clahe_vectorized``) computes histograms |
| | inline for better GPU efficiency. |
| | |
| | Args: |
| | tensor: Input (B, 1, H, W) with values in [0, 1]. |
| | num_bins: Number of histogram bins. Default 256. |
| | |
| | Returns: |
| | Histograms of shape (B, num_bins), dtype matching input. |
| | """ |
| | B = tensor.shape[0] |
| | device = tensor.device |
| | dtype = tensor.dtype |
| |
|
| | |
| | flat = tensor.view(B, -1) |
| |
|
| | |
| | bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
| |
|
| | |
| | histograms = torch.zeros(B, num_bins, device=device, dtype=dtype) |
| | ones = torch.ones_like(flat, dtype=dtype) |
| |
|
| | for i in range(B): |
| | histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i]) |
| |
|
| | return histograms |
| |
|
| |
|
| | def clahe_single_tile( |
| | tile: torch.Tensor, |
| | clip_limit: float, |
| | num_bins: int = 256, |
| | ) -> torch.Tensor: |
| | """Compute the clipped-and-redistributed CDF for a single CLAHE tile. |
| | |
| | Clips the histogram so no bin exceeds ``clip_limit * num_pixels / num_bins``, |
| | redistributes the excess uniformly, then computes and min-max normalises the CDF. |
| | |
| | Note: This function is not used by the main pipeline — see |
| | ``apply_clahe_vectorized`` which processes all tiles in a single pass. |
| | |
| | Args: |
| | tile: Single-channel tile images (B, 1, tile_h, tile_w) in [0, 1]. |
| | clip_limit: Relative clip limit (higher = less contrast limiting). |
| | num_bins: Number of histogram bins. Default 256. |
| | |
| | Returns: |
| | Normalised CDF lookup table (B, num_bins) in [0, 1]. |
| | """ |
| | B, _, tile_h, tile_w = tile.shape |
| | device = tile.device |
| | dtype = tile.dtype |
| | num_pixels = tile_h * tile_w |
| |
|
| | |
| | hist = compute_histogram(tile, num_bins) |
| |
|
| | |
| | clip_value = clip_limit * num_pixels / num_bins |
| | excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
| | hist = hist.clamp(max=clip_value) |
| |
|
| | |
| | redistribution = excess / num_bins |
| | hist = hist + redistribution |
| |
|
| | |
| | cdf = hist.cumsum(dim=1) |
| |
|
| | |
| | cdf_min = cdf[:, 0:1] |
| | cdf_max = cdf[:, -1:] |
| | cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8) |
| |
|
| | return cdf |
| |
|
| |
|
| | def apply_clahe_vectorized( |
| | images: torch.Tensor, |
| | grid_size: int = 8, |
| | clip_limit: float = 2.0, |
| | num_bins: int = 256, |
| | ) -> torch.Tensor: |
| | """Fully-vectorized CLAHE (Contrast Limited Adaptive Histogram Equalisation). |
| | |
| | For RGB input, converts to CIE LAB, applies CLAHE to the L channel only, |
| | then converts back to sRGB. For single-channel input, operates directly. |
| | |
| | Algorithm: |
| | 1. Pads the luminance channel to be divisible by ``grid_size`` (reflect padding). |
| | 2. Reshapes into ``grid_size x grid_size`` non-overlapping tiles. |
| | 3. Computes a histogram per tile via ``scatter_add_`` (fully batched, no loops). |
| | 4. Clips each histogram at ``clip_limit * num_pixels / num_bins`` and |
| | redistributes excess counts uniformly across all bins. |
| | 5. Computes the cumulative distribution function (CDF) per tile and |
| | min-max normalises it to [0, 1]. |
| | 6. Maps each output pixel to the four surrounding tile centres and |
| | bilinearly interpolates their CDF values for a smooth result. |
| | |
| | Args: |
| | images: Input images (B, C, H, W) in [0, 1]. C must be 1 or 3. |
| | grid_size: Tile grid resolution (tiles per axis). Default 8. |
| | clip_limit: Relative clip limit for histogram clipping. Default 2.0. |
| | num_bins: Number of histogram bins. Default 256. |
| | |
| | Returns: |
| | CLAHE-enhanced images (B, C, H, W) in [0, 1]. |
| | """ |
| | B, C, H, W = images.shape |
| | device = images.device |
| | dtype = images.dtype |
| |
|
| | |
| | if C == 3: |
| | L, a, b_ch = rgb_to_lab(images) |
| | else: |
| | L = images.clone() |
| | a = b_ch = None |
| |
|
| | |
| | pad_h = (grid_size - H % grid_size) % grid_size |
| | pad_w = (grid_size - W % grid_size) % grid_size |
| |
|
| | if pad_h > 0 or pad_w > 0: |
| | L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect') |
| | else: |
| | L_padded = L |
| |
|
| | _, _, H_pad, W_pad = L_padded.shape |
| | tile_h = H_pad // grid_size |
| | tile_w = W_pad // grid_size |
| |
|
| | |
| | L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w) |
| | L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) |
| | L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w) |
| |
|
| | |
| | num_pixels = tile_h * tile_w |
| | flat = L_tiles.view(B * grid_size * grid_size, -1) |
| | bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
| |
|
| | |
| | histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype) |
| | histograms.scatter_add_(1, bin_indices, torch.ones_like(flat)) |
| |
|
| | |
| | clip_value = clip_limit * num_pixels / num_bins |
| | excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
| | histograms = histograms.clamp(max=clip_value) |
| | histograms = histograms + excess / num_bins |
| |
|
| | |
| | cdfs = histograms.cumsum(dim=1) |
| | cdf_min = cdfs[:, 0:1] |
| | cdf_max = cdfs[:, -1:] |
| | cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8) |
| |
|
| | |
| | cdfs = cdfs.view(B, grid_size, grid_size, num_bins) |
| |
|
| | |
| | y_coords = torch.arange(H_pad, device=device, dtype=dtype) |
| | x_coords = torch.arange(W_pad, device=device, dtype=dtype) |
| |
|
| | |
| | tile_y = (y_coords + 0.5) / tile_h - 0.5 |
| | tile_x = (x_coords + 0.5) / tile_w - 0.5 |
| |
|
| | tile_y = tile_y.clamp(0, grid_size - 1.001) |
| | tile_x = tile_x.clamp(0, grid_size - 1.001) |
| |
|
| | |
| | ty0 = tile_y.long().clamp(0, grid_size - 2) |
| | tx0 = tile_x.long().clamp(0, grid_size - 2) |
| | ty1 = (ty0 + 1).clamp(max=grid_size - 1) |
| | tx1 = (tx0 + 1).clamp(max=grid_size - 1) |
| |
|
| | wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1) |
| | wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1) |
| |
|
| | |
| | bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) |
| | bin_idx = bin_idx.squeeze(1) |
| |
|
| | |
| | |
| |
|
| | |
| | b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad) |
| | ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
| | ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
| | tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
| | tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
| |
|
| | |
| | v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] |
| | v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx] |
| | v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx] |
| | v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx] |
| |
|
| | |
| | wy = wy.squeeze(-1) |
| | wx = wx.squeeze(-1) |
| |
|
| | L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11 |
| | L_out = L_out.unsqueeze(1) |
| |
|
| | |
| | if pad_h > 0 or pad_w > 0: |
| | L_out = L_out[:, :, :H, :W] |
| |
|
| | |
| | if C == 3: |
| | output = lab_to_rgb(L_out, a, b_ch) |
| | else: |
| | output = L_out |
| |
|
| | return output |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| | IMAGENET_STD = [0.229, 0.224, 0.225] |
| |
|
| |
|
| | def resize_images( |
| | images: torch.Tensor, |
| | size: int, |
| | mode: str = 'bilinear', |
| | antialias: bool = True, |
| | ) -> torch.Tensor: |
| | """Resize images to a square target size using ``F.interpolate``. |
| | |
| | Args: |
| | images: Input images (B, C, H, W). Must be float for bilinear/bicubic modes. |
| | size: Target side length (output is always square). |
| | mode: Interpolation mode (``'bilinear'``, ``'bicubic'``, ``'nearest'``, etc.). |
| | Default ``'bilinear'``. |
| | antialias: Enable antialiasing for bilinear/bicubic downscaling. Default True. |
| | |
| | Returns: |
| | Resized images (B, C, size, size). |
| | """ |
| | return F.interpolate( |
| | images, |
| | size=(size, size), |
| | mode=mode, |
| | align_corners=False if mode in ['bilinear', 'bicubic'] else None, |
| | antialias=antialias if mode in ['bilinear', 'bicubic'] else False, |
| | ) |
| |
|
| |
|
| | def normalize_images( |
| | images: torch.Tensor, |
| | mean: Optional[List[float]] = None, |
| | std: Optional[List[float]] = None, |
| | mode: str = 'imagenet', |
| | ) -> torch.Tensor: |
| | """Channel-wise normalisation: ``(image - mean) / std``. |
| | |
| | Args: |
| | images: Input images (B, C, H, W) in [0, 1]. |
| | mean: Per-channel means (length C). Required when ``mode='custom'``. |
| | std: Per-channel stds (length C). Required when ``mode='custom'``. |
| | mode: ``'imagenet'`` (uses ImageNet stats), ``'none'`` (identity), or |
| | ``'custom'`` (uses caller-supplied mean/std). Default ``'imagenet'``. |
| | |
| | Returns: |
| | Normalised images (B, C, H, W). Range depends on mean/std. |
| | """ |
| | if mode == 'none': |
| | return images |
| |
|
| | if mode == 'imagenet': |
| | mean = IMAGENET_MEAN |
| | std = IMAGENET_STD |
| | elif mode == 'custom': |
| | if mean is None or std is None: |
| | raise ValueError("Custom mode requires mean and std") |
| | else: |
| | raise ValueError(f"Unknown normalization mode: {mode}") |
| |
|
| | device = images.device |
| | dtype = images.dtype |
| |
|
| | mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1) |
| | std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1) |
| |
|
| | return (images - mean_tensor) / std_tensor |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class EyeCLAHEImageProcessor(BaseImageProcessor): |
| | """GPU-native Hugging Face image processor for Colour Fundus Photography (CFP). |
| | |
| | Processing pipeline (all steps optional via constructor flags): |
| | |
| | 1. **Eye localisation** (``do_crop=True``): detects the fundus disc centre via |
| | gradient-based radial symmetry (dark-region centre-of-mass → Sobel gradients → |
| | radial alignment score → Gaussian smoothing → soft argmax) and estimates the |
| | disc radius from the strongest negative radial intensity gradient. |
| | 2. **Square crop & resize**: crops a square region around the detected disc |
| | (``radius * crop_scale_factor``), optionally allowing overflow beyond image |
| | bounds (``allow_overflow``), then resamples to ``size x size`` via bilinear |
| | ``grid_sample``. When ``do_crop=False``, the whole image is resized directly. |
| | 3. **CLAHE** (``do_clahe=True``): applies Contrast Limited Adaptive Histogram |
| | Equalisation to the CIE LAB luminance channel, using a fully-vectorized |
| | tile-based implementation with bilinear CDF interpolation. |
| | 4. **Normalisation**: channel-wise ``(image - mean) / std`` with configurable |
| | mode (ImageNet, custom, or none). |
| | |
| | The processor also returns per-image coordinate-mapping scalars (``scale_x/y``, |
| | ``offset_x/y``) so that predictions in processed-image space can be mapped back |
| | to original pixel coordinates. |
| | |
| | All operations are pure PyTorch — no OpenCV, PIL, or NumPy at runtime — and are |
| | CUDA-compatible and batch-friendly. |
| | """ |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | size: int = 224, |
| | crop_scale_factor: float = 1.1, |
| | clahe_grid_size: int = 8, |
| | clahe_clip_limit: float = 2.0, |
| | normalization_mode: str = "imagenet", |
| | custom_mean: Optional[List[float]] = None, |
| | custom_std: Optional[List[float]] = None, |
| | do_clahe: bool = True, |
| | do_crop: bool = True, |
| | min_radius_frac: float = 0.1, |
| | max_radius_frac: float = 0.5, |
| | allow_overflow: bool = False, |
| | softmax_temperature: float = 0.1, |
| | **kwargs, |
| | ): |
| | """ |
| | Initialize the EyeCLAHEImageProcessor. |
| | |
| | Args: |
| | size: Output image size (square) |
| | crop_scale_factor: Scale factor for crop box (relative to detected radius) |
| | clahe_grid_size: Number of tiles for CLAHE |
| | clahe_clip_limit: Histogram clip limit for CLAHE |
| | normalization_mode: 'imagenet', 'none', or 'custom' |
| | custom_mean: Custom normalization mean (if mode='custom') |
| | custom_std: Custom normalization std (if mode='custom') |
| | do_clahe: Whether to apply CLAHE |
| | do_crop: Whether to perform eye-centered cropping |
| | min_radius_frac: Minimum radius as fraction of image size |
| | max_radius_frac: Maximum radius as fraction of image size |
| | allow_overflow: If True, allow crop box to extend beyond image bounds |
| | and fill missing regions with black. Useful for pre-cropped |
| | images where the fundus circle is partially cut off. |
| | softmax_temperature: Temperature for soft argmax in eye center detection. |
| | Lower values (0.01-0.1) give sharper peak detection, higher values |
| | (0.3-0.5) provide more averaging for noisy images. Default: 0.1. |
| | """ |
| | super().__init__(**kwargs) |
| |
|
| | self.size = size |
| | self.crop_scale_factor = crop_scale_factor |
| | self.clahe_grid_size = clahe_grid_size |
| | self.clahe_clip_limit = clahe_clip_limit |
| | self.normalization_mode = normalization_mode |
| | self.custom_mean = custom_mean |
| | self.custom_std = custom_std |
| | self.do_clahe = do_clahe |
| | self.do_crop = do_crop |
| | self.min_radius_frac = min_radius_frac |
| | self.max_radius_frac = max_radius_frac |
| | self.allow_overflow = allow_overflow |
| | self.softmax_temperature = softmax_temperature |
| |
|
| | def preprocess( |
| | self, |
| | images, |
| | masks=None, |
| | return_tensors: str = "pt", |
| | device: Optional[Union[str, torch.device]] = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """Run the full preprocessing pipeline on a batch of images. |
| | |
| | Accepts any combination of torch.Tensor, PIL.Image, or numpy.ndarray inputs |
| | (see ``standardize_input`` for format details). Optionally processes |
| | accompanying segmentation masks with matching spatial transforms. |
| | |
| | Args: |
| | images: Input images in any supported format. |
| | masks: Optional segmentation masks in any format accepted by |
| | ``standardize_mask_input``. Undergo the same crop/resize as images |
| | (nearest-neighbour interpolation, label-preserving). Returned as |
| | ``torch.long`` under the ``"mask"`` key (or ``None`` if not provided). |
| | return_tensors: Only ``"pt"`` is supported. |
| | device: Device for all tensor operations (e.g. ``"cuda:0"``). |
| | Defaults to the device of the input tensor, or CPU for PIL/numpy. |
| | **kwargs: Passed through to ``BaseImageProcessor``. |
| | |
| | Returns: |
| | ``BatchFeature`` with keys: |
| | |
| | - ``pixel_values`` (B, 3, size, size): Processed float32 images. |
| | - ``mask`` (B, 1, size, size) or ``None``: Processed long masks. |
| | - ``scale_x``, ``scale_y`` (B,): Per-image scale factors. |
| | - ``offset_x``, ``offset_y`` (B,): Per-image offsets. |
| | |
| | Coordinate mapping from processed → original pixel space:: |
| | |
| | orig_x = offset_x + proc_x * scale_x |
| | orig_y = offset_y + proc_y * scale_y |
| | """ |
| | if return_tensors != "pt": |
| | raise ValueError("Only 'pt' (PyTorch) tensors are supported") |
| |
|
| | |
| | if device is not None: |
| | device = torch.device(device) |
| | elif isinstance(images, torch.Tensor): |
| | device = images.device |
| | elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor): |
| | device = images[0].device |
| | else: |
| | |
| | device = torch.device('cpu') |
| |
|
| | |
| | images = standardize_input(images, device) |
| | if masks is not None: |
| | masks = standardize_mask_input(masks, device) |
| | |
| | B, C, H_orig, W_orig = images.shape |
| |
|
| | if self.do_crop: |
| | |
| | cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature) |
| |
|
| | |
| | radius = estimate_radius( |
| | images, cx, cy, |
| | min_radius_frac=self.min_radius_frac, |
| | max_radius_frac=self.max_radius_frac, |
| | ) |
| |
|
| | |
| | x1, y1, x2, y2 = compute_crop_box( |
| | cx, cy, radius, H_orig, W_orig, |
| | scale_factor=self.crop_scale_factor, |
| | allow_overflow=self.allow_overflow, |
| | ) |
| |
|
| | |
| | |
| | scale_x = (x2 - x1) / (self.size - 1) |
| | scale_y = (y2 - y1) / (self.size - 1) |
| | offset_x = x1 |
| | offset_y = y1 |
| |
|
| | |
| | |
| | padding_mode = 'zeros' if self.allow_overflow else 'border' |
| | images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode) |
| |
|
| | if masks is not None: |
| | masks = batch_crop_and_resize_mask( |
| | masks, x1, y1, x2, y2, |
| | self.size, |
| | padding_mode=padding_mode, |
| | ) |
| | else: |
| | |
| | |
| | scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
| | scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
| | offset_x = torch.zeros(B, device=device, dtype=images.dtype) |
| | offset_y = torch.zeros(B, device=device, dtype=images.dtype) |
| | images = resize_images(images, self.size) |
| |
|
| | if masks is not None: |
| | |
| | masks = resize_images(masks.float(), self.size, mode="nearest", antialias=False).round().long() |
| |
|
| | |
| | if self.do_clahe: |
| | images = apply_clahe_vectorized( |
| | images, |
| | grid_size=self.clahe_grid_size, |
| | clip_limit=self.clahe_clip_limit, |
| | ) |
| |
|
| | |
| | images = normalize_images( |
| | images, |
| | mean=self.custom_mean, |
| | std=self.custom_std, |
| | mode=self.normalization_mode, |
| | ) |
| |
|
| | |
| | data = { |
| | "pixel_values": images, |
| | "scale_x": scale_x, |
| | "scale_y": scale_y, |
| | "offset_x": offset_x, |
| | "offset_y": offset_y, |
| | } |
| | if masks is not None: |
| | data["mask"] = masks |
| | return BatchFeature(data=data, tensor_type="pt") |
| |
|
| | def __call__( |
| | self, |
| | images: Union[torch.Tensor, List[torch.Tensor]], |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """Alias for ``preprocess`` — enables ``processor(images, ...)`` call syntax.""" |
| | return self.preprocess(images, **kwargs) |
| |
|
| |
|
| | |
| | EyeGPUImageProcessor = EyeCLAHEImageProcessor |
| |
|