| |
| import copy |
| import os.path as osp |
| from collections import defaultdict |
| from typing import Any, List, Tuple |
|
|
| import mmengine.fileio as fileio |
| from mmengine.dataset import BaseDataset |
| from mmengine.logging import print_log |
|
|
| from mmdet.datasets.api_wrappers import COCO |
| from mmdet.registry import DATASETS |
|
|
|
|
| @DATASETS.register_module() |
| class BaseVideoDataset(BaseDataset): |
| """Base video dataset for VID, MOT and VIS tasks.""" |
|
|
| META = dict(classes=None) |
| |
| ANN_ID_UNIQUE = True |
|
|
| def __init__(self, *args, backend_args: dict = None, **kwargs): |
| self.backend_args = backend_args |
| super().__init__(*args, **kwargs) |
|
|
| def load_data_list(self) -> Tuple[List[dict], List]: |
| """Load annotations from an annotation file named as ``self.ann_file``. |
| |
| Returns: |
| tuple(list[dict], list): A list of annotation and a list of |
| valid data indices. |
| """ |
| with fileio.get_local_path(self.ann_file) as local_path: |
| self.coco = COCO(local_path) |
| |
| |
| self.cat_ids = self.coco.get_cat_ids( |
| cat_names=self.metainfo['classes']) |
| self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} |
| self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) |
| |
| self.img_ids_with_ann = set() |
|
|
| img_ids = self.coco.get_img_ids() |
| total_ann_ids = [] |
| |
| |
| single_video_id = 100000 |
| videos = {} |
| for img_id in img_ids: |
| raw_img_info = self.coco.load_imgs([img_id])[0] |
| raw_img_info['img_id'] = img_id |
| if 'video_id' not in raw_img_info: |
| single_video_id = single_video_id + 1 |
| video_id = single_video_id |
| else: |
| video_id = raw_img_info['video_id'] |
|
|
| if video_id not in videos: |
| videos[video_id] = { |
| 'video_id': video_id, |
| 'images': [], |
| 'video_length': 0 |
| } |
|
|
| videos[video_id]['video_length'] += 1 |
| ann_ids = self.coco.get_ann_ids( |
| img_ids=[img_id], cat_ids=self.cat_ids) |
| raw_ann_info = self.coco.load_anns(ann_ids) |
| total_ann_ids.extend(ann_ids) |
|
|
| parsed_data_info = self.parse_data_info( |
| dict(raw_img_info=raw_img_info, raw_ann_info=raw_ann_info)) |
|
|
| if len(parsed_data_info['instances']) > 0: |
| self.img_ids_with_ann.add(parsed_data_info['img_id']) |
|
|
| videos[video_id]['images'].append(parsed_data_info) |
|
|
| data_list = [v for v in videos.values()] |
|
|
| if self.ANN_ID_UNIQUE: |
| assert len(set(total_ann_ids)) == len( |
| total_ann_ids |
| ), f"Annotation ids in '{self.ann_file}' are not unique!" |
|
|
| del self.coco |
|
|
| return data_list |
|
|
| def parse_data_info(self, raw_data_info: dict) -> dict: |
| """Parse raw annotation to target format. |
| |
| Args: |
| raw_data_info (dict): Raw data information loaded from |
| ``ann_file``. |
| |
| Returns: |
| dict: Parsed annotation. |
| """ |
| img_info = raw_data_info['raw_img_info'] |
| ann_info = raw_data_info['raw_ann_info'] |
| data_info = {} |
|
|
| data_info.update(img_info) |
| if self.data_prefix.get('img_path', None) is not None: |
| img_path = osp.join(self.data_prefix['img_path'], |
| img_info['file_name']) |
| else: |
| img_path = img_info['file_name'] |
| data_info['img_path'] = img_path |
|
|
| instances = [] |
| for i, ann in enumerate(ann_info): |
| instance = {} |
|
|
| if ann.get('ignore', False): |
| continue |
| x1, y1, w, h = ann['bbox'] |
| inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) |
| inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) |
| if inter_w * inter_h == 0: |
| continue |
| if ann['area'] <= 0 or w < 1 or h < 1: |
| continue |
| if ann['category_id'] not in self.cat_ids: |
| continue |
| bbox = [x1, y1, x1 + w, y1 + h] |
|
|
| if ann.get('iscrowd', False): |
| instance['ignore_flag'] = 1 |
| else: |
| instance['ignore_flag'] = 0 |
| instance['bbox'] = bbox |
| instance['bbox_label'] = self.cat2label[ann['category_id']] |
| if ann.get('segmentation', None): |
| instance['mask'] = ann['segmentation'] |
| if ann.get('instance_id', None): |
| instance['instance_id'] = ann['instance_id'] |
| else: |
| |
| |
| instance['instance_id'] = i |
| instances.append(instance) |
| data_info['instances'] = instances |
| return data_info |
|
|
| def filter_data(self) -> List[int]: |
| """Filter image annotations according to filter_cfg. |
| |
| Returns: |
| list[int]: Filtered results. |
| """ |
| if self.test_mode: |
| return self.data_list |
|
|
| num_imgs_before_filter = sum( |
| [len(info['images']) for info in self.data_list]) |
| num_imgs_after_filter = 0 |
|
|
| |
| ids_in_cat = set() |
| for i, class_id in enumerate(self.cat_ids): |
| ids_in_cat |= set(self.cat_img_map[class_id]) |
| |
| |
| ids_in_cat &= self.img_ids_with_ann |
|
|
| new_data_list = [] |
| for video_data_info in self.data_list: |
| imgs_data_info = video_data_info['images'] |
| valid_imgs_data_info = [] |
|
|
| for data_info in imgs_data_info: |
| img_id = data_info['img_id'] |
| width = data_info['width'] |
| height = data_info['height'] |
| |
| if self.filter_cfg is None: |
| if img_id not in ids_in_cat: |
| video_data_info['video_length'] -= 1 |
| continue |
| if min(width, height) >= 32: |
| valid_imgs_data_info.append(data_info) |
| num_imgs_after_filter += 1 |
| else: |
| video_data_info['video_length'] -= 1 |
| else: |
| if self.filter_cfg.get('filter_empty_gt', |
| True) and img_id not in ids_in_cat: |
| video_data_info['video_length'] -= 1 |
| continue |
| if min(width, height) >= self.filter_cfg.get( |
| 'min_size', 32): |
| valid_imgs_data_info.append(data_info) |
| num_imgs_after_filter += 1 |
| else: |
| video_data_info['video_length'] -= 1 |
| video_data_info['images'] = valid_imgs_data_info |
| new_data_list.append(video_data_info) |
|
|
| print_log( |
| 'The number of samples before and after filtering: ' |
| f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current') |
| return new_data_list |
|
|
| def prepare_data(self, idx) -> Any: |
| """Get date processed by ``self.pipeline``. Note that ``idx`` is a |
| video index in default since the base element of video dataset is a |
| video. However, in some cases, we need to specific both the video index |
| and frame index. For example, in traing mode, we may want to sample the |
| specific frames and all the frames must be sampled once in a epoch; in |
| test mode, we may want to output data of a single image rather than the |
| whole video for saving memory. |
| |
| Args: |
| idx (int): The index of ``data_info``. |
| |
| Returns: |
| Any: Depends on ``self.pipeline``. |
| """ |
| if isinstance(idx, tuple): |
| assert len(idx) == 2, 'The length of idx must be 2: ' |
| '(video_index, frame_index)' |
| video_idx, frame_idx = idx[0], idx[1] |
| else: |
| video_idx, frame_idx = idx, None |
|
|
| data_info = self.get_data_info(video_idx) |
| if self.test_mode: |
| |
| final_data_info = defaultdict(list) |
| if frame_idx is None: |
| frames_idx_list = list(range(data_info['video_length'])) |
| else: |
| frames_idx_list = [frame_idx] |
| for index in frames_idx_list: |
| frame_ann = data_info['images'][index] |
| frame_ann['video_id'] = data_info['video_id'] |
| |
| for key, value in frame_ann.items(): |
| final_data_info[key].append(value) |
| |
| |
| |
| final_data_info['ori_video_length'].append( |
| data_info['video_length']) |
|
|
| final_data_info['video_length'] = [len(frames_idx_list) |
| ] * len(frames_idx_list) |
| return self.pipeline(final_data_info) |
| else: |
| |
| if frame_idx is not None: |
| data_info['key_frame_id'] = frame_idx |
| return self.pipeline(data_info) |
|
|
| def get_cat_ids(self, index) -> List[int]: |
| """Following image detection, we provide this interface function. Get |
| category ids by video index and frame index. |
| |
| Args: |
| index: The index of the dataset. It support two kinds of inputs: |
| Tuple: |
| video_idx (int): Index of video. |
| frame_idx (int): Index of frame. |
| Int: Index of video. |
| |
| Returns: |
| List[int]: All categories in the image of specified video index |
| and frame index. |
| """ |
| if isinstance(index, tuple): |
| assert len( |
| index |
| ) == 2, f'Expect the length of index is 2, but got {len(index)}' |
| video_idx, frame_idx = index |
| instances = self.get_data_info( |
| video_idx)['images'][frame_idx]['instances'] |
| return [instance['bbox_label'] for instance in instances] |
| else: |
| cat_ids = [] |
| for img in self.get_data_info(index)['images']: |
| for instance in img['instances']: |
| cat_ids.append(instance['bbox_label']) |
| return cat_ids |
|
|
| @property |
| def num_all_imgs(self): |
| """Get the number of all the images in this video dataset.""" |
| return sum( |
| [len(self.get_data_info(i)['images']) for i in range(len(self))]) |
|
|
| def get_len_per_video(self, idx): |
| """Get length of one video. |
| |
| Args: |
| idx (int): Index of video. |
| |
| Returns: |
| int (int): The length of the video. |
| """ |
| return len(self.get_data_info(idx)['images']) |
|
|