P2DFlow / analysis /src /common /geo_utils.py
Holmes
test
ca7299e
"""
Utility functions for geometric operations (torch only).
"""
import torch
def rots_mul_vecs(m, v):
"""(Batch) Apply rotations 'm' to vectors 'v'."""
return torch.stack([
m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
], dim=-1)
def distance(p, eps=1e-10):
"""Calculate distance between a pair of points (dim=-2)."""
# [*, 2, 3]
return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
def dihedral(p, eps=1e-10):
"""Calculate dihedral angle between a quadruple of points (dim=-2)."""
# p: [*, 4, 3]
# [*, 3]
u1 = p[..., 1, :] - p[..., 0, :]
u2 = p[..., 2, :] - p[..., 1, :]
u3 = p[..., 3, :] - p[..., 2, :]
# [*, 3]
u1xu2 = torch.cross(u1, u2, dim=-1)
u2xu3 = torch.cross(u2, u3, dim=-1)
# [*]
u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5
# [*]
cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm)
return torch.stack([cos_enc, sin_enc], dim=-1)
def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
# pos: [*, L, 3]
dists_2d = torch.linalg.norm(
pos[..., :, None, :] - pos[..., None, :, :], axis=-1
)[..., None]
lower = torch.linspace(
min_bin,
max_bin,
num_bins,
device=pos.device)
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
return distogram
def rmsd(xyz1, xyz2):
""" Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
return squared_deviation(xyz1, xyz2, 'rmsd')
def squared_deviation(xyz1, xyz2, reduction='none'):
"""Squared point-wise deviation between two point clouds after alignment.
Args:
xyz1: (*, L, 3), to be transformed
xyz2: (*, L, 3), the reference
Returns:
rmsd: (*, ) or none: (*, L)
"""
map_to_np = False
if not torch.is_tensor(xyz1):
map_to_np = True
xyz1 = torch.as_tensor(xyz1)
xyz2 = torch.as_tensor(xyz2)
R, t = _find_rigid_alignment(xyz1, xyz2)
# print(R.shape, t.shape) # B, 3, 3 & B, 3
# xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L)
assert sd.shape == xyz1.shape[:-1]
if reduction == 'none':
pass
elif reduction == 'rmsd':
sd = torch.sqrt(sd.mean(dim=-1))
else:
raise NotImplementedError()
sd = sd.numpy() if map_to_np else sd
return sd
def _find_rigid_alignment(src, tgt):
"""Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
2-D or 3-D registration with known correspondences.
Registration occurs in the zero centered coordinate system, and then
must be transported back.
Args:
src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
Returns:
R: optimal rotation (*, 3, 3)
t: optimal translation (*, 3)
Test on rotation + translation and on rotation + translation + reflection
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
>>> B = (R0.mm(A.T)).T
>>> t0 = torch.tensor([3., 3.])
>>> B += t0
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
>>> B *= torch.tensor([-1., 1.])
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
"""
assert src.shape[-2] > 1
src_com = src.mean(dim=-2, keepdim=True)
tgt_com = tgt.mean(dim=-2, keepdim=True)
src_centered = src - src_com
tgt_centered = tgt - tgt_com
# Covariance matrix
# H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3
H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered)
U, S, V = torch.svd(H)
# Rotation matrix
# R = V.bmm(U.transpose(-2,-1))
R = torch.matmul(V, U.transpose(-2, -1))
# Translation vector
# t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)
return R, t.squeeze(-2) # (B, 3, 3), (B, 3)