| | import uniperceiver.utils.comm as comm |
| | import torch |
| | import numpy as np |
| | from uniperceiver.utils.events import get_event_storage |
| | from typing import Dict |
| | from uniperceiver.datasets import ( |
| | build_standard_valtest_loader, |
| | build_unified_train_loader, |
| | ) |
| | import weakref |
| |
|
| | def write_metrics(loss_dict: Dict[str, torch.Tensor], |
| | data_time: float, |
| | prefix: str = "", |
| | ): |
| | """ |
| | Args: |
| | loss_dict (dict): dict of scalar losses |
| | data_time (float): time taken by the dataloader iteration |
| | """ |
| | metrics_dict = {} |
| | for k, v in loss_dict.items(): |
| | if isinstance(v, torch.Tensor): |
| | metrics_dict.update({k: v.detach().cpu().item()}) |
| | else: |
| | metrics_dict.update({k: v}) |
| | metrics_dict["data_time"] = data_time |
| |
|
| | |
| | |
| | |
| | all_metrics_dict = [metrics_dict] |
| | if comm.is_main_process(): |
| | |
| | storage = get_event_storage() |
| |
|
| | |
| | |
| | data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) |
| | storage.put_scalar("data_time", data_time) |
| | |
| | metrics_dict = { |
| | k: np.mean([x[k] for x in all_metrics_dict]) |
| | for k in all_metrics_dict[0].keys() |
| | } |
| | total_losses_reduced = sum(metrics_dict.values()) |
| | storage.put_scalar("{}total_loss".format(prefix), |
| | total_losses_reduced) |
| | if len(metrics_dict) > 1: |
| | for k, v in metrics_dict.items(): |
| | if k != 'null_loss': |
| | storage.put_scalar(f'{prefix}{k}', v) |
| |
|
| | def build_writers(cfg, max_iter): |
| | from uniperceiver.engine.defaults import default_writers |
| | return default_writers(cfg.OUTPUT_DIR, max_iter) |
| |
|
| | def build_train_loader(cfg, task_cfg, model): |
| | loader = dict() |
| | if cfg.DATALOADER.UNIFIED_DATASET: |
| | loader = build_unified_train_loader(cfg, task_cfg, model=weakref.proxy(comm.unwrap_model(model)) if cfg.DATALOADER.LOAD_INLABEL else None) |
| | return loader |
| | else: |
| | raise NotImplementedError('please use unified dataset.') |
| |
|
| | def build_test_loader(cfg, task_cfg): |
| | loaders = dict() |
| | |
| | for name, new_cfg in task_cfg.items(): |
| | multi_gpu = name in [ |
| | 'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption', |
| | 'flickr30k_caption', 'K700_retrieve', 'imagenet_caption' |
| | ] |
| | loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='test', multi_gpu_eval=multi_gpu) |
| | return loaders |
| |
|
| | def build_val_loader(cfg, task_cfg): |
| | loaders = dict() |
| | for name, new_cfg in task_cfg.items(): |
| | |
| | multi_gpu = name in [ |
| | 'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption', |
| | 'flickr30k_caption', 'K700_retrieve', 'imagenet_caption' |
| | ] |
| | loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='val', multi_gpu_eval=multi_gpu) |
| | return loaders |
| |
|
| | def get_batch_data(cfg, train_data_loader_iter, train_data_loader): |
| | if not cfg.DATALOADER.FAKE_DATA: |
| | try: |
| | data = next(train_data_loader_iter) |
| | except StopIteration: |
| | train_data_loader_iter = iter(train_data_loader) |
| | data = next(train_data_loader_iter) |
| | else: |
| | |
| | bs = 32 |
| | return data |
| |
|