File size: 6,151 Bytes
a68e3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy as np
import torch
from typing import List, Tuple


# ---------------------------------------------------------------------------
# Loss
# ---------------------------------------------------------------------------

def sparse_structure_loss(pc: torch.Tensor, vox: torch.Tensor) -> torch.Tensor:
    """
    Compute BCE loss that encourages voxels occupied by the point cloud to be active.

    For each point in `pc`, the corresponding voxel logit is looked up and pushed
    toward 1 via BCEWithLogitsLoss.

    Args:
        pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
        vox: Voxel logit grid of shape [1, 1, Rx, Ry, Rz].

    Returns:
        Scalar loss tensor.
    """
    resox, resoy, resoz = vox.shape[2], vox.shape[3], vox.shape[4]

    idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
    idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
    idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)

    vox_on_pc = vox[0, 0, idx_x, idx_y, idx_z]
    return torch.nn.BCEWithLogitsLoss()(vox_on_pc, torch.ones_like(vox_on_pc))


# ---------------------------------------------------------------------------
# Voxelization
# ---------------------------------------------------------------------------

def pointcloud_to_voxel(pc: torch.Tensor, reso: Tuple[int, int, int]) -> torch.Tensor:
    """
    Rasterize a point cloud into a binary occupancy voxel grid.

    Args:
        pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
        reso: Grid resolution as (Rx, Ry, Rz).

    Returns:
        Occupancy grid of shape [1, 1, Rx, Ry, Rz], dtype uint8.
    """
    resox, resoy, resoz = reso

    idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
    idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
    idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)

    vox = torch.zeros((1, 1, resox, resoy, resoz), dtype=torch.uint8)
    vox[0, 0, idx_x, idx_y, idx_z] = 1
    return vox


# ---------------------------------------------------------------------------
# Sampling schedule
# ---------------------------------------------------------------------------

def schedule(
    steps: int,
    rescale_t: float,
    start: float = 1.0,
    stop: float = 0.0,
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
    """
    Build a rescaled time schedule for flow-matching sampling.

    The raw linear sequence is warped by `rescale_t` to concentrate steps near
    t = 0 (high-detail regime) when rescale_t > 1.

    Args:
        steps: Number of sampling steps.
        rescale_t: Time-rescaling factor (1.0 = linear, >1 = compressed near t=0).
        start: Starting time value (inclusive).
        stop: Ending time value (inclusive).

    Returns:
        t_seq: Array of shape [steps + 1] with the full time sequence.
        t_pairs: List of (t, t_prev) pairs for each step.
    """
    t_seq = np.linspace(1, 0, steps + 1)
    t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
    t_seq = stop + (start - stop) * t_seq

    t_pairs = [(t_seq[i], t_seq[i + 1]) for i in range(steps)]
    return t_seq, t_pairs


# ---------------------------------------------------------------------------
# View / patch utilities
# ---------------------------------------------------------------------------

def get_views(
    width: int,
    length: int,
    reso: int,
    div: int,
) -> List[Tuple[int, int, int, int, int, int]]:
    """
    Enumerate overlapping patch views over a [width*reso, length*reso] latent grid.

    Each view is a (i, j, y_start, y_end, x_start, x_end) tuple where (i, j) is
    the patch index and the remaining four values are pixel-space bounds.

    Args:
        width: Number of tiles along Y.
        length: Number of tiles along X.
        reso: Resolution of each tile (must be divisible by `div`).
        div: Overlap subdivision factor.

    Returns:
        List of (i, j, y_start, y_end, x_start, x_end) tuples.
    """
    assert reso % div == 0, f"reso ({reso}) must be divisible by div ({div})"

    stride = reso // div
    views  = []
    for i, y0 in enumerate(range(0, reso * (width - 1) + stride, stride)):
        for j, x0 in enumerate(range(0, reso * (length - 1) + stride, stride)):
            views.append((i, j, y0, y0 + reso, x0, x0 + reso))
    return views


def dilated_sampling(reso: int, width: int, length: int) -> np.ndarray:
    """
    Generate dilated (strided) coordinate samples covering the full latent grid.

    Each of the width×length samples picks every width-th row and every
    length-th column, shifted by (i, j), so together they tile the grid without
    overlap. The batch dimension is shuffled to randomize processing order.

    Args:
        reso: Per-tile resolution.
        width: Number of tiles along Y.
        length: Number of tiles along X.

    Returns:
        Array of shape [width*length, reso*reso, 2] containing (y, x) index pairs.
    """
    samples = np.array([
        [[y, x] for y in range(i, reso * width, width) for x in range(j, reso * length, length)]
        for i in range(width)
        for j in range(length)
    ])

    # Shuffle the batch dimension independently at each spatial position
    for k in range(samples.shape[1]):
        np.random.shuffle(samples[:, k])

    return samples


# ---------------------------------------------------------------------------
# Diffusion
# ---------------------------------------------------------------------------

def diffuse(x_0: torch.Tensor, t: torch.Tensor, sigma_min: float) -> torch.Tensor:
    """
    Apply forward diffusion to `x_0` at time `t` under the flow-matching noise schedule.

    Args:
        x_0: Clean latent of any shape.
        t: Scalar or batch time value in [0, 1].
        sigma_min: Minimum noise level at t = 1.

    Returns:
        Noisy latent x_t of the same shape as `x_0`.
    """
    noise = torch.randn_like(x_0)
    t     = t.view(-1, *([1] * (x_0.ndim - 1)))  # broadcast over spatial dims
    return (1 - t) * x_0 + (sigma_min + (1 - sigma_min) * t) * noise