File size: 6,151 Bytes
a68e3ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import numpy as np
import torch
from typing import List, Tuple
# ---------------------------------------------------------------------------
# Loss
# ---------------------------------------------------------------------------
def sparse_structure_loss(pc: torch.Tensor, vox: torch.Tensor) -> torch.Tensor:
"""
Compute BCE loss that encourages voxels occupied by the point cloud to be active.
For each point in `pc`, the corresponding voxel logit is looked up and pushed
toward 1 via BCEWithLogitsLoss.
Args:
pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
vox: Voxel logit grid of shape [1, 1, Rx, Ry, Rz].
Returns:
Scalar loss tensor.
"""
resox, resoy, resoz = vox.shape[2], vox.shape[3], vox.shape[4]
idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)
vox_on_pc = vox[0, 0, idx_x, idx_y, idx_z]
return torch.nn.BCEWithLogitsLoss()(vox_on_pc, torch.ones_like(vox_on_pc))
# ---------------------------------------------------------------------------
# Voxelization
# ---------------------------------------------------------------------------
def pointcloud_to_voxel(pc: torch.Tensor, reso: Tuple[int, int, int]) -> torch.Tensor:
"""
Rasterize a point cloud into a binary occupancy voxel grid.
Args:
pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
reso: Grid resolution as (Rx, Ry, Rz).
Returns:
Occupancy grid of shape [1, 1, Rx, Ry, Rz], dtype uint8.
"""
resox, resoy, resoz = reso
idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)
vox = torch.zeros((1, 1, resox, resoy, resoz), dtype=torch.uint8)
vox[0, 0, idx_x, idx_y, idx_z] = 1
return vox
# ---------------------------------------------------------------------------
# Sampling schedule
# ---------------------------------------------------------------------------
def schedule(
steps: int,
rescale_t: float,
start: float = 1.0,
stop: float = 0.0,
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
"""
Build a rescaled time schedule for flow-matching sampling.
The raw linear sequence is warped by `rescale_t` to concentrate steps near
t = 0 (high-detail regime) when rescale_t > 1.
Args:
steps: Number of sampling steps.
rescale_t: Time-rescaling factor (1.0 = linear, >1 = compressed near t=0).
start: Starting time value (inclusive).
stop: Ending time value (inclusive).
Returns:
t_seq: Array of shape [steps + 1] with the full time sequence.
t_pairs: List of (t, t_prev) pairs for each step.
"""
t_seq = np.linspace(1, 0, steps + 1)
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
t_seq = stop + (start - stop) * t_seq
t_pairs = [(t_seq[i], t_seq[i + 1]) for i in range(steps)]
return t_seq, t_pairs
# ---------------------------------------------------------------------------
# View / patch utilities
# ---------------------------------------------------------------------------
def get_views(
width: int,
length: int,
reso: int,
div: int,
) -> List[Tuple[int, int, int, int, int, int]]:
"""
Enumerate overlapping patch views over a [width*reso, length*reso] latent grid.
Each view is a (i, j, y_start, y_end, x_start, x_end) tuple where (i, j) is
the patch index and the remaining four values are pixel-space bounds.
Args:
width: Number of tiles along Y.
length: Number of tiles along X.
reso: Resolution of each tile (must be divisible by `div`).
div: Overlap subdivision factor.
Returns:
List of (i, j, y_start, y_end, x_start, x_end) tuples.
"""
assert reso % div == 0, f"reso ({reso}) must be divisible by div ({div})"
stride = reso // div
views = []
for i, y0 in enumerate(range(0, reso * (width - 1) + stride, stride)):
for j, x0 in enumerate(range(0, reso * (length - 1) + stride, stride)):
views.append((i, j, y0, y0 + reso, x0, x0 + reso))
return views
def dilated_sampling(reso: int, width: int, length: int) -> np.ndarray:
"""
Generate dilated (strided) coordinate samples covering the full latent grid.
Each of the width×length samples picks every width-th row and every
length-th column, shifted by (i, j), so together they tile the grid without
overlap. The batch dimension is shuffled to randomize processing order.
Args:
reso: Per-tile resolution.
width: Number of tiles along Y.
length: Number of tiles along X.
Returns:
Array of shape [width*length, reso*reso, 2] containing (y, x) index pairs.
"""
samples = np.array([
[[y, x] for y in range(i, reso * width, width) for x in range(j, reso * length, length)]
for i in range(width)
for j in range(length)
])
# Shuffle the batch dimension independently at each spatial position
for k in range(samples.shape[1]):
np.random.shuffle(samples[:, k])
return samples
# ---------------------------------------------------------------------------
# Diffusion
# ---------------------------------------------------------------------------
def diffuse(x_0: torch.Tensor, t: torch.Tensor, sigma_min: float) -> torch.Tensor:
"""
Apply forward diffusion to `x_0` at time `t` under the flow-matching noise schedule.
Args:
x_0: Clean latent of any shape.
t: Scalar or batch time value in [0, 1].
sigma_min: Minimum noise level at t = 1.
Returns:
Noisy latent x_t of the same shape as `x_0`.
"""
noise = torch.randn_like(x_0)
t = t.view(-1, *([1] * (x_0.ndim - 1))) # broadcast over spatial dims
return (1 - t) * x_0 + (sigma_min + (1 - sigma_min) * t) * noise
|