File size: 5,810 Bytes
a68e3ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import torch
import numpy as np
from PIL import Image
from typing import List
import utils3d
from moge.model.v2 import MoGeModel
from utils.depth_utils import PointmapInfo, align_ground_to_z, crop_and_resize_foreground
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
def moge_preprocess(image: Image.Image, device) -> torch.Tensor:
"""Convert a PIL image to a normalized float32 CHW tensor on `device`."""
rgb = np.array(image.convert("RGB"))
return torch.tensor(rgb / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
# ---------------------------------------------------------------------------
# MoGe-based pointmap
# ---------------------------------------------------------------------------
class PointmapInfoMoGe(PointmapInfo):
"""
Concrete PointmapInfo implementation backed by the MoGe monocular depth estimator.
The MoGe model is loaded once and cached as a class-level attribute, so
subsequent instantiations reuse the same weights.
"""
# Shared across all instances to avoid redundant weight loading
moge_model: MoGeModel | None = None
def __init__(self, image: Image.Image, device: str = 'cuda'):
self._input_image = moge_preprocess(image, device)
# Run MoGe inference (no gradients needed)
with torch.no_grad():
if PointmapInfoMoGe.moge_model is None:
PointmapInfoMoGe.moge_model = (
MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
)
predictions = PointmapInfoMoGe.moge_model.infer(self._input_image)
# Mask out depth edges to suppress discontinuity artifacts
depth_edge_mask = utils3d.numpy.depth_edge(predictions['depth'].cpu().numpy(), rtol=0.04)
mask = predictions['mask'] & torch.from_numpy(~depth_edge_mask).to(device)
# Align the ground plane with the XY plane (+Z up)
points = predictions['points']
masked_points, _, R = align_ground_to_z(points[mask].reshape(-1, 3), return_transform=True)
# Move arrays to CPU/numpy for coordinate normalization
mask = mask.cpu().numpy()
points = points.cpu().numpy()
masked_points = masked_points.cpu().numpy()
self.intrinsic = predictions['intrinsics'].cpu().numpy()
# Normalize XY to [0, 1] and Z to a height relative to scene scale
mins = masked_points[:, :2].min(axis=0)
maxs = masked_points[:, :2].max(axis=0)
scaling = (maxs - mins).max()
height = masked_points[:, 2].max() / scaling
# Flip Z, center XY, and apply uniform scale
masked_points[:, 2] *= -1
masked_points[:, :2] = (masked_points[:, :2] - mins) / scaling + (1 - (maxs - mins) / scaling) / 2
masked_points[:, 2] -= masked_points[:, 2].min()
masked_points[:, 2] *= 1.0 / scaling
# Build the camera extrinsic [R | t] from the alignment transform
R = R.T
R[:, 2] *= -1
t = R @ np.array([*(mins / scaling - (1 - (maxs - mins) / scaling) / 2), -height])
t += R @ np.array([0.5, 0.5, 0.0])
# Permute axes from (y, x, z) to (x, y, z) convention
P = np.array([[0, 1, 0],
[1, 0, 0],
[0, 0, 1]])
self.intrinsic = P @ self.intrinsic @ P.T
R = P @ R @ P.T
t = P @ t
self.extrinsic = np.vstack((np.hstack((R, t.reshape(-1, 1))), [0, 0, 0, 1]))
# Store the full pointmap (with masked region filled in) for patch extraction
self.pc = masked_points
points[mask] = masked_points
self._pointmap = points
# -----------------------------------------------------------------------
# PointmapInfo interface
# -----------------------------------------------------------------------
def point_cloud(self) -> np.ndarray:
return self.pc
def camera_intrinsic(self) -> np.ndarray:
return self.intrinsic
def camera_extrinsic(self) -> np.ndarray:
return self.extrinsic
def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]:
"""
Slice the image into overlapping patches based on the normalized pointmap.
Args:
width: Number of tiles along the Y axis.
length: Number of tiles along the X axis.
div: Overlap subdivision factor (higher = more overlap).
Returns:
2D list of PIL images of shape [width*(div-1)+1][length*(div-1)+1].
"""
# Convert the input tensor back to a uint8 HWC numpy array
image_np = (self._input_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
patches = []
for i in range(width * div - div + 1):
row = []
for j in range(length * div - div + 1):
# Compute normalized [0, 1] bounds for this patch
y_start = i / (width * div)
x_start = j / (length * div)
y_end = y_start + 1.0 / width
x_end = x_start + 1.0 / length
# Mask pixels whose pointmap coordinates fall within this patch
pm = self._pointmap
in_patch = (
(y_start <= pm[:, :, 1]) & (pm[:, :, 1] < y_end) &
(x_start <= pm[:, :, 0]) & (pm[:, :, 0] < x_end)
)[:, :, None]
patch_np = np.where(in_patch, image_np, 0).astype(np.uint8)
patch_img = crop_and_resize_foreground(Image.fromarray(patch_np))
row.append(patch_img)
patches.append(row)
return patches
|