cpuai commited on
Commit
cfef1cb
·
verified ·
1 Parent(s): 0846da3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -68
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import os
2
- import numpy as np
 
 
 
3
  import cv2
 
4
  import kiui
5
- import trimesh
6
- import torch
7
  import rembg
8
- from datetime import datetime
9
- import gradio as gr
 
 
 
10
 
11
  try:
12
  import spaces
@@ -19,61 +25,70 @@ except ImportError:
19
  def __call__(self, func):
20
  return func
21
 
22
-
23
  from flow.model import Model
24
  from flow.configs.schema import ModelConfig
25
  from flow.utils import get_random_color, recenter_foreground
26
  from vae.utils import postprocess_mesh
27
- from huggingface_hub import hf_hub_download
28
 
29
- # =========================
30
- # CPU 基础设置
31
- # =========================
32
  DEVICE = torch.device("cpu")
33
  DTYPE = torch.float32
34
 
 
35
  CPU_THREADS = int(os.environ.get("CPU_THREADS", "2"))
36
  torch.set_num_threads(CPU_THREADS)
37
  torch.set_num_interop_threads(max(1, min(2, CPU_THREADS)))
38
 
 
 
 
 
 
 
 
 
 
39
  TRIMESH_GLB_EXPORT = np.array(
40
  [[0, 1, 0], [0, 0, 1], [1, 0, 0]],
41
  dtype=np.float32
42
  )
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
  bg_remover = rembg.new_session()
45
 
46
- # =========================
47
- # 下载模型
48
- # =========================
49
- flow_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="flow.pt")
50
- vae_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="vae.pt")
 
 
 
 
 
 
 
51
 
52
- # =========================
53
- # 模型配置
54
- # =========================
55
- model_config = ModelConfig(
56
- vae_conf="vae.configs.part_woenc",
57
- vae_ckpt_path=vae_ckpt_path,
58
- qknorm=True,
59
- qknorm_type="RMSNorm",
60
- use_pos_embed=False,
61
- dino_model="dinov2_vitg14",
62
- hidden_dim=1536,
63
- flow_shift=3.0,
64
- logitnorm_mean=1.0,
65
- logitnorm_std=1.0,
66
- latent_size=4096,
67
- use_parts=True,
68
- )
69
 
70
- # =========================
 
 
 
 
 
 
 
 
 
71
  # 工具函数:强制整个模块转 float32
72
- # =========================
73
  def force_module_fp32(module: torch.nn.Module):
74
  """
75
- 递归把模块参数和 buffer 全部 float32。
76
- 这一步是解决 CPU 下 bfloat16/float32 混用问题的关键。
77
  """
78
  module.to(device=DEVICE)
79
  module.float()
@@ -81,31 +96,233 @@ def force_module_fp32(module: torch.nn.Module):
81
  for child in module.children():
82
  force_module_fp32(child)
83
 
 
84
  for name, buf in module.named_buffers(recurse=False):
85
- if torch.is_floating_point(buf):
86
  setattr(module, name, buf.to(device=DEVICE, dtype=torch.float32))
87
 
88
  return module
89
 
90
 
91
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # 初始化模型(CPU + float32)
93
- # =========================
94
  print("正在加载模型到 CPU ...")
 
 
 
95
  model = Model(model_config)
96
  model.eval()
97
  model.to(DEVICE)
98
 
99
  # 显式按 CPU 加载权重
100
- ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE, weights_only=True)
 
 
 
 
 
101
  model.load_state_dict(ckpt_dict, strict=True)
102
 
103
- # 关键:再次强制整个模型 float32
104
  force_module_fp32(model)
105
  model.eval()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  print("模型加载完成。")
108
- print("主模型 dtype:", next(model.parameters()).dtype)
 
 
 
109
 
110
 
111
  def get_random_seed(randomize_seed, seed):
