Spaces:
Paused
Paused
| """SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128).""" | |
| import os, sys, subprocess | |
| os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") | |
| os.environ.setdefault("CONDA_PREFIX", "/usr/local") | |
| os.environ["LIDRA_SKIP_INIT"] = "true" | |
| os.environ["ATTN_BACKEND"] = "sdpa" | |
| os.environ["SPARSE_ATTN_BACKEND"] = "sdpa" | |
| os.environ["SPARSE_BACKEND"] = "spconv" | |
| # MUST import spaces before torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download, login | |
| import tempfile | |
| from pathlib import Path | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ["HF_TOKEN"]) | |
| # --- Stubs (must be before sam3d imports) --- | |
| STUB_KAOLIN = Path("/home/user/app/kaolin_stub") | |
| STUB_PT3D = Path("/home/user/app/pytorch3d_stub") | |
| STUB_FA = Path("/home/user/app/flash_attn_stub") | |
| for stub in [STUB_KAOLIN, STUB_PT3D, STUB_FA]: | |
| if stub.exists(): | |
| sys.path.insert(0, str(stub)) | |
| print(f"Stub added: {stub.name}") | |
| # --- Runtime pip installs --- | |
| def _pip(*a): | |
| r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a), | |
| capture_output=True, text=True, timeout=1200) | |
| ok = r.returncode == 0 | |
| tag = a[-1][:50] if a else "?" | |
| if ok: | |
| print(f" pip OK: {tag}") | |
| else: | |
| print(f" pip FAIL: {tag}") | |
| print(f" {r.stderr[-300:]}") | |
| return ok | |
| print("=== Runtime installs ===") | |
| _pip("open3d>=0.18.0") | |
| _pip("--no-deps", "git+https://github.com/EasternJournalist/utils3d.git") # --no-deps: skip jupyter dependency | |
| _pip("iopath") | |
| _pip("--no-deps", "sam2>=1.1.0") | |
| _pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b") | |
| # gsplat | |
| for idx in ["https://docs.gsplat.studio/whl/pt210cu128", | |
| "https://docs.gsplat.studio/whl/pt28cu128"]: | |
| if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"): | |
| break | |
| # spconv (sparse convolution – needed for SAM3D's SLatFlowModel) | |
| # cu124 wheel is forward-compatible with cu128 | |
| _pip("spconv-cu124==2.3.8") | |
| # DO NOT import CUDA-dependent packages here! | |
| # --- Clone sam-3d-objects --- | |
| SAM3D_PATH = Path("/home/user/app/sam-3d-objects") | |
| if not SAM3D_PATH.exists(): | |
| print("Cloning sam-3d-objects...") | |
| subprocess.run(["git", "clone", "--depth", "1", | |
| "https://github.com/facebookresearch/sam-3d-objects.git", | |
| str(SAM3D_PATH)], check=True) | |
| subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"], | |
| capture_output=True, text=True) | |
| # Hydra patch | |
| patch = SAM3D_PATH / "patching" / "hydra" | |
| if patch.exists(): | |
| subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH)) | |
| # CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn | |
| # inference_pipeline.py auto-detects H200/A100 and forces flash_attn, | |
| # but we don't have the real flash_attn package. | |
| ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py" | |
| if ip_file.exists(): | |
| ip_src = ip_file.read_text() | |
| # Find and replace the set_attention_backend function | |
| old_marker = 'os.environ["ATTN_BACKEND"] = "flash_attn"' | |
| if old_marker in ip_src: | |
| # Replace the entire if-block that forces flash_attn | |
| ip_src = ip_src.replace( | |
| 'if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:\n' | |
| ' # logger.info("Use flash_attn")\n' | |
| ' os.environ["ATTN_BACKEND"] = "flash_attn"\n' | |
| ' os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn"', | |
| '# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)\n' | |
| ' logger.info("Using sdpa backend (patched for ZeroGPU)")\n' | |
| ' os.environ.setdefault("ATTN_BACKEND", "sdpa")\n' | |
| ' os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")' | |
| ) | |
| ip_file.write_text(ip_src) | |
| print("PATCHED: inference_pipeline.py - forced sdpa backend") | |
| else: | |
| print("INFO: inference_pipeline.py already patched or different version") | |
| sys.path.insert(0, str(SAM3D_PATH)) | |
| sys.path.insert(0, str(SAM3D_PATH / "notebook")) | |
| # --- Monkey-patch: inject depth_edge into utils3d.numpy --- | |
| # utils3d package lacks depth_edge in newer versions; SAM3D needs it for layout post-optimization | |
| try: | |
| import utils3d.numpy as _u3d_np | |
| if not hasattr(_u3d_np, 'depth_edge'): | |
| def _depth_edge(depth, rtol=0.03, mask=None): | |
| from scipy.ndimage import sobel | |
| import numpy as _np | |
| d = _np.where(mask, depth, 0.0) if mask is not None else depth.copy() | |
| gx = sobel(d, axis=1) | |
| gy = sobel(d, axis=0) | |
| grad = _np.sqrt(gx**2 + gy**2) | |
| denom = _np.abs(d) | |
| denom[denom < 1e-6] = 1e-6 | |
| edge = (grad / denom) > rtol | |
| if mask is not None: | |
| edge = edge & mask | |
| return edge | |
| _u3d_np.depth_edge = _depth_edge | |
| def _normals_edge(normals, tol=0.1, mask=None): | |
| """Detect normal discontinuities.""" | |
| import numpy as _np | |
| from scipy.ndimage import sobel | |
| # Compute gradient of each normal component | |
| edges = _np.zeros(normals.shape[:2], dtype=bool) | |
| for c in range(normals.shape[-1]): | |
| ch = normals[..., c] | |
| if mask is not None: | |
| ch = _np.where(mask, ch, 0.0) | |
| gx = sobel(ch, axis=1) | |
| gy = sobel(ch, axis=0) | |
| grad = _np.sqrt(gx**2 + gy**2) | |
| edges |= (grad > tol) | |
| if mask is not None: | |
| edges = edges & mask | |
| return edges | |
| _u3d_np.normals_edge = _normals_edge | |
| # Also inject a catch-all __getattr__ for any future missing functions | |
| _orig_getattr = getattr(_u3d_np, '__getattr__', None) | |
| def _u3d_catchall(name): | |
| if name.startswith('__') and name.endswith('__'): | |
| raise AttributeError(name) | |
| import warnings | |
| warnings.warn(f"utils3d.numpy stub: {name} not implemented, returning dummy") | |
| def _dummy(*a, **kw): | |
| import numpy as _np | |
| return _np.zeros(1) | |
| return _dummy | |
| import types | |
| _u3d_np.__getattr__ = _u3d_catchall | |
| print("Injected depth_edge + normals_edge + catch-all into utils3d.numpy") | |
| except Exception as e: | |
| print(f"depth_edge patch skipped: {e}") | |
| # --- Pre-download checkpoints --- | |
| print("Downloading SAM3D checkpoints...") | |
| CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects", | |
| token=os.environ.get("HF_TOKEN")) | |
| hf_ckpt = Path(CKPT_DIR) / "checkpoints" | |
| local_ckpt = SAM3D_PATH / "checkpoints" / "hf" | |
| if hf_ckpt.exists() and not local_ckpt.exists(): | |
| local_ckpt.parent.mkdir(parents=True, exist_ok=True) | |
| local_ckpt.symlink_to(hf_ckpt) | |
| CONFIG_PATH = str(local_ckpt / "pipeline.yaml") | |
| print(f"Config exists: {Path(CONFIG_PATH).exists()}") | |
| print("=== Startup complete ===") | |
| # --- Endpoints --- | |
| def diagnose(): | |
| import torch | |
| lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"] | |
| if torch.cuda.is_available(): | |
| lines.append(f"gpu={torch.cuda.get_device_name()}") | |
| for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]: | |
| try: | |
| m = __import__(mod) | |
| lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})") | |
| except Exception as e: | |
| lines.append(f"{mod}: FAIL - {e}") | |
| try: | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| lines.append("sam2: OK") | |
| except Exception as e: | |
| lines.append(f"sam2: FAIL - {e}") | |
| try: | |
| from inference import Inference | |
| lines.append("SAM3D Inference: importable") | |
| except Exception as e: | |
| lines.append(f"SAM3D Inference: FAIL - {e}") | |
| lines.append(f"config: {Path(CONFIG_PATH).exists()}") | |
| return "\n".join(lines) | |
| def reconstruct_objects(image: np.ndarray): | |
| if image is None: | |
| return None, None, "No image" | |
| try: | |
| import torch, trimesh, time | |
| t0 = time.time() | |
| print(f"GPU: {torch.cuda.get_device_name()}") | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| print(f" Loading SAM2... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-small") | |
| print(f" SAM2 loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| image_np = np.array(image) if not isinstance(image, np.ndarray) else image | |
| masks = sam2_gen.generate(image_np) | |
| if not masks: | |
| return None, image_np, "No objects detected" | |
| masks = sorted(masks, key=lambda x: x["area"], reverse=True) | |
| best_mask = masks[0]["segmentation"] | |
| preview = image_np.copy() | |
| preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| print(f" {len(masks)} masks ({time.time()-t0:.0f}s)") | |
| # Free SAM2 to save VRAM for SAM3D | |
| del sam2_gen | |
| torch.cuda.empty_cache() | |
| print(f" SAM2 freed (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| from inference import Inference | |
| print(f" Loading SAM3D... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| sam3d = Inference(CONFIG_PATH, compile=False) | |
| print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| print(f" Running reconstruction... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| result = sam3d(image=image_np, mask=best_mask, seed=42) | |
| print(f" Reconstructed ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| if result is None: | |
| return None, preview, "Reconstruction returned None" | |
| od = tempfile.mkdtemp() | |
| glb = f"{od}/object.glb" | |
| gs = None | |
| if hasattr(result, "save_ply"): | |
| gs = result | |
| elif isinstance(result, dict): | |
| for k in ("gs", "gaussian", "gaussians", "scene"): | |
| v = result.get(k) | |
| if v is not None: | |
| gs = v[0] if isinstance(v, (list, tuple)) else v | |
| break | |
| if gs is not None and hasattr(gs, "save_ply"): | |
| ply = f"{od}/temp.ply" | |
| gs.save_ply(ply) | |
| import open3d as o3d | |
| pcd = o3d.io.read_point_cloud(ply) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| elif gs is not None and hasattr(gs, "_xyz"): | |
| import open3d as o3d | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy()) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| elif isinstance(result, dict) and "mesh" in result: | |
| m = result["mesh"] | |
| if hasattr(m, "export"): | |
| m.export(glb) | |
| else: | |
| keys = list(result.keys()) if isinstance(result, dict) else dir(result) | |
| return None, preview, f"Cannot extract 3D. Keys: {keys}" | |
| n = 0 | |
| try: | |
| n = len(trimesh.load(glb, force="mesh").faces) | |
| except Exception: | |
| pass | |
| elapsed = int(time.time() - t0) | |
| return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)" | |
| except Exception as e: | |
| import traceback | |
| tb = traceback.format_exc() | |
| print(tb) | |
| return None, None, f"Error:\n{tb[-1500:]}" | |
| def test_sam3d_only(image: np.ndarray): | |
| """Test SAM3D reconstruction with center-crop mask (no SAM2).""" | |
| if image is None: | |
| return None, None, "No image" | |
| try: | |
| import torch, time, gc | |
| t0 = time.time() | |
| print(f"GPU: {torch.cuda.get_device_name()}, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB") | |
| image_np = np.array(image) if not isinstance(image, np.ndarray) else image | |
| h, w = image_np.shape[:2] | |
| # Create a center mask (middle 60% of image) | |
| mask = np.zeros((h, w), dtype=bool) | |
| y1, y2 = int(h * 0.2), int(h * 0.8) | |
| x1, x2 = int(w * 0.2), int(w * 0.8) | |
| mask[y1:y2, x1:x2] = True | |
| preview = image_np.copy() | |
| preview[mask] = (preview[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)") | |
| from inference import Inference | |
| print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB") | |
| sam3d = Inference(CONFIG_PATH, compile=False) | |
| print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| print(f" Running reconstruction...") | |
| result = sam3d(image=image_np, mask=mask, seed=42) | |
| print(f" Done ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)") | |
| if result is None: | |
| return None, preview, "Reconstruction returned None" | |
| import tempfile | |
| od = tempfile.mkdtemp() | |
| glb = f"{od}/object.glb" | |
| gs = None | |
| if isinstance(result, dict): | |
| for k in ("gs", "gaussian", "gaussians", "scene"): | |
| v = result.get(k) | |
| if v is not None: | |
| gs = v[0] if isinstance(v, (list, tuple)) else v | |
| break | |
| if gs is not None and hasattr(gs, "save_ply"): | |
| ply = f"{od}/temp.ply" | |
| gs.save_ply(ply) | |
| import open3d as o3d | |
| pcd = o3d.io.read_point_cloud(ply) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| elif gs is not None and hasattr(gs, "_xyz"): | |
| import open3d as o3d | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy()) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| else: | |
| keys = list(result.keys()) if isinstance(result, dict) else dir(result) | |
| return None, preview, f"Cannot extract 3D. Keys: {keys}" | |
| import trimesh | |
| n = 0 | |
| try: | |
| n = len(trimesh.load(glb, force="mesh").faces) | |
| except: pass | |
| elapsed = int(time.time() - t0) | |
| return glb, preview, f"OK: {n:,} faces ({elapsed}s)" | |
| except Exception as e: | |
| import traceback | |
| tb = traceback.format_exc() | |
| print(tb) | |
| return None, None, f"Error:\n{tb[-1500:]}" | |
| # --- UI --- | |
| with gr.Blocks(title="SAM 3D Objects") as demo: | |
| gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.") | |
| with gr.Tab("Reconstruct"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(label="Input", type="numpy") | |
| btn = gr.Button("Reconstruct", variant="primary", size="lg") | |
| with gr.Column(): | |
| prev = gr.Image(label="Detection", type="numpy", interactive=False) | |
| stat = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| m3d = gr.Model3D(label="3D Preview") | |
| dl = gr.File(label="Download GLB") | |
| btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat]) | |
| m3d.change(lambda x: x, inputs=[m3d], outputs=[dl]) | |
| with gr.Tab("Test SAM3D Only"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| tinp = gr.Image(label="Input", type="numpy") | |
| tbtn = gr.Button("Test SAM3D (no SAM2)", variant="primary") | |
| with gr.Column(): | |
| tprev = gr.Image(label="Mask Preview", type="numpy", interactive=False) | |
| tstat = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| tm3d = gr.Model3D(label="3D Preview") | |
| tbtn.click(test_sam3d_only, inputs=[tinp], outputs=[tm3d, tprev, tstat]) | |
| with gr.Tab("Diagnose"): | |
| dbtn = gr.Button("Diagnose GPU & Modules") | |
| dout = gr.Textbox(lines=15) | |
| dbtn.click(diagnose, outputs=[dout]) | |
| demo.launch(mcp_server=True) | |