jboth commited on
Commit
09ceeeb
·
verified ·
1 Parent(s): 6b31dc3

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -55,38 +55,60 @@ def quaternion_invert(quaternion):
55
 
56
  class Transform3d:
57
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
58
- self.device = device
59
- self.dtype = dtype
60
  if matrix is not None:
61
  self._matrix = matrix.to(device=device, dtype=dtype)
 
 
62
  else:
63
  self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
 
 
 
64
  def get_matrix(self):
65
  return self._matrix
 
66
  def compose(self, *others):
67
  m = self._matrix
68
  for o in others:
69
- m = m @ o.get_matrix()
70
- return Transform3d(matrix=m, device=self.device, dtype=self.dtype)
 
 
 
 
 
 
 
71
  def transform_points(self, points):
72
  if points.dim() == 2:
73
  points = points.unsqueeze(0)
74
- ones = torch.ones(*points.shape[:-1], 1, dtype=points.dtype, device=points.device)
 
 
75
  pts4 = torch.cat([points, ones], dim=-1)
76
- out = torch.bmm(pts4, self._matrix.expand(pts4.shape[0], -1, -1).transpose(-2, -1))
 
77
  return out[..., :3]
 
78
  def translate(self, x, y=None, z=None):
 
 
79
  if isinstance(x, torch.Tensor) and x.dim() >= 1 and x.shape[-1] == 3:
80
- t = x
81
  else:
82
- t = torch.tensor([[x, y, z]], dtype=self.dtype, device=self.device)
83
- if t.dim() == 1: t = t.unsqueeze(0)
84
- T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
 
85
  T[:, :3, 3] = t
86
  new_m = self._matrix @ T
87
- return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
 
88
  def scale(self, x, y=None, z=None):
 
 
89
  if isinstance(x, torch.Tensor):
 
90
  if x.dim() == 0:
91
  s = x.expand(3)
92
  elif x.shape[-1] == 3:
@@ -96,23 +118,49 @@ class Transform3d:
96
  else:
97
  if y is None: y = x
98
  if z is None: z = x
99
- s = torch.tensor([x, y, z], dtype=self.dtype, device=self.device)
100
- if s.dim() == 1: s = s.unsqueeze(0)
101
- S = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
 
102
  S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
103
  new_m = self._matrix @ S
104
- return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
 
105
  def to(self, device=None, dtype=None):
106
- if device is not None: self.device = device
107
- if dtype is not None: self.dtype = dtype
108
  self._matrix = self._matrix.to(device=device, dtype=dtype)
 
 
 
109
  return self
 
110
  def inverse(self):
111
  inv_m = torch.inverse(self._matrix)
112
- return Transform3d(matrix=inv_m, device=self.device, dtype=self.dtype)
 
113
  def rotate(self, R):
114
- if R.dim() == 2: R = R.unsqueeze(0)
115
- T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
 
 
 
 
116
  T[:, :3, :3] = R
117
  new_m = self._matrix @ T
118
- return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  class Transform3d:
57
  def __init__(self, dtype=torch.float32, device="cpu", matrix=None):
 
 
58
  if matrix is not None:
59
  self._matrix = matrix.to(device=device, dtype=dtype)
60
+ self.device = self._matrix.device
61
+ self.dtype = self._matrix.dtype
62
  else:
63
  self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0)
64
+ self.device = self._matrix.device
65
+ self.dtype = dtype
66
+
67
  def get_matrix(self):
68
  return self._matrix
69
+
70
  def compose(self, *others):
71
  m = self._matrix
72
  for o in others:
73
+ om = o.get_matrix().to(device=m.device, dtype=m.dtype)
74
+ if om.shape[0] != m.shape[0]:
75
+ if m.shape[0] == 1:
76
+ m = m.expand(om.shape[0], -1, -1)
77
+ elif om.shape[0] == 1:
78
+ om = om.expand(m.shape[0], -1, -1)
79
+ m = m @ om
80
+ return Transform3d(matrix=m, device=str(m.device), dtype=m.dtype)
81
+
82
  def transform_points(self, points):
