File size: 5,338 Bytes
ca7299e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Utility functions for geometric operations (torch only).
"""
import torch
def rots_mul_vecs(m, v):
"""(Batch) Apply rotations 'm' to vectors 'v'."""
return torch.stack([
m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
], dim=-1)
def distance(p, eps=1e-10):
"""Calculate distance between a pair of points (dim=-2)."""
# [*, 2, 3]
return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
def dihedral(p, eps=1e-10):
"""Calculate dihedral angle between a quadruple of points (dim=-2)."""
# p: [*, 4, 3]
# [*, 3]
u1 = p[..., 1, :] - p[..., 0, :]
u2 = p[..., 2, :] - p[..., 1, :]
u3 = p[..., 3, :] - p[..., 2, :]
# [*, 3]
u1xu2 = torch.cross(u1, u2, dim=-1)
u2xu3 = torch.cross(u2, u3, dim=-1)
# [*]
u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5
# [*]
cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm)
return torch.stack([cos_enc, sin_enc], dim=-1)
def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
# pos: [*, L, 3]
dists_2d = torch.linalg.norm(
pos[..., :, None, :] - pos[..., None, :, :], axis=-1
)[..., None]
lower = torch.linspace(
min_bin,
max_bin,
num_bins,
device=pos.device)
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
return distogram
def rmsd(xyz1, xyz2):
""" Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
return squared_deviation(xyz1, xyz2, 'rmsd')
def squared_deviation(xyz1, xyz2, reduction='none'):
"""Squared point-wise deviation between two point clouds after alignment.
Args:
xyz1: (*, L, 3), to be transformed
xyz2: (*, L, 3), the reference
Returns:
rmsd: (*, ) or none: (*, L)
"""
map_to_np = False
if not torch.is_tensor(xyz1):
map_to_np = True
xyz1 = torch.as_tensor(xyz1)
xyz2 = torch.as_tensor(xyz2)
R, t = _find_rigid_alignment(xyz1, xyz2)
# print(R.shape, t.shape) # B, 3, 3 & B, 3
# xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L)
assert sd.shape == xyz1.shape[:-1]
if reduction == 'none':
pass
elif reduction == 'rmsd':
sd = torch.sqrt(sd.mean(dim=-1))
else:
raise NotImplementedError()
sd = sd.numpy() if map_to_np else sd
return sd
def _find_rigid_alignment(src, tgt):
"""Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
2-D or 3-D registration with known correspondences.
Registration occurs in the zero centered coordinate system, and then
must be transported back.
Args:
src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
Returns:
R: optimal rotation (*, 3, 3)
t: optimal translation (*, 3)
Test on rotation + translation and on rotation + translation + reflection
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
>>> B = (R0.mm(A.T)).T
>>> t0 = torch.tensor([3., 3.])
>>> B += t0
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
>>> B *= torch.tensor([-1., 1.])
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
"""
assert src.shape[-2] > 1
src_com = src.mean(dim=-2, keepdim=True)
tgt_com = tgt.mean(dim=-2, keepdim=True)
src_centered = src - src_com
tgt_centered = tgt - tgt_com
# Covariance matrix
# H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3
H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered)
U, S, V = torch.svd(H)
# Rotation matrix
# R = V.bmm(U.transpose(-2,-1))
R = torch.matmul(V, U.transpose(-2, -1))
# Translation vector
# t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)
return R, t.squeeze(-2) # (B, 3, 3), (B, 3)
|