| | |
| | from typing import Sequence, Tuple |
| |
|
| | import torch |
| | from mmcv.ops import batched_nms |
| | from mmengine.structures import InstanceData |
| |
|
| | from mmdet.structures import DetDataSample, SampleList |
| |
|
| |
|
| | def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): |
| | """Shift rotated bboxes with offset. |
| | |
| | Args: |
| | bboxes (Tensor): The rotated bboxes need to be translated. |
| | With shape (n, 5), which means (x, y, w, h, a). |
| | offset (Sequence[int]): The translation offsets with shape of (2, ). |
| | Returns: |
| | Tensor: Shifted rotated bboxes. |
| | """ |
| | offset_tensor = bboxes.new_tensor(offset) |
| | shifted_bboxes = bboxes.clone() |
| | shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor |
| | return shifted_bboxes |
| |
|
| |
|
| | def shift_predictions(det_data_samples: SampleList, |
| | offsets: Sequence[Tuple[int, int]], |
| | src_image_shape: Tuple[int, int]) -> SampleList: |
| | """Shift predictions to the original image. |
| | |
| | Args: |
| | det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. |
| | offsets (Sequence[Tuple[int, int]]): Positions of the left top points |
| | of patches. |
| | src_image_shape (Tuple[int, int]): A (height, width) tuple of the large |
| | image's width and height. |
| | Returns: |
| | (List[:obj:`DetDataSample`]): shifted results. |
| | """ |
| | try: |
| | from sahi.slicing import shift_bboxes, shift_masks |
| | except ImportError: |
| | raise ImportError('Please run "pip install -U sahi" ' |
| | 'to install sahi first for large image inference.') |
| |
|
| | assert len(det_data_samples) == len( |
| | offsets), 'The `results` should has the ' 'same length with `offsets`.' |
| | shifted_predictions = [] |
| | for det_data_sample, offset in zip(det_data_samples, offsets): |
| | pred_inst = det_data_sample.pred_instances.clone() |
| |
|
| | |
| | if pred_inst.bboxes.size(-1) == 4: |
| | |
| | shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) |
| | elif pred_inst.bboxes.size(-1) == 5: |
| | |
| | shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | pred_inst.bboxes = shifted_bboxes |
| | if 'masks' in det_data_sample: |
| | pred_inst.masks = shift_masks(pred_inst.masks, offset, |
| | src_image_shape) |
| |
|
| | shifted_predictions.append(pred_inst.clone()) |
| |
|
| | shifted_predictions = InstanceData.cat(shifted_predictions) |
| |
|
| | return shifted_predictions |
| |
|
| |
|
| | def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, |
| | int]], |
| | src_image_shape: Tuple[int, int], |
| | nms_cfg: dict) -> DetDataSample: |
| | """Merge patch results by nms. |
| | |
| | Args: |
| | results (List[:obj:`DetDataSample`]): A list of patch results. |
| | offsets (Sequence[Tuple[int, int]]): Positions of the left top points |
| | of patches. |
| | src_image_shape (Tuple[int, int]): A (height, width) tuple of the large |
| | image's width and height. |
| | nms_cfg (dict): it should specify nms type and other parameters |
| | like `iou_threshold`. |
| | Returns: |
| | :obj:`DetDataSample`: merged results. |
| | """ |
| | shifted_instances = shift_predictions(results, offsets, src_image_shape) |
| |
|
| | _, keeps = batched_nms( |
| | boxes=shifted_instances.bboxes, |
| | scores=shifted_instances.scores, |
| | idxs=shifted_instances.labels, |
| | nms_cfg=nms_cfg) |
| | merged_instances = shifted_instances[keeps] |
| |
|
| | merged_result = results[0].clone() |
| | merged_result.pred_instances = merged_instances |
| | return merged_result |
| |
|