|
|
import argparse |
|
|
from contextlib import nullcontext |
|
|
|
|
|
|
|
|
import collections |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
|
|
|
import torch |
|
|
from time import time |
|
|
from src.text_utils.logging import get_logger |
|
|
from contextlib import contextmanager |
|
|
from timeit import default_timer |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def elapsed_timer(): |
|
|
start = default_timer() |
|
|
elapser = lambda: default_timer() - start |
|
|
yield lambda: elapser() |
|
|
end = default_timer() |
|
|
elapser = lambda: end-start |
|
|
|
|
|
class AverageMeter(object): |
|
|
"""Computes and stores the average and current value""" |
|
|
def __init__(self, name, fmt=':f'): |
|
|
self.name = name |
|
|
self.fmt = fmt |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
def __str__(self): |
|
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
|
|
|
def save_args_to_json(args, output_json_path): |
|
|
serializable_args = {} |
|
|
for k, v in vars(args).items(): |
|
|
try: |
|
|
v = json.dumps(v) |
|
|
serializable_args[k] = v |
|
|
except Exception as e: |
|
|
continue |
|
|
with open(output_json_path, 'w') as arg_json: |
|
|
json.dump(serializable_args, arg_json) |
|
|
|
|
|
|
|
|
def load_args_from_json(output_json_path): |
|
|
if os.path.isdir(output_json_path): |
|
|
output_json_path += 'train_args.json' |
|
|
with open(output_json_path, 'r') as arg_json: |
|
|
kwargs = json.load(arg_json) |
|
|
_kwargs = {} |
|
|
for k, v in kwargs.items(): |
|
|
if v == 'null': |
|
|
v = None |
|
|
elif v == 'true' or v == 'false': |
|
|
v = True if v == 'true' else False |
|
|
else: |
|
|
try: |
|
|
v = eval(v) |
|
|
except ValueError: |
|
|
pass |
|
|
_kwargs[k] = v |
|
|
args = argparse.Namespace(**_kwargs) |
|
|
return args |
|
|
|
|
|
def tensor_norm(input, input_mask=None): |
|
|
if input_mask is not None: |
|
|
_norm = torch.linalg.norm((input * input_mask.unsqueeze(-1)), dim=1) |
|
|
_norm = torch.masked_select(_norm, input_mask.bool().reshape(-1)) |
|
|
else: |
|
|
_norm = torch.linalg.norm(input, dim=1, ord=2) |
|
|
return _norm.mean() |
|
|
|
|
|
|
|
|
class print_time(): |
|
|
def __init__(self, task): |
|
|
self.task = task |
|
|
|
|
|
def __enter__(self): |
|
|
print_master(self.task) |
|
|
self.t = time() |
|
|
|
|
|
def __exit__(self, type, value, traceback): |
|
|
print_master(f'{self.task} took {time()-self.t:.02f}s') |
|
|
|
|
|
|
|
|
def print_rank(message): |
|
|
"""If distributed is initialized, print the rank.""" |
|
|
if torch.distributed.is_initialized(): |
|
|
logger.info(f'rank{torch.distributed.get_rank()}: ' + message) |
|
|
else: |
|
|
logger.info(message) |
|
|
|
|
|
|
|
|
def print_master(message): |
|
|
"""If distributed is initialized print only on rank 0.""" |
|
|
if torch.distributed.is_initialized(): |
|
|
if torch.distributed.get_rank() == 0: |
|
|
logger.info(message) |
|
|
else: |
|
|
logger.info(message) |
|
|
|
|
|
|
|
|
def str2bool(v): |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
|
return True |
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
|
|
|
|
|
def calc_gradient_norm(model, return_param_norm=False, return_details=True, is_deepspeed=False): |
|
|
''' |
|
|
return_param_norm: if True it returns the norm of parameters, otherwise grad |
|
|
No effect for DeepSpeed as it handles parameters differently |
|
|
''' |
|
|
total_norm = 0.0 |
|
|
n_parameter = 0 |
|
|
group_norm = collections.defaultdict(float) |
|
|
group_norm['total'] = 0.0 |
|
|
for n, p in model.named_parameters(): |
|
|
|
|
|
with nullcontext(): |
|
|
if p.requires_grad and p.grad is not None: |
|
|
if return_param_norm: |
|
|
param_norm = p.detach().data.norm(p=2).item() |
|
|
else: |
|
|
param_norm = p.grad.detach().data.norm(p=2).item() |
|
|
|
|
|
total_norm += param_norm ** 2 |
|
|
n_parameter += torch.numel(p.grad) |
|
|
module_name = 'q_encoder' |
|
|
|
|
|
if return_details: |
|
|
if 'embed' in n: |
|
|
part_name = 'embeddings' |
|
|
group_norm[f'{module_name}-{part_name}'] += param_norm |
|
|
elif 'addon_layer' in n: |
|
|
part_name = 'addon_layer' |
|
|
group_norm[f'{module_name}-{part_name}'] += param_norm |
|
|
elif 'layer' in n: |
|
|
part_name = re.search('layers.\d+|layer.\d+', n) |
|
|
if part_name: |
|
|
part_name = part_name.group(0) |
|
|
else: |
|
|
part_name = 'unknown_group' |
|
|
|
|
|
group_norm[f'{module_name}-{part_name}'] += param_norm |
|
|
if "model" in n: |
|
|
part_name = n[n.rfind("model")+6:] |
|
|
part_name = part_name.replace('module.', '').replace('.dense', '').replace('.weight', '').replace('.bias', '').replace('.pytorch', '').replace('.default', '') |
|
|
group_norm[f'{part_name}'] += param_norm |
|
|
|
|
|
group_norm['total'] = total_norm ** 0.5 |
|
|
return group_norm |
|
|
|
|
|
|
|
|
def get_gradient_norm(model): |
|
|
total_norm = 0.0 |
|
|
for p in model.parameters(): |
|
|
param_norm = p.grad.data.norm(2).item() if p.grad is not None else 0.0 |
|
|
total_norm += param_norm ** 2 |
|
|
total_norm = total_norm ** (1. / 2) |
|
|
return total_norm |
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
|
total_num = sum(p.numel() for p in model.parameters()) |
|
|
grad_num = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f'#Total parameters: {total_num}') |
|
|
print(f'#Parameters require gradient: {grad_num}') |
|
|
|