| | import torch |
| | import sys |
| | from sklearn.preprocessing import StandardScaler |
| | import pytorch_lightning as pl |
| | from torch.utils.data import DataLoader |
| | from pytorch_lightning.utilities.combined_loader import CombinedLoader |
| | import laspy |
| | import numpy as np |
| | from scipy.spatial import cKDTree |
| | import math |
| | from functools import partial |
| | from torch.utils.data import TensorDataset |
| |
|
| |
|
| | class GaussianMM: |
| | def __init__(self, mu, var): |
| | super().__init__() |
| | self.centers = torch.tensor(mu) |
| | self.logstd = torch.tensor(var).log() / 2.0 |
| | self.K = self.centers.shape[0] |
| |
|
| | def logprob(self, x): |
| | logprobs = self.normal_logprob( |
| | x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd |
| | ) |
| | logprobs = torch.sum(logprobs, dim=2) |
| | return torch.logsumexp(logprobs, dim=1) - math.log(self.K) |
| |
|
| | def normal_logprob(self, z, mean, log_std): |
| | mean = mean + torch.tensor(0.0) |
| | log_std = log_std + torch.tensor(0.0) |
| | c = torch.tensor([math.log(2 * math.pi)]).to(z) |
| | inv_sigma = torch.exp(-log_std) |
| | tmp = (z - mean) * inv_sigma |
| | return -0.5 * (tmp * tmp + 2 * log_std + c) |
| |
|
| | def __call__(self, n_samples): |
| | idx = torch.randint(self.K, (n_samples,)).to(self.centers.device) |
| | mean = self.centers[idx] |
| | return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean |
| | |
| | class LidarSingleDataModule(pl.LightningDataModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | self.data_path = args.data_path |
| | self.batch_size = args.batch_size |
| | self.max_dim = args.dim |
| | self.whiten = args.whiten |
| | self.p0_mu = [ |
| | [-4.5, -4.0, 0.5], |
| | [-4.2, -3.5, 0.5], |
| | [-4.0, -3.0, 0.5], |
| | [-3.75, -2.5, 0.5], |
| | ] |
| | self.p0_var = 0.02 |
| | |
| | |
| | self.p1_1_mu = [ |
| | [-2.5, -0.25, 0.5], |
| | [-2.25, 0.675, 0.5], |
| | [-2, 1.5, 0.5], |
| | ] |
| | self.p1_2_mu = [ |
| | [2, -2, 0.5], |
| | [2.6, -1.25, 0.5], |
| | [3.2, -0.5, 0.5] |
| | ] |
| | |
| | self.p1_var = 0.03 |
| | self.k = 20 |
| | self.n_samples = 5000 |
| | self.num_timesteps = 2 |
| | self.split_ratios = args.split_ratios |
| | |
| | self.num_timesteps = 2 |
| | self.metric_clusters = 3 |
| | self.args = args |
| | self._prepare_data() |
| |
|
| | def _prepare_data(self): |
| | las = laspy.read(self.data_path) |
| | |
| | self.mask = las.classification == 2 |
| | |
| | x_offset, x_scale = las.header.offsets[0], las.header.scales[0] |
| | y_offset, y_scale = las.header.offsets[1], las.header.scales[1] |
| | z_offset, z_scale = las.header.offsets[2], las.header.scales[2] |
| | dataset = np.vstack( |
| | ( |
| | las.X[self.mask] * x_scale + x_offset, |
| | las.Y[self.mask] * y_scale + y_offset, |
| | las.Z[self.mask] * z_scale + z_offset, |
| | ) |
| | ).transpose() |
| | mi = dataset.min(axis=0, keepdims=True) |
| | ma = dataset.max(axis=0, keepdims=True) |
| | dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] |
| |
|
| | self.dataset = torch.tensor(dataset, dtype=torch.float32) |
| | self.tree = cKDTree(dataset) |
| |
|
| | x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) |
| | x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) |
| | x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) |
| |
|
| | x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) |
| | x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) |
| | x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) |
| |
|
| | split_index = int(self.n_samples * self.split_ratios[0]) |
| |
|
| | self.scaler = StandardScaler() |
| | if self.whiten: |
| | self.dataset = torch.tensor( |
| | self.scaler.fit_transform(dataset), dtype=torch.float32 |
| | ) |
| | x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) |
| | x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) |
| | x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) |
| | x1 = torch.cat([x1_1, x1_2], dim=0) |
| |
|
| | self.coords_t0 = x0 |
| | self.coords_t1 = x1 |
| | self.time_labels = np.concatenate([ |
| | np.zeros(len(self.coords_t0)), |
| | np.ones(len(self.coords_t1)), |
| | ]) |
| | |
| | train_x0 = x0[:split_index] |
| | val_x0 = x0[split_index:] |
| | |
| | |
| | train_x1 = x1[:split_index] |
| | val_x1 = x1[split_index:] |
| | |
| | self.val_x0 = val_x0 |
| | |
| | train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| | train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0) |
| | |
| | val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
| | val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0) |
| |
|
| | |
| | if self.n_samples - split_index < self.batch_size: |
| | split_index = self.n_samples - self.batch_size |
| | |
| | self.train_dataloaders = { |
| | "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| | |
| | self.val_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| | |
| | |
| | self.test_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False), |
| | "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), |
| | "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False), |
| | } |
| | |
| | points = self.dataset.cpu().numpy() |
| | x, y = points[:, 0], points[:, 1] |
| | |
| | u = (x + y) / np.sqrt(2) |
| | |
| | u_thresh = np.percentile(u, 30) |
| | mask_A = u <= u_thresh |
| |
|
| | |
| | remaining = ~mask_A |
| | mask_B = remaining & (x < y) |
| | mask_C = remaining & (x >= y) |
| |
|
| | |
| | self.metric_samples_dataloaders = [ |
| | DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), |
| | DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False), |
| | ] |
| |
|
| | def train_dataloader(self): |
| | combined_loaders = { |
| | "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def val_dataloader(self): |
| | combined_loaders = { |
| | "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def test_dataloader(self): |
| | combined_loaders = { |
| | "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def get_tangent_proj(self, points): |
| | w = self.get_tangent_plane(points) |
| | return partial(LidarSingleDataModule.projection_op, w=w) |
| |
|
| | def get_tangent_plane(self, points, temp=1e-3): |
| | points_np = points.detach().cpu().numpy() |
| | _, idx = self.tree.query(points_np, k=self.k) |
| | nearest_pts = self.dataset[idx] |
| | nearest_pts = torch.tensor(nearest_pts).to(points) |
| |
|
| | dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| | weights = torch.exp(-dists / temp) |
| |
|
| | |
| | w = LidarSingleDataModule.fit_plane(nearest_pts, weights) |
| | return w |
| |
|
| | @staticmethod |
| | def fit_plane(points, weights=None): |
| | """Expects points to be of shape (..., 3). |
| | Returns [a, b, c] such that the plane is defined as |
| | ax + by + c = z |
| | """ |
| | D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) |
| | z = points[..., 2] |
| | if weights is not None: |
| | Dtrans = D.transpose(-1, -2) |
| | else: |
| | DW = D * weights |
| | Dtrans = DW.transpose(-1, -2) |
| | w = torch.linalg.solve( |
| | torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) |
| | ).squeeze(-1) |
| | return w |
| |
|
| | @staticmethod |
| | def projection_op(x, w): |
| | """Projects points to a plane defined by w.""" |
| | |
| | n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) |
| |
|
| | pn = torch.sum(x * n, dim=-1, keepdim=True) |
| | nn = torch.sum(n * n, dim=-1, keepdim=True) |
| |
|
| | |
| | d = w[..., 2:3] |
| |
|
| | |
| | projn_x = ((pn + d) / nn) * n |
| |
|
| | |
| | return x - projn_x |
| |
|
| | def get_timepoint_data(self): |
| | """Return data organized by timepoints for visualization""" |
| | return { |
| | 't0': self.coords_t0, |
| | 't1': self.coords_t1, |
| | 'time_labels': self.time_labels |
| | } |
| |
|