@@ -165,7 +382,9 @@ def process_3d(
165
  os.makedirs("output", exist_ok=True)
166
  output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
167
 
168
- # RGBA -> float32
 
 
169
  image = input_image.astype(np.float32) / 255.0
170
  image = image[..., :3] * image[..., 3:4] + (1.0 - image[..., 3:4])
171
 
@@ -178,56 +397,83 @@ def process_3d(
178
  )
179
 
180
  data = {
181
- "cond_images": image_tensor.float()
182
  }
 
183
 
184
- # 再保险:推理前确保模型仍是 float32
 
 
185
  force_module_fp32(model)
186
  model.eval()
 
 
 
187
 
 
 
 
188
  with torch.inference_mode():
189
- results = model(
190
- data,
191
- num_steps=int(num_steps),
192
- cfg_scale=float(cfg_scale)
193
- )
 
194
 
195
- latent = results["latent"]
196
 
197
- # 关键:latent 强制 float32
198
- if isinstance(latent, torch.Tensor):
199
- latent = latent.to(device=DEVICE, dtype=torch.float32).contiguous()
200
- else:
201
  raise gr.Error("模型输出 latent 异常。")
202
 
203
- # VAE 输入前再做 float32 保证
 
 
 
 
204
  data_part0 = {
205
- "latent": latent[:, : model.config.latent_size, :].float().contiguous()
206
  }
207
  data_part1 = {
208
- "latent": latent[:, model.config.latent_size:, :].float().contiguous()
209
  }
210
 
211
- # 再保险:把 VAE 也强制成 float32
212
- force_module_fp32(model.vae)
213
- model.vae.eval()
214
 
215
  with torch.inference_mode():
216
- results_part0 = model.vae(data_part0, resolution=int(grid_res))
217
- results_part1 = model.vae(data_part1, resolution=int(grid_res))
 
 
 
 
218
 
219
  if not simplify_mesh:
220
  target_num_faces = -1
221
 
222
  parts = []
223
 
 
 
 
224
  vertices, faces = results_part0["meshes"][0]
 
 
 
225
  mesh_part0 = trimesh.Trimesh(vertices, faces, process=False)
226
  mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
227
  mesh_part0 = postprocess_mesh(mesh_part0, int(target_num_faces))
228
  parts.extend(mesh_part0.split(only_watertight=False))
229
 
 
 
 
230
  vertices, faces = results_part1["meshes"][0]
 
 
 
231
  mesh_part1 = trimesh.Trimesh(vertices, faces, process=False)
232
  mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
233
  mesh_part1 = postprocess_mesh(mesh_part1, int(target_num_faces))
@@ -244,6 +490,8 @@ def process_3d(
244
 
245
  return output_glb_path
246
 
 
 
247
  except Exception as e:
248
  raise gr.Error(
249
  "CPU 生成失败:"
@@ -282,8 +530,16 @@ with block:
282
 
283
  with gr.Row():
284
  with gr.Column():
285
- input_image = gr.Image(label="上传图片", type="filepath")
286
- seg_image = gr.Image(label="处理后图片", type="numpy", interactive=False, image_mode="RGBA")
 
 
 
 
 
 
 
 
287
 
288
  with gr.Accordion("高级设置", open=False):
289
  num_steps = gr.Slider(
@@ -307,9 +563,16 @@ with block:
307
  step=1,
308
  value=128
309
  )
 
310
  with gr.Row():
311
  randomize_seed = gr.Checkbox(label="随机种子", value=True)
312
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
313
 
314
  with gr.Row():
315
  simplify_mesh = gr.Checkbox(label="简化网格", value=True)
@@ -349,8 +612,19 @@ with block:
349
  outputs=[seed]
350
  ).then(
351
  fn=process_3d,
352
- inputs=[seg_image, num_steps, cfg_scale, input_grid_res, seed, simplify_mesh, target_num_faces],
 
 
 
 
 
 
 
 
353
  outputs=[output_model]
354
  )
355
 
356
- block.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
1
  import os
2
+ import contextlib
3
+ import functools
4
+ from datetime import datetime
5
+
6
  import cv2
7
+ import gradio as gr
8
  import kiui
9
+ import numpy as np
 
10
  import rembg
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import trimesh
15
+ from huggingface_hub import hf_hub_download
16
 
17
  try:
18
  import spaces
 
25
  def __call__(self, func):
26
  return func
27
 
 
28
  from flow.model import Model
29
  from flow.configs.schema import ModelConfig
30
  from flow.utils import get_random_color, recenter_foreground
31
  from vae.utils import postprocess_mesh
 
32
 
33
+ # =========================================================
34
+ # CPU / dtype 基础设置
35
+ # =========================================================
36
  DEVICE = torch.device("cpu")
37
  DTYPE = torch.float32
38
 
39
+ # 线程数可按 HF CPU Space 机器情况调整
40
  CPU_THREADS = int(os.environ.get("CPU_THREADS", "2"))
41
  torch.set_num_threads(CPU_THREADS)
42
  torch.set_num_interop_threads(max(1, min(2, CPU_THREADS)))
43
 
44
+ # 显式设默认浮点 dtype 为 float32
45
+ torch.set_default_dtype(torch.float32)
46
+
47
+ # 对 CPU 推理更稳妥
48
+ try:
49
+ torch.set_grad_enabled(False)
50
+ except Exception:
51
+ pass
52
+
53
  TRIMESH_GLB_EXPORT = np.array(
54
  [[0, 1, 0], [0, 0, 1], [1, 0, 0]],
55
  dtype=np.float32
56
  )
57
+
58
  MAX_SEED = np.iinfo(np.int32).max
59
  bg_remover = rembg.new_session()
60
 
61
+ # =========================================================
62
+ # 工具函数:递归转换任意对象中的浮点 Tensor 为 float32
63
+ # =========================================================
64
+ def to_cpu_fp32(obj):
65
+ """
66
+ 递归把对象中的浮点 Tensor 转成 CPU + float32。
67
+ 支持 Tensor / dict / list / tuple。
68
+ """
69
+ if torch.is_tensor(obj):
70
+ if obj.is_floating_point():
71
+ return obj.to(device=DEVICE, dtype=torch.float32, non_blocking=False)
72
+ return obj.to(device=DEVICE, non_blocking=False)
73
 
74
+ if isinstance(obj, dict):
75
+ return {k: to_cpu_fp32(v) for k, v in obj.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ if isinstance(obj, list):
78
+ return [to_cpu_fp32(v) for v in obj]
79
+
80
+ if isinstance(obj, tuple):
81
+ return tuple(to_cpu_fp32(v) for v in obj)
82
+
83
+ return obj
84
+
85
+
86
+ # =========================================================
87
  # 工具函数:强制整个模块转 float32
88
+ # =========================================================
89
  def force_module_fp32(module: torch.nn.Module):
90
  """
91
+ 递归把模块参数和 buffer CPU + float32。
 
92
  """
93
  module.to(device=DEVICE)
94
  module.float()
 
96
  for child in module.children():
97
  force_module_fp32(child)
98
 
99
+ # 处理 buffer
100
  for name, buf in module.named_buffers(recurse=False):
101
+ if torch.is_tensor(buf) and buf.is_floating_point():
102
  setattr(module, name, buf.to(device=DEVICE, dtype=torch.float32))
103
 
104
  return module
105
 
106
 
107
+ # =========================================================
108
+ # 工具函数:禁用 CPU autocast
109
+ # =========================================================
110
+ @contextlib.contextmanager
111
+ def disable_cpu_autocast():
112
+ """
113
+ 显式关闭 CPU autocast,防止内部偷偷切到 bfloat16。
114
+ """
115
+ try:
116
+ with torch.autocast(device_type="cpu", enabled=False):
117
+ yield
118
+ except Exception:
119
+ # 某些环境/版本可能不支持该写法,直接退化为普通上下文
120
+ yield
121
+
122
+
123
+ # =========================================================
124
+ # 兜底补丁 1:全局修补 F.linear
125
+ # =========================================================
126
+ def patch_functional_linear():
127
+ """
128
+ 给 torch.nn.functional.linear 打补丁:
129
+ 如果 input 和 weight dtype 不一致,自动把 input 转成 weight.dtype。
130
+ 这是最后一道保险。
131
+ """
132
+ if getattr(F.linear, "_fp32_safe_patched", False):
133
+ return
134
+
135
+ original_linear = F.linear
136
+
137
+ @functools.wraps(original_linear)
138
+ def linear_fp32_safe(input, weight, bias=None):
139
+ if (
140
+ torch.is_tensor(input)
141
+ and torch.is_tensor(weight)
142
+ and input.device.type == "cpu"
143
+ and input.is_floating_point()
144
+ and weight.is_floating_point()
145
+ and input.dtype != weight.dtype
146
+ ):
147
+ input = input.to(dtype=weight.dtype)
148
+
149
+ if (
150
+ bias is not None
151
+ and torch.is_tensor(bias)
152
+ and bias.is_floating_point()
153
+ and torch.is_tensor(weight)
154
+ and weight.is_floating_point()
155
+ and bias.dtype != weight.dtype
156
+ ):
157
+ bias = bias.to(dtype=weight.dtype)
158
+
159
+ return original_linear(input, weight, bias)
160
+
161
+ linear_fp32_safe._fp32_safe_patched = True
162
+ F.linear = linear_fp32_safe
163
+
164
+
165
+ # =========================================================
166
+ # 兜底补丁 2:给常见模块加 forward pre-hook
167
+ # =========================================================
168
+ def register_dtype_guard_hooks(root_module: nn.Module):
169
+ """
170
+ 给常见算子模块注册前置 hook,在 forward 入口把输入对齐到参数 dtype。
171
+ """
172
+ hooks = []
173
+
174
+ guarded_types = (
175
+ nn.Linear,
176
+ nn.Conv1d,
177
+ nn.Conv2d,
178
+ nn.Conv3d,
179
+ nn.LayerNorm,
180
+ nn.GroupNorm,
181
+ nn.BatchNorm1d,
182
+ nn.BatchNorm2d,
183
+ nn.BatchNorm3d,
184
+ nn.MultiheadAttention,
185
+ )
186
+
187
+ def cast_obj_to_dtype(obj, dtype, device):
188
+ if torch.is_tensor(obj):
189
+ if obj.is_floating_point():
190
+ return obj.to(device=device, dtype=dtype)
191
+ return obj.to(device=device)
192
+ if isinstance(obj, dict):
193
+ return {k: cast_obj_to_dtype(v, dtype, device) for k, v in obj.items()}
194
+ if isinstance(obj, list):
195
+ return [cast_obj_to_dtype(v, dtype, device) for v in obj]
196
+ if isinstance(obj, tuple):
197
+ return tuple(cast_obj_to_dtype(v, dtype, device) for v in obj)
198
+ return obj
199
+
200
+ def pre_hook(module, inputs):
201
+ ref_tensor = None
202
+
203
+ # 先从参数里找参考 dtype
204
+ for p in module.parameters(recurse=False):
205
+ if torch.is_tensor(p) and p.is_floating_point():
206
+ ref_tensor = p
207
+ break
208
+
209
+ # 参数没有,再从 buffer 里找
210
+ if ref_tensor is None:
211
+ for b in module.buffers(recurse=False):
212
+ if torch.is_tensor(b) and b.is_floating_point():
213
+ ref_tensor = b
214
+ break
215
+
216
+ if ref_tensor is None:
217
+ return inputs
218
+
219
+ return cast_obj_to_dtype(inputs, ref_tensor.dtype, ref_tensor.device)
220
+
221
+ for submodule in root_module.modules():
222
+ if isinstance(submodule, guarded_types):
223
+ hooks.append(submodule.register_forward_pre_hook(pre_hook))
224
+
225
+ return hooks
226
+
227
+
228
+ # =========================================================
229
+ # 兜底补丁 3:包装 forward,统一禁用 autocast + 输入转 fp32
230
+ # =========================================================
231
+ def wrap_forward_fp32(module: nn.Module):
232
+ """
233
+ 包装模块的 forward:
234
+ 1. 进入 forward 前先把输入递归转为 float32
235
+ 2. forward 期间禁用 CPU autocast
236
+ """
237
+ if getattr(module, "_forward_fp32_wrapped", False):
238
+ return
239
+
240
+ original_forward = module.forward
241
+
242
+ @functools.wraps(original_forward)
243
+ def forward_fp32_safe(*args, **kwargs):
244
+ args = to_cpu_fp32(args)
245
+ kwargs = to_cpu_fp32(kwargs)
246
+ with disable_cpu_autocast():
247
+ out = original_forward(*args, **kwargs)
248
+ return to_cpu_fp32(out)
249
+
250
+ module.forward = forward_fp32_safe
251
+ module._forward_fp32_wrapped = True
252
+
253
+
254
+ # =========================================================
255
+ # 下载模型
256
+ # =========================================================
257
+ flow_ckpt_path = hf_hub_download(
258
+ repo_id="nvidia/PartPacker",
259
+ filename="flow.pt"
260
+ )
261
+ vae_ckpt_path = hf_hub_download(
262
+ repo_id="nvidia/PartPacker",
263
+ filename="vae.pt"
264
+ )
265
+
266
+ # =========================================================
267
+ # 模型配置
268
+ # =========================================================
269
+ model_config = ModelConfig(
270
+ vae_conf="vae.configs.part_woenc",
271
+ vae_ckpt_path=vae_ckpt_path,
272
+ qknorm=True,
273
+ qknorm_type="RMSNorm",
274
+ use_pos_embed=False,
275
+ dino_model="dinov2_vitg14",
276
+ hidden_dim=1536,
277
+ flow_shift=3.0,
278
+ logitnorm_mean=1.0,
279
+ logitnorm_std=1.0,
280
+ latent_size=4096,
281
+ use_parts=True,
282
+ )
283
+
284
+ # =========================================================
285
  # 初始化模型(CPU + float32)
286
+ # =========================================================
287
  print("正在加载模型到 CPU ...")
288
+
289
+ patch_functional_linear()
290
+
291
  model = Model(model_config)
292
  model.eval()
293
  model.to(DEVICE)
294
 
295
  # 显式按 CPU 加载权重
296
+ # 某些环境下 weights_only=True 不兼容时,可退回普通 torch.load
297
+ try:
298
+ ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE, weights_only=True)
299
+ except TypeError:
300
+ ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE)
301
+
302
  model.load_state_dict(ckpt_dict, strict=True)
303
 
304
+ # 强制模型 float32
305
  force_module_fp32(model)
306
  model.eval()
307
 
308
+ # 包装 forward,彻底关闭 CPU autocast
309
+ wrap_forward_fp32(model)
310
+ if hasattr(model, "dit"):
311
+ wrap_forward_fp32(model.dit)
312
+ if hasattr(model, "vae"):
313
+ wrap_forward_fp32(model.vae)
314
+
315
+ # 给模型注册 dtype 保护 hook
316
+ _DTYPE_GUARD_HOOKS = []
317
+ _DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model))
318
+ if hasattr(model, "vae"):
319
+ _DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model.vae))
320
+
321
  print("模型加载完成。")
