| | |
| | import json |
| | import os.path as osp |
| | from typing import List, Optional |
| |
|
| | from mmengine.fileio import get_local_path |
| |
|
| | from mmdet.registry import DATASETS |
| | from .base_det_dataset import BaseDetDataset |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class ODVGDataset(BaseDetDataset): |
| | """object detection and visual grounding dataset.""" |
| |
|
| | def __init__(self, |
| | *args, |
| | data_root: str = '', |
| | label_map_file: Optional[str] = None, |
| | need_text: bool = True, |
| | **kwargs) -> None: |
| | self.dataset_mode = 'VG' |
| | self.need_text = need_text |
| | if label_map_file: |
| | label_map_file = osp.join(data_root, label_map_file) |
| | with open(label_map_file, 'r') as file: |
| | self.label_map = json.load(file) |
| | self.dataset_mode = 'OD' |
| | super().__init__(*args, data_root=data_root, **kwargs) |
| | assert self.return_classes is True |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | with get_local_path( |
| | self.ann_file, backend_args=self.backend_args) as local_path: |
| | with open(local_path, 'r') as f: |
| | data_list = [json.loads(line) for line in f] |
| |
|
| | out_data_list = [] |
| | for data in data_list: |
| | data_info = {} |
| | img_path = osp.join(self.data_prefix['img'], data['filename']) |
| | data_info['img_path'] = img_path |
| | data_info['height'] = data['height'] |
| | data_info['width'] = data['width'] |
| | if self.dataset_mode == 'OD': |
| | if self.need_text: |
| | data_info['text'] = self.label_map |
| | anno = data.get('detection', {}) |
| | instances = [obj for obj in anno.get('instances', [])] |
| | bboxes = [obj['bbox'] for obj in instances] |
| | bbox_labels = [str(obj['label']) for obj in instances] |
| |
|
| | instances = [] |
| | for bbox, label in zip(bboxes, bbox_labels): |
| | instance = {} |
| | x1, y1, x2, y2 = bbox |
| | inter_w = max(0, min(x2, data['width']) - max(x1, 0)) |
| | inter_h = max(0, min(y2, data['height']) - max(y1, 0)) |
| | if inter_w * inter_h == 0: |
| | continue |
| | if (x2 - x1) < 1 or (y2 - y1) < 1: |
| | continue |
| | instance['ignore_flag'] = 0 |
| | instance['bbox'] = bbox |
| | instance['bbox_label'] = int(label) |
| | instances.append(instance) |
| | data_info['instances'] = instances |
| | data_info['dataset_mode'] = self.dataset_mode |
| | out_data_list.append(data_info) |
| | else: |
| | anno = data['grounding'] |
| | data_info['text'] = anno['caption'] |
| | regions = anno['regions'] |
| |
|
| | instances = [] |
| | phrases = {} |
| | for i, region in enumerate(regions): |
| | bbox = region['bbox'] |
| | phrase = region['phrase'] |
| | tokens_positive = region['tokens_positive'] |
| | if not isinstance(bbox[0], list): |
| | bbox = [bbox] |
| | for box in bbox: |
| | instance = {} |
| | x1, y1, x2, y2 = box |
| | inter_w = max(0, min(x2, data['width']) - max(x1, 0)) |
| | inter_h = max(0, min(y2, data['height']) - max(y1, 0)) |
| | if inter_w * inter_h == 0: |
| | continue |
| | if (x2 - x1) < 1 or (y2 - y1) < 1: |
| | continue |
| | instance['ignore_flag'] = 0 |
| | instance['bbox'] = box |
| | instance['bbox_label'] = i |
| | phrases[i] = { |
| | 'phrase': phrase, |
| | 'tokens_positive': tokens_positive |
| | } |
| | instances.append(instance) |
| | data_info['instances'] = instances |
| | data_info['phrases'] = phrases |
| | data_info['dataset_mode'] = self.dataset_mode |
| | out_data_list.append(data_info) |
| |
|
| | del data_list |
| | return out_data_list |
| |
|