| | |
| | import datetime as dt |
| | import fnmatch |
| | import glob |
| | import importlib |
| | import os |
| | import random |
| | import re |
| | import shutil |
| | import socket |
| | import subprocess |
| | import sys |
| | import time |
| | from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from transformers import HfArgumentParser, enable_full_determinism, set_seed |
| | from transformers.utils import strtobool |
| |
|
| | from .env import is_dist, is_dist_ta |
| | from .logger import get_logger |
| | from .np_utils import stat_array |
| |
|
| | logger = get_logger() |
| |
|
| |
|
| | def check_json_format(obj: Any, token_safe: bool = True) -> Any: |
| | if obj is None or isinstance(obj, (int, float, str, complex)): |
| | return obj |
| | if isinstance(obj, bytes): |
| | return '<<<bytes>>>' |
| | if isinstance(obj, (torch.dtype, torch.device)): |
| | obj = str(obj) |
| | return obj[len('torch.'):] if obj.startswith('torch.') else obj |
| |
|
| | if isinstance(obj, Sequence): |
| | res = [] |
| | for x in obj: |
| | res.append(check_json_format(x, token_safe)) |
| | elif isinstance(obj, Mapping): |
| | res = {} |
| | for k, v in obj.items(): |
| | if token_safe and isinstance(k, str) and '_token' in k and isinstance(v, str): |
| | res[k] = None |
| | else: |
| | res[k] = check_json_format(v, token_safe) |
| | else: |
| | if token_safe: |
| | unsafe_items = {} |
| | for k, v in obj.__dict__.items(): |
| | if '_token' in k: |
| | unsafe_items[k] = v |
| | setattr(obj, k, None) |
| | res = repr(obj) |
| | |
| | for k, v in unsafe_items.items(): |
| | setattr(obj, k, v) |
| | else: |
| | res = repr(obj) |
| | return res |
| |
|
| |
|
| | def _get_version(work_dir: str) -> int: |
| | if os.path.isdir(work_dir): |
| | fnames = os.listdir(work_dir) |
| | else: |
| | fnames = [] |
| | v_list = [-1] |
| | for fname in fnames: |
| | m = re.match(r'v(\d+)', fname) |
| | if m is None: |
| | continue |
| | v = m.group(1) |
| | v_list.append(int(v)) |
| | return max(v_list) + 1 |
| |
|
| |
|
| | def format_time(seconds): |
| | days = int(seconds // (24 * 3600)) |
| | hours = int((seconds % (24 * 3600)) // 3600) |
| | minutes = int((seconds % 3600) // 60) |
| | seconds = int(seconds % 60) |
| |
|
| | if days > 0: |
| | time_str = f'{days}d {hours}h {minutes}m {seconds}s' |
| | elif hours > 0: |
| | time_str = f'{hours}h {minutes}m {seconds}s' |
| | elif minutes > 0: |
| | time_str = f'{minutes}m {seconds}s' |
| | else: |
| | time_str = f'{seconds}s' |
| |
|
| | return time_str |
| |
|
| |
|
| | def deep_getattr(obj, attr: str, default=None): |
| | attrs = attr.split('.') |
| | for a in attrs: |
| | if obj is None: |
| | break |
| | if isinstance(obj, dict): |
| | obj = obj.get(a, default) |
| | else: |
| | obj = getattr(obj, a, default) |
| | return obj |
| |
|
| |
|
| | def seed_everything(seed: Optional[int] = None, full_determinism: bool = False, *, verbose: bool = True) -> int: |
| |
|
| | if seed is None: |
| | seed_max = np.iinfo(np.int32).max |
| | seed = random.randint(0, seed_max) |
| |
|
| | if full_determinism: |
| | enable_full_determinism(seed) |
| | else: |
| | set_seed(seed) |
| | if verbose: |
| | logger.info(f'Global seed set to {seed}') |
| | return seed |
| |
|
| |
|
| | def add_version_to_work_dir(work_dir: str) -> str: |
| | """add version""" |
| | version = _get_version(work_dir) |
| | time = dt.datetime.now().strftime('%Y%m%d-%H%M%S') |
| | sub_folder = f'v{version}-{time}' |
| | if (dist.is_initialized() and is_dist()) or is_dist_ta(): |
| | obj_list = [sub_folder] |
| | dist.broadcast_object_list(obj_list) |
| | sub_folder = obj_list[0] |
| |
|
| | work_dir = os.path.join(work_dir, sub_folder) |
| | return work_dir |
| |
|
| |
|
| | _T = TypeVar('_T') |
| |
|
| |
|
| | def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]: |
| | parser = HfArgumentParser([class_type]) |
| | if argv is None: |
| | argv = sys.argv[1:] |
| | if len(argv) > 0 and argv[0].endswith('.json'): |
| | json_path = os.path.abspath(os.path.expanduser(argv[0])) |
| | args, = parser.parse_json_file(json_path) |
| | remaining_args = argv[1:] |
| | else: |
| | args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True) |
| | return args, remaining_args |
| |
|
| |
|
| | def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int: |
| | |
| | while lo < hi: |
| | mid = (lo + hi) >> 1 |
| | if cond(mid): |
| | hi = mid |
| | else: |
| | lo = mid + 1 |
| | return lo |
| |
|
| |
|
| | def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int: |
| | |
| | while lo < hi: |
| | mid = (lo + hi + 1) >> 1 |
| | if cond(mid): |
| | lo = mid |
| | else: |
| | hi = mid - 1 |
| | return lo |
| |
|
| |
|
| | def test_time(func: Callable[[], _T], |
| | number: int = 1, |
| | warmup: int = 0, |
| | timer: Optional[Callable[[], float]] = None) -> _T: |
| | |
| | timer = timer if timer is not None else time.perf_counter |
| |
|
| | ts = [] |
| | res = None |
| | |
| | for _ in range(warmup): |
| | res = func() |
| |
|
| | for _ in range(number): |
| | t1 = timer() |
| | res = func() |
| | t2 = timer() |
| | ts.append(t2 - t1) |
| |
|
| | ts = np.array(ts) |
| | _, stat_str = stat_array(ts) |
| | |
| | logger.info(f'time[number={number}]: {stat_str}') |
| | return res |
| |
|
| |
|
| | def read_multi_line(addi_prompt: str = '') -> str: |
| | res = [] |
| | prompt = f'<<<{addi_prompt} ' |
| | while True: |
| | text = input(prompt) + '\n' |
| | prompt = '' |
| | res.append(text) |
| | if text.endswith('#\n'): |
| | res[-1] = text[:-2] |
| | break |
| | return ''.join(res) |
| |
|
| |
|
| | def subprocess_run(command: List[str], env: Optional[Dict[str, str]] = None, stdout=None, stderr=None): |
| | |
| | resp = subprocess.run(command, env=env, stdout=stdout, stderr=stderr) |
| | resp.check_returncode() |
| | return resp |
| |
|
| |
|
| | def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: Optional[_T]) -> Optional[_T]: |
| | args_name_upper = args_name.upper() |
| | value = os.getenv(args_name_upper) |
| | if value is None: |
| | value = default_value |
| | log_info = (f'Setting {args_name}: {default_value}. ' |
| | f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.') |
| | else: |
| | if type_func is bool: |
| | value = strtobool(value) |
| | value = type_func(value) |
| | log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.' |
| | logger.info_once(log_info) |
| | return value |
| |
|
| |
|
| | def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int: |
| | if start_port is None: |
| | start_port = 0 |
| | for port in range(start_port, start_port + retry): |
| | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: |
| | try: |
| | sock.bind(('', port)) |
| | port = sock.getsockname()[1] |
| | break |
| | except OSError: |
| | pass |
| | return port |
| |
|
| |
|
| | def copy_files_by_pattern(source_dir, dest_dir, patterns): |
| | if not os.path.exists(dest_dir): |
| | os.makedirs(dest_dir) |
| |
|
| | if isinstance(patterns, str): |
| | patterns = [patterns] |
| |
|
| | for pattern in patterns: |
| | pattern_parts = pattern.split(os.path.sep) |
| | if len(pattern_parts) > 1: |
| | subdir_pattern = os.path.sep.join(pattern_parts[:-1]) |
| | file_pattern = pattern_parts[-1] |
| |
|
| | for root, dirs, files in os.walk(source_dir): |
| | rel_path = os.path.relpath(root, source_dir) |
| | if rel_path == '.' or (rel_path != '.' and not fnmatch.fnmatch(rel_path, subdir_pattern)): |
| | continue |
| |
|
| | for file in files: |
| | if fnmatch.fnmatch(file, file_pattern): |
| | file_path = os.path.join(root, file) |
| | target_dir = os.path.join(dest_dir, rel_path) |
| | if not os.path.exists(target_dir): |
| | os.makedirs(target_dir) |
| | dest_file = os.path.join(target_dir, file) |
| |
|
| | if not os.path.exists(dest_file): |
| | shutil.copy2(file_path, dest_file) |
| | else: |
| | search_path = os.path.join(source_dir, pattern) |
| | matched_files = glob.glob(search_path) |
| |
|
| | for file_path in matched_files: |
| | if os.path.isfile(file_path): |
| | file_name = os.path.basename(file_path) |
| | destination = os.path.join(dest_dir, file_name) |
| | if not os.path.exists(destination): |
| | shutil.copy2(file_path, destination) |
| |
|
| |
|
| | def split_list(ori_list, num_shards): |
| | idx_list = np.linspace(0, len(ori_list), num_shards + 1) |
| | shard = [] |
| | for i in range(len(idx_list) - 1): |
| | shard.append(ori_list[int(idx_list[i]):int(idx_list[i + 1])]) |
| | return shard |
| |
|
| |
|
| | def patch_getattr(obj_cls, item_name: str): |
| | if hasattr(obj_cls, '_patch'): |
| | return |
| |
|
| | def __new_getattr__(self, key: str): |
| | try: |
| | return super(self.__class__, self).__getattr__(key) |
| | except AttributeError: |
| | if item_name in dir(self): |
| | item = getattr(self, item_name) |
| | return getattr(item, key) |
| | raise |
| |
|
| | obj_cls.__getattr__ = __new_getattr__ |
| | obj_cls._patch = True |
| |
|
| |
|
| | def import_external_file(file_path: str): |
| | file_path = os.path.abspath(os.path.expanduser(file_path)) |
| | py_dir, py_file = os.path.split(file_path) |
| | assert os.path.isdir(py_dir), f'py_dir: {py_dir}' |
| | sys.path.insert(0, py_dir) |
| | return importlib.import_module(py_file.split('.', 1)[0]) |
| |
|