322
+ try:
323
+ print("主模型 dtype:", next(model.parameters()).dtype)
324
+ except StopIteration:
325
+ print("主模型没有可见参数。")
326
 
327
 
328
  def get_random_seed(randomize_seed, seed):
 
382
  os.makedirs("output", exist_ok=True)
383
  output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
384
 
385
+ # -------------------------------------------------
386
+ # 1) RGBA -> RGB 白底合成 -> float32
387
+ # -------------------------------------------------
388
  image = input_image.astype(np.float32) / 255.0
389
  image = image[..., :3] * image[..., 3:4] + (1.0 - image[..., 3:4])
390
 
 
397
  )
398
 
399
  data = {
400
+ "cond_images": image_tensor
401
  }
402
+ data = to_cpu_fp32(data)
403
 
404
+ # -------------------------------------------------
405
+ # 2) 推理前再次强制模型为 float32
406
+ # -------------------------------------------------
407
  force_module_fp32(model)
408
  model.eval()
409
+ if hasattr(model, "vae"):
410
+ force_module_fp32(model.vae)
411
+ model.vae.eval()
412
 
413
+ # -------------------------------------------------
414
+ # 3) 主模型推理:显式禁用 CPU autocast
415
+ # -------------------------------------------------
416
  with torch.inference_mode():
