| |
| import os |
| import sys |
| import shutil |
| import logging |
| import colorlog |
| from tqdm import tqdm |
| import time |
| import yaml |
| import random |
| import importlib |
| from PIL import Image |
| from warnings import simplefilter |
| import imageio |
| import math |
| import collections |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import Adam |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torch.utils.data import DataLoader, Dataset |
| from einops import rearrange, repeat |
| import torch.distributed as dist |
| from torchvision import datasets, transforms, utils |
|
|
| logging.getLogger().setLevel(logging.WARNING) |
| simplefilter(action='ignore', category=FutureWarning) |
|
|
| def get_logger(filename=None): |
| """ |
| examples: |
| logger = get_logger('try_logging.txt') |
| |
| logger.debug("Do something.") |
| logger.info("Start print log.") |
| logger.warning("Something maybe fail.") |
| try: |
| raise ValueError() |
| except ValueError: |
| logger.error("Error", exc_info=True) |
| |
| tips: |
| DO NOT logger.inf(some big tensors since color may not helpful.) |
| """ |
| logger = logging.getLogger('utils') |
| level = logging.DEBUG |
| logger.setLevel(level=level) |
| |
| logger.propagate = False |
| |
| format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s' |
|
|
| streamHandler = logging.StreamHandler() |
| streamHandler.setLevel(level) |
| coloredFormatter = colorlog.ColoredFormatter( |
| '%(log_color)s' + format_str, |
| datefmt='%Y-%m-%d %H:%M:%S', |
| reset=True, |
| log_colors={ |
| 'DEBUG': 'cyan', |
| |
| 'WARNING': 'yellow', |
| 'ERROR': 'red', |
| 'CRITICAL': 'reg,bg_white', |
| } |
| ) |
|
|
| streamHandler.setFormatter(coloredFormatter) |
| logger.addHandler(streamHandler) |
|
|
| if filename: |
| fileHandler = logging.FileHandler(filename) |
| fileHandler.setLevel(level) |
| formatter = logging.Formatter(format_str) |
| fileHandler.setFormatter(formatter) |
| logger.addHandler(fileHandler) |
|
|
| |
| try: |
| class UniqueLogger: |
| def __init__(self, logger): |
| self.logger = logger |
| self.local_rank = torch.distributed.get_rank() |
|
|
| def info(self, msg, *args, **kwargs): |
| if self.local_rank == 0: |
| return self.logger.info(msg, *args, **kwargs) |
|
|
| def warning(self, msg, *args, **kwargs): |
| if self.local_rank == 0: |
| return self.logger.warning(msg, *args, **kwargs) |
|
|
| logger = UniqueLogger(logger) |
| |
| |
| except Exception: |
| pass |
| return logger |
|
|
|
|
| logger = get_logger() |
|
|
| def split_filename(filename): |
| absname = os.path.abspath(filename) |
| dirname, basename = os.path.split(absname) |
| split_tmp = basename.rsplit('.', maxsplit=1) |
| if len(split_tmp) == 2: |
| rootname, extname = split_tmp |
| elif len(split_tmp) == 1: |
| rootname = split_tmp[0] |
| extname = None |
| else: |
| raise ValueError("programming error!") |
| return dirname, rootname, extname |
|
|
| def data2file(data, filename, type=None, override=False, printable=False, **kwargs): |
| dirname, rootname, extname = split_filename(filename) |
| print_did_not_save_flag = True |
| if type: |
| extname = type |
| if not os.path.exists(dirname): |
| os.makedirs(dirname, exist_ok=True) |
|
|
| if not os.path.exists(filename) or override: |
| if extname in ['jpg', 'png', 'jpeg']: |
| utils.save_image(data, filename, **kwargs) |
| elif extname == 'gif': |
| imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) |
| elif extname == 'txt': |
| if kwargs is None: |
| kwargs = {} |
| max_step = kwargs.get('max_step') |
| if max_step is None: |
| max_step = np.Infinity |
|
|
| with open(filename, 'w', encoding='utf-8') as f: |
| for i, e in enumerate(data): |
| if i < max_step: |
| f.write(str(e) + '\n') |
| else: |
| break |
| else: |
| raise ValueError('Do not support this type') |
| if printable: logger.info('Saved data to %s' % os.path.abspath(filename)) |
| else: |
| if print_did_not_save_flag: logger.info( |
| 'Did not save data to %s because file exists and override is False' % os.path.abspath( |
| filename)) |
|
|
|
|
| def file2data(filename, type=None, printable=True, **kwargs): |
| dirname, rootname, extname = split_filename(filename) |
| print_load_flag = True |
| if type: |
| extname = type |
| |
| if extname in ['pth', 'ckpt']: |
| data = torch.load(filename, map_location=kwargs.get('map_location')) |
| elif extname == 'txt': |
| top = kwargs.get('top', None) |
| with open(filename, encoding='utf-8') as f: |
| if top: |
| data = [f.readline() for _ in range(top)] |
| else: |
| data = [e for e in f.read().split('\n') if e] |
| elif extname == 'yaml': |
| with open(filename, 'r') as f: |
| data = yaml.load(f) |
| else: |
| raise ValueError('type can only support h5, npy, json, txt') |
| if printable: |
| if print_load_flag: |
| logger.info('Loaded data from %s' % os.path.abspath(filename)) |
| return data |
|
|
|
|
| def ensure_dirname(dirname, override=False): |
| if os.path.exists(dirname) and override: |
| logger.info('Removing dirname: %s' % os.path.abspath(dirname)) |
| try: |
| shutil.rmtree(dirname) |
| except OSError as e: |
| raise ValueError('Failed to delete %s because %s' % (dirname, e)) |
|
|
| if not os.path.exists(dirname): |
| logger.info('Making dirname: %s' % os.path.abspath(dirname)) |
| os.makedirs(dirname, exist_ok=True) |
|
|
|
|
| def import_filename(filename): |
| spec = importlib.util.spec_from_file_location("mymodule", filename) |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[spec.name] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def adaptively_load_state_dict(target, state_dict): |
| target_dict = target.state_dict() |
|
|
| try: |
| common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} |
| except Exception as e: |
| logger.warning('load error %s', e) |
| common_dict = {k: v for k, v in state_dict.items() if k in target_dict} |
|
|
| if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ |
| target.state_dict()['param_groups'][0]['params']: |
| logger.warning('Detected mismatch params, auto adapte state_dict to current') |
| common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] |
| target_dict.update(common_dict) |
| target.load_state_dict(target_dict) |
|
|
| missing_keys = [k for k in target_dict.keys() if k not in common_dict] |
| unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] |
|
|
| if len(unexpected_keys) != 0: |
| logger.warning( |
| f"Some weights of state_dict were not used in target: {unexpected_keys}" |
| ) |
| if len(missing_keys) != 0: |
| logger.warning( |
| f"Some weights of state_dict are missing used in target {missing_keys}" |
| ) |
| if len(unexpected_keys) == 0 and len(missing_keys) == 0: |
| logger.warning("Strictly Loaded state_dict.") |
|
|
| def set_seed(seed=42): |
| random.seed(seed) |
| os.environ['PYHTONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
|
|
| def image2pil(filename): |
| return Image.open(filename) |
|
|
|
|
| def image2arr(filename): |
| pil = image2pil(filename) |
| return pil2arr(pil) |
|
|
|
|
| |
| def pil2arr(pil): |
| if isinstance(pil, list): |
| arr = np.array( |
| [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) |
| else: |
| arr = np.array(pil) |
| return arr |
|
|
|
|
| def arr2pil(arr): |
| if arr.ndim == 3: |
| return Image.fromarray(arr.astype('uint8'), 'RGB') |
| elif arr.ndim == 4: |
| return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] |
| else: |
| raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) |
|
|
| def notebook_show(*images): |
| from IPython.display import Image |
| from IPython.display import display |
| display(*[Image(e) for e in images]) |