Upload external/MV-Adapter/mvadapter/utils/mesh_utils/cv_ops.py with huggingface_hub
Browse files
external/MV-Adapter/mvadapter/utils/mesh_utils/cv_ops.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
-
import
|
| 4 |
-
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def inpaint_cvc(
|
| 9 |
image: torch.Tensor,
|
|
@@ -14,11 +17,18 @@ def inpaint_cvc(
|
|
| 14 |
input_dtype = image.dtype
|
| 15 |
input_device = image.device
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 24 |
return output
|
|
@@ -31,11 +41,14 @@ def batch_inpaint_cvc(
|
|
| 31 |
padding_size: int,
|
| 32 |
return_dtype: Optional[torch.dtype] = None,
|
| 33 |
):
|
| 34 |
-
|
| 35 |
-
[
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
)
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def batch_erode(
|
|
@@ -43,19 +56,20 @@ def batch_erode(
|
|
| 43 |
):
|
| 44 |
input_dtype = masks.dtype
|
| 45 |
input_device = masks.device
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 57 |
-
return
|
| 58 |
-
return (
|
| 59 |
|
| 60 |
|
| 61 |
def batch_dilate(
|
|
@@ -63,16 +77,17 @@ def batch_dilate(
|
|
| 63 |
):
|
| 64 |
input_dtype = masks.dtype
|
| 65 |
input_device = masks.device
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 77 |
-
return
|
| 78 |
-
return (
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
+
import cvcuda
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout)
|
| 7 |
+
|
| 8 |
+
cvc_to_torch = lambda x, device: torch.tensor(x.cuda(), device=device)
|
| 9 |
+
|
| 10 |
|
| 11 |
def inpaint_cvc(
|
| 12 |
image: torch.Tensor,
|
|
|
|
| 17 |
input_dtype = image.dtype
|
| 18 |
input_device = image.device
|
| 19 |
|
| 20 |
+
image = image.detach()
|
| 21 |
+
mask = mask.detach()
|
| 22 |
|
| 23 |
+
if image.dtype != torch.uint8:
|
| 24 |
+
image = (image * 255).to(torch.uint8)
|
| 25 |
+
if mask.dtype != torch.uint8:
|
| 26 |
+
mask = (mask * 255).to(torch.uint8)
|
| 27 |
+
|
| 28 |
+
image_cvc = torch_to_cvc(image, "HWC")
|
| 29 |
+
mask_cvc = torch_to_cvc(mask, "HW")
|
| 30 |
+
output_cvc = cvcuda.inpaint(image_cvc, mask_cvc, padding_size)
|
| 31 |
+
output = cvc_to_torch(output_cvc, device=input_device)
|
| 32 |
|
| 33 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 34 |
return output
|
|
|
|
| 41 |
padding_size: int,
|
| 42 |
return_dtype: Optional[torch.dtype] = None,
|
| 43 |
):
|
| 44 |
+
output = torch.stack(
|
| 45 |
+
[
|
| 46 |
+
inpaint_cvc(image, mask, padding_size, return_dtype)
|
| 47 |
+
for (image, mask) in zip(images, masks)
|
| 48 |
+
],
|
| 49 |
+
axis=0,
|
| 50 |
)
|
| 51 |
+
return output
|
| 52 |
|
| 53 |
|
| 54 |
def batch_erode(
|
|
|
|
| 56 |
):
|
| 57 |
input_dtype = masks.dtype
|
| 58 |
input_device = masks.device
|
| 59 |
+
masks = masks.detach()
|
| 60 |
+
if masks.dtype != torch.uint8:
|
| 61 |
+
masks = (masks.float() * 255).to(torch.uint8)
|
| 62 |
+
masks_cvc = torch_to_cvc(masks[..., None], "NHWC")
|
| 63 |
+
masks_erode_cvc = cvcuda.morphology(
|
| 64 |
+
masks_cvc,
|
| 65 |
+
cvcuda.MorphologyType.ERODE,
|
| 66 |
+
maskSize=(kernel_size, kernel_size),
|
| 67 |
+
anchor=(-1, -1),
|
| 68 |
+
)
|
| 69 |
+
masks_erode = cvc_to_torch(masks_erode_cvc, device=input_device)[..., 0]
|
| 70 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 71 |
+
return masks_erode
|
| 72 |
+
return (masks_erode > 0).to(dtype=input_dtype)
|
| 73 |
|
| 74 |
|
| 75 |
def batch_dilate(
|
|
|
|
| 77 |
):
|
| 78 |
input_dtype = masks.dtype
|
| 79 |
input_device = masks.device
|
| 80 |
+
masks = masks.detach()
|
| 81 |
+
if masks.dtype != torch.uint8:
|
| 82 |
+
masks = (masks.float() * 255).to(torch.uint8)
|
| 83 |
+
masks_cvc = torch_to_cvc(masks[..., None], "NHWC")
|
| 84 |
+
masks_dilate_cvc = cvcuda.morphology(
|
| 85 |
+
masks_cvc,
|
| 86 |
+
cvcuda.MorphologyType.DILATE,
|
| 87 |
+
maskSize=(kernel_size, kernel_size),
|
| 88 |
+
anchor=(-1, -1),
|
| 89 |
+
)
|
| 90 |
+
masks_dilate = cvc_to_torch(masks_dilate_cvc, device=input_device)[..., 0]
|
| 91 |
if return_dtype == torch.uint8 or input_dtype == torch.uint8:
|
| 92 |
+
return masks_dilate
|
| 93 |
+
return (masks_dilate > 0).to(dtype=input_dtype)
|