Daankular commited on
Commit
9757b8d
·
1 Parent(s): 396e206

Fix queries dtype cast in inference_utils.py; remove autocast (caused NaN/no isosurface)

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -216,7 +216,7 @@ def load_triposg():
216
  _iu_path = triposg_src / "triposg" / "inference_utils.py"
217
  if _iu_path.exists():
218
  _iu_text = _iu_path.read_text()
219
- if "DiffDMC = None" not in _iu_text and "from diso import DiffDMC" in _iu_text:
220
  _iu_text = _iu_text.replace(
221
  "from diso import DiffDMC",
222
  "try:\n from diso import DiffDMC\n"
@@ -237,8 +237,15 @@ def load_triposg():
237
  " return flash_extract_geometry(*args, **kwargs)\n"
238
  " return _hierarchical_extract_geometry_impl(*args, **kwargs)\n"
239
  )
 
 
 
 
 
 
 
240
  _iu_path.write_text(_iu_text)
241
- print("[load_triposg] Patched inference_utils.py: diso optional")
242
 
243
  weights_path = snapshot_download("VAST-AI/TripoSG")
244
 
@@ -334,16 +341,15 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
334
 
335
  progress(0.5, desc="Generating shape (SDF diffusion)...")
336
  from scripts.inference_triposg import run_triposg
337
- with torch.autocast(device_type="cuda", dtype=torch.float16):
338
- mesh = run_triposg(
339
- pipe=pipe,
340
- image_input=img_path,
341
- rmbg_net=rmbg_net if remove_background else None,
342
- seed=int(seed),
343
- num_inference_steps=int(num_steps),
344
- guidance_scale=float(guidance_scale),
345
- faces=int(face_count) if int(face_count) > 0 else -1,
346
- )
347
 
348
  out_path = "/tmp/triposg_shape.glb"
349
  mesh.export(out_path)
 
216
  _iu_path = triposg_src / "triposg" / "inference_utils.py"
217
  if _iu_path.exists():
218
  _iu_text = _iu_path.read_text()
219
+ if "queries.to(dtype=batch_latents.dtype)" not in _iu_text:
220
  _iu_text = _iu_text.replace(
221
  "from diso import DiffDMC",
222
  "try:\n from diso import DiffDMC\n"
 
237
  " return flash_extract_geometry(*args, **kwargs)\n"
238
  " return _hierarchical_extract_geometry_impl(*args, **kwargs)\n"
239
  )
240
+ # Also cast queries to match batch_latents dtype before vae.decode.
241
+ # TripoSGPipeline loads as float16 but flash_extract_geometry creates
242
+ # query grids as float32, causing a dtype mismatch in F.linear.
243
+ _iu_text = _iu_text.replace(
244
+ "logits = vae.decode(batch_latents, queries).sample",
245
+ "logits = vae.decode(batch_latents, queries.to(dtype=batch_latents.dtype)).sample",
246
+ )
247
  _iu_path.write_text(_iu_text)
248
+ print("[load_triposg] Patched inference_utils.py: diso optional + queries dtype cast")
249
 
250
  weights_path = snapshot_download("VAST-AI/TripoSG")
251
 
 
341
 
342
  progress(0.5, desc="Generating shape (SDF diffusion)...")
343
  from scripts.inference_triposg import run_triposg
344
+ mesh = run_triposg(
345
+ pipe=pipe,
346
+ image_input=img_path,
347
+ rmbg_net=rmbg_net if remove_background else None,
348
+ seed=int(seed),
349
+ num_inference_steps=int(num_steps),
350
+ guidance_scale=float(guidance_scale),
351
+ faces=int(face_count) if int(face_count) > 0 else -1,
352
+ )
 
353
 
354
  out_path = "/tmp/triposg_shape.glb"
355
  mesh.export(out_path)