Daankular commited on
Commit
b7c4667
·
1 Parent(s): 9336ed9

Pre-convert grid_logits to CPU before DiffDMC attempt (prevents CUDA ctx corruption in fallback)

Browse files
patches/triposg/triposg/inference_utils.py CHANGED
@@ -478,13 +478,18 @@ def flash_extract_geometry(
478
  torch.cuda.empty_cache()
479
  mesh_v_f = []
480
  grid_logits = grid_logits[0]
 
 
 
 
 
481
  try:
482
- print("final grids shape = ", grid_logits.shape)
483
  if DiffDMC is not None:
484
  print("[TripoSG] Attempting DiffDMC extraction...")
485
- dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
486
- sdf = -grid_logits / octree_resolution
487
- sdf = sdf.to(torch.float32).contiguous()
 
488
  vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
489
  vertices = vertices.detach().cpu().numpy()
490
  faces = faces.detach().cpu().numpy()[:, ::-1]
@@ -492,28 +497,17 @@ def flash_extract_geometry(
492
  mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
493
  print(f"[TripoSG] DiffDMC succeeded: {len(vertices)} vertices, {len(faces)} faces")
494
  else:
495
- # DiffDMC unavailable — fall back to marching cubes
496
- print("[TripoSG] DiffDMC unavailable, falling back to marching cubes")
497
- gl = grid_logits.float().cpu().numpy()
498
- gl_clean = np.nan_to_num(gl, nan=0.0)
499
- vertices, faces, normals, _ = measure.marching_cubes(gl_clean, 0, method="lewiner")
500
- vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
501
- mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
502
- print(f"[TripoSG] marching cubes succeeded: {len(vertices)} vertices")
503
  except Exception as e:
504
- import traceback as _tb
505
- print(f"[TripoSG] mesh extraction failed: {e}")
506
- print(_tb.format_exc())
507
  torch.cuda.empty_cache()
508
- # Last-resort: marching cubes on nan-cleaned grid
509
  try:
510
- gl_np = np.nan_to_num(grid_logits.float().cpu().numpy(), nan=0.0)
511
- vertices, faces, normals, _ = measure.marching_cubes(gl_np, 0, method="lewiner")
512
- vertices = vertices / max(gl_np.shape) * bbox_size + bbox_min
513
  mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
514
- print(f"[TripoSG] last-resort MC succeeded: {len(vertices)} vertices")
515
  except Exception as e2:
516
- print(f"[TripoSG] last-resort MC also failed: {e2}")
517
  mesh_v_f = (None, None)
518
 
519
  return [mesh_v_f]
 
478
  torch.cuda.empty_cache()
479
  mesh_v_f = []
480
  grid_logits = grid_logits[0]
481
+ print("final grids shape = ", grid_logits.shape)
482
+ # Pre-convert to CPU float32 before any CUDA-compiled ops (DiffDMC may corrupt CUDA ctx)
483
+ grid_logits_cpu = np.nan_to_num(grid_logits.float().cpu().numpy(), nan=0.0)
484
+ torch.cuda.empty_cache()
485
+ mesh_v_f = []
486
  try:
 
487
  if DiffDMC is not None:
488
  print("[TripoSG] Attempting DiffDMC extraction...")
489
+ gl_gpu = torch.from_numpy(grid_logits_cpu).to(grid_logits.device)
490
+ dmc = DiffDMC(dtype=torch.float32).to(gl_gpu.device)
491
+ sdf = -gl_gpu / octree_resolution
492
+ sdf = sdf.contiguous()
493
  vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
494
  vertices = vertices.detach().cpu().numpy()
495
  faces = faces.detach().cpu().numpy()[:, ::-1]
 
497
  mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
498
  print(f"[TripoSG] DiffDMC succeeded: {len(vertices)} vertices, {len(faces)} faces")
499
  else:
500
+ raise RuntimeError("DiffDMC unavailable — using marching cubes")
 
 
 
 
 
 
 
501
  except Exception as e:
502
+ print(f"[TripoSG] DiffDMC failed ({e}), falling back to marching cubes")
 
 
503
  torch.cuda.empty_cache()
 
504
  try:
505
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits_cpu, 0, method="lewiner")
506
+ vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
 
507
  mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
508
+ print(f"[TripoSG] marching cubes succeeded: {len(vertices)} vertices")
509
  except Exception as e2:
510
+ print(f"[TripoSG] marching cubes also failed: {e2}")
511
  mesh_v_f = (None, None)
512
 
513
  return [mesh_v_f]