| | |
| | import itertools |
| | from collections.abc import Sized |
| | from typing import Any, List, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from mmengine.device import get_device |
| | from .base_data_element import BaseDataElement |
| |
|
| | BoolTypeTensor: Union[Any] |
| | LongTypeTensor: Union[Any] |
| |
|
| | if get_device() == 'npu': |
| | BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] |
| | LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] |
| | elif get_device() == 'mlu': |
| | BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] |
| | LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] |
| | elif get_device() == 'musa': |
| | BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor] |
| | LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor] |
| | else: |
| | BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] |
| | LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] |
| |
|
| | IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, |
| | BoolTypeTensor, np.ndarray] |
| |
|
| |
|
| | |
| | |
| | class InstanceData(BaseDataElement): |
| | """Data structure for instance-level annotations or predictions. |
| | |
| | Subclass of :class:`BaseDataElement`. All value in `data_fields` |
| | should have the same length. This design refer to |
| | https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 |
| | InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value |
| | in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, |
| | and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. |
| | |
| | Examples: |
| | >>> # custom data structure |
| | >>> class TmpObject: |
| | ... def __init__(self, tmp) -> None: |
| | ... assert isinstance(tmp, list) |
| | ... self.tmp = tmp |
| | ... def __len__(self): |
| | ... return len(self.tmp) |
| | ... def __getitem__(self, item): |
| | ... if isinstance(item, int): |
| | ... if item >= len(self) or item < -len(self): # type:ignore |
| | ... raise IndexError(f'Index {item} out of range!') |
| | ... else: |
| | ... # keep the dimension |
| | ... item = slice(item, None, len(self)) |
| | ... return TmpObject(self.tmp[item]) |
| | ... @staticmethod |
| | ... def cat(tmp_objs): |
| | ... assert all(isinstance(results, TmpObject) for results in tmp_objs) |
| | ... if len(tmp_objs) == 1: |
| | ... return tmp_objs[0] |
| | ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] |
| | ... tmp_list = list(itertools.chain(*tmp_list)) |
| | ... new_data = TmpObject(tmp_list) |
| | ... return new_data |
| | ... def __repr__(self): |
| | ... return str(self.tmp) |
| | >>> from mmengine.structures import InstanceData |
| | >>> import numpy as np |
| | >>> import torch |
| | >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) |
| | >>> instance_data = InstanceData(metainfo=img_meta) |
| | >>> 'img_shape' in instance_data |
| | True |
| | >>> instance_data.det_labels = torch.LongTensor([2, 3]) |
| | >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) |
| | >>> instance_data.bboxes = torch.rand((2, 4)) |
| | >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) |
| | >>> len(instance_data) |
| | 2 |
| | >>> print(instance_data) |
| | <InstanceData( |
| | META INFORMATION |
| | img_shape: (800, 1196, 3) |
| | pad_shape: (800, 1216, 3) |
| | DATA FIELDS |
| | det_labels: tensor([2, 3]) |
| | det_scores: tensor([0.8000, 0.7000]) |
| | bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
| | [0.8101, 0.3105, 0.5123, 0.6263]]) |
| | polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] |
| | ) at 0x7fb492de6280> |
| | >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] |
| | >>> sorted_results.det_scores |
| | tensor([0.7000, 0.8000]) |
| | >>> print(instance_data[instance_data.det_scores > 0.75]) |
| | <InstanceData( |
| | META INFORMATION |
| | img_shape: (800, 1196, 3) |
| | pad_shape: (800, 1216, 3) |
| | DATA FIELDS |
| | det_labels: tensor([2]) |
| | det_scores: tensor([0.8000]) |
| | bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) |
| | polygons: [[1, 2, 3, 4]] |
| | ) at 0x7f64ecf0ec40> |
| | >>> print(instance_data[instance_data.det_scores > 1]) |
| | <InstanceData( |
| | META INFORMATION |
| | img_shape: (800, 1196, 3) |
| | pad_shape: (800, 1216, 3) |
| | DATA FIELDS |
| | det_labels: tensor([], dtype=torch.int64) |
| | det_scores: tensor([]) |
| | bboxes: tensor([], size=(0, 4)) |
| | polygons: [] |
| | ) at 0x7f660a6a7f70> |
| | >>> print(instance_data.cat([instance_data, instance_data])) |
| | <InstanceData( |
| | META INFORMATION |
| | img_shape: (800, 1196, 3) |
| | pad_shape: (800, 1216, 3) |
| | DATA FIELDS |
| | det_labels: tensor([2, 3, 2, 3]) |
| | det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) |
| | bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
| | [0.8101, 0.3105, 0.5123, 0.6263], |
| | [0.4997, 0.7707, 0.0595, 0.4188], |
| | [0.8101, 0.3105, 0.5123, 0.6263]]) |
| | polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] |
| | ) at 0x7f203542feb0> |
| | """ |
| |
|
| | def __setattr__(self, name: str, value: Sized): |
| | """setattr is only used to set data. |
| | |
| | The value must have the attribute of `__len__` and have the same length |
| | of `InstanceData`. |
| | """ |
| | if name in ('_metainfo_fields', '_data_fields'): |
| | if not hasattr(self, name): |
| | super().__setattr__(name, value) |
| | else: |
| | raise AttributeError(f'{name} has been used as a ' |
| | 'private attribute, which is immutable.') |
| |
|
| | else: |
| | assert isinstance(value, |
| | Sized), 'value must contain `__len__` attribute' |
| |
|
| | if len(self) > 0: |
| | assert len(value) == len(self), 'The length of ' \ |
| | f'values {len(value)} is ' \ |
| | 'not consistent with ' \ |
| | 'the length of this ' \ |
| | ':obj:`InstanceData` ' \ |
| | f'{len(self)}' |
| | super().__setattr__(name, value) |
| |
|
| | __setitem__ = __setattr__ |
| |
|
| | def __getitem__(self, item: IndexType) -> 'InstanceData': |
| | """ |
| | Args: |
| | item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, |
| | :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): |
| | Get the corresponding values according to item. |
| | |
| | Returns: |
| | :obj:`InstanceData`: Corresponding values. |
| | """ |
| | assert isinstance(item, IndexType.__args__) |
| | if isinstance(item, list): |
| | item = np.array(item) |
| | if isinstance(item, np.ndarray): |
| | |
| | |
| | |
| | |
| | item = item.astype(np.int64) if item.dtype == np.int32 else item |
| | item = torch.from_numpy(item) |
| |
|
| | if isinstance(item, str): |
| | return getattr(self, item) |
| |
|
| | if isinstance(item, int): |
| | if item >= len(self) or item < -len(self): |
| | raise IndexError(f'Index {item} out of range!') |
| | else: |
| | |
| | item = slice(item, None, len(self)) |
| |
|
| | new_data = self.__class__(metainfo=self.metainfo) |
| | if isinstance(item, torch.Tensor): |
| | assert item.dim() == 1, 'Only support to get the' \ |
| | ' values along the first dimension.' |
| | if isinstance(item, BoolTypeTensor.__args__): |
| | assert len(item) == len(self), 'The shape of the ' \ |
| | 'input(BoolTensor) ' \ |
| | f'{len(item)} ' \ |
| | 'does not match the shape ' \ |
| | 'of the indexed tensor ' \ |
| | 'in results_field ' \ |
| | f'{len(self)} at ' \ |
| | 'first dimension.' |
| |
|
| | for k, v in self.items(): |
| | if isinstance(v, torch.Tensor): |
| | new_data[k] = v[item] |
| | elif isinstance(v, np.ndarray): |
| | new_data[k] = v[item.cpu().numpy()] |
| | elif isinstance( |
| | v, (str, list, tuple)) or (hasattr(v, '__getitem__') |
| | and hasattr(v, 'cat')): |
| | |
| | if isinstance(item, BoolTypeTensor.__args__): |
| | indexes = torch.nonzero(item).view( |
| | -1).cpu().numpy().tolist() |
| | else: |
| | indexes = item.cpu().numpy().tolist() |
| | slice_list = [] |
| | if indexes: |
| | for index in indexes: |
| | slice_list.append(slice(index, None, len(v))) |
| | else: |
| | slice_list.append(slice(None, 0, None)) |
| | r_list = [v[s] for s in slice_list] |
| | if isinstance(v, (str, list, tuple)): |
| | new_value = r_list[0] |
| | for r in r_list[1:]: |
| | new_value = new_value + r |
| | else: |
| | new_value = v.cat(r_list) |
| | new_data[k] = new_value |
| | else: |
| | raise ValueError( |
| | f'The type of `{k}` is `{type(v)}`, which has no ' |
| | 'attribute of `cat`, so it does not ' |
| | 'support slice with `bool`') |
| |
|
| | else: |
| | |
| | for k, v in self.items(): |
| | new_data[k] = v[item] |
| | return new_data |
| |
|
| | @staticmethod |
| | def cat(instances_list: List['InstanceData']) -> 'InstanceData': |
| | """Concat the instances of all :obj:`InstanceData` in the list. |
| | |
| | Note: To ensure that cat returns as expected, make sure that |
| | all elements in the list must have exactly the same keys. |
| | |
| | Args: |
| | instances_list (list[:obj:`InstanceData`]): A list |
| | of :obj:`InstanceData`. |
| | |
| | Returns: |
| | :obj:`InstanceData` |
| | """ |
| | assert all( |
| | isinstance(results, InstanceData) for results in instances_list) |
| | assert len(instances_list) > 0 |
| | if len(instances_list) == 1: |
| | return instances_list[0] |
| |
|
| | |
| | |
| | field_keys_list = [ |
| | instances.all_keys() for instances in instances_list |
| | ] |
| | assert len({len(field_keys) for field_keys in field_keys_list}) \ |
| | == 1 and len(set(itertools.chain(*field_keys_list))) \ |
| | == len(field_keys_list[0]), 'There are different keys in ' \ |
| | '`instances_list`, which may ' \ |
| | 'cause the cat operation ' \ |
| | 'to fail. Please make sure all ' \ |
| | 'elements in `instances_list` ' \ |
| | 'have the exact same key.' |
| |
|
| | new_data = instances_list[0].__class__( |
| | metainfo=instances_list[0].metainfo) |
| | for k in instances_list[0].keys(): |
| | values = [results[k] for results in instances_list] |
| | v0 = values[0] |
| | if isinstance(v0, torch.Tensor): |
| | new_values = torch.cat(values, dim=0) |
| | elif isinstance(v0, np.ndarray): |
| | new_values = np.concatenate(values, axis=0) |
| | elif isinstance(v0, (str, list, tuple)): |
| | new_values = v0[:] |
| | for v in values[1:]: |
| | new_values += v |
| | elif hasattr(v0, 'cat'): |
| | new_values = v0.cat(values) |
| | else: |
| | raise ValueError( |
| | f'The type of `{k}` is `{type(v0)}` which has no ' |
| | 'attribute of `cat`') |
| | new_data[k] = new_values |
| | return new_data |
| |
|
| | def __len__(self) -> int: |
| | """int: The length of InstanceData.""" |
| | if len(self._data_fields) > 0: |
| | return len(self.values()[0]) |
| | else: |
| | return 0 |
| |
|