dreamlessx commited on
Commit
eac09b2
·
verified ·
1 Parent(s): cc423b0

Upload landmarkdiff/displacement_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/displacement_model.py +728 -0
landmarkdiff/displacement_model.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data-driven surgical displacement extraction and modeling.
2
+
3
+ Extracts real landmark displacements from before/after surgery image pairs,
4
+ classifies procedures based on regional displacement patterns, and fits
5
+ per-procedure statistical models that can replace the hand-tuned RBF
6
+ displacement vectors in ``manipulation.py``.
7
+
8
+ Typical usage::
9
+
10
+ from landmarkdiff.displacement_model import (
11
+ extract_displacements,
12
+ extract_from_directory,
13
+ DisplacementModel,
14
+ )
15
+
16
+ # Single pair
17
+ result = extract_displacements(before_img, after_img)
18
+
19
+ # Batch from directory
20
+ all_displacements = extract_from_directory("data/surgery_pairs/")
21
+
22
+ # Fit model
23
+ model = DisplacementModel()
24
+ model.fit(all_displacements)
25
+ model.save("displacement_model.npz")
26
+
27
+ # Generate displacement field
28
+ field = model.get_displacement_field("rhinoplasty", intensity=0.7)
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import json
34
+ import logging
35
+ from pathlib import Path
36
+ from typing import Optional, Union
37
+
38
+ import cv2
39
+ import numpy as np
40
+
41
+ from landmarkdiff.landmarks import extract_landmarks, FaceLandmarks
42
+ from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Number of MediaPipe Face Mesh landmarks (468 face + 10 iris)
47
+ NUM_LANDMARKS = 478
48
+
49
+ # All supported procedures
50
+ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Helpers
55
+ # ---------------------------------------------------------------------------
56
+
57
+ def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
58
+ """Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
59
+
60
+ ``FaceLandmarks.landmarks`` is (478, 3) with (x, y, z) in normalized space.
61
+ We take only the x, y columns.
62
+ """
63
+ return face.landmarks[:, :2].copy()
64
+
65
+
66
+ def _compute_alignment_quality(
67
+ landmarks_before: np.ndarray,
68
+ landmarks_after: np.ndarray,
69
+ ) -> float:
70
+ """Estimate alignment quality between two landmark sets.
71
+
72
+ Uses a Procrustes-style analysis on landmarks that should *not* move during
73
+ surgery (forehead, temples, ears) to measure how well the faces are aligned.
74
+ A score of 1.0 means perfect alignment; lower values indicate pose/scale
75
+ mismatches that contaminate the displacement signal.
76
+
77
+ Args:
78
+ landmarks_before: (478, 2) normalized coordinates.
79
+ landmarks_after: (478, 2) normalized coordinates.
80
+
81
+ Returns:
82
+ Quality score in [0, 1].
83
+ """
84
+ # Stable landmarks: forehead, temple region, outer face oval
85
+ # These should exhibit near-zero displacement after surgery.
86
+ stable_indices = [
87
+ 10, 109, 67, 103, 54, 21, 162, 127, # left forehead/temple
88
+ 338, 297, 332, 284, 251, 389, 356, 454, # right forehead/temple
89
+ 234, 93, # outer cheek anchors
90
+ ]
91
+ stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
92
+
93
+ before_stable = landmarks_before[stable_indices]
94
+ after_stable = landmarks_after[stable_indices]
95
+
96
+ # RMS displacement on stable points
97
+ diffs = after_stable - before_stable
98
+ rms = np.sqrt(np.mean(np.sum(diffs ** 2, axis=1)))
99
+
100
+ # Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
101
+ quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
102
+ return quality
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Procedure classification
107
+ # ---------------------------------------------------------------------------
108
+
109
+ def classify_procedure(displacements: np.ndarray) -> str:
110
+ """Classify which surgical procedure was performed from displacement vectors.
111
+
112
+ Computes the mean displacement magnitude within each procedure's landmark
113
+ region (as defined by ``PROCEDURE_LANDMARKS``) and returns the procedure
114
+ with the highest regional activity.
115
+
116
+ Args:
117
+ displacements: (478, 2) displacement vectors (after - before) in
118
+ normalized coordinate space.
119
+
120
+ Returns:
121
+ Procedure name string, one of ``PROCEDURES``, or ``"unknown"`` if
122
+ no region shows significant displacement.
123
+ """
124
+ magnitudes = np.linalg.norm(displacements, axis=1)
125
+
126
+ best_procedure = "unknown"
127
+ best_score = 0.0
128
+
129
+ for procedure, indices in PROCEDURE_LANDMARKS.items():
130
+ valid_indices = [i for i in indices if i < len(magnitudes)]
131
+ if not valid_indices:
132
+ continue
133
+
134
+ region_mag = magnitudes[valid_indices]
135
+ # Use mean magnitude in the region as the score
136
+ score = float(np.mean(region_mag))
137
+
138
+ if score > best_score:
139
+ best_score = score
140
+ best_procedure = procedure
141
+
142
+ # If the best score is negligible, classify as unknown
143
+ # Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
144
+ if best_score < 0.002:
145
+ logger.debug(
146
+ "No significant displacement detected (best=%.5f). "
147
+ "Classified as 'unknown'.",
148
+ best_score,
149
+ )
150
+ return "unknown"
151
+
152
+ return best_procedure
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Single-pair extraction
157
+ # ---------------------------------------------------------------------------
158
+
159
+ def extract_displacements(
160
+ before_img: np.ndarray,
161
+ after_img: np.ndarray,
162
+ min_detection_confidence: float = 0.5,
163
+ ) -> Optional[dict]:
164
+ """Extract landmark displacements from a before/after surgery image pair.
165
+
166
+ Runs MediaPipe Face Mesh on both images, computes per-landmark
167
+ displacement vectors, classifies the procedure, and evaluates
168
+ alignment quality.
169
+
170
+ Args:
171
+ before_img: Pre-surgery BGR image as numpy array.
172
+ after_img: Post-surgery BGR image as numpy array.
173
+ min_detection_confidence: Minimum face detection confidence for
174
+ MediaPipe (default 0.5).
175
+
176
+ Returns:
177
+ Dictionary with keys:
178
+ - ``landmarks_before``: (478, 2) normalized coordinates
179
+ - ``landmarks_after``: (478, 2) normalized coordinates
180
+ - ``displacements``: (478, 2) displacement vectors
181
+ - ``magnitude``: (478,) per-landmark displacement magnitudes
182
+ - ``procedure``: classified procedure name or ``"unknown"``
183
+ - ``quality_score``: float in [0, 1] indicating alignment quality
184
+
185
+ Returns ``None`` if face detection fails on either image.
186
+ """
187
+ # Extract landmarks from both images
188
+ face_before = extract_landmarks(
189
+ before_img, min_detection_confidence=min_detection_confidence
190
+ )
191
+ if face_before is None:
192
+ logger.warning("Face detection failed on before image.")
193
+ return None
194
+
195
+ face_after = extract_landmarks(
196
+ after_img, min_detection_confidence=min_detection_confidence
197
+ )
198
+ if face_after is None:
199
+ logger.warning("Face detection failed on after image.")
200
+ return None
201
+
202
+ # Get normalized 2D coordinates
203
+ coords_before = _normalized_coords_2d(face_before)
204
+ coords_after = _normalized_coords_2d(face_after)
205
+
206
+ # Compute displacements
207
+ displacements = coords_after - coords_before
208
+ magnitudes = np.linalg.norm(displacements, axis=1)
209
+
210
+ # Classify procedure
211
+ procedure = classify_procedure(displacements)
212
+
213
+ # Evaluate alignment quality
214
+ quality = _compute_alignment_quality(coords_before, coords_after)
215
+
216
+ return {
217
+ "landmarks_before": coords_before,
218
+ "landmarks_after": coords_after,
219
+ "displacements": displacements,
220
+ "magnitude": magnitudes,
221
+ "procedure": procedure,
222
+ "quality_score": quality,
223
+ }
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Batch extraction from directory
228
+ # ---------------------------------------------------------------------------
229
+
230
+ def extract_from_directory(
231
+ pairs_dir: Union[str, Path],
232
+ min_detection_confidence: float = 0.5,
233
+ min_quality: float = 0.0,
234
+ ) -> list[dict]:
235
+ """Batch-extract displacements from a directory of before/after image pairs.
236
+
237
+ Supports two naming conventions:
238
+ - ``<name>_before.{png,jpg,...}`` / ``<name>_after.{png,jpg,...}``
239
+ - ``<name>_input.{png,jpg,...}`` / ``<name>_target.{png,jpg,...}``
240
+
241
+ Args:
242
+ pairs_dir: Path to directory containing image pairs.
243
+ min_detection_confidence: Passed to ``extract_displacements``.
244
+ min_quality: Minimum alignment quality score to include a pair
245
+ in the results (default 0.0 = include all).
246
+
247
+ Returns:
248
+ List of displacement dictionaries (same format as
249
+ ``extract_displacements``), each augmented with:
250
+ - ``pair_name``: stem of the pair (e.g. ``"patient_001"``)
251
+ - ``before_path``: path to the before image
252
+ - ``after_path``: path to the after image
253
+ """
254
+ pairs_dir = Path(pairs_dir)
255
+ if not pairs_dir.is_dir():
256
+ raise FileNotFoundError(f"Directory not found: {pairs_dir}")
257
+
258
+ # Collect all image files
259
+ image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
260
+ all_files = {
261
+ f.stem.lower(): f
262
+ for f in pairs_dir.iterdir()
263
+ if f.is_file() and f.suffix.lower() in image_extensions
264
+ }
265
+
266
+ # Find pairs using both naming conventions
267
+ pairs: list[tuple[str, Path, Path]] = []
268
+ seen_stems: set[str] = set()
269
+
270
+ for stem_lower, filepath in all_files.items():
271
+ # Convention 1: *_before / *_after
272
+ for before_suffix, after_suffix in [("_before", "_after"), ("_input", "_target")]:
273
+ if stem_lower.endswith(before_suffix):
274
+ base = stem_lower[: -len(before_suffix)]
275
+ after_stem = base + after_suffix
276
+ if after_stem in all_files and base not in seen_stems:
277
+ # Use original-case paths
278
+ before_path = filepath
279
+ after_path = all_files[after_stem]
280
+ pairs.append((base, before_path, after_path))
281
+ seen_stems.add(base)
282
+
283
+ if not pairs:
284
+ logger.warning("No image pairs found in %s", pairs_dir)
285
+ return []
286
+
287
+ logger.info("Found %d image pairs in %s", len(pairs), pairs_dir)
288
+
289
+ results: list[dict] = []
290
+ for pair_name, before_path, after_path in sorted(pairs):
291
+ logger.info("Processing pair: %s", pair_name)
292
+
293
+ # Load images
294
+ before_img = cv2.imread(str(before_path))
295
+ if before_img is None:
296
+ logger.warning("Failed to load before image: %s", before_path)
297
+ continue
298
+
299
+ after_img = cv2.imread(str(after_path))
300
+ if after_img is None:
301
+ logger.warning("Failed to load after image: %s", after_path)
302
+ continue
303
+
304
+ # Extract displacements
305
+ result = extract_displacements(
306
+ before_img, after_img, min_detection_confidence=min_detection_confidence
307
+ )
308
+ if result is None:
309
+ logger.warning("Skipping pair %s: face detection failed.", pair_name)
310
+ continue
311
+
312
+ # Filter by quality
313
+ if result["quality_score"] < min_quality:
314
+ logger.info(
315
+ "Skipping pair %s: quality %.3f < threshold %.3f",
316
+ pair_name,
317
+ result["quality_score"],
318
+ min_quality,
319
+ )
320
+ continue
321
+
322
+ # Augment with metadata
323
+ result["pair_name"] = pair_name
324
+ result["before_path"] = str(before_path)
325
+ result["after_path"] = str(after_path)
326
+ results.append(result)
327
+
328
+ logger.info(
329
+ "Successfully extracted %d / %d pairs (%.0f%%)",
330
+ len(results),
331
+ len(pairs),
332
+ 100.0 * len(results) / max(len(pairs), 1),
333
+ )
334
+ return results
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Displacement model
339
+ # ---------------------------------------------------------------------------
340
+
341
+ class DisplacementModel:
342
+ """Statistical model of per-procedure surgical displacements.
343
+
344
+ Aggregates displacement vectors from multiple before/after pairs and
345
+ computes per-procedure, per-landmark statistics (mean, std, min, max).
346
+ Can then generate displacement fields for use in the conditioning
347
+ pipeline, replacing hand-tuned RBF vectors.
348
+
349
+ Attributes:
350
+ procedures: List of procedure names the model has data for.
351
+ stats: Nested dict ``{procedure: {stat_name: array}}``.
352
+ n_samples: Dict ``{procedure: int}`` sample counts.
353
+ """
354
+
355
+ def __init__(self) -> None:
356
+ self.stats: dict[str, dict[str, np.ndarray]] = {}
357
+ self.n_samples: dict[str, int] = {}
358
+ self._fitted = False
359
+
360
+ @property
361
+ def procedures(self) -> list[str]:
362
+ """Return list of procedures the model has been fitted on."""
363
+ return list(self.stats.keys())
364
+
365
+ @property
366
+ def fitted(self) -> bool:
367
+ """Whether the model has been fitted."""
368
+ return self._fitted
369
+
370
+ def fit(self, displacement_list: list[dict]) -> None:
371
+ """Fit the model from a list of extracted displacement dictionaries.
372
+
373
+ Groups displacements by classified procedure and computes per-landmark
374
+ statistics for each group.
375
+
376
+ Args:
377
+ displacement_list: List of dicts as returned by
378
+ ``extract_displacements`` or ``extract_from_directory``.
379
+ Each must contain ``"displacements"`` (478, 2) and
380
+ ``"procedure"`` (str) keys.
381
+
382
+ Raises:
383
+ ValueError: If ``displacement_list`` is empty or contains no
384
+ valid displacement data.
385
+ """
386
+ if not displacement_list:
387
+ raise ValueError("displacement_list is empty.")
388
+
389
+ # Group by procedure
390
+ procedure_groups: dict[str, list[np.ndarray]] = {}
391
+ for entry in displacement_list:
392
+ proc = entry.get("procedure", "unknown")
393
+ disp = entry.get("displacements")
394
+ if disp is None:
395
+ logger.warning("Skipping entry without 'displacements' key.")
396
+ continue
397
+ if disp.shape != (NUM_LANDMARKS, 2):
398
+ logger.warning(
399
+ "Skipping entry with unexpected shape %s (expected (%d, 2)).",
400
+ disp.shape,
401
+ NUM_LANDMARKS,
402
+ )
403
+ continue
404
+
405
+ if proc not in procedure_groups:
406
+ procedure_groups[proc] = []
407
+ procedure_groups[proc].append(disp)
408
+
409
+ if not procedure_groups:
410
+ raise ValueError("No valid displacement data found in displacement_list.")
411
+
412
+ # Compute per-procedure statistics
413
+ self.stats = {}
414
+ self.n_samples = {}
415
+
416
+ for proc, disp_arrays in procedure_groups.items():
417
+ stacked = np.stack(disp_arrays, axis=0) # (N, 478, 2)
418
+ n = stacked.shape[0]
419
+
420
+ self.stats[proc] = {
421
+ "mean": np.mean(stacked, axis=0), # (478, 2)
422
+ "std": np.std(stacked, axis=0), # (478, 2)
423
+ "min": np.min(stacked, axis=0), # (478, 2)
424
+ "max": np.max(stacked, axis=0), # (478, 2)
425
+ "median": np.median(stacked, axis=0), # (478, 2)
426
+ "mean_magnitude": np.mean( # (478,)
427
+ np.linalg.norm(stacked, axis=2), axis=0
428
+ ),
429
+ }
430
+ self.n_samples[proc] = n
431
+ logger.info(
432
+ "Fitted procedure '%s': %d samples, mean magnitude=%.5f",
433
+ proc,
434
+ n,
435
+ float(np.mean(self.stats[proc]["mean_magnitude"])),
436
+ )
437
+
438
+ self._fitted = True
439
+
440
+ def get_displacement_field(
441
+ self,
442
+ procedure: str,
443
+ intensity: float = 1.0,
444
+ noise_scale: float = 0.0,
445
+ rng: Optional[np.random.Generator] = None,
446
+ ) -> np.ndarray:
447
+ """Generate a displacement field for a given procedure and intensity.
448
+
449
+ Returns the mean displacement scaled by ``intensity``, optionally
450
+ with Gaussian noise added (scaled by per-landmark std).
451
+
452
+ Args:
453
+ procedure: Procedure name (must exist in the fitted model).
454
+ intensity: Scaling factor for the mean displacement. 1.0 = average
455
+ observed displacement; 0.5 = half intensity; etc.
456
+ noise_scale: If > 0, adds Gaussian noise with this many standard
457
+ deviations of variation. 0.0 = deterministic mean field.
458
+ rng: NumPy random generator for reproducible noise. If ``None``
459
+ and ``noise_scale > 0``, uses ``np.random.default_rng()``.
460
+
461
+ Returns:
462
+ (478, 2) displacement field in normalized coordinate space.
463
+
464
+ Raises:
465
+ RuntimeError: If the model has not been fitted.
466
+ KeyError: If the procedure is not in the model.
467
+ """
468
+ if not self._fitted:
469
+ raise RuntimeError("Model has not been fitted. Call fit() first.")
470
+
471
+ if procedure not in self.stats:
472
+ available = ", ".join(self.procedures)
473
+ raise KeyError(
474
+ f"Procedure '{procedure}' not in model. "
475
+ f"Available: {available}"
476
+ )
477
+
478
+ proc_stats = self.stats[procedure]
479
+ field = proc_stats["mean"].copy() * intensity
480
+
481
+ if noise_scale > 0:
482
+ if rng is None:
483
+ rng = np.random.default_rng()
484
+ noise = rng.normal(
485
+ loc=0.0,
486
+ scale=proc_stats["std"] * noise_scale,
487
+ )
488
+ field += noise
489
+
490
+ return field.astype(np.float32)
491
+
492
+ def get_summary(self, procedure: Optional[str] = None) -> dict:
493
+ """Get a human-readable summary of the model statistics.
494
+
495
+ Args:
496
+ procedure: If provided, return summary for one procedure.
497
+ If ``None``, return summaries for all procedures.
498
+
499
+ Returns:
500
+ Dictionary with summary statistics.
501
+ """
502
+ if not self._fitted:
503
+ return {"fitted": False}
504
+
505
+ procs = [procedure] if procedure else self.procedures
506
+ summary = {"fitted": True, "procedures": {}}
507
+
508
+ for proc in procs:
509
+ if proc not in self.stats:
510
+ continue
511
+ s = self.stats[proc]
512
+ summary["procedures"][proc] = {
513
+ "n_samples": self.n_samples[proc],
514
+ "global_mean_magnitude": float(np.mean(s["mean_magnitude"])),
515
+ "global_max_magnitude": float(np.max(s["mean_magnitude"])),
516
+ "top_landmarks": _top_k_landmarks(s["mean_magnitude"], k=10),
517
+ }
518
+
519
+ return summary
520
+
521
+ def save(self, path: Union[str, Path]) -> None:
522
+ """Save the fitted model to disk as a ``.npz`` file.
523
+
524
+ The file contains:
525
+ - Per-procedure stat arrays keyed as ``{procedure}__{stat_name}``
526
+ - A JSON metadata string with sample counts and procedure list
527
+
528
+ Args:
529
+ path: Output file path. Extension ``.npz`` is added if missing.
530
+
531
+ Raises:
532
+ RuntimeError: If the model has not been fitted.
533
+ """
534
+ if not self._fitted:
535
+ raise RuntimeError("Model has not been fitted. Call fit() first.")
536
+
537
+ path = Path(path)
538
+ if path.suffix != ".npz":
539
+ path = path.with_suffix(".npz")
540
+
541
+ arrays: dict[str, np.ndarray] = {}
542
+ for proc, proc_stats in self.stats.items():
543
+ for stat_name, arr in proc_stats.items():
544
+ key = f"{proc}__{stat_name}"
545
+ arrays[key] = arr
546
+
547
+ # Store metadata as a JSON string encoded to bytes
548
+ metadata = {
549
+ "procedures": self.procedures,
550
+ "n_samples": self.n_samples,
551
+ "num_landmarks": NUM_LANDMARKS,
552
+ }
553
+ arrays["__metadata__"] = np.frombuffer(
554
+ json.dumps(metadata).encode("utf-8"), dtype=np.uint8
555
+ )
556
+
557
+ np.savez_compressed(str(path), **arrays)
558
+ logger.info("Saved displacement model to %s", path)
559
+
560
+ @classmethod
561
+ def load(cls, path: Union[str, Path]) -> "DisplacementModel":
562
+ """Load a fitted model from a ``.npz`` file.
563
+
564
+ Supports two formats:
565
+ 1. ``save()`` format: keys like ``{proc}__{stat}`` with ``__metadata__``
566
+ 2. ``extract_displacements.py`` format: keys like ``{proc}_{stat}``
567
+ with a ``procedures`` array
568
+
569
+ Args:
570
+ path: Path to the ``.npz`` file.
571
+
572
+ Returns:
573
+ A fitted ``DisplacementModel`` instance.
574
+
575
+ Raises:
576
+ FileNotFoundError: If the file does not exist.
577
+ """
578
+ path = Path(path)
579
+ if not path.exists():
580
+ raise FileNotFoundError(f"Model file not found: {path}")
581
+
582
+ data = np.load(str(path), allow_pickle=False)
583
+ model = cls()
584
+
585
+ # Format 1: save() format with __metadata__
586
+ if "__metadata__" in data.files:
587
+ meta_bytes = data["__metadata__"].tobytes()
588
+ metadata = json.loads(meta_bytes.decode("utf-8"))
589
+ model.n_samples = {k: int(v) for k, v in metadata["n_samples"].items()}
590
+
591
+ for proc in metadata["procedures"]:
592
+ model.stats[proc] = {}
593
+ for key in data.files:
594
+ if key.startswith(f"{proc}__"):
595
+ stat_name = key[len(f"{proc}__"):]
596
+ model.stats[proc][stat_name] = data[key]
597
+
598
+ # Format 2: extract_displacements.py format with procedures array
599
+ elif "procedures" in data.files:
600
+ procedures = [str(p) for p in data["procedures"]]
601
+ # Map from extraction script key names to DisplacementModel stat names
602
+ stat_map = {
603
+ "mean": "mean",
604
+ "std": "std",
605
+ "median": "median",
606
+ "min": "min",
607
+ "max": "max",
608
+ "mag_mean": "mean_magnitude",
609
+ "mag_std": "std_magnitude",
610
+ "count": "_count",
611
+ }
612
+ for proc in procedures:
613
+ model.stats[proc] = {}
614
+ for ext_key, model_key in stat_map.items():
615
+ npz_key = f"{proc}_{ext_key}"
616
+ if npz_key in data.files:
617
+ arr = data[npz_key]
618
+ if model_key == "_count":
619
+ model.n_samples[proc] = int(arr)
620
+ else:
621
+ model.stats[proc][model_key] = arr
622
+
623
+ # Ensure count is set
624
+ if proc not in model.n_samples:
625
+ model.n_samples[proc] = 0
626
+
627
+ else:
628
+ raise ValueError(
629
+ f"Unrecognized displacement model format. "
630
+ f"Keys: {data.files[:10]}"
631
+ )
632
+
633
+ model._fitted = True
634
+ logger.info(
635
+ "Loaded displacement model from %s (%d procedures, %s samples)",
636
+ path,
637
+ len(model.procedures),
638
+ model.n_samples,
639
+ )
640
+ return model
641
+
642
+
643
+ # ---------------------------------------------------------------------------
644
+ # Utilities
645
+ # ---------------------------------------------------------------------------
646
+
647
+ def _top_k_landmarks(
648
+ magnitudes: np.ndarray,
649
+ k: int = 10,
650
+ ) -> list[dict]:
651
+ """Return the top-k landmarks by mean displacement magnitude.
652
+
653
+ Args:
654
+ magnitudes: (478,) array of per-landmark magnitudes.
655
+ k: Number of top landmarks to return.
656
+
657
+ Returns:
658
+ List of dicts with ``index`` and ``magnitude`` keys, sorted
659
+ descending by magnitude.
660
+ """
661
+ top_indices = np.argsort(magnitudes)[::-1][:k]
662
+ return [
663
+ {"index": int(idx), "magnitude": float(magnitudes[idx])}
664
+ for idx in top_indices
665
+ ]
666
+
667
+
668
+ def visualize_displacements(
669
+ before_img: np.ndarray,
670
+ result: dict,
671
+ scale: float = 10.0,
672
+ arrow_color: tuple[int, int, int] = (0, 255, 0),
673
+ thickness: int = 1,
674
+ ) -> np.ndarray:
675
+ """Draw displacement arrows on the before image for visual inspection.
676
+
677
+ Args:
678
+ before_img: BGR image (will be copied).
679
+ result: Displacement dict from ``extract_displacements``.
680
+ scale: Arrow length multiplier (displacements are small in
681
+ normalized space, so scale up for visibility).
682
+ arrow_color: BGR color for arrows.
683
+ thickness: Arrow line thickness.
684
+
685
+ Returns:
686
+ Annotated BGR image.
687
+ """
688
+ canvas = before_img.copy()
689
+ h, w = canvas.shape[:2]
690
+
691
+ coords_before = result["landmarks_before"]
692
+ displacements = result["displacements"]
693
+
694
+ for i in range(NUM_LANDMARKS):
695
+ bx = int(coords_before[i, 0] * w)
696
+ by = int(coords_before[i, 1] * h)
697
+ dx = int(displacements[i, 0] * w * scale)
698
+ dy = int(displacements[i, 1] * h * scale)
699
+
700
+ # Only draw if displacement is above noise floor
701
+ mag = np.sqrt(dx ** 2 + dy ** 2)
702
+ if mag < 1.0:
703
+ continue
704
+
705
+ cv2.arrowedLine(
706
+ canvas,
707
+ (bx, by),
708
+ (bx + dx, by + dy),
709
+ arrow_color,
710
+ thickness,
711
+ tipLength=0.3,
712
+ )
713
+
714
+ # Add procedure label and quality score
715
+ proc = result.get("procedure", "unknown")
716
+ quality = result.get("quality_score", 0.0)
717
+ label = f"{proc} (quality={quality:.2f})"
718
+ cv2.putText(
719
+ canvas,
720
+ label,
721
+ (10, 30),
722
+ cv2.FONT_HERSHEY_SIMPLEX,
723
+ 0.8,
724
+ (255, 255, 255),
725
+ 2,
726
+ )
727
+
728
+ return canvas