Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,18 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
|
|
|
|
|
|
|
|
|
| 3 |
import cv2
|
|
|
|
| 4 |
import kiui
|
| 5 |
-
import
|
| 6 |
-
import torch
|
| 7 |
import rembg
|
| 8 |
-
|
| 9 |
-
import
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
try:
|
| 12 |
import spaces
|
|
@@ -19,61 +25,70 @@ except ImportError:
|
|
| 19 |
def __call__(self, func):
|
| 20 |
return func
|
| 21 |
|
| 22 |
-
|
| 23 |
from flow.model import Model
|
| 24 |
from flow.configs.schema import ModelConfig
|
| 25 |
from flow.utils import get_random_color, recenter_foreground
|
| 26 |
from vae.utils import postprocess_mesh
|
| 27 |
-
from huggingface_hub import hf_hub_download
|
| 28 |
|
| 29 |
-
# =========================
|
| 30 |
-
# CPU 基础设置
|
| 31 |
-
# =========================
|
| 32 |
DEVICE = torch.device("cpu")
|
| 33 |
DTYPE = torch.float32
|
| 34 |
|
|
|
|
| 35 |
CPU_THREADS = int(os.environ.get("CPU_THREADS", "2"))
|
| 36 |
torch.set_num_threads(CPU_THREADS)
|
| 37 |
torch.set_num_interop_threads(max(1, min(2, CPU_THREADS)))
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
TRIMESH_GLB_EXPORT = np.array(
|
| 40 |
[[0, 1, 0], [0, 0, 1], [1, 0, 0]],
|
| 41 |
dtype=np.float32
|
| 42 |
)
|
|
|
|
| 43 |
MAX_SEED = np.iinfo(np.int32).max
|
| 44 |
bg_remover = rembg.new_session()
|
| 45 |
|
| 46 |
-
# =========================
|
| 47 |
-
#
|
| 48 |
-
# =========================
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# =========================
|
| 55 |
-
model_config = ModelConfig(
|
| 56 |
-
vae_conf="vae.configs.part_woenc",
|
| 57 |
-
vae_ckpt_path=vae_ckpt_path,
|
| 58 |
-
qknorm=True,
|
| 59 |
-
qknorm_type="RMSNorm",
|
| 60 |
-
use_pos_embed=False,
|
| 61 |
-
dino_model="dinov2_vitg14",
|
| 62 |
-
hidden_dim=1536,
|
| 63 |
-
flow_shift=3.0,
|
| 64 |
-
logitnorm_mean=1.0,
|
| 65 |
-
logitnorm_std=1.0,
|
| 66 |
-
latent_size=4096,
|
| 67 |
-
use_parts=True,
|
| 68 |
-
)
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
# 工具函数:强制整个模块转 float32
|
| 72 |
-
# =========================
|
| 73 |
def force_module_fp32(module: torch.nn.Module):
|
| 74 |
"""
|
| 75 |
-
递归把模块参数和 buffer
|
| 76 |
-
这一步是解决 CPU 下 bfloat16/float32 混用问题的关键。
|
| 77 |
"""
|
| 78 |
module.to(device=DEVICE)
|
| 79 |
module.float()
|
|
@@ -81,31 +96,233 @@ def force_module_fp32(module: torch.nn.Module):
|
|
| 81 |
for child in module.children():
|
| 82 |
force_module_fp32(child)
|
| 83 |
|
|
|
|
| 84 |
for name, buf in module.named_buffers(recurse=False):
|
| 85 |
-
if torch.
|
| 86 |
setattr(module, name, buf.to(device=DEVICE, dtype=torch.float32))
|
| 87 |
|
| 88 |
return module
|
| 89 |
|
| 90 |
|
| 91 |
-
# =========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# 初始化模型(CPU + float32)
|
| 93 |
-
# =========================
|
| 94 |
print("正在加载模型到 CPU ...")
|
|
|
|
|
|
|
|
|
|
| 95 |
model = Model(model_config)
|
| 96 |
model.eval()
|
| 97 |
model.to(DEVICE)
|
| 98 |
|
| 99 |
# 显式按 CPU 加载权重
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
model.load_state_dict(ckpt_dict, strict=True)
|
| 102 |
|
| 103 |
-
#
|
| 104 |
force_module_fp32(model)
|
| 105 |
model.eval()
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
print("模型加载完成。")
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def get_random_seed(randomize_seed, seed):
|
|
@@ -165,7 +382,9 @@ def process_3d(
|
|
| 165 |
os.makedirs("output", exist_ok=True)
|
| 166 |
output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
|
| 167 |
|
| 168 |
-
#
|
|
|
|
|
|
|
| 169 |
image = input_image.astype(np.float32) / 255.0
|
| 170 |
image = image[..., :3] * image[..., 3:4] + (1.0 - image[..., 3:4])
|
| 171 |
|
|
@@ -178,56 +397,83 @@ def process_3d(
|
|
| 178 |
)
|
| 179 |
|
| 180 |
data = {
|
| 181 |
-
"cond_images": image_tensor
|
| 182 |
}
|
|
|
|
| 183 |
|
| 184 |
-
#
|
|
|
|
|
|
|
| 185 |
force_module_fp32(model)
|
| 186 |
model.eval()
|
|
|
|
|
|
|
|
|
|
| 187 |
|
|
|
|
|
|
|
|
|
|
| 188 |
with torch.inference_mode():
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
|
| 197 |
-
|
| 198 |
-
if isinstance(latent, torch.Tensor):
|
| 199 |
-
latent = latent.to(device=DEVICE, dtype=torch.float32).contiguous()
|
| 200 |
-
else:
|
| 201 |
raise gr.Error("模型输出 latent 异常。")
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
data_part0 = {
|
| 205 |
-
"latent": latent[:, : model.config.latent_size, :].
|
| 206 |
}
|
| 207 |
data_part1 = {
|
| 208 |
-
"latent": latent[:, model.config.latent_size:, :].
|
| 209 |
}
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
model.vae.eval()
|
| 214 |
|
| 215 |
with torch.inference_mode():
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
if not simplify_mesh:
|
| 220 |
target_num_faces = -1
|
| 221 |
|
| 222 |
parts = []
|
| 223 |
|
|
|
|
|
|
|
|
|
|
| 224 |
vertices, faces = results_part0["meshes"][0]
|
|
|
|
|
|
|
|
|
|
| 225 |
mesh_part0 = trimesh.Trimesh(vertices, faces, process=False)
|
| 226 |
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
|
| 227 |
mesh_part0 = postprocess_mesh(mesh_part0, int(target_num_faces))
|
| 228 |
parts.extend(mesh_part0.split(only_watertight=False))
|
| 229 |
|
|
|
|
|
|
|
|
|
|
| 230 |
vertices, faces = results_part1["meshes"][0]
|
|
|
|
|
|
|
|
|
|
| 231 |
mesh_part1 = trimesh.Trimesh(vertices, faces, process=False)
|
| 232 |
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
|
| 233 |
mesh_part1 = postprocess_mesh(mesh_part1, int(target_num_faces))
|
|
@@ -244,6 +490,8 @@ def process_3d(
|
|
| 244 |
|
| 245 |
return output_glb_path
|
| 246 |
|
|
|
|
|
|
|
| 247 |
except Exception as e:
|
| 248 |
raise gr.Error(
|
| 249 |
"CPU 生成失败:"
|
|
@@ -282,8 +530,16 @@ with block:
|
|
| 282 |
|
| 283 |
with gr.Row():
|
| 284 |
with gr.Column():
|
| 285 |
-
input_image = gr.Image(
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
with gr.Accordion("高级设置", open=False):
|
| 289 |
num_steps = gr.Slider(
|
|
@@ -307,9 +563,16 @@ with block:
|
|
| 307 |
step=1,
|
| 308 |
value=128
|
| 309 |
)
|
|
|
|
| 310 |
with gr.Row():
|
| 311 |
randomize_seed = gr.Checkbox(label="随机种子", value=True)
|
| 312 |
-
seed = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
with gr.Row():
|
| 315 |
simplify_mesh = gr.Checkbox(label="简化网格", value=True)
|
|
@@ -349,8 +612,19 @@ with block:
|
|
| 349 |
outputs=[seed]
|
| 350 |
).then(
|
| 351 |
fn=process_3d,
|
| 352 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
outputs=[output_model]
|
| 354 |
)
|
| 355 |
|
| 356 |
-
block.launch(
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import contextlib
|
| 3 |
+
import functools
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
import kiui
|
| 9 |
+
import numpy as np
|
|
|
|
| 10 |
import rembg
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import trimesh
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
|
| 17 |
try:
|
| 18 |
import spaces
|
|
|
|
| 25 |
def __call__(self, func):
|
| 26 |
return func
|
| 27 |
|
|
|
|
| 28 |
from flow.model import Model
|
| 29 |
from flow.configs.schema import ModelConfig
|
| 30 |
from flow.utils import get_random_color, recenter_foreground
|
| 31 |
from vae.utils import postprocess_mesh
|
|
|
|
| 32 |
|
| 33 |
+
# =========================================================
|
| 34 |
+
# CPU / dtype 基础设置
|
| 35 |
+
# =========================================================
|
| 36 |
DEVICE = torch.device("cpu")
|
| 37 |
DTYPE = torch.float32
|
| 38 |
|
| 39 |
+
# 线程数可按 HF CPU Space 机器情况调整
|
| 40 |
CPU_THREADS = int(os.environ.get("CPU_THREADS", "2"))
|
| 41 |
torch.set_num_threads(CPU_THREADS)
|
| 42 |
torch.set_num_interop_threads(max(1, min(2, CPU_THREADS)))
|
| 43 |
|
| 44 |
+
# 显式设默认浮点 dtype 为 float32
|
| 45 |
+
torch.set_default_dtype(torch.float32)
|
| 46 |
+
|
| 47 |
+
# 对 CPU 推理更稳妥
|
| 48 |
+
try:
|
| 49 |
+
torch.set_grad_enabled(False)
|
| 50 |
+
except Exception:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
TRIMESH_GLB_EXPORT = np.array(
|
| 54 |
[[0, 1, 0], [0, 0, 1], [1, 0, 0]],
|
| 55 |
dtype=np.float32
|
| 56 |
)
|
| 57 |
+
|
| 58 |
MAX_SEED = np.iinfo(np.int32).max
|
| 59 |
bg_remover = rembg.new_session()
|
| 60 |
|
| 61 |
+
# =========================================================
|
| 62 |
+
# 工具函数:递归转换任意对象中的浮点 Tensor 为 float32
|
| 63 |
+
# =========================================================
|
| 64 |
+
def to_cpu_fp32(obj):
|
| 65 |
+
"""
|
| 66 |
+
递归把对象中的浮点 Tensor 转成 CPU + float32。
|
| 67 |
+
支持 Tensor / dict / list / tuple。
|
| 68 |
+
"""
|
| 69 |
+
if torch.is_tensor(obj):
|
| 70 |
+
if obj.is_floating_point():
|
| 71 |
+
return obj.to(device=DEVICE, dtype=torch.float32, non_blocking=False)
|
| 72 |
+
return obj.to(device=DEVICE, non_blocking=False)
|
| 73 |
|
| 74 |
+
if isinstance(obj, dict):
|
| 75 |
+
return {k: to_cpu_fp32(v) for k, v in obj.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
if isinstance(obj, list):
|
| 78 |
+
return [to_cpu_fp32(v) for v in obj]
|
| 79 |
+
|
| 80 |
+
if isinstance(obj, tuple):
|
| 81 |
+
return tuple(to_cpu_fp32(v) for v in obj)
|
| 82 |
+
|
| 83 |
+
return obj
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# =========================================================
|
| 87 |
# 工具函数:强制整个模块转 float32
|
| 88 |
+
# =========================================================
|
| 89 |
def force_module_fp32(module: torch.nn.Module):
|
| 90 |
"""
|
| 91 |
+
递归把模块参数和 buffer 都转到 CPU + float32。
|
|
|
|
| 92 |
"""
|
| 93 |
module.to(device=DEVICE)
|
| 94 |
module.float()
|
|
|
|
| 96 |
for child in module.children():
|
| 97 |
force_module_fp32(child)
|
| 98 |
|
| 99 |
+
# 处理 buffer
|
| 100 |
for name, buf in module.named_buffers(recurse=False):
|
| 101 |
+
if torch.is_tensor(buf) and buf.is_floating_point():
|
| 102 |
setattr(module, name, buf.to(device=DEVICE, dtype=torch.float32))
|
| 103 |
|
| 104 |
return module
|
| 105 |
|
| 106 |
|
| 107 |
+
# =========================================================
|
| 108 |
+
# 工具函数:禁用 CPU autocast
|
| 109 |
+
# =========================================================
|
| 110 |
+
@contextlib.contextmanager
|
| 111 |
+
def disable_cpu_autocast():
|
| 112 |
+
"""
|
| 113 |
+
显式关闭 CPU autocast,防止内部偷偷切到 bfloat16。
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
with torch.autocast(device_type="cpu", enabled=False):
|
| 117 |
+
yield
|
| 118 |
+
except Exception:
|
| 119 |
+
# 某些环境/版本可能不支持该写法,直接退化为普通上下文
|
| 120 |
+
yield
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# =========================================================
|
| 124 |
+
# 兜底补丁 1:全局修补 F.linear
|
| 125 |
+
# =========================================================
|
| 126 |
+
def patch_functional_linear():
|
| 127 |
+
"""
|
| 128 |
+
给 torch.nn.functional.linear 打补丁:
|
| 129 |
+
如果 input 和 weight dtype 不一致,自动把 input 转成 weight.dtype。
|
| 130 |
+
这是最后一道保险。
|
| 131 |
+
"""
|
| 132 |
+
if getattr(F.linear, "_fp32_safe_patched", False):
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
original_linear = F.linear
|
| 136 |
+
|
| 137 |
+
@functools.wraps(original_linear)
|
| 138 |
+
def linear_fp32_safe(input, weight, bias=None):
|
| 139 |
+
if (
|
| 140 |
+
torch.is_tensor(input)
|
| 141 |
+
and torch.is_tensor(weight)
|
| 142 |
+
and input.device.type == "cpu"
|
| 143 |
+
and input.is_floating_point()
|
| 144 |
+
and weight.is_floating_point()
|
| 145 |
+
and input.dtype != weight.dtype
|
| 146 |
+
):
|
| 147 |
+
input = input.to(dtype=weight.dtype)
|
| 148 |
+
|
| 149 |
+
if (
|
| 150 |
+
bias is not None
|
| 151 |
+
and torch.is_tensor(bias)
|
| 152 |
+
and bias.is_floating_point()
|
| 153 |
+
and torch.is_tensor(weight)
|
| 154 |
+
and weight.is_floating_point()
|
| 155 |
+
and bias.dtype != weight.dtype
|
| 156 |
+
):
|
| 157 |
+
bias = bias.to(dtype=weight.dtype)
|
| 158 |
+
|
| 159 |
+
return original_linear(input, weight, bias)
|
| 160 |
+
|
| 161 |
+
linear_fp32_safe._fp32_safe_patched = True
|
| 162 |
+
F.linear = linear_fp32_safe
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# =========================================================
|
| 166 |
+
# 兜底补丁 2:给常见模块加 forward pre-hook
|
| 167 |
+
# =========================================================
|
| 168 |
+
def register_dtype_guard_hooks(root_module: nn.Module):
|
| 169 |
+
"""
|
| 170 |
+
给常见算子模块注册前置 hook,在 forward 入口把输入对齐到参数 dtype。
|
| 171 |
+
"""
|
| 172 |
+
hooks = []
|
| 173 |
+
|
| 174 |
+
guarded_types = (
|
| 175 |
+
nn.Linear,
|
| 176 |
+
nn.Conv1d,
|
| 177 |
+
nn.Conv2d,
|
| 178 |
+
nn.Conv3d,
|
| 179 |
+
nn.LayerNorm,
|
| 180 |
+
nn.GroupNorm,
|
| 181 |
+
nn.BatchNorm1d,
|
| 182 |
+
nn.BatchNorm2d,
|
| 183 |
+
nn.BatchNorm3d,
|
| 184 |
+
nn.MultiheadAttention,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def cast_obj_to_dtype(obj, dtype, device):
|
| 188 |
+
if torch.is_tensor(obj):
|
| 189 |
+
if obj.is_floating_point():
|
| 190 |
+
return obj.to(device=device, dtype=dtype)
|
| 191 |
+
return obj.to(device=device)
|
| 192 |
+
if isinstance(obj, dict):
|
| 193 |
+
return {k: cast_obj_to_dtype(v, dtype, device) for k, v in obj.items()}
|
| 194 |
+
if isinstance(obj, list):
|
| 195 |
+
return [cast_obj_to_dtype(v, dtype, device) for v in obj]
|
| 196 |
+
if isinstance(obj, tuple):
|
| 197 |
+
return tuple(cast_obj_to_dtype(v, dtype, device) for v in obj)
|
| 198 |
+
return obj
|
| 199 |
+
|
| 200 |
+
def pre_hook(module, inputs):
|
| 201 |
+
ref_tensor = None
|
| 202 |
+
|
| 203 |
+
# 先从参数里找参考 dtype
|
| 204 |
+
for p in module.parameters(recurse=False):
|
| 205 |
+
if torch.is_tensor(p) and p.is_floating_point():
|
| 206 |
+
ref_tensor = p
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
# 参数没有,再从 buffer 里找
|
| 210 |
+
if ref_tensor is None:
|
| 211 |
+
for b in module.buffers(recurse=False):
|
| 212 |
+
if torch.is_tensor(b) and b.is_floating_point():
|
| 213 |
+
ref_tensor = b
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
if ref_tensor is None:
|
| 217 |
+
return inputs
|
| 218 |
+
|
| 219 |
+
return cast_obj_to_dtype(inputs, ref_tensor.dtype, ref_tensor.device)
|
| 220 |
+
|
| 221 |
+
for submodule in root_module.modules():
|
| 222 |
+
if isinstance(submodule, guarded_types):
|
| 223 |
+
hooks.append(submodule.register_forward_pre_hook(pre_hook))
|
| 224 |
+
|
| 225 |
+
return hooks
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# =========================================================
|
| 229 |
+
# 兜底补丁 3:包装 forward,统一禁用 autocast + 输入转 fp32
|
| 230 |
+
# =========================================================
|
| 231 |
+
def wrap_forward_fp32(module: nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
包装模块的 forward:
|
| 234 |
+
1. 进入 forward 前先把输入递归转为 float32
|
| 235 |
+
2. forward 期间禁用 CPU autocast
|
| 236 |
+
"""
|
| 237 |
+
if getattr(module, "_forward_fp32_wrapped", False):
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
original_forward = module.forward
|
| 241 |
+
|
| 242 |
+
@functools.wraps(original_forward)
|
| 243 |
+
def forward_fp32_safe(*args, **kwargs):
|
| 244 |
+
args = to_cpu_fp32(args)
|
| 245 |
+
kwargs = to_cpu_fp32(kwargs)
|
| 246 |
+
with disable_cpu_autocast():
|
| 247 |
+
out = original_forward(*args, **kwargs)
|
| 248 |
+
return to_cpu_fp32(out)
|
| 249 |
+
|
| 250 |
+
module.forward = forward_fp32_safe
|
| 251 |
+
module._forward_fp32_wrapped = True
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# =========================================================
|
| 255 |
+
# 下载模型
|
| 256 |
+
# =========================================================
|
| 257 |
+
flow_ckpt_path = hf_hub_download(
|
| 258 |
+
repo_id="nvidia/PartPacker",
|
| 259 |
+
filename="flow.pt"
|
| 260 |
+
)
|
| 261 |
+
vae_ckpt_path = hf_hub_download(
|
| 262 |
+
repo_id="nvidia/PartPacker",
|
| 263 |
+
filename="vae.pt"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# =========================================================
|
| 267 |
+
# 模型配置
|
| 268 |
+
# =========================================================
|
| 269 |
+
model_config = ModelConfig(
|
| 270 |
+
vae_conf="vae.configs.part_woenc",
|
| 271 |
+
vae_ckpt_path=vae_ckpt_path,
|
| 272 |
+
qknorm=True,
|
| 273 |
+
qknorm_type="RMSNorm",
|
| 274 |
+
use_pos_embed=False,
|
| 275 |
+
dino_model="dinov2_vitg14",
|
| 276 |
+
hidden_dim=1536,
|
| 277 |
+
flow_shift=3.0,
|
| 278 |
+
logitnorm_mean=1.0,
|
| 279 |
+
logitnorm_std=1.0,
|
| 280 |
+
latent_size=4096,
|
| 281 |
+
use_parts=True,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# =========================================================
|
| 285 |
# 初始化模型(CPU + float32)
|
| 286 |
+
# =========================================================
|
| 287 |
print("正在加载模型到 CPU ...")
|
| 288 |
+
|
| 289 |
+
patch_functional_linear()
|
| 290 |
+
|
| 291 |
model = Model(model_config)
|
| 292 |
model.eval()
|
| 293 |
model.to(DEVICE)
|
| 294 |
|
| 295 |
# 显式按 CPU 加载权重
|
| 296 |
+
# 某些环境下 weights_only=True 不兼容时,可退回普通 torch.load
|
| 297 |
+
try:
|
| 298 |
+
ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE, weights_only=True)
|
| 299 |
+
except TypeError:
|
| 300 |
+
ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE)
|
| 301 |
+
|
| 302 |
model.load_state_dict(ckpt_dict, strict=True)
|
| 303 |
|
| 304 |
+
# 强制全模型转 float32
|
| 305 |
force_module_fp32(model)
|
| 306 |
model.eval()
|
| 307 |
|
| 308 |
+
# 包装 forward,彻底关闭 CPU autocast
|
| 309 |
+
wrap_forward_fp32(model)
|
| 310 |
+
if hasattr(model, "dit"):
|
| 311 |
+
wrap_forward_fp32(model.dit)
|
| 312 |
+
if hasattr(model, "vae"):
|
| 313 |
+
wrap_forward_fp32(model.vae)
|
| 314 |
+
|
| 315 |
+
# 给模型注册 dtype 保护 hook
|
| 316 |
+
_DTYPE_GUARD_HOOKS = []
|
| 317 |
+
_DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model))
|
| 318 |
+
if hasattr(model, "vae"):
|
| 319 |
+
_DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model.vae))
|
| 320 |
+
|
| 321 |
print("模型加载完成。")
|
| 322 |
+
try:
|
| 323 |
+
print("主模型 dtype:", next(model.parameters()).dtype)
|
| 324 |
+
except StopIteration:
|
| 325 |
+
print("主模型没有可见参数。")
|
| 326 |
|
| 327 |
|
| 328 |
def get_random_seed(randomize_seed, seed):
|
|
|
|
| 382 |
os.makedirs("output", exist_ok=True)
|
| 383 |
output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
|
| 384 |
|
| 385 |
+
# -------------------------------------------------
|
| 386 |
+
# 1) RGBA -> RGB 白底合成 -> float32
|
| 387 |
+
# -------------------------------------------------
|
| 388 |
image = input_image.astype(np.float32) / 255.0
|
| 389 |
image = image[..., :3] * image[..., 3:4] + (1.0 - image[..., 3:4])
|
| 390 |
|
|
|
|
| 397 |
)
|
| 398 |
|
| 399 |
data = {
|
| 400 |
+
"cond_images": image_tensor
|
| 401 |
}
|
| 402 |
+
data = to_cpu_fp32(data)
|
| 403 |
|
| 404 |
+
# -------------------------------------------------
|
| 405 |
+
# 2) 推理前再次强制模型为 float32
|
| 406 |
+
# -------------------------------------------------
|
| 407 |
force_module_fp32(model)
|
| 408 |
model.eval()
|
| 409 |
+
if hasattr(model, "vae"):
|
| 410 |
+
force_module_fp32(model.vae)
|
| 411 |
+
model.vae.eval()
|
| 412 |
|
| 413 |
+
# -------------------------------------------------
|
| 414 |
+
# 3) 主模型推理:显式禁用 CPU autocast
|
| 415 |
+
# -------------------------------------------------
|
| 416 |
with torch.inference_mode():
|
| 417 |
+
with disable_cpu_autocast():
|
| 418 |
+
results = model(
|
| 419 |
+
data,
|
| 420 |
+
num_steps=int(num_steps),
|
| 421 |
+
cfg_scale=float(cfg_scale)
|
| 422 |
+
)
|
| 423 |
|
| 424 |
+
results = to_cpu_fp32(results)
|
| 425 |
|
| 426 |
+
latent = results.get("latent", None)
|
| 427 |
+
if not isinstance(latent, torch.Tensor):
|
|
|
|
|
|
|
| 428 |
raise gr.Error("模型输出 latent 异常。")
|
| 429 |
|
| 430 |
+
latent = latent.to(device=DEVICE, dtype=torch.float32).contiguous()
|
| 431 |
+
|
| 432 |
+
# -------------------------------------------------
|
| 433 |
+
# 4) VAE 解码:再次显式禁用 CPU autocast
|
| 434 |
+
# -------------------------------------------------
|
| 435 |
data_part0 = {
|
| 436 |
+
"latent": latent[:, : model.config.latent_size, :].contiguous()
|
| 437 |
}
|
| 438 |
data_part1 = {
|
| 439 |
+
"latent": latent[:, model.config.latent_size:, :].contiguous()
|
| 440 |
}
|
| 441 |
|
| 442 |
+
data_part0 = to_cpu_fp32(data_part0)
|
| 443 |
+
data_part1 = to_cpu_fp32(data_part1)
|
|
|
|
| 444 |
|
| 445 |
with torch.inference_mode():
|
| 446 |
+
with disable_cpu_autocast():
|
| 447 |
+
results_part0 = model.vae(data_part0, resolution=int(grid_res))
|
| 448 |
+
results_part1 = model.vae(data_part1, resolution=int(grid_res))
|
| 449 |
+
|
| 450 |
+
results_part0 = to_cpu_fp32(results_part0)
|
| 451 |
+
results_part1 = to_cpu_fp32(results_part1)
|
| 452 |
|
| 453 |
if not simplify_mesh:
|
| 454 |
target_num_faces = -1
|
| 455 |
|
| 456 |
parts = []
|
| 457 |
|
| 458 |
+
# -------------------------------------------------
|
| 459 |
+
# 5) part 0 mesh
|
| 460 |
+
# -------------------------------------------------
|
| 461 |
vertices, faces = results_part0["meshes"][0]
|
| 462 |
+
vertices = np.asarray(vertices, dtype=np.float32)
|
| 463 |
+
faces = np.asarray(faces, dtype=np.int64)
|
| 464 |
+
|
| 465 |
mesh_part0 = trimesh.Trimesh(vertices, faces, process=False)
|
| 466 |
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
|
| 467 |
mesh_part0 = postprocess_mesh(mesh_part0, int(target_num_faces))
|
| 468 |
parts.extend(mesh_part0.split(only_watertight=False))
|
| 469 |
|
| 470 |
+
# -------------------------------------------------
|
| 471 |
+
# 6) part 1 mesh
|
| 472 |
+
# -------------------------------------------------
|
| 473 |
vertices, faces = results_part1["meshes"][0]
|
| 474 |
+
vertices = np.asarray(vertices, dtype=np.float32)
|
| 475 |
+
faces = np.asarray(faces, dtype=np.int64)
|
| 476 |
+
|
| 477 |
mesh_part1 = trimesh.Trimesh(vertices, faces, process=False)
|
| 478 |
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
|
| 479 |
mesh_part1 = postprocess_mesh(mesh_part1, int(target_num_faces))
|
|
|
|
| 490 |
|
| 491 |
return output_glb_path
|
| 492 |
|
| 493 |
+
except gr.Error:
|
| 494 |
+
raise
|
| 495 |
except Exception as e:
|
| 496 |
raise gr.Error(
|
| 497 |
"CPU 生成失败:"
|
|
|
|
| 530 |
|
| 531 |
with gr.Row():
|
| 532 |
with gr.Column():
|
| 533 |
+
input_image = gr.Image(
|
| 534 |
+
label="上传图片",
|
| 535 |
+
type="filepath"
|
| 536 |
+
)
|
| 537 |
+
seg_image = gr.Image(
|
| 538 |
+
label="处理后图片",
|
| 539 |
+
type="numpy",
|
| 540 |
+
interactive=False,
|
| 541 |
+
image_mode="RGBA"
|
| 542 |
+
)
|
| 543 |
|
| 544 |
with gr.Accordion("高级设置", open=False):
|
| 545 |
num_steps = gr.Slider(
|
|
|
|
| 563 |
step=1,
|
| 564 |
value=128
|
| 565 |
)
|
| 566 |
+
|
| 567 |
with gr.Row():
|
| 568 |
randomize_seed = gr.Checkbox(label="随机种子", value=True)
|
| 569 |
+
seed = gr.Slider(
|
| 570 |
+
label="Seed",
|
| 571 |
+
minimum=0,
|
| 572 |
+
maximum=MAX_SEED,
|
| 573 |
+
step=1,
|
| 574 |
+
value=0
|
| 575 |
+
)
|
| 576 |
|
| 577 |
with gr.Row():
|
| 578 |
simplify_mesh = gr.Checkbox(label="简化网格", value=True)
|
|
|
|
| 612 |
outputs=[seed]
|
| 613 |
).then(
|
| 614 |
fn=process_3d,
|
| 615 |
+
inputs=[
|
| 616 |
+
seg_image,
|
| 617 |
+
num_steps,
|
| 618 |
+
cfg_scale,
|
| 619 |
+
input_grid_res,
|
| 620 |
+
seed,
|
| 621 |
+
simplify_mesh,
|
| 622 |
+
target_num_faces
|
| 623 |
+
],
|
| 624 |
outputs=[output_model]
|
| 625 |
)
|
| 626 |
|
| 627 |
+
block.launch(
|
| 628 |
+
server_name="0.0.0.0",
|
| 629 |
+
server_port=int(os.environ.get("PORT", 7860))
|
| 630 |
+
)
|