| | """Contains linear algebra related utility functions. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from scipy.spatial.transform import Rotation |
| |
|
| |
|
| | def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor: |
| | """Convert batch of quaternions into rotations matrices. |
| | |
| | Args: |
| | quaternions: The quaternions convert to matrices. |
| | |
| | Returns: |
| | The rotations matrices corresponding to the (normalized) quaternions. |
| | """ |
| | device = quaternions.device |
| | shape = quaternions.shape[:-1] |
| |
|
| | quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True) |
| | real_part = quaternions[..., 0] |
| | vector_part = quaternions[..., 1:] |
| |
|
| | vector_cross = get_cross_product_matrix(vector_part) |
| | real_part = real_part[..., None, None] |
| |
|
| | matrix_outer = vector_part[..., :, None] * vector_part[..., None, :] |
| | matrix_diag = real_part.square() * eyes(3, shape=shape, device=device) |
| | matrix_cross_1 = 2 * real_part * vector_cross |
| | matrix_cross_2 = vector_cross @ vector_cross |
| |
|
| | return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2 |
| |
|
| |
|
| | def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor: |
| | """Convert batch of rotation matrices to quaternions. |
| | |
| | Args: |
| | matrices: The matrices to convert to quaternions. |
| | |
| | Returns: |
| | The quaternions corresponding to the rotation matrices. |
| | |
| | Note: this operation is not differentiable and will be performed on the CPU. |
| | """ |
| | if not matrices.shape[-2:] == (3, 3): |
| | raise ValueError(f"matrices have invalid shape {matrices.shape}") |
| | matrices_np = matrices.detach().cpu().numpy() |
| | quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat() |
| | |
| | quaternions_np = quaternions_np[:, [3, 0, 1, 2]] |
| | quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,)) |
| | return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype) |
| |
|
| |
|
| | def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor: |
| | """Generate cross product matrix for vector exterior product.""" |
| | if not vectors.shape[-1] == 3: |
| | raise ValueError("Only 3-dimensional vectors are supported") |
| | device = vectors.device |
| | shape = vectors.shape[:-1] |
| | unit_basis = eyes(3, shape=shape, device=device) |
| | |
| | |
| | return torch.cross(vectors[..., :, None], unit_basis, dim=-2) |
| |
|
| |
|
| | def eyes( |
| | dim: int, shape: tuple[int, ...], device: torch.device | str | None = None |
| | ) -> torch.Tensor: |
| | """Create batch of identity matrices.""" |
| | return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone() |
| |
|
| |
|
| | def quaternion_product(q1, q2): |
| | """Compute dot product between two quaternions.""" |
| | real_1 = q1[..., :1] |
| | real_2 = q2[..., :1] |
| | vector_1 = q1[..., 1:] |
| | vector_2 = q2[..., 1:] |
| |
|
| | real_out = real_1 * real_2 - (vector_1 * vector_2).sum(dim=-1, keepdim=True) |
| | vector_out = real_1 * vector_2 + real_2 * vector_1 + torch.cross(vector_1, vector_2) |
| | return torch.concatenate([real_out, vector_out], dim=-1) |
| |
|
| |
|
| | def quaternion_conj(q): |
| | """Get conjugate of a quaternion.""" |
| | real = q[..., :1] |
| | vector = q[..., 1:] |
| | return torch.concatenate([real, -vector], dim=-1) |
| |
|
| |
|
| | def project(u: torch.Tensor, basis: torch.Tensor) -> torch.Tensor: |
| | """Project tensor u to unit basis a.""" |
| | unit_u = F.normalize(u, dim=-1) |
| | inner_prod = (unit_u * basis).sum(dim=-1, keepdim=True) |
| | return inner_prod * u |
| |
|