Daankular commited on
Commit
6e9c7d3
·
1 Parent(s): efc7f8b

Fix TripoSG NaN cascade and remove stale autocast wrapper

Browse files

- inference_utils.py: replace empty next_points guard — instead of
filling next_logits with all-(-10000) sentinel (which cascades to
all-NaN after the loop), trilinearly upsample the coarse grid to the
target resolution so real SDF values are preserved through refinement
- app.py: remove torch.autocast wrapper around run_triposg(); autocast
was previously identified (commit 9757b8d) as causing NaN / no-isosurface
issues and must not re-wrap the inference call

app.py CHANGED
@@ -880,9 +880,8 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
880
  img.save(img_path)
881
 
882
  progress(0.5, desc="Generating shape (SDF diffusion)...")
883
- with torch.autocast(device_type="cuda", dtype=torch.float16):
884
- from scripts.inference_triposg import run_triposg
885
- mesh = run_triposg(
886
  pipe=pipe,
887
  image_input=img_path,
888
  rmbg_net=rmbg_net if remove_background else None,
 
880
  img.save(img_path)
881
 
882
  progress(0.5, desc="Generating shape (SDF diffusion)...")
883
+ from scripts.inference_triposg import run_triposg
884
+ mesh = run_triposg(
 
885
  pipe=pipe,
886
  image_input=img_path,
887
  rmbg_net=rmbg_net if remove_background else None,
patches/triposg/triposg/inference_utils.py CHANGED
@@ -431,8 +431,13 @@ def flash_extract_geometry(
431
 
432
  next_points = torch.stack(nidx, dim=1)
433
  if next_points.shape[0] == 0:
434
- # No active voxels at this refinement level — skip and carry forward
435
- grid_logits = next_logits.unsqueeze(0)
 
 
 
 
 
436
  continue
437
  next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
438
  torch.tensor(bbox_min, dtype=torch.float32, device=device))
 
431
 
432
  next_points = torch.stack(nidx, dim=1)
433
  if next_points.shape[0] == 0:
434
+ # No active voxels at this refinement level — upsample coarse grid instead of
435
+ # filling with all-sentinel values (which cascade to all-NaN at extract time)
436
+ print(f"[TripoSG] skip level {octree_depth_now}: empty near-surface, upsampling coarse grid")
437
+ tgt = tuple(int(x) for x in grid_size)
438
+ grid_logits = F.interpolate(
439
+ grid_logits.float().unsqueeze(1), size=tgt, mode='trilinear', align_corners=False
440
+ ).squeeze(1).to(dtype)
441
  continue
442
  next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
443
  torch.tensor(bbox_min, dtype=torch.float32, device=device))