jboth commited on
Commit
31053b5
·
verified ·
1 Parent(s): e213e91

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +0 -181
app.py CHANGED
@@ -1,181 +0,0 @@
1
- """SAM 3D Objects – pinned torch 2.8.0 for kaolin ABI compat."""
2
- import os, sys, subprocess
3
- os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
4
- os.environ.setdefault("CONDA_PREFIX", "/usr/local")
5
- os.environ["LIDRA_SKIP_INIT"] = "true"
6
-
7
- import spaces
8
- import gradio as gr
9
- import numpy as np
10
- from PIL import Image
11
- from huggingface_hub import snapshot_download, login
12
- import tempfile, uuid
13
- from pathlib import Path
14
-
15
- if os.environ.get("HF_TOKEN"):
16
- login(token=os.environ["HF_TOKEN"])
17
-
18
- # Runtime installs for things that need source builds
19
- def _pip(*a):
20
- r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a),
21
- capture_output=True, text=True, timeout=1200)
22
- return r.returncode == 0
23
-
24
- _pip("utils3d")
25
- _pip("iopath")
26
- _pip("--no-deps", "pytorch3d")
27
-
28
- # gsplat
29
- for idx in ["https://docs.gsplat.studio/whl/pt28cu128",
30
- "https://docs.gsplat.studio/whl/pt27cu128",
31
- "https://docs.gsplat.studio/whl/pt26cu124"]:
32
- if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"):
33
- break
34
-
35
- _pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b")
36
-
37
- # Clone sam-3d-objects
38
- SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
39
- if not SAM3D_PATH.exists():
40
- subprocess.run(["git", "clone", "--depth", "1",
41
- "https://github.com/facebookresearch/sam-3d-objects.git", str(SAM3D_PATH)], check=True)
42
- subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"],
43
- capture_output=True, text=True)
44
- patch = SAM3D_PATH / "patching" / "hydra"
45
- if patch.exists():
46
- subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
47
- sys.path.insert(0, str(SAM3D_PATH))
48
- sys.path.insert(0, str(SAM3D_PATH / "notebook"))
49
-
50
- # Pre-download checkpoints
51
- CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects", token=os.environ.get("HF_TOKEN"))
52
- hf_ckpt = Path(CKPT_DIR) / "checkpoints"
53
- local_ckpt = SAM3D_PATH / "checkpoints" / "hf"
54
- if hf_ckpt.exists() and not local_ckpt.exists():
55
- local_ckpt.parent.mkdir(parents=True, exist_ok=True)
56
- local_ckpt.symlink_to(hf_ckpt)
57
- CONFIG_PATH = str(local_ckpt / "pipeline.yaml")
58
- print(f"Config exists: {Path(CONFIG_PATH).exists()}")
59
-
60
- # Verify
61
- for mod in ["torch", "kaolin", "gsplat", "open3d", "sam2"]:
62
- try:
63
- m = __import__(mod)
64
- print(f" {mod}={getattr(m, '__version__', 'ok')}")
65
- except Exception as e:
66
- print(f" {mod}: {e}")
67
- try:
68
- import kaolin; kaolin.ops.mesh
69
- print(" kaolin C++: OK")
70
- except Exception as e:
71
- print(f" kaolin C++: {e}")
72
-
73
- print("=== Setup done ===")
74
-
75
- SAM3D_MODEL = None
76
- SAM2_GEN = None
77
-
78
- @spaces.GPU(duration=60)
79
- def diagnose():
80
- import torch
81
- lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"]
82
- if torch.cuda.is_available():
83
- lines.append(f"gpu={torch.cuda.get_device_name()}")
84
- for mod in ["kaolin", "gsplat", "open3d", "sam2", "utils3d"]:
85
- try:
86
- m = __import__(mod)
87
- lines.append(f"{mod}={getattr(m, '__version__', 'ok')}")
88
- except Exception as e:
89
- lines.append(f"{mod}: {e}")
90
- try:
91
- import kaolin; kaolin.ops.mesh
92
- lines.append("kaolin C++: OK")
93
- except Exception as e:
94
- lines.append(f"kaolin C++: {e}")
95
- return "\n".join(lines)
96
-
97
- @spaces.GPU(duration=300)
98
- def reconstruct_objects(image: np.ndarray):
99
- global SAM3D_MODEL, SAM2_GEN
100
- if image is None:
101
- return None, None, "No image"
102
- try:
103
- import torch, trimesh, time
104
- t0 = time.time()
105
-
106
- if SAM2_GEN is None:
107
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
108
- SAM2_GEN = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large")
109
- image_np = np.array(image) if not isinstance(image, np.ndarray) else image
110
- masks = SAM2_GEN.generate(image_np)
111
- if not masks:
112
- return None, image_np, "No objects"
113
- masks = sorted(masks, key=lambda x: x["area"], reverse=True)
114
- best_mask = masks[0]["segmentation"]
115
- preview = image_np.copy()
116
- preview[best_mask] = (preview[best_mask]*0.5 + np.array([0,255,0])*0.5).astype(np.uint8)
117
- print(f" SAM2: {len(masks)} masks ({time.time()-t0:.0f}s)")
118
-
119
- if SAM3D_MODEL is None:
120
- from inference import Inference
121
- SAM3D_MODEL = Inference(CONFIG_PATH, compile=False)
122
- print(f" SAM3D loaded ({time.time()-t0:.0f}s)")
123
-
124
- result = SAM3D_MODEL(image=image_np, mask=best_mask, seed=42)
125
- print(f" Reconstructed ({time.time()-t0:.0f}s)")
126
-
127
- if result is None:
128
- return None, preview, "Reconstruction None"
129
-
130
- od = tempfile.mkdtemp()
131
- glb = f"{od}/obj.glb"
132
- gs=None
133
- if hasattr(result,"save_ply"): gs=result
134
- elif isinstance(result,dict):
135
- for k in("gs","gaussian","gaussians"):
136
- v=result.get(k)
137
- if v: gs=v[0] if isinstance(v,(list,tuple)) else v; break
138
- if gs and hasattr(gs,"save_ply"):
139
- ply=f"{od}/t.ply"; gs.save_ply(ply)
140
- import open3d as o3d
141
- p=o3d.io.read_point_cloud(ply); p.estimate_normals()
142
- m,_=o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(p,depth=8)
143
- o3d.io.write_triangle_mesh(glb,m)
144
- elif gs and hasattr(gs,"_xyz"):
145
- import open3d as o3d
146
- p=o3d.geometry.PointCloud()
147
- p.points=o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
148
- p.estimate_normals()
149
- m,_=o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(p,depth=8)
150
- o3d.io.write_triangle_mesh(glb,m)
151
- else:
152
- return None,preview,f"No 3D: {type(result)}"
153
-
154
- n=0
155
- try: n=len(trimesh.load(glb,force="mesh").faces)
156
- except: pass
157
- return glb,preview,f"OK: {n:,} faces ({int(time.time()-t0)}s)"
158
- except Exception as e:
159
- import traceback; traceback.print_exc()
160
- return None,None,f"Error: {e}"
161
-
162
- with gr.Blocks(title="SAM 3D Objects") as demo:
163
- gr.Markdown("# SAM 3D Objects\nImage -> 3D (GLB)")
164
- with gr.Tab("Reconstruct"):
165
- with gr.Row():
166
- with gr.Column():
167
- inp=gr.Image(label="Input",type="numpy")
168
- btn=gr.Button("Reconstruct",variant="primary",size="lg")
169
- with gr.Column():
170
- prev=gr.Image(label="Detection",type="numpy",interactive=False)
171
- stat=gr.Textbox(label="Status")
172
- with gr.Row():
173
- m3d=gr.Model3D(label="3D Preview")
174
- dl=gr.File(label="Download GLB")
175
- btn.click(reconstruct_objects,inputs=[inp],outputs=[m3d,prev,stat])
176
- m3d.change(lambda x:x,inputs=[m3d],outputs=[dl])
177
- with gr.Tab("Diagnose"):
178
- dbtn=gr.Button("GPU Diagnose")
179
- dout=gr.Textbox(label="Env",lines=15)
180
- dbtn.click(diagnose,outputs=[dout])
181
- demo.launch(mcp_server=True)