Extend3D / utils /utils.py
Seungwoo-Yoon
initial commit for HF space
a68e3ed
import numpy as np
import torch
from typing import List, Tuple
# ---------------------------------------------------------------------------
# Loss
# ---------------------------------------------------------------------------
def sparse_structure_loss(pc: torch.Tensor, vox: torch.Tensor) -> torch.Tensor:
"""
Compute BCE loss that encourages voxels occupied by the point cloud to be active.
For each point in `pc`, the corresponding voxel logit is looked up and pushed
toward 1 via BCEWithLogitsLoss.
Args:
pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
vox: Voxel logit grid of shape [1, 1, Rx, Ry, Rz].
Returns:
Scalar loss tensor.
"""
resox, resoy, resoz = vox.shape[2], vox.shape[3], vox.shape[4]
idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)
vox_on_pc = vox[0, 0, idx_x, idx_y, idx_z]
return torch.nn.BCEWithLogitsLoss()(vox_on_pc, torch.ones_like(vox_on_pc))
# ---------------------------------------------------------------------------
# Voxelization
# ---------------------------------------------------------------------------
def pointcloud_to_voxel(pc: torch.Tensor, reso: Tuple[int, int, int]) -> torch.Tensor:
"""
Rasterize a point cloud into a binary occupancy voxel grid.
Args:
pc: Point cloud of shape [N, 3] with coordinates in [0, 1].
reso: Grid resolution as (Rx, Ry, Rz).
Returns:
Occupancy grid of shape [1, 1, Rx, Ry, Rz], dtype uint8.
"""
resox, resoy, resoz = reso
idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1)
idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1)
idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1)
vox = torch.zeros((1, 1, resox, resoy, resoz), dtype=torch.uint8)
vox[0, 0, idx_x, idx_y, idx_z] = 1
return vox
# ---------------------------------------------------------------------------
# Sampling schedule
# ---------------------------------------------------------------------------
def schedule(
steps: int,
rescale_t: float,
start: float = 1.0,
stop: float = 0.0,
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
"""
Build a rescaled time schedule for flow-matching sampling.
The raw linear sequence is warped by `rescale_t` to concentrate steps near
t = 0 (high-detail regime) when rescale_t > 1.
Args:
steps: Number of sampling steps.
rescale_t: Time-rescaling factor (1.0 = linear, >1 = compressed near t=0).
start: Starting time value (inclusive).
stop: Ending time value (inclusive).
Returns:
t_seq: Array of shape [steps + 1] with the full time sequence.
t_pairs: List of (t, t_prev) pairs for each step.
"""
t_seq = np.linspace(1, 0, steps + 1)
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
t_seq = stop + (start - stop) * t_seq
t_pairs = [(t_seq[i], t_seq[i + 1]) for i in range(steps)]
return t_seq, t_pairs
# ---------------------------------------------------------------------------
# View / patch utilities
# ---------------------------------------------------------------------------
def get_views(
width: int,
length: int,
reso: int,
div: int,
) -> List[Tuple[int, int, int, int, int, int]]:
"""
Enumerate overlapping patch views over a [width*reso, length*reso] latent grid.
Each view is a (i, j, y_start, y_end, x_start, x_end) tuple where (i, j) is
the patch index and the remaining four values are pixel-space bounds.
Args:
width: Number of tiles along Y.
length: Number of tiles along X.
reso: Resolution of each tile (must be divisible by `div`).
div: Overlap subdivision factor.
Returns:
List of (i, j, y_start, y_end, x_start, x_end) tuples.
"""
assert reso % div == 0, f"reso ({reso}) must be divisible by div ({div})"
stride = reso // div
views = []
for i, y0 in enumerate(range(0, reso * (width - 1) + stride, stride)):
for j, x0 in enumerate(range(0, reso * (length - 1) + stride, stride)):
views.append((i, j, y0, y0 + reso, x0, x0 + reso))
return views
def dilated_sampling(reso: int, width: int, length: int) -> np.ndarray:
"""
Generate dilated (strided) coordinate samples covering the full latent grid.
Each of the width×length samples picks every width-th row and every
length-th column, shifted by (i, j), so together they tile the grid without
overlap. The batch dimension is shuffled to randomize processing order.
Args:
reso: Per-tile resolution.
width: Number of tiles along Y.
length: Number of tiles along X.
Returns:
Array of shape [width*length, reso*reso, 2] containing (y, x) index pairs.
"""
samples = np.array([
[[y, x] for y in range(i, reso * width, width) for x in range(j, reso * length, length)]
for i in range(width)
for j in range(length)
])
# Shuffle the batch dimension independently at each spatial position
for k in range(samples.shape[1]):
np.random.shuffle(samples[:, k])
return samples
# ---------------------------------------------------------------------------
# Diffusion
# ---------------------------------------------------------------------------
def diffuse(x_0: torch.Tensor, t: torch.Tensor, sigma_min: float) -> torch.Tensor:
"""
Apply forward diffusion to `x_0` at time `t` under the flow-matching noise schedule.
Args:
x_0: Clean latent of any shape.
t: Scalar or batch time value in [0, 1].
sigma_min: Minimum noise level at t = 1.
Returns:
Noisy latent x_t of the same shape as `x_0`.
"""
noise = torch.randn_like(x_0)
t = t.view(-1, *([1] * (x_0.ndim - 1))) # broadcast over spatial dims
return (1 - t) * x_0 + (sigma_min + (1 - sigma_min) * t) * noise