Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import random | |
| import time | |
| import torch | |
| import numpy as np | |
| from os import path as osp | |
| from .dist_util import master_only | |
| from .logger import get_root_logger | |
| # --------------------------- | |
| # GPU / MPS Compatibility | |
| # --------------------------- | |
| # Check if PyTorch β₯ 1.12 for MPS (Apple Silicon) | |
| try: | |
| version_match = re.findall( | |
| r"^([0-9]+)\.([0-9]+)\.([0-9]+)", | |
| torch.__version__ | |
| )[0] | |
| IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0] | |
| except: | |
| IS_HIGH_VERSION = False | |
| def gpu_is_available(): | |
| """Return True if CUDA or MPS is available.""" | |
| if IS_HIGH_VERSION and torch.backends.mps.is_available(): | |
| return True | |
| return torch.cuda.is_available() and torch.backends.cudnn.is_available() | |
| def get_device(gpu_id=None): | |
| """Return the best available device (MPS β CUDA β CPU).""" | |
| gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else "" | |
| # Apple MPS | |
| if IS_HIGH_VERSION and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| # NVIDIA CUDA | |
| if torch.cuda.is_available() and torch.backends.cudnn.is_available(): | |
| return torch.device("cuda" + gpu_str) | |
| # CPU fallback | |
| return torch.device("cpu") | |
| # --------------------------- | |
| # Utilities | |
| # --------------------------- | |
| def set_random_seed(seed): | |
| """Set random seeds.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_time_str(): | |
| return time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
| def mkdir_and_rename(path): | |
| if osp.exists(path): | |
| new_name = path + '_archived_' + get_time_str() | |
| print(f'Path already exists. Renamed to {new_name}', flush=True) | |
| os.rename(path, new_name) | |
| os.makedirs(path, exist_ok=True) | |
| def make_exp_dirs(opt): | |
| path_opt = opt['path'].copy() | |
| if opt['is_train']: | |
| mkdir_and_rename(path_opt.pop('experiments_root')) | |
| else: | |
| mkdir_and_rename(path_opt.pop('results_root')) | |
| for key, path in path_opt.items(): | |
| if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): | |
| os.makedirs(path, exist_ok=True) | |
| def scandir(dir_path, suffix=None, recursive=False, full_path=False): | |
| root = dir_path | |
| def _scan(path): | |
| for entry in os.scandir(path): | |
| if entry.is_file() and not entry.name.startswith('.'): | |
| file_path = entry.path if full_path else osp.relpath(entry.path, root) | |
| if suffix is None or file_path.endswith(suffix): | |
| yield file_path | |
| elif entry.is_dir() and recursive: | |
| yield from _scan(entry.path) | |
| return _scan(dir_path) | |
| def check_resume(opt, resume_iter): | |
| logger = get_root_logger() | |
| if opt['path']['resume_state']: | |
| networks = [k for k in opt.keys() if k.startswith('network_')] | |
| flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks) | |
| if flag_pretrain: | |
| logger.warning('pretrain_network path will be ignored during resuming.') | |
| for network in networks: | |
| basename = network.replace('network_', '') | |
| if opt['path'].get('ignore_resume_networks') is None or ( | |
| basename not in opt['path']['ignore_resume_networks'] | |
| ): | |
| opt['path'][f'pretrain_{network}'] = osp.join( | |
| opt['path']['models'], f'net_{basename}_{resume_iter}.pth' | |
| ) | |
| logger.info(f"Set pretrain for {network}") | |
| def sizeof_fmt(size, suffix='B'): | |
| for unit in ['', 'K', 'M', 'G', 'T', 'P']: | |
| if size < 1024: | |
| return f"{size:3.1f} {unit}{suffix}" | |
| size /= 1024 | |
| return f"{size:3.1f} Y{suffix}" | |