import numpy as np import torch from typing import List, Tuple # --------------------------------------------------------------------------- # Loss # --------------------------------------------------------------------------- def sparse_structure_loss(pc: torch.Tensor, vox: torch.Tensor) -> torch.Tensor: """ Compute BCE loss that encourages voxels occupied by the point cloud to be active. For each point in `pc`, the corresponding voxel logit is looked up and pushed toward 1 via BCEWithLogitsLoss. Args: pc: Point cloud of shape [N, 3] with coordinates in [0, 1]. vox: Voxel logit grid of shape [1, 1, Rx, Ry, Rz]. Returns: Scalar loss tensor. """ resox, resoy, resoz = vox.shape[2], vox.shape[3], vox.shape[4] idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1) idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1) idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1) vox_on_pc = vox[0, 0, idx_x, idx_y, idx_z] return torch.nn.BCEWithLogitsLoss()(vox_on_pc, torch.ones_like(vox_on_pc)) # --------------------------------------------------------------------------- # Voxelization # --------------------------------------------------------------------------- def pointcloud_to_voxel(pc: torch.Tensor, reso: Tuple[int, int, int]) -> torch.Tensor: """ Rasterize a point cloud into a binary occupancy voxel grid. Args: pc: Point cloud of shape [N, 3] with coordinates in [0, 1]. reso: Grid resolution as (Rx, Ry, Rz). Returns: Occupancy grid of shape [1, 1, Rx, Ry, Rz], dtype uint8. """ resox, resoy, resoz = reso idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1) idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1) idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1) vox = torch.zeros((1, 1, resox, resoy, resoz), dtype=torch.uint8) vox[0, 0, idx_x, idx_y, idx_z] = 1 return vox # --------------------------------------------------------------------------- # Sampling schedule # --------------------------------------------------------------------------- def schedule( steps: int, rescale_t: float, start: float = 1.0, stop: float = 0.0, ) -> Tuple[np.ndarray, List[Tuple[float, float]]]: """ Build a rescaled time schedule for flow-matching sampling. The raw linear sequence is warped by `rescale_t` to concentrate steps near t = 0 (high-detail regime) when rescale_t > 1. Args: steps: Number of sampling steps. rescale_t: Time-rescaling factor (1.0 = linear, >1 = compressed near t=0). start: Starting time value (inclusive). stop: Ending time value (inclusive). Returns: t_seq: Array of shape [steps + 1] with the full time sequence. t_pairs: List of (t, t_prev) pairs for each step. """ t_seq = np.linspace(1, 0, steps + 1) t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) t_seq = stop + (start - stop) * t_seq t_pairs = [(t_seq[i], t_seq[i + 1]) for i in range(steps)] return t_seq, t_pairs # --------------------------------------------------------------------------- # View / patch utilities # --------------------------------------------------------------------------- def get_views( width: int, length: int, reso: int, div: int, ) -> List[Tuple[int, int, int, int, int, int]]: """ Enumerate overlapping patch views over a [width*reso, length*reso] latent grid. Each view is a (i, j, y_start, y_end, x_start, x_end) tuple where (i, j) is the patch index and the remaining four values are pixel-space bounds. Args: width: Number of tiles along Y. length: Number of tiles along X. reso: Resolution of each tile (must be divisible by `div`). div: Overlap subdivision factor. Returns: List of (i, j, y_start, y_end, x_start, x_end) tuples. """ assert reso % div == 0, f"reso ({reso}) must be divisible by div ({div})" stride = reso // div views = [] for i, y0 in enumerate(range(0, reso * (width - 1) + stride, stride)): for j, x0 in enumerate(range(0, reso * (length - 1) + stride, stride)): views.append((i, j, y0, y0 + reso, x0, x0 + reso)) return views def dilated_sampling(reso: int, width: int, length: int) -> np.ndarray: """ Generate dilated (strided) coordinate samples covering the full latent grid. Each of the width×length samples picks every width-th row and every length-th column, shifted by (i, j), so together they tile the grid without overlap. The batch dimension is shuffled to randomize processing order. Args: reso: Per-tile resolution. width: Number of tiles along Y. length: Number of tiles along X. Returns: Array of shape [width*length, reso*reso, 2] containing (y, x) index pairs. """ samples = np.array([ [[y, x] for y in range(i, reso * width, width) for x in range(j, reso * length, length)] for i in range(width) for j in range(length) ]) # Shuffle the batch dimension independently at each spatial position for k in range(samples.shape[1]): np.random.shuffle(samples[:, k]) return samples # --------------------------------------------------------------------------- # Diffusion # --------------------------------------------------------------------------- def diffuse(x_0: torch.Tensor, t: torch.Tensor, sigma_min: float) -> torch.Tensor: """ Apply forward diffusion to `x_0` at time `t` under the flow-matching noise schedule. Args: x_0: Clean latent of any shape. t: Scalar or batch time value in [0, 1]. sigma_min: Minimum noise level at t = 1. Returns: Noisy latent x_t of the same shape as `x_0`. """ noise = torch.randn_like(x_0) t = t.view(-1, *([1] * (x_0.ndim - 1))) # broadcast over spatial dims return (1 - t) * x_0 + (sigma_min + (1 - sigma_min) * t) * noise