Spaces:
Running
Running
Update binary_segmentation.py
Browse files- binary_segmentation.py +11 -1
binary_segmentation.py
CHANGED
|
@@ -448,6 +448,11 @@ class BinarySegmenter:
|
|
| 448 |
self.model = None
|
| 449 |
self.transform = None
|
| 450 |
self._load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
def _load_model(self):
|
| 453 |
"""Load the specified segmentation model"""
|
|
@@ -497,7 +502,9 @@ class BinarySegmenter:
|
|
| 497 |
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 498 |
'ZhengPeng7/BiRefNet',
|
| 499 |
trust_remote_code=True,
|
| 500 |
-
cache_dir=str(self.cache_dir)
|
|
|
|
|
|
|
| 501 |
)
|
| 502 |
|
| 503 |
self.transform = transforms.Compose([
|
|
@@ -559,6 +566,9 @@ class BinarySegmenter:
|
|
| 559 |
|
| 560 |
# Transform
|
| 561 |
input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
# Inference
|
| 564 |
with torch.no_grad():
|
|
|
|
| 448 |
self.model = None
|
| 449 |
self.transform = None
|
| 450 |
self._load_model()
|
| 451 |
+
|
| 452 |
+
if DEVICE == "cpu":
|
| 453 |
+
self.model = self.model.float()
|
| 454 |
+
self.model.to(DEVICE)
|
| 455 |
+
self.model.eval()
|
| 456 |
|
| 457 |
def _load_model(self):
|
| 458 |
"""Load the specified segmentation model"""
|
|
|
|
| 502 |
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 503 |
'ZhengPeng7/BiRefNet',
|
| 504 |
trust_remote_code=True,
|
| 505 |
+
cache_dir=str(self.cache_dir),
|
| 506 |
+
torch_dtype=torch.float32,
|
| 507 |
+
low_cpu_mem_usage=False
|
| 508 |
)
|
| 509 |
|
| 510 |
self.transform = transforms.Compose([
|
|
|
|
| 566 |
|
| 567 |
# Transform
|
| 568 |
input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 569 |
+
if DEVICE == "cpu":
|
| 570 |
+
input_tensor = input_tensor.float()
|
| 571 |
+
|
| 572 |
|
| 573 |
# Inference
|
| 574 |
with torch.no_grad():
|