| | """Metrics for S2F training and evaluation. |
| | |
| | Includes: MSE, MS-SSIM, Pixel Correlation (Pearson), Relative Magnitude Error (WFM), |
| | and evaluation helpers for notebooks and scripts. |
| | """ |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from skimage.metrics import structural_similarity as ssim |
| | from scipy.stats import pearsonr |
| | from tqdm import tqdm |
| | import matplotlib.pyplot as plt |
| |
|
| | try: |
| | from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure |
| | from torchmetrics import MeanSquaredError |
| | HAS_TORCHMETRICS = True |
| | except ImportError: |
| | HAS_TORCHMETRICS = False |
| |
|
| |
|
| | def calculate_mse(y_true, y_pred): |
| | if isinstance(y_true, torch.Tensor): |
| | return F.mse_loss(y_pred, y_true).item() |
| | return float(np.mean((np.asarray(y_true) - np.asarray(y_pred)) ** 2)) |
| |
|
| |
|
| | def calculate_psnr(y_true, y_pred, max_pixel_value=1.0): |
| | mse = calculate_mse(y_true, y_pred) |
| | if mse == 0: |
| | return float('inf') |
| | return 20 * np.log10(max_pixel_value / np.sqrt(mse)) |
| |
|
| |
|
| | def calculate_ssim_tensor(y_true, y_pred, data_range=1.0): |
| | if isinstance(y_true, torch.Tensor): |
| | y_true = y_true.detach().cpu().numpy() |
| | if isinstance(y_pred, torch.Tensor): |
| | y_pred = y_pred.detach().cpu().numpy() |
| | ssim_values = [] |
| | batch_size = y_true.shape[0] |
| | for i in range(batch_size): |
| | if len(y_true.shape) == 4: |
| | true_img = y_true[i, 0] if y_true.shape[1] == 1 else y_true[i, 0] |
| | pred_img = y_pred[i, 0] if y_pred.shape[1] == 1 else y_pred[i, 0] |
| | else: |
| | true_img, pred_img = y_true[i], y_pred[i] |
| | ssim_values.append(ssim(true_img, pred_img, data_range=data_range)) |
| | return np.mean(ssim_values) |
| |
|
| |
|
| | def calculate_pearson_correlation(y_true, y_pred): |
| | if isinstance(y_true, torch.Tensor): |
| | y_true = y_true.cpu().numpy() |
| | if isinstance(y_pred, torch.Tensor): |
| | y_pred = y_pred.cpu().numpy() |
| | correlation, _ = pearsonr(y_true.flatten(), y_pred.flatten()) |
| | return correlation |
| |
|
| |
|
| | def calculate_individual_pixel_correlation(y_true, y_pred): |
| | """Pixel-wise Pearson correlation per sample in batch.""" |
| | if isinstance(y_true, torch.Tensor): |
| | y_true = y_true.cpu().numpy() |
| | if isinstance(y_pred, torch.Tensor): |
| | y_pred = y_pred.cpu().numpy() |
| | correlations = [] |
| | batch_size = y_true.shape[0] |
| | for i in range(batch_size): |
| | true_flat = y_true[i].flatten() |
| | pred_flat = y_pred[i].flatten() |
| | r, _ = pearsonr(true_flat, pred_flat) |
| | correlations.append(r) |
| | return correlations |
| |
|
| |
|
| | |
| |
|
| | def _to_numpy_wfm(x): |
| | if isinstance(x, torch.Tensor): |
| | return x.detach().cpu().numpy() |
| | return np.asarray(x) |
| |
|
| |
|
| | def _ensure_shape_wfm(f): |
| | """Ensure (N, 2, H, W). Heatmap -> fx=magnitude, fy=0.""" |
| | if f.ndim == 3: |
| | if f.shape[-1] == 2: |
| | f = np.transpose(f, (2, 0, 1))[None, ...] |
| | elif f.shape[0] == 2: |
| | f = f[None, ...] |
| | else: |
| | raise ValueError(f"Unsupported 3D shape {f.shape}") |
| | elif f.ndim == 4: |
| | if f.shape[-1] == 2: |
| | f = np.transpose(f, (0, 3, 1, 2)) |
| | else: |
| | raise ValueError(f"Unsupported ndim={f.ndim}") |
| | return f |
| |
|
| |
|
| | def _force_mag_wfm(f): |
| | fx, fy = f[:, 0], f[:, 1] |
| | return np.sqrt(fx**2 + fy**2) |
| |
|
| |
|
| | def wfm_correlation(y_true, y_pred, mode="magnitude"): |
| | """Pearson correlation between prediction and ground truth (magnitude mode for heatmaps).""" |
| | t = _ensure_shape_wfm(_to_numpy_wfm(y_true)) |
| | p = _ensure_shape_wfm(_to_numpy_wfm(y_pred)) |
| | if t.shape != p.shape: |
| | raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}") |
| | if mode == "magnitude": |
| | tv = _force_mag_wfm(t).ravel() |
| | pv = _force_mag_wfm(p).ravel() |
| | else: |
| | raise ValueError(f"Unknown mode '{mode}'") |
| | tv, pv = tv.astype(np.float64), pv.astype(np.float64) |
| | if np.allclose(tv.std(), 0) or np.allclose(pv.std(), 0): |
| | return 0.0 |
| | return float(np.corrcoef(tv, pv)[0, 1]) |
| |
|
| |
|
| | def wfm_relative_magnitude_error(y_true, y_pred, eps=1e-8): |
| | """Relative magnitude error for heatmap-as-magnitude.""" |
| | t = _ensure_shape_wfm(_to_numpy_wfm(y_true)) |
| | p = _ensure_shape_wfm(_to_numpy_wfm(y_pred)) |
| | if t.shape != p.shape: |
| | raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}") |
| | mag_t = _force_mag_wfm(t) |
| | mag_p = _force_mag_wfm(p) |
| | fbar = np.mean(mag_t) |
| | if np.isclose(fbar, 0): |
| | return 0.0 |
| | rel = np.abs(mag_p - mag_t) / (mag_t + eps) |
| | w = mag_t / fbar |
| | return float(np.mean(rel * w)) |
| |
|
| |
|
| | def apply_threshold_mask(tensor, threshold=0.0): |
| | return tensor * (tensor >= threshold).float() |
| |
|
| |
|
| | def detect_tanh_output_model(model): |
| | """Detect if model outputs [-1, 1] (Tanh).""" |
| | if hasattr(model, 'use_sigmoid') and not model.use_sigmoid: |
| | return True |
| | if hasattr(model, 'use_tanh_output') and model.use_tanh_output: |
| | return True |
| | if hasattr(model, 'final_conv'): |
| | fc = model.final_conv |
| | if isinstance(fc, nn.Sequential): |
| | if isinstance(fc[-1], nn.Tanh): |
| | return True |
| | elif isinstance(fc, nn.Tanh): |
| | return True |
| | return False |
| |
|
| |
|
| | def convert_tanh_to_sigmoid_range(tensor): |
| | return (tensor + 1.0) / 2.0 |
| |
|
| |
|
| | |
| |
|
| | class TorchMetricsWrapper: |
| | def __init__(self, device='cpu'): |
| | self.device = device |
| | self.reset_metrics() |
| |
|
| | def reset_metrics(self): |
| | if HAS_TORCHMETRICS: |
| | self.ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) |
| | self.mse = MeanSquaredError().to(self.device) |
| | else: |
| | self.ms_ssim = None |
| | self.mse = None |
| |
|
| | def compute_ms_ssim(self, y_true, y_pred): |
| | if not HAS_TORCHMETRICS: |
| | return float(calculate_ssim_tensor(y_true, y_pred)) |
| | y_true = y_true.to(self.device) |
| | y_pred = y_pred.to(self.device) |
| | if y_true.shape[1] == 1: |
| | pass |
| | else: |
| | y_true, y_pred = y_true[:, 0:1], y_pred[:, 0:1] |
| | return self.ms_ssim(y_pred, y_true).item() |
| |
|
| | def compute_mse(self, y_true, y_pred): |
| | if not HAS_TORCHMETRICS: |
| | return calculate_mse(y_true, y_pred) |
| | y_true = y_true.to(self.device) |
| | y_pred = y_pred.to(self.device) |
| | return self.mse(y_pred, y_true).item() |
| |
|
| |
|
| | |
| |
|
| | def evaluate_metrics_on_dataset(generator, data_loader, device=None, description="Evaluating", |
| | save_predictions=False, threshold=0.0, use_settings=False, |
| | normalization_params=None, config_path=None, substrate_override=None): |
| | """ |
| | Evaluate S2F generator on a dataset. Returns MSE, MS-SSIM, Pixel Correlation, |
| | Relative Magnitude Error, and force sum/mean correlations. |
| | """ |
| | if device is None: |
| | device = torch.device('mps' if torch.backends.mps.is_available() else |
| | 'cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | generator = generator.to(device) |
| | generator.eval() |
| | metrics_wrapper = TorchMetricsWrapper(device=device) |
| |
|
| | heatmap_mse = [] |
| | heatmap_ms_ssim = [] |
| | heatmap_pixel_corr = [] |
| | wfm_corr_mag = [] |
| | wfm_rel_mag_err = [] |
| | force_sum_gt, force_sum_pred = [], [] |
| | force_mean_gt, force_mean_pred = [], [] |
| | individual_predictions = [] if save_predictions else None |
| |
|
| | with torch.no_grad(): |
| | for batch_idx, batch_data in enumerate(tqdm(data_loader, desc=description)): |
| | if len(batch_data) == 5: |
| | images, heatmaps, _, _, metadata = batch_data |
| | has_metadata = True |
| | else: |
| | images, heatmaps, _, _ = batch_data |
| | has_metadata = False |
| |
|
| | images = images.to(device, dtype=torch.float32) |
| | heatmaps = heatmaps.to(device, dtype=torch.float32) |
| |
|
| | if use_settings and normalization_params is not None: |
| | from models.s2f_model import create_settings_channels |
| | meta = metadata if has_metadata else {'substrate': [substrate_override or 'fibroblasts_PDMS'] * images.size(0)} |
| | settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path) |
| | images = torch.cat([images, settings_ch], dim=1) |
| |
|
| | pred = generator(images) |
| | if detect_tanh_output_model(generator): |
| | pred = convert_tanh_to_sigmoid_range(pred) |
| |
|
| | gt_thresh = apply_threshold_mask(heatmaps, threshold) |
| | pred_thresh = pred |
| |
|
| | heatmap_mse.append(metrics_wrapper.compute_mse(gt_thresh, pred_thresh)) |
| | heatmap_ms_ssim.append(metrics_wrapper.compute_ms_ssim(gt_thresh, pred_thresh)) |
| | heatmap_pixel_corr.extend(calculate_individual_pixel_correlation(gt_thresh, pred_thresh)) |
| |
|
| | |
| | B, _, H, W = gt_thresh.shape |
| | gt_ff = torch.zeros(B, 2, H, W, device=device) |
| | pred_ff = torch.zeros(B, 2, H, W, device=device) |
| | gt_ff[:, 0], pred_ff[:, 0] = gt_thresh[:, 0], pred_thresh[:, 0] |
| | try: |
| | wfm_corr_mag.append(wfm_correlation(gt_ff, pred_ff, mode="magnitude")) |
| | wfm_rel_mag_err.append(wfm_relative_magnitude_error(gt_ff, pred_ff)) |
| | except Exception: |
| | wfm_corr_mag.append(float('nan')) |
| | wfm_rel_mag_err.append(float('nan')) |
| |
|
| | force_sum_gt.extend(torch.sum(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) |
| | force_sum_pred.extend(torch.sum(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) |
| | force_mean_gt.extend(torch.mean(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) |
| | force_mean_pred.extend(torch.mean(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) |
| |
|
| | if save_predictions: |
| | for i in range(images.size(0)): |
| | p, t = pred_thresh[i:i+1], gt_thresh[i:i+1] |
| | gt_ff_i = torch.zeros(1, 2, H, W, device=device) |
| | pred_ff_i = torch.zeros(1, 2, H, W, device=device) |
| | gt_ff_i[0, 0], pred_ff_i[0, 0] = t[0, 0], p[0, 0] |
| | try: |
| | rme = wfm_relative_magnitude_error(gt_ff_i, pred_ff_i) |
| | except Exception: |
| | rme = float('nan') |
| | individual_predictions.append({ |
| | 'batch_idx': batch_idx, |
| | 'sample_idx': i, |
| | 'original_image': images[i].cpu().numpy(), |
| | 'ground_truth': heatmaps[i].cpu().numpy(), |
| | 'ground_truth_thresholded': gt_thresh[i].cpu().numpy(), |
| | 'prediction': pred[i].cpu().numpy(), |
| | 'prediction_thresholded': pred_thresh[i].cpu().numpy(), |
| | 'mse': metrics_wrapper.compute_mse(t, p), |
| | 'ms_ssim': metrics_wrapper.compute_ms_ssim(t, p), |
| | 'pixel_correlation': calculate_pearson_correlation(t, p), |
| | 'wfm_relative_magnitude_error': rme, |
| | 'force_sum_gt': torch.sum(gt_thresh[i]).item(), |
| | 'force_sum_pred': torch.sum(pred_thresh[i]).item(), |
| | 'force_mean_gt': torch.mean(gt_thresh[i]).item(), |
| | 'force_mean_pred': torch.mean(pred_thresh[i]).item(), |
| | }) |
| |
|
| | valid_wfm_corr = [x for x in wfm_corr_mag if not np.isnan(x)] |
| | valid_wfm_rme = [x for x in wfm_rel_mag_err if not np.isnan(x)] |
| | try: |
| | force_sum_corr, _ = pearsonr(force_sum_gt, force_sum_pred) |
| | force_mean_corr, _ = pearsonr(force_mean_gt, force_mean_pred) |
| | except Exception: |
| | force_sum_corr = force_mean_corr = 0.0 |
| | if force_sum_corr is None or (isinstance(force_sum_corr, float) and np.isnan(force_sum_corr)): |
| | force_sum_corr = 0.0 |
| | if force_mean_corr is None or (isinstance(force_mean_corr, float) and np.isnan(force_mean_corr)): |
| | force_mean_corr = 0.0 |
| |
|
| | results = { |
| | 'heatmap': { |
| | 'mse': np.mean(heatmap_mse), |
| | 'mse_std': np.std(heatmap_mse), |
| | 'ms_ssim': np.mean(heatmap_ms_ssim), |
| | 'ms_ssim_std': np.std(heatmap_ms_ssim), |
| | 'pixel_correlation': np.mean(heatmap_pixel_corr), |
| | 'pixel_correlation_std': np.std(heatmap_pixel_corr), |
| | }, |
| | 'wfm': { |
| | 'correlation_magnitude': np.mean(valid_wfm_corr) if valid_wfm_corr else float('nan'), |
| | 'correlation_magnitude_std': np.std(valid_wfm_corr) if valid_wfm_corr else float('nan'), |
| | 'relative_magnitude_error': np.mean(valid_wfm_rme) if valid_wfm_rme else float('nan'), |
| | 'relative_magnitude_error_std': np.std(valid_wfm_rme) if valid_wfm_rme else float('nan'), |
| | }, |
| | 'force_sum': { |
| | 'correlation': float(force_sum_corr), |
| | 'gt_mean': np.mean(force_sum_gt), |
| | 'pred_mean': np.mean(force_sum_pred), |
| | 'gt_std': np.std(force_sum_gt), |
| | 'pred_std': np.std(force_sum_pred), |
| | }, |
| | 'force_mean': { |
| | 'correlation': float(force_mean_corr), |
| | 'gt_mean': np.mean(force_mean_gt), |
| | 'pred_mean': np.mean(force_mean_pred), |
| | }, |
| | } |
| |
|
| | if save_predictions: |
| | results['individual_predictions'] = individual_predictions |
| | return results |
| |
|
| |
|
| | def print_metrics_report(report, threshold=0.0, uses_tanh=False): |
| | """Print formatted metrics report.""" |
| | for name, metrics in report.items(): |
| | print(f"\n🔸 {name.upper()} SET METRICS" + (f" (threshold={threshold})" if threshold > 0 else "")) |
| | print("-" * 60) |
| | print("HEATMAP METRICS:") |
| | print(f" MSE: {metrics['heatmap']['mse']:.6f} ± {metrics['heatmap']['mse_std']:.6f}") |
| | print(f" MS-SSIM: {metrics['heatmap']['ms_ssim']:.4f} ± {metrics['heatmap']['ms_ssim_std']:.4f}") |
| | print(f" Pixel Corr: {metrics['heatmap']['pixel_correlation']:.4f} ± {metrics['heatmap']['pixel_correlation_std']:.4f}") |
| | print("WFM METRICS (heatmap as magnitude):") |
| | print(f" Correlation (Magnitude): {metrics['wfm']['correlation_magnitude']:.4f} ± {metrics['wfm']['correlation_magnitude_std']:.4f}") |
| | print(f" Relative Magnitude Error: {metrics['wfm']['relative_magnitude_error']:.4f} ± {metrics['wfm']['relative_magnitude_error_std']:.4f}") |
| | print("FORCE SUM CORRELATION:") |
| | print(f" Correlation: {metrics['force_sum']['correlation']:.4f}") |
| | print(f" GT Mean: {metrics['force_sum']['gt_mean']:.2f} ± {metrics['force_sum']['gt_std']:.2f}") |
| | print(f" Pred Mean: {metrics['force_sum']['pred_mean']:.2f} ± {metrics['force_sum']['pred_std']:.2f}") |
| | if uses_tanh: |
| | print(" Note: Model outputs [-1,1], converted to [0,1] for evaluation") |
| | print("=" * 60) |
| |
|
| |
|
| | def gen_prediction_plots(individual_predictions, save_dir, sort_by='ms_ssim', sort_order='desc', threshold=0.0): |
| | """Generate prediction plots (BF | GT | Pred) sorted by metric.""" |
| | os.makedirs(save_dir, exist_ok=True) |
| | reverse = (sort_order.lower() == 'desc') if sort_by.lower() not in ['mse', 'wfm_relative_magnitude_error'] else (sort_order.lower() == 'desc') |
| | valid = [p for p in individual_predictions if not np.isnan(p.get(sort_by.lower(), 0))] |
| | sorted_preds = sorted(valid, key=lambda x: x[sort_by.lower()], reverse=reverse) |
| | print(f"Sorting {len(sorted_preds)} predictions by {sort_by} ({sort_order})") |
| | for rank, p in enumerate(tqdm(sorted_preds, desc="Saving plots"), 1): |
| | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| | img = p['original_image'] |
| | axes[0].imshow(img[0] if img.ndim == 3 else img, cmap='gray') |
| | axes[0].set_title('Bright Field') |
| | axes[0].axis('off') |
| | gt = p['ground_truth'] |
| | axes[1].imshow(gt[0] if gt.ndim == 3 else gt, cmap='jet', vmin=0, vmax=1) |
| | axes[1].set_title('Ground Truth') |
| | axes[1].axis('off') |
| | pr = p['prediction'] |
| | axes[2].imshow(pr[0] if pr.ndim == 3 else pr, cmap='jet', vmin=0, vmax=1) |
| | axes[2].set_title('Prediction') |
| | axes[2].axis('off') |
| | m = (f"MSE: {p['mse']:.4f} | MS-SSIM: {p['ms_ssim']:.4f} | " |
| | f"Pixel Corr: {p['pixel_correlation']:.4f} | Rel Mag Err: {p.get('wfm_relative_magnitude_error', 'N/A')}") |
| | fig.suptitle(f"Rank {rank} (by {sort_by})\n{m}", fontsize=10, y=0.02) |
| | plt.tight_layout() |
| | plt.savefig(os.path.join(save_dir, f"rank{rank:03d}_batch{p['batch_idx']:03d}_sample{p['sample_idx']:02d}.png"), dpi=150, bbox_inches='tight') |
| | plt.close() |
| |
|
| |
|
| | def plot_predictions(loader, generator, n_samples, device, threshold=0.0, |
| | use_settings=False, normalization_params=None, config_path=None, substrate_override=None): |
| | """Plot BF | GT | Pred for first n_samples from loader.""" |
| | generator = generator.to(device) |
| | generator.eval() |
| | bf_list, gt_list, meta_list = [], [], [] |
| | it = iter(loader) |
| | while len(bf_list) < n_samples: |
| | try: |
| | batch = next(it) |
| | except StopIteration: |
| | break |
| | if len(batch) == 5: |
| | images, heatmaps, _, _, meta = batch |
| | else: |
| | images, heatmaps = batch[0], batch[1] |
| | meta = None |
| | for i in range(images.shape[0]): |
| | if len(bf_list) >= n_samples: |
| | break |
| | bf_list.append(images[i]) |
| | gt_list.append(heatmaps[i]) |
| | meta_list.append(meta) |
| | n = min(n_samples, len(bf_list)) |
| | bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32) |
| | if use_settings and normalization_params: |
| | from models.s2f_model import create_settings_channels |
| | sub = substrate_override or 'fibroblasts_PDMS' |
| | meta_dict = {'substrate': [sub] * n} |
| | settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path) |
| | bf_batch = torch.cat([bf_batch, settings_ch], dim=1) |
| | with torch.no_grad(): |
| | pred = generator(bf_batch) |
| | if detect_tanh_output_model(generator): |
| | pred = convert_tanh_to_sigmoid_range(pred) |
| | if threshold > 0: |
| | pred = pred * (pred >= threshold).float() |
| | fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n)) |
| | if n == 1: |
| | axes = axes.reshape(1, -1) |
| | for i in range(n): |
| | axes[i, 0].imshow(bf_list[i].squeeze().cpu().numpy(), cmap='gray') |
| | axes[i, 0].set_title('Bright Field') |
| | axes[i, 0].axis('off') |
| | axes[i, 1].imshow(gt_list[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1) |
| | axes[i, 1].set_title('Ground Truth') |
| | axes[i, 1].axis('off') |
| | axes[i, 2].imshow(pred[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1) |
| | axes[i, 2].set_title('Prediction') |
| | axes[i, 2].axis('off') |
| | plt.tight_layout() |
| | plt.show() |
| |
|