| | |
| | |
| | ''' |
| | @File : summary.py |
| | @Time : 2022/10/15 23:38:13 |
| | @Author : BQH |
| | @Version : 1.0 |
| | @Contact : raogx.vip@hotmail.com |
| | @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
| | @Desc : 运行时日志文件 |
| | ''' |
| |
|
| | |
| |
|
| | import os |
| | import sys |
| | import torch |
| | import logging |
| | from datetime import datetime |
| |
|
| | |
| |
|
| | try: |
| | from tensorboardX import SummaryWriter |
| | except ImportError: |
| | class SummaryWriter: |
| | def __init__(self, log_dir=None, comment='', **kwargs): |
| | print('\nunable to import tensorboardX, log will be recorded by pytorch!\n') |
| | self.log_dir = log_dir if log_dir is not None else './logs' |
| | os.makedirs('./logs', exist_ok=True) |
| | self.logs = {'comment': comment} |
| | return |
| |
|
| | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): |
| | if tag in self.logs: |
| | self.logs[tag].append((scalar_value, global_step, walltime)) |
| | else: |
| | self.logs[tag] = [(scalar_value, global_step, walltime)] |
| | return |
| |
|
| | def close(self): |
| | timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_') |
| | torch.save(self.logs, os.path.join(self.log_dir, 'log_%s.pickle' % timestamp)) |
| | return |
| |
|
| |
|
| | class EmptySummaryWriter: |
| | def __init__(self, **kwargs): |
| | pass |
| |
|
| | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): |
| | pass |
| |
|
| | def close(self): |
| | pass |
| |
|
| |
|
| | def create_summary(distributed_rank=0, **kwargs): |
| | if distributed_rank > 0: |
| | return EmptySummaryWriter(**kwargs) |
| | else: |
| | return SummaryWriter(**kwargs) |
| |
|
| |
|
| | def create_logger(distributed_rank=0, save_dir=None): |
| | logger = logging.getLogger('logger') |
| | logger.setLevel(logging.DEBUG) |
| |
|
| | filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) |
| |
|
| | |
| | if distributed_rank > 0: |
| | return logger |
| | ch = logging.StreamHandler(stream=sys.stdout) |
| | ch.setLevel(logging.DEBUG) |
| | |
| | formatter = logging.Formatter("%(message)s [%(asctime)s]") |
| | ch.setFormatter(formatter) |
| | logger.addHandler(ch) |
| |
|
| | if save_dir is not None: |
| | fh = logging.FileHandler(os.path.join(save_dir, filename)) |
| | fh.setLevel(logging.DEBUG) |
| | fh.setFormatter(formatter) |
| | logger.addHandler(fh) |
| |
|
| | return logger |
| |
|
| |
|
| | class Saver: |
| | def __init__(self, distributed_rank, save_dir): |
| | self.distributed_rank = distributed_rank |
| | self.save_dir = save_dir |
| | os.makedirs(self.save_dir, exist_ok=True) |
| | return |
| |
|
| | def save(self, obj, save_name): |
| | if self.distributed_rank == 0: |
| | torch.save(obj, os.path.join(self.save_dir, save_name + '.t7')) |
| | return 'checkpoint saved in %s !' % os.path.join(self.save_dir, save_name) |
| | else: |
| | return '' |
| |
|
| |
|
| | def create_saver(distributed_rank, save_dir): |
| | return Saver(distributed_rank, save_dir) |
| |
|
| |
|
| | class DisablePrint: |
| | def __init__(self, local_rank=0): |
| | self.local_rank = local_rank |
| |
|
| | def __enter__(self): |
| | if self.local_rank != 0: |
| | self._original_stdout = sys.stdout |
| | sys.stdout = open(os.devnull, 'w') |
| | else: |
| | pass |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | if self.local_rank != 0: |
| | sys.stdout.close() |
| | sys.stdout = self._original_stdout |
| | else: |
| | pass |