Shape2Force / S2FApp /utils /metrics.py
kaveh's picture
added
2b9ff22
"""Metrics for S2F training and evaluation.
Includes: MSE, MS-SSIM, Pixel Correlation (Pearson), Relative Magnitude Error (WFM),
and evaluation helpers for notebooks and scripts.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.metrics import structural_similarity as ssim
from scipy.stats import pearsonr
from tqdm import tqdm
import matplotlib.pyplot as plt
try:
from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure
from torchmetrics import MeanSquaredError
HAS_TORCHMETRICS = True
except ImportError:
HAS_TORCHMETRICS = False
def calculate_mse(y_true, y_pred):
if isinstance(y_true, torch.Tensor):
return F.mse_loss(y_pred, y_true).item()
return float(np.mean((np.asarray(y_true) - np.asarray(y_pred)) ** 2))
def calculate_psnr(y_true, y_pred, max_pixel_value=1.0):
mse = calculate_mse(y_true, y_pred)
if mse == 0:
return float('inf')
return 20 * np.log10(max_pixel_value / np.sqrt(mse))
def calculate_ssim_tensor(y_true, y_pred, data_range=1.0):
if isinstance(y_true, torch.Tensor):
y_true = y_true.detach().cpu().numpy()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.detach().cpu().numpy()
ssim_values = []
batch_size = y_true.shape[0]
for i in range(batch_size):
if len(y_true.shape) == 4:
true_img = y_true[i, 0] if y_true.shape[1] == 1 else y_true[i, 0]
pred_img = y_pred[i, 0] if y_pred.shape[1] == 1 else y_pred[i, 0]
else:
true_img, pred_img = y_true[i], y_pred[i]
ssim_values.append(ssim(true_img, pred_img, data_range=data_range))
return np.mean(ssim_values)
def calculate_pearson_correlation(y_true, y_pred):
if isinstance(y_true, torch.Tensor):
y_true = y_true.cpu().numpy()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu().numpy()
correlation, _ = pearsonr(y_true.flatten(), y_pred.flatten())
return correlation
def calculate_individual_pixel_correlation(y_true, y_pred):
"""Pixel-wise Pearson correlation per sample in batch."""
if isinstance(y_true, torch.Tensor):
y_true = y_true.cpu().numpy()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu().numpy()
correlations = []
batch_size = y_true.shape[0]
for i in range(batch_size):
true_flat = y_true[i].flatten()
pred_flat = y_pred[i].flatten()
r, _ = pearsonr(true_flat, pred_flat)
correlations.append(r)
return correlations
# --- WFM (Wrinkle Force Microscopy) metrics for heatmap as magnitude ---
def _to_numpy_wfm(x):
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return np.asarray(x)
def _ensure_shape_wfm(f):
"""Ensure (N, 2, H, W). Heatmap -> fx=magnitude, fy=0."""
if f.ndim == 3:
if f.shape[-1] == 2:
f = np.transpose(f, (2, 0, 1))[None, ...]
elif f.shape[0] == 2:
f = f[None, ...]
else:
raise ValueError(f"Unsupported 3D shape {f.shape}")
elif f.ndim == 4:
if f.shape[-1] == 2:
f = np.transpose(f, (0, 3, 1, 2))
else:
raise ValueError(f"Unsupported ndim={f.ndim}")
return f
def _force_mag_wfm(f):
fx, fy = f[:, 0], f[:, 1]
return np.sqrt(fx**2 + fy**2)
def wfm_correlation(y_true, y_pred, mode="magnitude"):
"""Pearson correlation between prediction and ground truth (magnitude mode for heatmaps)."""
t = _ensure_shape_wfm(_to_numpy_wfm(y_true))
p = _ensure_shape_wfm(_to_numpy_wfm(y_pred))
if t.shape != p.shape:
raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}")
if mode == "magnitude":
tv = _force_mag_wfm(t).ravel()
pv = _force_mag_wfm(p).ravel()
else:
raise ValueError(f"Unknown mode '{mode}'")
tv, pv = tv.astype(np.float64), pv.astype(np.float64)
if np.allclose(tv.std(), 0) or np.allclose(pv.std(), 0):
return 0.0
return float(np.corrcoef(tv, pv)[0, 1])
def wfm_relative_magnitude_error(y_true, y_pred, eps=1e-8):
"""Relative magnitude error for heatmap-as-magnitude."""
t = _ensure_shape_wfm(_to_numpy_wfm(y_true))
p = _ensure_shape_wfm(_to_numpy_wfm(y_pred))
if t.shape != p.shape:
raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}")
mag_t = _force_mag_wfm(t)
mag_p = _force_mag_wfm(p)
fbar = np.mean(mag_t)
if np.isclose(fbar, 0):
return 0.0
rel = np.abs(mag_p - mag_t) / (mag_t + eps)
w = mag_t / fbar
return float(np.mean(rel * w))
def apply_threshold_mask(tensor, threshold=0.0):
return tensor * (tensor >= threshold).float()
def detect_tanh_output_model(model):
"""Detect if model outputs [-1, 1] (Tanh)."""
if hasattr(model, 'use_sigmoid') and not model.use_sigmoid:
return True
if hasattr(model, 'use_tanh_output') and model.use_tanh_output:
return True
if hasattr(model, 'final_conv'):
fc = model.final_conv
if isinstance(fc, nn.Sequential):
if isinstance(fc[-1], nn.Tanh):
return True
elif isinstance(fc, nn.Tanh):
return True
return False
def convert_tanh_to_sigmoid_range(tensor):
return (tensor + 1.0) / 2.0
# --- TorchMetrics wrapper for MS-SSIM ---
class TorchMetricsWrapper:
def __init__(self, device='cpu'):
self.device = device
self.reset_metrics()
def reset_metrics(self):
if HAS_TORCHMETRICS:
self.ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0).to(self.device)
self.mse = MeanSquaredError().to(self.device)
else:
self.ms_ssim = None
self.mse = None
def compute_ms_ssim(self, y_true, y_pred):
if not HAS_TORCHMETRICS:
return float(calculate_ssim_tensor(y_true, y_pred)) # fallback to SSIM
y_true = y_true.to(self.device)
y_pred = y_pred.to(self.device)
if y_true.shape[1] == 1:
pass
else:
y_true, y_pred = y_true[:, 0:1], y_pred[:, 0:1]
return self.ms_ssim(y_pred, y_true).item()
def compute_mse(self, y_true, y_pred):
if not HAS_TORCHMETRICS:
return calculate_mse(y_true, y_pred)
y_true = y_true.to(self.device)
y_pred = y_pred.to(self.device)
return self.mse(y_pred, y_true).item()
# --- Full evaluation on dataset ---
def evaluate_metrics_on_dataset(generator, data_loader, device=None, description="Evaluating",
save_predictions=False, threshold=0.0, use_settings=False,
normalization_params=None, config_path=None, substrate_override=None):
"""
Evaluate S2F generator on a dataset. Returns MSE, MS-SSIM, Pixel Correlation,
Relative Magnitude Error, and force sum/mean correlations.
"""
if device is None:
device = torch.device('mps' if torch.backends.mps.is_available() else
'cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
generator.eval()
metrics_wrapper = TorchMetricsWrapper(device=device)
heatmap_mse = []
heatmap_ms_ssim = []
heatmap_pixel_corr = []
wfm_corr_mag = []
wfm_rel_mag_err = []
force_sum_gt, force_sum_pred = [], []
force_mean_gt, force_mean_pred = [], []
individual_predictions = [] if save_predictions else None
with torch.no_grad():
for batch_idx, batch_data in enumerate(tqdm(data_loader, desc=description)):
if len(batch_data) == 5:
images, heatmaps, _, _, metadata = batch_data
has_metadata = True
else:
images, heatmaps, _, _ = batch_data
has_metadata = False
images = images.to(device, dtype=torch.float32)
heatmaps = heatmaps.to(device, dtype=torch.float32)
if use_settings and normalization_params is not None:
from models.s2f_model import create_settings_channels
meta = metadata if has_metadata else {'substrate': [substrate_override or 'fibroblasts_PDMS'] * images.size(0)}
settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path)
images = torch.cat([images, settings_ch], dim=1)
pred = generator(images)
if detect_tanh_output_model(generator):
pred = convert_tanh_to_sigmoid_range(pred)
gt_thresh = apply_threshold_mask(heatmaps, threshold)
pred_thresh = pred # no threshold on pred for metrics
heatmap_mse.append(metrics_wrapper.compute_mse(gt_thresh, pred_thresh))
heatmap_ms_ssim.append(metrics_wrapper.compute_ms_ssim(gt_thresh, pred_thresh))
heatmap_pixel_corr.extend(calculate_individual_pixel_correlation(gt_thresh, pred_thresh))
# WFM: heatmap as magnitude (fx=magnitude, fy=0)
B, _, H, W = gt_thresh.shape
gt_ff = torch.zeros(B, 2, H, W, device=device)
pred_ff = torch.zeros(B, 2, H, W, device=device)
gt_ff[:, 0], pred_ff[:, 0] = gt_thresh[:, 0], pred_thresh[:, 0]
try:
wfm_corr_mag.append(wfm_correlation(gt_ff, pred_ff, mode="magnitude"))
wfm_rel_mag_err.append(wfm_relative_magnitude_error(gt_ff, pred_ff))
except Exception:
wfm_corr_mag.append(float('nan'))
wfm_rel_mag_err.append(float('nan'))
force_sum_gt.extend(torch.sum(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist())
force_sum_pred.extend(torch.sum(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist())
force_mean_gt.extend(torch.mean(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist())
force_mean_pred.extend(torch.mean(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist())
if save_predictions:
for i in range(images.size(0)):
p, t = pred_thresh[i:i+1], gt_thresh[i:i+1]
gt_ff_i = torch.zeros(1, 2, H, W, device=device)
pred_ff_i = torch.zeros(1, 2, H, W, device=device)
gt_ff_i[0, 0], pred_ff_i[0, 0] = t[0, 0], p[0, 0]
try:
rme = wfm_relative_magnitude_error(gt_ff_i, pred_ff_i)
except Exception:
rme = float('nan')
individual_predictions.append({
'batch_idx': batch_idx,
'sample_idx': i,
'original_image': images[i].cpu().numpy(),
'ground_truth': heatmaps[i].cpu().numpy(),
'ground_truth_thresholded': gt_thresh[i].cpu().numpy(),
'prediction': pred[i].cpu().numpy(),
'prediction_thresholded': pred_thresh[i].cpu().numpy(),
'mse': metrics_wrapper.compute_mse(t, p),
'ms_ssim': metrics_wrapper.compute_ms_ssim(t, p),
'pixel_correlation': calculate_pearson_correlation(t, p),
'wfm_relative_magnitude_error': rme,
'force_sum_gt': torch.sum(gt_thresh[i]).item(),
'force_sum_pred': torch.sum(pred_thresh[i]).item(),
'force_mean_gt': torch.mean(gt_thresh[i]).item(),
'force_mean_pred': torch.mean(pred_thresh[i]).item(),
})
valid_wfm_corr = [x for x in wfm_corr_mag if not np.isnan(x)]
valid_wfm_rme = [x for x in wfm_rel_mag_err if not np.isnan(x)]
try:
force_sum_corr, _ = pearsonr(force_sum_gt, force_sum_pred)
force_mean_corr, _ = pearsonr(force_mean_gt, force_mean_pred)
except Exception:
force_sum_corr = force_mean_corr = 0.0
if force_sum_corr is None or (isinstance(force_sum_corr, float) and np.isnan(force_sum_corr)):
force_sum_corr = 0.0
if force_mean_corr is None or (isinstance(force_mean_corr, float) and np.isnan(force_mean_corr)):
force_mean_corr = 0.0
results = {
'heatmap': {
'mse': np.mean(heatmap_mse),
'mse_std': np.std(heatmap_mse),
'ms_ssim': np.mean(heatmap_ms_ssim),
'ms_ssim_std': np.std(heatmap_ms_ssim),
'pixel_correlation': np.mean(heatmap_pixel_corr),
'pixel_correlation_std': np.std(heatmap_pixel_corr),
},
'wfm': {
'correlation_magnitude': np.mean(valid_wfm_corr) if valid_wfm_corr else float('nan'),
'correlation_magnitude_std': np.std(valid_wfm_corr) if valid_wfm_corr else float('nan'),
'relative_magnitude_error': np.mean(valid_wfm_rme) if valid_wfm_rme else float('nan'),
'relative_magnitude_error_std': np.std(valid_wfm_rme) if valid_wfm_rme else float('nan'),
},
'force_sum': {
'correlation': float(force_sum_corr),
'gt_mean': np.mean(force_sum_gt),
'pred_mean': np.mean(force_sum_pred),
'gt_std': np.std(force_sum_gt),
'pred_std': np.std(force_sum_pred),
},
'force_mean': {
'correlation': float(force_mean_corr),
'gt_mean': np.mean(force_mean_gt),
'pred_mean': np.mean(force_mean_pred),
},
}
if save_predictions:
results['individual_predictions'] = individual_predictions
return results
def print_metrics_report(report, threshold=0.0, uses_tanh=False):
"""Print formatted metrics report."""
for name, metrics in report.items():
print(f"\n🔸 {name.upper()} SET METRICS" + (f" (threshold={threshold})" if threshold > 0 else ""))
print("-" * 60)
print("HEATMAP METRICS:")
print(f" MSE: {metrics['heatmap']['mse']:.6f} ± {metrics['heatmap']['mse_std']:.6f}")
print(f" MS-SSIM: {metrics['heatmap']['ms_ssim']:.4f} ± {metrics['heatmap']['ms_ssim_std']:.4f}")
print(f" Pixel Corr: {metrics['heatmap']['pixel_correlation']:.4f} ± {metrics['heatmap']['pixel_correlation_std']:.4f}")
print("WFM METRICS (heatmap as magnitude):")
print(f" Correlation (Magnitude): {metrics['wfm']['correlation_magnitude']:.4f} ± {metrics['wfm']['correlation_magnitude_std']:.4f}")
print(f" Relative Magnitude Error: {metrics['wfm']['relative_magnitude_error']:.4f} ± {metrics['wfm']['relative_magnitude_error_std']:.4f}")
print("FORCE SUM CORRELATION:")
print(f" Correlation: {metrics['force_sum']['correlation']:.4f}")
print(f" GT Mean: {metrics['force_sum']['gt_mean']:.2f} ± {metrics['force_sum']['gt_std']:.2f}")
print(f" Pred Mean: {metrics['force_sum']['pred_mean']:.2f} ± {metrics['force_sum']['pred_std']:.2f}")
if uses_tanh:
print(" Note: Model outputs [-1,1], converted to [0,1] for evaluation")
print("=" * 60)
def gen_prediction_plots(individual_predictions, save_dir, sort_by='ms_ssim', sort_order='desc', threshold=0.0):
"""Generate prediction plots (BF | GT | Pred) sorted by metric."""
os.makedirs(save_dir, exist_ok=True)
reverse = (sort_order.lower() == 'desc') if sort_by.lower() not in ['mse', 'wfm_relative_magnitude_error'] else (sort_order.lower() == 'desc')
valid = [p for p in individual_predictions if not np.isnan(p.get(sort_by.lower(), 0))]
sorted_preds = sorted(valid, key=lambda x: x[sort_by.lower()], reverse=reverse)
print(f"Sorting {len(sorted_preds)} predictions by {sort_by} ({sort_order})")
for rank, p in enumerate(tqdm(sorted_preds, desc="Saving plots"), 1):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
img = p['original_image']
axes[0].imshow(img[0] if img.ndim == 3 else img, cmap='gray')
axes[0].set_title('Bright Field')
axes[0].axis('off')
gt = p['ground_truth']
axes[1].imshow(gt[0] if gt.ndim == 3 else gt, cmap='jet', vmin=0, vmax=1)
axes[1].set_title('Ground Truth')
axes[1].axis('off')
pr = p['prediction']
axes[2].imshow(pr[0] if pr.ndim == 3 else pr, cmap='jet', vmin=0, vmax=1)
axes[2].set_title('Prediction')
axes[2].axis('off')
m = (f"MSE: {p['mse']:.4f} | MS-SSIM: {p['ms_ssim']:.4f} | "
f"Pixel Corr: {p['pixel_correlation']:.4f} | Rel Mag Err: {p.get('wfm_relative_magnitude_error', 'N/A')}")
fig.suptitle(f"Rank {rank} (by {sort_by})\n{m}", fontsize=10, y=0.02)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, f"rank{rank:03d}_batch{p['batch_idx']:03d}_sample{p['sample_idx']:02d}.png"), dpi=150, bbox_inches='tight')
plt.close()
def plot_predictions(loader, generator, n_samples, device, threshold=0.0,
use_settings=False, normalization_params=None, config_path=None, substrate_override=None):
"""Plot BF | GT | Pred for first n_samples from loader."""
generator = generator.to(device)
generator.eval()
bf_list, gt_list, meta_list = [], [], []
it = iter(loader)
while len(bf_list) < n_samples:
try:
batch = next(it)
except StopIteration:
break
if len(batch) == 5:
images, heatmaps, _, _, meta = batch
else:
images, heatmaps = batch[0], batch[1]
meta = None
for i in range(images.shape[0]):
if len(bf_list) >= n_samples:
break
bf_list.append(images[i])
gt_list.append(heatmaps[i])
meta_list.append(meta)
n = min(n_samples, len(bf_list))
bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32)
if use_settings and normalization_params:
from models.s2f_model import create_settings_channels
sub = substrate_override or 'fibroblasts_PDMS'
meta_dict = {'substrate': [sub] * n}
settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path)
bf_batch = torch.cat([bf_batch, settings_ch], dim=1)
with torch.no_grad():
pred = generator(bf_batch)
if detect_tanh_output_model(generator):
pred = convert_tanh_to_sigmoid_range(pred)
if threshold > 0:
pred = pred * (pred >= threshold).float()
fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n))
if n == 1:
axes = axes.reshape(1, -1)
for i in range(n):
axes[i, 0].imshow(bf_list[i].squeeze().cpu().numpy(), cmap='gray')
axes[i, 0].set_title('Bright Field')
axes[i, 0].axis('off')
axes[i, 1].imshow(gt_list[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1)
axes[i, 1].set_title('Ground Truth')
axes[i, 1].axis('off')
axes[i, 2].imshow(pred[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1)
axes[i, 2].set_title('Prediction')
axes[i, 2].axis('off')
plt.tight_layout()
plt.show()