| | |
| | from abc import ABCMeta, abstractmethod |
| | from typing import List, Tuple |
| |
|
| | from mmengine.model import BaseModel |
| | from mmengine.structures import PixelData |
| | from torch import Tensor |
| |
|
| | from mmseg.structures import SegDataSample |
| | from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, |
| | OptSampleList, SampleList) |
| | from ..utils import resize |
| |
|
| | import torch |
| |
|
| |
|
| | class BaseSegmentor(BaseModel, metaclass=ABCMeta): |
| | """Base class for segmentors. |
| | |
| | Args: |
| | data_preprocessor (dict, optional): Model preprocessing config |
| | for processing the input data. it usually includes |
| | ``to_rgb``, ``pad_size_divisor``, ``pad_val``, |
| | ``mean`` and ``std``. Default to None. |
| | init_cfg (dict, optional): the config to control the |
| | initialization. Default to None. |
| | """ |
| |
|
| | def __init__(self, |
| | data_preprocessor: OptConfigType = None, |
| | init_cfg: OptMultiConfig = None): |
| | super().__init__( |
| | data_preprocessor=data_preprocessor, init_cfg=init_cfg) |
| |
|
| | @property |
| | def with_neck(self) -> bool: |
| | """bool: whether the segmentor has neck""" |
| | return hasattr(self, 'neck') and self.neck is not None |
| |
|
| | @property |
| | def with_auxiliary_head(self) -> bool: |
| | """bool: whether the segmentor has auxiliary head""" |
| | return hasattr(self, |
| | 'auxiliary_head') and self.auxiliary_head is not None |
| |
|
| | @property |
| | def with_decode_head(self) -> bool: |
| | """bool: whether the segmentor has decode head""" |
| | return hasattr(self, 'decode_head') and self.decode_head is not None |
| |
|
| | @abstractmethod |
| | def extract_feat(self, inputs: Tensor) -> bool: |
| | """Placeholder for extract features from images.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList): |
| | """Placeholder for encode images with backbone and decode into a |
| | semantic segmentation map of the same size as input.""" |
| | pass |
| |
|
| | def forward(self, |
| | inputs: Tensor, |
| | data_samples: OptSampleList = None, |
| | mode: str = 'tensor') -> ForwardResults: |
| | """The unified entry for a forward process in both training and test. |
| | |
| | The method should accept three modes: "tensor", "predict" and "loss": |
| | |
| | - "tensor": Forward the whole network and return tensor or tuple of |
| | tensor without any post-processing, same as a common nn.Module. |
| | - "predict": Forward and return the predictions, which are fully |
| | processed to a list of :obj:`SegDataSample`. |
| | - "loss": Forward and return a dict of losses according to the given |
| | inputs and data samples. |
| | |
| | Note that this method doesn't handle neither back propagation nor |
| | optimizer updating, which are done in the :meth:`train_step`. |
| | |
| | Args: |
| | inputs (torch.Tensor): The input tensor with shape (N, C, ...) in |
| | general. |
| | data_samples (list[:obj:`SegDataSample`]): The seg data samples. |
| | It usually includes information such as `metainfo` and |
| | `gt_sem_seg`. Default to None. |
| | mode (str): Return what kind of value. Defaults to 'tensor'. |
| | |
| | Returns: |
| | The return type depends on ``mode``. |
| | |
| | - If ``mode="tensor"``, return a tensor or a tuple of tensor. |
| | - If ``mode="predict"``, return a list of :obj:`DetDataSample`. |
| | - If ``mode="loss"``, return a dict of tensor. |
| | """ |
| | if mode == 'loss': |
| | |
| | |
| | |
| | |
| | return self.loss(inputs, data_samples) |
| | elif mode == 'predict': |
| | |
| | |
| | return self.predict(inputs, data_samples) |
| | |
| | |
| | |
| | |
| | elif mode == 'tensor': |
| | return self._forward(inputs, data_samples) |
| | else: |
| | raise RuntimeError(f'Invalid mode "{mode}". ' |
| | 'Only supports loss, predict and tensor mode') |
| |
|
| | @abstractmethod |
| | def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: |
| | """Calculate losses from a batch of inputs and data samples.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def predict(self, |
| | inputs: Tensor, |
| | data_samples: OptSampleList = None) -> SampleList: |
| | """Predict results from a batch of inputs and data samples with post- |
| | processing.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def _forward(self, |
| | inputs: Tensor, |
| | data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: |
| | """Network forward process. |
| | |
| | Usually includes backbone, neck and head forward without any post- |
| | processing. |
| | """ |
| | pass |
| |
|
| | def postprocess_result(self, |
| | seg_logits: Tensor, |
| | data_samples: OptSampleList = None) -> SampleList: |
| | """ Convert results list to `SegDataSample`. |
| | Args: |
| | seg_logits (Tensor): The segmentation results, seg_logits from |
| | model of each input image. |
| | data_samples (list[:obj:`SegDataSample`]): The seg data samples. |
| | It usually includes information such as `metainfo` and |
| | `gt_sem_seg`. Default to None. |
| | Returns: |
| | list[:obj:`SegDataSample`]: Segmentation results of the |
| | input images. Each SegDataSample usually contain: |
| | |
| | - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. |
| | - ``seg_logits``(PixelData): Predicted logits of semantic |
| | segmentation before normalization. |
| | """ |
| | batch_size, C, H, W = seg_logits.shape |
| |
|
| | if data_samples is None: |
| | data_samples = [SegDataSample() for _ in range(batch_size)] |
| | only_prediction = True |
| | else: |
| | only_prediction = False |
| |
|
| | for i in range(batch_size): |
| | if not only_prediction: |
| | img_meta = data_samples[i].metainfo |
| | |
| | if 'img_padding_size' not in img_meta: |
| | padding_size = img_meta.get('padding_size', [0] * 4) |
| | else: |
| | padding_size = img_meta['img_padding_size'] |
| | padding_left, padding_right, padding_top, padding_bottom =\ |
| | padding_size |
| | |
| | i_seg_logits = seg_logits[i:i + 1, :, |
| | padding_top:H - padding_bottom, |
| | padding_left:W - padding_right] |
| |
|
| | flip = img_meta.get('flip', None) |
| | if flip: |
| | flip_direction = img_meta.get('flip_direction', None) |
| | assert flip_direction in ['horizontal', 'vertical'] |
| | if flip_direction == 'horizontal': |
| | i_seg_logits = i_seg_logits.flip(dims=(3, )) |
| | else: |
| | i_seg_logits = i_seg_logits.flip(dims=(2, )) |
| |
|
| | |
| | i_seg_logits = resize( |
| | i_seg_logits, |
| | size=img_meta['ori_shape'], |
| | mode='bilinear', |
| | align_corners=self.align_corners, |
| | warning=False).squeeze(0) |
| | else: |
| | i_seg_logits = seg_logits[i] |
| |
|
| | if C > 1: |
| | i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) |
| | else: |
| | i_seg_logits = i_seg_logits.sigmoid() |
| | i_seg_pred = (i_seg_logits > |
| | self.decode_head.threshold).to(i_seg_logits) |
| | data_samples[i].set_data({ |
| | 'seg_logits': |
| | PixelData(**{'data': i_seg_logits}), |
| | 'pred_sem_seg': |
| | PixelData(**{'data': i_seg_pred}) |
| | }) |
| |
|
| | return data_samples |
| |
|