| | |
| | from mmengine.model import is_model_wrapper |
| | from mmengine.runner import ValLoop |
| |
|
| | from mmdet.registry import LOOPS |
| |
|
| |
|
| | @LOOPS.register_module() |
| | class TeacherStudentValLoop(ValLoop): |
| | """Loop for validation of model teacher and student.""" |
| |
|
| | def run(self): |
| | """Launch validation for model teacher and student.""" |
| | self.runner.call_hook('before_val') |
| | self.runner.call_hook('before_val_epoch') |
| | self.runner.model.eval() |
| |
|
| | model = self.runner.model |
| | if is_model_wrapper(model): |
| | model = model.module |
| | assert hasattr(model, 'teacher') |
| | assert hasattr(model, 'student') |
| |
|
| | predict_on = model.semi_test_cfg.get('predict_on', None) |
| | multi_metrics = dict() |
| | for _predict_on in ['teacher', 'student']: |
| | model.semi_test_cfg['predict_on'] = _predict_on |
| | for idx, data_batch in enumerate(self.dataloader): |
| | self.run_iter(idx, data_batch) |
| | |
| | metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
| | multi_metrics.update( |
| | {'/'.join((_predict_on, k)): v |
| | for k, v in metrics.items()}) |
| | model.semi_test_cfg['predict_on'] = predict_on |
| |
|
| | self.runner.call_hook('after_val_epoch', metrics=multi_metrics) |
| | self.runner.call_hook('after_val') |
| |
|