| | |
| | import os |
| | from typing import List |
| |
|
| | from mmdet.registry import DATASETS |
| | from .base_det_dataset import BaseDetDataset |
| |
|
| | try: |
| | from dsdl.dataset import DSDLDataset |
| | except ImportError: |
| | DSDLDataset = None |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class DSDLDetDataset(BaseDetDataset): |
| | """Dataset for dsdl detection. |
| | |
| | Args: |
| | with_bbox(bool): Load bbox or not, defaults to be True. |
| | with_polygon(bool): Load polygon or not, defaults to be False. |
| | with_mask(bool): Load seg map mask or not, defaults to be False. |
| | with_imagelevel_label(bool): Load image level label or not, |
| | defaults to be False. |
| | with_hierarchy(bool): Load hierarchy information or not, |
| | defaults to be False. |
| | specific_key_path(dict): Path of specific key which can not |
| | be loaded by it's field name. |
| | pre_transform(dict): pre-transform functions before loading. |
| | """ |
| |
|
| | METAINFO = {} |
| |
|
| | def __init__(self, |
| | with_bbox: bool = True, |
| | with_polygon: bool = False, |
| | with_mask: bool = False, |
| | with_imagelevel_label: bool = False, |
| | with_hierarchy: bool = False, |
| | specific_key_path: dict = {}, |
| | pre_transform: dict = {}, |
| | **kwargs) -> None: |
| |
|
| | if DSDLDataset is None: |
| | raise RuntimeError( |
| | 'Package dsdl is not installed. Please run "pip install dsdl".' |
| | ) |
| |
|
| | self.with_hierarchy = with_hierarchy |
| | self.specific_key_path = specific_key_path |
| |
|
| | loc_config = dict(type='LocalFileReader', working_dir='') |
| | if kwargs.get('data_root'): |
| | kwargs['ann_file'] = os.path.join(kwargs['data_root'], |
| | kwargs['ann_file']) |
| | self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag'] |
| | if with_bbox: |
| | self.required_fields.append('Bbox') |
| | if with_polygon: |
| | self.required_fields.append('Polygon') |
| | if with_mask: |
| | self.required_fields.append('LabelMap') |
| | if with_imagelevel_label: |
| | self.required_fields.append('image_level_labels') |
| | assert 'image_level_labels' in specific_key_path.keys( |
| | ), '`image_level_labels` not specified in `specific_key_path` !' |
| |
|
| | self.extra_keys = [ |
| | key for key in self.specific_key_path.keys() |
| | if key not in self.required_fields |
| | ] |
| |
|
| | self.dsdldataset = DSDLDataset( |
| | dsdl_yaml=kwargs['ann_file'], |
| | location_config=loc_config, |
| | required_fields=self.required_fields, |
| | specific_key_path=specific_key_path, |
| | transform=pre_transform, |
| | ) |
| |
|
| | BaseDetDataset.__init__(self, **kwargs) |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | """Load data info from an dsdl yaml file named as ``self.ann_file`` |
| | |
| | Returns: |
| | List[dict]: A list of data info. |
| | """ |
| | if self.with_hierarchy: |
| | |
| | classes_names, relation_matrix = \ |
| | self.dsdldataset.class_dom.get_hierarchy_info() |
| | self._metainfo['classes'] = tuple(classes_names) |
| | self._metainfo['RELATION_MATRIX'] = relation_matrix |
| |
|
| | else: |
| | self._metainfo['classes'] = tuple(self.dsdldataset.class_names) |
| |
|
| | data_list = [] |
| |
|
| | for i, data in enumerate(self.dsdldataset): |
| | |
| | datainfo = dict( |
| | img_id=i, |
| | img_path=os.path.join(self.data_prefix['img_path'], |
| | data['Image'][0].location), |
| | width=data['ImageShape'][0].width, |
| | height=data['ImageShape'][0].height, |
| | ) |
| |
|
| | |
| | if 'image_level_labels' in data.keys(): |
| | if self.with_hierarchy: |
| | |
| | datainfo['image_level_labels'] = [ |
| | self._metainfo['classes'].index(i.leaf_node_name) |
| | for i in data['image_level_labels'] |
| | ] |
| | else: |
| | datainfo['image_level_labels'] = [ |
| | self._metainfo['classes'].index(i.name) |
| | for i in data['image_level_labels'] |
| | ] |
| |
|
| | |
| | if 'LabelMap' in data.keys(): |
| | datainfo['seg_map_path'] = data['LabelMap'] |
| |
|
| | |
| | instances = [] |
| | if 'Bbox' in data.keys(): |
| | for idx in range(len(data['Bbox'])): |
| | bbox = data['Bbox'][idx] |
| | if self.with_hierarchy: |
| | |
| | label = data['Label'][idx].leaf_node_name |
| | label_index = self._metainfo['classes'].index(label) |
| | else: |
| | label = data['Label'][idx].name |
| | label_index = self._metainfo['classes'].index(label) |
| |
|
| | instance = {} |
| | instance['bbox'] = bbox.xyxy |
| | instance['bbox_label'] = label_index |
| |
|
| | if 'ignore_flag' in data.keys(): |
| | |
| | instance['ignore_flag'] = data['ignore_flag'][idx] |
| | else: |
| | instance['ignore_flag'] = 0 |
| |
|
| | if 'Polygon' in data.keys(): |
| | |
| | polygon = data['Polygon'][idx] |
| | instance['mask'] = polygon.openmmlabformat |
| |
|
| | for key in self.extra_keys: |
| | |
| | instance[key] = data[key][idx] |
| |
|
| | instances.append(instance) |
| |
|
| | datainfo['instances'] = instances |
| | |
| | if len(datainfo['instances']) > 0: |
| | data_list.append(datainfo) |
| |
|
| | return data_list |
| |
|
| | def filter_data(self) -> List[dict]: |
| | """Filter annotations according to filter_cfg. |
| | |
| | Returns: |
| | List[dict]: Filtered results. |
| | """ |
| | if self.test_mode: |
| | return self.data_list |
| |
|
| | filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ |
| | if self.filter_cfg is not None else False |
| | min_size = self.filter_cfg.get('min_size', 0) \ |
| | if self.filter_cfg is not None else 0 |
| |
|
| | valid_data_list = [] |
| | for i, data_info in enumerate(self.data_list): |
| | width = data_info['width'] |
| | height = data_info['height'] |
| | if filter_empty_gt and len(data_info['instances']) == 0: |
| | continue |
| | if min(width, height) >= min_size: |
| | valid_data_list.append(data_info) |
| |
|
| | return valid_data_list |
| |
|