import sys import logging import numpy as np import torch import torch.nn.functional as F from torchvision import transforms import numpy as np from PIL import Image, ImageDraw from skimage.measure import label, regionprops import matplotlib.pyplot as plt from matplotlib import cm from io import BytesIO from PIL import Image, ImageOps, ImageDraw from skimage.transform import resize from skimage.measure import label, regionprops # From scikit-image from PIL import Image, ImageOps, ImageDraw # Configure logging logger = logging.getLogger(__name__) class GradCAMSegmentation: def __init__(self, model, target_layer_name): self.model = model self.target_layer = self._find_layer(target_layer_name) self.activations = None self.gradients = None self._register_hooks() def _find_layer(self, target_layer_name): """More robust layer finding implementation""" module = self.model for attr in target_layer_name.split('.'): try: if attr.isdigit(): module = module[int(attr)] else: module = getattr(module, attr) except (AttributeError, IndexError) as e: raise ValueError(f"Could not find layer {target_layer_name}: {str(e)}") return module def _register_hooks(self): def forward_hook(module, input, output): self.activations = output.detach() def backward_hook(module, grad_input, grad_output): self.gradients = grad_output[0].detach() self.target_layer.register_forward_hook(forward_hook) self.target_layer.register_backward_hook(backward_hook) def __call__(self, input_tensor): self.model.zero_grad() output = self.model(input_tensor) output.mean().backward() # More robust gradient calculation if self.gradients is None: raise RuntimeError("Gradients not captured - check hook registration") weights = torch.mean(self.gradients, dim=[2, 3], keepdim=True) cam = torch.sum(self.activations * weights, dim=1, keepdim=True) cam = torch.relu(cam) # Better normalization cam_min, cam_max = torch.min(cam), torch.max(cam) if cam_max - cam_min > 1e-8: cam = (cam - cam_min) / (cam_max - cam_min) else: cam = torch.zeros_like(cam) return cam.squeeze().cpu().numpy() def preprocess_image_pil(image): """Convert PIL image to properly shaped torch tensor""" try: # Convert to grayscale if image.mode != 'L': image = image.convert('L') # Validate original size if not all(dim > 0 for dim in image.size): raise ValueError(f"Invalid image dimensions: {image.size}") # Transform pipeline transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) tensor = transform(image).unsqueeze(0) # Add batch dimension logger.debug(f"Output tensor shape: {tensor.shape}") if tensor.shape != (1, 1, 256, 256): raise ValueError(f"Invalid output shape: {tensor.shape}") return tensor, image except Exception as e: logger.error(f"Preprocessing failed: {e}") return None, None def postprocess_mask(mask_tensor, original_size): """Convert model output to PIL mask with thresholding""" try: if not isinstance(mask_tensor, torch.Tensor): raise ValueError("Input must be a torch.Tensor") mask_np = mask_tensor.squeeze().cpu().numpy() logger.info(f"Mask range before threshold: {mask_np.min():.2f}-{mask_np.max():.2f}") # Apply threshold and convert to binary mask binary_mask = (mask_np > 0.5).astype(np.uint8) * 255 mask_img = Image.fromarray(binary_mask).convert("L") # Resize to original dimensions if original_size != (256, 256): mask_img = mask_img.resize(original_size) logger.info(f"Resized mask to original size: {original_size}") return mask_img except Exception as e: logger.error(f"Postprocessing failed: {str(e)}", exc_info=True) raise def overlay_mask(original, mask): """Create overlay visualization with validation""" try: if not all(isinstance(img, Image.Image) for img in [original, mask]): raise ValueError("Both inputs must be PIL Images") # Convert original to RGB if needed if original.mode != 'RGB': original = original.convert("RGB") # Create red mask overlay mask_gray = mask.convert("L") red_mask = ImageOps.colorize(mask_gray, black="black", white="red").convert("RGBA") # Blend with original overlayed = Image.blend( original.convert("RGBA"), red_mask, alpha=0.4 # Adjust transparency ) return overlayed.convert("RGB") except Exception as e: logger.error(f"Overlay creation failed: {str(e)}", exc_info=True) error_img = Image.new("RGB", (256, 256), color="black") draw = ImageDraw.Draw(error_img) draw.text((10, 10), f"Overlay Error: {str(e)}", fill="white") return error_img def generate_gradcam(model, input_tensor, original_size): """Guaranteed-to-work Grad-CAM implementation""" try: # Verify input if not isinstance(input_tensor, torch.Tensor): raise ValueError("Input must be a torch.Tensor") # Generate CAM gradcam = GradCAMSegmentation(model, "output_block.conv.conv") cam = gradcam(input_tensor) # Convert to PIL Image with guaranteed visualization cam_uint8 = np.uint8(255 * cam) heatmap = Image.fromarray(cam_uint8).convert('L') # Apply color mapping that always works heatmap_color = ImageOps.colorize( heatmap, black='blue', white='red', mid='yellow' ).resize(original_size) return heatmap_color except Exception as e: # Create error image with detailed message error_img = Image.new("RGB", original_size, color="black") draw = ImageDraw.Draw(error_img) draw.text((10, 10), "GRAD-CAM ERROR", fill="red") draw.text((10, 40), str(e)[:50], fill="white") return error_img def overlay_gradcam(original, gradcam): """Foolproof overlay implementation""" try: original = original.convert("RGB") gradcam = gradcam.convert("RGB") # Simple, guaranteed overlay return Image.blend(original, gradcam, alpha=0.5) except Exception as e: error_img = Image.new("RGB", (256, 256), color="black") draw = ImageDraw.Draw(error_img) draw.text((10, 10), "OVERLAY ERROR", fill="red") return error_img