import os import re import random import time import torch import numpy as np from os import path as osp from .dist_util import master_only from .logger import get_root_logger # --------------------------- # GPU / MPS Compatibility # --------------------------- # Check if PyTorch ≥ 1.12 for MPS (Apple Silicon) try: version_match = re.findall( r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__ )[0] IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0] except: IS_HIGH_VERSION = False def gpu_is_available(): """Return True if CUDA or MPS is available.""" if IS_HIGH_VERSION and torch.backends.mps.is_available(): return True return torch.cuda.is_available() and torch.backends.cudnn.is_available() def get_device(gpu_id=None): """Return the best available device (MPS → CUDA → CPU).""" gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else "" # Apple MPS if IS_HIGH_VERSION and torch.backends.mps.is_available(): return torch.device("mps") # NVIDIA CUDA if torch.cuda.is_available() and torch.backends.cudnn.is_available(): return torch.device("cuda" + gpu_str) # CPU fallback return torch.device("cpu") # --------------------------- # Utilities # --------------------------- def set_random_seed(seed): """Set random seeds.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_time_str(): return time.strftime('%Y%m%d_%H%M%S', time.localtime()) def mkdir_and_rename(path): if osp.exists(path): new_name = path + '_archived_' + get_time_str() print(f'Path already exists. Renamed to {new_name}', flush=True) os.rename(path, new_name) os.makedirs(path, exist_ok=True) @master_only def make_exp_dirs(opt): path_opt = opt['path'].copy() if opt['is_train']: mkdir_and_rename(path_opt.pop('experiments_root')) else: mkdir_and_rename(path_opt.pop('results_root')) for key, path in path_opt.items(): if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): os.makedirs(path, exist_ok=True) def scandir(dir_path, suffix=None, recursive=False, full_path=False): root = dir_path def _scan(path): for entry in os.scandir(path): if entry.is_file() and not entry.name.startswith('.'): file_path = entry.path if full_path else osp.relpath(entry.path, root) if suffix is None or file_path.endswith(suffix): yield file_path elif entry.is_dir() and recursive: yield from _scan(entry.path) return _scan(dir_path) def check_resume(opt, resume_iter): logger = get_root_logger() if opt['path']['resume_state']: networks = [k for k in opt.keys() if k.startswith('network_')] flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks) if flag_pretrain: logger.warning('pretrain_network path will be ignored during resuming.') for network in networks: basename = network.replace('network_', '') if opt['path'].get('ignore_resume_networks') is None or ( basename not in opt['path']['ignore_resume_networks'] ): opt['path'][f'pretrain_{network}'] = osp.join( opt['path']['models'], f'net_{basename}_{resume_iter}.pth' ) logger.info(f"Set pretrain for {network}") def sizeof_fmt(size, suffix='B'): for unit in ['', 'K', 'M', 'G', 'T', 'P']: if size < 1024: return f"{size:3.1f} {unit}{suffix}" size /= 1024 return f"{size:3.1f} Y{suffix}"