Steewee98 commited on
Commit
6047a0c
·
verified ·
1 Parent(s): 5225c04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -53
app.py CHANGED
@@ -233,60 +233,23 @@ def build_args_list_for_test(d16_batch_path: str,
233
  return args
234
 
235
  # ===== 新增:ZeroGPU 后按需安装 Pytorch-Correlation-extension =====
236
- _CORR_OK_STAMP = path.join(os.getcwd(), ".corr_ext_installed")
237
-
238
- def ensure_correlation_extension_installed():
239
- """
240
- 在 ZeroGPU 分配后调用(位于 @spaces.GPU 函数体内):
241
- - 若已能 import 或存在本地 stamp,则直接返回
242
- - 否则执行:
243
- git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git
244
- cd Pytorch-Correlation-extension && python setup.py install && cd ..
245
- """
246
- # 1) 尝试直接导入
247
- try:
248
- import spatial_correlation_sampler # noqa: F401
249
- return
250
- except Exception:
251
- pass
252
- # 2) 之前成功过(打过 stamp)
253
- if path.exists(_CORR_OK_STAMP):
254
- return
255
-
256
- repo_url = "https://github.com/ClementPinard/Pytorch-Correlation-extension.git"
257
- workdir = tempfile.mkdtemp(prefix="corr_ext_")
258
- repo_dir = path.join(workdir, "Pytorch-Correlation-extension")
259
- try:
260
- print("[INFO] Installing Pytorch-Correlation-extension ...")
261
- # clone
262
- subprocess.run(
263
- ["git", "clone", "--depth", "1", repo_url],
264
- cwd=workdir, check=True
265
- )
266
- # build & install
267
- subprocess.run(
268
- [sys.executable, "setup.py", "install"],
269
- cwd=repo_dir, check=True
270
- )
271
- # 验证
272
- importlib.invalidate_caches()
273
- import spatial_correlation_sampler # noqa: F401
274
- # 打 stamp,避免下次重复
275
- with open(_CORR_OK_STAMP, "w") as f:
276
- f.write("ok")
277
- print("[INFO] Pytorch-Correlation-extension installed successfully.")
278
- except subprocess.CalledProcessError as e:
279
- print(f"[WARN] Failed to build/install correlation extension: {e}")
280
- print("You can still proceed if your pipeline doesn't use it.")
281
- except Exception as e:
282
- print(f"[WARN] Correlation extension install check failed: {e}")
283
- finally:
284
- # 清理临时目录
285
- try:
286
- shutil.rmtree(workdir, ignore_errors=True)
287
- except Exception:
288
- pass
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  # ----------------- GRADIO HANDLER -----------------
291
  @spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
292
  def gradio_infer(
 
233
  return args
234
 
235
  # ===== 新增:ZeroGPU 后按需安装 Pytorch-Correlation-extension =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ def ensure_correlation_extension_installed():
238
+ """Install spatial_correlation_sampler at runtime (needs CUDA)."""
239
+ try:
240
+ import spatial_correlation_sampler
241
+ return
242
+ except ImportError:
243
+ pass
244
+ print("[INFO] Installing spatial-correlation-sampler via pip (needs CUDA)...")
245
+ subprocess.run(
246
+ [sys.executable, "-m", "pip", "install", "spatial-correlation-sampler"],
247
+ check=True
248
+ )
249
+ importlib.invalidate_caches()
250
+ import spatial_correlation_sampler # verify
251
+ print("[INFO] spatial-correlation-sampler installed successfully.")
252
+
253
  # ----------------- GRADIO HANDLER -----------------
254
  @spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
255
  def gradio_infer(