Stefano Demartis Claude Opus 4.6 commited on
Commit
961bb51
·
1 Parent(s): 4896667

fix: robust spatial-correlation-sampler install with CUDA_HOME detection and fallback

Browse files

- Auto-detect CUDA_HOME from common paths (/usr/local/cuda, etc.)
- Add CUDA bin dir to PATH if nvcc not found
- Try git source first, fall back to PyPI
- Detailed debug logging for build failures

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +45 -8
app.py CHANGED
@@ -235,25 +235,62 @@ def build_args_list_for_test(d16_batch_path: str,
235
  # ===== ZeroGPU: install spatial-correlation-sampler at runtime =====
236
 
237
  def ensure_correlation_extension_installed():
238
- """Install spatial_correlation_sampler at runtime (needs CUDA)."""
239
  try:
240
  import spatial_correlation_sampler
241
  return
242
  except ImportError:
243
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  print("[INFO] Installing spatial-correlation-sampler from git source...")
245
  result = subprocess.run(
246
  [sys.executable, "-m", "pip", "install", "--no-build-isolation",
247
- "git+https://github.com/ClementPinard/Pytorch-Correlation-extension.git"],
248
- capture_output=True, text=True
 
249
  )
250
- print(f"[DEBUG] pip stdout: {result.stdout[-2000:]}")
251
- print(f"[DEBUG] pip stderr: {result.stderr[-2000:]}")
252
- print(f"[DEBUG] pip returncode: {result.returncode}")
253
  if result.returncode != 0:
254
- raise RuntimeError(f"Failed to install spatial-correlation-sampler:{result.stderr[-500:]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  importlib.invalidate_caches()
256
- import spatial_correlation_sampler
257
  print("[INFO] spatial-correlation-sampler installed successfully.")
258
 
259
  # ----------------- GRADIO HANDLER -----------------
 
235
  # ===== ZeroGPU: install spatial-correlation-sampler at runtime =====
236
 
237
  def ensure_correlation_extension_installed():
238
+ """Install spatial_correlation_sampler at runtime inside @spaces.GPU where CUDA is live."""
239
  try:
240
  import spatial_correlation_sampler
241
  return
242
  except ImportError:
243
  pass
244
+
245
+ import torch
246
+ cuda_home = os.environ.get("CUDA_HOME", "")
247
+ nvcc_ok = shutil.which("nvcc") is not None
248
+ print(f"[DEBUG] torch.cuda.is_available={torch.cuda.is_available()}, "
249
+ f"CUDA_HOME={cuda_home!r}, nvcc_in_PATH={nvcc_ok}")
250
+
251
+ # Ensure CUDA_HOME is set so setuptools can find nvcc
252
+ if not cuda_home:
253
+ for candidate in ["/usr/local/cuda", "/usr/lib/cuda", "/opt/cuda"]:
254
+ if path.isdir(candidate):
255
+ os.environ["CUDA_HOME"] = candidate
256
+ print(f"[INFO] Set CUDA_HOME={candidate}")
257
+ break
258
+
259
+ # Also make sure nvcc is in PATH
260
+ cuda_home = os.environ.get("CUDA_HOME", "")
261
+ if cuda_home and not nvcc_ok:
262
+ cuda_bin = path.join(cuda_home, "bin")
263
+ if path.isdir(cuda_bin):
264
+ os.environ["PATH"] = cuda_bin + ":" + os.environ.get("PATH", "")
265
+ print(f"[INFO] Added {cuda_bin} to PATH")
266
+
267
  print("[INFO] Installing spatial-correlation-sampler from git source...")
268
  result = subprocess.run(
269
  [sys.executable, "-m", "pip", "install", "--no-build-isolation",
270
+ "git+https://github.com/ClementPinard/Pytorch-Correlation-extension.git"],
271
+ capture_output=True, text=True, timeout=120,
272
+ env={**os.environ},
273
  )
 
 
 
274
  if result.returncode != 0:
275
+ print(f"[WARN] pip install from git failed (rc={result.returncode})")
276
+ print(f"[WARN] stderr tail: {result.stderr[-1500:]}")
277
+ # Fallback: try PyPI package
278
+ print("[INFO] Trying pip install spatial-correlation-sampler from PyPI...")
279
+ result2 = subprocess.run(
280
+ [sys.executable, "-m", "pip", "install", "spatial-correlation-sampler"],
281
+ capture_output=True, text=True, timeout=120,
282
+ env={**os.environ},
283
+ )
284
+ if result2.returncode != 0:
285
+ print(f"[ERROR] PyPI install also failed: {result2.stderr[-500:]}")
286
+ raise RuntimeError(
287
+ f"Cannot install spatial-correlation-sampler.\n"
288
+ f"Git error: {result.stderr[-300:]}\n"
289
+ f"PyPI error: {result2.stderr[-300:]}"
290
+ )
291
+
292
  importlib.invalidate_caches()
293
+ import spatial_correlation_sampler # noqa: F811
294
  print("[INFO] spatial-correlation-sampler installed successfully.")
295
 
296
  # ----------------- GRADIO HANDLER -----------------