| | |
| | import copy |
| | import logging |
| | import os |
| | import shutil |
| | import tempfile |
| | import time |
| | from unittest import TestCase |
| | from uuid import uuid4 |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.distributed import destroy_process_group |
| | from torch.utils.data import Dataset |
| |
|
| | import mmengine.hooks |
| | import mmengine.optim |
| | from mmengine.config import Config |
| | from mmengine.dist import is_distributed |
| | from mmengine.evaluator import BaseMetric |
| | from mmengine.logging import MessageHub, MMLogger |
| | from mmengine.model import BaseModel |
| | from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope |
| | from mmengine.runner import Runner |
| | from mmengine.visualization import Visualizer |
| |
|
| |
|
| | class ToyModel(BaseModel): |
| |
|
| | def __init__(self, data_preprocessor=None): |
| | super().__init__(data_preprocessor=data_preprocessor) |
| | self.linear1 = nn.Linear(2, 2) |
| | self.linear2 = nn.Linear(2, 1) |
| |
|
| | def forward(self, inputs, data_samples=None, mode='tensor'): |
| | if isinstance(inputs, list): |
| | inputs = torch.stack(inputs) |
| | if isinstance(data_samples, list): |
| | data_samples = torch.stack(data_samples) |
| | outputs = self.linear1(inputs) |
| | outputs = self.linear2(outputs) |
| |
|
| | if mode == 'tensor': |
| | return outputs |
| | elif mode == 'loss': |
| | loss = (data_samples - outputs).sum() |
| | outputs = dict(loss=loss) |
| | return outputs |
| | elif mode == 'predict': |
| | return outputs |
| |
|
| |
|
| | class ToyDataset(Dataset): |
| | METAINFO = dict() |
| | data = torch.randn(12, 2) |
| | label = torch.ones(12) |
| |
|
| | @property |
| | def metainfo(self): |
| | return self.METAINFO |
| |
|
| | def __len__(self): |
| | return self.data.size(0) |
| |
|
| | def __getitem__(self, index): |
| | return dict(inputs=self.data[index], data_samples=self.label[index]) |
| |
|
| |
|
| | class ToyMetric(BaseMetric): |
| |
|
| | def __init__(self, collect_device='cpu', dummy_metrics=None): |
| | super().__init__(collect_device=collect_device) |
| | self.dummy_metrics = dummy_metrics |
| |
|
| | def process(self, data_batch, predictions): |
| | result = {'acc': 1} |
| | self.results.append(result) |
| |
|
| | def compute_metrics(self, results): |
| | return dict(acc=1) |
| |
|
| |
|
| | class RunnerTestCase(TestCase): |
| | """A test case to build runner easily. |
| | |
| | `RunnerTestCase` will do the following things: |
| | |
| | 1. Registers a toy model, a toy metric, and a toy dataset, which can be |
| | used to run the `Runner` successfully. |
| | 2. Provides epoch based and iteration based cfg to build runner. |
| | 3. Provides `build_runner` method to build runner easily. |
| | 4. Clean the global variable used by the runner. |
| | """ |
| | dist_cfg = dict( |
| | MASTER_ADDR='127.0.0.1', |
| | MASTER_PORT=29600, |
| | RANK='0', |
| | WORLD_SIZE='1', |
| | LOCAL_RANK='0') |
| |
|
| | def setUp(self) -> None: |
| | self.temp_dir = tempfile.TemporaryDirectory() |
| | |
| | |
| | MODELS.register_module(module=ToyModel, force=True) |
| | METRICS.register_module(module=ToyMetric, force=True) |
| | DATASETS.register_module(module=ToyDataset, force=True) |
| | epoch_based_cfg = dict( |
| | work_dir=self.temp_dir.name, |
| | model=dict(type='ToyModel'), |
| | train_dataloader=dict( |
| | dataset=dict(type='ToyDataset'), |
| | sampler=dict(type='DefaultSampler', shuffle=True), |
| | batch_size=3, |
| | num_workers=0), |
| | val_dataloader=dict( |
| | dataset=dict(type='ToyDataset'), |
| | sampler=dict(type='DefaultSampler', shuffle=False), |
| | batch_size=3, |
| | num_workers=0), |
| | val_evaluator=[dict(type='ToyMetric')], |
| | test_dataloader=dict( |
| | dataset=dict(type='ToyDataset'), |
| | sampler=dict(type='DefaultSampler', shuffle=False), |
| | batch_size=3, |
| | num_workers=0), |
| | test_evaluator=[dict(type='ToyMetric')], |
| | optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), |
| | train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), |
| | val_cfg=dict(), |
| | test_cfg=dict(), |
| | default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), |
| | custom_hooks=[], |
| | env_cfg=dict(dist_cfg=dict(backend='nccl')), |
| | experiment_name='test1') |
| | self.epoch_based_cfg = Config(epoch_based_cfg) |
| |
|
| | |
| | self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg) |
| | self.iter_based_cfg.train_dataloader = dict( |
| | dataset=dict(type='ToyDataset'), |
| | sampler=dict(type='InfiniteSampler', shuffle=True), |
| | batch_size=3, |
| | num_workers=0) |
| | self.iter_based_cfg.log_processor = dict(by_epoch=False) |
| |
|
| | self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) |
| | self.iter_based_cfg.default_hooks = dict( |
| | logger=dict(type='LoggerHook', interval=1), |
| | checkpoint=dict( |
| | type='CheckpointHook', interval=12, by_epoch=False)) |
| |
|
| | def tearDown(self): |
| | |
| | |
| | logging.shutdown() |
| | MMLogger._instance_dict.clear() |
| | Visualizer._instance_dict.clear() |
| | DefaultScope._instance_dict.clear() |
| | MessageHub._instance_dict.clear() |
| | MODELS.module_dict.pop('ToyModel', None) |
| | METRICS.module_dict.pop('ToyMetric', None) |
| | DATASETS.module_dict.pop('ToyDataset', None) |
| | self.temp_dir.cleanup() |
| | if is_distributed(): |
| | destroy_process_group() |
| |
|
| | def build_runner(self, cfg: Config): |
| | cfg.experiment_name = self.experiment_name |
| | runner = Runner.from_cfg(cfg) |
| | return runner |
| |
|
| | @property |
| | def experiment_name(self): |
| | |
| | |
| | |
| | return f'{self._testMethodName}_{time.time()} + ' \ |
| | f'{uuid4()}' |
| |
|
| | def setup_dist_env(self): |
| | self.dist_cfg['MASTER_PORT'] += 1 |
| | os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT']) |
| | os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR'] |
| | os.environ['RANK'] = self.dist_cfg['RANK'] |
| | os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] |
| | os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] |
| |
|
| | def clear_work_dir(self): |
| | logging.shutdown() |
| | for filename in os.listdir(self.temp_dir.name): |
| | filepath = os.path.join(self.temp_dir.name, filename) |
| | if os.path.isfile(filepath): |
| | os.remove(filepath) |
| | else: |
| | shutil.rmtree(filepath) |
| |
|