File size: 6,335 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import argparse
from contextlib import nullcontext

# import deepspeed
import collections
import json
import os
import re

import torch
from time import time
from src.text_utils.logging import get_logger
from contextlib import contextmanager
from timeit import default_timer

logger = get_logger(__name__)
########################################################################################################
## text_utils


@contextmanager
def elapsed_timer():
    start = default_timer()
    elapser = lambda: default_timer() - start
    yield lambda: elapser()
    end = default_timer()
    elapser = lambda: end-start

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def save_args_to_json(args, output_json_path):
    serializable_args = {}
    for k, v in vars(args).items():
        try:
            v = json.dumps(v)
            serializable_args[k] = v
        except Exception as e:
            continue
    with open(output_json_path, 'w') as arg_json:
        json.dump(serializable_args, arg_json)


def load_args_from_json(output_json_path):
    if os.path.isdir(output_json_path):
        output_json_path += 'train_args.json'
    with open(output_json_path, 'r') as arg_json:
        kwargs = json.load(arg_json)
    _kwargs = {}
    for k, v in kwargs.items():
        if v == 'null':
            v = None
        elif v == 'true' or v == 'false':
            v = True if v == 'true' else False
        else:
            try:
                v = eval(v)
            except ValueError:
                pass
        _kwargs[k] = v
    args = argparse.Namespace(**_kwargs)
    return args

def tensor_norm(input, input_mask=None):
    if input_mask is not None:
        _norm = torch.linalg.norm((input * input_mask.unsqueeze(-1)), dim=1)
        _norm = torch.masked_select(_norm, input_mask.bool().reshape(-1))
    else:
        _norm = torch.linalg.norm(input, dim=1, ord=2)
    return _norm.mean()


class print_time():
    def __init__(self, task):
        self.task = task

    def __enter__(self):
        print_master(self.task)
        self.t = time()

    def __exit__(self, type, value, traceback):
        print_master(f'{self.task} took {time()-self.t:.02f}s')


def print_rank(message):
    """If distributed is initialized, print the rank."""
    if torch.distributed.is_initialized():
        logger.info(f'rank{torch.distributed.get_rank()}: ' + message)
    else:
        logger.info(message)


def print_master(message):
    """If distributed is initialized print only on rank 0."""
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            logger.info(message)
    else:
        logger.info(message)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def calc_gradient_norm(model, return_param_norm=False, return_details=True, is_deepspeed=False):
    '''
    return_param_norm: if True it returns the norm of parameters, otherwise grad
    No effect for DeepSpeed as it handles parameters differently
    '''
    total_norm = 0.0
    n_parameter = 0
    group_norm = collections.defaultdict(float)
    group_norm['total'] = 0.0
    for n, p in model.named_parameters():
        # with deepspeed.zero.GatheredParameters(p, modifier_rank=None) if is_deepspeed else nullcontext():
        with nullcontext():
            if p.requires_grad and p.grad is not None:
                if return_param_norm:
                    param_norm = p.detach().data.norm(p=2).item()
                else:
                    param_norm = p.grad.detach().data.norm(p=2).item()
                # param_norm = p.grad.detach().data.norm(p=float('inf'))
                total_norm += param_norm ** 2
                n_parameter += torch.numel(p.grad)
                module_name = 'q_encoder'
                # only work for BERT/mistral
                if return_details:
                    if 'embed' in n:
                        part_name = 'embeddings'
                        group_norm[f'{module_name}-{part_name}'] += param_norm
                    elif 'addon_layer' in n:
                        part_name = 'addon_layer'
                        group_norm[f'{module_name}-{part_name}'] += param_norm
                    elif 'layer' in n:
                        part_name = re.search('layers.\d+|layer.\d+', n)
                        if part_name:
                            part_name = part_name.group(0)
                        else:
                            part_name = 'unknown_group'
                    # will include a lot of stats if the model is large
                    group_norm[f'{module_name}-{part_name}'] += param_norm
                    if "model" in n:
                        part_name = n[n.rfind("model")+6:]
                    part_name = part_name.replace('module.', '').replace('.dense', '').replace('.weight', '').replace('.bias', '').replace('.pytorch', '').replace('.default', '')
                    group_norm[f'{part_name}'] += param_norm

    group_norm['total'] = total_norm ** 0.5
    return group_norm


def get_gradient_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2).item() if p.grad is not None else 0.0
        total_norm += param_norm ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm


def count_parameters(model):
    total_num = sum(p.numel() for p in model.parameters())
    grad_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'#Total parameters: {total_num}')
    print(f'#Parameters require gradient: {grad_num}')