Daankular commited on
Commit
51d578c
·
verified ·
1 Parent(s): 3d6d1b8

Upload external/MV-Adapter/mvadapter/utils/mesh_utils/cv_ops.py with huggingface_hub

Browse files
external/MV-Adapter/mvadapter/utils/mesh_utils/cv_ops.py CHANGED
@@ -1,9 +1,12 @@
1
  from typing import Optional
2
 
3
- import cv2
4
- import numpy as np
5
  import torch
6
 
 
 
 
 
7
 
8
  def inpaint_cvc(
9
  image: torch.Tensor,
@@ -14,11 +17,18 @@ def inpaint_cvc(
14
  input_dtype = image.dtype
15
  input_device = image.device
16
 
17
- img_np = (image.detach().float().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
18
- msk_np = (mask.detach().float().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
19
 
20
- result = cv2.inpaint(img_np, msk_np, padding_size, cv2.INPAINT_TELEA)
21
- output = torch.from_numpy(result).to(device=input_device)
 
 
 
 
 
 
 
22
 
23
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
24
  return output
@@ -31,11 +41,14 @@ def batch_inpaint_cvc(
31
  padding_size: int,
32
  return_dtype: Optional[torch.dtype] = None,
33
  ):
34
- return torch.stack(
35
- [inpaint_cvc(image, mask, padding_size, return_dtype)
36
- for image, mask in zip(images, masks)],
37
- dim=0,
 
 
38
  )
 
39
 
40
 
41
  def batch_erode(
@@ -43,19 +56,20 @@ def batch_erode(
43
  ):
44
  input_dtype = masks.dtype
45
  input_device = masks.device
46
-
47
- msk_np = (masks.detach().float().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
48
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
49
-
50
- if msk_np.ndim == 3:
51
- eroded = np.stack([cv2.erode(m, kernel) for m in msk_np])
52
- else:
53
- eroded = cv2.erode(msk_np, kernel)
54
-
55
- output = torch.from_numpy(eroded).to(device=input_device)
 
56
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
57
- return output
58
- return (output > 0).to(dtype=input_dtype)
59
 
60
 
61
  def batch_dilate(
@@ -63,16 +77,17 @@ def batch_dilate(
63
  ):
64
  input_dtype = masks.dtype
65
  input_device = masks.device
66
-
67
- msk_np = (masks.detach().float().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
68
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
69
-
70
- if msk_np.ndim == 3:
71
- dilated = np.stack([cv2.dilate(m, kernel) for m in msk_np])
72
- else:
73
- dilated = cv2.dilate(msk_np, kernel)
74
-
75
- output = torch.from_numpy(dilated).to(device=input_device)
 
76
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
77
- return output
78
- return (output > 0).to(dtype=input_dtype)
 
1
  from typing import Optional
2
 
3
+ import cvcuda
 
4
  import torch
5
 
6
+ torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout)
7
+
8
+ cvc_to_torch = lambda x, device: torch.tensor(x.cuda(), device=device)
9
+
10
 
11
  def inpaint_cvc(
12
  image: torch.Tensor,
 
17
  input_dtype = image.dtype
18
  input_device = image.device
19
 
20
+ image = image.detach()
21
+ mask = mask.detach()
22
 
23
+ if image.dtype != torch.uint8:
24
+ image = (image * 255).to(torch.uint8)
25
+ if mask.dtype != torch.uint8:
26
+ mask = (mask * 255).to(torch.uint8)
27
+
28
+ image_cvc = torch_to_cvc(image, "HWC")
29
+ mask_cvc = torch_to_cvc(mask, "HW")
30
+ output_cvc = cvcuda.inpaint(image_cvc, mask_cvc, padding_size)
31
+ output = cvc_to_torch(output_cvc, device=input_device)
32
 
33
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
34
  return output
 
41
  padding_size: int,
42
  return_dtype: Optional[torch.dtype] = None,
43
  ):
44
+ output = torch.stack(
45
+ [
46
+ inpaint_cvc(image, mask, padding_size, return_dtype)
47
+ for (image, mask) in zip(images, masks)
48
+ ],
49
+ axis=0,
50
  )
51
+ return output
52
 
53
 
54
  def batch_erode(
 
56
  ):
57
  input_dtype = masks.dtype
58
  input_device = masks.device
59
+ masks = masks.detach()
60
+ if masks.dtype != torch.uint8:
61
+ masks = (masks.float() * 255).to(torch.uint8)
62
+ masks_cvc = torch_to_cvc(masks[..., None], "NHWC")
63
+ masks_erode_cvc = cvcuda.morphology(
64
+ masks_cvc,
65
+ cvcuda.MorphologyType.ERODE,
66
+ maskSize=(kernel_size, kernel_size),
67
+ anchor=(-1, -1),
68
+ )
69
+ masks_erode = cvc_to_torch(masks_erode_cvc, device=input_device)[..., 0]
70
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
71
+ return masks_erode
72
+ return (masks_erode > 0).to(dtype=input_dtype)
73
 
74
 
75
  def batch_dilate(
 
77
  ):
78
  input_dtype = masks.dtype
79
  input_device = masks.device
80
+ masks = masks.detach()
81
+ if masks.dtype != torch.uint8:
82
+ masks = (masks.float() * 255).to(torch.uint8)
83
+ masks_cvc = torch_to_cvc(masks[..., None], "NHWC")
84
+ masks_dilate_cvc = cvcuda.morphology(
85
+ masks_cvc,
86
+ cvcuda.MorphologyType.DILATE,
87
+ maskSize=(kernel_size, kernel_size),
88
+ anchor=(-1, -1),
89
+ )
90
+ masks_dilate = cvc_to_torch(masks_dilate_cvc, device=input_device)[..., 0]
91
  if return_dtype == torch.uint8 or input_dtype == torch.uint8:
92
+ return masks_dilate
93
+ return (masks_dilate > 0).to(dtype=input_dtype)