nsgf-plusplus / evaluation.py
rogermt's picture
Upload evaluation.py
da55996 verified
"""evaluation.py — Evaluation metrics for NSGF/NSGF++ experiments.
Implements:
- 2-Wasserstein distance (2D experiments)
- FID (Fréchet Inception Distance) for image experiments
- IS (Inception Score) for image experiments
- Visualization utilities
Reference: arXiv:2401.14069, Section 5, Appendix E
"""
import os
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional, List, Tuple
logger = logging.getLogger(__name__)
def compute_w2_distance(samples: torch.Tensor, targets: torch.Tensor) -> float:
"""Compute 2-Wasserstein distance using POT library."""
import ot
x = samples.detach().cpu().numpy()
y = targets.detach().cpu().numpy()
M = ot.dist(x, y, metric="sqeuclidean")
a = np.ones(len(x)) / len(x)
b = np.ones(len(y)) / len(y)
w2_sq = ot.emd2(a, b, M)
return float(np.sqrt(max(w2_sq, 0)))
class InceptionV3Features(nn.Module):
"""Inception V3 wrapper for FID/IS computation."""
def __init__(self, device: str = "cpu"):
super().__init__()
import torchvision.models as models
self.device = device
inception = models.inception_v3(pretrained=True, transform_input=False)
inception.eval()
self.blocks = nn.Sequential(
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
inception.Conv2d_2b_3x3, nn.MaxPool2d(3, stride=2),
inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
nn.MaxPool2d(3, stride=2),
inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c,
inception.Mixed_6d, inception.Mixed_6e,
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = inception.fc
self.to(device)
for p in self.parameters():
p.requires_grad_(False)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.shape[2] != 299 or x.shape[3] != 299:
x = torch.nn.functional.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
x = (x + 1) / 2
h = self.blocks(x)
features = self.avgpool(h).squeeze(-1).squeeze(-1)
logits = self.fc(features)
return features, logits
def compute_fid(generated: torch.Tensor, real: torch.Tensor,
device: str = "cpu", batch_size: int = 64) -> float:
from scipy import linalg
model = InceptionV3Features(device)
def get_features(images):
feats = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size].to(device)
f, _ = model(batch)
feats.append(f.cpu().numpy())
return np.concatenate(feats, axis=0)
logger.info("Computing FID: extracting generated features...")
feats_gen = get_features(generated)
logger.info("Computing FID: extracting real features...")
feats_real = get_features(real)
mu_gen, sigma_gen = feats_gen.mean(0), np.cov(feats_gen, rowvar=False)
mu_real, sigma_real = feats_real.mean(0), np.cov(feats_real, rowvar=False)
diff = mu_gen - mu_real
covmean, _ = linalg.sqrtm(sigma_gen @ sigma_real, disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean)
return float(fid)
def compute_inception_score(images: torch.Tensor, device: str = "cpu",
batch_size: int = 64, splits: int = 10) -> Tuple[float, float]:
model = InceptionV3Features(device)
all_logits = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size].to(device)
_, logits = model(batch)
all_logits.append(logits.cpu())
all_logits = torch.cat(all_logits, dim=0)
probs = torch.softmax(all_logits, dim=1).numpy()
scores = []
n = len(probs)
split_size = n // splits
for i in range(splits):
part = probs[i * split_size:(i + 1) * split_size]
py = part.mean(axis=0, keepdims=True)
kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10))
kl = kl.sum(axis=1).mean()
scores.append(np.exp(kl))
return float(np.mean(scores)), float(np.std(scores))
class Evaluation:
def __init__(self, config: dict, device: str = "cpu"):
self.config = config
self.device = device
self.dataset_name = config.get("dataset", "8gaussians")
self.is_image = self.dataset_name in ("mnist", "cifar10")
def evaluate(self, generated: torch.Tensor, real: torch.Tensor) -> Dict[str, float]:
metrics = {}
if self.is_image:
eval_cfg = self.config.get("evaluation", {})
metric_names = eval_cfg.get("metrics", ["fid"])
if "fid" in metric_names:
logger.info("Computing FID...")
metrics["fid"] = compute_fid(generated, real, self.device)
logger.info(f"FID: {metrics['fid']:.2f}")
if "is" in metric_names:
logger.info("Computing Inception Score...")
is_mean, is_std = compute_inception_score(generated, self.device)
metrics["is_mean"] = is_mean
metrics["is_std"] = is_std
logger.info(f"IS: {is_mean:.2f} ± {is_std:.2f}")
else:
w2 = compute_w2_distance(generated, real)
metrics["w2"] = w2
logger.info(f"W2 distance: {w2:.4f}")
return metrics
def plot_2d_samples(samples: torch.Tensor, targets: torch.Tensor,
title: str = "Generated vs Target", save_path: Optional[str] = None):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
s = samples.detach().cpu().numpy()
t = targets.detach().cpu().numpy()
axes[0].scatter(t[:, 0], t[:, 1], s=3, alpha=0.5, c="blue")
axes[0].set_title("Target Distribution")
axes[0].set_xlim(-6, 6); axes[0].set_ylim(-6, 6); axes[0].set_aspect("equal")
axes[1].scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c="red")
axes[1].set_title("Generated Samples")
axes[1].set_xlim(-6, 6); axes[1].set_ylim(-6, 6); axes[1].set_aspect("equal")
axes[2].scatter(t[:, 0], t[:, 1], s=3, alpha=0.3, c="blue", label="Target")
axes[2].scatter(s[:, 0], s[:, 1], s=3, alpha=0.3, c="red", label="Generated")
axes[2].set_title("Overlay")
axes[2].set_xlim(-6, 6); axes[2].set_ylim(-6, 6); axes[2].set_aspect("equal")
axes[2].legend()
plt.suptitle(title)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
logger.info(f"Saved plot to {save_path}")
plt.close()
def plot_2d_trajectory(trajectory: List[torch.Tensor], targets: torch.Tensor,
title: str = "Flow Trajectory", save_path: Optional[str] = None,
max_particles: int = 200):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
t = targets.detach().cpu().numpy()
ax.scatter(t[:, 0], t[:, 1], s=3, alpha=0.2, c="blue", label="Target")
T = len(trajectory)
n = min(trajectory[0].shape[0], max_particles)
for i in range(n):
points = np.array([trajectory[step][i].detach().cpu().numpy() for step in range(T)])
segments = np.array([[points[j], points[j + 1]] for j in range(len(points) - 1)])
colors = plt.cm.coolwarm(np.linspace(0, 1, len(segments)))
lc = LineCollection(segments, colors=colors, linewidths=0.5, alpha=0.5)
ax.add_collection(lc)
final = trajectory[-1][:n].detach().cpu().numpy()
ax.scatter(final[:, 0], final[:, 1], s=5, c="red", alpha=0.5, label="Generated")
ax.set_xlim(-6, 6); ax.set_ylim(-6, 6); ax.set_aspect("equal")
ax.set_title(title); ax.legend()
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
logger.info(f"Saved trajectory plot to {save_path}")
plt.close()
def plot_image_grid(images: torch.Tensor, nrow: int = 8,
title: str = "Generated Images", save_path: Optional[str] = None):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torchvision.utils as vutils
grid = vutils.make_grid(images[:nrow * nrow], nrow=nrow, normalize=True, value_range=(-1, 1))
grid_np = grid.permute(1, 2, 0).cpu().numpy()
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(grid_np); ax.set_title(title); ax.axis("off")
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
logger.info(f"Saved image grid to {save_path}")
plt.close()