|
|
import math |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def _random_orthonormal_matrix(d: int, device: torch.device) -> torch.Tensor: |
|
|
"""Draw a random rotation matrix Q ∈ SO(d) (Haar) via QR-factorisation.""" |
|
|
a = torch.randn(d, d, device=device) |
|
|
|
|
|
q, r = torch.linalg.qr(a, mode="reduced") |
|
|
|
|
|
if torch.det(q) < 0: |
|
|
q[:, 0] = -q[:, 0] |
|
|
return q |
|
|
|
|
|
|
|
|
def sobol_sphere( |
|
|
n: int, |
|
|
d: int, |
|
|
device: torch.device, |
|
|
sobol_engine: Optional[torch.quasirandom.SobolEngine] = None, |
|
|
) -> Union[torch.Tensor, torch.quasirandom.SobolEngine]: |
|
|
"""n unit vectors on S^{d-1} via scrambled Sobol + Gaussian + random rotation.""" |
|
|
if sobol_engine is None: |
|
|
sob = torch.quasirandom.SobolEngine(dimension=d, scramble=True) |
|
|
else: |
|
|
sob = sobol_engine |
|
|
|
|
|
u01 = sob.draw(n).to(device) |
|
|
|
|
|
eps = 1e-7 |
|
|
u01 = u01.clamp(min=eps, max=1.0 - eps) |
|
|
|
|
|
z = torch.erfinv(2.0 * u01 - 1.0) * math.sqrt(2.0) |
|
|
z = z / (z.norm(dim=1, keepdim=True) + 1e-8) |
|
|
|
|
|
Q = _random_orthonormal_matrix(d, device) |
|
|
return z @ Q.T, sob |
|
|
|