"""evaluation.py — Evaluation metrics for NSGF/NSGF++ experiments. Implements: - 2-Wasserstein distance (2D experiments) - FID (Fréchet Inception Distance) for image experiments - IS (Inception Score) for image experiments - Visualization utilities Reference: arXiv:2401.14069, Section 5, Appendix E """ import os import logging import numpy as np import torch import torch.nn as nn from typing import Dict, Optional, List, Tuple logger = logging.getLogger(__name__) def compute_w2_distance(samples: torch.Tensor, targets: torch.Tensor) -> float: """Compute 2-Wasserstein distance using POT library.""" import ot x = samples.detach().cpu().numpy() y = targets.detach().cpu().numpy() M = ot.dist(x, y, metric="sqeuclidean") a = np.ones(len(x)) / len(x) b = np.ones(len(y)) / len(y) w2_sq = ot.emd2(a, b, M) return float(np.sqrt(max(w2_sq, 0))) class InceptionV3Features(nn.Module): """Inception V3 wrapper for FID/IS computation.""" def __init__(self, device: str = "cpu"): super().__init__() import torchvision.models as models self.device = device inception = models.inception_v3(pretrained=True, transform_input=False) inception.eval() self.blocks = nn.Sequential( inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, nn.MaxPool2d(3, stride=2), inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(3, stride=2), inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d, inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c, inception.Mixed_6d, inception.Mixed_6e, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = inception.fc self.to(device) for p in self.parameters(): p.requires_grad_(False) @torch.no_grad() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if x.shape[2] != 299 or x.shape[3] != 299: x = torch.nn.functional.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) if x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) x = (x + 1) / 2 h = self.blocks(x) features = self.avgpool(h).squeeze(-1).squeeze(-1) logits = self.fc(features) return features, logits def compute_fid(generated: torch.Tensor, real: torch.Tensor, device: str = "cpu", batch_size: int = 64) -> float: from scipy import linalg model = InceptionV3Features(device) def get_features(images): feats = [] for i in range(0, len(images), batch_size): batch = images[i:i + batch_size].to(device) f, _ = model(batch) feats.append(f.cpu().numpy()) return np.concatenate(feats, axis=0) logger.info("Computing FID: extracting generated features...") feats_gen = get_features(generated) logger.info("Computing FID: extracting real features...") feats_real = get_features(real) mu_gen, sigma_gen = feats_gen.mean(0), np.cov(feats_gen, rowvar=False) mu_real, sigma_real = feats_real.mean(0), np.cov(feats_real, rowvar=False) diff = mu_gen - mu_real covmean, _ = linalg.sqrtm(sigma_gen @ sigma_real, disp=False) if np.iscomplexobj(covmean): covmean = covmean.real fid = diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean) return float(fid) def compute_inception_score(images: torch.Tensor, device: str = "cpu", batch_size: int = 64, splits: int = 10) -> Tuple[float, float]: model = InceptionV3Features(device) all_logits = [] for i in range(0, len(images), batch_size): batch = images[i:i + batch_size].to(device) _, logits = model(batch) all_logits.append(logits.cpu()) all_logits = torch.cat(all_logits, dim=0) probs = torch.softmax(all_logits, dim=1).numpy() scores = [] n = len(probs) split_size = n // splits for i in range(splits): part = probs[i * split_size:(i + 1) * split_size] py = part.mean(axis=0, keepdims=True) kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10)) kl = kl.sum(axis=1).mean() scores.append(np.exp(kl)) return float(np.mean(scores)), float(np.std(scores)) class Evaluation: def __init__(self, config: dict, device: str = "cpu"): self.config = config self.device = device self.dataset_name = config.get("dataset", "8gaussians") self.is_image = self.dataset_name in ("mnist", "cifar10") def evaluate(self, generated: torch.Tensor, real: torch.Tensor) -> Dict[str, float]: metrics = {} if self.is_image: eval_cfg = self.config.get("evaluation", {}) metric_names = eval_cfg.get("metrics", ["fid"]) if "fid" in metric_names: logger.info("Computing FID...") metrics["fid"] = compute_fid(generated, real, self.device) logger.info(f"FID: {metrics['fid']:.2f}") if "is" in metric_names: logger.info("Computing Inception Score...") is_mean, is_std = compute_inception_score(generated, self.device) metrics["is_mean"] = is_mean metrics["is_std"] = is_std logger.info(f"IS: {is_mean:.2f} ± {is_std:.2f}") else: w2 = compute_w2_distance(generated, real) metrics["w2"] = w2 logger.info(f"W2 distance: {w2:.4f}") return metrics def plot_2d_samples(samples: torch.Tensor, targets: torch.Tensor, title: str = "Generated vs Target", save_path: Optional[str] = None): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 3, figsize=(15, 5)) s = samples.detach().cpu().numpy() t = targets.detach().cpu().numpy() axes[0].scatter(t[:, 0], t[:, 1], s=3, alpha=0.5, c="blue") axes[0].set_title("Target Distribution") axes[0].set_xlim(-6, 6); axes[0].set_ylim(-6, 6); axes[0].set_aspect("equal") axes[1].scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c="red") axes[1].set_title("Generated Samples") axes[1].set_xlim(-6, 6); axes[1].set_ylim(-6, 6); axes[1].set_aspect("equal") axes[2].scatter(t[:, 0], t[:, 1], s=3, alpha=0.3, c="blue", label="Target") axes[2].scatter(s[:, 0], s[:, 1], s=3, alpha=0.3, c="red", label="Generated") axes[2].set_title("Overlay") axes[2].set_xlim(-6, 6); axes[2].set_ylim(-6, 6); axes[2].set_aspect("equal") axes[2].legend() plt.suptitle(title) plt.tight_layout() if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info(f"Saved plot to {save_path}") plt.close() def plot_2d_trajectory(trajectory: List[torch.Tensor], targets: torch.Tensor, title: str = "Flow Trajectory", save_path: Optional[str] = None, max_particles: int = 200): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.collections import LineCollection fig, ax = plt.subplots(1, 1, figsize=(8, 8)) t = targets.detach().cpu().numpy() ax.scatter(t[:, 0], t[:, 1], s=3, alpha=0.2, c="blue", label="Target") T = len(trajectory) n = min(trajectory[0].shape[0], max_particles) for i in range(n): points = np.array([trajectory[step][i].detach().cpu().numpy() for step in range(T)]) segments = np.array([[points[j], points[j + 1]] for j in range(len(points) - 1)]) colors = plt.cm.coolwarm(np.linspace(0, 1, len(segments))) lc = LineCollection(segments, colors=colors, linewidths=0.5, alpha=0.5) ax.add_collection(lc) final = trajectory[-1][:n].detach().cpu().numpy() ax.scatter(final[:, 0], final[:, 1], s=5, c="red", alpha=0.5, label="Generated") ax.set_xlim(-6, 6); ax.set_ylim(-6, 6); ax.set_aspect("equal") ax.set_title(title); ax.legend() if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info(f"Saved trajectory plot to {save_path}") plt.close() def plot_image_grid(images: torch.Tensor, nrow: int = 8, title: str = "Generated Images", save_path: Optional[str] = None): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import torchvision.utils as vutils grid = vutils.make_grid(images[:nrow * nrow], nrow=nrow, normalize=True, value_range=(-1, 1)) grid_np = grid.permute(1, 2, 0).cpu().numpy() fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(grid_np); ax.set_title(title); ax.axis("off") if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info(f"Saved image grid to {save_path}") plt.close()