83
  if points.dim() == 2:
84
  points = points.unsqueeze(0)
85
+ mat = self._matrix
86
+ points = points.to(device=mat.device, dtype=mat.dtype)
87
+ ones = torch.ones(*points.shape[:-1], 1, dtype=mat.dtype, device=mat.device)
88
  pts4 = torch.cat([points, ones], dim=-1)
89
+ m = mat.expand(pts4.shape[0], -1, -1)
90
+ out = torch.bmm(pts4, m.transpose(-2, -1))
91
  return out[..., :3]
92
+
93
  def translate(self, x, y=None, z=None):
94
+ dev = self._matrix.device
95
+ dt = self._matrix.dtype
96
  if isinstance(x, torch.Tensor) and x.dim() >= 1 and x.shape[-1] == 3:
97
+ t = x.to(device=dev, dtype=dt)
98
  else:
99
+ t = torch.tensor([[x, y, z]], dtype=dt, device=dev)
100
+ if t.dim() == 1:
101
+ t = t.unsqueeze(0)
102
+ T = torch.eye(4, dtype=dt, device=dev).unsqueeze(0).expand(t.shape[0], -1, -1).clone()
103
  T[:, :3, 3] = t
104
  new_m = self._matrix @ T
105
+ return Transform3d(matrix=new_m, device=str(dev), dtype=dt)
106
+
107
  def scale(self, x, y=None, z=None):
108
+ dev = self._matrix.device
109
+ dt = self._matrix.dtype
110
  if isinstance(x, torch.Tensor):
111
+ x = x.to(device=dev, dtype=dt)
112
  if x.dim() == 0:
113
  s = x.expand(3)
114
  elif x.shape[-1] == 3:
 
118
  else:
119
  if y is None: y = x
120
  if z is None: z = x
121
+ s = torch.tensor([x, y, z], dtype=dt, device=dev)
122
+ if s.dim() == 1:
123
+ s = s.unsqueeze(0)
124
+ S = torch.eye(4, dtype=dt, device=dev).unsqueeze(0).expand(s.shape[0], -1, -1).clone()
125
  S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
126
  new_m = self._matrix @ S
127
+ return Transform3d(matrix=new_m, device=str(dev), dtype=dt)
128
+
129
  def to(self, device=None, dtype=None):
 
 
130
  self._matrix = self._matrix.to(device=device, dtype=dtype)
131
+ self.device = self._matrix.device
132
+ if dtype is not None:
133
+ self.dtype = dtype
134
  return self
135
+
136
  def inverse(self):
137
  inv_m = torch.inverse(self._matrix)
138
+ return Transform3d(matrix=inv_m, device=str(self._matrix.device), dtype=self._matrix.dtype)
139
+
140
  def rotate(self, R):
141
+ dev = self._matrix.device
142
+ dt = self._matrix.dtype
143
+ R = R.to(device=dev, dtype=dt)
144
+ if R.dim() == 2:
145
+ R = R.unsqueeze(0)
146
+ T = torch.eye(4, dtype=dt, device=dev).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
147
  T[:, :3, :3] = R
148
  new_m = self._matrix @ T
149
+ return Transform3d(matrix=new_m, device=str(dev), dtype=dt)
150
+
151
+ def stack(self, *others):
152
+ matrices = [self._matrix] + [o.get_matrix().to(device=self._matrix.device, dtype=self._matrix.dtype) for o in others]
153
+ stacked = torch.cat(matrices, dim=0)
154
+ return Transform3d(matrix=stacked, device=str(self._matrix.device), dtype=self._matrix.dtype)
155
+
156
+ def clone(self):
157
+ return Transform3d(matrix=self._matrix.clone(), device=str(self._matrix.device), dtype=self._matrix.dtype)
158
+
159
+ def __len__(self):
160
+ return self._matrix.shape[0]
161
+
162
+ def __getitem__(self, index):
163
+ m = self._matrix[index]
164
+ if m.dim() == 2:
165
+ m = m.unsqueeze(0)
166
+ return Transform3d(matrix=m, device=str(self._matrix.device), dtype=self._matrix.dtype)