417
+ with disable_cpu_autocast():
418
+ results = model(
419
+ data,
420
+ num_steps=int(num_steps),
421
+ cfg_scale=float(cfg_scale)
422
+ )
423
 
424
+ results = to_cpu_fp32(results)
425
 
426
+ latent = results.get("latent", None)
427
+ if not isinstance(latent, torch.Tensor):
 
 
428
  raise gr.Error("模型输出 latent 异常。")
429
 
430
+ latent = latent.to(device=DEVICE, dtype=torch.float32).contiguous()
431
+
432
+ # -------------------------------------------------
433
+ # 4) VAE 解码:再次显式禁用 CPU autocast
434
+ # -------------------------------------------------
435
  data_part0 = {
436
+ "latent": latent[:, : model.config.latent_size, :].contiguous()
437
  }
438
  data_part1 = {
439
+ "latent": latent[:, model.config.latent_size:, :].contiguous()
440
  }
441
 
442
+ data_part0 = to_cpu_fp32(data_part0)
443
+ data_part1 = to_cpu_fp32(data_part1)
 
444
 
445
  with torch.inference_mode():
446
+ with disable_cpu_autocast():
447
+ results_part0 = model.vae(data_part0, resolution=int(grid_res))
448
+ results_part1 = model.vae(data_part1, resolution=int(grid_res))
449
+
450
+ results_part0 = to_cpu_fp32(results_part0)
451
+ results_part1 = to_cpu_fp32(results_part1)
452
 
