File size: 5,338 Bytes
ca7299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Utility functions for geometric operations (torch only).
"""
import torch


def rots_mul_vecs(m, v):
  """(Batch) Apply rotations 'm' to vectors 'v'."""
  return torch.stack([
        m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
        m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
        m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
  ], dim=-1)
  
def distance(p, eps=1e-10):
    """Calculate distance between a pair of points (dim=-2)."""
    # [*, 2, 3]
    return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5

def dihedral(p, eps=1e-10):
    """Calculate dihedral angle between a quadruple of points (dim=-2)."""
    # p: [*, 4, 3]

    # [*, 3]
    u1 = p[..., 1, :] - p[..., 0, :]
    u2 = p[..., 2, :] - p[..., 1, :]
    u3 = p[..., 3, :] - p[..., 2, :]

    # [*, 3]
    u1xu2 = torch.cross(u1, u2, dim=-1)
    u2xu3 = torch.cross(u2, u3, dim=-1)

    # [*]
    u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
    u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
    u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5

    # [*]
    cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
    sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) /  (u2_norm * u1xu2_norm * u2xu3_norm)

    return torch.stack([cos_enc, sin_enc], dim=-1)

def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
    # pos: [*, L, 3]
    dists_2d = torch.linalg.norm(
        pos[..., :, None, :] - pos[..., None, :, :], axis=-1
    )[..., None]
    lower = torch.linspace(
        min_bin,
        max_bin,
        num_bins,
        device=pos.device)
    upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
    distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
    return distogram

def rmsd(xyz1, xyz2):
    """ Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
    return squared_deviation(xyz1, xyz2, 'rmsd')

def squared_deviation(xyz1, xyz2, reduction='none'):
    """Squared point-wise deviation between two point clouds after alignment.
    
    Args:
        xyz1: (*, L, 3), to be transformed
        xyz2: (*, L, 3), the reference 
    
    Returns:
        rmsd: (*, ) or none: (*, L)
    """
    map_to_np = False
    if not torch.is_tensor(xyz1):
        map_to_np = True
        xyz1 = torch.as_tensor(xyz1)
        xyz2 = torch.as_tensor(xyz2)
    
    R, t = _find_rigid_alignment(xyz1, xyz2)

    # print(R.shape, t.shape) # B, 3, 3 & B, 3
    
    # xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
    xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)

    sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1)    # (*, L)
    
    assert sd.shape == xyz1.shape[:-1]  
    if reduction == 'none':
        pass
    elif reduction == 'rmsd':
        sd = torch.sqrt(sd.mean(dim=-1))
    else:
        raise NotImplementedError()
    
    sd = sd.numpy() if map_to_np else sd
    return sd

def _find_rigid_alignment(src, tgt):
    """Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
        https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
    
    See: https://en.wikipedia.org/wiki/Kabsch_algorithm
    
    2-D or 3-D registration with known correspondences.
    Registration occurs in the zero centered coordinate system, and then
    must be transported back.
    
    Args:
        src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
        tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
    Returns:
        R: optimal rotation (*, 3, 3)
        t: optimal translation (*, 3)
        
    Test on rotation + translation and on rotation + translation + reflection
        >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
        >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
        >>> B = (R0.mm(A.T)).T
        >>> t0 = torch.tensor([3., 3.])
        >>> B += t0
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
        >>> B *= torch.tensor([-1., 1.])
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
    """
    assert src.shape[-2] > 1
    src_com = src.mean(dim=-2, keepdim=True)
    tgt_com = tgt.mean(dim=-2, keepdim=True)
    src_centered = src - src_com
    tgt_centered = tgt - tgt_com

    # Covariance matrix

    # H = src_centered.transpose(-2,-1).bmm(tgt_centered)    # *, 3, 3
    H = torch.matmul(src_centered.transpose(-2,-1),  tgt_centered)

    U, S, V = torch.svd(H)
    # Rotation matrix

    # R = V.bmm(U.transpose(-2,-1))
    R = torch.matmul(V, U.transpose(-2, -1))

    # Translation vector

    # t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
    t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)

    return R, t.squeeze(-2) # (B, 3, 3), (B, 3)