Spaces:
Sleeping
Sleeping
Upload landmarkdiff/face_verifier.py with huggingface_hub
Browse files- landmarkdiff/face_verifier.py +166 -43
landmarkdiff/face_verifier.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
|
@@ -21,7 +30,7 @@ import numpy as np
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class DistortionReport:
|
| 24 |
-
"""
|
| 25 |
|
| 26 |
# Overall quality score (0-100, higher = better)
|
| 27 |
quality_score: float = 0.0
|
|
@@ -63,7 +72,7 @@ class DistortionReport:
|
|
| 63 |
|
| 64 |
@dataclass
|
| 65 |
class RestorationResult:
|
| 66 |
-
"""
|
| 67 |
|
| 68 |
restored: np.ndarray # Restored BGR image
|
| 69 |
original: np.ndarray # Original BGR image
|
|
@@ -81,14 +90,14 @@ class RestorationResult:
|
|
| 81 |
f"Improvement: +{self.improvement:.1f}",
|
| 82 |
f"Identity Sim: {self.identity_similarity:.3f}",
|
| 83 |
f"Identity OK: {self.identity_preserved}",
|
| 84 |
-
f"Stages Used: {'
|
| 85 |
]
|
| 86 |
return "\n".join(lines)
|
| 87 |
|
| 88 |
|
| 89 |
@dataclass
|
| 90 |
class BatchVerificationReport:
|
| 91 |
-
"""
|
| 92 |
|
| 93 |
total: int = 0
|
| 94 |
passed: int = 0 # Good quality, no fix needed
|
|
@@ -125,7 +134,11 @@ class BatchVerificationReport:
|
|
| 125 |
# ---------------------------------------------------------------------------
|
| 126 |
|
| 127 |
def detect_blur(image: np.ndarray) -> float:
|
| 128 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 130 |
|
| 131 |
# Laplacian variance (primary metric)
|
|
@@ -144,19 +157,27 @@ def detect_blur(image: np.ndarray) -> float:
|
|
| 144 |
|
| 145 |
|
| 146 |
def detect_noise(image: np.ndarray) -> float:
|
| 147 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 149 |
|
| 150 |
# Robust noise estimation via MAD of Laplacian
|
| 151 |
lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F)
|
| 152 |
-
sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD
|
| 153 |
|
| 154 |
# Normalize: sigma > 20 is very noisy
|
| 155 |
return float(np.clip(sigma_est / 25.0, 0, 1))
|
| 156 |
|
| 157 |
|
| 158 |
def detect_compression_artifacts(image: np.ndarray) -> float:
|
| 159 |
-
"""JPEG
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 161 |
h, w = gray.shape
|
| 162 |
|
|
@@ -188,11 +209,18 @@ def detect_compression_artifacts(image: np.ndarray) -> float:
|
|
| 188 |
|
| 189 |
|
| 190 |
def detect_oversmoothing(image: np.ndarray) -> float:
|
| 191 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 193 |
h, w = gray.shape
|
| 194 |
|
| 195 |
# Focus on face center region (avoid background)
|
|
|
|
|
|
|
| 196 |
roi = gray[h // 4:3 * h // 4, w // 4:3 * w // 4]
|
| 197 |
|
| 198 |
# Texture energy: variance of high-pass filtered image
|
|
@@ -216,7 +244,12 @@ def detect_oversmoothing(image: np.ndarray) -> float:
|
|
| 216 |
|
| 217 |
|
| 218 |
def detect_color_cast(image: np.ndarray) -> float:
|
| 219 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 221 |
h, w = image.shape[:2]
|
| 222 |
|
|
@@ -242,7 +275,11 @@ def detect_color_cast(image: np.ndarray) -> float:
|
|
| 242 |
|
| 243 |
|
| 244 |
def detect_geometric_distortion(image: np.ndarray) -> float:
|
| 245 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
try:
|
| 247 |
from landmarkdiff.landmarks import extract_landmarks
|
| 248 |
except ImportError:
|
|
@@ -255,6 +292,9 @@ def detect_geometric_distortion(image: np.ndarray) -> float:
|
|
| 255 |
coords = face.pixel_coords
|
| 256 |
h, w = image.shape[:2]
|
| 257 |
|
|
|
|
|
|
|
|
|
|
| 258 |
# Key ratios that should be anatomically consistent
|
| 259 |
left_eye = coords[33]
|
| 260 |
right_eye = coords[263]
|
|
@@ -288,7 +328,10 @@ def detect_geometric_distortion(image: np.ndarray) -> float:
|
|
| 288 |
|
| 289 |
|
| 290 |
def detect_lighting_issues(image: np.ndarray) -> float:
|
| 291 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 292 |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
| 293 |
l_channel = lab[:, :, 0]
|
| 294 |
|
|
@@ -298,7 +341,10 @@ def detect_lighting_issues(image: np.ndarray) -> float:
|
|
| 298 |
|
| 299 |
# Check for bimodal distribution (harsh shadows)
|
| 300 |
hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
| 302 |
# Measure how spread out the histogram is
|
| 303 |
entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10))
|
| 304 |
# Low entropy = concentrated = potentially problematic
|
|
@@ -309,7 +355,11 @@ def detect_lighting_issues(image: np.ndarray) -> float:
|
|
| 309 |
|
| 310 |
|
| 311 |
def analyze_distortions(image: np.ndarray) -> DistortionReport:
|
| 312 |
-
"""Run
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
blur = detect_blur(image)
|
| 314 |
noise = detect_noise(image)
|
| 315 |
compression = detect_compression_artifacts(image)
|
|
@@ -318,7 +368,7 @@ def analyze_distortions(image: np.ndarray) -> DistortionReport:
|
|
| 318 |
geometric = detect_geometric_distortion(image)
|
| 319 |
lighting = detect_lighting_issues(image)
|
| 320 |
|
| 321 |
-
# weighted combination (inverted
|
| 322 |
weighted = (
|
| 323 |
0.25 * blur
|
| 324 |
+ 0.15 * noise
|
|
@@ -380,7 +430,10 @@ _FACE_QUALITY_NET = None
|
|
| 380 |
|
| 381 |
|
| 382 |
def _get_face_quality_scorer():
|
| 383 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 384 |
global _FACE_QUALITY_NET
|
| 385 |
if _FACE_QUALITY_NET is not None:
|
| 386 |
return _FACE_QUALITY_NET
|
|
@@ -396,7 +449,12 @@ def _get_face_quality_scorer():
|
|
| 396 |
|
| 397 |
|
| 398 |
def neural_quality_score(image: np.ndarray) -> float:
|
| 399 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
# Try neural scorer
|
| 401 |
scorer = _get_face_quality_scorer()
|
| 402 |
if scorer is not None:
|
|
@@ -429,14 +487,31 @@ def restore_face(
|
|
| 429 |
mode: str = "auto",
|
| 430 |
codeformer_fidelity: float = 0.7,
|
| 431 |
) -> tuple[np.ndarray, list[str]]:
|
| 432 |
-
"""Cascaded neural face restoration.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
if distortion is None:
|
| 434 |
distortion = analyze_distortions(image)
|
| 435 |
|
| 436 |
result = image.copy()
|
| 437 |
stages = []
|
| 438 |
|
| 439 |
-
#
|
| 440 |
if distortion.color_cast_score > 0.25:
|
| 441 |
result = _fix_color_cast(result)
|
| 442 |
stages.append("color_correction")
|
|
@@ -518,6 +593,9 @@ def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
|
|
| 518 |
return None
|
| 519 |
|
| 520 |
|
|
|
|
|
|
|
|
|
|
| 521 |
def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
| 522 |
"""Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
|
| 523 |
try:
|
|
@@ -525,20 +603,22 @@ def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
|
| 525 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 526 |
import torch
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
| 542 |
|
| 543 |
# Downsample to 512x512 for pipeline consistency
|
| 544 |
enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
|
@@ -604,7 +684,10 @@ def _get_arcface():
|
|
| 604 |
|
| 605 |
|
| 606 |
def get_face_embedding(image: np.ndarray) -> np.ndarray | None:
|
| 607 |
-
"""ArcFace 512-d embedding
|
|
|
|
|
|
|
|
|
|
| 608 |
app = _get_arcface()
|
| 609 |
if app is None:
|
| 610 |
return None
|
|
@@ -623,12 +706,16 @@ def verify_identity(
|
|
| 623 |
restored: np.ndarray,
|
| 624 |
threshold: float = 0.6,
|
| 625 |
) -> tuple[float, bool]:
|
| 626 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
emb_orig = get_face_embedding(original)
|
| 628 |
emb_rest = get_face_embedding(restored)
|
| 629 |
|
| 630 |
if emb_orig is None or emb_rest is None:
|
| 631 |
-
return -1.0, True #
|
| 632 |
|
| 633 |
sim = float(np.dot(emb_orig, emb_rest) / (
|
| 634 |
np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8
|
|
@@ -648,13 +735,30 @@ def verify_and_restore(
|
|
| 648 |
restore_mode: str = "auto",
|
| 649 |
codeformer_fidelity: float = 0.7,
|
| 650 |
) -> RestorationResult:
|
| 651 |
-
"""Full pipeline: analyze
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
# Step 1: Analyze distortions
|
| 653 |
report = analyze_distortions(image)
|
| 654 |
|
| 655 |
# Step 2: Decide if restoration needed
|
| 656 |
if report.quality_score >= quality_threshold and report.severity in ("none", "mild"):
|
| 657 |
-
#
|
| 658 |
return RestorationResult(
|
| 659 |
restored=image.copy(),
|
| 660 |
original=image.copy(),
|
|
@@ -718,7 +822,26 @@ def verify_batch(
|
|
| 718 |
save_rejected: bool = False,
|
| 719 |
extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"),
|
| 720 |
) -> BatchVerificationReport:
|
| 721 |
-
"""Process a directory of face images: analyze, restore, verify, sort.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
image_path = Path(image_dir)
|
| 723 |
if output_dir is None:
|
| 724 |
out_path = image_path.parent / f"{image_path.name}_verified"
|
|
|
|
| 1 |
+
"""Neural face verification, distortion detection, and restoration pipeline.
|
| 2 |
+
|
| 3 |
+
End-to-end system that:
|
| 4 |
+
1. Detects face distortions (blur, beauty filters, compression, warping, etc.)
|
| 5 |
+
2. Classifies distortion type and severity using no-reference quality metrics
|
| 6 |
+
3. Restores faces using cascaded neural networks (CodeFormer β GFPGAN β Real-ESRGAN)
|
| 7 |
+
4. Verifies output identity matches input via ArcFace embeddings
|
| 8 |
+
5. Scores output realism using learned perceptual metrics
|
| 9 |
+
|
| 10 |
+
Designed for:
|
| 11 |
+
- Cleaning scraped training data (reject/fix bad images before pair generation)
|
| 12 |
+
- Post-diffusion quality gate (ensure generated faces pass realism threshold)
|
| 13 |
+
- Filter removal (undo Snapchat/Instagram beauty filters for clinical use)
|
| 14 |
"""
|
| 15 |
|
| 16 |
from __future__ import annotations
|
|
|
|
| 30 |
|
| 31 |
@dataclass
|
| 32 |
class DistortionReport:
|
| 33 |
+
"""Analysis of detected distortions in a face image."""
|
| 34 |
|
| 35 |
# Overall quality score (0-100, higher = better)
|
| 36 |
quality_score: float = 0.0
|
|
|
|
| 72 |
|
| 73 |
@dataclass
|
| 74 |
class RestorationResult:
|
| 75 |
+
"""Result of neural face restoration pipeline."""
|
| 76 |
|
| 77 |
restored: np.ndarray # Restored BGR image
|
| 78 |
original: np.ndarray # Original BGR image
|
|
|
|
| 90 |
f"Improvement: +{self.improvement:.1f}",
|
| 91 |
f"Identity Sim: {self.identity_similarity:.3f}",
|
| 92 |
f"Identity OK: {self.identity_preserved}",
|
| 93 |
+
f"Stages Used: {' β '.join(self.restoration_stages) or 'none'}",
|
| 94 |
]
|
| 95 |
return "\n".join(lines)
|
| 96 |
|
| 97 |
|
| 98 |
@dataclass
|
| 99 |
class BatchVerificationReport:
|
| 100 |
+
"""Summary of batch face verification/restoration."""
|
| 101 |
|
| 102 |
total: int = 0
|
| 103 |
passed: int = 0 # Good quality, no fix needed
|
|
|
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
|
| 136 |
def detect_blur(image: np.ndarray) -> float:
|
| 137 |
+
"""Detect blur using Laplacian variance.
|
| 138 |
+
|
| 139 |
+
Low variance = blurry. We normalize to 0-1 where 1 = very blurry.
|
| 140 |
+
Uses both Laplacian variance and gradient magnitude for robustness.
|
| 141 |
+
"""
|
| 142 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 143 |
|
| 144 |
# Laplacian variance (primary metric)
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def detect_noise(image: np.ndarray) -> float:
|
| 160 |
+
"""Detect image noise level.
|
| 161 |
+
|
| 162 |
+
Estimates noise by measuring high-frequency energy in smooth regions.
|
| 163 |
+
Uses the median absolute deviation of the Laplacian (robust estimator).
|
| 164 |
+
"""
|
| 165 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 166 |
|
| 167 |
# Robust noise estimation via MAD of Laplacian
|
| 168 |
lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F)
|
| 169 |
+
sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD β std conversion
|
| 170 |
|
| 171 |
# Normalize: sigma > 20 is very noisy
|
| 172 |
return float(np.clip(sigma_est / 25.0, 0, 1))
|
| 173 |
|
| 174 |
|
| 175 |
def detect_compression_artifacts(image: np.ndarray) -> float:
|
| 176 |
+
"""Detect JPEG compression block artifacts.
|
| 177 |
+
|
| 178 |
+
Measures energy at 8x8 block boundaries (JPEG DCT block size).
|
| 179 |
+
High boundary energy relative to interior = compression artifacts.
|
| 180 |
+
"""
|
| 181 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 182 |
h, w = gray.shape
|
| 183 |
|
|
|
|
| 209 |
|
| 210 |
|
| 211 |
def detect_oversmoothing(image: np.ndarray) -> float:
|
| 212 |
+
"""Detect beauty filter / airbrushed skin (oversmoothing).
|
| 213 |
+
|
| 214 |
+
Beauty filters remove skin texture while preserving edges. We detect
|
| 215 |
+
this by measuring the ratio of edge energy to texture energy.
|
| 216 |
+
High edge / low texture = beauty filtered.
|
| 217 |
+
"""
|
| 218 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
| 219 |
h, w = gray.shape
|
| 220 |
|
| 221 |
# Focus on face center region (avoid background)
|
| 222 |
+
if h < 8 or w < 8:
|
| 223 |
+
return 0.0 # Too small to analyze
|
| 224 |
roi = gray[h // 4:3 * h // 4, w // 4:3 * w // 4]
|
| 225 |
|
| 226 |
# Texture energy: variance of high-pass filtered image
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
def detect_color_cast(image: np.ndarray) -> float:
|
| 247 |
+
"""Detect unnatural color cast (Instagram-style filters).
|
| 248 |
+
|
| 249 |
+
Measures deviation of average A/B channels in LAB space from
|
| 250 |
+
neutral. Natural skin has consistent LAB distributions; filtered
|
| 251 |
+
images shift these channels.
|
| 252 |
+
"""
|
| 253 |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 254 |
h, w = image.shape[:2]
|
| 255 |
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
def detect_geometric_distortion(image: np.ndarray) -> float:
|
| 278 |
+
"""Detect geometric face distortion (warping filters, lens distortion).
|
| 279 |
+
|
| 280 |
+
Uses MediaPipe landmarks to check face proportions against anatomical
|
| 281 |
+
norms. Distorted faces have abnormal inter-ocular / face-width ratios.
|
| 282 |
+
"""
|
| 283 |
try:
|
| 284 |
from landmarkdiff.landmarks import extract_landmarks
|
| 285 |
except ImportError:
|
|
|
|
| 292 |
coords = face.pixel_coords
|
| 293 |
h, w = image.shape[:2]
|
| 294 |
|
| 295 |
+
if len(coords) < 478:
|
| 296 |
+
return 0.5 # Incomplete landmark set
|
| 297 |
+
|
| 298 |
# Key ratios that should be anatomically consistent
|
| 299 |
left_eye = coords[33]
|
| 300 |
right_eye = coords[263]
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
def detect_lighting_issues(image: np.ndarray) -> float:
|
| 331 |
+
"""Detect over/under exposure and harsh lighting.
|
| 332 |
+
|
| 333 |
+
Checks luminance histogram for clipping and uneven distribution.
|
| 334 |
+
"""
|
| 335 |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
| 336 |
l_channel = lab[:, :, 0]
|
| 337 |
|
|
|
|
| 341 |
|
| 342 |
# Check for bimodal distribution (harsh shadows)
|
| 343 |
hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
|
| 344 |
+
hist_sum = hist.sum()
|
| 345 |
+
if hist_sum < 1e-10:
|
| 346 |
+
return 0.0
|
| 347 |
+
hist = hist / hist_sum
|
| 348 |
# Measure how spread out the histogram is
|
| 349 |
entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10))
|
| 350 |
# Low entropy = concentrated = potentially problematic
|
|
|
|
| 355 |
|
| 356 |
|
| 357 |
def analyze_distortions(image: np.ndarray) -> DistortionReport:
|
| 358 |
+
"""Run full distortion analysis on a face image.
|
| 359 |
+
|
| 360 |
+
Combines all detection methods into a comprehensive report with
|
| 361 |
+
quality score, primary distortion classification, and severity.
|
| 362 |
+
"""
|
| 363 |
blur = detect_blur(image)
|
| 364 |
noise = detect_noise(image)
|
| 365 |
compression = detect_compression_artifacts(image)
|
|
|
|
| 368 |
geometric = detect_geometric_distortion(image)
|
| 369 |
lighting = detect_lighting_issues(image)
|
| 370 |
|
| 371 |
+
# Overall quality: weighted combination (inverted β 100 = perfect)
|
| 372 |
weighted = (
|
| 373 |
0.25 * blur
|
| 374 |
+ 0.15 * noise
|
|
|
|
| 430 |
|
| 431 |
|
| 432 |
def _get_face_quality_scorer():
|
| 433 |
+
"""Get or create singleton face quality assessment model.
|
| 434 |
+
|
| 435 |
+
Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
|
| 436 |
+
"""
|
| 437 |
global _FACE_QUALITY_NET
|
| 438 |
if _FACE_QUALITY_NET is not None:
|
| 439 |
return _FACE_QUALITY_NET
|
|
|
|
| 449 |
|
| 450 |
|
| 451 |
def neural_quality_score(image: np.ndarray) -> float:
|
| 452 |
+
"""Score face quality using neural network (0-100, higher = better).
|
| 453 |
+
|
| 454 |
+
Tries FaceXLib quality assessment first, then falls back to
|
| 455 |
+
BRISQUE-style scoring using OpenCV's QualityBRISQUE if available,
|
| 456 |
+
or classical metrics as last resort.
|
| 457 |
+
"""
|
| 458 |
# Try neural scorer
|
| 459 |
scorer = _get_face_quality_scorer()
|
| 460 |
if scorer is not None:
|
|
|
|
| 487 |
mode: str = "auto",
|
| 488 |
codeformer_fidelity: float = 0.7,
|
| 489 |
) -> tuple[np.ndarray, list[str]]:
|
| 490 |
+
"""Cascaded neural face restoration.
|
| 491 |
+
|
| 492 |
+
Selects and applies restoration networks based on detected distortions:
|
| 493 |
+
- Blur/oversmooth β CodeFormer (recovers texture from codebook)
|
| 494 |
+
- Noise/compression β GFPGAN (trained on degraded faces)
|
| 495 |
+
- Background β Real-ESRGAN (neural 4x upscale + downsample)
|
| 496 |
+
- Color cast β Classical LAB correction (no neural net needed)
|
| 497 |
+
- Geometric β Not fixable by restoration (flag and skip)
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
image: BGR face image to restore.
|
| 501 |
+
distortion: Pre-computed distortion report (computed if None).
|
| 502 |
+
mode: 'auto' (choose based on distortion), 'codeformer', 'gfpgan', 'all'.
|
| 503 |
+
codeformer_fidelity: CodeFormer quality-fidelity tradeoff.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
Tuple of (restored BGR image, list of stages applied).
|
| 507 |
+
"""
|
| 508 |
if distortion is None:
|
| 509 |
distortion = analyze_distortions(image)
|
| 510 |
|
| 511 |
result = image.copy()
|
| 512 |
stages = []
|
| 513 |
|
| 514 |
+
# Step 0: Fix color cast first (classical β fast, doesn't affect identity)
|
| 515 |
if distortion.color_cast_score > 0.25:
|
| 516 |
result = _fix_color_cast(result)
|
| 517 |
stages.append("color_correction")
|
|
|
|
| 593 |
return None
|
| 594 |
|
| 595 |
|
| 596 |
+
_FV_REALESRGAN = None
|
| 597 |
+
|
| 598 |
+
|
| 599 |
def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
| 600 |
"""Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
|
| 601 |
try:
|
|
|
|
| 603 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 604 |
import torch
|
| 605 |
|
| 606 |
+
global _FV_REALESRGAN
|
| 607 |
+
if _FV_REALESRGAN is None:
|
| 608 |
+
model = RRDBNet(
|
| 609 |
+
num_in_ch=3, num_out_ch=3, num_feat=64,
|
| 610 |
+
num_block=23, num_grow_ch=32, scale=4,
|
| 611 |
+
)
|
| 612 |
+
_FV_REALESRGAN = RealESRGANer(
|
| 613 |
+
scale=4,
|
| 614 |
+
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
| 615 |
+
model=model,
|
| 616 |
+
tile=400,
|
| 617 |
+
tile_pad=10,
|
| 618 |
+
pre_pad=0,
|
| 619 |
+
half=torch.cuda.is_available(),
|
| 620 |
+
)
|
| 621 |
+
enhanced, _ = _FV_REALESRGAN.enhance(image, outscale=2)
|
| 622 |
|
| 623 |
# Downsample to 512x512 for pipeline consistency
|
| 624 |
enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
|
|
|
| 684 |
|
| 685 |
|
| 686 |
def get_face_embedding(image: np.ndarray) -> np.ndarray | None:
|
| 687 |
+
"""Extract ArcFace 512-d embedding from a face image.
|
| 688 |
+
|
| 689 |
+
Returns None if no face detected or InsightFace unavailable.
|
| 690 |
+
"""
|
| 691 |
app = _get_arcface()
|
| 692 |
if app is None:
|
| 693 |
return None
|
|
|
|
| 706 |
restored: np.ndarray,
|
| 707 |
threshold: float = 0.6,
|
| 708 |
) -> tuple[float, bool]:
|
| 709 |
+
"""Compare identity between original and restored using ArcFace.
|
| 710 |
+
|
| 711 |
+
Returns (cosine_similarity, passed).
|
| 712 |
+
Similarity > threshold means same person (threshold=0.6 is conservative).
|
| 713 |
+
"""
|
| 714 |
emb_orig = get_face_embedding(original)
|
| 715 |
emb_rest = get_face_embedding(restored)
|
| 716 |
|
| 717 |
if emb_orig is None or emb_rest is None:
|
| 718 |
+
return -1.0, True # Can't verify β assume OK
|
| 719 |
|
| 720 |
sim = float(np.dot(emb_orig, emb_rest) / (
|
| 721 |
np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8
|
|
|
|
| 735 |
restore_mode: str = "auto",
|
| 736 |
codeformer_fidelity: float = 0.7,
|
| 737 |
) -> RestorationResult:
|
| 738 |
+
"""Full pipeline: analyze β restore β verify identity.
|
| 739 |
+
|
| 740 |
+
This is the main entry point for the face verifier. It:
|
| 741 |
+
1. Analyzes the input for distortions
|
| 742 |
+
2. If quality is below threshold, applies neural restoration
|
| 743 |
+
3. Verifies the restored face preserves identity
|
| 744 |
+
4. Returns comprehensive result with metrics
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
image: BGR face image.
|
| 748 |
+
quality_threshold: Min quality to skip restoration (0-100).
|
| 749 |
+
identity_threshold: Min ArcFace similarity to pass (0-1).
|
| 750 |
+
restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
|
| 751 |
+
codeformer_fidelity: CodeFormer quality-fidelity balance.
|
| 752 |
+
|
| 753 |
+
Returns:
|
| 754 |
+
RestorationResult with restored image and full metrics.
|
| 755 |
+
"""
|
| 756 |
# Step 1: Analyze distortions
|
| 757 |
report = analyze_distortions(image)
|
| 758 |
|
| 759 |
# Step 2: Decide if restoration needed
|
| 760 |
if report.quality_score >= quality_threshold and report.severity in ("none", "mild"):
|
| 761 |
+
# Image is good enough β no restoration needed
|
| 762 |
return RestorationResult(
|
| 763 |
restored=image.copy(),
|
| 764 |
original=image.copy(),
|
|
|
|
| 822 |
save_rejected: bool = False,
|
| 823 |
extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"),
|
| 824 |
) -> BatchVerificationReport:
|
| 825 |
+
"""Process a directory of face images: analyze, restore, verify, sort.
|
| 826 |
+
|
| 827 |
+
Outputs:
|
| 828 |
+
- {output_dir}/passed/ β good images (no fix needed)
|
| 829 |
+
- {output_dir}/restored/ β fixed images
|
| 830 |
+
- {output_dir}/rejected/ β too distorted to use (if save_rejected=True)
|
| 831 |
+
- {output_dir}/report.txt β batch verification report
|
| 832 |
+
|
| 833 |
+
Args:
|
| 834 |
+
image_dir: Directory of face images to process.
|
| 835 |
+
output_dir: Where to save results (default: {image_dir}_verified/).
|
| 836 |
+
quality_threshold: Min quality to pass without restoration.
|
| 837 |
+
identity_threshold: Min identity similarity after restoration.
|
| 838 |
+
restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
|
| 839 |
+
save_rejected: Whether to copy rejected images to rejected/ subdir.
|
| 840 |
+
extensions: File extensions to process.
|
| 841 |
+
|
| 842 |
+
Returns:
|
| 843 |
+
BatchVerificationReport with summary statistics.
|
| 844 |
+
"""
|
| 845 |
image_path = Path(image_dir)
|
| 846 |
if output_dir is None:
|
| 847 |
out_path = image_path.parent / f"{image_path.name}_verified"
|