| import numpy as np |
| import torch |
| from typing import List, Tuple |
|
|
|
|
| |
| |
| |
|
|
| def sparse_structure_loss(pc: torch.Tensor, vox: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute BCE loss that encourages voxels occupied by the point cloud to be active. |
| |
| For each point in `pc`, the corresponding voxel logit is looked up and pushed |
| toward 1 via BCEWithLogitsLoss. |
| |
| Args: |
| pc: Point cloud of shape [N, 3] with coordinates in [0, 1]. |
| vox: Voxel logit grid of shape [1, 1, Rx, Ry, Rz]. |
| |
| Returns: |
| Scalar loss tensor. |
| """ |
| resox, resoy, resoz = vox.shape[2], vox.shape[3], vox.shape[4] |
|
|
| idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1) |
| idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1) |
| idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1) |
|
|
| vox_on_pc = vox[0, 0, idx_x, idx_y, idx_z] |
| return torch.nn.BCEWithLogitsLoss()(vox_on_pc, torch.ones_like(vox_on_pc)) |
|
|
|
|
| |
| |
| |
|
|
| def pointcloud_to_voxel(pc: torch.Tensor, reso: Tuple[int, int, int]) -> torch.Tensor: |
| """ |
| Rasterize a point cloud into a binary occupancy voxel grid. |
| |
| Args: |
| pc: Point cloud of shape [N, 3] with coordinates in [0, 1]. |
| reso: Grid resolution as (Rx, Ry, Rz). |
| |
| Returns: |
| Occupancy grid of shape [1, 1, Rx, Ry, Rz], dtype uint8. |
| """ |
| resox, resoy, resoz = reso |
|
|
| idx_x = torch.clamp(torch.floor(pc[:, 0] * resox).long(), 0, resox - 1) |
| idx_y = torch.clamp(torch.floor(pc[:, 1] * resoy).long(), 0, resoy - 1) |
| idx_z = torch.clamp(torch.floor(pc[:, 2] * resoz).long(), 0, resoz - 1) |
|
|
| vox = torch.zeros((1, 1, resox, resoy, resoz), dtype=torch.uint8) |
| vox[0, 0, idx_x, idx_y, idx_z] = 1 |
| return vox |
|
|
|
|
| |
| |
| |
|
|
| def schedule( |
| steps: int, |
| rescale_t: float, |
| start: float = 1.0, |
| stop: float = 0.0, |
| ) -> Tuple[np.ndarray, List[Tuple[float, float]]]: |
| """ |
| Build a rescaled time schedule for flow-matching sampling. |
| |
| The raw linear sequence is warped by `rescale_t` to concentrate steps near |
| t = 0 (high-detail regime) when rescale_t > 1. |
| |
| Args: |
| steps: Number of sampling steps. |
| rescale_t: Time-rescaling factor (1.0 = linear, >1 = compressed near t=0). |
| start: Starting time value (inclusive). |
| stop: Ending time value (inclusive). |
| |
| Returns: |
| t_seq: Array of shape [steps + 1] with the full time sequence. |
| t_pairs: List of (t, t_prev) pairs for each step. |
| """ |
| t_seq = np.linspace(1, 0, steps + 1) |
| t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) |
| t_seq = stop + (start - stop) * t_seq |
|
|
| t_pairs = [(t_seq[i], t_seq[i + 1]) for i in range(steps)] |
| return t_seq, t_pairs |
|
|
|
|
| |
| |
| |
|
|
| def get_views( |
| width: int, |
| length: int, |
| reso: int, |
| div: int, |
| ) -> List[Tuple[int, int, int, int, int, int]]: |
| """ |
| Enumerate overlapping patch views over a [width*reso, length*reso] latent grid. |
| |
| Each view is a (i, j, y_start, y_end, x_start, x_end) tuple where (i, j) is |
| the patch index and the remaining four values are pixel-space bounds. |
| |
| Args: |
| width: Number of tiles along Y. |
| length: Number of tiles along X. |
| reso: Resolution of each tile (must be divisible by `div`). |
| div: Overlap subdivision factor. |
| |
| Returns: |
| List of (i, j, y_start, y_end, x_start, x_end) tuples. |
| """ |
| assert reso % div == 0, f"reso ({reso}) must be divisible by div ({div})" |
|
|
| stride = reso // div |
| views = [] |
| for i, y0 in enumerate(range(0, reso * (width - 1) + stride, stride)): |
| for j, x0 in enumerate(range(0, reso * (length - 1) + stride, stride)): |
| views.append((i, j, y0, y0 + reso, x0, x0 + reso)) |
| return views |
|
|
|
|
| def dilated_sampling(reso: int, width: int, length: int) -> np.ndarray: |
| """ |
| Generate dilated (strided) coordinate samples covering the full latent grid. |
| |
| Each of the width×length samples picks every width-th row and every |
| length-th column, shifted by (i, j), so together they tile the grid without |
| overlap. The batch dimension is shuffled to randomize processing order. |
| |
| Args: |
| reso: Per-tile resolution. |
| width: Number of tiles along Y. |
| length: Number of tiles along X. |
| |
| Returns: |
| Array of shape [width*length, reso*reso, 2] containing (y, x) index pairs. |
| """ |
| samples = np.array([ |
| [[y, x] for y in range(i, reso * width, width) for x in range(j, reso * length, length)] |
| for i in range(width) |
| for j in range(length) |
| ]) |
|
|
| |
| for k in range(samples.shape[1]): |
| np.random.shuffle(samples[:, k]) |
|
|
| return samples |
|
|
|
|
| |
| |
| |
|
|
| def diffuse(x_0: torch.Tensor, t: torch.Tensor, sigma_min: float) -> torch.Tensor: |
| """ |
| Apply forward diffusion to `x_0` at time `t` under the flow-matching noise schedule. |
| |
| Args: |
| x_0: Clean latent of any shape. |
| t: Scalar or batch time value in [0, 1]. |
| sigma_min: Minimum noise level at t = 1. |
| |
| Returns: |
| Noisy latent x_t of the same shape as `x_0`. |
| """ |
| noise = torch.randn_like(x_0) |
| t = t.view(-1, *([1] * (x_0.ndim - 1))) |
| return (1 - t) * x_0 + (sigma_min + (1 - sigma_min) * t) * noise |
|
|