Stefano Demartis Claude Opus 4.6 commited on
Commit
89f99fe
·
1 Parent(s): e382bab

fix: disable spatial_correlation_sampler, use pure-PyTorch fallback

Browse files

The C++ extension cannot compile on ZeroGPU (PyTorch 2.10 / CUDA 12.8+
dispatch macro incompatibility + CUDA version mismatch).

model/attention.py already has a pure-PyTorch path (pad_and_unfold)
when enable_corr=False. Set all 3 defaults to False and remove
the install machinery from app.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +3 -71
  2. model/attention.py +3 -3
app.py CHANGED
@@ -232,70 +232,9 @@ def build_args_list_for_test(d16_batch_path: str,
232
  args.extend([flag, str(v)])
233
  return args
234
 
235
- # ===== ZeroGPU: install spatial-correlation-sampler at runtime =====
236
-
237
- def ensure_correlation_extension_installed():
238
- """Install spatial_correlation_sampler at runtime inside @spaces.GPU where CUDA is live."""
239
- try:
240
- import spatial_correlation_sampler
241
- return
242
- except ImportError:
243
- pass
244
-
245
- import torch
246
- cuda_home = os.environ.get("CUDA_HOME", "")
247
- nvcc_ok = shutil.which("nvcc") is not None
248
- diag = (f"torch={torch.__version__}, cuda={torch.version.cuda}, "
249
- f"torch.cuda.avail={torch.cuda.is_available()}, "
250
- f"CUDA_HOME={cuda_home!r}, nvcc={nvcc_ok}")
251
- print(f"[DEBUG] {diag}")
252
- # Store diagnostics for error reporting
253
- _diag_info = diag
254
-
255
- # Ensure CUDA_HOME is set so setuptools can find nvcc
256
- if not cuda_home:
257
- for candidate in ["/usr/local/cuda", "/usr/lib/cuda", "/opt/cuda"]:
258
- if path.isdir(candidate):
259
- os.environ["CUDA_HOME"] = candidate
260
- print(f"[INFO] Set CUDA_HOME={candidate}")
261
- break
262
-
263
- # Also make sure nvcc is in PATH
264
- cuda_home = os.environ.get("CUDA_HOME", "")
265
- if cuda_home and not nvcc_ok:
266
- cuda_bin = path.join(cuda_home, "bin")
267
- if path.isdir(cuda_bin):
268
- os.environ["PATH"] = cuda_bin + ":" + os.environ.get("PATH", "")
269
- print(f"[INFO] Added {cuda_bin} to PATH")
270
-
271
- print("[INFO] Installing spatial-correlation-sampler from git source...")
272
- result = subprocess.run(
273
- [sys.executable, "-m", "pip", "install", "--no-build-isolation",
274
- "git+https://github.com/ClementPinard/Pytorch-Correlation-extension.git"],
275
- capture_output=True, text=True, timeout=120,
276
- env={**os.environ},
277
- )
278
- if result.returncode != 0:
279
- print(f"[WARN] pip install from git failed (rc={result.returncode})")
280
- print(f"[WARN] stderr tail: {result.stderr[-1500:]}")
281
- # Fallback: try PyPI package
282
- print("[INFO] Trying pip install spatial-correlation-sampler from PyPI...")
283
- result2 = subprocess.run(
284
- [sys.executable, "-m", "pip", "install", "spatial-correlation-sampler"],
285
- capture_output=True, text=True, timeout=120,
286
- env={**os.environ},
287
- )
288
- if result2.returncode != 0:
289
- raise RuntimeError(
290
- f"Cannot install spatial-correlation-sampler.\n"
291
- f"Diag: {_diag_info}\n"
292
- f"Git stderr:\n{result.stderr[-2000:]}\n"
293
- f"PyPI stderr:\n{result2.stderr[-2000:]}"
294
- )
295
-
296
- importlib.invalidate_caches()
297
- import spatial_correlation_sampler # noqa: F811
298
- print("[INFO] spatial-correlation-sampler installed successfully.")
299
 
300
  # ----------------- GRADIO HANDLER -----------------
301
 
@@ -308,13 +247,6 @@ def gradio_infer(
308
  num_proto, top_k, mem_every, deep_update,
309
  save_scores, flip, size, reverse
310
  ):
311
- # <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>>
312
- try:
313
- ensure_correlation_extension_installed()
314
- except Exception as install_err:
315
- return None, f"spatial-correlation-sampler install failed:\n{install_err}"
316
- # --------------------------------------------------------------
317
-
318
  # 1) 基本校验与临时目录
319
  if bw_video is None:
320
  return None, "请上传黑白视频 / Please upload a B&W video."
 
232
  args.extend([flag, str(v)])
233
  return args
234
 
235
+ # ===== spatial_correlation_sampler: DISABLED =====
236
+ # The C++ extension is incompatible with PyTorch 2.10 / CUDA 12.8+ on ZeroGPU.
237
+ # model/attention.py now uses enable_corr=False (pure-PyTorch fallback).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  # ----------------- GRADIO HANDLER -----------------
240
 
 
247
  num_proto, top_k, mem_every, deep_update,
248
  save_scores, flip, size, reverse
249
  ):
 
 
 
 
 
 
 
250
  # 1) 基本校验与临时目录
251
  if bw_video is None:
252
  return None, "请上传黑白视频 / Please upload a B&W video."
model/attention.py CHANGED
@@ -130,7 +130,7 @@ class MultiheadLocalAttentionV1(nn.Module):
130
  max_dis=7,
131
  dilation=1,
132
  use_linear=True,
133
- enable_corr=True):
134
  super().__init__()
135
  self.dilation = dilation
136
  self.window_size = 2 * max_dis + 1
@@ -248,7 +248,7 @@ class MultiheadLocalAttentionV2(nn.Module):
248
  max_dis=7,
249
  dilation=1,
250
  use_linear=True,
251
- enable_corr=True,
252
  d_att=None,
253
  use_dis=False):
254
  super().__init__()
@@ -721,7 +721,7 @@ class LocalGatedPropagation(nn.Module):
721
  max_dis=7,
722
  dilation=1,
723
  use_linear=True,
724
- enable_corr=True,
725
  d_att=None,
726
  use_dis=False,
727
  expand_ratio=2.):
 
130
  max_dis=7,
131
  dilation=1,
132
  use_linear=True,
133
+ enable_corr=False):
134
  super().__init__()
135
  self.dilation = dilation
136
  self.window_size = 2 * max_dis + 1
 
248
  max_dis=7,
249
  dilation=1,
250
  use_linear=True,
251
+ enable_corr=False,
252
  d_att=None,
253
  use_dis=False):
254
  super().__init__()
 
721
  max_dis=7,
722
  dilation=1,
723
  use_linear=True,
724
+ enable_corr=False,
725
  d_att=None,
726
  use_dis=False,
727
  expand_ratio=2.):