dreamlessx commited on
Commit
996d0fe
·
verified ·
1 Parent(s): d895a0c

Upload landmarkdiff/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/inference.py +525 -0
landmarkdiff/inference.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference pipeline for surgical outcome prediction.
2
+
3
+ Modes:
4
+ 1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (HF auth + GPU)
5
+ 2. ControlNet + IP-Adapter: ControlNet w/ identity preservation
6
+ 3. Img2Img: SD1.5 img2img with mask compositing (MPS ok, no auth)
7
+ 4. TPS-only: geometric warp, no diffusion, instant
8
+
9
+ Works on MPS (Apple Silicon), CUDA, and CPU.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
24
+ from landmarkdiff.conditioning import generate_conditioning
25
+ from landmarkdiff.manipulation import apply_procedure_preset
26
+ from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
27
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
28
+
29
+
30
+ def get_device() -> torch.device:
31
+ if torch.backends.mps.is_available():
32
+ return torch.device("mps")
33
+ if torch.cuda.is_available():
34
+ return torch.device("cuda")
35
+ return torch.device("cpu")
36
+
37
+
38
+ def numpy_to_pil(arr: np.ndarray) -> Image.Image:
39
+ if len(arr.shape) == 2:
40
+ return Image.fromarray(arr, mode="L")
41
+ return Image.fromarray(arr[:, :, ::-1])
42
+
43
+
44
+ def pil_to_numpy(img: Image.Image) -> np.ndarray:
45
+ arr = np.array(img)
46
+ if len(arr.shape) == 3 and arr.shape[2] == 3:
47
+ return arr[:, :, ::-1].copy()
48
+ return arr
49
+
50
+
51
+ PROCEDURE_PROMPTS: dict[str, str] = {
52
+ "rhinoplasty": (
53
+ "clinical photograph, patient face, natural refined nose, smooth nasal bridge, "
54
+ "realistic skin pores and texture, sharp focus, studio lighting, "
55
+ "DSLR quality, natural skin color"
56
+ ),
57
+ "blepharoplasty": (
58
+ "clinical photograph, patient face, natural eyelids, smooth periorbital area, "
59
+ "realistic skin pores and texture, sharp focus, studio lighting, "
60
+ "DSLR quality, natural skin color"
61
+ ),
62
+ "rhytidectomy": (
63
+ "clinical photograph, patient face, defined jawline, smooth facial contour, "
64
+ "realistic skin pores and texture, sharp focus, studio lighting, "
65
+ "DSLR quality, natural skin color"
66
+ ),
67
+ "orthognathic": (
68
+ "clinical photograph, patient face, balanced jaw and chin proportions, "
69
+ "realistic skin pores and texture, sharp focus, studio lighting, "
70
+ "DSLR quality, natural skin color"
71
+ ),
72
+ }
73
+
74
+ NEGATIVE_PROMPT = (
75
+ "painting, drawing, illustration, cartoon, anime, render, 3d, cgi, "
76
+ "blurry, distorted, deformed, disfigured, bad anatomy, bad proportions, "
77
+ "extra limbs, mutated, poorly drawn face, ugly, low quality, low resolution, "
78
+ "watermark, text, signature, duplicate, artifact, noise, overexposed, "
79
+ "plastic skin, waxy, smooth skin, airbrushed, oversaturated"
80
+ )
81
+
82
+
83
+ def mask_composite(
84
+ warped: np.ndarray,
85
+ original: np.ndarray,
86
+ mask: np.ndarray,
87
+ use_laplacian: bool = True,
88
+ ) -> np.ndarray:
89
+ """Blend warped region into original via Laplacian pyramid + LAB skin-tone match."""
90
+ mask_f = mask.astype(np.float32)
91
+ if mask_f.max() > 1.0:
92
+ mask_f = mask_f / 255.0
93
+
94
+ # Match color of warped region to original skin tone in LAB space
95
+ corrected = _match_skin_tone(warped, original, mask_f)
96
+
97
+ if use_laplacian:
98
+ try:
99
+ from landmarkdiff.postprocess import laplacian_pyramid_blend
100
+ return laplacian_pyramid_blend(corrected, original, mask_f)
101
+ except Exception:
102
+ pass
103
+
104
+ # Fallback: simple alpha blend
105
+ mask_3ch = mask_to_3channel(mask_f)
106
+ result = (
107
+ corrected.astype(np.float32) * mask_3ch
108
+ + original.astype(np.float32) * (1.0 - mask_3ch)
109
+ ).astype(np.uint8)
110
+
111
+ return result
112
+
113
+
114
+ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
115
+ """LAB-space color transfer so warped region matches original skin tone."""
116
+ mask_bool = mask > 0.3
117
+ if not np.any(mask_bool):
118
+ return source
119
+
120
+ src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
121
+ tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
122
+
123
+ # match per-channel stats in masked region
124
+ for ch in range(3):
125
+ src_vals = src_lab[:, :, ch][mask_bool]
126
+ tgt_vals = tgt_lab[:, :, ch][mask_bool]
127
+
128
+ src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
129
+ tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
130
+
131
+ # shift+scale to match target distribution
132
+ src_lab[:, :, ch] = np.where(
133
+ mask_bool,
134
+ (src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
135
+ src_lab[:, :, ch],
136
+ )
137
+
138
+ src_lab = np.clip(src_lab, 0, 255)
139
+ return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
140
+
141
+
142
+ def poisson_blend(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
143
+ """Poisson blend - just delegates to mask_composite (more reliable)."""
144
+ return mask_composite(source, target, mask)
145
+
146
+
147
+ class LandmarkDiffPipeline:
148
+ """Image -> landmarks -> manipulate -> generate."""
149
+
150
+ # Default IP-Adapter model for SD1.5 face identity
151
+ IP_ADAPTER_REPO = "h94/IP-Adapter"
152
+ IP_ADAPTER_SUBFOLDER = "models"
153
+ IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus-face_sd15.bin"
154
+ IP_ADAPTER_SCALE_DEFAULT = 0.6
155
+
156
+ def __init__(
157
+ self,
158
+ mode: str = "img2img",
159
+ controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
160
+ base_model_id: str | None = None,
161
+ device: Optional[torch.device] = None,
162
+ dtype: Optional[torch.dtype] = None,
163
+ ip_adapter_scale: float = 0.6,
164
+ clinical_flags: Optional["ClinicalFlags"] = None,
165
+ ):
166
+ self.mode = mode
167
+ self.device = device or get_device()
168
+ self.ip_adapter_scale = ip_adapter_scale
169
+ self.clinical_flags = clinical_flags
170
+
171
+ if self.device.type == "mps":
172
+ self.dtype = torch.float32
173
+ elif dtype:
174
+ self.dtype = dtype
175
+ else:
176
+ self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
177
+
178
+ if base_model_id:
179
+ self.base_model_id = base_model_id
180
+ elif mode in ("controlnet", "controlnet_ip"):
181
+ self.base_model_id = "runwayml/stable-diffusion-v1-5"
182
+ else:
183
+ self.base_model_id = "runwayml/stable-diffusion-v1-5"
184
+
185
+ self.controlnet_id = controlnet_id
186
+ self._pipe = None
187
+ self._ip_adapter_loaded = False
188
+
189
+ def load(self) -> None:
190
+ if self.mode == "tps":
191
+ print("TPS mode - no model to load")
192
+ return
193
+ if self.mode in ("controlnet", "controlnet_ip"):
194
+ self._load_controlnet()
195
+ if self.mode == "controlnet_ip":
196
+ self._load_ip_adapter()
197
+ else:
198
+ self._load_img2img()
199
+
200
+ def _load_controlnet(self) -> None:
201
+ from diffusers import (
202
+ ControlNetModel,
203
+ StableDiffusionControlNetPipeline,
204
+ DPMSolverMultistepScheduler,
205
+ )
206
+
207
+ print(f"Loading ControlNet from {self.controlnet_id}...")
208
+ controlnet = ControlNetModel.from_pretrained(
209
+ self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
210
+ )
211
+ print(f"Loading base model from {self.base_model_id}...")
212
+ self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
213
+ self.base_model_id,
214
+ controlnet=controlnet,
215
+ torch_dtype=self.dtype,
216
+ safety_checker=None,
217
+ requires_safety_checker=False,
218
+ )
219
+ # DPM++ 2M Karras - better skin than UniPC
220
+ self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
221
+ self._pipe.scheduler.config,
222
+ algorithm_type="dpmsolver++",
223
+ use_karras_sigmas=True,
224
+ )
225
+ # FP32 VAE decode - prevents color banding on skin
226
+ if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
227
+ self._pipe.vae.config.force_upcast = True
228
+ self._apply_device_optimizations()
229
+
230
+ def _load_ip_adapter(self) -> None:
231
+ """Load IP-Adapter for identity preservation."""
232
+ if self._pipe is None:
233
+ raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
234
+ try:
235
+ print(f"Loading IP-Adapter ({self.IP_ADAPTER_WEIGHT_NAME})...")
236
+ self._pipe.load_ip_adapter(
237
+ self.IP_ADAPTER_REPO,
238
+ subfolder=self.IP_ADAPTER_SUBFOLDER,
239
+ weight_name=self.IP_ADAPTER_WEIGHT_NAME,
240
+ )
241
+ self._pipe.set_ip_adapter_scale(self.ip_adapter_scale)
242
+ self._ip_adapter_loaded = True
243
+ print(f"IP-Adapter loaded (scale={self.ip_adapter_scale})")
244
+ except Exception as e:
245
+ print(f"WARNING: IP-Adapter load failed: {e}")
246
+ print("Falling back to ControlNet-only mode")
247
+ self._ip_adapter_loaded = False
248
+
249
+ def _load_img2img(self) -> None:
250
+ from diffusers import (
251
+ StableDiffusionImg2ImgPipeline,
252
+ DPMSolverMultistepScheduler,
253
+ )
254
+
255
+ print(f"Loading SD1.5 img2img from {self.base_model_id}...")
256
+ self._pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
257
+ self.base_model_id,
258
+ torch_dtype=self.dtype,
259
+ safety_checker=None,
260
+ requires_safety_checker=False,
261
+ )
262
+ self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
263
+ self._pipe.scheduler.config
264
+ )
265
+ self._apply_device_optimizations()
266
+
267
+ def _apply_device_optimizations(self) -> None:
268
+ if self.device.type == "mps":
269
+ self._pipe = self._pipe.to(self.device)
270
+ self._pipe.enable_attention_slicing()
271
+ elif self.device.type == "cuda":
272
+ try:
273
+ self._pipe.enable_model_cpu_offload()
274
+ except Exception:
275
+ self._pipe = self._pipe.to(self.device)
276
+ else:
277
+ self._pipe.enable_sequential_cpu_offload()
278
+ print(f"Pipeline loaded on {self.device} ({self.dtype})")
279
+
280
+ @property
281
+ def is_loaded(self) -> bool:
282
+ return self._pipe is not None or self.mode == "tps"
283
+
284
+ def generate(
285
+ self,
286
+ image: np.ndarray,
287
+ procedure: str = "rhinoplasty",
288
+ intensity: float = 50.0,
289
+ num_inference_steps: int = 30,
290
+ guidance_scale: float = 9.0,
291
+ controlnet_conditioning_scale: float = 0.9,
292
+ strength: float = 0.5,
293
+ seed: Optional[int] = None,
294
+ clinical_flags: Optional["ClinicalFlags"] = None,
295
+ postprocess: bool = True,
296
+ use_gfpgan: bool = False,
297
+ ) -> dict:
298
+ if not self.is_loaded:
299
+ raise RuntimeError("Pipeline not loaded. Call .load() first.")
300
+
301
+ flags = clinical_flags or self.clinical_flags
302
+ image_512 = cv2.resize(image, (512, 512))
303
+
304
+ face = extract_landmarks(image_512)
305
+ if face is None:
306
+ raise ValueError("No face detected in image.")
307
+
308
+ # face view angle for multi-view awareness
309
+ view_info = estimate_face_view(face)
310
+
311
+ manipulated = apply_procedure_preset(
312
+ face, procedure, intensity, image_size=512, clinical_flags=flags,
313
+ )
314
+ landmark_img = render_landmark_image(manipulated, 512, 512)
315
+ mask = generate_surgical_mask(
316
+ face, procedure, 512, 512, clinical_flags=flags,
317
+ )
318
+
319
+ generator = None
320
+ if seed is not None:
321
+ generator = torch.Generator(device="cpu").manual_seed(seed)
322
+
323
+ prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
324
+
325
+ # TPS warp is always the geometric baseline
326
+ tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
327
+
328
+ if self.mode == "tps":
329
+ raw_output = tps_warped
330
+ elif self.mode in ("controlnet", "controlnet_ip"):
331
+ ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
332
+ raw_output = self._generate_controlnet(
333
+ image_512, landmark_img, prompt, num_inference_steps,
334
+ guidance_scale, controlnet_conditioning_scale, generator,
335
+ ip_adapter_image=ip_image,
336
+ )
337
+ else:
338
+ raw_output = self._generate_img2img(
339
+ tps_warped, mask, prompt, num_inference_steps,
340
+ guidance_scale, strength, generator,
341
+ )
342
+
343
+ # postprocess for photorealism
344
+ identity_check = None
345
+ restore_used = "none"
346
+ if postprocess and self.mode != "tps":
347
+ from landmarkdiff.postprocess import full_postprocess
348
+ pp_result = full_postprocess(
349
+ generated=raw_output,
350
+ original=image_512,
351
+ mask=mask,
352
+ restore_mode="codeformer" if use_gfpgan else "none",
353
+ use_realesrgan=use_gfpgan,
354
+ use_laplacian_blend=True,
355
+ sharpen_strength=0.25,
356
+ verify_identity=True,
357
+ )
358
+ composited = pp_result["image"]
359
+ identity_check = pp_result["identity_check"]
360
+ restore_used = pp_result["restore_used"]
361
+ else:
362
+ composited = mask_composite(raw_output, image_512, mask)
363
+
364
+ return {
365
+ "output": composited,
366
+ "output_raw": raw_output,
367
+ "output_tps": tps_warped,
368
+ "input": image_512,
369
+ "landmarks_original": face,
370
+ "landmarks_manipulated": manipulated,
371
+ "conditioning": landmark_img,
372
+ "mask": mask,
373
+ "procedure": procedure,
374
+ "intensity": intensity,
375
+ "device": str(self.device),
376
+ "mode": self.mode,
377
+ "view_info": view_info,
378
+ "ip_adapter_active": self._ip_adapter_loaded,
379
+ "identity_check": identity_check,
380
+ "restore_used": restore_used,
381
+ }
382
+
383
+ def _generate_controlnet(
384
+ self, image: np.ndarray, conditioning: np.ndarray,
385
+ prompt: str, steps: int, cfg: float, cn_scale: float,
386
+ generator: torch.Generator | None,
387
+ ip_adapter_image: Image.Image | None = None,
388
+ ) -> np.ndarray:
389
+ kwargs = dict(
390
+ prompt=prompt,
391
+ negative_prompt=NEGATIVE_PROMPT,
392
+ image=numpy_to_pil(conditioning), # control conditioning only
393
+ num_inference_steps=steps,
394
+ guidance_scale=cfg,
395
+ controlnet_conditioning_scale=cn_scale,
396
+ generator=generator,
397
+ )
398
+ if ip_adapter_image is not None and self._ip_adapter_loaded:
399
+ kwargs["ip_adapter_image"] = ip_adapter_image
400
+ result = self._pipe(**kwargs)
401
+ return pil_to_numpy(result.images[0])
402
+
403
+ def _generate_img2img(
404
+ self, image: np.ndarray, mask: np.ndarray,
405
+ prompt: str, steps: int, cfg: float, strength: float,
406
+ generator: torch.Generator | None,
407
+ ) -> np.ndarray:
408
+ result = self._pipe(
409
+ prompt=prompt,
410
+ negative_prompt=NEGATIVE_PROMPT,
411
+ image=numpy_to_pil(image),
412
+ num_inference_steps=steps,
413
+ guidance_scale=cfg,
414
+ strength=strength,
415
+ generator=generator,
416
+ )
417
+ return pil_to_numpy(result.images[0])
418
+
419
+
420
+ def estimate_face_view(face: FaceLandmarks) -> dict:
421
+ """Yaw/pitch from nose-ear and forehead-chin distances. Returns view dict."""
422
+ coords = face.pixel_coords
423
+ nose_tip = coords[1]
424
+ left_ear = coords[234]
425
+ right_ear = coords[454]
426
+ forehead = coords[10]
427
+ chin = coords[152]
428
+
429
+ # Yaw: ratio of nose-to-ear distances (symmetric = 0 degrees)
430
+ left_dist = np.linalg.norm(nose_tip - left_ear)
431
+ right_dist = np.linalg.norm(nose_tip - right_ear)
432
+ total = left_dist + right_dist
433
+ if total < 1.0:
434
+ yaw = 0.0
435
+ else:
436
+ ratio = (right_dist - left_dist) / total
437
+ yaw = float(np.arcsin(np.clip(ratio, -1, 1)) * 180 / np.pi)
438
+
439
+ # Pitch: nose-to-chin vs forehead-to-nose vertical ratio
440
+ upper = np.linalg.norm(forehead - nose_tip)
441
+ lower = np.linalg.norm(nose_tip - chin)
442
+ if upper + lower < 1.0:
443
+ pitch = 0.0
444
+ else:
445
+ pitch_ratio = (lower - upper) / (upper + lower)
446
+ pitch = float(pitch_ratio * 45)
447
+
448
+ # Classify view
449
+ abs_yaw = abs(yaw)
450
+ if abs_yaw < 15:
451
+ view = "frontal"
452
+ elif abs_yaw < 45:
453
+ view = "three_quarter"
454
+ else:
455
+ view = "profile"
456
+
457
+ return {
458
+ "yaw": round(yaw, 1),
459
+ "pitch": round(pitch, 1),
460
+ "view": view,
461
+ "is_frontal": abs_yaw < 15,
462
+ "warning": "Side-view detected: results may be less accurate" if abs_yaw > 30 else None,
463
+ }
464
+
465
+
466
+ def run_inference(
467
+ image_path: str,
468
+ procedure: str = "rhinoplasty",
469
+ intensity: float = 50.0,
470
+ output_dir: str = "scripts/inference_output",
471
+ seed: int = 42,
472
+ mode: str = "img2img",
473
+ ip_adapter_scale: float = 0.6,
474
+ ) -> None:
475
+ out = Path(output_dir)
476
+ out.mkdir(parents=True, exist_ok=True)
477
+
478
+ image = cv2.imread(image_path)
479
+ if image is None:
480
+ print(f"ERROR: Could not load {image_path}")
481
+ sys.exit(1)
482
+
483
+ pipe = LandmarkDiffPipeline(mode=mode, ip_adapter_scale=ip_adapter_scale)
484
+ pipe.load()
485
+
486
+ print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
487
+ result = pipe.generate(image, procedure=procedure, intensity=intensity, seed=seed)
488
+
489
+ cv2.imwrite(str(out / "input.png"), result["input"])
490
+ cv2.imwrite(str(out / "output.png"), result["output"])
491
+ cv2.imwrite(str(out / "output_raw.png"), result["output_raw"])
492
+ cv2.imwrite(str(out / "output_tps.png"), result["output_tps"])
493
+ cv2.imwrite(str(out / "conditioning.png"), result["conditioning"])
494
+ cv2.imwrite(str(out / "mask.png"), (result["mask"] * 255).astype(np.uint8))
495
+
496
+ comparison = np.hstack([result["input"], result["output_tps"], result["output"]])
497
+ cv2.imwrite(str(out / "comparison.png"), comparison)
498
+
499
+ view = result.get("view_info", {})
500
+ if view.get("warning"):
501
+ print(f"WARNING: {view['warning']}")
502
+ print(f"Face view: {view.get('view', 'unknown')} (yaw={view.get('yaw', 0)})")
503
+ print(f"Results saved to {out}/")
504
+
505
+
506
+ if __name__ == "__main__":
507
+ import argparse
508
+
509
+ parser = argparse.ArgumentParser(description="LandmarkDiff inference")
510
+ parser.add_argument("image", help="Path to face image")
511
+ parser.add_argument("--procedure", default="rhinoplasty")
512
+ parser.add_argument("--intensity", type=float, default=50.0)
513
+ parser.add_argument("--output", default="scripts/inference_output")
514
+ parser.add_argument("--seed", type=int, default=42)
515
+ parser.add_argument(
516
+ "--mode", default="img2img",
517
+ choices=["img2img", "controlnet", "controlnet_ip", "tps"],
518
+ )
519
+ parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
520
+ args = parser.parse_args()
521
+
522
+ run_inference(
523
+ args.image, args.procedure, args.intensity, args.output,
524
+ args.seed, args.mode, args.ip_adapter_scale,
525
+ )