Extend3D / utils /moge_utils.py
Seungwoo-Yoon
initial commit for HF space
a68e3ed
import torch
import numpy as np
from PIL import Image
from typing import List
import utils3d
from moge.model.v2 import MoGeModel
from utils.depth_utils import PointmapInfo, align_ground_to_z, crop_and_resize_foreground
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
def moge_preprocess(image: Image.Image, device) -> torch.Tensor:
"""Convert a PIL image to a normalized float32 CHW tensor on `device`."""
rgb = np.array(image.convert("RGB"))
return torch.tensor(rgb / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
# ---------------------------------------------------------------------------
# MoGe-based pointmap
# ---------------------------------------------------------------------------
class PointmapInfoMoGe(PointmapInfo):
"""
Concrete PointmapInfo implementation backed by the MoGe monocular depth estimator.
The MoGe model is loaded once and cached as a class-level attribute, so
subsequent instantiations reuse the same weights.
"""
# Shared across all instances to avoid redundant weight loading
moge_model: MoGeModel | None = None
def __init__(self, image: Image.Image, device: str = 'cuda'):
self._input_image = moge_preprocess(image, device)
# Run MoGe inference (no gradients needed)
with torch.no_grad():
if PointmapInfoMoGe.moge_model is None:
PointmapInfoMoGe.moge_model = (
MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
)
predictions = PointmapInfoMoGe.moge_model.infer(self._input_image)
# Mask out depth edges to suppress discontinuity artifacts
depth_edge_mask = utils3d.numpy.depth_edge(predictions['depth'].cpu().numpy(), rtol=0.04)
mask = predictions['mask'] & torch.from_numpy(~depth_edge_mask).to(device)
# Align the ground plane with the XY plane (+Z up)
points = predictions['points']
masked_points, _, R = align_ground_to_z(points[mask].reshape(-1, 3), return_transform=True)
# Move arrays to CPU/numpy for coordinate normalization
mask = mask.cpu().numpy()
points = points.cpu().numpy()
masked_points = masked_points.cpu().numpy()
self.intrinsic = predictions['intrinsics'].cpu().numpy()
# Normalize XY to [0, 1] and Z to a height relative to scene scale
mins = masked_points[:, :2].min(axis=0)
maxs = masked_points[:, :2].max(axis=0)
scaling = (maxs - mins).max()
height = masked_points[:, 2].max() / scaling
# Flip Z, center XY, and apply uniform scale
masked_points[:, 2] *= -1
masked_points[:, :2] = (masked_points[:, :2] - mins) / scaling + (1 - (maxs - mins) / scaling) / 2
masked_points[:, 2] -= masked_points[:, 2].min()
masked_points[:, 2] *= 1.0 / scaling
# Build the camera extrinsic [R | t] from the alignment transform
R = R.T
R[:, 2] *= -1
t = R @ np.array([*(mins / scaling - (1 - (maxs - mins) / scaling) / 2), -height])
t += R @ np.array([0.5, 0.5, 0.0])
# Permute axes from (y, x, z) to (x, y, z) convention
P = np.array([[0, 1, 0],
[1, 0, 0],
[0, 0, 1]])
self.intrinsic = P @ self.intrinsic @ P.T
R = P @ R @ P.T
t = P @ t
self.extrinsic = np.vstack((np.hstack((R, t.reshape(-1, 1))), [0, 0, 0, 1]))
# Store the full pointmap (with masked region filled in) for patch extraction
self.pc = masked_points
points[mask] = masked_points
self._pointmap = points
# -----------------------------------------------------------------------
# PointmapInfo interface
# -----------------------------------------------------------------------
def point_cloud(self) -> np.ndarray:
return self.pc
def camera_intrinsic(self) -> np.ndarray:
return self.intrinsic
def camera_extrinsic(self) -> np.ndarray:
return self.extrinsic
def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]:
"""
Slice the image into overlapping patches based on the normalized pointmap.
Args:
width: Number of tiles along the Y axis.
length: Number of tiles along the X axis.
div: Overlap subdivision factor (higher = more overlap).
Returns:
2D list of PIL images of shape [width*(div-1)+1][length*(div-1)+1].
"""
# Convert the input tensor back to a uint8 HWC numpy array
image_np = (self._input_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
patches = []
for i in range(width * div - div + 1):
row = []
for j in range(length * div - div + 1):
# Compute normalized [0, 1] bounds for this patch
y_start = i / (width * div)
x_start = j / (length * div)
y_end = y_start + 1.0 / width
x_end = x_start + 1.0 / length
# Mask pixels whose pointmap coordinates fall within this patch
pm = self._pointmap
in_patch = (
(y_start <= pm[:, :, 1]) & (pm[:, :, 1] < y_end) &
(x_start <= pm[:, :, 0]) & (pm[:, :, 0] < x_end)
)[:, :, None]
patch_np = np.where(in_patch, image_np, 0).astype(np.uint8)
patch_img = crop_and_resize_foreground(Image.fromarray(patch_np))
row.append(patch_img)
patches.append(row)
return patches