dreamlessx commited on
Commit
6421899
·
verified ·
1 Parent(s): d2c39e0

Upload landmarkdiff/fid.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/fid.py +232 -0
landmarkdiff/fid.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained FID computation using InceptionV3 feature extraction.
2
+
3
+ Avoids dependency on torch-fidelity by implementing FID directly.
4
+ Supports GPU acceleration, batched processing, and caching.
5
+
6
+ Usage:
7
+ from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays
8
+
9
+ # From directories
10
+ fid = compute_fid_from_dirs("path/to/real", "path/to/generated")
11
+
12
+ # From numpy arrays
13
+ fid = compute_fid_from_arrays(real_images, generated_images)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+
22
+ try:
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.utils.data import DataLoader, Dataset
26
+ HAS_TORCH = True
27
+ except ImportError:
28
+ HAS_TORCH = False
29
+
30
+
31
+ def _load_inception_v3():
32
+ """Load InceptionV3 with pool3 features (2048-dim)."""
33
+ from torchvision.models import inception_v3, Inception_V3_Weights
34
+
35
+ model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
36
+ # We want features from the avg pool layer (2048-dim)
37
+ # Remove the final FC layer
38
+ model.fc = nn.Identity()
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ class ImageFolderDataset(Dataset):
44
+ """Simple dataset that loads images from a directory."""
45
+
46
+ def __init__(self, directory: str | Path, image_size: int = 299):
47
+ self.directory = Path(directory)
48
+ exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
49
+ self.files = sorted(
50
+ f for f in self.directory.iterdir()
51
+ if f.suffix.lower() in exts and f.is_file()
52
+ )
53
+ self.image_size = image_size
54
+
55
+ def __len__(self):
56
+ return len(self.files)
57
+
58
+ def __getitem__(self, idx):
59
+ import cv2
60
+ img = cv2.imread(str(self.files[idx]))
61
+ if img is None:
62
+ # Return zeros if image can't be loaded
63
+ return torch.zeros(3, self.image_size, self.image_size)
64
+ img = cv2.resize(img, (self.image_size, self.image_size))
65
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
66
+ # Normalize to [0, 1] then ImageNet normalize
67
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
68
+ t = _imagenet_normalize(t)
69
+ return t
70
+
71
+
72
+ class NumpyArrayDataset(Dataset):
73
+ """Dataset wrapping a list of numpy arrays."""
74
+
75
+ def __init__(self, images: list[np.ndarray], image_size: int = 299):
76
+ self.images = images
77
+ self.image_size = image_size
78
+
79
+ def __len__(self):
80
+ return len(self.images)
81
+
82
+ def __getitem__(self, idx):
83
+ import cv2
84
+ img = self.images[idx]
85
+ if img.shape[:2] != (self.image_size, self.image_size):
86
+ img = cv2.resize(img, (self.image_size, self.image_size))
87
+ if img.shape[2] == 3:
88
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
89
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
90
+ t = _imagenet_normalize(t)
91
+ return t
92
+
93
+
94
+ def _imagenet_normalize(t: "torch.Tensor") -> "torch.Tensor":
95
+ """Apply ImageNet normalization."""
96
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
97
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
98
+ return (t - mean) / std
99
+
100
+
101
+ @torch.no_grad()
102
+ def _extract_features(
103
+ model: nn.Module,
104
+ dataloader: DataLoader,
105
+ device: torch.device,
106
+ ) -> np.ndarray:
107
+ """Extract InceptionV3 pool3 features from a dataloader."""
108
+ features = []
109
+ for batch in dataloader:
110
+ batch = batch.to(device)
111
+ feat = model(batch)
112
+ if isinstance(feat, tuple):
113
+ feat = feat[0]
114
+ features.append(feat.cpu().numpy())
115
+ return np.concatenate(features, axis=0)
116
+
117
+
118
+ def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
119
+ """Compute mean and covariance of feature vectors."""
120
+ mu = np.mean(features, axis=0)
121
+ sigma = np.cov(features, rowvar=False)
122
+ return mu, sigma
123
+
124
+
125
+ def _calculate_fid(
126
+ mu1: np.ndarray, sigma1: np.ndarray,
127
+ mu2: np.ndarray, sigma2: np.ndarray,
128
+ ) -> float:
129
+ """Calculate FID given two sets of statistics.
130
+
131
+ FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
132
+ """
133
+ from scipy.linalg import sqrtm
134
+
135
+ diff = mu1 - mu2
136
+ covmean = sqrtm(sigma1 @ sigma2)
137
+
138
+ # Handle numerical instability
139
+ if np.iscomplexobj(covmean):
140
+ covmean = covmean.real
141
+
142
+ fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
143
+ return float(fid)
144
+
145
+
146
+ def compute_fid_from_dirs(
147
+ real_dir: str | Path,
148
+ generated_dir: str | Path,
149
+ batch_size: int = 32,
150
+ num_workers: int = 4,
151
+ device: str | None = None,
152
+ ) -> float:
153
+ """Compute FID between two directories of images.
154
+
155
+ Args:
156
+ real_dir: Path to real images.
157
+ generated_dir: Path to generated images.
158
+ batch_size: Batch size for feature extraction.
159
+ num_workers: DataLoader workers.
160
+ device: "cuda" or "cpu". Auto-detects if None.
161
+
162
+ Returns:
163
+ FID score (lower = better).
164
+ """
165
+ if not HAS_TORCH:
166
+ raise ImportError("PyTorch required for FID computation")
167
+
168
+ if device is None:
169
+ device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ dev = torch.device(device)
171
+
172
+ model = _load_inception_v3().to(dev)
173
+
174
+ real_ds = ImageFolderDataset(real_dir)
175
+ gen_ds = ImageFolderDataset(generated_dir)
176
+
177
+ if len(real_ds) == 0 or len(gen_ds) == 0:
178
+ raise ValueError("Need at least 1 image in each directory")
179
+
180
+ real_loader = DataLoader(real_ds, batch_size=batch_size,
181
+ num_workers=num_workers, pin_memory=True)
182
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size,
183
+ num_workers=num_workers, pin_memory=True)
184
+
185
+ real_features = _extract_features(model, real_loader, dev)
186
+ gen_features = _extract_features(model, gen_loader, dev)
187
+
188
+ mu_real, sigma_real = _compute_statistics(real_features)
189
+ mu_gen, sigma_gen = _compute_statistics(gen_features)
190
+
191
+ return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
192
+
193
+
194
+ def compute_fid_from_arrays(
195
+ real_images: list[np.ndarray],
196
+ generated_images: list[np.ndarray],
197
+ batch_size: int = 32,
198
+ device: str | None = None,
199
+ ) -> float:
200
+ """Compute FID from lists of numpy arrays.
201
+
202
+ Args:
203
+ real_images: List of (H, W, 3) BGR uint8 images.
204
+ generated_images: List of (H, W, 3) BGR uint8 images.
205
+ batch_size: Batch size for feature extraction.
206
+ device: "cuda" or "cpu".
207
+
208
+ Returns:
209
+ FID score (lower = better).
210
+ """
211
+ if not HAS_TORCH:
212
+ raise ImportError("PyTorch required for FID computation")
213
+
214
+ if device is None:
215
+ device = "cuda" if torch.cuda.is_available() else "cpu"
216
+ dev = torch.device(device)
217
+
218
+ model = _load_inception_v3().to(dev)
219
+
220
+ real_ds = NumpyArrayDataset(real_images)
221
+ gen_ds = NumpyArrayDataset(generated_images)
222
+
223
+ real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
224
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)
225
+
226
+ real_features = _extract_features(model, real_loader, dev)
227
+ gen_features = _extract_features(model, gen_loader, dev)
228
+
229
+ mu_real, sigma_real = _compute_statistics(real_features)
230
+ mu_gen, sigma_gen = _compute_statistics(gen_features)
231
+
232
+ return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)