| import torch |
| import numpy as np |
| from PIL import Image |
| from typing import List |
|
|
| import utils3d |
| from moge.model.v2 import MoGeModel |
|
|
| from utils.depth_utils import PointmapInfo, align_ground_to_z, crop_and_resize_foreground |
|
|
|
|
| |
| |
| |
|
|
| def moge_preprocess(image: Image.Image, device) -> torch.Tensor: |
| """Convert a PIL image to a normalized float32 CHW tensor on `device`.""" |
| rgb = np.array(image.convert("RGB")) |
| return torch.tensor(rgb / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) |
|
|
|
|
| |
| |
| |
|
|
| class PointmapInfoMoGe(PointmapInfo): |
| """ |
| Concrete PointmapInfo implementation backed by the MoGe monocular depth estimator. |
| |
| The MoGe model is loaded once and cached as a class-level attribute, so |
| subsequent instantiations reuse the same weights. |
| """ |
|
|
| |
| moge_model: MoGeModel | None = None |
|
|
| def __init__(self, image: Image.Image, device: str = 'cuda'): |
| self._input_image = moge_preprocess(image, device) |
|
|
| |
| with torch.no_grad(): |
| if PointmapInfoMoGe.moge_model is None: |
| PointmapInfoMoGe.moge_model = ( |
| MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device) |
| ) |
| predictions = PointmapInfoMoGe.moge_model.infer(self._input_image) |
|
|
| |
| depth_edge_mask = utils3d.numpy.depth_edge(predictions['depth'].cpu().numpy(), rtol=0.04) |
| mask = predictions['mask'] & torch.from_numpy(~depth_edge_mask).to(device) |
|
|
| |
| points = predictions['points'] |
| masked_points, _, R = align_ground_to_z(points[mask].reshape(-1, 3), return_transform=True) |
|
|
| |
| mask = mask.cpu().numpy() |
| points = points.cpu().numpy() |
| masked_points = masked_points.cpu().numpy() |
| self.intrinsic = predictions['intrinsics'].cpu().numpy() |
|
|
| |
| mins = masked_points[:, :2].min(axis=0) |
| maxs = masked_points[:, :2].max(axis=0) |
| scaling = (maxs - mins).max() |
| height = masked_points[:, 2].max() / scaling |
|
|
| |
| masked_points[:, 2] *= -1 |
| masked_points[:, :2] = (masked_points[:, :2] - mins) / scaling + (1 - (maxs - mins) / scaling) / 2 |
| masked_points[:, 2] -= masked_points[:, 2].min() |
| masked_points[:, 2] *= 1.0 / scaling |
|
|
| |
| R = R.T |
| R[:, 2] *= -1 |
| t = R @ np.array([*(mins / scaling - (1 - (maxs - mins) / scaling) / 2), -height]) |
| t += R @ np.array([0.5, 0.5, 0.0]) |
|
|
| |
| P = np.array([[0, 1, 0], |
| [1, 0, 0], |
| [0, 0, 1]]) |
| self.intrinsic = P @ self.intrinsic @ P.T |
| R = P @ R @ P.T |
| t = P @ t |
|
|
| self.extrinsic = np.vstack((np.hstack((R, t.reshape(-1, 1))), [0, 0, 0, 1])) |
|
|
| |
| self.pc = masked_points |
| points[mask] = masked_points |
| self._pointmap = points |
|
|
| |
| |
| |
|
|
| def point_cloud(self) -> np.ndarray: |
| return self.pc |
|
|
| def camera_intrinsic(self) -> np.ndarray: |
| return self.intrinsic |
|
|
| def camera_extrinsic(self) -> np.ndarray: |
| return self.extrinsic |
|
|
| def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]: |
| """ |
| Slice the image into overlapping patches based on the normalized pointmap. |
| |
| Args: |
| width: Number of tiles along the Y axis. |
| length: Number of tiles along the X axis. |
| div: Overlap subdivision factor (higher = more overlap). |
| |
| Returns: |
| 2D list of PIL images of shape [width*(div-1)+1][length*(div-1)+1]. |
| """ |
| |
| image_np = (self._input_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
| patches = [] |
| for i in range(width * div - div + 1): |
| row = [] |
| for j in range(length * div - div + 1): |
| |
| y_start = i / (width * div) |
| x_start = j / (length * div) |
| y_end = y_start + 1.0 / width |
| x_end = x_start + 1.0 / length |
|
|
| |
| pm = self._pointmap |
| in_patch = ( |
| (y_start <= pm[:, :, 1]) & (pm[:, :, 1] < y_end) & |
| (x_start <= pm[:, :, 0]) & (pm[:, :, 0] < x_end) |
| )[:, :, None] |
| patch_np = np.where(in_patch, image_np, 0).astype(np.uint8) |
|
|
| patch_img = crop_and_resize_foreground(Image.fromarray(patch_np)) |
| row.append(patch_img) |
| patches.append(row) |
|
|
| return patches |
|
|