File size: 7,252 Bytes
a68e3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import numpy as np
import open3d as o3d
from PIL import Image, ImageOps
from sklearn.neighbors import NearestNeighbors
from typing import List, Tuple, Union
from abc import ABC, abstractmethod


# ---------------------------------------------------------------------------
# Point cloud utilities
# ---------------------------------------------------------------------------

def align_ground_to_z(
    pc: torch.Tensor,
    voxel_size: float = 0.05,
    distance_threshold: float = 0.005,
    ransac_n: int = 3,
    num_iterations: int = 10000,
    return_transform: bool = False,
) -> Union[
    Tuple[torch.Tensor, o3d.geometry.PointCloud],
    Tuple[torch.Tensor, o3d.geometry.PointCloud, np.ndarray],
]:
    """
    Detect the dominant plane in a point cloud, align its normal with +Z,
    and return the leveled cloud (and optionally the rotation matrix).

    Args:
        pc: Input point cloud of shape [N, 3], dtype float32.
        voxel_size: Voxel size for downsampling before RANSAC; set 0 to skip.
        distance_threshold: Max distance for a point to be counted as a RANSAC inlier.
        ransac_n: Number of points sampled per RANSAC trial.
        num_iterations: Maximum number of RANSAC iterations.
        return_transform: If True, also return the 3×3 rotation matrix.

    Returns:
        aligned_pc: Rotated point cloud, same shape and device as `pc`.
        aligned_pcd: Rotated Open3D point cloud.
        R (optional): 3×3 rotation matrix mapping the detected plane normal to [0, 0, 1].
            Only returned when `return_transform=True`.

    Raises:
        ValueError: If `pc` is not an Nx3 tensor.
        RuntimeError: If RANSAC fails to find a valid dominant plane.
    """
    if pc.ndim != 2 or pc.shape[1] != 3:
        raise ValueError(f"Expected pc of shape [N, 3], got {tuple(pc.shape)}")

    device = pc.device
    xyz    = pc.detach().cpu().numpy()

    # Build (and optionally downsample) an Open3D point cloud for RANSAC
    pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))
    if voxel_size > 0:
        pcd = pcd.voxel_down_sample(voxel_size)

    # Plane segmentation via RANSAC
    plane_model, inliers = pcd.segment_plane(
        distance_threshold=distance_threshold,
        ransac_n=ransac_n,
        num_iterations=num_iterations,
    )
    if len(inliers) < ransac_n:
        raise RuntimeError("RANSAC failed to find a dominant plane.")

    a, b, c, _ = plane_model
    normal = np.array([a, b, c], dtype=np.float64)
    normal /= np.linalg.norm(normal)

    # Compute rotation from detected normal to +Z via axis-angle
    target = np.array([0.0, 0.0, 1.0], dtype=np.float64)
    dot    = np.dot(normal, target)

    if np.allclose(dot, 1.0, atol=1e-6):
        # Already aligned
        R = np.eye(3)
    elif np.allclose(dot, -1.0, atol=1e-6):
        # 180° flip about any axis orthogonal to the normal
        ortho = np.array([0.0, 1.0, 0.0] if abs(normal[0]) > 0.9 else [1.0, 0.0, 0.0])
        axis  = np.cross(normal, ortho)
        axis /= np.linalg.norm(axis)
        R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * np.pi)
    else:
        axis  = np.cross(normal, target)
        axis /= np.linalg.norm(axis)
        angle = np.arccos(dot)
        R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * angle)

    # Apply rotation to the full (non-downsampled) cloud
    full_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))
    full_pcd.rotate(R, center=(0, 0, 0))

    aligned_np = np.asarray(full_pcd.points, dtype=np.float32)
    aligned_pc = torch.from_numpy(aligned_np).to(device)

    if return_transform:
        return aligned_pc, full_pcd, R
    return aligned_pc, full_pcd


def remove_outliers_statistical(
    pts: np.ndarray,
    nb_neighbors: int = 20,
    std_ratio: float = 2.0,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Remove outliers from a point cloud via Statistical Outlier Removal.

    For each point, the mean distance to its `nb_neighbors` nearest neighbors is
    computed. Points whose mean distance exceeds (μ + std_ratio × σ) are discarded.

    Args:
        pts: Input point cloud of shape [N, 3].
        nb_neighbors: Number of nearest neighbors used per point.
        std_ratio: Threshold multiplier for the outlier distance cutoff.

    Returns:
        clean_pts: Inlier points of shape [M, 3] where M ≤ N.
        mask: Boolean mask of shape [N] that is True for inlier points.
    """
    # k+1 neighbors so we can skip the zero-distance self-match
    nbrs = NearestNeighbors(n_neighbors=nb_neighbors + 1, algorithm="auto").fit(pts)
    dists, _ = nbrs.kneighbors(pts)

    avg_dists = dists[:, 1:].mean(axis=1)       # exclude self (distance = 0)
    mu        = avg_dists.mean()
    sigma     = avg_dists.std()

    mask      = avg_dists <= (mu + std_ratio * sigma)
    return pts[mask], mask


# ---------------------------------------------------------------------------
# Image utilities
# ---------------------------------------------------------------------------

def crop_and_resize_foreground(img: Image.Image, padding: float = 0.05) -> Image.Image:
    """
    Crop the foreground of an image and resize it back to the original dimensions.

    Steps:
        1. Find the tightest bounding box around all non-black pixels.
        2. Crop to that bounding box.
        3. Add a black border of `padding` × original dimensions on all sides.
        4. Scale the padded crop uniformly to fit within the original canvas and
           center-paste it onto a black background.

    Args:
        img: Input PIL image.
        padding: Border width as a fraction of the original image dimensions.

    Returns:
        A new RGB PIL image of the same size as `img`.
    """
    rgb  = img.convert("RGB")
    w, h = rgb.size

    mask = rgb.convert("L").point(lambda x: 0 if x == 0 else 255, mode="1")
    bbox = mask.getbbox()
    if bbox is None:
        return img.copy()

    crop  = rgb.crop(bbox)
    pad_x = int(padding * w)
    pad_y = int(padding * h)
    padded = ImageOps.expand(crop, border=(pad_x, pad_y, pad_x, pad_y), fill=(0, 0, 0))

    pw, ph = padded.size
    scale  = min(w / pw, h / ph)
    new_w  = max(1, int(pw * scale))
    new_h  = max(1, int(ph * scale))

    resized_fg = padded.resize((new_w, new_h), Image.LANCZOS)

    canvas   = Image.new("RGB", (w, h), (0, 0, 0))
    offset_x = (w - new_w) // 2
    offset_y = (h - new_h) // 2
    canvas.paste(resized_fg, (offset_x, offset_y))
    return canvas


# ---------------------------------------------------------------------------
# Pointmap interface
# ---------------------------------------------------------------------------

class PointmapInfo(ABC):
    """Abstract base class for depth-based point cloud extraction."""

    @abstractmethod
    def __init__(self, image: Image.Image, device):
        pass

    @abstractmethod
    def point_cloud(self) -> torch.Tensor:
        pass

    @abstractmethod
    def camera_intrinsic(self) -> np.ndarray:
        pass

    @abstractmethod
    def camera_extrinsic(self) -> np.ndarray:
        pass

    @abstractmethod
    def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]:
        pass