453
  if not simplify_mesh:
454
  target_num_faces = -1
455
 
456
  parts = []
457
 
458
+ # -------------------------------------------------
459
+ # 5) part 0 mesh
460
+ # -------------------------------------------------
461
  vertices, faces = results_part0["meshes"][0]
462
+ vertices = np.asarray(vertices, dtype=np.float32)
463
+ faces = np.asarray(faces, dtype=np.int64)
464
+
465
  mesh_part0 = trimesh.Trimesh(vertices, faces, process=False)
466
  mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
467
  mesh_part0 = postprocess_mesh(mesh_part0, int(target_num_faces))
468
  parts.extend(mesh_part0.split(only_watertight=False))
469
 
470
+ # -------------------------------------------------
471
+ # 6) part 1 mesh
472
+ # -------------------------------------------------
473
  vertices, faces = results_part1["meshes"][0]
474
+ vertices = np.asarray(vertices, dtype=np.float32)
475
+ faces = np.asarray(faces, dtype=np.int64)
476
+
477
  mesh_part1 = trimesh.Trimesh(vertices, faces, process=False)
478
  mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
479
  mesh_part1 = postprocess_mesh(mesh_part1, int(target_num_faces))
 
490
 
491
  return output_glb_path
492
 
493
+ except gr.Error:
494
+ raise
495
  except Exception as e:
