| | import logging |
| | import os |
| | from typing import Callable, Dict, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from tqdm import tqdm |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def scale_rotmat( |
| | rotation_matrix: torch.Tensor, scalar: torch.Tensor, tol: float = 1e-7 |
| | ) -> torch.Tensor: |
| | """ |
| | Scale rotation matrix. This is done by converting it to vector representation, |
| | scaling the length of the vector and converting back to matrix representation. |
| | |
| | Args: |
| | rotation_matrix: Rotation matrices. |
| | scalar: Scalar values used for scaling. Should have one fewer dimension than the |
| | rotation matrices for correct broadcasting. |
| | tol: Numerical offset for stability. |
| | |
| | Returns: |
| | Scaled rotation matrix. |
| | """ |
| | |
| | assert rotation_matrix.ndim - 1 == scalar.ndim |
| | scaled_rmat = rotvec_to_rotmat(rotmat_to_rotvec(rotation_matrix) * scalar, tol=tol) |
| | return scaled_rmat |
| |
|
| |
|
| | def _broadcast_identity(target: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Generate a 3 by 3 identity matrix and broadcast it to a batch of target matrices. |
| | |
| | Args: |
| | target (torch.Tensor): Batch of target 3 by 3 matrices. |
| | |
| | Returns: |
| | torch.Tensor: 3 by 3 identity matrices in the shapes of the target. |
| | """ |
| | id3 = torch.eye(3, device=target.device, dtype=target.dtype) |
| | id3 = torch.broadcast_to(id3, target.shape) |
| | return id3 |
| |
|
| |
|
| | def skew_matrix_exponential_map_axis_angle( |
| | angles: torch.Tensor, skew_matrices: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the matrix exponential of a rotation in axis-angle representation with the axis in skew |
| | matrix representation form. Maps the rotation from the lie group to the rotation matrix |
| | representation. Uses Rodrigues' formula instead of `torch.linalg.matrix_exp` for better |
| | computational performance: |
| | |
| | .. math:: |
| | |
| | \exp(\theta \mathbf{K}) = \mathbf{I} + \sin(\theta) \mathbf{K} + [1 - \cos(\theta)] \mathbf{K}^2 |
| | |
| | Args: |
| | angles (torch.Tensor): Batch of rotation angles. |
| | skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis. |
| | |
| | Returns: |
| | torch.Tensor: Batch of corresponding rotation matrices. |
| | """ |
| | |
| | id3 = _broadcast_identity(skew_matrices) |
| |
|
| | |
| | angles = angles[..., None, None] |
| |
|
| | exp_skew = ( |
| | id3 |
| | + torch.sin(angles) * skew_matrices |
| | + (1.0 - torch.cos(angles)) |
| | * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices) |
| | ) |
| | return exp_skew |
| |
|
| |
|
| | def skew_matrix_exponential_map( |
| | angles: torch.Tensor, skew_matrices: torch.Tensor, tol=1e-7 |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the matrix exponential of a rotation vector in skew matrix representation. Maps the |
| | rotation from the lie group to the rotation matrix representation. Uses the following form of |
| | Rodrigues' formula instead of `torch.linalg.matrix_exp` for better computational performance |
| | (in this case the skew matrix already contains the angle factor): |
| | |
| | .. math :: |
| | |
| | \exp(\mathbf{K}) = \mathbf{I} + \frac{\sin(\theta)}{\theta} \mathbf{K} + \frac{1-\cos(\theta)}{\theta^2} \mathbf{K}^2 |
| | |
| | This form has the advantage, that Taylor expansions can be used for small angles (instead of |
| | having to compute the unit length axis by dividing the rotation vector by small angles): |
| | |
| | .. math :: |
| | |
| | \frac{\sin(\theta)}{\theta} \approx 1 - \frac{\theta^2}{6} |
| | \frac{1-\cos(\theta)}{\theta^2} \approx \frac{1}{2} - \frac{\theta^2}{24} |
| | |
| | Args: |
| | angles (torch.Tensor): Batch of rotation angles. |
| | skew_matrices (torch.Tensor): Batch of rotation axes in skew matrix (lie so(3)) basis. |
| | |
| | Returns: |
| | torch.Tensor: Batch of corresponding rotation matrices. |
| | """ |
| | |
| | id3 = _broadcast_identity(skew_matrices) |
| |
|
| | |
| | angles = angles[..., None, None] |
| | angles_sq = angles.square() |
| |
|
| | |
| | sin_coeff = torch.sin(angles) / angles |
| | cos_coeff = (1.0 - torch.cos(angles)) / angles_sq |
| | |
| | sin_coeff_small = 1.0 - angles_sq / 6.0 |
| | cos_coeff_small = 0.5 - angles_sq / 24.0 |
| |
|
| | mask_zero = torch.abs(angles) < tol |
| | sin_coeff = torch.where(mask_zero, sin_coeff_small, sin_coeff) |
| | cos_coeff = torch.where(mask_zero, cos_coeff_small, cos_coeff) |
| |
|
| | |
| | exp_skew = ( |
| | id3 |
| | + sin_coeff * skew_matrices |
| | + cos_coeff * torch.einsum("b...ik,b...kj->b...ij", skew_matrices, skew_matrices) |
| | ) |
| | return exp_skew |
| |
|
| |
|
| | def rotvec_to_rotmat(rotation_vectors: torch.Tensor, tol: float = 1e-7) -> torch.Tensor: |
| | """ |
| | Convert rotation vectors to rotation matrix representation. The length of the rotation vector |
| | is the angle of rotation, the unit vector the rotation axis. |
| | |
| | Args: |
| | rotation_vectors (torch.Tensor): Batch of rotation vectors. |
| | tol: small offset for numerical stability. |
| | |
| | Returns: |
| | torch.Tensor: Rotation in rotation matrix representation. |
| | """ |
| | |
| | rotation_angles = torch.norm(rotation_vectors, dim=-1) |
| |
|
| | |
| | skew_matrices = vector_to_skew_matrix(rotation_vectors) |
| |
|
| | |
| | rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices, tol=tol) |
| |
|
| | return rotation_matrices |
| |
|
| |
|
| | def rotmat_to_rotvec(rotation_matrices: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert a batch of rotation matrices to rotation vectors (logarithmic map from SO(3) to so(3)). |
| | The standard logarithmic map can be derived from Rodrigues' formula via Taylor approximation |
| | (in this case operating on the vector coefficients of the skew so(3) basis). |
| | |
| | ..math :: |
| | |
| | \left[\log(\mathbf{R})\right]^\lor = \frac{\theta}{2\sin(\theta)} \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor |
| | |
| | This formula has problems at 1) angles theta close or equal to zero and 2) at angles close and |
| | equal to pi. |
| | |
| | To improve numerical stability for case 1), the angle term at small or zero angles is |
| | approximated by its truncated Taylor expansion: |
| | |
| | .. math :: |
| | |
| | \left[\log(\mathbf{R})\right]^\lor \approx \frac{1}{2} (1 + \frac{\theta^2}{6}) \left[\mathbf{R} - \mathbf{R}^\top\right]^\lor |
| | |
| | For angles close or equal to pi (case 2), the outer product relation can be used to obtain the |
| | squared rotation vector: |
| | |
| | .. math :: \omega \otimes \omega = \frac{1}{2}(\mathbf{I} + R) |
| | |
| | Taking the root of the diagonal elements recovers the normalized rotation vector up to the signs |
| | of the component. The latter can be obtained from the off-diagonal elements. |
| | |
| | Adapted from https://github.com/jasonkyuyim/se3_diffusion/blob/2cba9e09fdc58112126a0441493b42022c62bbea/data/so3_utils.py |
| | which was adapted from https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/special_orthogonal.py |
| | with heavy help from https://cvg.cit.tum.de/_media/members/demmeln/nurlanov2021so3log.pdf |
| | |
| | Args: |
| | rotation_matrices (torch.Tensor): Input batch of rotation matrices. |
| | |
| | Returns: |
| | torch.Tensor: Batch of rotation vectors. |
| | """ |
| | |
| | angles, angles_sin, _ = angle_from_rotmat(rotation_matrices) |
| | |
| | vector = skew_matrix_to_vector(rotation_matrices - rotation_matrices.transpose(-2, -1)) |
| |
|
| | |
| | |
| | mask_zero = torch.isclose(angles, torch.zeros_like(angles)).to(angles.dtype) |
| | |
| | mask_pi = torch.isclose(angles, torch.full_like(angles, np.pi), atol=1e-2).to(angles.dtype) |
| | |
| | mask_else = (1 - mask_zero) * (1 - mask_pi) |
| |
|
| | |
| | numerator = mask_zero / 2.0 + angles * mask_else |
| | |
| | |
| | |
| | denominator = ( |
| | (1.0 - angles**2 / 6.0) * mask_zero |
| | + 2.0 * angles_sin * mask_else |
| | + mask_pi |
| | ) |
| | prefactor = numerator / denominator |
| | vector = vector * prefactor[..., None] |
| |
|
| | |
| | id3 = _broadcast_identity(rotation_matrices) |
| | skew_outer = (id3 + rotation_matrices) / 2.0 |
| | |
| | skew_outer = skew_outer + (torch.relu(skew_outer) - skew_outer) * id3 |
| |
|
| | |
| | vector_pi = torch.sqrt(torch.diagonal(skew_outer, dim1=-2, dim2=-1)) |
| |
|
| | |
| | |
| | signs_line_idx = torch.argmax(torch.norm(skew_outer, dim=-1), dim=-1).long() |
| | |
| | signs_line = torch.take_along_dim(skew_outer, dim=-2, indices=signs_line_idx[..., None, None]) |
| | signs_line = signs_line.squeeze(-2) |
| | signs = torch.sign(signs_line) |
| |
|
| | |
| | vector_pi = vector_pi * angles[..., None] * signs |
| |
|
| | |
| | vector = vector + vector_pi * mask_pi[..., None] |
| |
|
| | return vector |
| |
|
| |
|
| | def angle_from_rotmat( |
| | rotation_matrices: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Compute rotation angles (as well as their sines and cosines) encoded by rotation matrices. |
| | Uses atan2 for better numerical stability for small angles. |
| | |
| | Args: |
| | rotation_matrices (torch.Tensor): Batch of rotation matrices. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Batch of computed angles, sines of the |
| | angles and cosines of angles. |
| | """ |
| | |
| | |
| | skew_matrices = rotation_matrices - rotation_matrices.transpose(-2, -1) |
| | skew_vectors = skew_matrix_to_vector(skew_matrices) |
| | angles_sin = torch.norm(skew_vectors, dim=-1) / 2.0 |
| | |
| | angles_cos = (torch.einsum("...ii", rotation_matrices) - 1.0) / 2.0 |
| |
|
| | |
| | angles = torch.atan2(angles_sin, angles_cos) |
| |
|
| | return angles, angles_sin, angles_cos |
| |
|
| |
|
| | def vector_to_skew_matrix(vectors: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Map a vector into the corresponding skew matrix so(3) basis. |
| | ``` |
| | [ 0 -z y] |
| | [x,y,z] -> [ z 0 -x] |
| | [ -y x 0] |
| | ``` |
| | |
| | Args: |
| | vectors (torch.Tensor): Batch of vectors to be mapped to skew matrices. |
| | |
| | Returns: |
| | torch.Tensor: Vectors in skew matrix representation. |
| | """ |
| | |
| | skew_matrices = torch.zeros((*vectors.shape, 3), device=vectors.device, dtype=vectors.dtype) |
| |
|
| | |
| | skew_matrices[..., 2, 1] = vectors[..., 0] |
| | skew_matrices[..., 0, 2] = vectors[..., 1] |
| | skew_matrices[..., 1, 0] = vectors[..., 2] |
| |
|
| | |
| | skew_matrices = skew_matrices - skew_matrices.transpose(-2, -1) |
| |
|
| | return skew_matrices |
| |
|
| |
|
| | def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Extract a rotation vector from the so(3) skew matrix basis. |
| | |
| | Args: |
| | skew_matrices (torch.Tensor): Skew matrices. |
| | |
| | Returns: |
| | torch.Tensor: Rotation vectors corresponding to skew matrices. |
| | """ |
| | vectors = torch.zeros_like(skew_matrices[..., 0]) |
| | vectors[..., 0] = skew_matrices[..., 2, 1] |
| | vectors[..., 1] = skew_matrices[..., 0, 2] |
| | vectors[..., 2] = skew_matrices[..., 1, 0] |
| | return vectors |
| |
|
| |
|
| | def _rotquat_to_axis_angle( |
| | rotation_quaternions: torch.Tensor, tol: float = 1e-7 |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Auxiliary routine for computing rotation angle and rotation axis from unit quaternions. To avoid |
| | complications, rotations vectors with angles below `tol` are set to zero. |
| | |
| | Args: |
| | rotation_quaternions (torch.Tensor): Rotation quaternions in [r, i, j, k] format. |
| | tol (float, optional): Threshold for small rotations. Defaults to 1e-7. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Rotation angles and axes. |
| | """ |
| | |
| | rotation_axes = rotation_quaternions[..., 1:] |
| | rotation_axes_norms = torch.norm(rotation_axes, dim=-1) |
| |
|
| | |
| | rotation_angles = 2.0 * torch.atan2(rotation_axes_norms, rotation_quaternions[..., 0]) |
| |
|
| | |
| | rotation_axes = rotation_axes / (rotation_axes_norms[:, None] + tol) |
| | return rotation_angles, rotation_axes |
| |
|
| |
|
| | def rotquat_to_rotvec(rotation_quaternions: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert unit quaternions to rotation vectors. |
| | |
| | Args: |
| | rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format. |
| | |
| | Returns: |
| | torch.Tensor: Rotation vectors. |
| | """ |
| | rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions) |
| | rotation_vectors = rotation_axes * rotation_angles[..., None] |
| | return rotation_vectors |
| |
|
| |
|
| | def rotquat_to_rotmat(rotation_quaternions: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert unit quaternion to rotation matrix. |
| | |
| | Args: |
| | rotation_quaternions (torch.Tensor): Input quaternions in [r,i,j,k] format. |
| | |
| | Returns: |
| | torch.Tensor: Rotation matrices. |
| | """ |
| | rotation_angles, rotation_axes = _rotquat_to_axis_angle(rotation_quaternions) |
| | skew_matrices = vector_to_skew_matrix(rotation_axes * rotation_angles[..., None]) |
| | rotation_matrices = skew_matrix_exponential_map(rotation_angles, skew_matrices) |
| | return rotation_matrices |
| |
|
| |
|
| | def apply_rotvec_to_rotmat( |
| | rotation_matrices: torch.Tensor, |
| | rotation_vectors: torch.Tensor, |
| | tol: float = 1e-7, |
| | ) -> torch.Tensor: |
| | """ |
| | Update a rotation encoded in a rotation matrix with a rotation vector. |
| | |
| | Args: |
| | rotation_matrices: Input batch of rotation matrices. |
| | rotation_vectors: Input batch of rotation vectors. |
| | tol: Small offset for numerical stability. |
| | |
| | Returns: |
| | Updated rotation matrices. |
| | """ |
| | |
| | rmat_right = rotvec_to_rotmat(rotation_vectors, tol=tol) |
| | |
| | rmat_rotated = torch.einsum("...ij,...jk->...ik", rotation_matrices, rmat_right) |
| | return rmat_rotated |
| |
|
| |
|
| | def rotmat_to_skew_matrix(mat: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Generates skew matrix for corresponding rotation matrix. |
| | |
| | Args: |
| | mat (torch.Tensor): Batch of rotation matrices. |
| | |
| | Returns: |
| | torch.Tensor: Skew matrices in the shapes of mat. |
| | """ |
| | vec = rotmat_to_rotvec(mat) |
| | return vector_to_skew_matrix(vec) |
| |
|
| |
|
| | def skew_matrix_to_rotmat(skew: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Generates rotation matrix for corresponding skew matrix. |
| | |
| | Args: |
| | skew (torch.Tensor): Batch of target 3 by 3 skew symmetric matrices. |
| | |
| | Returns: |
| | torch.Tensor: Rotation matrices in the shapes of skew. |
| | """ |
| | vec = skew_matrix_to_vector(skew) |
| | return rotvec_to_rotmat(vec) |
| |
|
| |
|
| | def local_log(point: torch.Tensor, base_point: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Matrix logarithm. Computes left-invariant vector field of beinging base_point to point |
| | on the manifold. Follows the signature of geomstats' equivalent function. |
| | https://geomstats.github.io/api/geometry.html#geomstats.geometry.lie_group.MatrixLieGroup.log |
| | |
| | Args: |
| | point (torch.Tensor): Batch of rotation matrices to compute vector field at. |
| | base_point (torch.Tensor): Transport coordinates to take matrix logarithm. |
| | |
| | Returns: |
| | torch.Tensor: Skew matrix that holds the vector field (in the tangent space). |
| | """ |
| | return rotmat_to_skew_matrix(rot_mult(rot_transpose(base_point), point)) |
| |
|
| |
|
| | def multidim_trace(mat: torch.Tensor) -> torch.Tensor: |
| | """Take the trace of a matrix with leading dimensions.""" |
| | return torch.einsum("...ii->...", mat) |
| |
|
| |
|
| | def geodesic_dist(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Calculate the geodesic distance of two rotation matrices. |
| | |
| | Args: |
| | mat_1 (torch.Tensor): First rotation matrix. |
| | mat_2 (torch.Tensor): Second rotation matrix. |
| | |
| | Returns: |
| | Scalar for the geodesic distance between mat_1 and mat_2 with the same |
| | leading (i.e. batch) dimensions. |
| | """ |
| | A = rotmat_to_skew_matrix(rot_mult(rot_transpose(mat_1), mat_2)) |
| | return torch.sqrt(multidim_trace(rot_mult(A, rot_transpose(A)))) |
| |
|
| |
|
| | def rot_transpose(mat: torch.Tensor) -> torch.Tensor: |
| | """Take the transpose of the last two dimensions.""" |
| | return torch.transpose(mat, -1, -2) |
| |
|
| |
|
| | def rot_mult(mat_1: torch.Tensor, mat_2: torch.Tensor) -> torch.Tensor: |
| | """Matrix multiply two rotation matrices with leading dimensions.""" |
| | return torch.einsum("...ij,...jk->...ik", mat_1, mat_2) |
| |
|
| |
|
| | def calc_rot_vf(mat_t: torch.Tensor, mat_1: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Computes the vector field Log_{mat_t}(mat_1). |
| | |
| | Args: |
| | mat_t (torch.Tensor): base point to compute vector field at. |
| | mat_1 (torch.Tensor): target rotation. |
| | |
| | Returns: |
| | Rotation vector representing the vector field. |
| | """ |
| | return rotmat_to_rotvec(rot_mult(rot_transpose(mat_t), mat_1)) |
| |
|
| |
|
| | def geodesic_t(t: float, mat: torch.Tensor, base_mat: torch.Tensor, rot_vf=None) -> torch.Tensor: |
| | """ |
| | Computes the geodesic at time t. Specifically, R_t = Exp_{base_mat}(t * Log_{base_mat}(mat)). |
| | |
| | Args: |
| | t: time along geodesic. |
| | mat: target points on manifold. |
| | base_mat: source point on manifold. |
| | |
| | Returns: |
| | Point along geodesic starting at base_mat and ending at mat. |
| | """ |
| | if rot_vf is None: |
| | rot_vf = calc_rot_vf(base_mat, mat) |
| | mat_t = rotvec_to_rotmat(t * rot_vf) |
| | if base_mat.shape != mat_t.shape: |
| | raise ValueError( |
| | f'Incompatible shapes: base_mat={base_mat.shape}, mat_t={mat_t.shape}') |
| | return torch.einsum("...ij,...jk->...ik", base_mat, mat_t) |
| |
|
| |
|
| | class SO3LookupCache: |
| | def __init__( |
| | self, |
| | cache_dir: str, |
| | cache_file: str, |
| | overwrite: bool = False, |
| | ) -> None: |
| | """ |
| | Auxiliary class for handling storage / loading of SO(3) lookup tables in npz format. |
| | |
| | Args: |
| | cache_dir: Path to the cache directory. |
| | cache_file: Basic file name of the cache file. |
| | overwrite: Whether existing cache files should be overwritten if requested. |
| | """ |
| | if not cache_file.endswith(".npz"): |
| | raise ValueError("Filename should have '.npz' extension.") |
| | self.cache_file = cache_file |
| | self.cache_dir = cache_dir |
| | self.cache_path = os.path.join(cache_dir, cache_file) |
| | self.overwrite = overwrite |
| |
|
| | @property |
| | def path_exists(self) -> bool: |
| | return os.path.exists(self.cache_path) |
| |
|
| | @property |
| | def dir_exists(self) -> bool: |
| | return os.path.exists(self.cache_dir) |
| |
|
| | def delete_cache(self) -> None: |
| | """ |
| | Delete the cache file. |
| | """ |
| | if self.path_exists: |
| | os.remove(self.cache_path) |
| |
|
| | def load_cache(self) -> Dict[str, torch.Tensor]: |
| | """ |
| | Load data from the cache file. |
| | |
| | Returns: |
| | Dictionary of loaded data tensors. |
| | """ |
| | if self.path_exists: |
| | |
| | npz_data = np.load(self.cache_path) |
| | torch_dict = {f: torch.from_numpy(npz_data[f]) for f in npz_data.files} |
| | logger.info(f"Data loaded from {self.cache_path}") |
| | return torch_dict |
| | else: |
| | raise ValueError(f"No cache data found at {self.cache_path}.") |
| |
|
| | def save_cache(self, data: Dict[str, torch.Tensor]) -> None: |
| | """ |
| | Save a dictionary of tensors to the cache file. If overwrite is set to True, an existing |
| | file is overwritten, otherwise a warning is raised and the file is not modified. |
| | |
| | Args: |
| | data: Dictionary of tensors that should be saved to the cache. |
| | """ |
| | if not self.dir_exists: |
| | os.makedirs(self.cache_dir) |
| |
|
| | if self.path_exists: |
| | if self.overwrite: |
| | logger.info("Overwriting cache ...") |
| | self.delete_cache() |
| | else: |
| | logger.warn( |
| | f"Cache at {self.cache_path} exits and overwriting disabled. Doing nothing." |
| | ) |
| | else: |
| | |
| | logger.info(f"Data saved to {self.cache_path}") |
| | numpy_dict = {k: v.detach().cpu().numpy() for k, v in data.items()} |
| | np.savez(self.cache_path, **numpy_dict) |
| |
|
| |
|
| | class BaseSampleSO3(nn.Module): |
| | so3_type: str = "base" |
| |
|
| | def __init__( |
| | self, |
| | num_omega: int, |
| | sigma_grid: torch.Tensor, |
| | omega_exponent: int = 3, |
| | tol: float = 1e-7, |
| | interpolate: bool = True, |
| | cache_dir: Optional[str] = None, |
| | overwrite_cache: bool = False, |
| | device: str = 'cpu', |
| | ) -> None: |
| | """ |
| | Base torch.nn module for sampling rotations from the IGSO(3) distribution. Samples are |
| | created by uniformly sampling a rotation axis and using inverse transform sampling for |
| | the angles. The latter uses the associated SO(3) cumulative probability distribution |
| | function (CDF) and a uniform distribution [0,1] as described in [#leach2022_1]_. CDF values |
| | are obtained by numerically integrating the probability distribution evaluated on a grid of |
| | angles and noise levels and stored in a lookup table. Linear interpolation is used to |
| | approximate continuos sampling of the function. Angles are discretized in an interval [0,pi] |
| | and the grid can be squashed to have higher resolutions at low angles by taking different |
| | powers. Since sampling relies on tabulated values of the CDF and indexing in the form of |
| | `torch.bucketize`, gradients are not supported. |
| | |
| | Args: |
| | num_omega (int): Number of discrete angles used for generating the lookup table. |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking |
| | its power with the provided number. Defaults to 3. |
| | tol (float, optional): Small value for numerical stability. Defaults to 1e-7. |
| | interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF |
| | to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. |
| | cache_dir: Path to an optional cache directory. If set to None, lookup tables are |
| | computed on the fly. |
| | overwrite_cache: If set to true, existing cache files are overwritten. Can be used for |
| | updating stale caches. |
| | |
| | References |
| | ---------- |
| | .. [#leach2022_1] Leach, Schmon, Degiacomi, Willcocks: |
| | Denoising diffusion probabilistic models on so (3) for rotational alignment. |
| | ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. |
| | """ |
| | super().__init__() |
| | self.num_omega = num_omega |
| | self.omega_exponent = omega_exponent |
| | self.tol = tol |
| | self.interpolate = interpolate |
| | self.device = device |
| | self.register_buffer("sigma_grid", sigma_grid, persistent=False) |
| |
|
| | |
| | omega_grid, cdf_igso3 = self._setup_lookup(sigma_grid, cache_dir, overwrite_cache) |
| | self.register_buffer("omega_grid", omega_grid, persistent=False) |
| | self.register_buffer("cdf_igso3", cdf_igso3, persistent=False) |
| |
|
| | def _setup_lookup( |
| | self, |
| | sigma_grid: torch.Tensor, |
| | cache_dir: Optional[str] = None, |
| | overwrite_cache: bool = False, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Master function for setting up the lookup tables. These can either be loaded from a npz |
| | cache file or computed on the fly. Lookup tables will always be created and stored in double |
| | precision. Casting to the target dtype is done at the end of the function. |
| | |
| | Args: |
| | sigma_grid: Grid of sigma values used for computing the lookup tables. |
| | cache_dir: Path to the cache directory. |
| | overwrite_cache: If set to true, an existing cache is overwritten. Can be used for |
| | updating stale caches. |
| | |
| | Returns: |
| | Grid of angle values and SO(3) cumulative distribution function. |
| | """ |
| | if cache_dir is not None: |
| | cache_name = self._get_cache_name() |
| | cache = SO3LookupCache(cache_dir, cache_name, overwrite=True) |
| |
|
| | |
| | |
| | if cache.path_exists and not overwrite_cache: |
| | |
| | cache_data = cache.load_cache() |
| | omega_grid = cache_data["omega_grid"] |
| | cdf_igso3 = cache_data["cdf_igso3"] |
| | else: |
| | |
| | omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid) |
| | cache.save_cache({"omega_grid": omega_grid, "cdf_igso3": cdf_igso3}) |
| | else: |
| | |
| | omega_grid, cdf_igso3 = self._generate_lookup(sigma_grid) |
| |
|
| | return omega_grid.to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype) |
| |
|
| | def _get_cache_name(self) -> str: |
| | """ |
| | Auxiliary function for determining the cache file name based on the parameters (sigma, |
| | omega, l, etc.) used for generating the lookup tables. |
| | |
| | Returns: |
| | Base name of the cache file. |
| | """ |
| | cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_o{:d}-{:d}.npz".format( |
| | self.so3_type, |
| | torch.min(self.sigma_grid).cpu().item(), |
| | torch.max(self.sigma_grid).cpu().item(), |
| | self.sigma_grid.shape[0], |
| | self.num_omega, |
| | self.omega_exponent, |
| | ) |
| | return cache_name |
| |
|
| | def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert continuous sigmas to the indices of the closest tabulated values. |
| | |
| | Args: |
| | sigma (torch.Tensor): IGSO3 std devs. |
| | |
| | Returns: |
| | torch.Tensor: Index tensor mapping the provided sigma values to the internal lookup |
| | table. |
| | """ |
| | return torch.bucketize(sigma, self.sigma_grid) |
| |
|
| | def expansion_function( |
| | self, omega_grid: torch.Tensor, sigma_grid: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Function for generating the angle probability distribution. Should return a 2D tensor with |
| | values for the std dev at the first dimension (rows) and angles at the second |
| | (columns). |
| | |
| | Args: |
| | omega_grid (torch.Tensor): Grid of angle values. |
| | sigma_grid (torch.Tensor): IGSO3 std devs. |
| | |
| | Returns: |
| | torch.Tensor: Distribution for angles discretized on a 2D grid. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @torch.no_grad() |
| | def _generate_lookup(self, sigma_grid: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Generate the lookup table for sampling from the target SO(3) CDF. The table is 2D, with the |
| | rows corresponding to different sigma values and the columns with angles computed on a grid. |
| | Variance is scaled by a factor of 1/2 to account for the deacceleration of time in the |
| | diffusion process due to the choice of SO(3) basis and guarantee time-reversibility (see |
| | appendix E.3 in [#yim2023_2]_). The returned tables are double precision and will be cast |
| | to the target dtype in `_setup_lookup`. |
| | |
| | Args: |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Tuple containing the grid used to compute the angles |
| | and the associated lookup table. |
| | |
| | References |
| | ---------- |
| | .. [#yim2023_2] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola: |
| | SE(3) diffusion model with application to protein backbone generation. |
| | arXiv preprint arXiv:2302.02277. 2023. |
| | """ |
| |
|
| | current_device = sigma_grid.device |
| | sigma_grid_tmp = sigma_grid.to(torch.float64) |
| |
|
| | |
| | |
| | |
| | if torch.cuda.is_available(): |
| | sigma_grid_tmp = sigma_grid_tmp.to(device=self.device) |
| |
|
| | |
| | omega_grid = torch.linspace(0.0, 1, self.num_omega + 1).to(sigma_grid_tmp) |
| |
|
| | |
| | omega_grid = omega_grid**self.omega_exponent |
| |
|
| | omega_grid = omega_grid * np.pi |
| |
|
| | |
| | pdf_igso3 = self.expansion_function(omega_grid, sigma_grid_tmp) |
| |
|
| | |
| | pdf_igso3 = pdf_igso3 * (1.0 - torch.cos(omega_grid)) / np.pi |
| |
|
| | |
| | cdf_igso3 = integrate_trapezoid_cumulative(pdf_igso3, omega_grid) |
| | |
| | cdf_igso3 = cdf_igso3 / cdf_igso3[:, -1][:, None] |
| |
|
| | |
| | cdf_igso3 = cdf_igso3.to(device=current_device) |
| | omega_grid = omega_grid.to(device=current_device) |
| |
|
| | return omega_grid[1:].to(sigma_grid.dtype), cdf_igso3.to(sigma_grid.dtype) |
| |
|
| | def sample(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor: |
| | """ |
| | Generate samples from the target SO(3) distribution by sampling a rotation axis angle, |
| | which are then combined into a rotation vector and transformed into the corresponding |
| | rotation matrix via an exponential map. |
| | |
| | Args: |
| | sigma_indices (torch.Tensor): Indices of the IGSO3 std devs for which to take samples. |
| | num_samples (int): Number of angle samples to take for each std dev |
| | |
| | Returns: |
| | torch.Tensor: Sampled rotations in matrix representation with dimensions |
| | [num_sigma x num_samples x 3 x 3]. |
| | """ |
| |
|
| | vectors = self.sample_vector(sigma.shape[0], num_samples) |
| | angles = self.sample_angle(sigma, num_samples) |
| |
|
| | |
| | angles = self._process_angles(sigma, angles) |
| |
|
| | rotation_vectors = vectors * angles[..., None] |
| |
|
| | rotation_matrices = rotvec_to_rotmat(rotation_vectors, tol=self.tol) |
| | return rotation_matrices |
| |
|
| | def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Auxiliary function for performing additional processing steps on the sampled angles. One |
| | example would be to ensure sampled angles are 0 for a std dev of 0 for IGSO(3). |
| | |
| | Args: |
| | sigma (torch.Tensor): Current values of sigma. |
| | angles (torch.Tensor): Sampled angles. |
| | |
| | Returns: |
| | torch.Tensor: Processed sampled angles. |
| | """ |
| | return angles |
| |
|
| | def sample_vector(self, num_sigma: int, num_samples: int) -> torch.Tensor: |
| | """ |
| | Uniformly sample rotation axis for constructing the overall rotation. |
| | |
| | Args: |
| | num_sigma (int): Number of samples to draw for each std dev. |
| | num_samples (int): Number of angle samples to take for each std dev. |
| | |
| | Returns: |
| | torch.Tensor: Batch of rotation axes with dimensions [num_sigma x num_samples x 3]. |
| | """ |
| | vectors = torch.randn(num_sigma, num_samples, 3, device=self.sigma_grid.device) |
| | vectors = vectors / torch.norm(vectors, dim=2, keepdim=True) |
| | return vectors |
| |
|
| | def sample_angle(self, sigma: torch.Tensor, num_samples: int) -> torch.Tensor: |
| | """ |
| | Create a series of samples from the IGSO(3) angle distribution. |
| | |
| | Args: |
| | sigma_indices (torch.Tensor): Indices of the IGSO3 std deves for which to |
| | take samples. |
| | num_samples (int): Number of angle samples to take for each std dev. |
| | |
| | Returns: |
| | torch.Tensor: Collected samples, will have the dimension [num_sigma x num_samples]. |
| | """ |
| | |
| | sigma_indices = self.get_sigma_idx(sigma) |
| | |
| | cdf_tmp = self.cdf_igso3[sigma_indices, :] |
| |
|
| | |
| | p_uniform = torch.rand((*sigma_indices.shape, *[num_samples]), device=sigma_indices.device) |
| |
|
| | |
| | idx_stop = torch.sum(cdf_tmp[..., None] < p_uniform[:, None, :], dim=1).long() |
| | idx_start = torch.clamp(idx_stop - 1, min=0) |
| |
|
| | if not self.interpolate: |
| | omega = torch.gather(cdf_tmp, dim=1, index=idx_stop) |
| | else: |
| | |
| | cdf_start = torch.gather(cdf_tmp, dim=1, index=idx_start) |
| | cdf_stop = torch.gather(cdf_tmp, dim=1, index=idx_stop) |
| |
|
| | |
| | cdf_delta = torch.clamp(cdf_stop - cdf_start, min=self.tol) |
| | cdf_weight = torch.clamp((p_uniform - cdf_start) / cdf_delta, min=0.0, max=1.0) |
| |
|
| | |
| | omega_start = self.omega_grid[idx_start] |
| | omega_stop = self.omega_grid[idx_stop] |
| |
|
| | |
| | omega = torch.lerp(omega_start, omega_stop, cdf_weight) |
| |
|
| | return omega |
| |
|
| |
|
| | class SampleIGSO3(BaseSampleSO3): |
| | so3_type = "igso3" |
| |
|
| | def __init__( |
| | self, |
| | num_omega: int, |
| | sigma_grid: torch.Tensor, |
| | omega_exponent: int = 3, |
| | tol: float = 1e-7, |
| | interpolate: bool = True, |
| | l_max: int = 1000, |
| | cache_dir: Optional[str] = None, |
| | overwrite_cache: bool = False, |
| | device: str = 'cpu', |
| | ) -> None: |
| | """ |
| | Module for sampling rotations from the IGSO(3) distribution using the explicit series |
| | expansion. Samples are created using inverse transform sampling based on the associated |
| | cumulative probability distribution function (CDF) and a uniform distribution [0,1] as |
| | described in [#leach2022_2]_. CDF values are obtained by numerically integrating the |
| | probability distribution evaluated on a grid of angles and noise levels and stored in a |
| | lookup table. Linear interpolation is used to approximate continuos sampling of the |
| | function. Angles are discretized in an interval [0,pi] and the grid can be squashed to have |
| | higher resolutions at low angles by taking different powers. |
| | Since sampling relies on tabulated values of the CDF and indexing in the form of |
| | `torch.bucketize`, gradients are not supported. |
| | |
| | Args: |
| | num_omega (int): Number of discrete angles used for generating the lookup table. |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking |
| | its power with the provided number. Defaults to 3. |
| | tol (float, optional): Small value for numerical stability. Defaults to 1e-7. |
| | interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF |
| | to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. |
| | l_max (int, optional): Maximum number of terms used in the series expansion. |
| | cache_dir: Path to an optional cache directory. If set to None, lookup tables are |
| | computed on the fly. |
| | overwrite_cache: If set to true, existing cache files are overwritten. Can be used for |
| | updating stale caches. |
| | |
| | References |
| | ---------- |
| | .. [#leach2022_2] Leach, Schmon, Degiacomi, Willcocks: |
| | Denoising diffusion probabilistic models on so (3) for rotational alignment. |
| | ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. |
| | """ |
| | self.l_max = l_max |
| | super().__init__( |
| | num_omega=num_omega, |
| | sigma_grid=sigma_grid, |
| | omega_exponent=omega_exponent, |
| | tol=tol, |
| | interpolate=interpolate, |
| | cache_dir=cache_dir, |
| | overwrite_cache=overwrite_cache, |
| | device=device, |
| | ) |
| |
|
| | def _get_cache_name(self) -> str: |
| | """ |
| | Auxiliary function for determining the cache file name based on the parameters (sigma, |
| | omega, l, etc.) used for generating the lookup tables. |
| | |
| | Returns: |
| | Base name of the cache file. |
| | """ |
| | cache_name = "cache_{:s}_s{:04.3f}-{:04.3f}-{:d}_l{:d}_o{:d}-{:d}.npz".format( |
| | self.so3_type, |
| | torch.min(self.sigma_grid).cpu().item(), |
| | torch.max(self.sigma_grid).cpu().item(), |
| | self.sigma_grid.shape[0], |
| | self.l_max, |
| | self.num_omega, |
| | self.omega_exponent, |
| | ) |
| | return cache_name |
| |
|
| | def expansion_function( |
| | self, |
| | omega_grid: torch.Tensor, |
| | sigma_grid: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Use the truncated expansion of the IGSO(3) probability function to generate the lookup table. |
| | |
| | Args: |
| | omega_grid (torch.Tensor): Grid of angle values. |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | |
| | Returns: |
| | torch.Tensor: IGSO(3) distribution for angles discretized on a 2D grid. |
| | """ |
| | return generate_igso3_lookup_table(omega_grid, sigma_grid, l_max=self.l_max, tol=self.tol) |
| |
|
| | def _process_angles(self, sigma: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Ensure sampled angles are 0 for small noise levels in IGSO(3). (Series expansion gives |
| | uniform probability distribution.) |
| | |
| | Args: |
| | sigma (torch.Tensor): Current values of sigma. |
| | angles (torch.Tensor): Sampled angles. |
| | |
| | Returns: |
| | torch.Tensor: Processed sampled angles. |
| | """ |
| | angles = torch.where( |
| | sigma[..., None] < self.tol, |
| | torch.zeros_like(angles), |
| | angles, |
| | ) |
| | return angles |
| |
|
| |
|
| | class SampleUSO3(BaseSampleSO3): |
| | so3_type = "uso3" |
| |
|
| | def __init__( |
| | self, |
| | num_omega: int, |
| | sigma_grid: torch.Tensor, |
| | omega_exponent: int = 3, |
| | tol: float = 1e-7, |
| | interpolate: bool = True, |
| | cache_dir: Optional[str] = None, |
| | overwrite_cache: bool = False, |
| | ) -> None: |
| | """ |
| | Module for sampling rotations from the USO(3) distribution. Can be used to generate initial |
| | unbiased samples in the reverse process. Samples are created using inverse transform |
| | sampling based on the associated cumulative probability distribution function (CDF) and a |
| | uniform distribution [0,1] as described in [#leach2022_4]_. CDF values are obtained by |
| | numerically integrating the probability distribution evaluated on a grid of angles and noise |
| | levels and stored in a lookup table. Linear interpolation is used to approximate continuos |
| | sampling of the function. Angles are discretized in an interval [0,pi] and the grid can be |
| | squashed to have higher resolutions at low angles by taking different powers. |
| | Since sampling relies on tabulated values of the CDF and indexing in the form of |
| | `torch.bucketize`, gradients are not supported. |
| | |
| | Args: |
| | num_omega (int): Number of discrete angles used for generating the lookup table. |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | omega_exponent (int, optional): Make the angle grid denser for smaller angles by taking |
| | its power with the provided number. Defaults to 3. |
| | tol (float, optional): Small value for numerical stability. Defaults to 1e-7. |
| | interpolate (bool, optional): If enables, perform linear interpolation of the angle CDF |
| | to sample angles. Otherwise the closest tabulated point is returned. Defaults to True. |
| | cache_dir: Path to an optional cache directory. If set to None, lookup tables are |
| | computed on the fly. |
| | overwrite_cache: If set to true, existing cache files are overwritten. Can be used for |
| | updating stale caches. |
| | |
| | References |
| | ---------- |
| | .. [#leach2022_4] Leach, Schmon, Degiacomi, Willcocks: |
| | Denoising diffusion probabilistic models on so (3) for rotational alignment. |
| | ICLR 2022 Workshop on Geometrical and Topological Representation Learning. 2022. |
| | """ |
| | super().__init__( |
| | num_omega=num_omega, |
| | sigma_grid=sigma_grid, |
| | omega_exponent=omega_exponent, |
| | tol=tol, |
| | interpolate=interpolate, |
| | cache_dir=cache_dir, |
| | overwrite_cache=overwrite_cache, |
| | ) |
| |
|
| | def get_sigma_idx(self, sigma: torch.Tensor) -> torch.Tensor: |
| | return torch.zeros_like(sigma).long() |
| |
|
| | def sample_shape(self, num_sigma: int, num_samples: int) -> torch.Tensor: |
| | dummy_sigma = torch.zeros(num_sigma, device=self.sigma_grid.device) |
| | return self.sample(dummy_sigma, num_samples) |
| |
|
| | def expansion_function( |
| | self, |
| | omega_grid: torch.Tensor, |
| | sigma_grid: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | The probability density function of the uniform SO(3) distribution is the cosine scaling |
| | term (1-cos(omega))/pi which is applied automatically during sampling. This means, it is |
| | sufficient to return a tensor of ones to create the correct USO(3) lookup table. |
| | |
| | Args: |
| | omega_grid (torch.Tensor): Grid of angle values. |
| | sigma_grid (torch.Tensor): Grid of IGSO3 std devs. |
| | |
| | Returns: |
| | torch.Tensor: USO(3) distribution for angles discretized on a 2D grid. |
| | """ |
| | return torch.ones(1, omega_grid.shape[0], device=omega_grid.device) |
| |
|
| |
|
| | @torch.no_grad() |
| | def integrate_trapezoid_cumulative(f_grid: torch.Tensor, x_grid: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Auxiliary function for numerically integrating a discretized 1D function using the trapezoid |
| | rule. This is mainly used for computing the cumulative probability distributions for sampling |
| | from the IGSO(3) distribution. Works on a single 1D grid or a batch of grids. |
| | |
| | Args: |
| | f_grid (torch.Tensor): Discretized function values. |
| | x_grid (torch.Tensor): Discretized input values. |
| | |
| | Returns: |
| | torch.Tensor: Integrated function (not normalized). |
| | """ |
| | f_sum = f_grid[..., :-1] + f_grid[..., 1:] |
| | delta_x = torch.diff(x_grid, dim=-1) |
| | integral = torch.cumsum((f_sum * delta_x[None, :]) / 2.0, dim=-1) |
| | return integral |
| |
|
| |
|
| | def uniform_so3_density(omega: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Compute the density over the uniform angle distribution in SO(3). |
| | |
| | Args: |
| | omega: Angles in radians. |
| | |
| | Returns: |
| | Uniform distribution density. |
| | """ |
| | return (1.0 - torch.cos(omega)) / np.pi |
| |
|
| |
|
| | def igso3_expansion( |
| | omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the IGSO(3) angle probability distribution function for pairs of angles and std dev |
| | levels. The expansion is computed using a grid of expansion orders ranging from 0 to l_max. |
| | |
| | This function approximates the power series in equation 5 of [#yim2023_3]_. With this |
| | parameterization, IGSO(3) agrees with the Brownian motion on SO(3) with t=sigma^2. |
| | |
| | Args: |
| | omega: Values of angles (1D tensor). |
| | sigma: Values of std dev of IGSO3 distribution (1D tensor of same shape as `omega`). |
| | l_grid: Tensor containing expansion orders (0 to l_max). |
| | tol: Small offset for numerical stability. |
| | |
| | Returns: |
| | IGSO(3) angle distribution function (without pre-factor for uniform SO(3) distribution). |
| | |
| | References |
| | ---------- |
| | .. [#yim2023_3] Yim, Trippe, De Bortoli, Mathieu, Doucet, Barzilay, Jaakkola: |
| | SE(3) diffusion model with application to protein backbone generation. |
| | arXiv preprint arXiv:2302.02277. 2023. |
| | """ |
| | |
| | denom_sin = torch.sin(0.5 * omega) |
| |
|
| | |
| | l_fac_1 = 2.0 * l_grid + 1.0 |
| | l_fac_2 = -l_grid * (l_grid + 1.0) |
| |
|
| | |
| | numerator_sin = torch.sin((l_grid[None, :] + 1 / 2) * omega[:, None]) |
| |
|
| | |
| | exponential_term = l_fac_1[None, :] * torch.exp(l_fac_2[None, :] * sigma[:, None] ** 2 / 2) |
| |
|
| | |
| | f_igso = torch.sum(exponential_term * numerator_sin, dim=1) |
| | |
| | |
| | f_limw = torch.sum(exponential_term * l_fac_1[None, :], dim=1) |
| |
|
| | |
| | |
| | f_igso = f_igso / (denom_sin + tol) |
| |
|
| | |
| | f_igso = torch.where(omega <= tol, f_limw, f_igso) |
| |
|
| | |
| | f_igso = torch.where( |
| | torch.logical_or(torch.isinf(f_igso), torch.isnan(f_igso)), torch.zeros_like(f_igso), f_igso |
| | ) |
| |
|
| | return f_igso |
| |
|
| |
|
| | def digso3_expansion( |
| | omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the derivative of the IGSO(3) angle probability distribution function with respect to |
| | the angles for pairs of angles and std dev levels. As in `igso3_expansion` a grid is used for the |
| | expansion levels. Evaluates the derivative directly in order to avoid second derivatives during |
| | backpropagation. |
| | |
| | The derivative of the angle-dependent part is computed as: |
| | |
| | .. math :: |
| | \frac{\partial}{\partial \omega} \frac{\sin((l+\tfrac{1}{2})\omega)}{\sin(\tfrac{1}{2}\omega)} = \frac{l\sin((l+1)\omega) - (l+1)\sin(l\omega)}{1 - \cos(\omega)} |
| | |
| | (obtained via quotient rule + different trigonometric identities). |
| | |
| | Args: |
| | omega: Values of angles (1D tensor). |
| | sigma: Values of IGSO3 distribution std devs (1D tensor of same shape as `omega`). |
| | l_grid: Tensor containing expansion orders (0 to l_max). |
| | tol: Small offset for numerical stability. |
| | |
| | Returns: |
| | IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution). |
| | """ |
| | denom_cos = 1.0 - torch.cos(omega) |
| |
|
| | l_fac_1 = 2.0 * l_grid + 1.0 |
| | l_fac_2 = l_grid + 1.0 |
| | l_fac_3 = -l_grid * l_fac_2 |
| |
|
| | |
| | numerator_sin = l_grid[None, :] * torch.sin(l_fac_2[None, :] * omega[:, None]) - l_fac_2[ |
| | None, : |
| | ] * torch.sin(l_grid[None, :] * omega[:, None]) |
| |
|
| | |
| | df_igso = torch.sum( |
| | l_fac_1[None, :] * torch.exp(l_fac_3[None, :] * sigma[:, None] ** 2 / 2) * numerator_sin, |
| | dim=1, |
| | ) |
| |
|
| | |
| | |
| | df_igso = df_igso / (denom_cos + tol) |
| |
|
| | |
| | df_igso = torch.where(omega <= tol, torch.zeros_like(df_igso), df_igso) |
| |
|
| | |
| | df_igso = torch.where( |
| | torch.logical_or(torch.isinf(df_igso), torch.isnan(df_igso)), |
| | torch.zeros_like(df_igso), |
| | df_igso, |
| | ) |
| |
|
| | return df_igso |
| |
|
| |
|
| | def dlog_igso3_expansion( |
| | omega: torch.Tensor, sigma: torch.Tensor, l_grid: torch.Tensor, tol=1e-7 |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the derivative of the logarithm of the IGSO(3) angle distribution function for pairs of |
| | angles and std dev levels: |
| | |
| | .. math :: |
| | \frac{\partial}{\partial \omega} \log f(\omega) = \frac{\tfrac{\partial}{\partial \omega} f(\omega)}{f(\omega)} |
| | |
| | Required for SO(3) score computation. |
| | |
| | Args: |
| | omega: Values of angles (1D tensor). |
| | sigma: Values of IGSO3 std devs (1D tensor of same shape as `omega`). |
| | l_grid: Tensor containing expansion orders (0 to l_max). |
| | tol: Small offset for numerical stability. |
| | |
| | Returns: |
| | IGSO(3) angle distribution derivative (without pre-factor for uniform SO(3) distribution). |
| | """ |
| | f_igso3 = igso3_expansion(omega, sigma, l_grid, tol=tol) |
| | df_igso3 = digso3_expansion(omega, sigma, l_grid, tol=tol) |
| |
|
| | return df_igso3 / (f_igso3 + tol) |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate_lookup_table( |
| | base_function: Callable, |
| | omega_grid: torch.Tensor, |
| | sigma_grid: torch.Tensor, |
| | l_max: int = 1000, |
| | tol: float = 1e-7, |
| | ): |
| | """ |
| | Auxiliary function for generating a lookup table from IGSO(3) expansions and their derivatives. |
| | Takes a basic function and loops over different std dev levels. |
| | |
| | Args: |
| | base_function: Function used for setting up the lookup table. |
| | omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). |
| | sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). |
| | l_max: Number of terms used in the series expansion. |
| | tol: Small value for numerical stability. |
| | |
| | Returns: |
| | Table of function values evaluated at different angles and std dev levels. The final shape is |
| | [num_sigma x num_omega]. |
| | """ |
| | |
| | l_grid = torch.arange(l_max + 1, device=omega_grid.device).to(omega_grid.dtype) |
| |
|
| | n_omega = len(omega_grid) |
| | n_sigma = len(sigma_grid) |
| |
|
| | |
| | f_table = torch.zeros(n_sigma, n_omega, device=omega_grid.device, dtype=omega_grid.dtype) |
| |
|
| | for eps_idx in tqdm(range(n_sigma), desc=f"Computing {base_function.__name__}"): |
| | f_table[eps_idx, :] = base_function( |
| | omega_grid, |
| | torch.ones_like(omega_grid) * sigma_grid[eps_idx], |
| | l_grid, |
| | tol=tol, |
| | ) |
| |
|
| | return f_table |
| |
|
| |
|
| | def generate_igso3_lookup_table( |
| | omega_grid: torch.Tensor, |
| | sigma_grid: torch.Tensor, |
| | l_max: int = 1000, |
| | tol: float = 1e-7, |
| | ) -> torch.Tensor: |
| | """ |
| | Generate a lookup table for the IGSO(3) probability distribution function of angles. |
| | |
| | Args: |
| | omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). |
| | sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). |
| | l_max: Number of terms used in the series expansion. |
| | tol: Small value for numerical stability. |
| | |
| | Returns: |
| | Table of function values evaluated at different angles and std dev levels. The final shape is |
| | [num_sigma x num_omega]. |
| | """ |
| | f_igso = generate_lookup_table( |
| | base_function=igso3_expansion, |
| | omega_grid=omega_grid, |
| | sigma_grid=sigma_grid, |
| | l_max=l_max, |
| | tol=tol, |
| | ) |
| | return f_igso |
| |
|
| |
|
| | def generate_dlog_igso3_lookup_table( |
| | omega_grid: torch.Tensor, |
| | sigma_grid: torch.Tensor, |
| | l_max: int = 1000, |
| | tol: float = 1e-7, |
| | ) -> torch.Tensor: |
| | """ |
| | Generate a lookup table for the derivative of the logarithm of the angular IGSO(3) probability |
| | distribution function. Used e.g. for computing scaling of SO(3) norms. |
| | |
| | Args: |
| | omega_grid: Grid of angle values ranging from [0,pi] (shape is[num_omega]). |
| | sigma_grid: Grid of IGSO3 std dev values (shape is [num_sigma]). |
| | l_max: Number of terms used in the series expansion. |
| | tol: Small value for numerical stability. |
| | |
| | Returns: |
| | Table of function values evaluated at different angles and std dev levels. The final shape is |
| | [num_sigma x num_omega]. |
| | """ |
| | dlog_igso = generate_lookup_table( |
| | base_function=dlog_igso3_expansion, |
| | omega_grid=omega_grid, |
| | sigma_grid=sigma_grid, |
| | l_max=l_max, |
| | tol=tol, |
| | ) |
| | return dlog_igso |
| |
|