| | import numpy as np |
| | import torch |
| | from torchvision.transforms import functional as tfn |
| | import torchvision.transforms.functional as tvf |
| |
|
| | from ..utils import decompose_rotmat |
| | from ..image import pad_image, rectify_image, resize_image |
| | from ...utils.wrappers import Camera |
| | from ..schema import KITTIDataConfiguration |
| |
|
| |
|
| | class BEVTransform: |
| | def __init__(self, |
| | cfg: KITTIDataConfiguration, augmentations): |
| | self.cfg = cfg |
| | self.augmentations = augmentations |
| |
|
| | @staticmethod |
| | def _compact_labels(msk, cat, iscrowd): |
| | ids = np.unique(msk) |
| | if 0 not in ids: |
| | ids = np.concatenate((np.array([0], dtype=np.int32), ids), axis=0) |
| |
|
| | ids_to_compact = np.zeros((ids.max() + 1,), dtype=np.int32) |
| | ids_to_compact[ids] = np.arange(0, ids.size, dtype=np.int32) |
| |
|
| | msk = ids_to_compact[msk] |
| | cat = cat[ids] |
| | iscrowd = iscrowd[ids] |
| |
|
| | return msk, cat, iscrowd |
| |
|
| | def __call__(self, img, bev_msk=None, bev_plabel=None, fv_msk=None, bev_weights_msk=None, |
| | bev_cat=None, bev_iscrowd=None, fv_cat=None, fv_iscrowd=None, |
| | fv_intrinsics=None, ego_pose=None): |
| | |
| | if bev_cat is not None: |
| | bev_cat = np.array(bev_cat, dtype=np.int32) |
| | if bev_iscrowd is not None: |
| | bev_iscrowd = np.array(bev_iscrowd, dtype=np.uint8) |
| |
|
| | if ego_pose is not None: |
| | ego_pose = np.array(ego_pose, dtype=np.float32) |
| |
|
| | roll, pitch, yaw = decompose_rotmat(ego_pose[:3, :3]) |
| |
|
| | |
| | img = tfn.to_tensor(img) |
| | |
| | fx = fv_intrinsics[0][0] |
| | fy = fv_intrinsics[1][1] |
| | cx = fv_intrinsics[0][2] |
| | cy = fv_intrinsics[1][2] |
| | width = img.shape[2] |
| | height = img.shape[1] |
| |
|
| | cam = Camera(torch.tensor( |
| | [width, height, fx, fy, cx - 0.5, cy - 0.5])).float() |
| |
|
| | if not self.cfg.gravity_align: |
| | |
| | roll = 0.0 |
| | pitch = 0.0 |
| | img, valid = rectify_image(img, cam, roll, pitch) |
| | else: |
| | img, valid = rectify_image( |
| | img, cam, roll, pitch if self.cfg.rectify_pitch else None |
| | ) |
| | roll = 0.0 |
| | if self.cfg.rectify_pitch: |
| | pitch = 0.0 |
| |
|
| | if self.cfg.target_focal_length is not None: |
| | |
| | factor = self.cfg.target_focal_length / cam.f.numpy() |
| | size = (np.array(img.shape[-2:][::-1]) * factor).astype(int) |
| | img, _, cam, valid = resize_image(img, size, camera=cam, valid=valid) |
| | size_out = self.cfg.resize_image |
| | if size_out is None: |
| | |
| | stride = self.cfg.pad_to_multiple |
| | size_out = (np.ceil((size / stride)) * stride).astype(int) |
| | |
| | img, valid, cam = pad_image( |
| | img, size_out, cam, valid, crop_and_center=False |
| | ) |
| | elif self.cfg.resize_image is not None: |
| | img, _, cam, valid = resize_image( |
| | img, self.cfg.resize_image, fn=max, camera=cam, valid=valid |
| | ) |
| | if self.cfg.pad_to_square: |
| | |
| | img, valid, cam = pad_image(img, self.cfg.resize_image, cam, valid) |
| |
|
| | |
| | if bev_msk is not None: |
| | bev_msk = np.expand_dims( |
| | np.array(bev_msk, dtype=np.int32, copy=False), |
| | axis=0 |
| | ) |
| | bev_msk, bev_cat, bev_iscrowd = self._compact_labels( |
| | bev_msk, bev_cat, bev_iscrowd |
| | ) |
| |
|
| | bev_msk = torch.from_numpy(bev_msk) |
| | bev_cat = torch.from_numpy(bev_cat) |
| |
|
| | rotated_mask = torch.rot90(bev_msk, dims=(1, 2)) |
| | cropped_mask = rotated_mask[:, :672, (rotated_mask.size(2) - 672) // 2:-(rotated_mask.size(2) - 672) // 2] |
| |
|
| | bev_msk = cropped_mask.squeeze(0) |
| | seg_masks = bev_cat[bev_msk] |
| |
|
| | seg_masks_onehot = seg_masks.clone() |
| | seg_masks_onehot[seg_masks_onehot == 255] = 0 |
| | seg_masks_onehot = torch.nn.functional.one_hot( |
| | seg_masks_onehot.to(torch.int64), |
| | num_classes=self.cfg.num_classes |
| | ) |
| | seg_masks_onehot[seg_masks == 255] = 0 |
| |
|
| | seg_masks_onehot = seg_masks_onehot.permute(2, 0, 1) |
| |
|
| | seg_masks_down = tvf.resize(seg_masks_onehot, (100, 100)) |
| |
|
| | seg_masks_down = seg_masks_down.permute(1, 2, 0) |
| |
|
| | if self.cfg.class_mapping is not None: |
| | seg_masks_down = seg_masks_down[:, :, self.cfg.class_mapping] |
| |
|
| | img = self.augmentations(img) |
| | flood_masks = torch.all(seg_masks_down == 0, dim=2).float() |
| |
|
| |
|
| | ret = { |
| | "image": img, |
| | "valid": valid, |
| | "camera": cam, |
| | "seg_masks": (seg_masks_down).float().contiguous(), |
| | "flood_masks": flood_masks, |
| | "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(), |
| | "confidence_map": flood_masks, |
| | } |
| | |
| | for key, value in ret.items(): |
| | if isinstance(value, np.ndarray): |
| | ret[key] = torch.from_numpy(value) |
| | |
| | return ret |
| |
|