496
  raise gr.Error(
497
  "CPU 生成失败:"
 
530
 
531
  with gr.Row():
532
  with gr.Column():
533
+ input_image = gr.Image(
534
+ label="上传图片",
535
+ type="filepath"
536
+ )
537
+ seg_image = gr.Image(
538
+ label="处理后图片",
539
+ type="numpy",
540
+ interactive=False,
541
+ image_mode="RGBA"
542
+ )
543
 
544
  with gr.Accordion("高级设置", open=False):
545
  num_steps = gr.Slider(
 
563
  step=1,
564
  value=128
565
  )
566
+
567
  with gr.Row():
568
  randomize_seed = gr.Checkbox(label="随机种子", value=True)
569
+ seed = gr.Slider(
570
+ label="Seed",
571
+ minimum=0,
572
+ maximum=MAX_SEED,
573
+ step=1,
574
+ value=0
575
+ )
576
 
577
  with gr.Row():
578
  simplify_mesh = gr.Checkbox(label="简化网格", value=True)
 
612
  outputs=[seed]
613
  ).then(
614
  fn=process_3d,
615
+ inputs=[
616
+ seg_image,
617
+ num_steps,
618
+ cfg_scale,
619
+ input_grid_res,
620
+ seed,
621
+ simplify_mesh,
622
+ target_num_faces
623
+ ],
624
  outputs=[output_model]
625
  )
626
 
627
+ block.launch(
628
+ server_name="0.0.0.0",
629
+ server_port=int(os.environ.get("PORT", 7860))
630
+ )