Daankular commited on
Commit
45eaaef
Β·
verified Β·
1 Parent(s): 1097d24

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +2238 -0
app.py ADDED
@@ -0,0 +1,2238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import subprocess
4
+ import tempfile
5
+ import shutil
6
+ import traceback
7
+ import json
8
+ import random
9
+ from pathlib import Path
10
+
11
+ REPO_DIR = Path(__file__).resolve().parent
12
+ PIPELINE_DIR = REPO_DIR / "pipeline"
13
+ if str(REPO_DIR) not in sys.path:
14
+ sys.path.insert(0, str(REPO_DIR))
15
+ if str(PIPELINE_DIR) not in sys.path:
16
+ sys.path.insert(0, str(PIPELINE_DIR))
17
+
18
+ try:
19
+ from pipeline.enhance_surface import (
20
+ run_stable_normal,
21
+ run_depth_anything,
22
+ bake_normal_into_glb,
23
+ bake_depth_as_occlusion,
24
+ unload_models,
25
+ )
26
+ import pipeline.enhance_surface as _enh_mod
27
+ except Exception:
28
+ from enhance_surface import (
29
+ run_stable_normal,
30
+ run_depth_anything,
31
+ bake_normal_into_glb,
32
+ bake_depth_as_occlusion,
33
+ unload_models,
34
+ )
35
+ import enhance_surface as _enh_mod
36
+
37
+ import cv2
38
+ import gradio as gr
39
+ import torch
40
+ import numpy as np
41
+ from PIL import Image
42
+
43
+ PYTHON = os.getenv("MESHFORGE_PYTHON", sys.executable)
44
+ TRIPOSG_DIR = os.getenv("MESHFORGE_TRIPOSG_DIR", str(REPO_DIR / "external" / "TripoSG"))
45
+ MVADAPTER_DIR = os.getenv(
46
+ "MESHFORGE_MVADAPTER_DIR", str(REPO_DIR / "external" / "MV-Adapter")
47
+ )
48
+ CKPT_DIR = os.getenv("MESHFORGE_CKPT_DIR", str(Path(MVADAPTER_DIR) / "checkpoints"))
49
+ FIRERED_DIR = os.getenv(
50
+ "MESHFORGE_FIRERED_DIR", str(REPO_DIR / "external" / "FireRed-Image-Edit")
51
+ )
52
+ TMP_DIR = Path(os.getenv("MESHFORGE_TMP_DIR", tempfile.gettempdir())) / "meshforge"
53
+ TMP_DIR.mkdir(parents=True, exist_ok=True)
54
+ os.environ["GRADIO_CDN_BACKEND_ENABLED"] = "False"
55
+ os.environ["GRADIO_UPLOAD_CHUNK_SIZE"] = (
56
+ "8388608" # 8 MB chunks β€” fixes 504 timeout on gradio.live tunnel
57
+ )
58
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
59
+ "expandable_segments:True" # reduces fragmentation for 17GB transformer + 5GB activations
60
+ )
61
+
62
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
63
+
64
+ # Lazy-loaded models (kept in memory between calls)
65
+ _triposg_pipe = None
66
+ _rmbg_net = None
67
+ _last_glb_path = None
68
+ _hyperswap_sess = None
69
+ _gfpgan_restorer = None
70
+ _rmbg_version = None # "2.0"
71
+ _firered_pipe = None
72
+ _init_seed = random.randint(0, 2**31 - 1)
73
+
74
+ import threading
75
+
76
+ _model_load_lock = threading.Lock()
77
+
78
+ ARCFACE_256 = (
79
+ np.array(
80
+ [
81
+ [38.2946, 51.6963],
82
+ [73.5318, 51.5014],
83
+ [56.0252, 71.7366],
84
+ [41.5493, 92.3655],
85
+ [70.7299, 92.2041],
86
+ ],
87
+ dtype=np.float32,
88
+ )
89
+ * (256 / 112)
90
+ + (256 - 112 * (256 / 112)) / 2
91
+ )
92
+
93
+ VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"]
94
+ VIEW_PATHS = [str(TMP_DIR / f"render_{n}.png") for n in VIEW_NAMES]
95
+
96
+
97
+ def _build_texture_env() -> dict:
98
+ """Build subprocess env for the MV-Adapter texture subprocess.
99
+
100
+ Runs vcvarsall.bat to initialise MSVC (needed by nvdiffrast JIT), captures
101
+ the resulting environment, then layers our extra variables on top.
102
+ """
103
+ import subprocess as _sp
104
+
105
+ base_env = os.environ.copy()
106
+
107
+ # Run vcvarsall.bat x64 and capture the environment it produces
108
+ vcvarsall = (
109
+ r"C:\Program Files\Microsoft Visual Studio\2022\Professional"
110
+ r"\VC\Auxiliary\Build\vcvarsall.bat"
111
+ )
112
+ if os.path.exists(vcvarsall):
113
+ try:
114
+ result = _sp.run(
115
+ f'"{vcvarsall}" x64 && set',
116
+ shell=True,
117
+ capture_output=True,
118
+ text=True,
119
+ timeout=30,
120
+ )
121
+ for line in result.stdout.splitlines():
122
+ if "=" in line:
123
+ k, _, v = line.partition("=")
124
+ base_env[k.strip()] = v.strip()
125
+ except Exception:
126
+ pass
127
+
128
+ base_env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0;12.0"
129
+ base_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
130
+ base_env.setdefault("CUDA_VISIBLE_DEVICES", "0")
131
+ base_env["HF_HUB_DISABLE_XET"] = "1"
132
+
133
+ try:
134
+ import ninja as _ninja
135
+ base_env["PATH"] = _ninja.BIN_DIR + os.pathsep + base_env.get("PATH", "")
136
+ except ImportError:
137
+ pass
138
+
139
+ return base_env
140
+
141
+
142
+ def load_triposg():
143
+ global _triposg_pipe, _rmbg_net, _rmbg_version
144
+ if _triposg_pipe is not None:
145
+ _triposg_pipe.to(DEVICE)
146
+ if _rmbg_net is not None:
147
+ _rmbg_net.to(DEVICE)
148
+ return _triposg_pipe, _rmbg_net
149
+ print("Loading TripoSG pipeline...")
150
+ sys.path.insert(0, TRIPOSG_DIR)
151
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
152
+ from huggingface_hub import snapshot_download
153
+
154
+ weights_path = snapshot_download("VAST-AI/TripoSG")
155
+ _triposg_pipe = TripoSGPipeline.from_pretrained(
156
+ weights_path, torch_dtype=torch.float16
157
+ ).to(DEVICE)
158
+
159
+ _load_rmbg()
160
+ return _triposg_pipe, _rmbg_net
161
+
162
+
163
+ def load_gfpgan():
164
+ global _gfpgan_restorer
165
+ if _gfpgan_restorer is not None:
166
+ return _gfpgan_restorer
167
+ try:
168
+ from gfpgan import GFPGANer
169
+ from basicsr.archs.rrdbnet_arch import RRDBNet
170
+ from realesrgan import RealESRGANer
171
+
172
+ model_path = os.path.join(CKPT_DIR, "GFPGANv1.4.pth")
173
+ if not os.path.exists(model_path):
174
+ print(f"[GFPGAN] Not found at {model_path}")
175
+ return None
176
+
177
+ # RealESRGAN x2plus as background upsampler β€” upscales face crop 2x before GFPGAN
178
+ realesrgan_path = os.path.join(CKPT_DIR, "RealESRGAN_x2plus.pth")
179
+ bg_upsampler = None
180
+ if os.path.exists(realesrgan_path):
181
+ bg_model = RRDBNet(
182
+ num_in_ch=3,
183
+ num_out_ch=3,
184
+ num_feat=64,
185
+ num_block=23,
186
+ num_grow_ch=32,
187
+ scale=2,
188
+ )
189
+ bg_upsampler = RealESRGANer(
190
+ scale=2,
191
+ model_path=realesrgan_path,
192
+ model=bg_model,
193
+ tile=400,
194
+ tile_pad=10,
195
+ pre_pad=0,
196
+ half=True,
197
+ )
198
+ print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded")
199
+ else:
200
+ print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler")
201
+
202
+ _gfpgan_restorer = GFPGANer(
203
+ model_path=model_path,
204
+ upscale=2,
205
+ arch="clean",
206
+ channel_multiplier=2,
207
+ bg_upsampler=bg_upsampler,
208
+ )
209
+ print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)")
210
+ return _gfpgan_restorer
211
+ except Exception as e:
212
+ print(f"[GFPGAN] Load failed: {e}")
213
+ return None
214
+
215
+
216
+ def _load_rmbg():
217
+ """Load RMBG-2.0 or fallback to RMBG-1.4."""
218
+ global _rmbg_net, _rmbg_version
219
+ if _rmbg_net is not None:
220
+ return
221
+
222
+ # Try RMBG-2.0 with transformers 5.x compatibility patches
223
+ try:
224
+ from transformers import AutoModelForImageSegmentation
225
+ from transformers import PreTrainedModel as _PTM
226
+
227
+ # Patch mark_tied_weights_as_initialized for transformers 5.x
228
+ _orig_mark_tied = _PTM.mark_tied_weights_as_initialized
229
+
230
+ def _safe_mark_tied(self, loading_info):
231
+ if not hasattr(self, "all_tied_weights_keys"):
232
+ self.all_tied_weights_keys = None
233
+ return _orig_mark_tied(self, loading_info)
234
+
235
+ _PTM.mark_tied_weights_as_initialized = _safe_mark_tied
236
+
237
+ try:
238
+ # Load with low_cpu_mem_usage=False to avoid meta device issues
239
+ _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
240
+ "1038lab/RMBG-2.0",
241
+ trust_remote_code=True,
242
+ low_cpu_mem_usage=False,
243
+ torch_dtype=torch.float32,
244
+ )
245
+ _rmbg_net.to(DEVICE).eval()
246
+ _rmbg_version = "2.0"
247
+ print("RMBG-2.0 loaded successfully.")
248
+ finally:
249
+ _PTM.mark_tied_weights_as_initialized = _orig_mark_tied
250
+
251
+ except Exception as e:
252
+ print(f"RMBG-2.0 load failed ({type(e).__name__}: {str(e)[:80]}...) - falling back to RMBG-1.4")
253
+ _rmbg_net = None
254
+ _rmbg_version = None
255
+
256
+ # Fallback to RMBG-1.4
257
+ try:
258
+ from huggingface_hub import snapshot_download
259
+ from external.TripoSG.scripts.briarmbg import BriaRMBG
260
+
261
+ rmbg_weights_dir = snapshot_download("briaai/RMBG-1.4")
262
+ _rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(DEVICE).eval()
263
+ _rmbg_version = "1.4"
264
+ print("RMBG-1.4 fallback loaded successfully.")
265
+ except Exception as e2:
266
+ _rmbg_net = None
267
+ _rmbg_version = None
268
+ print(f"RMBG-1.4 fallback failed ({type(e2).__name__}: {str(e2)[:80]}...) - background removal disabled.")
269
+
270
+
271
+ def load_rmbg_only():
272
+ """Load RMBG standalone without loading TripoSG."""
273
+ _load_rmbg()
274
+ return _rmbg_net
275
+
276
+
277
+ def load_firered():
278
+ """Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer.
279
+
280
+ Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk).
281
+ Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model).
282
+ Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails.
283
+
284
+ text_encoder: 4-bit NF4 on GPU (~5.6 GB).
285
+ GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget.
286
+ Lightning scheduler: 4 steps, CFG 1.0 β†’ ~1-2 min per inference.
287
+
288
+ GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE β‰ˆ 24 GB.
289
+ """
290
+ global _firered_pipe
291
+ if _firered_pipe is not None:
292
+ return _firered_pipe
293
+
294
+ import math
295
+ from diffusers import (
296
+ QwenImageEditPlusPipeline,
297
+ FlowMatchEulerDiscreteScheduler,
298
+ GGUFQuantizationConfig,
299
+ )
300
+ from diffusers.models import QwenImageTransformer2DModel
301
+ from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
302
+ from accelerate import dispatch_model, infer_auto_device_map
303
+ from huggingface_hub import hf_hub_download
304
+
305
+ # Patch SDPA to cast K/V to match Q dtype.
306
+ import torch.nn.functional as _F
307
+
308
+ _orig_sdpa = _F.scaled_dot_product_attention
309
+
310
+ def _dtype_safe_sdpa(query, key, value, *a, **kw):
311
+ if key.dtype != query.dtype:
312
+ key = key.to(query.dtype)
313
+ if value.dtype != query.dtype:
314
+ value = value.to(query.dtype)
315
+ return _orig_sdpa(query, key, value, *a, **kw)
316
+
317
+ _F.scaled_dot_product_attention = _dtype_safe_sdpa
318
+
319
+ torch.cuda.empty_cache()
320
+
321
+ # Load RMBG NOW β€” before dispatch_model creates meta tensors that poison later loads
322
+ _load_rmbg()
323
+
324
+ gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16)
325
+
326
+ # ── Transformer: GGUF Q4_K_M β€” try fine-tuned Rapid-AIO first, fall back to base ──
327
+ transformer = None
328
+
329
+ # Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB)
330
+ try:
331
+ print(
332
+ "[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)..."
333
+ )
334
+ gguf_path = hf_hub_download(
335
+ repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF",
336
+ filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf",
337
+ )
338
+ print("[FireRed] Loading Rapid-AIO transformer from GGUF...")
339
+ transformer = QwenImageTransformer2DModel.from_single_file(
340
+ gguf_path,
341
+ quantization_config=gguf_config,
342
+ torch_dtype=torch.bfloat16,
343
+ config="Qwen/Qwen-Image-Edit-2511",
344
+ subfolder="transformer",
345
+ )
346
+ print("[FireRed] Rapid-AIO GGUF transformer loaded OK.")
347
+ except Exception as e:
348
+ print(
349
+ f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF..."
350
+ )
351
+ transformer = None
352
+
353
+ # Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB)
354
+ if transformer is None:
355
+ print(
356
+ "[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)..."
357
+ )
358
+ gguf_path = hf_hub_download(
359
+ repo_id="unsloth/Qwen-Image-Edit-2511-GGUF",
360
+ filename="qwen-image-edit-2511-Q4_K_M.gguf",
361
+ )
362
+ print("[FireRed] Loading base transformer from GGUF...")
363
+ transformer = QwenImageTransformer2DModel.from_single_file(
364
+ gguf_path,
365
+ quantization_config=gguf_config,
366
+ torch_dtype=torch.bfloat16,
367
+ config="Qwen/Qwen-Image-Edit-2511",
368
+ subfolder="transformer",
369
+ )
370
+ print("[FireRed] Base GGUF transformer loaded OK.")
371
+
372
+ print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...")
373
+ device_map = infer_auto_device_map(
374
+ transformer,
375
+ max_memory={0: "18GiB", "cpu": "90GiB"},
376
+ dtype=torch.bfloat16,
377
+ )
378
+ n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0"))
379
+ n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu")
380
+ print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU")
381
+ transformer = dispatch_model(transformer, device_map=device_map)
382
+ used_mb = torch.cuda.memory_allocated() // (1024**2)
383
+ print(f"[FireRed] Transformer dispatched β€” VRAM: {used_mb} MB")
384
+
385
+ # ── text_encoder: 4-bit NF4 on GPU (~5.6 GB) ──────────────────────────────
386
+ bnb_enc = BitsAndBytesConfig(
387
+ load_in_4bit=True,
388
+ bnb_4bit_quant_type="nf4",
389
+ bnb_4bit_compute_dtype=torch.bfloat16,
390
+ bnb_4bit_use_double_quant=True,
391
+ )
392
+ print("[FireRed] Loading text_encoder (4-bit NF4)...")
393
+ text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
394
+ "Qwen/Qwen-Image-Edit-2511",
395
+ subfolder="text_encoder",
396
+ quantization_config=bnb_enc,
397
+ device_map="auto",
398
+ )
399
+ used_mb = torch.cuda.memory_allocated() // (1024**2)
400
+ print(f"[FireRed] Text encoder loaded β€” VRAM: {used_mb} MB")
401
+
402
+ # ── Pipeline: VAE + scheduler + processor + tokenizer ─────────────────────
403
+ print("[FireRed] Loading pipeline...")
404
+ _firered_pipe = QwenImageEditPlusPipeline.from_pretrained(
405
+ "Qwen/Qwen-Image-Edit-2511",
406
+ transformer=transformer,
407
+ text_encoder=text_encoder,
408
+ torch_dtype=torch.bfloat16,
409
+ )
410
+ _firered_pipe.vae.to(DEVICE)
411
+
412
+ # Lightning scheduler β€” 4 steps, use_dynamic_shifting, matches reference space config
413
+ _firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
414
+ {
415
+ "base_image_seq_len": 256,
416
+ "base_shift": math.log(3),
417
+ "max_image_seq_len": 8192,
418
+ "max_shift": math.log(3),
419
+ "num_train_timesteps": 1000,
420
+ "shift": 1.0,
421
+ "time_shift_type": "exponential",
422
+ "use_dynamic_shifting": True,
423
+ }
424
+ )
425
+
426
+ used_mb = torch.cuda.memory_allocated() // (1024**2)
427
+ print(f"[FireRed] Pipeline ready β€” total VRAM: {used_mb} MB")
428
+ return _firered_pipe
429
+
430
+
431
+ def _gallery_to_pil_list(gallery_value):
432
+ """Convert a Gradio Gallery value (list of various formats) to a list of PIL Images."""
433
+ pil_images = []
434
+ if not gallery_value:
435
+ return pil_images
436
+ for item in gallery_value:
437
+ try:
438
+ if isinstance(item, np.ndarray):
439
+ pil_images.append(Image.fromarray(item).convert("RGB"))
440
+ continue
441
+ if isinstance(item, Image.Image):
442
+ pil_images.append(item.convert("RGB"))
443
+ continue
444
+ # Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...}
445
+ if isinstance(item, dict):
446
+ img_data = item.get("image") or item
447
+ if isinstance(img_data, dict):
448
+ path = (
449
+ img_data.get("path")
450
+ or img_data.get("url")
451
+ or img_data.get("name")
452
+ )
453
+ else:
454
+ path = img_data
455
+ elif isinstance(item, (list, tuple)):
456
+ path = item[0]
457
+ else:
458
+ path = item
459
+ if path and os.path.exists(str(path)):
460
+ pil_images.append(Image.open(str(path)).convert("RGB"))
461
+ except Exception as e:
462
+ print(f"[FireRed] Could not load gallery image: {e}")
463
+ return pil_images
464
+
465
+
466
+ def _firered_resize(img):
467
+ """Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8."""
468
+ w, h = img.size
469
+ if max(w, h) > 1024:
470
+ if w > h:
471
+ nw, nh = 1024, int(1024 * h / w)
472
+ else:
473
+ nw, nh = int(1024 * w / h), 1024
474
+ else:
475
+ nw, nh = w, h
476
+ nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8)
477
+ if (nw, nh) != (w, h):
478
+ img = img.resize((nw, nh), Image.LANCZOS)
479
+ return img
480
+
481
+
482
+ _FIRERED_NEGATIVE = (
483
+ "worst quality, low quality, bad anatomy, bad hands, text, error, "
484
+ "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
485
+ "signature, watermark, username, blurry"
486
+ )
487
+
488
+
489
+ def firered_generate(
490
+ gallery_images,
491
+ prompt,
492
+ seed,
493
+ randomize_seed,
494
+ guidance_scale,
495
+ steps,
496
+ progress=gr.Progress(),
497
+ ):
498
+ """Run FireRed image-edit inference on one or more reference images (max 3 natively)."""
499
+ pil_images = _gallery_to_pil_list(gallery_images)
500
+ if not pil_images:
501
+ return None, int(seed), "Please upload at least one image."
502
+ if not prompt or not prompt.strip():
503
+ return None, int(seed), "Please enter an edit prompt."
504
+ try:
505
+ import gc
506
+
507
+ progress(0.05, desc="Loading FireRed pipeline...")
508
+ pipe = load_firered()
509
+
510
+ if randomize_seed:
511
+ seed = random.randint(0, 2**31 - 1)
512
+
513
+ # FireRed natively handles 1-3 images; cap silently and warn
514
+ if len(pil_images) > 3:
515
+ print(
516
+ f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit)."
517
+ )
518
+ pil_images = pil_images[:3]
519
+
520
+ # Resize to max 1024px and align to multiple of 8 (prevents padding bars)
521
+ pil_images = [_firered_resize(img) for img in pil_images]
522
+ height, width = pil_images[0].height, pil_images[0].width
523
+ print(f"[FireRed] Input size after resize: {width}x{height}")
524
+
525
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
526
+
527
+ progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...")
528
+ with torch.inference_mode():
529
+ result = pipe(
530
+ image=pil_images,
531
+ prompt=prompt.strip(),
532
+ negative_prompt=_FIRERED_NEGATIVE,
533
+ num_inference_steps=int(steps),
534
+ generator=generator,
535
+ true_cfg_scale=float(guidance_scale),
536
+ num_images_per_prompt=1,
537
+ height=height,
538
+ width=width,
539
+ ).images[0]
540
+
541
+ gc.collect()
542
+ torch.cuda.empty_cache()
543
+ progress(1.0, desc="Done!")
544
+ n = len(pil_images)
545
+ note = (
546
+ " (truncated to 3)"
547
+ if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3
548
+ else ""
549
+ )
550
+ return np.array(result), int(seed), f"Preview ready β€” {n} image(s) used{note}."
551
+ except Exception:
552
+ return None, int(seed), f"FireRed error:\n{traceback.format_exc()}"
553
+
554
+
555
+ def firered_load_into_pipeline(
556
+ firered_output, threshold, erode_px, progress=gr.Progress()
557
+ ):
558
+ """Load a FireRed output into the main pipeline with automatic background removal."""
559
+ if firered_output is None:
560
+ return None, None, "No FireRed output β€” generate an image first."
561
+ try:
562
+ progress(0.1, desc="Loading RMBG model...")
563
+ load_rmbg_only()
564
+
565
+ img = Image.fromarray(firered_output).convert("RGB")
566
+ if _rmbg_net is not None:
567
+ progress(0.5, desc="Removing background...")
568
+ composited = _remove_bg_rmbg(
569
+ img, threshold=float(threshold), erode_px=int(erode_px)
570
+ )
571
+ result = np.array(composited)
572
+ msg = "Loaded into pipeline β€” background removed."
573
+ else:
574
+ result = firered_output
575
+ msg = "Loaded into pipeline (RMBG unavailable β€” background not removed)."
576
+
577
+ progress(1.0, desc="Done!")
578
+ return result, result, msg
579
+ except Exception:
580
+ return None, None, f"Error:\n{traceback.format_exc()}"
581
+
582
+
583
+ def generate_shape(
584
+ input_image,
585
+ remove_background,
586
+ num_steps,
587
+ guidance_scale,
588
+ seed,
589
+ face_count,
590
+ progress=gr.Progress(),
591
+ ):
592
+ if input_image is None:
593
+ return None, "Please upload an image."
594
+ try:
595
+ progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...")
596
+ global _firered_pipe
597
+ if _firered_pipe is not None:
598
+ # dispatch_model attaches accelerate hooks β€” remove them before .to("cpu")
599
+ try:
600
+ from accelerate.hooks import remove_hook_from_submodules
601
+
602
+ remove_hook_from_submodules(_firered_pipe.transformer)
603
+ _firered_pipe.transformer.to("cpu")
604
+ except Exception as _e:
605
+ print(f"[TripoSG] Transformer CPU offload: {_e}")
606
+ try:
607
+ _firered_pipe.text_encoder.to("cpu")
608
+ except Exception as _e:
609
+ print(f"[TripoSG] TextEncoder CPU offload: {_e}")
610
+ try:
611
+ _firered_pipe.vae.to("cpu")
612
+ except Exception as _e:
613
+ print(f"[TripoSG] VAE CPU offload: {_e}")
614
+ # Mark pipe for full reload next FireRed call (hooks are gone)
615
+ _firered_pipe = None
616
+ torch.cuda.empty_cache()
617
+ print("[TripoSG] FireRed offloaded β€” VRAM freed for shape generation.")
618
+
619
+ progress(0.1, desc="Loading TripoSG...")
620
+ sys.path.insert(0, TRIPOSG_DIR)
621
+ from scripts.inference_triposg import run_triposg
622
+ from scripts.image_process import prepare_image
623
+
624
+ pipe, rmbg_net = load_triposg()
625
+
626
+ img = Image.fromarray(input_image).convert("RGB")
627
+ img_path = str(TMP_DIR / "triposg_input.png")
628
+ img.save(img_path)
629
+
630
+ progress(0.5, desc="Generating shape (SDF diffusion)...")
631
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
632
+ mesh = run_triposg(
633
+ pipe=pipe,
634
+ image_input=img_path,
635
+ rmbg_net=rmbg_net, # always pass; TripoSG always calls it internally
636
+ seed=int(seed),
637
+ num_inference_steps=int(num_steps),
638
+ guidance_scale=float(guidance_scale),
639
+ faces=int(face_count) if int(face_count) > 0 else -1,
640
+ )
641
+
642
+ out_path = str(TMP_DIR / "triposg_shape.glb")
643
+ mesh.export(out_path)
644
+
645
+ # Offload models to CPU to free VRAM for texture subprocess
646
+ _triposg_pipe.to("cpu")
647
+ if _rmbg_net is not None:
648
+ _rmbg_net.to("cpu")
649
+ torch.cuda.empty_cache()
650
+
651
+ return out_path, "Shape generated!"
652
+ except Exception:
653
+ return None, f"Error:\n{traceback.format_exc()}"
654
+
655
+
656
+ def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
657
+ """
658
+ Remove background using RMBG (2.0 or 1.4), return RGB composited on neutral gray.
659
+ threshold : float [0,1] β€” mask confidence cutoff; raise to cut more background
660
+ erode_px : int β€” shrink mask by this many pixels to remove fringe
661
+ """
662
+ import torch
663
+ import numpy as np
664
+ import torchvision.transforms.functional as TF
665
+ from torchvision import transforms
666
+
667
+ if _rmbg_net is None:
668
+ return img_pil
669
+
670
+ device = next(_rmbg_net.parameters()).device
671
+ _rmbg_net.eval()
672
+
673
+ # Resize and preprocess
674
+ img_resized = img_pil.resize((1024, 1024))
675
+ img_tensor = transforms.ToTensor()(img_resized)
676
+ img_tensor = TF.normalize(
677
+ img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
678
+ ).unsqueeze(0).to(device)
679
+
680
+ with torch.no_grad():
681
+ result = _rmbg_net(img_tensor)
682
+
683
+ # Handle both RMBG-2.0 (returns list) and RMBG-1.4 (returns tensor)
684
+ if isinstance(result, (list, tuple)):
685
+ candidate = result[-1]
686
+ if isinstance(candidate, (list, tuple)):
687
+ candidate = candidate[0]
688
+ else:
689
+ candidate = result
690
+
691
+ # Extract mask and apply sigmoid if needed
692
+ if candidate.dim() == 4:
693
+ mask_tensor = candidate[0, 0]
694
+ else:
695
+ mask_tensor = candidate
696
+
697
+ if mask_tensor.max() > 1.0: # Already in [0, 1] after sigmoid
698
+ mask_tensor = torch.sigmoid(mask_tensor)
699
+
700
+ mask_pil = transforms.ToPILImage()(mask_tensor.cpu())
701
+ mask = np.array(mask_pil.resize(img_pil.size, Image.BILINEAR), dtype=np.float32) / 255.0
702
+
703
+ # Apply threshold
704
+ mask = (mask >= threshold).astype(np.float32) * mask
705
+
706
+ # Erode mask to remove background fringe
707
+ if erode_px > 0:
708
+ import cv2 as _cv2
709
+ kernel = _cv2.getStructuringElement(_cv2.MORPH_ELLIPSE, (erode_px * 2 + 1,) * 2)
710
+ mask = _cv2.erode((mask * 255).astype(np.uint8), kernel).astype(np.float32) / 255.0
711
+
712
+ # Composite on gray background
713
+ rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
714
+ alpha = mask[:, :, np.newaxis]
715
+ composited = rgb * alpha + 0.5 * (1.0 - alpha)
716
+ composited = (composited * 255).clip(0, 255).astype(np.uint8)
717
+ return Image.fromarray(composited)
718
+
719
+
720
+ def _load_realesrgan(scale: int = 4):
721
+ """Load RealESRGAN upsampler (x4plus by default). Returns RealESRGANer or None."""
722
+ try:
723
+ from basicsr.archs.rrdbnet_arch import RRDBNet
724
+ from realesrgan import RealESRGANer
725
+
726
+ if scale == 4:
727
+ model_path = os.path.join(CKPT_DIR, "RealESRGAN_x4plus.pth")
728
+ model = RRDBNet(
729
+ num_in_ch=3,
730
+ num_out_ch=3,
731
+ num_feat=64,
732
+ num_block=23,
733
+ num_grow_ch=32,
734
+ scale=4,
735
+ )
736
+ else:
737
+ model_path = os.path.join(CKPT_DIR, "RealESRGAN_x2plus.pth")
738
+ model = RRDBNet(
739
+ num_in_ch=3,
740
+ num_out_ch=3,
741
+ num_feat=64,
742
+ num_block=23,
743
+ num_grow_ch=32,
744
+ scale=2,
745
+ )
746
+ if not os.path.exists(model_path):
747
+ print(f"[RealESRGAN] {model_path} not found")
748
+ return None
749
+ upsampler = RealESRGANer(
750
+ scale=scale,
751
+ model_path=model_path,
752
+ model=model,
753
+ tile=512,
754
+ tile_pad=32,
755
+ pre_pad=0,
756
+ half=True,
757
+ )
758
+ print(f"[RealESRGAN] Loaded x{scale}plus")
759
+ return upsampler
760
+ except Exception as e:
761
+ print(f"[RealESRGAN] Load failed: {e}")
762
+ return None
763
+
764
+
765
+ def _enhance_glb_texture(glb_path: str) -> bool:
766
+ """
767
+ Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4,
768
+ downscale back to original resolution (sharper detail), then repack in-place.
769
+ Returns True if enhancement was applied.
770
+ """
771
+ import pygltflib
772
+
773
+ upsampler = _load_realesrgan(scale=4)
774
+ if upsampler is None:
775
+ # Try x2 fallback
776
+ upsampler = _load_realesrgan(scale=2)
777
+ if upsampler is None:
778
+ print("[enhance_glb] No RealESRGAN checkpoint available")
779
+ return False
780
+
781
+ glb = pygltflib.GLTF2().load(glb_path)
782
+ blob = bytearray(glb.binary_blob() or b"")
783
+
784
+ for mat in glb.materials:
785
+ bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None)
786
+ if bct is None:
787
+ continue
788
+ tex = glb.textures[bct.index]
789
+ if tex.source is None:
790
+ continue
791
+ img_obj = glb.images[tex.source]
792
+ if img_obj.bufferView is None:
793
+ continue
794
+ bv = glb.bufferViews[img_obj.bufferView]
795
+ offset, length = bv.byteOffset or 0, bv.byteLength
796
+
797
+ img_arr = np.frombuffer(blob[offset : offset + length], dtype=np.uint8)
798
+ atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
799
+ if atlas_bgr is None:
800
+ continue
801
+
802
+ orig_h, orig_w = atlas_bgr.shape[:2]
803
+ print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGAN…")
804
+
805
+ try:
806
+ upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4)
807
+ except Exception as e:
808
+ print(f"[enhance_glb] RealESRGAN enhance failed: {e}")
809
+ continue
810
+
811
+ # Downscale back to original resolution β€” net effect: sharper details
812
+ restored = cv2.resize(
813
+ upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4
814
+ )
815
+
816
+ ok, new_bytes = cv2.imencode(".png", restored)
817
+ if not ok:
818
+ continue
819
+ new_bytes = new_bytes.tobytes()
820
+ new_len = len(new_bytes)
821
+
822
+ if new_len > length:
823
+ before = bytes(blob[:offset])
824
+ after = bytes(blob[offset + length :])
825
+ blob = bytearray(before + new_bytes + after)
826
+ delta = new_len - length
827
+ bv.byteLength = new_len
828
+ for other_bv in glb.bufferViews:
829
+ if (other_bv.byteOffset or 0) > offset:
830
+ other_bv.byteOffset += delta
831
+ glb.buffers[0].byteLength += delta
832
+ else:
833
+ blob[offset : offset + new_len] = new_bytes
834
+ bv.byteLength = new_len
835
+
836
+ glb.set_binary_blob(bytes(blob))
837
+ glb.save(glb_path)
838
+ print(f"[enhance_glb] GLB texture enhanced OK (was {length}B β†’ {new_len}B)")
839
+ return True
840
+
841
+ print("[enhance_glb] No base-color texture found in GLB")
842
+ return False
843
+
844
+
845
+ def apply_texture(
846
+ glb_path,
847
+ input_image,
848
+ remove_background,
849
+ variant,
850
+ tex_seed,
851
+ enhance_face,
852
+ rembg_threshold=0.5,
853
+ rembg_erode=2,
854
+ progress=gr.Progress(),
855
+ ):
856
+ if glb_path is None:
857
+ glb_path = str(TMP_DIR / "triposg_shape.glb")
858
+ if not os.path.exists(glb_path):
859
+ return None, None, "Generate a shape first."
860
+ if input_image is None:
861
+ return None, None, "Please upload an image."
862
+ try:
863
+ progress(0.1, desc="Preprocessing image...")
864
+ img = Image.fromarray(input_image).convert("RGB")
865
+
866
+ # Save original photo before any processing β€” used as HyperSwap face source
867
+ face_ref_path = str(TMP_DIR / "triposg_face_ref.png")
868
+ img.save(face_ref_path)
869
+
870
+ if remove_background and _rmbg_net is not None:
871
+ img = _remove_bg_rmbg(
872
+ img, threshold=float(rembg_threshold), erode_px=int(rembg_erode)
873
+ )
874
+
875
+ img = img.resize((768, 768), Image.LANCZOS)
876
+ img_path = str(TMP_DIR / "tex_input.png")
877
+ img.save(img_path)
878
+
879
+ # Free GPU memory before launching SDXL subprocess (~15 GB peak)
880
+ import gc
881
+
882
+ gc.collect()
883
+ torch.cuda.empty_cache()
884
+
885
+ out_dir = str(TMP_DIR / "tex_out")
886
+ os.makedirs(out_dir, exist_ok=True)
887
+ out_name = "textured"
888
+
889
+ cmd = [
890
+ PYTHON,
891
+ "-m",
892
+ "scripts.texture_i2tex",
893
+ "--mesh",
894
+ glb_path,
895
+ "--image",
896
+ img_path,
897
+ "--save_dir",
898
+ out_dir,
899
+ "--save_name",
900
+ out_name,
901
+ "--variant",
902
+ variant,
903
+ "--seed",
904
+ str(int(tex_seed)),
905
+ "--device",
906
+ DEVICE,
907
+ "--reference_conditioning_scale",
908
+ "1.5",
909
+ "--text",
910
+ "photorealistic person, detailed skin texture, realistic clothing",
911
+ "--preprocess_mesh",
912
+ ]
913
+ # face enhancement is handled in-app after texture subprocess returns
914
+
915
+ progress(0.3, desc="Running MV-Adapter SDXL...")
916
+ env = _build_texture_env()
917
+
918
+ result = subprocess.run(
919
+ cmd,
920
+ cwd=MVADAPTER_DIR,
921
+ capture_output=True,
922
+ text=True,
923
+ timeout=3600,
924
+ env=env,
925
+ )
926
+
927
+ out_glb = f"{out_dir}/{out_name}_shaded.glb"
928
+ mv_png = f"{out_dir}/{out_name}.png"
929
+
930
+ if os.path.exists(out_glb):
931
+ final_path = str(TMP_DIR / "triposg_textured.glb")
932
+ shutil.copy(out_glb, final_path)
933
+
934
+ # Face enhancement: extract UV texture atlas from GLB, run GFPGAN, repack
935
+ face_enhanced = False
936
+ if enhance_face:
937
+ try:
938
+ import pygltflib
939
+
940
+ face_enhanced = _enhance_glb_texture(final_path)
941
+ except Exception as _fe:
942
+ print(f"[enhance_glb] {_fe}")
943
+
944
+ mv_out = mv_png if os.path.exists(mv_png) else None
945
+ label = "Texture applied" + (" + face enhanced!" if face_enhanced else "!")
946
+ global _last_glb_path
947
+ _last_glb_path = final_path
948
+ return final_path, mv_out, label
949
+ else:
950
+ combined = (result.stdout or "") + (result.stderr or "")
951
+ err = combined[-3000:] if combined else "No output (exit code %d)" % result.returncode
952
+ return None, None, f"Texture failed:\n{err}"
953
+ except Exception:
954
+ return None, None, f"Error:\n{traceback.format_exc()}"
955
+
956
+
957
+ def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
958
+ """Preview REMBG result on upload. Returns composited RGB numpy array."""
959
+ if input_image is None:
960
+ return None
961
+ if not do_remove_bg:
962
+ return input_image
963
+ if _rmbg_net is None:
964
+ return input_image # models not loaded yet β€” skip blocking load
965
+ try:
966
+ img = Image.fromarray(input_image).convert("RGB")
967
+ composited = _remove_bg_rmbg(
968
+ img, threshold=float(threshold), erode_px=int(erode_px)
969
+ )
970
+ return np.array(composited)
971
+ except Exception:
972
+ return input_image
973
+
974
+
975
+ def render_views(glb_file):
976
+ """Render a GLB from 5 standard angles using nvdiffrast."""
977
+ if not glb_file:
978
+ return []
979
+ if isinstance(glb_file, str):
980
+ glb_path = glb_file
981
+ elif isinstance(glb_file, dict):
982
+ glb_path = glb_file.get("path") or glb_file.get("name") or ""
983
+ else:
984
+ glb_path = str(glb_file)
985
+ if not glb_path or not os.path.exists(glb_path):
986
+ msg = f"render_views: GLB not found ({glb_path!r})"
987
+ print(msg)
988
+ return [{"image": None, "caption": msg}]
989
+ print(f"render_views: loading {glb_path} ({os.path.getsize(glb_path) // 1024}KB)")
990
+ try:
991
+ sys.path.insert(0, MVADAPTER_DIR)
992
+ print("render_views: importing nvdiffrast utils...")
993
+ from mvadapter.utils.mesh_utils import (
994
+ NVDiffRastContextWrapper,
995
+ load_mesh,
996
+ render,
997
+ get_orthogonal_camera,
998
+ )
999
+
1000
+ device = "cuda"
1001
+ ctx = NVDiffRastContextWrapper(device=device, context_type="cuda")
1002
+ print("render_views: loading mesh...")
1003
+ mesh = load_mesh(glb_path, rescale=True, device=device)
1004
+ print(f"render_views: mesh loaded, rendering...")
1005
+
1006
+ azimuth_deg = [x - 90 for x in [0, 45, 90, 180, 315]]
1007
+ cameras = get_orthogonal_camera(
1008
+ elevation_deg=[0, 0, 0, 0, 0],
1009
+ distance=[1.8] * 5,
1010
+ left=-0.55,
1011
+ right=0.55,
1012
+ bottom=-0.55,
1013
+ top=0.55,
1014
+ azimuth_deg=azimuth_deg,
1015
+ device=device,
1016
+ )
1017
+
1018
+ render_out = render(
1019
+ ctx,
1020
+ mesh,
1021
+ cameras,
1022
+ height=1024,
1023
+ width=768,
1024
+ render_attr=True,
1025
+ normal_background=0.0,
1026
+ )
1027
+ print(f"render_views: render complete, attr shape={render_out.attr.shape}")
1028
+
1029
+ names = ["front", "3q_front", "side", "back", "3q_back"]
1030
+ save_dir = os.path.dirname(glb_path)
1031
+ results = []
1032
+ for i, name in enumerate(names):
1033
+ arr = (render_out.attr[i].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
1034
+ path = os.path.join(save_dir, f"render_{name}.png")
1035
+ Image.fromarray(arr).save(path)
1036
+ results.append((path, name))
1037
+ print(f"render_views: saved {name} -> {path}")
1038
+
1039
+ return results
1040
+ except Exception:
1041
+ err = traceback.format_exc()
1042
+ print(f"render_views FAILED:\n{err}")
1043
+ return []
1044
+
1045
+
1046
+ def hyperswap_views(embedding_json: str):
1047
+ """
1048
+ Stage 6 β€” run HyperSwap on the last rendered views.
1049
+ embedding_json: JSON string of the 512-d ArcFace embedding list.
1050
+ Returns a gallery of (swapped_image_path, view_name) tuples.
1051
+ """
1052
+ global _hyperswap_sess
1053
+ try:
1054
+ import onnxruntime as ort
1055
+ from insightface.app import FaceAnalysis
1056
+
1057
+ embedding = np.array(json.loads(embedding_json), dtype=np.float32)
1058
+ embedding /= np.linalg.norm(embedding)
1059
+
1060
+ # Load HyperSwap once
1061
+ if _hyperswap_sess is None:
1062
+ hs_path = os.path.join(CKPT_DIR, "hyperswap_1a_256.onnx")
1063
+ _hyperswap_sess = ort.InferenceSession(
1064
+ hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
1065
+ )
1066
+ print(f"[hyperswap_views] Loaded {hs_path}")
1067
+
1068
+ app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
1069
+ app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1)
1070
+
1071
+ results = []
1072
+ for view_path, name in zip(VIEW_PATHS, VIEW_NAMES):
1073
+ if not os.path.exists(view_path):
1074
+ print(f"[hyperswap_views] Missing {view_path}, skipping")
1075
+ continue
1076
+
1077
+ bgr = cv2.imread(view_path)
1078
+ faces = app.get(bgr)
1079
+ if not faces:
1080
+ print(f"[hyperswap_views] {name}: no face detected")
1081
+ out_path = view_path # return original
1082
+ else:
1083
+ face = faces[0]
1084
+ M, _ = cv2.estimateAffinePartial2D(
1085
+ face.kps, ARCFACE_256, method=cv2.RANSAC, ransacReprojThreshold=100
1086
+ )
1087
+ H, W = bgr.shape[:2]
1088
+ aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR)
1089
+ t = (
1090
+ ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1]
1091
+ .copy()
1092
+ .transpose(2, 0, 1)[None]
1093
+ )
1094
+ out, mask = _hyperswap_sess.run(
1095
+ None,
1096
+ {
1097
+ "source": embedding.reshape(1, -1),
1098
+ "target": t,
1099
+ },
1100
+ )
1101
+ out_bgr = (
1102
+ ((out[0].transpose(1, 2, 0) + 1) / 2 * 255)
1103
+ .clip(0, 255)
1104
+ .astype(np.uint8)
1105
+ )[:, :, ::-1].copy()
1106
+ m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8)
1107
+ Mi = cv2.invertAffineTransform(M)
1108
+ of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR)
1109
+ mf = (
1110
+ cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(
1111
+ np.float32
1112
+ )[:, :, None]
1113
+ / 255
1114
+ )
1115
+ swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8)
1116
+
1117
+ # GFPGAN face restoration β€” use the SAME bbox from the already-detected face
1118
+ # (avoids re-running InsightFace at det_thresh=0.1 which can latch onto skin/body)
1119
+ restorer = load_gfpgan()
1120
+ if restorer is not None:
1121
+ b = face.bbox.astype(int)
1122
+ h2, w2 = swapped.shape[:2]
1123
+ pad = 0.35
1124
+ bw2, bh2 = b[2] - b[0], b[3] - b[1]
1125
+ cx1 = max(0, b[0] - int(bw2 * pad))
1126
+ cy1 = max(0, b[1] - int(bh2 * pad))
1127
+ cx2 = min(w2, b[2] + int(bw2 * pad))
1128
+ cy2 = min(h2, b[3] + int(bh2 * pad))
1129
+ crop = swapped[cy1:cy2, cx1:cx2]
1130
+ try:
1131
+ _, _, rest = restorer.enhance(
1132
+ crop,
1133
+ has_aligned=False,
1134
+ only_center_face=True,
1135
+ paste_back=True,
1136
+ weight=0.5,
1137
+ )
1138
+ if rest is not None:
1139
+ ch, cw = cy2 - cy1, cx2 - cx1
1140
+ if rest.shape[:2] != (ch, cw):
1141
+ rest = cv2.resize(
1142
+ rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4
1143
+ )
1144
+ swapped[cy1:cy2, cx1:cx2] = rest
1145
+ except Exception as _ge:
1146
+ print(f"[hyperswap_views] GFPGAN failed: {_ge}")
1147
+
1148
+ out_path = view_path.replace("render_", "swapped_")
1149
+ cv2.imwrite(out_path, swapped)
1150
+ print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}")
1151
+
1152
+ results.append((out_path, name))
1153
+
1154
+ return results
1155
+ except Exception:
1156
+ err = traceback.format_exc()
1157
+ print(f"hyperswap_views FAILED:\n{err}")
1158
+ return []
1159
+
1160
+
1161
+ def gradio_tpose(glb_state_path, export_skel_flag, progress=gr.Progress()):
1162
+ """Rig surface mesh with YOLO-pose + optionally export SKEL bone mesh."""
1163
+ try:
1164
+ glb = glb_state_path or _last_glb_path or str(TMP_DIR / "triposg_textured.glb")
1165
+ if not os.path.exists(glb):
1166
+ return (
1167
+ None,
1168
+ None,
1169
+ "No GLB found β€” run Generate Shape + Apply Texture first.",
1170
+ )
1171
+
1172
+ # Surface: YOLO-rig (replaces broken inverse-LBS T-pose)
1173
+ progress(0.1, desc="YOLO pose detection + rigging surface ...")
1174
+ sys.path.insert(0, "/root")
1175
+ from rig_yolo import rig_yolo
1176
+
1177
+ out_dir = str(TMP_DIR / "rig_out")
1178
+ os.makedirs(out_dir, exist_ok=True)
1179
+ rigged, _rigged_skel = rig_yolo(
1180
+ glb, os.path.join(out_dir, "anatomy_rigged.glb"), debug_dir=None
1181
+ )
1182
+
1183
+ # SKEL bone mesh (zero-pose T-posed skeleton)
1184
+ bones = None
1185
+ if export_skel_flag:
1186
+ progress(0.7, desc="Generating SKEL bone mesh ...")
1187
+ import torch
1188
+ from tpose_smpl import export_skel_bones
1189
+
1190
+ bones = export_skel_bones(
1191
+ torch.zeros(10), str(TMP_DIR / "tposed_bones.glb"), gender="male"
1192
+ )
1193
+
1194
+ status = f"Rigged surface: {os.path.getsize(rigged) // 1024} KB"
1195
+ if bones:
1196
+ status += f"\nSKEL bone mesh: {os.path.getsize(bones) // 1024} KB"
1197
+ elif export_skel_flag:
1198
+ status += "\nSKEL bone mesh: failed (check logs)"
1199
+ progress(1.0, desc="Done!")
1200
+ return rigged, bones, status
1201
+ except Exception:
1202
+ return None, None, f"Error:\n{traceback.format_exc()}"
1203
+
1204
+
1205
+ UNIRIG_DIR = "/root/UniRig"
1206
+ UNIRIG_PY = "/root/miniconda/envs/unirig/bin/python"
1207
+ UNIRIG_BASH = "/root/miniconda/envs/unirig/bin" # prepended to PATH for launch scripts
1208
+
1209
+
1210
+ def _run_unirig(glb_path: str, out_dir: str) -> str:
1211
+ """
1212
+ Run the 3-step UniRig pipeline on a textured GLB.
1213
+ Returns path to the final rigged GLB, or raises on failure.
1214
+ """
1215
+ if not os.path.exists(UNIRIG_PY):
1216
+ raise RuntimeError("UniRig conda env not found β€” run setup_unirig.sh first")
1217
+
1218
+ os.makedirs(out_dir, exist_ok=True)
1219
+ skel_fbx = os.path.join(out_dir, "skeleton.fbx")
1220
+ skin_fbx = os.path.join(out_dir, "skin.fbx")
1221
+ rigged = os.path.join(out_dir, "rigged.glb")
1222
+
1223
+ env = os.environ.copy()
1224
+ env["PATH"] = f"{UNIRIG_BASH}:{env.get('PATH', '')}"
1225
+ env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
1226
+ env.setdefault("CUDA_VISIBLE_DEVICES", "0")
1227
+
1228
+ def _launch(script: str, extra_args: list[str]):
1229
+ sh = os.path.join(UNIRIG_DIR, "launch", "inference", script)
1230
+ cmd = ["bash", sh] + extra_args
1231
+ r = subprocess.run(
1232
+ cmd, cwd=UNIRIG_DIR, capture_output=True, text=True, timeout=300, env=env
1233
+ )
1234
+ if r.returncode != 0:
1235
+ raise RuntimeError(f"{script} failed:\n{r.stderr[-2000:]}")
1236
+ return r
1237
+
1238
+ print("[UniRig] Step 1/3 β€” generate skeleton...")
1239
+ _launch("generate_skeleton.sh", ["--input", glb_path, "--output", skel_fbx])
1240
+
1241
+ print("[UniRig] Step 2/3 β€” generate skinning...")
1242
+ _launch("generate_skin.sh", ["--input", skel_fbx, "--output", skin_fbx])
1243
+
1244
+ print("[UniRig] Step 3/3 β€” merge rig into mesh...")
1245
+ _launch(
1246
+ "merge.sh", ["--source", skin_fbx, "--target", glb_path, "--output", rigged]
1247
+ )
1248
+
1249
+ # UniRig ignores --output dir and always writes to /tmp/rig_out/rigged.glb
1250
+ # Fall back to that location if the requested path isn't populated.
1251
+ if not os.path.exists(rigged):
1252
+ fallback = str(TMP_DIR / "rig_out" / "rigged.glb")
1253
+ if os.path.exists(fallback):
1254
+ import shutil
1255
+
1256
+ shutil.copy2(fallback, rigged)
1257
+ else:
1258
+ raise RuntimeError(
1259
+ f"UniRig finished but output not found at {rigged} or {fallback}"
1260
+ )
1261
+
1262
+ print(f"[UniRig] Done β€” {os.path.getsize(rigged) // 1024} KB")
1263
+ return rigged
1264
+
1265
+
1266
+ def gradio_rig(
1267
+ input_image,
1268
+ glb_state_path,
1269
+ export_fbx_flag,
1270
+ pshuman_weight_threshold: float,
1271
+ pshuman_retract_mm: float,
1272
+ progress=gr.Progress(),
1273
+ ):
1274
+ """
1275
+ Rig pipeline β€” three stages run automatically in one click:
1276
+ 1. UniRig: skeleton + skinning weights on the TripoSG mesh
1277
+ 2. PSHuman: generate HD face from portrait (RMBG β†’ RGBA β†’ subprocess)
1278
+ 3. Face transplant: stitch PSHuman face into rigged mesh via bone-weight
1279
+ head detection + KNN weight transfer β†’ final rigged+HD-face GLB
1280
+ If no portrait is available, stages 2-3 are skipped.
1281
+ """
1282
+ try:
1283
+ glb = glb_state_path or _last_glb_path or str(TMP_DIR / "triposg_textured.glb")
1284
+ if not os.path.exists(glb):
1285
+ return (
1286
+ None,
1287
+ None,
1288
+ None,
1289
+ "No GLB found β€” run Generate Shape + Apply Texture first.",
1290
+ None,
1291
+ None,
1292
+ None,
1293
+ )
1294
+
1295
+ out_dir = str(TMP_DIR / "rig_out")
1296
+ os.makedirs(out_dir, exist_ok=True)
1297
+
1298
+ # ── Stage 1: UniRig ───────────────────────────────────────────────────
1299
+ progress(0.05, desc="Stage 1/3: UniRig β€” generating skeleton + skinning...")
1300
+ rigged = _run_unirig(glb, out_dir)
1301
+ final = rigged
1302
+
1303
+ # ── Stage 2+3: PSHuman face (only if portrait is loaded) ───────��─────
1304
+ if input_image is not None:
1305
+ try:
1306
+ _meshforge_dir = os.path.join(
1307
+ os.path.dirname(os.path.abspath(__file__)), "MeshForge"
1308
+ )
1309
+ if not os.path.isdir(_meshforge_dir):
1310
+ _meshforge_dir = os.path.dirname(os.path.abspath(__file__))
1311
+ if _meshforge_dir not in sys.path:
1312
+ sys.path.insert(0, _meshforge_dir)
1313
+
1314
+ work_dir = tempfile.mkdtemp(prefix="pshuman_rig_")
1315
+ img_path = os.path.join(work_dir, "portrait.png")
1316
+
1317
+ progress(
1318
+ 0.6,
1319
+ desc="Stage 2/3: PSHuman β€” RMBG + multi-view face generation...",
1320
+ )
1321
+ pil_img = (
1322
+ Image.fromarray(input_image)
1323
+ if isinstance(input_image, np.ndarray)
1324
+ else input_image
1325
+ )
1326
+ rgba = _portrait_to_rgba(pil_img)
1327
+ rgba.save(img_path)
1328
+
1329
+ from pipeline.pshuman_client import generate_pshuman_mesh
1330
+
1331
+ face_obj = os.path.join(work_dir, "pshuman_face.obj")
1332
+ generate_pshuman_mesh(
1333
+ image_path=img_path, output_path=face_obj, service_url="direct"
1334
+ )
1335
+
1336
+ progress(
1337
+ 0.85,
1338
+ desc="Stage 3/3: Face transplant β€” stitching into rigged mesh...",
1339
+ )
1340
+ from pipeline.face_transplant import transplant_face
1341
+
1342
+ final = os.path.join(work_dir, "rigged_hd_face.glb")
1343
+ transplant_face(
1344
+ body_glb_path=rigged,
1345
+ pshuman_mesh_path=face_obj,
1346
+ output_path=final,
1347
+ weight_threshold=float(pshuman_weight_threshold),
1348
+ retract_amount=float(pshuman_retract_mm) / 1000.0,
1349
+ )
1350
+ print(f"[rig] PSHuman face transplant complete: {final}")
1351
+ except Exception as _pse:
1352
+ print(
1353
+ f"[rig] PSHuman stage failed, using plain rig: {_pse}\n{traceback.format_exc()}"
1354
+ )
1355
+ final = rigged
1356
+
1357
+ fbx = None
1358
+ if export_fbx_flag:
1359
+ progress(0.92, desc="Exporting FBX...")
1360
+ try:
1361
+ sys.path.insert(0, "/root")
1362
+ from rig_stage import export_fbx as _export_fbx
1363
+
1364
+ fbx_path = os.path.join(out_dir, "rigged.fbx")
1365
+ fbx = fbx_path if _export_fbx(final, fbx_path) else None
1366
+ except Exception as _fe:
1367
+ print(f"[rig] FBX export failed: {_fe}")
1368
+
1369
+ had_pshuman = input_image is not None and final != rigged
1370
+ status_msg = (
1371
+ "Rigged + PSHuman HD face: " if had_pshuman else "Rigged: "
1372
+ ) + os.path.basename(final)
1373
+ if fbx:
1374
+ status_msg += " | FBX: " + os.path.basename(fbx)
1375
+ progress(1.0, desc="Done!")
1376
+ return final, None, fbx, status_msg, final, final, None
1377
+ except Exception:
1378
+ return None, None, None, f"Error:\n{traceback.format_exc()}", None, None, None
1379
+
1380
+
1381
+ def run_full_pipeline(
1382
+ input_image,
1383
+ remove_background,
1384
+ num_steps,
1385
+ guidance,
1386
+ seed,
1387
+ face_count,
1388
+ variant,
1389
+ tex_seed,
1390
+ enhance_face,
1391
+ rembg_threshold,
1392
+ rembg_erode,
1393
+ export_fbx,
1394
+ progress=gr.Progress(),
1395
+ ):
1396
+ """Single-click full pipeline: shape β†’ texture β†’ rig."""
1397
+ progress(0.0, desc="Stage 1/3: Generating shape...")
1398
+ glb, status = generate_shape(
1399
+ input_image, remove_background, num_steps, guidance, seed, face_count
1400
+ )
1401
+ if not glb:
1402
+ return None, None, None, None, None, None, status
1403
+
1404
+ progress(0.33, desc="Stage 2/3: Applying texture + face enhancement...")
1405
+ glb, mv_img, status = apply_texture(
1406
+ glb,
1407
+ input_image,
1408
+ remove_background,
1409
+ variant,
1410
+ tex_seed,
1411
+ enhance_face,
1412
+ rembg_threshold,
1413
+ rembg_erode,
1414
+ )
1415
+ if not glb:
1416
+ return None, None, None, None, None, None, status
1417
+
1418
+ progress(0.66, desc="Stage 3/3: Rigging (UniRig + PSHuman)...")
1419
+ rigged, animated, fbx, rig_status, _, _, _skel = gradio_rig(
1420
+ input_image, glb, export_fbx, 0.5, 2.0
1421
+ )
1422
+
1423
+ progress(1.0, desc="Pipeline complete!")
1424
+ combined_status = f"[Texture] {status}\n[Rig] {rig_status}"
1425
+ return glb, glb, mv_img, rigged, fbx, combined_status
1426
+
1427
+
1428
+ # ─────────────────────────────────────────────────────────────────────────────
1429
+ # Animate tab β€” motion search + bake
1430
+ # ─────────────────────────────────────────────────────────────────────────────
1431
+
1432
+
1433
+ def gradio_search_motions(query: str, progress=gr.Progress()):
1434
+ """Stream TeoGchx/HumanML3D and return matching motions as radio choices."""
1435
+ if not query.strip():
1436
+ return (
1437
+ gr.update(choices=[], visible=False),
1438
+ [],
1439
+ "Enter a motion description and click Search.",
1440
+ )
1441
+ try:
1442
+ progress(0.1, desc="Connecting to HumanML3D dataset…")
1443
+ sys.path.insert(0, "/root")
1444
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1445
+ from Retarget.search import search_motions, format_choice_label
1446
+
1447
+ progress(0.3, desc="Streaming dataset…")
1448
+ results = search_motions(query, top_k=8)
1449
+ progress(1.0)
1450
+ if not results:
1451
+ return (
1452
+ gr.update(
1453
+ choices=["No matches β€” try different keywords"], visible=True
1454
+ ),
1455
+ [],
1456
+ f"No motions matched '{query}'. Try broader terms.",
1457
+ )
1458
+ choices = [format_choice_label(r) for r in results]
1459
+ status = f"Found {len(results)} motions matching '{query}'"
1460
+ return (
1461
+ gr.update(choices=choices, value=choices[0], visible=True),
1462
+ results,
1463
+ status,
1464
+ )
1465
+ except Exception:
1466
+ return (
1467
+ gr.update(choices=[], visible=False),
1468
+ [],
1469
+ f"Search error:\n{traceback.format_exc()}",
1470
+ )
1471
+
1472
+
1473
+ def gradio_animate(
1474
+ rigged_glb_path,
1475
+ selected_label: str,
1476
+ motion_results: list,
1477
+ fps: int,
1478
+ max_frames: int,
1479
+ progress=gr.Progress(),
1480
+ ):
1481
+ """Bake selected HumanML3D motion onto the UniRig-rigged GLB."""
1482
+ try:
1483
+ glb = rigged_glb_path or str(TMP_DIR / "rig_out" / "rigged.glb")
1484
+ if not os.path.exists(glb):
1485
+ return None, "No rigged GLB β€” run the Rig step first.", None
1486
+
1487
+ if not motion_results or not selected_label:
1488
+ return None, "No motion selected β€” run Search first.", None
1489
+
1490
+ # Resolve which result was selected
1491
+ sys.path.insert(0, "/root")
1492
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1493
+ from Retarget.search import format_choice_label
1494
+
1495
+ idx = 0
1496
+ for i, r in enumerate(motion_results):
1497
+ if format_choice_label(r) == selected_label:
1498
+ idx = i
1499
+ break
1500
+
1501
+ chosen = motion_results[idx]
1502
+ motion = chosen["motion"] # np.ndarray [T, 263]
1503
+ caption = chosen["caption"]
1504
+ T_total = motion.shape[0]
1505
+ n_frames = min(max_frames, T_total) if max_frames > 0 else T_total
1506
+
1507
+ progress(0.2, desc="Parsing skeleton…")
1508
+ from Retarget.animate import animate_glb_from_hml3d
1509
+
1510
+ out_path = str(TMP_DIR / "animated_out" / "animated.glb")
1511
+ os.makedirs(str(TMP_DIR / "animated_out"), exist_ok=True)
1512
+
1513
+ progress(0.4, desc="Mapping bones to SMPL joints…")
1514
+ animated = animate_glb_from_hml3d(
1515
+ motion=motion,
1516
+ rigged_glb=glb,
1517
+ output_glb=out_path,
1518
+ fps=int(fps),
1519
+ num_frames=int(n_frames),
1520
+ )
1521
+ progress(1.0, desc="Done!")
1522
+ status = f"Animated: {n_frames} frames @ {fps} fps\nMotion: {caption[:120]}"
1523
+ return animated, status, animated
1524
+
1525
+ except Exception:
1526
+ return None, f"Error:\n{traceback.format_exc()}", None
1527
+
1528
+
1529
+ # ─────────────────────────────────────────────────────────────────────────────
1530
+ # PSHuman Face Transplant tab
1531
+ # ─────────────────────────────────────────────────────────────────────────────
1532
+
1533
+
1534
+ def _portrait_to_rgba(img_pil: Image.Image) -> Image.Image:
1535
+ """
1536
+ Run RMBG on a portrait and return an RGBA PIL image where alpha = foreground mask.
1537
+ PSHuman's dataset loader expects RGBA β€” it reads channel 3 as the alpha/mask.
1538
+ Falls back to fully-opaque RGBA if RMBG is unavailable.
1539
+ """
1540
+ import torchvision.transforms.functional as _TF
1541
+ from torchvision import transforms as _tvt
1542
+
1543
+ load_rmbg_only()
1544
+ if _rmbg_net is None:
1545
+ return img_pil.convert("RGBA")
1546
+
1547
+ # Run on CPU β€” keeps GPU free for the PSHuman subprocess that follows
1548
+ _rmbg_net.to("cpu").eval()
1549
+
1550
+ src = img_pil.convert("RGB")
1551
+ img_t = _tvt.ToTensor()(src.resize((1024, 1024)))
1552
+ img_t = _TF.normalize(
1553
+ img_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
1554
+ ).unsqueeze(0)
1555
+ with torch.no_grad():
1556
+ result = _rmbg_net(img_t)
1557
+ if isinstance(result, (list, tuple)):
1558
+ candidate = result[-1]
1559
+ if isinstance(candidate, (list, tuple)):
1560
+ candidate = candidate[0]
1561
+ else:
1562
+ candidate = result
1563
+
1564
+ mask_t = candidate.sigmoid()[0, 0].cpu()
1565
+ mask_pil = _tvt.ToPILImage()(mask_t).resize(src.size, Image.BILINEAR)
1566
+
1567
+ rgba = src.convert("RGBA")
1568
+ rgba.putalpha(mask_pil)
1569
+ return rgba
1570
+
1571
+
1572
+ def gradio_pshuman_face(
1573
+ input_image,
1574
+ rigged_glb_path,
1575
+ weight_threshold: float,
1576
+ retract_mm: float,
1577
+ progress=gr.Progress(),
1578
+ ):
1579
+ """
1580
+ PSHuman face transplant β€” post-rig pipeline:
1581
+ 1. Run RMBG on portrait β†’ RGBA (PSHuman needs alpha channel as foreground mask)
1582
+ 2. Run PSHuman on RGBA portrait β†’ colored OBJ face mesh (direct subprocess)
1583
+ 3. Transplant face into rigged GLB: bone weights ID head verts, KNN transfers
1584
+ skinning to PSHuman face. Output is a fully rigged mesh β€” no second rig pass.
1585
+ """
1586
+ try:
1587
+ if input_image is None:
1588
+ return None, "No portrait found β€” run Generate first.", None
1589
+ rigged = rigged_glb_path
1590
+ if not rigged or not os.path.exists(str(rigged)):
1591
+ return None, "No rigged GLB found β€” run Rig & Export first.", None
1592
+
1593
+ work_dir = tempfile.mkdtemp(prefix="pshuman_transplant_")
1594
+ img_path = os.path.join(work_dir, "portrait.png")
1595
+
1596
+ progress(0.03, desc="Preparing portrait (RMBG β†’ RGBA)...")
1597
+ pil_img = (
1598
+ Image.fromarray(input_image)
1599
+ if isinstance(input_image, np.ndarray)
1600
+ else input_image
1601
+ )
1602
+ rgba = _portrait_to_rgba(pil_img)
1603
+ rgba.save(img_path)
1604
+ print(f"[pshuman] Portrait saved as RGBA {rgba.size} β†’ {img_path}")
1605
+
1606
+ # Pipeline modules live at /root/MeshForge/pipeline/ on the instance
1607
+ _meshforge_dir = os.path.join(
1608
+ os.path.dirname(os.path.abspath(__file__)), "MeshForge"
1609
+ )
1610
+ if not os.path.isdir(_meshforge_dir):
1611
+ _meshforge_dir = os.path.dirname(os.path.abspath(__file__))
1612
+ if _meshforge_dir not in sys.path:
1613
+ sys.path.insert(0, _meshforge_dir)
1614
+
1615
+ # ── Step 2: PSHuman inference ──────────────────────────────────────────
1616
+ progress(0.08, desc="Step 2/3: Running PSHuman (multi-view face generation)...")
1617
+ from pipeline.pshuman_client import generate_pshuman_mesh
1618
+
1619
+ face_obj = os.path.join(work_dir, "pshuman_face.obj")
1620
+ generate_pshuman_mesh(
1621
+ image_path=img_path,
1622
+ output_path=face_obj,
1623
+ service_url="direct",
1624
+ )
1625
+
1626
+ # ── Step 3: Transplant into rigged GLB (bone-weight head detection + KNN) ──
1627
+ progress(0.7, desc="Step 3/3: Transplanting PSHuman face into rigged GLB...")
1628
+ out_glb = os.path.join(work_dir, "rigged_pshuman_face.glb")
1629
+
1630
+ from pipeline.face_transplant import transplant_face
1631
+
1632
+ transplant_face(
1633
+ body_glb_path=str(rigged),
1634
+ pshuman_mesh_path=face_obj,
1635
+ output_path=out_glb,
1636
+ weight_threshold=float(weight_threshold),
1637
+ retract_amount=float(retract_mm) / 1000.0,
1638
+ )
1639
+
1640
+ progress(1.0, desc="Done!")
1641
+ return out_glb, "PSHuman face transplant complete.", out_glb
1642
+
1643
+ except Exception:
1644
+ return None, f"Error:\n{traceback.format_exc()}", None
1645
+
1646
+
1647
+ # ── UI ────────────────────────────────────────────────────────────────────────
1648
+ with gr.Blocks(title="TripoSG + MV-Adapter 3D Studio", theme=gr.themes.Soft()) as demo:
1649
+ gr.Markdown("# TripoSG + MV-Adapter 3D Studio")
1650
+ glb_state = gr.State(None)
1651
+ rigged_glb_state = gr.State(None) # persists UniRig output for Animate tab
1652
+
1653
+ with gr.Tabs() as tabs:
1654
+ # ════════════════════════════════════════════════════════════════════
1655
+ with gr.Tab("Edit", id=0):
1656
+ gr.Markdown(
1657
+ "### Image Edit β€” FireRed\n"
1658
+ "Upload one or more reference images, write an edit prompt, preview the result, "
1659
+ "then click **Load to Generate** to send it to the 3D pipeline."
1660
+ )
1661
+ with gr.Row():
1662
+ with gr.Column(scale=1):
1663
+ firered_gallery = gr.Gallery(
1664
+ label="Reference Images (1–3 images, drag & drop)",
1665
+ interactive=True,
1666
+ columns=3,
1667
+ height=220,
1668
+ object_fit="contain",
1669
+ )
1670
+ firered_prompt = gr.Textbox(
1671
+ label="Edit Prompt",
1672
+ placeholder="make the person wear a red jacket",
1673
+ lines=2,
1674
+ )
1675
+ with gr.Row():
1676
+ firered_seed = gr.Number(
1677
+ value=_init_seed, label="Seed", precision=0
1678
+ )
1679
+ firered_rand = gr.Checkbox(label="Random Seed", value=True)
1680
+ with gr.Row():
1681
+ firered_guidance = gr.Slider(
1682
+ 1.0, 10.0, value=1.0, step=0.5, label="Guidance Scale"
1683
+ )
1684
+ firered_steps = gr.Slider(
1685
+ 1, 40, value=4, step=1, label="Inference Steps"
1686
+ )
1687
+ firered_btn = gr.Button("Generate Preview", variant="secondary")
1688
+ firered_status = gr.Textbox(
1689
+ label="Status", lines=2, interactive=False
1690
+ )
1691
+ with gr.Column(scale=1):
1692
+ firered_output_img = gr.Image(
1693
+ label="FireRed Output", type="numpy", interactive=False
1694
+ )
1695
+ load_to_generate_btn = gr.Button(
1696
+ "Load to Generate", variant="primary"
1697
+ )
1698
+
1699
+ # ════════════════════════════════════════════════════════════════════
1700
+ with gr.Tab("Generate", id=1):
1701
+ with gr.Row():
1702
+ with gr.Column(scale=1):
1703
+ input_image = gr.Image(label="Input Image", type="numpy")
1704
+ remove_bg_check = gr.Checkbox(label="Remove Background", value=True)
1705
+ with gr.Row():
1706
+ rembg_threshold = gr.Slider(
1707
+ 0.1,
1708
+ 0.95,
1709
+ value=0.5,
1710
+ step=0.05,
1711
+ label="BG Threshold (higher = stricter)",
1712
+ )
1713
+ rembg_erode = gr.Slider(
1714
+ 0, 8, value=2, step=1, label="Edge Erode (px)"
1715
+ )
1716
+
1717
+ with gr.Accordion("Shape Settings", open=True):
1718
+ num_steps = gr.Slider(
1719
+ 20, 100, value=50, step=5, label="Inference Steps"
1720
+ )
1721
+ guidance = gr.Slider(
1722
+ 1.0, 20.0, value=7.0, step=0.5, label="Guidance Scale"
1723
+ )
1724
+ seed = gr.Number(value=_init_seed, label="Seed", precision=0)
1725
+ face_count = gr.Number(
1726
+ value=0, label="Max Faces (0 = unlimited)", precision=0
1727
+ )
1728
+
1729
+ with gr.Accordion("Texture Settings", open=True):
1730
+ variant = gr.Radio(
1731
+ ["sdxl", "sd21"],
1732
+ value="sdxl",
1733
+ label="Model (sdxl = better quality, sd21 = less VRAM)",
1734
+ )
1735
+ tex_seed = gr.Number(
1736
+ value=_init_seed, label="Texture Seed", precision=0
1737
+ )
1738
+ enhance_face_check = gr.Checkbox(
1739
+ label="Enhance Face (HyperSwap + RealESRGAN)", value=True
1740
+ )
1741
+
1742
+ with gr.Row():
1743
+ shape_btn = gr.Button(
1744
+ "Generate Shape",
1745
+ variant="primary",
1746
+ scale=2,
1747
+ interactive=False,
1748
+ )
1749
+ texture_btn = gr.Button(
1750
+ "Apply Texture", variant="secondary", scale=2
1751
+ )
1752
+ render_btn = gr.Button(
1753
+ "Render Views", variant="secondary", scale=1
1754
+ )
1755
+ run_all_btn = gr.Button(
1756
+ "β–Ά Run Full Pipeline (Shape + Texture + Rig)",
1757
+ variant="primary",
1758
+ interactive=False,
1759
+ )
1760
+
1761
+ with gr.Column(scale=1):
1762
+ rembg_preview = gr.Image(
1763
+ label="BG Removed Preview", type="numpy", interactive=False
1764
+ )
1765
+ status = gr.Textbox(label="Status", lines=3, interactive=False)
1766
+ model_3d = gr.Model3D(
1767
+ label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0]
1768
+ )
1769
+ download_file = gr.File(label="Download GLB")
1770
+ multiview_img = gr.Image(
1771
+ label="Multiview", type="filepath", interactive=False
1772
+ )
1773
+
1774
+ render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
1775
+
1776
+ # ── wiring: Generate tab ──────────────────────────────────────
1777
+ _rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode]
1778
+ _pipeline_btns = [shape_btn, run_all_btn]
1779
+
1780
+ input_image.upload(
1781
+ fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)),
1782
+ inputs=[],
1783
+ outputs=_pipeline_btns,
1784
+ )
1785
+ input_image.clear(
1786
+ fn=lambda: (gr.update(interactive=False), gr.update(interactive=False)),
1787
+ inputs=[],
1788
+ outputs=_pipeline_btns,
1789
+ )
1790
+
1791
+ input_image.upload(
1792
+ fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]
1793
+ )
1794
+ remove_bg_check.change(
1795
+ fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]
1796
+ )
1797
+ rembg_threshold.release(
1798
+ fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]
1799
+ )
1800
+ rembg_erode.release(
1801
+ fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]
1802
+ )
1803
+
1804
+ shape_btn.click(
1805
+ fn=generate_shape,
1806
+ inputs=[
1807
+ input_image,
1808
+ remove_bg_check,
1809
+ num_steps,
1810
+ guidance,
1811
+ seed,
1812
+ face_count,
1813
+ ],
1814
+ outputs=[glb_state, status],
1815
+ ).then(
1816
+ fn=lambda p: (p, p) if p else (None, None),
1817
+ inputs=[glb_state],
1818
+ outputs=[model_3d, download_file],
1819
+ )
1820
+
1821
+ texture_btn.click(
1822
+ fn=apply_texture,
1823
+ inputs=[
1824
+ glb_state,
1825
+ input_image,
1826
+ remove_bg_check,
1827
+ variant,
1828
+ tex_seed,
1829
+ enhance_face_check,
1830
+ rembg_threshold,
1831
+ rembg_erode,
1832
+ ],
1833
+ outputs=[glb_state, multiview_img, status],
1834
+ ).then(
1835
+ fn=lambda p: (p, p) if p else (None, None),
1836
+ inputs=[glb_state],
1837
+ outputs=[model_3d, download_file],
1838
+ )
1839
+
1840
+ render_btn.click(
1841
+ fn=render_views, inputs=[download_file], outputs=[render_gallery]
1842
+ )
1843
+
1844
+ # ── Edit tab wiring (after Generate so all components are defined) ──
1845
+ firered_btn.click(
1846
+ fn=firered_generate,
1847
+ inputs=[
1848
+ firered_gallery,
1849
+ firered_prompt,
1850
+ firered_seed,
1851
+ firered_rand,
1852
+ firered_guidance,
1853
+ firered_steps,
1854
+ ],
1855
+ outputs=[firered_output_img, firered_seed, firered_status],
1856
+ api_name="firered_generate",
1857
+ )
1858
+
1859
+ load_to_generate_btn.click(
1860
+ fn=firered_load_into_pipeline,
1861
+ inputs=[firered_output_img, rembg_threshold, rembg_erode],
1862
+ outputs=[input_image, rembg_preview, firered_status],
1863
+ ).then(
1864
+ fn=lambda img: (
1865
+ gr.update(interactive=img is not None),
1866
+ gr.update(interactive=img is not None),
1867
+ gr.update(selected=1),
1868
+ ),
1869
+ inputs=[input_image],
1870
+ outputs=[shape_btn, run_all_btn, tabs],
1871
+ )
1872
+
1873
+ # ════════════════════════════════════════════════════════════════════
1874
+ with gr.Tab("Rig & Export"):
1875
+ with gr.Row():
1876
+ # ── Left column: controls ──────────────────────────────────
1877
+ with gr.Column(scale=1):
1878
+ gr.Markdown("### UniRig + PSHuman β€” Rig & HD Face")
1879
+ gr.Markdown(
1880
+ "One click runs the full pipeline:\n"
1881
+ "1. **UniRig** skeletonises + skins the mesh\n"
1882
+ "2. **PSHuman** generates an HD face from your portrait (RMBG β†’ multi-view diffusion)\n"
1883
+ "3. **Face transplant** stitches the HD face into the rigged mesh using bone weights + KNN\n\n"
1884
+ "Portrait is pulled automatically from the Generate tab."
1885
+ )
1886
+ export_fbx_check = gr.Checkbox(label="Export FBX", value=True)
1887
+ with gr.Accordion("PSHuman settings", open=False):
1888
+ pshuman_weight_thresh = gr.Slider(
1889
+ minimum=0.1,
1890
+ maximum=0.9,
1891
+ value=0.35,
1892
+ step=0.05,
1893
+ label="Head bone weight threshold",
1894
+ info="Vertices with head-bone weight above this get replaced",
1895
+ )
1896
+ pshuman_retract_mm = gr.Slider(
1897
+ minimum=0.0,
1898
+ maximum=20.0,
1899
+ value=4.0,
1900
+ step=0.5,
1901
+ label="Face retract (mm)",
1902
+ info="How far to push original face verts inward to avoid z-fighting",
1903
+ )
1904
+ rig_btn = gr.Button("Rig with UniRig", variant="primary")
1905
+
1906
+ # ── Right column: preview + downloads ─────────────────────
1907
+ with gr.Column(scale=2):
1908
+ rig_status = gr.Textbox(label="Status", lines=4, interactive=False)
1909
+ rig_model_3d = gr.Model3D(
1910
+ label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0]
1911
+ )
1912
+ with gr.Row():
1913
+ rig_glb_dl = gr.File(label="Download Rigged GLB")
1914
+ rig_fbx_dl = gr.File(label="Download FBX")
1915
+
1916
+ rig_btn.click(
1917
+ fn=gradio_rig,
1918
+ inputs=[
1919
+ input_image,
1920
+ glb_state,
1921
+ export_fbx_check,
1922
+ pshuman_weight_thresh,
1923
+ pshuman_retract_mm,
1924
+ ],
1925
+ outputs=[
1926
+ rig_glb_dl,
1927
+ gr.State(None),
1928
+ rig_fbx_dl,
1929
+ rig_status,
1930
+ rig_model_3d,
1931
+ rigged_glb_state,
1932
+ gr.State(None),
1933
+ ],
1934
+ )
1935
+
1936
+ # ════════════════════════════════════════════════════════════════════
1937
+ with gr.Tab("Enhancement"):
1938
+ gr.Markdown("""
1939
+ **Surface Enhancement** β€” runs on the reference portrait to produce
1940
+ calibrated normal + depth maps that are baked into the GLB as PBR textures.
1941
+ """)
1942
+ with gr.Row():
1943
+ with gr.Column(scale=1):
1944
+ gr.Markdown("### StableNormal")
1945
+ run_normal_check = gr.Checkbox(label="Run StableNormal", value=True)
1946
+ normal_res = gr.Slider(
1947
+ 512, 1024, value=768, step=128, label="Resolution"
1948
+ )
1949
+ normal_strength = gr.Slider(
1950
+ 0.1, 3.0, value=1.0, step=0.1, label="Normal Strength"
1951
+ )
1952
+
1953
+ gr.Markdown("### Depth-Anything V2")
1954
+ run_depth_check = gr.Checkbox(
1955
+ label="Run Depth-Anything V2", value=True
1956
+ )
1957
+ depth_res = gr.Slider(
1958
+ 512, 1024, value=768, step=128, label="Resolution"
1959
+ )
1960
+ displacement_scale = gr.Slider(
1961
+ 0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale"
1962
+ )
1963
+
1964
+ enhance_btn = gr.Button("Run Enhancement", variant="primary")
1965
+ unload_btn = gr.Button(
1966
+ "Unload Models (free VRAM)", variant="secondary"
1967
+ )
1968
+
1969
+ with gr.Column(scale=2):
1970
+ enhance_status = gr.Textbox(
1971
+ label="Status", lines=5, interactive=False
1972
+ )
1973
+ with gr.Row():
1974
+ normal_map_img = gr.Image(label="Normal Map", type="pil")
1975
+ depth_map_img = gr.Image(label="Depth Map", type="pil")
1976
+ enhanced_glb_dl = gr.File(label="Download Enhanced GLB")
1977
+ enhanced_model_3d = gr.Model3D(
1978
+ label="Enhanced Preview", clear_color=[0.9, 0.9, 0.9, 1.0]
1979
+ )
1980
+
1981
+ def gradio_enhance(
1982
+ glb_path,
1983
+ ref_img_np,
1984
+ do_normal,
1985
+ norm_res,
1986
+ norm_strength,
1987
+ do_depth,
1988
+ dep_res,
1989
+ disp_scale,
1990
+ ):
1991
+ if not glb_path:
1992
+ return None, None, None, None, "No GLB loaded β€” run Generate first."
1993
+ if ref_img_np is None:
1994
+ return (
1995
+ None,
1996
+ None,
1997
+ None,
1998
+ None,
1999
+ "No reference image β€” run Generate first.",
2000
+ )
2001
+ try:
2002
+ ref_pil = Image.fromarray(ref_img_np.astype(np.uint8))
2003
+ out_path = glb_path.replace(".glb", "_enhanced.glb")
2004
+ import shutil as _sh
2005
+
2006
+ _sh.copy2(glb_path, out_path)
2007
+
2008
+ normal_out = None
2009
+ depth_out = None
2010
+ log = []
2011
+
2012
+ if do_normal:
2013
+ log.append("[StableNormal] Running...")
2014
+ yield None, None, None, None, "\n".join(log)
2015
+ normal_out = run_stable_normal(ref_pil, resolution=norm_res)
2016
+ out_path = bake_normal_into_glb(
2017
+ out_path,
2018
+ normal_out,
2019
+ out_path,
2020
+ normal_strength=norm_strength,
2021
+ )
2022
+ log.append(
2023
+ f"[StableNormal] Done β†’ baked normalTexture (strength {norm_strength})"
2024
+ )
2025
+ yield normal_out, depth_out, None, None, "\n".join(log)
2026
+
2027
+ if do_depth:
2028
+ log.append("[Depth-Anything] Running...")
2029
+ yield normal_out, depth_out, None, None, "\n".join(log)
2030
+ depth_out = run_depth_anything(ref_pil, resolution=dep_res)
2031
+ out_path = bake_depth_as_occlusion(
2032
+ out_path, depth_out, out_path, displacement_scale=disp_scale
2033
+ )
2034
+ depth_preview = depth_out.convert("L").convert("RGB")
2035
+ log.append(
2036
+ f"[Depth-Anything] Done β†’ baked occlusionTexture (scale {disp_scale})"
2037
+ )
2038
+ yield normal_out, depth_preview, None, None, "\n".join(log)
2039
+
2040
+ log.append("Enhancement complete.")
2041
+ yield (
2042
+ normal_out,
2043
+ (depth_out.convert("L").convert("RGB") if depth_out else None),
2044
+ out_path,
2045
+ out_path,
2046
+ "\n".join(log),
2047
+ )
2048
+
2049
+ except Exception as e:
2050
+ yield None, None, None, None, f"Error:\n{traceback.format_exc()}"
2051
+
2052
+ enhance_btn.click(
2053
+ fn=gradio_enhance,
2054
+ inputs=[
2055
+ glb_state,
2056
+ input_image,
2057
+ run_normal_check,
2058
+ normal_res,
2059
+ normal_strength,
2060
+ run_depth_check,
2061
+ depth_res,
2062
+ displacement_scale,
2063
+ ],
2064
+ outputs=[
2065
+ normal_map_img,
2066
+ depth_map_img,
2067
+ enhanced_glb_dl,
2068
+ enhanced_model_3d,
2069
+ enhance_status,
2070
+ ],
2071
+ )
2072
+
2073
+ unload_btn.click(
2074
+ fn=lambda: (unload_models(), "Models unloaded β€” VRAM freed.")[1],
2075
+ inputs=[],
2076
+ outputs=[enhance_status],
2077
+ )
2078
+
2079
+ # ════════════════════════════════════════════════════════════════════
2080
+ with gr.Tab("Settings"):
2081
+
2082
+ def get_vram_status():
2083
+ lines = []
2084
+ if torch.cuda.is_available():
2085
+ alloc = torch.cuda.memory_allocated() / 1024**3
2086
+ reserv = torch.cuda.memory_reserved() / 1024**3
2087
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
2088
+ free = total - reserv
2089
+ lines.append(f"GPU: {torch.cuda.get_device_name(0)}")
2090
+ lines.append(f"VRAM total: {total:.1f} GB")
2091
+ lines.append(f"VRAM allocated: {alloc:.1f} GB")
2092
+ lines.append(f"VRAM reserved: {reserv:.1f} GB")
2093
+ lines.append(f"VRAM free: {free:.1f} GB")
2094
+ else:
2095
+ lines.append("No CUDA device available.")
2096
+ lines.append("")
2097
+ lines.append("Loaded models:")
2098
+ lines.append(
2099
+ f" TripoSG pipeline: {'βœ“ loaded' if _triposg_pipe is not None else 'β—‹ not loaded'}"
2100
+ )
2101
+ lines.append(
2102
+ f" RMBG-{_rmbg_version or '?'}: {'βœ“ loaded' if _rmbg_net is not None else 'β—‹ not loaded'}"
2103
+ )
2104
+ lines.append(
2105
+ f" StableNormal: {'βœ“ loaded' if _enh_mod._normal_pipe is not None else 'β—‹ not loaded'}"
2106
+ )
2107
+ lines.append(
2108
+ f" Depth-Anything: {'βœ“ loaded' if _enh_mod._depth_pipe is not None else 'β—‹ not loaded'}"
2109
+ )
2110
+ return "\n".join(lines)
2111
+
2112
+ def preload_triposg():
2113
+ try:
2114
+ load_triposg()
2115
+ return get_vram_status()
2116
+ except Exception as e:
2117
+ return f"Preload failed:\n{traceback.format_exc()}"
2118
+
2119
+ def unload_triposg():
2120
+ global _triposg_pipe, _rmbg_net
2121
+ with _model_load_lock:
2122
+ if _triposg_pipe is not None:
2123
+ _triposg_pipe.to("cpu")
2124
+ del _triposg_pipe
2125
+ _triposg_pipe = None
2126
+ if _rmbg_net is not None:
2127
+ _rmbg_net.to("cpu")
2128
+ del _rmbg_net
2129
+ _rmbg_net = None
2130
+ torch.cuda.empty_cache()
2131
+ return get_vram_status()
2132
+
2133
+ def unload_enhancement():
2134
+ unload_models()
2135
+ return get_vram_status()
2136
+
2137
+ def unload_all():
2138
+ unload_triposg()
2139
+ unload_models()
2140
+ return get_vram_status()
2141
+
2142
+ with gr.Row():
2143
+ with gr.Column(scale=1):
2144
+ gr.Markdown("### VRAM Management")
2145
+ preload_btn = gr.Button(
2146
+ "Preload TripoSG + RMBG to VRAM", variant="primary"
2147
+ )
2148
+ unload_triposg_btn = gr.Button("Unload TripoSG / RMBG")
2149
+ unload_enh_btn = gr.Button(
2150
+ "Unload Enhancement Models (StableNormal / Depth)"
2151
+ )
2152
+ unload_all_btn = gr.Button("Unload All Models", variant="stop")
2153
+ refresh_btn = gr.Button("Refresh Status")
2154
+
2155
+ with gr.Column(scale=1):
2156
+ gr.Markdown("### GPU Status")
2157
+ vram_status = gr.Textbox(
2158
+ label="",
2159
+ lines=12,
2160
+ interactive=False,
2161
+ value="Click Refresh to check VRAM status.",
2162
+ )
2163
+
2164
+ preload_btn.click(fn=preload_triposg, inputs=[], outputs=[vram_status])
2165
+ unload_triposg_btn.click(
2166
+ fn=unload_triposg, inputs=[], outputs=[vram_status]
2167
+ )
2168
+ unload_enh_btn.click(
2169
+ fn=unload_enhancement, inputs=[], outputs=[vram_status]
2170
+ )
2171
+ unload_all_btn.click(fn=unload_all, inputs=[], outputs=[vram_status])
2172
+ refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status])
2173
+
2174
+ # ── run_all wiring (after Rig tab so all components are defined) ──
2175
+ run_all_btn.click(
2176
+ fn=run_full_pipeline,
2177
+ inputs=[
2178
+ input_image,
2179
+ remove_bg_check,
2180
+ num_steps,
2181
+ guidance,
2182
+ seed,
2183
+ face_count,
2184
+ variant,
2185
+ tex_seed,
2186
+ enhance_face_check,
2187
+ rembg_threshold,
2188
+ rembg_erode,
2189
+ export_fbx_check,
2190
+ ],
2191
+ outputs=[
2192
+ glb_state,
2193
+ download_file,
2194
+ multiview_img,
2195
+ rig_glb_dl,
2196
+ rig_fbx_dl,
2197
+ status,
2198
+ ],
2199
+ ).then(
2200
+ fn=lambda p: (p, p) if p else (None, None),
2201
+ inputs=[glb_state],
2202
+ outputs=[model_3d, download_file],
2203
+ )
2204
+
2205
+ # ── Hidden API endpoints β€” use invisible Gallery (State is stripped from API in Gradio 6) ──
2206
+ _api_render_gallery = gr.Gallery(visible=False)
2207
+ _api_swap_gallery = gr.Gallery(visible=False)
2208
+
2209
+ def _render_last():
2210
+ path = _last_glb_path or str(TMP_DIR / "triposg_textured.glb")
2211
+ return render_views(path)
2212
+
2213
+ _hs_emb_input = gr.Textbox(visible=False)
2214
+
2215
+ gr.Button(visible=False).click(
2216
+ fn=_render_last,
2217
+ inputs=[],
2218
+ outputs=[_api_render_gallery],
2219
+ api_name="render_last",
2220
+ )
2221
+ gr.Button(visible=False).click(
2222
+ fn=hyperswap_views,
2223
+ inputs=[_hs_emb_input],
2224
+ outputs=[_api_swap_gallery],
2225
+ api_name="hyperswap_views",
2226
+ )
2227
+
2228
+
2229
+ if __name__ == "__main__":
2230
+ demo.launch(
2231
+ server_name="0.0.0.0",
2232
+ server_port=7860,
2233
+ share=True,
2234
+ show_error=True,
2235
+ allowed_paths=["/tmp"],
2236
+ max_threads=4,
2237
+ max_file_size="50mb",
2238
+ )