CocoRoF commited on
Commit
2b3d88b
·
verified ·
1 Parent(s): c50f133

Upload checkpoint-1491 contents

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/zero_to_fp32-checkpoint.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-3-27b-it
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.14.0
adapter_config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-3-27b-it",
5
+ "bias": "none",
6
+ "eva_config": null,
7
+ "exclude_modules": null,
8
+ "fan_in_fan_out": false,
9
+ "inference_mode": true,
10
+ "init_lora_weights": true,
11
+ "layer_replication": null,
12
+ "layers_pattern": null,
13
+ "layers_to_transform": null,
14
+ "loftq_config": {},
15
+ "lora_alpha": 32,
16
+ "lora_bias": false,
17
+ "lora_dropout": 0.05,
18
+ "megatron_config": null,
19
+ "megatron_core": "megatron.core",
20
+ "modules_to_save": [
21
+ "embed_tokens",
22
+ "lm_head"
23
+ ],
24
+ "peft_type": "LORA",
25
+ "r": 16,
26
+ "rank_pattern": {},
27
+ "revision": null,
28
+ "target_modules": [
29
+ "v_proj",
30
+ "k_proj",
31
+ "q_proj",
32
+ "o_proj"
33
+ ],
34
+ "task_type": "CAUSAL_LM",
35
+ "use_dora": false,
36
+ "use_rslora": false
37
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93c0b35bb2f4d35f6fa6698722098abcee6d9402322bd3ffd2f6d2ffdf2c8a3a
3
+ size 5711640384
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
3
+ }
global_step1491/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9088be42fdd195304bd2479329e0be460086e1781f3f785fcd8405175ff3c2fc
3
+ size 4283677776
global_step1491/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:786cd946e7fa2a921b4fe31bb44bbb0b64543eacd00c9894a0c909faf33e1df4
3
+ size 4283655696
global_step1491/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a7c969b8e5e37fb200a4840ab0089b4e5f1ab18a2cdf21da64d6d309e94f65d
3
+ size 4283655696
global_step1491/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e8514fb7770ebca226733faa38c8ca3e349e796eb471f68ecf4a09a9cd116f7
3
+ size 4283684048
global_step1491/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e4430135e20d55ea949a6242f48bfa19b371388c06b30d632f2be28358f621
3
+ size 4283689808
global_step1491/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:534dc9092aa96349817ea4dce8cf5dd800005632fabf268e68e8b7ce552f9185
3
+ size 4283655696
global_step1491/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85eca2ec0dbc051955695599819daa2e010eb6dc639af7a597759ada8656d673
3
+ size 4283655696
global_step1491/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:361d25aef70a8089a1f1cf5afa96ac938e6293701d699c17c4ffb4c238f44c0d
3
+ size 4283655696
global_step1491/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8556e68d38c1f01727a1243cdd9e8d56a36c3a032a04c0756d031d745b16039b
3
+ size 8531477944
latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step1491
preprocessor_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "do_center_crop": null,
7
+ "do_convert_rgb": null,
8
+ "do_normalize": true,
9
+ "do_pan_and_scan": null,
10
+ "do_rescale": true,
11
+ "do_resize": true,
12
+ "image_mean": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "image_processor_type": "Gemma3ImageProcessorFast",
18
+ "image_seq_length": 256,
19
+ "image_std": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "input_data_format": null,
25
+ "pan_and_scan_max_num_crops": null,
26
+ "pan_and_scan_min_crop_size": null,
27
+ "pan_and_scan_min_ratio_to_activate": null,
28
+ "processor_class": "Gemma3Processor",
29
+ "resample": 2,
30
+ "rescale_factor": 0.00392156862745098,
31
+ "return_tensors": null,
32
+ "size": {
33
+ "height": 896,
34
+ "width": 896
35
+ }
36
+ }
processor_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "image_seq_length": 256,
3
+ "processor_class": "Gemma3Processor"
4
+ }
rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efca4649b288b8ad5a7993429fbf1df9c8d390d3e6a2fd99c9abac426716a574
3
+ size 15984
rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f4319170fbce0a9f543035ef13c7e23bd26040d6391ed5c3a57f13d1b4396b8
3
+ size 15984
rng_state_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:327036d7728641b60db20720aa0f304e65c9e7c31b217dc9a0749c97f1ffbf21
3
+ size 15984
rng_state_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e79f1e1aec1dff62dac8c2ee0d2498a5360a8b6e7625b9af9c18fbf44c7021c6
3
+ size 15984
rng_state_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:941cf38d14342643415ee2d35c13039c0c9eec406c8c759c7377366fc2692851
3
+ size 15984
rng_state_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b52955539d014f9249f1d5ca976e749c952d5d82b447f3fd08ab16c6cf0b267a
3
+ size 15984
rng_state_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d28b3d402ce58c9d30f488e68bffada0c398774ba8971cf64a7ead8878a828d8
3
+ size 15984
rng_state_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2b3b7df2243bb8b9173d3d9a966115ca954c1686c9213a4c94f221cb77eb668
3
+ size 15984
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5cda86e1665017464cb4091acb5734e3208741b558b0a1874603320d04971f1
3
+ size 1064
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
trainer_state.json ADDED
@@ -0,0 +1,2427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.9994972347913524,
6
+ "eval_steps": 750,
7
+ "global_step": 1491,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.0033517680576504107,
14
+ "grad_norm": 14.694869995117188,
15
+ "learning_rate": 6.666666666666667e-06,
16
+ "loss": 53.6406,
17
+ "mean_token_accuracy": 0.5338318642228842,
18
+ "step": 5
19
+ },
20
+ {
21
+ "epoch": 0.006703536115300821,
22
+ "grad_norm": 14.033230781555176,
23
+ "learning_rate": 1.3333333333333333e-05,
24
+ "loss": 52.3838,
25
+ "mean_token_accuracy": 0.5248840853571892,
26
+ "step": 10
27
+ },
28
+ {
29
+ "epoch": 0.010055304172951232,
30
+ "grad_norm": 6.804769039154053,
31
+ "learning_rate": 2e-05,
32
+ "loss": 47.9105,
33
+ "mean_token_accuracy": 0.5399681400507689,
34
+ "step": 15
35
+ },
36
+ {
37
+ "epoch": 0.013407072230601643,
38
+ "grad_norm": 7.750083923339844,
39
+ "learning_rate": 2.6666666666666667e-05,
40
+ "loss": 41.8861,
41
+ "mean_token_accuracy": 0.55653104968369,
42
+ "step": 20
43
+ },
44
+ {
45
+ "epoch": 0.01675884028825205,
46
+ "grad_norm": 6.184543132781982,
47
+ "learning_rate": 3.3333333333333335e-05,
48
+ "loss": 37.33,
49
+ "mean_token_accuracy": 0.5655230440199375,
50
+ "step": 25
51
+ },
52
+ {
53
+ "epoch": 0.020110608345902465,
54
+ "grad_norm": 4.537179946899414,
55
+ "learning_rate": 4e-05,
56
+ "loss": 32.7503,
57
+ "mean_token_accuracy": 0.587661711126566,
58
+ "step": 30
59
+ },
60
+ {
61
+ "epoch": 0.023462376403552875,
62
+ "grad_norm": 3.6645753383636475,
63
+ "learning_rate": 4.666666666666667e-05,
64
+ "loss": 29.1892,
65
+ "mean_token_accuracy": 0.6075583577156067,
66
+ "step": 35
67
+ },
68
+ {
69
+ "epoch": 0.026814144461203285,
70
+ "grad_norm": 3.7526533603668213,
71
+ "learning_rate": 5.333333333333333e-05,
72
+ "loss": 26.3524,
73
+ "mean_token_accuracy": 0.6198613092303276,
74
+ "step": 40
75
+ },
76
+ {
77
+ "epoch": 0.030165912518853696,
78
+ "grad_norm": 3.0561397075653076,
79
+ "learning_rate": 6e-05,
80
+ "loss": 24.1513,
81
+ "mean_token_accuracy": 0.6353930421173573,
82
+ "step": 45
83
+ },
84
+ {
85
+ "epoch": 0.0335176805765041,
86
+ "grad_norm": 2.857618808746338,
87
+ "learning_rate": 6.666666666666667e-05,
88
+ "loss": 23.5029,
89
+ "mean_token_accuracy": 0.6437373287975788,
90
+ "step": 50
91
+ },
92
+ {
93
+ "epoch": 0.03686944863415452,
94
+ "grad_norm": 2.7901978492736816,
95
+ "learning_rate": 7.333333333333333e-05,
96
+ "loss": 22.9387,
97
+ "mean_token_accuracy": 0.646886795759201,
98
+ "step": 55
99
+ },
100
+ {
101
+ "epoch": 0.04022121669180493,
102
+ "grad_norm": 2.8266501426696777,
103
+ "learning_rate": 8e-05,
104
+ "loss": 22.0359,
105
+ "mean_token_accuracy": 0.6525138475000858,
106
+ "step": 60
107
+ },
108
+ {
109
+ "epoch": 0.04357298474945534,
110
+ "grad_norm": 2.5010733604431152,
111
+ "learning_rate": 8.666666666666667e-05,
112
+ "loss": 21.5158,
113
+ "mean_token_accuracy": 0.6548139773309231,
114
+ "step": 65
115
+ },
116
+ {
117
+ "epoch": 0.04692475280710575,
118
+ "grad_norm": 2.5834386348724365,
119
+ "learning_rate": 9.333333333333334e-05,
120
+ "loss": 21.5409,
121
+ "mean_token_accuracy": 0.6478891499340534,
122
+ "step": 70
123
+ },
124
+ {
125
+ "epoch": 0.05027652086475616,
126
+ "grad_norm": 2.6927576065063477,
127
+ "learning_rate": 0.0001,
128
+ "loss": 20.1017,
129
+ "mean_token_accuracy": 0.6757474772632122,
130
+ "step": 75
131
+ },
132
+ {
133
+ "epoch": 0.05362828892240657,
134
+ "grad_norm": 2.0276572704315186,
135
+ "learning_rate": 9.964689265536724e-05,
136
+ "loss": 19.9912,
137
+ "mean_token_accuracy": 0.6763999305665493,
138
+ "step": 80
139
+ },
140
+ {
141
+ "epoch": 0.05698005698005698,
142
+ "grad_norm": 2.4628567695617676,
143
+ "learning_rate": 9.929378531073446e-05,
144
+ "loss": 19.9089,
145
+ "mean_token_accuracy": 0.672279854118824,
146
+ "step": 85
147
+ },
148
+ {
149
+ "epoch": 0.06033182503770739,
150
+ "grad_norm": 2.258838415145874,
151
+ "learning_rate": 9.89406779661017e-05,
152
+ "loss": 19.7132,
153
+ "mean_token_accuracy": 0.6713059276342392,
154
+ "step": 90
155
+ },
156
+ {
157
+ "epoch": 0.0636835930953578,
158
+ "grad_norm": 2.447565793991089,
159
+ "learning_rate": 9.858757062146892e-05,
160
+ "loss": 18.7631,
161
+ "mean_token_accuracy": 0.6825208596885204,
162
+ "step": 95
163
+ },
164
+ {
165
+ "epoch": 0.0670353611530082,
166
+ "grad_norm": 2.1105902194976807,
167
+ "learning_rate": 9.823446327683616e-05,
168
+ "loss": 19.4631,
169
+ "mean_token_accuracy": 0.6674435302615166,
170
+ "step": 100
171
+ },
172
+ {
173
+ "epoch": 0.07038712921065862,
174
+ "grad_norm": 2.309248447418213,
175
+ "learning_rate": 9.78813559322034e-05,
176
+ "loss": 19.0249,
177
+ "mean_token_accuracy": 0.6734571024775505,
178
+ "step": 105
179
+ },
180
+ {
181
+ "epoch": 0.07373889726830904,
182
+ "grad_norm": 2.101681709289551,
183
+ "learning_rate": 9.752824858757063e-05,
184
+ "loss": 18.593,
185
+ "mean_token_accuracy": 0.6875097192823887,
186
+ "step": 110
187
+ },
188
+ {
189
+ "epoch": 0.07709066532595944,
190
+ "grad_norm": 2.157726526260376,
191
+ "learning_rate": 9.717514124293787e-05,
192
+ "loss": 18.5973,
193
+ "mean_token_accuracy": 0.6829216606914997,
194
+ "step": 115
195
+ },
196
+ {
197
+ "epoch": 0.08044243338360986,
198
+ "grad_norm": 2.0711209774017334,
199
+ "learning_rate": 9.682203389830509e-05,
200
+ "loss": 19.1541,
201
+ "mean_token_accuracy": 0.6785640828311443,
202
+ "step": 120
203
+ },
204
+ {
205
+ "epoch": 0.08379420144126026,
206
+ "grad_norm": 2.015594959259033,
207
+ "learning_rate": 9.646892655367233e-05,
208
+ "loss": 18.9493,
209
+ "mean_token_accuracy": 0.6861244946718216,
210
+ "step": 125
211
+ },
212
+ {
213
+ "epoch": 0.08714596949891068,
214
+ "grad_norm": 2.1295998096466064,
215
+ "learning_rate": 9.611581920903955e-05,
216
+ "loss": 18.5125,
217
+ "mean_token_accuracy": 0.6793887488543987,
218
+ "step": 130
219
+ },
220
+ {
221
+ "epoch": 0.09049773755656108,
222
+ "grad_norm": 2.2496395111083984,
223
+ "learning_rate": 9.576271186440679e-05,
224
+ "loss": 18.4019,
225
+ "mean_token_accuracy": 0.6890006221830844,
226
+ "step": 135
227
+ },
228
+ {
229
+ "epoch": 0.0938495056142115,
230
+ "grad_norm": 2.1168577671051025,
231
+ "learning_rate": 9.540960451977402e-05,
232
+ "loss": 18.7305,
233
+ "mean_token_accuracy": 0.6841622419655323,
234
+ "step": 140
235
+ },
236
+ {
237
+ "epoch": 0.0972012736718619,
238
+ "grad_norm": 1.8554915189743042,
239
+ "learning_rate": 9.505649717514125e-05,
240
+ "loss": 18.6606,
241
+ "mean_token_accuracy": 0.6859239712357521,
242
+ "step": 145
243
+ },
244
+ {
245
+ "epoch": 0.10055304172951232,
246
+ "grad_norm": 1.9698066711425781,
247
+ "learning_rate": 9.470338983050848e-05,
248
+ "loss": 19.1065,
249
+ "mean_token_accuracy": 0.6759489566087723,
250
+ "step": 150
251
+ },
252
+ {
253
+ "epoch": 0.10390480978716272,
254
+ "grad_norm": 2.2483623027801514,
255
+ "learning_rate": 9.43502824858757e-05,
256
+ "loss": 18.8041,
257
+ "mean_token_accuracy": 0.68142851293087,
258
+ "step": 155
259
+ },
260
+ {
261
+ "epoch": 0.10725657784481314,
262
+ "grad_norm": 1.8570690155029297,
263
+ "learning_rate": 9.399717514124294e-05,
264
+ "loss": 18.8862,
265
+ "mean_token_accuracy": 0.6791303649544715,
266
+ "step": 160
267
+ },
268
+ {
269
+ "epoch": 0.11060834590246355,
270
+ "grad_norm": 2.143021583557129,
271
+ "learning_rate": 9.364406779661016e-05,
272
+ "loss": 18.7605,
273
+ "mean_token_accuracy": 0.681893227249384,
274
+ "step": 165
275
+ },
276
+ {
277
+ "epoch": 0.11396011396011396,
278
+ "grad_norm": 1.8951307535171509,
279
+ "learning_rate": 9.32909604519774e-05,
280
+ "loss": 18.3005,
281
+ "mean_token_accuracy": 0.6897541806101799,
282
+ "step": 170
283
+ },
284
+ {
285
+ "epoch": 0.11731188201776437,
286
+ "grad_norm": 1.971745252609253,
287
+ "learning_rate": 9.293785310734464e-05,
288
+ "loss": 18.8995,
289
+ "mean_token_accuracy": 0.6820204116404056,
290
+ "step": 175
291
+ },
292
+ {
293
+ "epoch": 0.12066365007541478,
294
+ "grad_norm": 1.910328984260559,
295
+ "learning_rate": 9.258474576271187e-05,
296
+ "loss": 18.8808,
297
+ "mean_token_accuracy": 0.6812884464859963,
298
+ "step": 180
299
+ },
300
+ {
301
+ "epoch": 0.12401541813306519,
302
+ "grad_norm": 1.730974555015564,
303
+ "learning_rate": 9.223163841807911e-05,
304
+ "loss": 18.0871,
305
+ "mean_token_accuracy": 0.6907590143382549,
306
+ "step": 185
307
+ },
308
+ {
309
+ "epoch": 0.1273671861907156,
310
+ "grad_norm": 2.125452995300293,
311
+ "learning_rate": 9.187853107344633e-05,
312
+ "loss": 18.1569,
313
+ "mean_token_accuracy": 0.689236406236887,
314
+ "step": 190
315
+ },
316
+ {
317
+ "epoch": 0.13071895424836602,
318
+ "grad_norm": 2.0234949588775635,
319
+ "learning_rate": 9.152542372881357e-05,
320
+ "loss": 18.3342,
321
+ "mean_token_accuracy": 0.6902932204306126,
322
+ "step": 195
323
+ },
324
+ {
325
+ "epoch": 0.1340707223060164,
326
+ "grad_norm": 1.9802364110946655,
327
+ "learning_rate": 9.11723163841808e-05,
328
+ "loss": 18.7942,
329
+ "mean_token_accuracy": 0.6788501650094986,
330
+ "step": 200
331
+ },
332
+ {
333
+ "epoch": 0.13742249036366683,
334
+ "grad_norm": 1.8897534608840942,
335
+ "learning_rate": 9.081920903954803e-05,
336
+ "loss": 18.4679,
337
+ "mean_token_accuracy": 0.6900524459779263,
338
+ "step": 205
339
+ },
340
+ {
341
+ "epoch": 0.14077425842131724,
342
+ "grad_norm": 1.9040635824203491,
343
+ "learning_rate": 9.046610169491526e-05,
344
+ "loss": 18.0058,
345
+ "mean_token_accuracy": 0.690093420445919,
346
+ "step": 210
347
+ },
348
+ {
349
+ "epoch": 0.14412602647896766,
350
+ "grad_norm": 2.0558955669403076,
351
+ "learning_rate": 9.011299435028249e-05,
352
+ "loss": 17.5489,
353
+ "mean_token_accuracy": 0.7006829999387264,
354
+ "step": 215
355
+ },
356
+ {
357
+ "epoch": 0.14747779453661808,
358
+ "grad_norm": 1.7952055931091309,
359
+ "learning_rate": 8.975988700564972e-05,
360
+ "loss": 18.2907,
361
+ "mean_token_accuracy": 0.6876891441643238,
362
+ "step": 220
363
+ },
364
+ {
365
+ "epoch": 0.15082956259426847,
366
+ "grad_norm": 1.8588192462921143,
367
+ "learning_rate": 8.940677966101694e-05,
368
+ "loss": 18.4005,
369
+ "mean_token_accuracy": 0.6897859051823616,
370
+ "step": 225
371
+ },
372
+ {
373
+ "epoch": 0.15418133065191889,
374
+ "grad_norm": 1.9269477128982544,
375
+ "learning_rate": 8.905367231638418e-05,
376
+ "loss": 18.2096,
377
+ "mean_token_accuracy": 0.6909494370222091,
378
+ "step": 230
379
+ },
380
+ {
381
+ "epoch": 0.1575330987095693,
382
+ "grad_norm": 1.8693301677703857,
383
+ "learning_rate": 8.870056497175142e-05,
384
+ "loss": 18.394,
385
+ "mean_token_accuracy": 0.6836515329778194,
386
+ "step": 235
387
+ },
388
+ {
389
+ "epoch": 0.16088486676721972,
390
+ "grad_norm": 1.787061333656311,
391
+ "learning_rate": 8.834745762711864e-05,
392
+ "loss": 18.1503,
393
+ "mean_token_accuracy": 0.6907145738601684,
394
+ "step": 240
395
+ },
396
+ {
397
+ "epoch": 0.1642366348248701,
398
+ "grad_norm": 1.8895225524902344,
399
+ "learning_rate": 8.799435028248588e-05,
400
+ "loss": 18.3026,
401
+ "mean_token_accuracy": 0.6878940775990486,
402
+ "step": 245
403
+ },
404
+ {
405
+ "epoch": 0.16758840288252053,
406
+ "grad_norm": 1.835693120956421,
407
+ "learning_rate": 8.764124293785311e-05,
408
+ "loss": 17.9347,
409
+ "mean_token_accuracy": 0.6917316012084485,
410
+ "step": 250
411
+ },
412
+ {
413
+ "epoch": 0.17094017094017094,
414
+ "grad_norm": 1.7408661842346191,
415
+ "learning_rate": 8.728813559322035e-05,
416
+ "loss": 18.0051,
417
+ "mean_token_accuracy": 0.689583633840084,
418
+ "step": 255
419
+ },
420
+ {
421
+ "epoch": 0.17429193899782136,
422
+ "grad_norm": 1.9096996784210205,
423
+ "learning_rate": 8.693502824858759e-05,
424
+ "loss": 17.6064,
425
+ "mean_token_accuracy": 0.6965925216674804,
426
+ "step": 260
427
+ },
428
+ {
429
+ "epoch": 0.17764370705547175,
430
+ "grad_norm": 1.9822146892547607,
431
+ "learning_rate": 8.658192090395481e-05,
432
+ "loss": 17.6301,
433
+ "mean_token_accuracy": 0.7005406267940998,
434
+ "step": 265
435
+ },
436
+ {
437
+ "epoch": 0.18099547511312217,
438
+ "grad_norm": 1.8383901119232178,
439
+ "learning_rate": 8.622881355932204e-05,
440
+ "loss": 17.9114,
441
+ "mean_token_accuracy": 0.6876685306429863,
442
+ "step": 270
443
+ },
444
+ {
445
+ "epoch": 0.18434724317077258,
446
+ "grad_norm": 1.7920355796813965,
447
+ "learning_rate": 8.587570621468927e-05,
448
+ "loss": 18.1271,
449
+ "mean_token_accuracy": 0.689356567710638,
450
+ "step": 275
451
+ },
452
+ {
453
+ "epoch": 0.187699011228423,
454
+ "grad_norm": 1.6455663442611694,
455
+ "learning_rate": 8.55225988700565e-05,
456
+ "loss": 17.787,
457
+ "mean_token_accuracy": 0.6919776491820813,
458
+ "step": 280
459
+ },
460
+ {
461
+ "epoch": 0.1910507792860734,
462
+ "grad_norm": 1.9442647695541382,
463
+ "learning_rate": 8.516949152542373e-05,
464
+ "loss": 17.6019,
465
+ "mean_token_accuracy": 0.6980393722653389,
466
+ "step": 285
467
+ },
468
+ {
469
+ "epoch": 0.1944025473437238,
470
+ "grad_norm": 2.294377565383911,
471
+ "learning_rate": 8.481638418079096e-05,
472
+ "loss": 17.8778,
473
+ "mean_token_accuracy": 0.6954585202038288,
474
+ "step": 290
475
+ },
476
+ {
477
+ "epoch": 0.19775431540137423,
478
+ "grad_norm": 1.8009259700775146,
479
+ "learning_rate": 8.44632768361582e-05,
480
+ "loss": 17.5257,
481
+ "mean_token_accuracy": 0.6998075112700463,
482
+ "step": 295
483
+ },
484
+ {
485
+ "epoch": 0.20110608345902464,
486
+ "grad_norm": 2.015516757965088,
487
+ "learning_rate": 8.411016949152542e-05,
488
+ "loss": 17.7554,
489
+ "mean_token_accuracy": 0.6968327619135379,
490
+ "step": 300
491
+ },
492
+ {
493
+ "epoch": 0.20445785151667506,
494
+ "grad_norm": 1.5640082359313965,
495
+ "learning_rate": 8.375706214689266e-05,
496
+ "loss": 17.3438,
497
+ "mean_token_accuracy": 0.69996168166399,
498
+ "step": 305
499
+ },
500
+ {
501
+ "epoch": 0.20780961957432545,
502
+ "grad_norm": 1.9527899026870728,
503
+ "learning_rate": 8.340395480225988e-05,
504
+ "loss": 17.6883,
505
+ "mean_token_accuracy": 0.6988407798111439,
506
+ "step": 310
507
+ },
508
+ {
509
+ "epoch": 0.21116138763197587,
510
+ "grad_norm": 1.8222606182098389,
511
+ "learning_rate": 8.305084745762712e-05,
512
+ "loss": 17.0646,
513
+ "mean_token_accuracy": 0.7061679445207119,
514
+ "step": 315
515
+ },
516
+ {
517
+ "epoch": 0.21451315568962628,
518
+ "grad_norm": 1.8560868501663208,
519
+ "learning_rate": 8.269774011299435e-05,
520
+ "loss": 17.8875,
521
+ "mean_token_accuracy": 0.6941629223525524,
522
+ "step": 320
523
+ },
524
+ {
525
+ "epoch": 0.2178649237472767,
526
+ "grad_norm": 1.7588037252426147,
527
+ "learning_rate": 8.234463276836159e-05,
528
+ "loss": 17.6412,
529
+ "mean_token_accuracy": 0.6954927705228329,
530
+ "step": 325
531
+ },
532
+ {
533
+ "epoch": 0.2212166918049271,
534
+ "grad_norm": 1.738242268562317,
535
+ "learning_rate": 8.199152542372883e-05,
536
+ "loss": 17.8251,
537
+ "mean_token_accuracy": 0.6898994512856007,
538
+ "step": 330
539
+ },
540
+ {
541
+ "epoch": 0.2245684598625775,
542
+ "grad_norm": 1.8485089540481567,
543
+ "learning_rate": 8.163841807909605e-05,
544
+ "loss": 17.3078,
545
+ "mean_token_accuracy": 0.7000270999968052,
546
+ "step": 335
547
+ },
548
+ {
549
+ "epoch": 0.22792022792022792,
550
+ "grad_norm": 1.8579105138778687,
551
+ "learning_rate": 8.128531073446328e-05,
552
+ "loss": 17.3078,
553
+ "mean_token_accuracy": 0.6995702408254146,
554
+ "step": 340
555
+ },
556
+ {
557
+ "epoch": 0.23127199597787834,
558
+ "grad_norm": 1.7994352579116821,
559
+ "learning_rate": 8.093220338983051e-05,
560
+ "loss": 17.7557,
561
+ "mean_token_accuracy": 0.6928035505115986,
562
+ "step": 345
563
+ },
564
+ {
565
+ "epoch": 0.23462376403552873,
566
+ "grad_norm": 1.9240634441375732,
567
+ "learning_rate": 8.057909604519774e-05,
568
+ "loss": 17.4329,
569
+ "mean_token_accuracy": 0.6960855178534985,
570
+ "step": 350
571
+ },
572
+ {
573
+ "epoch": 0.23797553209317915,
574
+ "grad_norm": 1.6718952655792236,
575
+ "learning_rate": 8.022598870056498e-05,
576
+ "loss": 17.5951,
577
+ "mean_token_accuracy": 0.6947735913097859,
578
+ "step": 355
579
+ },
580
+ {
581
+ "epoch": 0.24132730015082957,
582
+ "grad_norm": 1.6835826635360718,
583
+ "learning_rate": 7.98728813559322e-05,
584
+ "loss": 18.1085,
585
+ "mean_token_accuracy": 0.6882089108228684,
586
+ "step": 360
587
+ },
588
+ {
589
+ "epoch": 0.24467906820847998,
590
+ "grad_norm": 1.7387073040008545,
591
+ "learning_rate": 7.951977401129944e-05,
592
+ "loss": 17.799,
593
+ "mean_token_accuracy": 0.6932998545467853,
594
+ "step": 365
595
+ },
596
+ {
597
+ "epoch": 0.24803083626613037,
598
+ "grad_norm": 2.0071725845336914,
599
+ "learning_rate": 7.916666666666666e-05,
600
+ "loss": 17.4076,
601
+ "mean_token_accuracy": 0.6961173862218857,
602
+ "step": 370
603
+ },
604
+ {
605
+ "epoch": 0.2513826043237808,
606
+ "grad_norm": 2.326915740966797,
607
+ "learning_rate": 7.88135593220339e-05,
608
+ "loss": 17.3121,
609
+ "mean_token_accuracy": 0.7005321949720382,
610
+ "step": 375
611
+ },
612
+ {
613
+ "epoch": 0.2547343723814312,
614
+ "grad_norm": 2.1876060962677,
615
+ "learning_rate": 7.846045197740113e-05,
616
+ "loss": 17.9069,
617
+ "mean_token_accuracy": 0.6906426399946213,
618
+ "step": 380
619
+ },
620
+ {
621
+ "epoch": 0.2580861404390816,
622
+ "grad_norm": 1.849671483039856,
623
+ "learning_rate": 7.810734463276837e-05,
624
+ "loss": 17.483,
625
+ "mean_token_accuracy": 0.7000573620200157,
626
+ "step": 385
627
+ },
628
+ {
629
+ "epoch": 0.26143790849673204,
630
+ "grad_norm": 1.6676862239837646,
631
+ "learning_rate": 7.775423728813561e-05,
632
+ "loss": 16.8936,
633
+ "mean_token_accuracy": 0.7045633904635906,
634
+ "step": 390
635
+ },
636
+ {
637
+ "epoch": 0.26478967655438246,
638
+ "grad_norm": 1.6702505350112915,
639
+ "learning_rate": 7.740112994350283e-05,
640
+ "loss": 17.904,
641
+ "mean_token_accuracy": 0.6874841086566448,
642
+ "step": 395
643
+ },
644
+ {
645
+ "epoch": 0.2681414446120328,
646
+ "grad_norm": 1.7280704975128174,
647
+ "learning_rate": 7.704802259887007e-05,
648
+ "loss": 17.4515,
649
+ "mean_token_accuracy": 0.7018027983605861,
650
+ "step": 400
651
+ },
652
+ {
653
+ "epoch": 0.27149321266968324,
654
+ "grad_norm": 1.8801991939544678,
655
+ "learning_rate": 7.669491525423729e-05,
656
+ "loss": 17.43,
657
+ "mean_token_accuracy": 0.7009049601852894,
658
+ "step": 405
659
+ },
660
+ {
661
+ "epoch": 0.27484498072733365,
662
+ "grad_norm": 1.9758073091506958,
663
+ "learning_rate": 7.634180790960453e-05,
664
+ "loss": 17.5984,
665
+ "mean_token_accuracy": 0.6948069363832474,
666
+ "step": 410
667
+ },
668
+ {
669
+ "epoch": 0.27819674878498407,
670
+ "grad_norm": 1.5747147798538208,
671
+ "learning_rate": 7.598870056497176e-05,
672
+ "loss": 18.3079,
673
+ "mean_token_accuracy": 0.6853139907121658,
674
+ "step": 415
675
+ },
676
+ {
677
+ "epoch": 0.2815485168426345,
678
+ "grad_norm": 1.6292234659194946,
679
+ "learning_rate": 7.563559322033898e-05,
680
+ "loss": 17.4527,
681
+ "mean_token_accuracy": 0.697540608048439,
682
+ "step": 420
683
+ },
684
+ {
685
+ "epoch": 0.2849002849002849,
686
+ "grad_norm": 1.6185086965560913,
687
+ "learning_rate": 7.528248587570622e-05,
688
+ "loss": 17.4193,
689
+ "mean_token_accuracy": 0.7012022204697133,
690
+ "step": 425
691
+ },
692
+ {
693
+ "epoch": 0.2882520529579353,
694
+ "grad_norm": 1.8361762762069702,
695
+ "learning_rate": 7.492937853107344e-05,
696
+ "loss": 17.4544,
697
+ "mean_token_accuracy": 0.698820473998785,
698
+ "step": 430
699
+ },
700
+ {
701
+ "epoch": 0.29160382101558574,
702
+ "grad_norm": 1.7740592956542969,
703
+ "learning_rate": 7.457627118644068e-05,
704
+ "loss": 18.0507,
705
+ "mean_token_accuracy": 0.6881603226065636,
706
+ "step": 435
707
+ },
708
+ {
709
+ "epoch": 0.29495558907323616,
710
+ "grad_norm": 1.8252911567687988,
711
+ "learning_rate": 7.42231638418079e-05,
712
+ "loss": 17.155,
713
+ "mean_token_accuracy": 0.7065504610538482,
714
+ "step": 440
715
+ },
716
+ {
717
+ "epoch": 0.2983073571308865,
718
+ "grad_norm": 1.8424382209777832,
719
+ "learning_rate": 7.387005649717514e-05,
720
+ "loss": 17.3055,
721
+ "mean_token_accuracy": 0.6978819817304611,
722
+ "step": 445
723
+ },
724
+ {
725
+ "epoch": 0.30165912518853694,
726
+ "grad_norm": 1.7494243383407593,
727
+ "learning_rate": 7.351694915254238e-05,
728
+ "loss": 16.8365,
729
+ "mean_token_accuracy": 0.7099504336714745,
730
+ "step": 450
731
+ },
732
+ {
733
+ "epoch": 0.30501089324618735,
734
+ "grad_norm": 1.936540961265564,
735
+ "learning_rate": 7.316384180790961e-05,
736
+ "loss": 18.2753,
737
+ "mean_token_accuracy": 0.6913827233016491,
738
+ "step": 455
739
+ },
740
+ {
741
+ "epoch": 0.30836266130383777,
742
+ "grad_norm": 1.810272216796875,
743
+ "learning_rate": 7.281073446327685e-05,
744
+ "loss": 17.0536,
745
+ "mean_token_accuracy": 0.6986232809722424,
746
+ "step": 460
747
+ },
748
+ {
749
+ "epoch": 0.3117144293614882,
750
+ "grad_norm": 1.6832094192504883,
751
+ "learning_rate": 7.245762711864407e-05,
752
+ "loss": 17.2231,
753
+ "mean_token_accuracy": 0.702030860632658,
754
+ "step": 465
755
+ },
756
+ {
757
+ "epoch": 0.3150661974191386,
758
+ "grad_norm": 1.8872151374816895,
759
+ "learning_rate": 7.21045197740113e-05,
760
+ "loss": 17.5502,
761
+ "mean_token_accuracy": 0.6932449921965599,
762
+ "step": 470
763
+ },
764
+ {
765
+ "epoch": 0.318417965476789,
766
+ "grad_norm": 1.788021445274353,
767
+ "learning_rate": 7.175141242937854e-05,
768
+ "loss": 16.8596,
769
+ "mean_token_accuracy": 0.7096694305539131,
770
+ "step": 475
771
+ },
772
+ {
773
+ "epoch": 0.32176973353443944,
774
+ "grad_norm": 1.8025559186935425,
775
+ "learning_rate": 7.139830508474577e-05,
776
+ "loss": 16.662,
777
+ "mean_token_accuracy": 0.7063573338091373,
778
+ "step": 480
779
+ },
780
+ {
781
+ "epoch": 0.3251215015920898,
782
+ "grad_norm": 2.274674654006958,
783
+ "learning_rate": 7.1045197740113e-05,
784
+ "loss": 17.5965,
785
+ "mean_token_accuracy": 0.6934389650821686,
786
+ "step": 485
787
+ },
788
+ {
789
+ "epoch": 0.3284732696497402,
790
+ "grad_norm": 1.6426053047180176,
791
+ "learning_rate": 7.069209039548022e-05,
792
+ "loss": 17.0914,
793
+ "mean_token_accuracy": 0.7049042917788029,
794
+ "step": 490
795
+ },
796
+ {
797
+ "epoch": 0.33182503770739064,
798
+ "grad_norm": 1.6252586841583252,
799
+ "learning_rate": 7.033898305084746e-05,
800
+ "loss": 17.6078,
801
+ "mean_token_accuracy": 0.6924709647893905,
802
+ "step": 495
803
+ },
804
+ {
805
+ "epoch": 0.33517680576504105,
806
+ "grad_norm": 1.7185930013656616,
807
+ "learning_rate": 6.998587570621468e-05,
808
+ "loss": 17.314,
809
+ "mean_token_accuracy": 0.7039985358715057,
810
+ "step": 500
811
+ },
812
+ {
813
+ "epoch": 0.33852857382269147,
814
+ "grad_norm": 1.7891852855682373,
815
+ "learning_rate": 6.963276836158192e-05,
816
+ "loss": 17.2188,
817
+ "mean_token_accuracy": 0.6977060906589031,
818
+ "step": 505
819
+ },
820
+ {
821
+ "epoch": 0.3418803418803419,
822
+ "grad_norm": 1.9103929996490479,
823
+ "learning_rate": 6.927966101694916e-05,
824
+ "loss": 17.4467,
825
+ "mean_token_accuracy": 0.6982413403689861,
826
+ "step": 510
827
+ },
828
+ {
829
+ "epoch": 0.3452321099379923,
830
+ "grad_norm": 1.8996375799179077,
831
+ "learning_rate": 6.892655367231638e-05,
832
+ "loss": 16.9608,
833
+ "mean_token_accuracy": 0.7054095402359962,
834
+ "step": 515
835
+ },
836
+ {
837
+ "epoch": 0.3485838779956427,
838
+ "grad_norm": 2.0335419178009033,
839
+ "learning_rate": 6.857344632768362e-05,
840
+ "loss": 17.3361,
841
+ "mean_token_accuracy": 0.7016568422317505,
842
+ "step": 520
843
+ },
844
+ {
845
+ "epoch": 0.35193564605329314,
846
+ "grad_norm": 1.9008755683898926,
847
+ "learning_rate": 6.822033898305085e-05,
848
+ "loss": 16.9694,
849
+ "mean_token_accuracy": 0.7059390284121037,
850
+ "step": 525
851
+ },
852
+ {
853
+ "epoch": 0.3552874141109435,
854
+ "grad_norm": 1.8340988159179688,
855
+ "learning_rate": 6.786723163841809e-05,
856
+ "loss": 17.3528,
857
+ "mean_token_accuracy": 0.7033507622778415,
858
+ "step": 530
859
+ },
860
+ {
861
+ "epoch": 0.3586391821685939,
862
+ "grad_norm": 1.6903594732284546,
863
+ "learning_rate": 6.751412429378532e-05,
864
+ "loss": 17.3021,
865
+ "mean_token_accuracy": 0.7001501135528088,
866
+ "step": 535
867
+ },
868
+ {
869
+ "epoch": 0.36199095022624433,
870
+ "grad_norm": 1.8101950883865356,
871
+ "learning_rate": 6.716101694915255e-05,
872
+ "loss": 17.938,
873
+ "mean_token_accuracy": 0.6908830553293228,
874
+ "step": 540
875
+ },
876
+ {
877
+ "epoch": 0.36534271828389475,
878
+ "grad_norm": 1.6470075845718384,
879
+ "learning_rate": 6.680790960451978e-05,
880
+ "loss": 17.6612,
881
+ "mean_token_accuracy": 0.6923478744924069,
882
+ "step": 545
883
+ },
884
+ {
885
+ "epoch": 0.36869448634154517,
886
+ "grad_norm": 2.1860337257385254,
887
+ "learning_rate": 6.6454802259887e-05,
888
+ "loss": 17.5684,
889
+ "mean_token_accuracy": 0.6983748801052571,
890
+ "step": 550
891
+ },
892
+ {
893
+ "epoch": 0.3720462543991956,
894
+ "grad_norm": 1.717653512954712,
895
+ "learning_rate": 6.610169491525424e-05,
896
+ "loss": 17.1166,
897
+ "mean_token_accuracy": 0.7025655619800091,
898
+ "step": 555
899
+ },
900
+ {
901
+ "epoch": 0.375398022456846,
902
+ "grad_norm": 1.9525723457336426,
903
+ "learning_rate": 6.574858757062147e-05,
904
+ "loss": 17.2908,
905
+ "mean_token_accuracy": 0.6997996769845486,
906
+ "step": 560
907
+ },
908
+ {
909
+ "epoch": 0.3787497905144964,
910
+ "grad_norm": 1.6053602695465088,
911
+ "learning_rate": 6.53954802259887e-05,
912
+ "loss": 17.3894,
913
+ "mean_token_accuracy": 0.698741364479065,
914
+ "step": 565
915
+ },
916
+ {
917
+ "epoch": 0.3821015585721468,
918
+ "grad_norm": 1.7356934547424316,
919
+ "learning_rate": 6.504237288135594e-05,
920
+ "loss": 17.1546,
921
+ "mean_token_accuracy": 0.7013543620705605,
922
+ "step": 570
923
+ },
924
+ {
925
+ "epoch": 0.3854533266297972,
926
+ "grad_norm": 1.7188559770584106,
927
+ "learning_rate": 6.468926553672316e-05,
928
+ "loss": 17.7637,
929
+ "mean_token_accuracy": 0.6936320647597313,
930
+ "step": 575
931
+ },
932
+ {
933
+ "epoch": 0.3888050946874476,
934
+ "grad_norm": 1.8413478136062622,
935
+ "learning_rate": 6.43361581920904e-05,
936
+ "loss": 17.8498,
937
+ "mean_token_accuracy": 0.695782047510147,
938
+ "step": 580
939
+ },
940
+ {
941
+ "epoch": 0.39215686274509803,
942
+ "grad_norm": 1.5715190172195435,
943
+ "learning_rate": 6.398305084745762e-05,
944
+ "loss": 17.4304,
945
+ "mean_token_accuracy": 0.6989135831594467,
946
+ "step": 585
947
+ },
948
+ {
949
+ "epoch": 0.39550863080274845,
950
+ "grad_norm": 1.8729442358016968,
951
+ "learning_rate": 6.362994350282486e-05,
952
+ "loss": 16.9125,
953
+ "mean_token_accuracy": 0.708356649428606,
954
+ "step": 590
955
+ },
956
+ {
957
+ "epoch": 0.39886039886039887,
958
+ "grad_norm": 2.099592685699463,
959
+ "learning_rate": 6.327683615819209e-05,
960
+ "loss": 17.542,
961
+ "mean_token_accuracy": 0.6888726130127907,
962
+ "step": 595
963
+ },
964
+ {
965
+ "epoch": 0.4022121669180493,
966
+ "grad_norm": 1.6204314231872559,
967
+ "learning_rate": 6.292372881355933e-05,
968
+ "loss": 16.9305,
969
+ "mean_token_accuracy": 0.7038852870464325,
970
+ "step": 600
971
+ },
972
+ {
973
+ "epoch": 0.4055639349756997,
974
+ "grad_norm": 2.12034010887146,
975
+ "learning_rate": 6.257062146892656e-05,
976
+ "loss": 17.0389,
977
+ "mean_token_accuracy": 0.704576326906681,
978
+ "step": 605
979
+ },
980
+ {
981
+ "epoch": 0.4089157030333501,
982
+ "grad_norm": 1.6821502447128296,
983
+ "learning_rate": 6.221751412429379e-05,
984
+ "loss": 16.788,
985
+ "mean_token_accuracy": 0.7000284940004349,
986
+ "step": 610
987
+ },
988
+ {
989
+ "epoch": 0.4122674710910005,
990
+ "grad_norm": 1.8137435913085938,
991
+ "learning_rate": 6.186440677966102e-05,
992
+ "loss": 17.5926,
993
+ "mean_token_accuracy": 0.6961537927389145,
994
+ "step": 615
995
+ },
996
+ {
997
+ "epoch": 0.4156192391486509,
998
+ "grad_norm": 1.6652235984802246,
999
+ "learning_rate": 6.151129943502825e-05,
1000
+ "loss": 17.3539,
1001
+ "mean_token_accuracy": 0.7028377398848533,
1002
+ "step": 620
1003
+ },
1004
+ {
1005
+ "epoch": 0.4189710072063013,
1006
+ "grad_norm": 1.766480803489685,
1007
+ "learning_rate": 6.115819209039548e-05,
1008
+ "loss": 17.529,
1009
+ "mean_token_accuracy": 0.6905739739537239,
1010
+ "step": 625
1011
+ },
1012
+ {
1013
+ "epoch": 0.42232277526395173,
1014
+ "grad_norm": 1.6319854259490967,
1015
+ "learning_rate": 6.080508474576272e-05,
1016
+ "loss": 16.9847,
1017
+ "mean_token_accuracy": 0.7060947254300117,
1018
+ "step": 630
1019
+ },
1020
+ {
1021
+ "epoch": 0.42567454332160215,
1022
+ "grad_norm": 2.1006696224212646,
1023
+ "learning_rate": 6.045197740112994e-05,
1024
+ "loss": 16.9317,
1025
+ "mean_token_accuracy": 0.7015593230724335,
1026
+ "step": 635
1027
+ },
1028
+ {
1029
+ "epoch": 0.42902631137925257,
1030
+ "grad_norm": 1.7353427410125732,
1031
+ "learning_rate": 6.009887005649718e-05,
1032
+ "loss": 17.4744,
1033
+ "mean_token_accuracy": 0.7001501567661762,
1034
+ "step": 640
1035
+ },
1036
+ {
1037
+ "epoch": 0.432378079436903,
1038
+ "grad_norm": 1.9449700117111206,
1039
+ "learning_rate": 5.974576271186441e-05,
1040
+ "loss": 16.8705,
1041
+ "mean_token_accuracy": 0.7026407413184643,
1042
+ "step": 645
1043
+ },
1044
+ {
1045
+ "epoch": 0.4357298474945534,
1046
+ "grad_norm": 1.6030067205429077,
1047
+ "learning_rate": 5.9392655367231644e-05,
1048
+ "loss": 16.8924,
1049
+ "mean_token_accuracy": 0.702277285605669,
1050
+ "step": 650
1051
+ },
1052
+ {
1053
+ "epoch": 0.43908161555220376,
1054
+ "grad_norm": 1.5722424983978271,
1055
+ "learning_rate": 5.903954802259888e-05,
1056
+ "loss": 17.364,
1057
+ "mean_token_accuracy": 0.6959278948605061,
1058
+ "step": 655
1059
+ },
1060
+ {
1061
+ "epoch": 0.4424333836098542,
1062
+ "grad_norm": 1.8168216943740845,
1063
+ "learning_rate": 5.86864406779661e-05,
1064
+ "loss": 16.704,
1065
+ "mean_token_accuracy": 0.7045813865959645,
1066
+ "step": 660
1067
+ },
1068
+ {
1069
+ "epoch": 0.4457851516675046,
1070
+ "grad_norm": 1.905402660369873,
1071
+ "learning_rate": 5.833333333333334e-05,
1072
+ "loss": 16.8896,
1073
+ "mean_token_accuracy": 0.7026248089969158,
1074
+ "step": 665
1075
+ },
1076
+ {
1077
+ "epoch": 0.449136919725155,
1078
+ "grad_norm": 1.7437454462051392,
1079
+ "learning_rate": 5.798022598870056e-05,
1080
+ "loss": 17.0496,
1081
+ "mean_token_accuracy": 0.702862861007452,
1082
+ "step": 670
1083
+ },
1084
+ {
1085
+ "epoch": 0.45248868778280543,
1086
+ "grad_norm": 1.7496871948242188,
1087
+ "learning_rate": 5.76271186440678e-05,
1088
+ "loss": 16.7024,
1089
+ "mean_token_accuracy": 0.7073140636086463,
1090
+ "step": 675
1091
+ },
1092
+ {
1093
+ "epoch": 0.45584045584045585,
1094
+ "grad_norm": 1.6521803140640259,
1095
+ "learning_rate": 5.727401129943503e-05,
1096
+ "loss": 17.4437,
1097
+ "mean_token_accuracy": 0.6910906590521335,
1098
+ "step": 680
1099
+ },
1100
+ {
1101
+ "epoch": 0.45919222389810627,
1102
+ "grad_norm": 1.7904677391052246,
1103
+ "learning_rate": 5.6920903954802264e-05,
1104
+ "loss": 17.4803,
1105
+ "mean_token_accuracy": 0.6987466789782047,
1106
+ "step": 685
1107
+ },
1108
+ {
1109
+ "epoch": 0.4625439919557567,
1110
+ "grad_norm": 2.4545388221740723,
1111
+ "learning_rate": 5.65677966101695e-05,
1112
+ "loss": 17.2987,
1113
+ "mean_token_accuracy": 0.699196208268404,
1114
+ "step": 690
1115
+ },
1116
+ {
1117
+ "epoch": 0.46589576001340705,
1118
+ "grad_norm": 1.6428866386413574,
1119
+ "learning_rate": 5.6214689265536723e-05,
1120
+ "loss": 16.7636,
1121
+ "mean_token_accuracy": 0.7029999569058418,
1122
+ "step": 695
1123
+ },
1124
+ {
1125
+ "epoch": 0.46924752807105746,
1126
+ "grad_norm": 1.9685977697372437,
1127
+ "learning_rate": 5.586158192090396e-05,
1128
+ "loss": 17.3887,
1129
+ "mean_token_accuracy": 0.6938736639916897,
1130
+ "step": 700
1131
+ },
1132
+ {
1133
+ "epoch": 0.4725992961287079,
1134
+ "grad_norm": 1.5567928552627563,
1135
+ "learning_rate": 5.550847457627118e-05,
1136
+ "loss": 17.1879,
1137
+ "mean_token_accuracy": 0.7024729043245316,
1138
+ "step": 705
1139
+ },
1140
+ {
1141
+ "epoch": 0.4759510641863583,
1142
+ "grad_norm": 1.6846567392349243,
1143
+ "learning_rate": 5.515536723163842e-05,
1144
+ "loss": 16.8679,
1145
+ "mean_token_accuracy": 0.7025640495121479,
1146
+ "step": 710
1147
+ },
1148
+ {
1149
+ "epoch": 0.4793028322440087,
1150
+ "grad_norm": 1.6596832275390625,
1151
+ "learning_rate": 5.480225988700565e-05,
1152
+ "loss": 16.7137,
1153
+ "mean_token_accuracy": 0.7031160019338131,
1154
+ "step": 715
1155
+ },
1156
+ {
1157
+ "epoch": 0.48265460030165913,
1158
+ "grad_norm": 2.04453444480896,
1159
+ "learning_rate": 5.4449152542372885e-05,
1160
+ "loss": 17.0646,
1161
+ "mean_token_accuracy": 0.7018779084086418,
1162
+ "step": 720
1163
+ },
1164
+ {
1165
+ "epoch": 0.48600636835930955,
1166
+ "grad_norm": 1.7244528532028198,
1167
+ "learning_rate": 5.409604519774012e-05,
1168
+ "loss": 17.1897,
1169
+ "mean_token_accuracy": 0.6981223806738853,
1170
+ "step": 725
1171
+ },
1172
+ {
1173
+ "epoch": 0.48935813641695997,
1174
+ "grad_norm": 1.6929802894592285,
1175
+ "learning_rate": 5.3742937853107344e-05,
1176
+ "loss": 17.2678,
1177
+ "mean_token_accuracy": 0.6996262572705746,
1178
+ "step": 730
1179
+ },
1180
+ {
1181
+ "epoch": 0.4927099044746104,
1182
+ "grad_norm": 1.7945303916931152,
1183
+ "learning_rate": 5.338983050847458e-05,
1184
+ "loss": 17.1465,
1185
+ "mean_token_accuracy": 0.7002299666404724,
1186
+ "step": 735
1187
+ },
1188
+ {
1189
+ "epoch": 0.49606167253226074,
1190
+ "grad_norm": 1.5936013460159302,
1191
+ "learning_rate": 5.30367231638418e-05,
1192
+ "loss": 17.0265,
1193
+ "mean_token_accuracy": 0.6998031720519066,
1194
+ "step": 740
1195
+ },
1196
+ {
1197
+ "epoch": 0.49941344058991116,
1198
+ "grad_norm": 1.553004264831543,
1199
+ "learning_rate": 5.268361581920904e-05,
1200
+ "loss": 16.7301,
1201
+ "mean_token_accuracy": 0.7022854961454869,
1202
+ "step": 745
1203
+ },
1204
+ {
1205
+ "epoch": 0.5027652086475616,
1206
+ "grad_norm": 1.7667690515518188,
1207
+ "learning_rate": 5.2330508474576275e-05,
1208
+ "loss": 16.8576,
1209
+ "mean_token_accuracy": 0.7085686258971691,
1210
+ "step": 750
1211
+ },
1212
+ {
1213
+ "epoch": 0.5027652086475616,
1214
+ "eval_loss": 1.0600364208221436,
1215
+ "eval_mean_token_accuracy": 0.7049777010093035,
1216
+ "eval_runtime": 1736.5707,
1217
+ "eval_samples_per_second": 1.392,
1218
+ "eval_steps_per_second": 0.174,
1219
+ "step": 750
1220
+ },
1221
+ {
1222
+ "epoch": 0.506116976705212,
1223
+ "grad_norm": 1.4901829957962036,
1224
+ "learning_rate": 5.1977401129943505e-05,
1225
+ "loss": 17.0004,
1226
+ "mean_token_accuracy": 0.6990960523486137,
1227
+ "step": 755
1228
+ },
1229
+ {
1230
+ "epoch": 0.5094687447628624,
1231
+ "grad_norm": 1.8451662063598633,
1232
+ "learning_rate": 5.162429378531074e-05,
1233
+ "loss": 17.2012,
1234
+ "mean_token_accuracy": 0.7007680244743824,
1235
+ "step": 760
1236
+ },
1237
+ {
1238
+ "epoch": 0.5128205128205128,
1239
+ "grad_norm": 1.6952011585235596,
1240
+ "learning_rate": 5.1271186440677964e-05,
1241
+ "loss": 17.612,
1242
+ "mean_token_accuracy": 0.6927438467741013,
1243
+ "step": 765
1244
+ },
1245
+ {
1246
+ "epoch": 0.5161722808781632,
1247
+ "grad_norm": 1.7307817935943604,
1248
+ "learning_rate": 5.09180790960452e-05,
1249
+ "loss": 16.8776,
1250
+ "mean_token_accuracy": 0.706513649225235,
1251
+ "step": 770
1252
+ },
1253
+ {
1254
+ "epoch": 0.5195240489358136,
1255
+ "grad_norm": 1.6692585945129395,
1256
+ "learning_rate": 5.056497175141243e-05,
1257
+ "loss": 17.0364,
1258
+ "mean_token_accuracy": 0.704279126226902,
1259
+ "step": 775
1260
+ },
1261
+ {
1262
+ "epoch": 0.5228758169934641,
1263
+ "grad_norm": 1.6963402032852173,
1264
+ "learning_rate": 5.0211864406779666e-05,
1265
+ "loss": 16.8957,
1266
+ "mean_token_accuracy": 0.7085353158414364,
1267
+ "step": 780
1268
+ },
1269
+ {
1270
+ "epoch": 0.5262275850511144,
1271
+ "grad_norm": 1.678458571434021,
1272
+ "learning_rate": 4.9858757062146896e-05,
1273
+ "loss": 17.7932,
1274
+ "mean_token_accuracy": 0.6964584030210972,
1275
+ "step": 785
1276
+ },
1277
+ {
1278
+ "epoch": 0.5295793531087649,
1279
+ "grad_norm": 1.7449827194213867,
1280
+ "learning_rate": 4.9505649717514125e-05,
1281
+ "loss": 16.8765,
1282
+ "mean_token_accuracy": 0.7036922007799149,
1283
+ "step": 790
1284
+ },
1285
+ {
1286
+ "epoch": 0.5329311211664153,
1287
+ "grad_norm": 1.7107524871826172,
1288
+ "learning_rate": 4.915254237288136e-05,
1289
+ "loss": 17.243,
1290
+ "mean_token_accuracy": 0.6997682720422744,
1291
+ "step": 795
1292
+ },
1293
+ {
1294
+ "epoch": 0.5362828892240656,
1295
+ "grad_norm": 1.6416223049163818,
1296
+ "learning_rate": 4.879943502824859e-05,
1297
+ "loss": 16.7253,
1298
+ "mean_token_accuracy": 0.7050332672894001,
1299
+ "step": 800
1300
+ },
1301
+ {
1302
+ "epoch": 0.5396346572817161,
1303
+ "grad_norm": 1.867213249206543,
1304
+ "learning_rate": 4.844632768361582e-05,
1305
+ "loss": 16.8566,
1306
+ "mean_token_accuracy": 0.7032786093652248,
1307
+ "step": 805
1308
+ },
1309
+ {
1310
+ "epoch": 0.5429864253393665,
1311
+ "grad_norm": 1.6539360284805298,
1312
+ "learning_rate": 4.809322033898305e-05,
1313
+ "loss": 16.6993,
1314
+ "mean_token_accuracy": 0.7117977932095527,
1315
+ "step": 810
1316
+ },
1317
+ {
1318
+ "epoch": 0.546338193397017,
1319
+ "grad_norm": 1.752715826034546,
1320
+ "learning_rate": 4.7740112994350286e-05,
1321
+ "loss": 17.5809,
1322
+ "mean_token_accuracy": 0.6992670528590679,
1323
+ "step": 815
1324
+ },
1325
+ {
1326
+ "epoch": 0.5496899614546673,
1327
+ "grad_norm": 1.806174397468567,
1328
+ "learning_rate": 4.7387005649717516e-05,
1329
+ "loss": 17.1588,
1330
+ "mean_token_accuracy": 0.6960965767502785,
1331
+ "step": 820
1332
+ },
1333
+ {
1334
+ "epoch": 0.5530417295123178,
1335
+ "grad_norm": 1.719764232635498,
1336
+ "learning_rate": 4.703389830508475e-05,
1337
+ "loss": 16.8685,
1338
+ "mean_token_accuracy": 0.7025568410754204,
1339
+ "step": 825
1340
+ },
1341
+ {
1342
+ "epoch": 0.5563934975699681,
1343
+ "grad_norm": 1.7800629138946533,
1344
+ "learning_rate": 4.668079096045198e-05,
1345
+ "loss": 16.8872,
1346
+ "mean_token_accuracy": 0.6994628652930259,
1347
+ "step": 830
1348
+ },
1349
+ {
1350
+ "epoch": 0.5597452656276186,
1351
+ "grad_norm": 1.7011103630065918,
1352
+ "learning_rate": 4.632768361581921e-05,
1353
+ "loss": 17.2342,
1354
+ "mean_token_accuracy": 0.7006913289427757,
1355
+ "step": 835
1356
+ },
1357
+ {
1358
+ "epoch": 0.563097033685269,
1359
+ "grad_norm": 1.6887695789337158,
1360
+ "learning_rate": 4.597457627118644e-05,
1361
+ "loss": 16.7385,
1362
+ "mean_token_accuracy": 0.7045929700136184,
1363
+ "step": 840
1364
+ },
1365
+ {
1366
+ "epoch": 0.5664488017429193,
1367
+ "grad_norm": 1.9496142864227295,
1368
+ "learning_rate": 4.562146892655367e-05,
1369
+ "loss": 16.8387,
1370
+ "mean_token_accuracy": 0.7083131410181522,
1371
+ "step": 845
1372
+ },
1373
+ {
1374
+ "epoch": 0.5698005698005698,
1375
+ "grad_norm": 1.7757388353347778,
1376
+ "learning_rate": 4.5268361581920906e-05,
1377
+ "loss": 17.3856,
1378
+ "mean_token_accuracy": 0.6994826771318913,
1379
+ "step": 850
1380
+ },
1381
+ {
1382
+ "epoch": 0.5731523378582202,
1383
+ "grad_norm": 1.7115302085876465,
1384
+ "learning_rate": 4.491525423728814e-05,
1385
+ "loss": 16.5993,
1386
+ "mean_token_accuracy": 0.7093915119767189,
1387
+ "step": 855
1388
+ },
1389
+ {
1390
+ "epoch": 0.5765041059158706,
1391
+ "grad_norm": 1.7968231439590454,
1392
+ "learning_rate": 4.456214689265537e-05,
1393
+ "loss": 16.8983,
1394
+ "mean_token_accuracy": 0.7087731070816516,
1395
+ "step": 860
1396
+ },
1397
+ {
1398
+ "epoch": 0.579855873973521,
1399
+ "grad_norm": 1.6066899299621582,
1400
+ "learning_rate": 4.42090395480226e-05,
1401
+ "loss": 16.7126,
1402
+ "mean_token_accuracy": 0.7053335346281528,
1403
+ "step": 865
1404
+ },
1405
+ {
1406
+ "epoch": 0.5832076420311715,
1407
+ "grad_norm": 1.6380205154418945,
1408
+ "learning_rate": 4.385593220338983e-05,
1409
+ "loss": 17.0037,
1410
+ "mean_token_accuracy": 0.7038719221949578,
1411
+ "step": 870
1412
+ },
1413
+ {
1414
+ "epoch": 0.5865594100888218,
1415
+ "grad_norm": 1.8956695795059204,
1416
+ "learning_rate": 4.350282485875706e-05,
1417
+ "loss": 16.9679,
1418
+ "mean_token_accuracy": 0.6983371920883655,
1419
+ "step": 875
1420
+ },
1421
+ {
1422
+ "epoch": 0.5899111781464723,
1423
+ "grad_norm": 1.625135064125061,
1424
+ "learning_rate": 4.314971751412429e-05,
1425
+ "loss": 17.0642,
1426
+ "mean_token_accuracy": 0.7067640118300915,
1427
+ "step": 880
1428
+ },
1429
+ {
1430
+ "epoch": 0.5932629462041227,
1431
+ "grad_norm": 1.6344581842422485,
1432
+ "learning_rate": 4.279661016949153e-05,
1433
+ "loss": 16.3079,
1434
+ "mean_token_accuracy": 0.7225491903722286,
1435
+ "step": 885
1436
+ },
1437
+ {
1438
+ "epoch": 0.596614714261773,
1439
+ "grad_norm": 1.7680976390838623,
1440
+ "learning_rate": 4.244350282485876e-05,
1441
+ "loss": 16.7187,
1442
+ "mean_token_accuracy": 0.7041032016277313,
1443
+ "step": 890
1444
+ },
1445
+ {
1446
+ "epoch": 0.5999664823194235,
1447
+ "grad_norm": 1.8056613206863403,
1448
+ "learning_rate": 4.209039548022599e-05,
1449
+ "loss": 17.3536,
1450
+ "mean_token_accuracy": 0.6975419208407402,
1451
+ "step": 895
1452
+ },
1453
+ {
1454
+ "epoch": 0.6033182503770739,
1455
+ "grad_norm": 1.8398966789245605,
1456
+ "learning_rate": 4.173728813559322e-05,
1457
+ "loss": 16.6245,
1458
+ "mean_token_accuracy": 0.7088275127112865,
1459
+ "step": 900
1460
+ },
1461
+ {
1462
+ "epoch": 0.6066700184347243,
1463
+ "grad_norm": 1.8332566022872925,
1464
+ "learning_rate": 4.138418079096045e-05,
1465
+ "loss": 17.0128,
1466
+ "mean_token_accuracy": 0.7018843114376068,
1467
+ "step": 905
1468
+ },
1469
+ {
1470
+ "epoch": 0.6100217864923747,
1471
+ "grad_norm": 1.6582337617874146,
1472
+ "learning_rate": 4.103107344632768e-05,
1473
+ "loss": 16.8948,
1474
+ "mean_token_accuracy": 0.7051651798188686,
1475
+ "step": 910
1476
+ },
1477
+ {
1478
+ "epoch": 0.6133735545500252,
1479
+ "grad_norm": 1.7373839616775513,
1480
+ "learning_rate": 4.067796610169492e-05,
1481
+ "loss": 16.9138,
1482
+ "mean_token_accuracy": 0.7022108249366283,
1483
+ "step": 915
1484
+ },
1485
+ {
1486
+ "epoch": 0.6167253226076755,
1487
+ "grad_norm": 1.6373577117919922,
1488
+ "learning_rate": 4.0324858757062154e-05,
1489
+ "loss": 17.0573,
1490
+ "mean_token_accuracy": 0.7042267486453057,
1491
+ "step": 920
1492
+ },
1493
+ {
1494
+ "epoch": 0.620077090665326,
1495
+ "grad_norm": 1.581024408340454,
1496
+ "learning_rate": 3.997175141242938e-05,
1497
+ "loss": 16.6234,
1498
+ "mean_token_accuracy": 0.7054463028907776,
1499
+ "step": 925
1500
+ },
1501
+ {
1502
+ "epoch": 0.6234288587229764,
1503
+ "grad_norm": 1.6900616884231567,
1504
+ "learning_rate": 3.961864406779661e-05,
1505
+ "loss": 17.0468,
1506
+ "mean_token_accuracy": 0.7014504976570606,
1507
+ "step": 930
1508
+ },
1509
+ {
1510
+ "epoch": 0.6267806267806267,
1511
+ "grad_norm": 1.6560430526733398,
1512
+ "learning_rate": 3.926553672316384e-05,
1513
+ "loss": 16.909,
1514
+ "mean_token_accuracy": 0.7064756542444229,
1515
+ "step": 935
1516
+ },
1517
+ {
1518
+ "epoch": 0.6301323948382772,
1519
+ "grad_norm": 1.8687000274658203,
1520
+ "learning_rate": 3.891242937853107e-05,
1521
+ "loss": 17.0047,
1522
+ "mean_token_accuracy": 0.7055176287889481,
1523
+ "step": 940
1524
+ },
1525
+ {
1526
+ "epoch": 0.6334841628959276,
1527
+ "grad_norm": 1.777716040611267,
1528
+ "learning_rate": 3.855932203389831e-05,
1529
+ "loss": 16.556,
1530
+ "mean_token_accuracy": 0.7047871246933937,
1531
+ "step": 945
1532
+ },
1533
+ {
1534
+ "epoch": 0.636835930953578,
1535
+ "grad_norm": 1.6830016374588013,
1536
+ "learning_rate": 3.820621468926554e-05,
1537
+ "loss": 16.5832,
1538
+ "mean_token_accuracy": 0.7049862682819367,
1539
+ "step": 950
1540
+ },
1541
+ {
1542
+ "epoch": 0.6401876990112284,
1543
+ "grad_norm": 1.5959638357162476,
1544
+ "learning_rate": 3.7853107344632774e-05,
1545
+ "loss": 16.8336,
1546
+ "mean_token_accuracy": 0.7072055459022522,
1547
+ "step": 955
1548
+ },
1549
+ {
1550
+ "epoch": 0.6435394670688789,
1551
+ "grad_norm": 1.82794189453125,
1552
+ "learning_rate": 3.7500000000000003e-05,
1553
+ "loss": 16.6644,
1554
+ "mean_token_accuracy": 0.7058505766093731,
1555
+ "step": 960
1556
+ },
1557
+ {
1558
+ "epoch": 0.6468912351265292,
1559
+ "grad_norm": 1.6554478406906128,
1560
+ "learning_rate": 3.714689265536723e-05,
1561
+ "loss": 16.2796,
1562
+ "mean_token_accuracy": 0.7101977132260799,
1563
+ "step": 965
1564
+ },
1565
+ {
1566
+ "epoch": 0.6502430031841796,
1567
+ "grad_norm": 1.8698370456695557,
1568
+ "learning_rate": 3.679378531073446e-05,
1569
+ "loss": 16.1934,
1570
+ "mean_token_accuracy": 0.7142874717712402,
1571
+ "step": 970
1572
+ },
1573
+ {
1574
+ "epoch": 0.6535947712418301,
1575
+ "grad_norm": 1.8040566444396973,
1576
+ "learning_rate": 3.644067796610169e-05,
1577
+ "loss": 16.5345,
1578
+ "mean_token_accuracy": 0.7125143676996231,
1579
+ "step": 975
1580
+ },
1581
+ {
1582
+ "epoch": 0.6569465392994804,
1583
+ "grad_norm": 1.6644558906555176,
1584
+ "learning_rate": 3.608757062146893e-05,
1585
+ "loss": 16.508,
1586
+ "mean_token_accuracy": 0.7078846462070942,
1587
+ "step": 980
1588
+ },
1589
+ {
1590
+ "epoch": 0.6602983073571309,
1591
+ "grad_norm": 1.7228506803512573,
1592
+ "learning_rate": 3.573446327683616e-05,
1593
+ "loss": 16.8474,
1594
+ "mean_token_accuracy": 0.7084795109927654,
1595
+ "step": 985
1596
+ },
1597
+ {
1598
+ "epoch": 0.6636500754147813,
1599
+ "grad_norm": 1.486241102218628,
1600
+ "learning_rate": 3.5381355932203394e-05,
1601
+ "loss": 17.1453,
1602
+ "mean_token_accuracy": 0.6975291892886162,
1603
+ "step": 990
1604
+ },
1605
+ {
1606
+ "epoch": 0.6670018434724317,
1607
+ "grad_norm": 1.7130765914916992,
1608
+ "learning_rate": 3.5028248587570624e-05,
1609
+ "loss": 16.458,
1610
+ "mean_token_accuracy": 0.7106956362724304,
1611
+ "step": 995
1612
+ },
1613
+ {
1614
+ "epoch": 0.6703536115300821,
1615
+ "grad_norm": 1.863926649093628,
1616
+ "learning_rate": 3.467514124293785e-05,
1617
+ "loss": 17.3095,
1618
+ "mean_token_accuracy": 0.6962033234536648,
1619
+ "step": 1000
1620
+ },
1621
+ {
1622
+ "epoch": 0.6737053795877326,
1623
+ "grad_norm": 1.6535072326660156,
1624
+ "learning_rate": 3.432203389830508e-05,
1625
+ "loss": 16.6846,
1626
+ "mean_token_accuracy": 0.7084034703671932,
1627
+ "step": 1005
1628
+ },
1629
+ {
1630
+ "epoch": 0.6770571476453829,
1631
+ "grad_norm": 1.7278594970703125,
1632
+ "learning_rate": 3.396892655367232e-05,
1633
+ "loss": 16.9805,
1634
+ "mean_token_accuracy": 0.7026786416769027,
1635
+ "step": 1010
1636
+ },
1637
+ {
1638
+ "epoch": 0.6804089157030333,
1639
+ "grad_norm": 1.9055004119873047,
1640
+ "learning_rate": 3.361581920903955e-05,
1641
+ "loss": 17.2562,
1642
+ "mean_token_accuracy": 0.6977267302572727,
1643
+ "step": 1015
1644
+ },
1645
+ {
1646
+ "epoch": 0.6837606837606838,
1647
+ "grad_norm": 1.6398614645004272,
1648
+ "learning_rate": 3.326271186440678e-05,
1649
+ "loss": 17.3378,
1650
+ "mean_token_accuracy": 0.6958214737474918,
1651
+ "step": 1020
1652
+ },
1653
+ {
1654
+ "epoch": 0.6871124518183341,
1655
+ "grad_norm": 1.926950454711914,
1656
+ "learning_rate": 3.2909604519774014e-05,
1657
+ "loss": 16.6536,
1658
+ "mean_token_accuracy": 0.7083842910826206,
1659
+ "step": 1025
1660
+ },
1661
+ {
1662
+ "epoch": 0.6904642198759846,
1663
+ "grad_norm": 1.8061659336090088,
1664
+ "learning_rate": 3.2556497175141244e-05,
1665
+ "loss": 16.643,
1666
+ "mean_token_accuracy": 0.7093963578343392,
1667
+ "step": 1030
1668
+ },
1669
+ {
1670
+ "epoch": 0.693815987933635,
1671
+ "grad_norm": 1.6816084384918213,
1672
+ "learning_rate": 3.2203389830508473e-05,
1673
+ "loss": 16.9696,
1674
+ "mean_token_accuracy": 0.7000316813588142,
1675
+ "step": 1035
1676
+ },
1677
+ {
1678
+ "epoch": 0.6971677559912854,
1679
+ "grad_norm": 1.630842685699463,
1680
+ "learning_rate": 3.185028248587571e-05,
1681
+ "loss": 16.587,
1682
+ "mean_token_accuracy": 0.7107978977262974,
1683
+ "step": 1040
1684
+ },
1685
+ {
1686
+ "epoch": 0.7005195240489358,
1687
+ "grad_norm": 1.755123257637024,
1688
+ "learning_rate": 3.149717514124294e-05,
1689
+ "loss": 17.0736,
1690
+ "mean_token_accuracy": 0.7017260067164898,
1691
+ "step": 1045
1692
+ },
1693
+ {
1694
+ "epoch": 0.7038712921065863,
1695
+ "grad_norm": 1.4850029945373535,
1696
+ "learning_rate": 3.114406779661017e-05,
1697
+ "loss": 16.3165,
1698
+ "mean_token_accuracy": 0.7119720429182053,
1699
+ "step": 1050
1700
+ },
1701
+ {
1702
+ "epoch": 0.7072230601642366,
1703
+ "grad_norm": 1.916961908340454,
1704
+ "learning_rate": 3.0790960451977405e-05,
1705
+ "loss": 17.0237,
1706
+ "mean_token_accuracy": 0.6976533338427544,
1707
+ "step": 1055
1708
+ },
1709
+ {
1710
+ "epoch": 0.710574828221887,
1711
+ "grad_norm": 1.5003294944763184,
1712
+ "learning_rate": 3.043785310734463e-05,
1713
+ "loss": 16.8504,
1714
+ "mean_token_accuracy": 0.7056308597326278,
1715
+ "step": 1060
1716
+ },
1717
+ {
1718
+ "epoch": 0.7139265962795375,
1719
+ "grad_norm": 1.9166836738586426,
1720
+ "learning_rate": 3.0084745762711864e-05,
1721
+ "loss": 16.8231,
1722
+ "mean_token_accuracy": 0.7023352533578873,
1723
+ "step": 1065
1724
+ },
1725
+ {
1726
+ "epoch": 0.7172783643371878,
1727
+ "grad_norm": 1.7789411544799805,
1728
+ "learning_rate": 2.97316384180791e-05,
1729
+ "loss": 17.3132,
1730
+ "mean_token_accuracy": 0.6994914725422859,
1731
+ "step": 1070
1732
+ },
1733
+ {
1734
+ "epoch": 0.7206301323948383,
1735
+ "grad_norm": 1.7289875745773315,
1736
+ "learning_rate": 2.937853107344633e-05,
1737
+ "loss": 17.3902,
1738
+ "mean_token_accuracy": 0.69447166249156,
1739
+ "step": 1075
1740
+ },
1741
+ {
1742
+ "epoch": 0.7239819004524887,
1743
+ "grad_norm": 1.4835467338562012,
1744
+ "learning_rate": 2.902542372881356e-05,
1745
+ "loss": 16.751,
1746
+ "mean_token_accuracy": 0.7052346661686897,
1747
+ "step": 1080
1748
+ },
1749
+ {
1750
+ "epoch": 0.7273336685101391,
1751
+ "grad_norm": 1.5802119970321655,
1752
+ "learning_rate": 2.8672316384180792e-05,
1753
+ "loss": 16.6574,
1754
+ "mean_token_accuracy": 0.7059398606419564,
1755
+ "step": 1085
1756
+ },
1757
+ {
1758
+ "epoch": 0.7306854365677895,
1759
+ "grad_norm": 1.8420851230621338,
1760
+ "learning_rate": 2.8319209039548022e-05,
1761
+ "loss": 16.9315,
1762
+ "mean_token_accuracy": 0.7063411138951778,
1763
+ "step": 1090
1764
+ },
1765
+ {
1766
+ "epoch": 0.7340372046254399,
1767
+ "grad_norm": 1.7593777179718018,
1768
+ "learning_rate": 2.7966101694915255e-05,
1769
+ "loss": 16.8653,
1770
+ "mean_token_accuracy": 0.7089171193540096,
1771
+ "step": 1095
1772
+ },
1773
+ {
1774
+ "epoch": 0.7373889726830903,
1775
+ "grad_norm": 1.681443452835083,
1776
+ "learning_rate": 2.7612994350282488e-05,
1777
+ "loss": 16.9878,
1778
+ "mean_token_accuracy": 0.7057393230497837,
1779
+ "step": 1100
1780
+ },
1781
+ {
1782
+ "epoch": 0.7407407407407407,
1783
+ "grad_norm": 1.6064281463623047,
1784
+ "learning_rate": 2.725988700564972e-05,
1785
+ "loss": 16.6153,
1786
+ "mean_token_accuracy": 0.7038764618337154,
1787
+ "step": 1105
1788
+ },
1789
+ {
1790
+ "epoch": 0.7440925087983912,
1791
+ "grad_norm": 1.5632483959197998,
1792
+ "learning_rate": 2.690677966101695e-05,
1793
+ "loss": 16.0927,
1794
+ "mean_token_accuracy": 0.7171440742909908,
1795
+ "step": 1110
1796
+ },
1797
+ {
1798
+ "epoch": 0.7474442768560415,
1799
+ "grad_norm": 1.8588156700134277,
1800
+ "learning_rate": 2.6553672316384183e-05,
1801
+ "loss": 16.5765,
1802
+ "mean_token_accuracy": 0.7098327249288559,
1803
+ "step": 1115
1804
+ },
1805
+ {
1806
+ "epoch": 0.750796044913692,
1807
+ "grad_norm": 1.5576221942901611,
1808
+ "learning_rate": 2.6200564971751413e-05,
1809
+ "loss": 16.6568,
1810
+ "mean_token_accuracy": 0.7029327027499676,
1811
+ "step": 1120
1812
+ },
1813
+ {
1814
+ "epoch": 0.7541478129713424,
1815
+ "grad_norm": 1.645244836807251,
1816
+ "learning_rate": 2.5847457627118642e-05,
1817
+ "loss": 16.7294,
1818
+ "mean_token_accuracy": 0.7060277953743934,
1819
+ "step": 1125
1820
+ },
1821
+ {
1822
+ "epoch": 0.7574995810289928,
1823
+ "grad_norm": 1.4038984775543213,
1824
+ "learning_rate": 2.549435028248588e-05,
1825
+ "loss": 16.5925,
1826
+ "mean_token_accuracy": 0.7068064086139202,
1827
+ "step": 1130
1828
+ },
1829
+ {
1830
+ "epoch": 0.7608513490866432,
1831
+ "grad_norm": 1.7987641096115112,
1832
+ "learning_rate": 2.514124293785311e-05,
1833
+ "loss": 16.6834,
1834
+ "mean_token_accuracy": 0.7070130936801433,
1835
+ "step": 1135
1836
+ },
1837
+ {
1838
+ "epoch": 0.7642031171442936,
1839
+ "grad_norm": 1.5423444509506226,
1840
+ "learning_rate": 2.478813559322034e-05,
1841
+ "loss": 16.4551,
1842
+ "mean_token_accuracy": 0.7121224895119667,
1843
+ "step": 1140
1844
+ },
1845
+ {
1846
+ "epoch": 0.767554885201944,
1847
+ "grad_norm": 1.7546942234039307,
1848
+ "learning_rate": 2.443502824858757e-05,
1849
+ "loss": 16.9741,
1850
+ "mean_token_accuracy": 0.7010989025235176,
1851
+ "step": 1145
1852
+ },
1853
+ {
1854
+ "epoch": 0.7709066532595944,
1855
+ "grad_norm": 1.8481935262680054,
1856
+ "learning_rate": 2.4081920903954803e-05,
1857
+ "loss": 16.6323,
1858
+ "mean_token_accuracy": 0.7058765202760696,
1859
+ "step": 1150
1860
+ },
1861
+ {
1862
+ "epoch": 0.7742584213172449,
1863
+ "grad_norm": 1.6855909824371338,
1864
+ "learning_rate": 2.3728813559322036e-05,
1865
+ "loss": 16.6844,
1866
+ "mean_token_accuracy": 0.7119428858160972,
1867
+ "step": 1155
1868
+ },
1869
+ {
1870
+ "epoch": 0.7776101893748952,
1871
+ "grad_norm": 1.9828130006790161,
1872
+ "learning_rate": 2.3375706214689266e-05,
1873
+ "loss": 16.866,
1874
+ "mean_token_accuracy": 0.7036800056695938,
1875
+ "step": 1160
1876
+ },
1877
+ {
1878
+ "epoch": 0.7809619574325457,
1879
+ "grad_norm": 1.5005120038986206,
1880
+ "learning_rate": 2.30225988700565e-05,
1881
+ "loss": 16.3539,
1882
+ "mean_token_accuracy": 0.711839384585619,
1883
+ "step": 1165
1884
+ },
1885
+ {
1886
+ "epoch": 0.7843137254901961,
1887
+ "grad_norm": 2.262735366821289,
1888
+ "learning_rate": 2.266949152542373e-05,
1889
+ "loss": 16.4102,
1890
+ "mean_token_accuracy": 0.7110463745892048,
1891
+ "step": 1170
1892
+ },
1893
+ {
1894
+ "epoch": 0.7876654935478465,
1895
+ "grad_norm": 1.6699568033218384,
1896
+ "learning_rate": 2.231638418079096e-05,
1897
+ "loss": 17.1027,
1898
+ "mean_token_accuracy": 0.7031991191208362,
1899
+ "step": 1175
1900
+ },
1901
+ {
1902
+ "epoch": 0.7910172616054969,
1903
+ "grad_norm": 1.6248890161514282,
1904
+ "learning_rate": 2.196327683615819e-05,
1905
+ "loss": 16.3399,
1906
+ "mean_token_accuracy": 0.7143234215676785,
1907
+ "step": 1180
1908
+ },
1909
+ {
1910
+ "epoch": 0.7943690296631473,
1911
+ "grad_norm": 1.7570775747299194,
1912
+ "learning_rate": 2.1610169491525427e-05,
1913
+ "loss": 16.2255,
1914
+ "mean_token_accuracy": 0.7123358778655529,
1915
+ "step": 1185
1916
+ },
1917
+ {
1918
+ "epoch": 0.7977207977207977,
1919
+ "grad_norm": 1.9391677379608154,
1920
+ "learning_rate": 2.1257062146892657e-05,
1921
+ "loss": 16.3472,
1922
+ "mean_token_accuracy": 0.711616413295269,
1923
+ "step": 1190
1924
+ },
1925
+ {
1926
+ "epoch": 0.8010725657784481,
1927
+ "grad_norm": 1.8997981548309326,
1928
+ "learning_rate": 2.0903954802259886e-05,
1929
+ "loss": 16.5601,
1930
+ "mean_token_accuracy": 0.7071553356945515,
1931
+ "step": 1195
1932
+ },
1933
+ {
1934
+ "epoch": 0.8044243338360986,
1935
+ "grad_norm": 1.6094359159469604,
1936
+ "learning_rate": 2.055084745762712e-05,
1937
+ "loss": 16.622,
1938
+ "mean_token_accuracy": 0.7043877936899662,
1939
+ "step": 1200
1940
+ },
1941
+ {
1942
+ "epoch": 0.8077761018937489,
1943
+ "grad_norm": 1.7940973043441772,
1944
+ "learning_rate": 2.0197740112994352e-05,
1945
+ "loss": 16.6535,
1946
+ "mean_token_accuracy": 0.705554535984993,
1947
+ "step": 1205
1948
+ },
1949
+ {
1950
+ "epoch": 0.8111278699513994,
1951
+ "grad_norm": 1.6890041828155518,
1952
+ "learning_rate": 1.984463276836158e-05,
1953
+ "loss": 17.2328,
1954
+ "mean_token_accuracy": 0.6988375537097454,
1955
+ "step": 1210
1956
+ },
1957
+ {
1958
+ "epoch": 0.8144796380090498,
1959
+ "grad_norm": 1.5568735599517822,
1960
+ "learning_rate": 1.9491525423728814e-05,
1961
+ "loss": 16.9753,
1962
+ "mean_token_accuracy": 0.7015632651746273,
1963
+ "step": 1215
1964
+ },
1965
+ {
1966
+ "epoch": 0.8178314060667002,
1967
+ "grad_norm": 1.7157835960388184,
1968
+ "learning_rate": 1.9138418079096047e-05,
1969
+ "loss": 16.3668,
1970
+ "mean_token_accuracy": 0.7098449252545833,
1971
+ "step": 1220
1972
+ },
1973
+ {
1974
+ "epoch": 0.8211831741243506,
1975
+ "grad_norm": 1.7175644636154175,
1976
+ "learning_rate": 1.8785310734463277e-05,
1977
+ "loss": 16.8061,
1978
+ "mean_token_accuracy": 0.7032932281494141,
1979
+ "step": 1225
1980
+ },
1981
+ {
1982
+ "epoch": 0.824534942182001,
1983
+ "grad_norm": 1.7225829362869263,
1984
+ "learning_rate": 1.843220338983051e-05,
1985
+ "loss": 16.5716,
1986
+ "mean_token_accuracy": 0.7074852548539639,
1987
+ "step": 1230
1988
+ },
1989
+ {
1990
+ "epoch": 0.8278867102396514,
1991
+ "grad_norm": 1.8654727935791016,
1992
+ "learning_rate": 1.8079096045197743e-05,
1993
+ "loss": 16.8172,
1994
+ "mean_token_accuracy": 0.7035241700708866,
1995
+ "step": 1235
1996
+ },
1997
+ {
1998
+ "epoch": 0.8312384782973018,
1999
+ "grad_norm": 1.9604694843292236,
2000
+ "learning_rate": 1.7725988700564972e-05,
2001
+ "loss": 16.2992,
2002
+ "mean_token_accuracy": 0.714275274425745,
2003
+ "step": 1240
2004
+ },
2005
+ {
2006
+ "epoch": 0.8345902463549523,
2007
+ "grad_norm": 1.7569185495376587,
2008
+ "learning_rate": 1.7372881355932205e-05,
2009
+ "loss": 16.6269,
2010
+ "mean_token_accuracy": 0.7052666112780571,
2011
+ "step": 1245
2012
+ },
2013
+ {
2014
+ "epoch": 0.8379420144126026,
2015
+ "grad_norm": 1.6537069082260132,
2016
+ "learning_rate": 1.7019774011299435e-05,
2017
+ "loss": 16.5978,
2018
+ "mean_token_accuracy": 0.708269502967596,
2019
+ "step": 1250
2020
+ },
2021
+ {
2022
+ "epoch": 0.8412937824702531,
2023
+ "grad_norm": 1.8623359203338623,
2024
+ "learning_rate": 1.6666666666666667e-05,
2025
+ "loss": 16.1831,
2026
+ "mean_token_accuracy": 0.7164609245955944,
2027
+ "step": 1255
2028
+ },
2029
+ {
2030
+ "epoch": 0.8446455505279035,
2031
+ "grad_norm": 1.7004101276397705,
2032
+ "learning_rate": 1.63135593220339e-05,
2033
+ "loss": 16.9611,
2034
+ "mean_token_accuracy": 0.7057129152119159,
2035
+ "step": 1260
2036
+ },
2037
+ {
2038
+ "epoch": 0.8479973185855538,
2039
+ "grad_norm": 1.8294973373413086,
2040
+ "learning_rate": 1.596045197740113e-05,
2041
+ "loss": 16.8036,
2042
+ "mean_token_accuracy": 0.7046464517712593,
2043
+ "step": 1265
2044
+ },
2045
+ {
2046
+ "epoch": 0.8513490866432043,
2047
+ "grad_norm": 1.7992702722549438,
2048
+ "learning_rate": 1.5607344632768363e-05,
2049
+ "loss": 16.139,
2050
+ "mean_token_accuracy": 0.7126708298921585,
2051
+ "step": 1270
2052
+ },
2053
+ {
2054
+ "epoch": 0.8547008547008547,
2055
+ "grad_norm": 2.033846855163574,
2056
+ "learning_rate": 1.5254237288135596e-05,
2057
+ "loss": 16.49,
2058
+ "mean_token_accuracy": 0.707030464708805,
2059
+ "step": 1275
2060
+ },
2061
+ {
2062
+ "epoch": 0.8580526227585051,
2063
+ "grad_norm": 1.690617561340332,
2064
+ "learning_rate": 1.4901129943502825e-05,
2065
+ "loss": 16.7829,
2066
+ "mean_token_accuracy": 0.7026272863149643,
2067
+ "step": 1280
2068
+ },
2069
+ {
2070
+ "epoch": 0.8614043908161555,
2071
+ "grad_norm": 1.7161706686019897,
2072
+ "learning_rate": 1.4548022598870056e-05,
2073
+ "loss": 16.4907,
2074
+ "mean_token_accuracy": 0.7054763376712799,
2075
+ "step": 1285
2076
+ },
2077
+ {
2078
+ "epoch": 0.864756158873806,
2079
+ "grad_norm": 1.5910500288009644,
2080
+ "learning_rate": 1.419491525423729e-05,
2081
+ "loss": 16.3073,
2082
+ "mean_token_accuracy": 0.7165283918380737,
2083
+ "step": 1290
2084
+ },
2085
+ {
2086
+ "epoch": 0.8681079269314563,
2087
+ "grad_norm": 1.5939749479293823,
2088
+ "learning_rate": 1.384180790960452e-05,
2089
+ "loss": 16.6524,
2090
+ "mean_token_accuracy": 0.705347529053688,
2091
+ "step": 1295
2092
+ },
2093
+ {
2094
+ "epoch": 0.8714596949891068,
2095
+ "grad_norm": 1.7478996515274048,
2096
+ "learning_rate": 1.3488700564971752e-05,
2097
+ "loss": 17.1832,
2098
+ "mean_token_accuracy": 0.6956523738801479,
2099
+ "step": 1300
2100
+ },
2101
+ {
2102
+ "epoch": 0.8748114630467572,
2103
+ "grad_norm": 1.6442205905914307,
2104
+ "learning_rate": 1.3135593220338985e-05,
2105
+ "loss": 16.3978,
2106
+ "mean_token_accuracy": 0.7132278561592102,
2107
+ "step": 1305
2108
+ },
2109
+ {
2110
+ "epoch": 0.8781632311044075,
2111
+ "grad_norm": 1.7201565504074097,
2112
+ "learning_rate": 1.2782485875706216e-05,
2113
+ "loss": 16.3159,
2114
+ "mean_token_accuracy": 0.711051919311285,
2115
+ "step": 1310
2116
+ },
2117
+ {
2118
+ "epoch": 0.881514999162058,
2119
+ "grad_norm": 1.829209327697754,
2120
+ "learning_rate": 1.2429378531073447e-05,
2121
+ "loss": 16.7987,
2122
+ "mean_token_accuracy": 0.7058401651680469,
2123
+ "step": 1315
2124
+ },
2125
+ {
2126
+ "epoch": 0.8848667672197084,
2127
+ "grad_norm": 1.4660886526107788,
2128
+ "learning_rate": 1.2076271186440678e-05,
2129
+ "loss": 16.7297,
2130
+ "mean_token_accuracy": 0.7092804253101349,
2131
+ "step": 1320
2132
+ },
2133
+ {
2134
+ "epoch": 0.8882185352773588,
2135
+ "grad_norm": 1.4927663803100586,
2136
+ "learning_rate": 1.172316384180791e-05,
2137
+ "loss": 15.9333,
2138
+ "mean_token_accuracy": 0.7158772744238376,
2139
+ "step": 1325
2140
+ },
2141
+ {
2142
+ "epoch": 0.8915703033350092,
2143
+ "grad_norm": 1.6522186994552612,
2144
+ "learning_rate": 1.137005649717514e-05,
2145
+ "loss": 16.4156,
2146
+ "mean_token_accuracy": 0.7134528748691082,
2147
+ "step": 1330
2148
+ },
2149
+ {
2150
+ "epoch": 0.8949220713926597,
2151
+ "grad_norm": 1.7809523344039917,
2152
+ "learning_rate": 1.1016949152542374e-05,
2153
+ "loss": 16.2625,
2154
+ "mean_token_accuracy": 0.7148336976766586,
2155
+ "step": 1335
2156
+ },
2157
+ {
2158
+ "epoch": 0.89827383945031,
2159
+ "grad_norm": 1.8860619068145752,
2160
+ "learning_rate": 1.0663841807909605e-05,
2161
+ "loss": 16.6187,
2162
+ "mean_token_accuracy": 0.7087382405996323,
2163
+ "step": 1340
2164
+ },
2165
+ {
2166
+ "epoch": 0.9016256075079605,
2167
+ "grad_norm": 1.854195475578308,
2168
+ "learning_rate": 1.0310734463276836e-05,
2169
+ "loss": 16.5843,
2170
+ "mean_token_accuracy": 0.7144103929400444,
2171
+ "step": 1345
2172
+ },
2173
+ {
2174
+ "epoch": 0.9049773755656109,
2175
+ "grad_norm": 1.7052239179611206,
2176
+ "learning_rate": 9.957627118644067e-06,
2177
+ "loss": 16.3345,
2178
+ "mean_token_accuracy": 0.7125584341585636,
2179
+ "step": 1350
2180
+ },
2181
+ {
2182
+ "epoch": 0.9083291436232612,
2183
+ "grad_norm": 1.5887420177459717,
2184
+ "learning_rate": 9.6045197740113e-06,
2185
+ "loss": 16.2409,
2186
+ "mean_token_accuracy": 0.7080107174813748,
2187
+ "step": 1355
2188
+ },
2189
+ {
2190
+ "epoch": 0.9116809116809117,
2191
+ "grad_norm": 1.6052732467651367,
2192
+ "learning_rate": 9.251412429378532e-06,
2193
+ "loss": 16.2373,
2194
+ "mean_token_accuracy": 0.7137157171964645,
2195
+ "step": 1360
2196
+ },
2197
+ {
2198
+ "epoch": 0.9150326797385621,
2199
+ "grad_norm": 1.7612617015838623,
2200
+ "learning_rate": 8.898305084745763e-06,
2201
+ "loss": 16.0292,
2202
+ "mean_token_accuracy": 0.7181592255830764,
2203
+ "step": 1365
2204
+ },
2205
+ {
2206
+ "epoch": 0.9183844477962125,
2207
+ "grad_norm": 1.8271749019622803,
2208
+ "learning_rate": 8.545197740112996e-06,
2209
+ "loss": 16.8757,
2210
+ "mean_token_accuracy": 0.701992305368185,
2211
+ "step": 1370
2212
+ },
2213
+ {
2214
+ "epoch": 0.9217362158538629,
2215
+ "grad_norm": 1.6350926160812378,
2216
+ "learning_rate": 8.192090395480225e-06,
2217
+ "loss": 16.6061,
2218
+ "mean_token_accuracy": 0.7089238859713077,
2219
+ "step": 1375
2220
+ },
2221
+ {
2222
+ "epoch": 0.9250879839115134,
2223
+ "grad_norm": 1.7321621179580688,
2224
+ "learning_rate": 7.838983050847458e-06,
2225
+ "loss": 16.2532,
2226
+ "mean_token_accuracy": 0.7115737572312355,
2227
+ "step": 1380
2228
+ },
2229
+ {
2230
+ "epoch": 0.9284397519691637,
2231
+ "grad_norm": 1.8958040475845337,
2232
+ "learning_rate": 7.48587570621469e-06,
2233
+ "loss": 16.5068,
2234
+ "mean_token_accuracy": 0.7108790181577206,
2235
+ "step": 1385
2236
+ },
2237
+ {
2238
+ "epoch": 0.9317915200268141,
2239
+ "grad_norm": 1.629992127418518,
2240
+ "learning_rate": 7.1327683615819206e-06,
2241
+ "loss": 16.2367,
2242
+ "mean_token_accuracy": 0.7134776934981346,
2243
+ "step": 1390
2244
+ },
2245
+ {
2246
+ "epoch": 0.9351432880844646,
2247
+ "grad_norm": 1.904123067855835,
2248
+ "learning_rate": 6.779661016949153e-06,
2249
+ "loss": 16.3444,
2250
+ "mean_token_accuracy": 0.7045241884887219,
2251
+ "step": 1395
2252
+ },
2253
+ {
2254
+ "epoch": 0.9384950561421149,
2255
+ "grad_norm": 1.6319600343704224,
2256
+ "learning_rate": 6.426553672316385e-06,
2257
+ "loss": 16.3,
2258
+ "mean_token_accuracy": 0.7118948072195053,
2259
+ "step": 1400
2260
+ },
2261
+ {
2262
+ "epoch": 0.9418468241997654,
2263
+ "grad_norm": 1.6921709775924683,
2264
+ "learning_rate": 6.073446327683617e-06,
2265
+ "loss": 16.5816,
2266
+ "mean_token_accuracy": 0.7079687170684338,
2267
+ "step": 1405
2268
+ },
2269
+ {
2270
+ "epoch": 0.9451985922574158,
2271
+ "grad_norm": 1.636551856994629,
2272
+ "learning_rate": 5.720338983050848e-06,
2273
+ "loss": 16.785,
2274
+ "mean_token_accuracy": 0.7054948009550571,
2275
+ "step": 1410
2276
+ },
2277
+ {
2278
+ "epoch": 0.9485503603150662,
2279
+ "grad_norm": 1.6171858310699463,
2280
+ "learning_rate": 5.367231638418079e-06,
2281
+ "loss": 16.6877,
2282
+ "mean_token_accuracy": 0.7033485405147075,
2283
+ "step": 1415
2284
+ },
2285
+ {
2286
+ "epoch": 0.9519021283727166,
2287
+ "grad_norm": 1.6833641529083252,
2288
+ "learning_rate": 5.014124293785311e-06,
2289
+ "loss": 16.5803,
2290
+ "mean_token_accuracy": 0.706027402728796,
2291
+ "step": 1420
2292
+ },
2293
+ {
2294
+ "epoch": 0.9552538964303671,
2295
+ "grad_norm": 2.0238494873046875,
2296
+ "learning_rate": 4.6610169491525425e-06,
2297
+ "loss": 16.4305,
2298
+ "mean_token_accuracy": 0.7110757566988468,
2299
+ "step": 1425
2300
+ },
2301
+ {
2302
+ "epoch": 0.9586056644880174,
2303
+ "grad_norm": 1.5262683629989624,
2304
+ "learning_rate": 4.307909604519774e-06,
2305
+ "loss": 16.105,
2306
+ "mean_token_accuracy": 0.7173994883894921,
2307
+ "step": 1430
2308
+ },
2309
+ {
2310
+ "epoch": 0.9619574325456678,
2311
+ "grad_norm": 1.6822128295898438,
2312
+ "learning_rate": 3.954802259887006e-06,
2313
+ "loss": 17.0064,
2314
+ "mean_token_accuracy": 0.7033144362270832,
2315
+ "step": 1435
2316
+ },
2317
+ {
2318
+ "epoch": 0.9653092006033183,
2319
+ "grad_norm": 2.1382946968078613,
2320
+ "learning_rate": 3.6016949152542374e-06,
2321
+ "loss": 16.6567,
2322
+ "mean_token_accuracy": 0.7085098147392273,
2323
+ "step": 1440
2324
+ },
2325
+ {
2326
+ "epoch": 0.9686609686609686,
2327
+ "grad_norm": 1.6137080192565918,
2328
+ "learning_rate": 3.248587570621469e-06,
2329
+ "loss": 16.4193,
2330
+ "mean_token_accuracy": 0.7077061600983143,
2331
+ "step": 1445
2332
+ },
2333
+ {
2334
+ "epoch": 0.9720127367186191,
2335
+ "grad_norm": 1.6318018436431885,
2336
+ "learning_rate": 2.8954802259887007e-06,
2337
+ "loss": 16.5904,
2338
+ "mean_token_accuracy": 0.7037704810500145,
2339
+ "step": 1450
2340
+ },
2341
+ {
2342
+ "epoch": 0.9753645047762695,
2343
+ "grad_norm": 1.6723519563674927,
2344
+ "learning_rate": 2.5423728813559323e-06,
2345
+ "loss": 16.351,
2346
+ "mean_token_accuracy": 0.715372896194458,
2347
+ "step": 1455
2348
+ },
2349
+ {
2350
+ "epoch": 0.9787162728339199,
2351
+ "grad_norm": 2.6915719509124756,
2352
+ "learning_rate": 2.189265536723164e-06,
2353
+ "loss": 16.5627,
2354
+ "mean_token_accuracy": 0.706637478619814,
2355
+ "step": 1460
2356
+ },
2357
+ {
2358
+ "epoch": 0.9820680408915703,
2359
+ "grad_norm": 1.9349390268325806,
2360
+ "learning_rate": 1.8361581920903956e-06,
2361
+ "loss": 16.7821,
2362
+ "mean_token_accuracy": 0.7010103747248649,
2363
+ "step": 1465
2364
+ },
2365
+ {
2366
+ "epoch": 0.9854198089492208,
2367
+ "grad_norm": 1.6685172319412231,
2368
+ "learning_rate": 1.4830508474576273e-06,
2369
+ "loss": 16.7016,
2370
+ "mean_token_accuracy": 0.7086931586265564,
2371
+ "step": 1470
2372
+ },
2373
+ {
2374
+ "epoch": 0.9887715770068711,
2375
+ "grad_norm": 1.7148998975753784,
2376
+ "learning_rate": 1.129943502824859e-06,
2377
+ "loss": 16.4809,
2378
+ "mean_token_accuracy": 0.7131018862128258,
2379
+ "step": 1475
2380
+ },
2381
+ {
2382
+ "epoch": 0.9921233450645215,
2383
+ "grad_norm": 1.8873836994171143,
2384
+ "learning_rate": 7.768361581920904e-07,
2385
+ "loss": 16.5183,
2386
+ "mean_token_accuracy": 0.7111847102642059,
2387
+ "step": 1480
2388
+ },
2389
+ {
2390
+ "epoch": 0.995475113122172,
2391
+ "grad_norm": 1.8390552997589111,
2392
+ "learning_rate": 4.2372881355932204e-07,
2393
+ "loss": 16.1742,
2394
+ "mean_token_accuracy": 0.7128683432936669,
2395
+ "step": 1485
2396
+ },
2397
+ {
2398
+ "epoch": 0.9988268811798223,
2399
+ "grad_norm": 1.8799461126327515,
2400
+ "learning_rate": 7.062146892655368e-08,
2401
+ "loss": 17.1633,
2402
+ "mean_token_accuracy": 0.6963419988751411,
2403
+ "step": 1490
2404
+ }
2405
+ ],
2406
+ "logging_steps": 5,
2407
+ "max_steps": 1491,
2408
+ "num_input_tokens_seen": 0,
2409
+ "num_train_epochs": 1,
2410
+ "save_steps": 750,
2411
+ "stateful_callbacks": {
2412
+ "TrainerControl": {
2413
+ "args": {
2414
+ "should_epoch_stop": false,
2415
+ "should_evaluate": false,
2416
+ "should_log": false,
2417
+ "should_save": true,
2418
+ "should_training_stop": true
2419
+ },
2420
+ "attributes": {}
2421
+ }
2422
+ },
2423
+ "total_flos": 1.5012213304045076e+19,
2424
+ "train_batch_size": 1,
2425
+ "trial_name": null,
2426
+ "trial_params": null
2427
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4615ef1249a88d7982238a4f457d81abd14a1991eab8956de34c050cdf3445ef
3
+ size 7224
zero_to_fp32.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python
3
+
4
+ # Copyright (c) Microsoft Corporation.
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # DeepSpeed Team
8
+
9
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
10
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
11
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
12
+ # application.
13
+ #
14
+ # example:
15
+ # python zero_to_fp32.py . output_dir/
16
+ # or
17
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
18
+
19
+ import argparse
20
+ import torch
21
+ import glob
22
+ import math
23
+ import os
24
+ import re
25
+ import gc
26
+ import json
27
+ import numpy as np
28
+ from tqdm import tqdm
29
+ from collections import OrderedDict
30
+ from dataclasses import dataclass
31
+
32
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
33
+ # DeepSpeed data structures it has to be available in the current python environment.
34
+ from deepspeed.utils import logger
35
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
36
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
37
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
38
+
39
+
40
+ @dataclass
41
+ class zero_model_state:
42
+ buffers: dict()
43
+ param_shapes: dict()
44
+ shared_params: list
45
+ ds_version: int
46
+ frozen_param_shapes: dict()
47
+ frozen_param_fragments: dict()
48
+
49
+
50
+ debug = 0
51
+
52
+ # load to cpu
53
+ device = torch.device('cpu')
54
+
55
+
56
+ def atoi(text):
57
+ return int(text) if text.isdigit() else text
58
+
59
+
60
+ def natural_keys(text):
61
+ '''
62
+ alist.sort(key=natural_keys) sorts in human order
63
+ http://nedbatchelder.com/blog/200712/human_sorting.html
64
+ (See Toothy's implementation in the comments)
65
+ '''
66
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
67
+
68
+
69
+ def get_model_state_file(checkpoint_dir, zero_stage):
70
+ if not os.path.isdir(checkpoint_dir):
71
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
72
+
73
+ # there should be only one file
74
+ if zero_stage <= 2:
75
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
76
+ elif zero_stage == 3:
77
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
78
+
79
+ if not os.path.exists(file):
80
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
81
+
82
+ return file
83
+
84
+
85
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
86
+ # XXX: need to test that this simple glob rule works for multi-node setup too
87
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
88
+
89
+ if len(ckpt_files) == 0:
90
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
91
+
92
+ return ckpt_files
93
+
94
+
95
+ def get_optim_files(checkpoint_dir):
96
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
97
+
98
+
99
+ def get_model_state_files(checkpoint_dir):
100
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
101
+
102
+
103
+ def parse_model_states(files):
104
+ zero_model_states = []
105
+ for file in files:
106
+ state_dict = torch.load(file, map_location=device, weights_only=False)
107
+
108
+ if BUFFER_NAMES not in state_dict:
109
+ raise ValueError(f"{file} is not a model state checkpoint")
110
+ buffer_names = state_dict[BUFFER_NAMES]
111
+ if debug:
112
+ print("Found buffers:", buffer_names)
113
+
114
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
115
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
116
+ param_shapes = state_dict[PARAM_SHAPES]
117
+
118
+ # collect parameters that are included in param_shapes
119
+ param_names = []
120
+ for s in param_shapes:
121
+ for name in s.keys():
122
+ param_names.append(name)
123
+
124
+ # update with frozen parameters
125
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
126
+ if frozen_param_shapes is not None:
127
+ if debug:
128
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
129
+ param_names += list(frozen_param_shapes.keys())
130
+
131
+ # handle shared params
132
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
133
+
134
+ ds_version = state_dict.get(DS_VERSION, None)
135
+
136
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
137
+
138
+ z_model_state = zero_model_state(buffers=buffers,
139
+ param_shapes=param_shapes,
140
+ shared_params=shared_params,
141
+ ds_version=ds_version,
142
+ frozen_param_shapes=frozen_param_shapes,
143
+ frozen_param_fragments=frozen_param_fragments)
144
+ zero_model_states.append(z_model_state)
145
+
146
+ return zero_model_states
147
+
148
+
149
+ def parse_optim_states(files, ds_checkpoint_dir):
150
+ total_files = len(files)
151
+ state_dicts = []
152
+ for f in tqdm(files, desc='Loading checkpoint shards'):
153
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
154
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
155
+ # and also handle the case where it was already removed by another helper script
156
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
157
+ state_dicts.append(state_dict)
158
+
159
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
160
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
161
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
162
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
163
+
164
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
165
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
166
+ # use the max of the partition_count to get the dp world_size.
167
+
168
+ if type(world_size) is list:
169
+ world_size = max(world_size)
170
+
171
+ if world_size != total_files:
172
+ raise ValueError(
173
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
174
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
175
+ )
176
+
177
+ # the groups are named differently in each stage
178
+ if zero_stage <= 2:
179
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
180
+ elif zero_stage == 3:
181
+ fp32_groups_key = FP32_FLAT_GROUPS
182
+ else:
183
+ raise ValueError(f"unknown zero stage {zero_stage}")
184
+
185
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
186
+ return zero_stage, world_size, fp32_flat_groups
187
+
188
+
189
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
190
+ """
191
+ Returns fp32 state_dict reconstructed from ds checkpoint
192
+
193
+ Args:
194
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
195
+
196
+ """
197
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
198
+
199
+ optim_files = get_optim_files(ds_checkpoint_dir)
200
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
201
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
202
+
203
+ model_files = get_model_state_files(ds_checkpoint_dir)
204
+
205
+ zero_model_states = parse_model_states(model_files)
206
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
207
+
208
+ if zero_stage <= 2:
209
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
210
+ exclude_frozen_parameters)
211
+ elif zero_stage == 3:
212
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
213
+ exclude_frozen_parameters)
214
+
215
+
216
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
217
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
218
+ return
219
+
220
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
221
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
222
+
223
+ if debug:
224
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
225
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
226
+
227
+ wanted_params = len(frozen_param_shapes)
228
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
229
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
230
+ print(f'Frozen params: Have {avail_numel} numels to process.')
231
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
232
+
233
+ total_params = 0
234
+ total_numel = 0
235
+ for name, shape in frozen_param_shapes.items():
236
+ total_params += 1
237
+ unpartitioned_numel = shape.numel()
238
+ total_numel += unpartitioned_numel
239
+
240
+ state_dict[name] = frozen_param_fragments[name]
241
+
242
+ if debug:
243
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
244
+
245
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
246
+
247
+
248
+ def _has_callable(obj, fn):
249
+ attr = getattr(obj, fn, None)
250
+ return callable(attr)
251
+
252
+
253
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
254
+ param_shapes = zero_model_states[0].param_shapes
255
+
256
+ # Reconstruction protocol:
257
+ #
258
+ # XXX: document this
259
+
260
+ if debug:
261
+ for i in range(world_size):
262
+ for j in range(len(fp32_flat_groups[0])):
263
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
264
+
265
+ # XXX: memory usage doubles here (zero2)
266
+ num_param_groups = len(fp32_flat_groups[0])
267
+ merged_single_partition_of_fp32_groups = []
268
+ for i in range(num_param_groups):
269
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
270
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
271
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
272
+ avail_numel = sum(
273
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
274
+
275
+ if debug:
276
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
277
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
278
+ # not asserting if there is a mismatch due to possible padding
279
+ print(f"Have {avail_numel} numels to process.")
280
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
281
+
282
+ # params
283
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
284
+ # out-of-core computing solution
285
+ total_numel = 0
286
+ total_params = 0
287
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
288
+ offset = 0
289
+ avail_numel = full_single_fp32_vector.numel()
290
+ for name, shape in shapes.items():
291
+
292
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
293
+ total_numel += unpartitioned_numel
294
+ total_params += 1
295
+
296
+ if debug:
297
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
298
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
299
+ offset += unpartitioned_numel
300
+
301
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
302
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
303
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
304
+ # live optimizer object, so we are checking that the numbers are within the right range
305
+ align_to = 2 * world_size
306
+
307
+ def zero2_align(x):
308
+ return align_to * math.ceil(x / align_to)
309
+
310
+ if debug:
311
+ print(f"original offset={offset}, avail_numel={avail_numel}")
312
+
313
+ offset = zero2_align(offset)
314
+ avail_numel = zero2_align(avail_numel)
315
+
316
+ if debug:
317
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
318
+
319
+ # Sanity check
320
+ if offset != avail_numel:
321
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
322
+
323
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
324
+
325
+
326
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
327
+ exclude_frozen_parameters):
328
+ state_dict = OrderedDict()
329
+
330
+ # buffers
331
+ buffers = zero_model_states[0].buffers
332
+ state_dict.update(buffers)
333
+ if debug:
334
+ print(f"added {len(buffers)} buffers")
335
+
336
+ if not exclude_frozen_parameters:
337
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
338
+
339
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
340
+
341
+ # recover shared parameters
342
+ for pair in zero_model_states[0].shared_params:
343
+ if pair[1] in state_dict:
344
+ state_dict[pair[0]] = state_dict[pair[1]]
345
+
346
+ return state_dict
347
+
348
+
349
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
350
+ remainder = unpartitioned_numel % world_size
351
+ padding_numel = (world_size - remainder) if remainder else 0
352
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
353
+ return partitioned_numel, padding_numel
354
+
355
+
356
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
357
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
358
+ return
359
+
360
+ if debug:
361
+ for i in range(world_size):
362
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
363
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
364
+
365
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
366
+ wanted_params = len(frozen_param_shapes)
367
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
368
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
369
+ print(f'Frozen params: Have {avail_numel} numels to process.')
370
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
371
+
372
+ total_params = 0
373
+ total_numel = 0
374
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
375
+ total_params += 1
376
+ unpartitioned_numel = shape.numel()
377
+ total_numel += unpartitioned_numel
378
+
379
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
380
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
381
+
382
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
383
+
384
+ if debug:
385
+ print(
386
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
387
+ )
388
+
389
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
390
+
391
+
392
+ class GatheredTensor:
393
+ """
394
+ A pseudo tensor that collects partitioned weights.
395
+ It is more memory efficient when there are multiple groups.
396
+ """
397
+
398
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
399
+ self.flat_groups = flat_groups
400
+ self.flat_groups_offset = flat_groups_offset
401
+ self.offset = offset
402
+ self.partitioned_numel = partitioned_numel
403
+ self.shape = shape
404
+ self.dtype = self.flat_groups[0][0].dtype
405
+
406
+ def contiguous(self):
407
+ """
408
+ Merge partitioned weights from flat_groups into a single tensor.
409
+ """
410
+ end_idx = self.offset + self.partitioned_numel
411
+ world_size = len(self.flat_groups)
412
+ pad_flat_param_chunks = []
413
+
414
+ for rank_i in range(world_size):
415
+ # for each rank, we need to collect weights from related group/groups
416
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
417
+ start_group_id = None
418
+ end_group_id = None
419
+ for group_id in range(len(self.flat_groups_offset)):
420
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
421
+ start_group_id = group_id
422
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
423
+ end_group_id = group_id
424
+ break
425
+ # collect weights from related group/groups
426
+ for group_id in range(start_group_id, end_group_id + 1):
427
+ flat_tensor = flat_groups_at_rank_i[group_id]
428
+ start_offset = self.offset - self.flat_groups_offset[group_id]
429
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
430
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
431
+
432
+ # collect weights from all ranks
433
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
434
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
435
+ return param
436
+
437
+
438
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
439
+ param_shapes = zero_model_states[0].param_shapes
440
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
441
+
442
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
443
+ # param, re-consolidating each param, while dealing with padding if any
444
+
445
+ # merge list of dicts, preserving order
446
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
447
+
448
+ if debug:
449
+ for i in range(world_size):
450
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
451
+
452
+ wanted_params = len(param_shapes)
453
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
454
+ # not asserting if there is a mismatch due to possible padding
455
+ avail_numel = fp32_flat_groups[0].numel() * world_size
456
+ print(f"Trainable params: Have {avail_numel} numels to process.")
457
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
458
+
459
+ # params
460
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
461
+ # out-of-core computing solution
462
+ offset = 0
463
+ total_numel = 0
464
+ total_params = 0
465
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
466
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
467
+ unpartitioned_numel = shape.numel()
468
+ total_numel += unpartitioned_numel
469
+ total_params += 1
470
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
471
+
472
+ if debug:
473
+ print(
474
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
475
+ )
476
+
477
+ # memory efficient tensor
478
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
479
+ state_dict[name] = tensor
480
+ offset += partitioned_numel
481
+
482
+ offset *= world_size
483
+
484
+ # Sanity check
485
+ if offset != avail_numel:
486
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
487
+
488
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
489
+
490
+
491
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
492
+ exclude_frozen_parameters):
493
+ state_dict = OrderedDict()
494
+
495
+ # buffers
496
+ buffers = zero_model_states[0].buffers
497
+ state_dict.update(buffers)
498
+ if debug:
499
+ print(f"added {len(buffers)} buffers")
500
+
501
+ if not exclude_frozen_parameters:
502
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
503
+
504
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
505
+
506
+ # recover shared parameters
507
+ for pair in zero_model_states[0].shared_params:
508
+ if pair[1] in state_dict:
509
+ state_dict[pair[0]] = state_dict[pair[1]]
510
+
511
+ return state_dict
512
+
513
+
514
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
515
+ """
516
+ Convert state_dict of GatheredTensor to torch tensor
517
+ """
518
+ torch_state_dict = {}
519
+ converted_tensors = {}
520
+ for name, tensor in state_dict.items():
521
+ tensor_id = id(tensor)
522
+ if tensor_id in converted_tensors: # shared tensors
523
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
524
+ torch_state_dict[name] = shared_tensor
525
+ else:
526
+ converted_tensors[tensor_id] = name
527
+ if return_empty_tensor:
528
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
529
+ else:
530
+ torch_state_dict[name] = tensor.contiguous()
531
+ return torch_state_dict
532
+
533
+
534
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
535
+ tag=None,
536
+ exclude_frozen_parameters=False,
537
+ lazy_mode=False):
538
+ """
539
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
540
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
541
+ via a model hub.
542
+
543
+ Args:
544
+ - ``checkpoint_dir``: path to the desired checkpoint folder
545
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
546
+ - ``exclude_frozen_parameters``: exclude frozen parameters
547
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
548
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
549
+
550
+ Returns:
551
+ - pytorch ``state_dict``
552
+
553
+ A typical usage might be ::
554
+
555
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
556
+ # do the training and checkpoint saving
557
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
558
+ model = model.cpu() # move to cpu
559
+ model.load_state_dict(state_dict)
560
+ # submit to model hub or save the model to share with others
561
+
562
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
563
+ application. i.e. you will need to re-initialize the deepspeed engine, since
564
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
565
+
566
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
567
+
568
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
569
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
570
+ the checkpoint. Or you can load state_dict in lazy mode ::
571
+
572
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
573
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
574
+ for name, lazy_tensor in state_dict.item():
575
+ tensor = lazy_tensor.contiguous() # to cpu
576
+ print(name, tensor)
577
+ # del tensor to release memory if it no longer in use
578
+ """
579
+ if tag is None:
580
+ latest_path = os.path.join(checkpoint_dir, 'latest')
581
+ if os.path.isfile(latest_path):
582
+ with open(latest_path, 'r') as fd:
583
+ tag = fd.read().strip()
584
+ else:
585
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
586
+
587
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
588
+
589
+ if not os.path.isdir(ds_checkpoint_dir):
590
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
591
+
592
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
593
+ if lazy_mode:
594
+ return state_dict
595
+ else:
596
+ return to_torch_tensor(state_dict)
597
+
598
+
599
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
600
+ output_dir,
601
+ max_shard_size="5GB",
602
+ safe_serialization=False,
603
+ tag=None,
604
+ exclude_frozen_parameters=False):
605
+ """
606
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
607
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
608
+
609
+ Args:
610
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
611
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
612
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
613
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
614
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
615
+ - ``exclude_frozen_parameters``: exclude frozen parameters
616
+ """
617
+
618
+ # Dependency pre-check
619
+ if safe_serialization:
620
+ try:
621
+ from safetensors.torch import save_file
622
+ except ImportError:
623
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
624
+ raise
625
+ if max_shard_size is not None:
626
+ try:
627
+ from huggingface_hub import split_torch_state_dict_into_shards
628
+ except ImportError:
629
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
630
+ raise
631
+
632
+ # Convert zero checkpoint to state_dict
633
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
634
+ tag,
635
+ exclude_frozen_parameters,
636
+ lazy_mode=True)
637
+
638
+ # Shard the model if it is too big.
639
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
640
+ if max_shard_size is not None:
641
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
642
+ # an memory-efficient approach for sharding
643
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
644
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
645
+ filename_pattern=filename_pattern,
646
+ max_shard_size=max_shard_size)
647
+ else:
648
+ from collections import namedtuple
649
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
650
+ state_dict_split = StateDictSplit(is_sharded=False,
651
+ filename_to_tensors={weights_name: list(state_dict.keys())})
652
+
653
+ # Save the model by shard
654
+ os.makedirs(output_dir, exist_ok=True)
655
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
656
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
657
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
658
+ shard_state_dict = to_torch_tensor(shard_state_dict)
659
+ output_path = os.path.join(output_dir, shard_file)
660
+ if safe_serialization:
661
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
662
+ else:
663
+ torch.save(shard_state_dict, output_path)
664
+ # release the memory of current shard
665
+ for tensor_name in list(shard_state_dict.keys()):
666
+ del state_dict[tensor_name]
667
+ del shard_state_dict[tensor_name]
668
+ del shard_state_dict
669
+ gc.collect()
670
+
671
+ # Save index if sharded
672
+ if state_dict_split.is_sharded:
673
+ index = {
674
+ "metadata": state_dict_split.metadata,
675
+ "weight_map": state_dict_split.tensor_to_filename,
676
+ }
677
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
678
+ save_index_file = os.path.join(output_dir, save_index_file)
679
+ with open(save_index_file, "w", encoding="utf-8") as f:
680
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
681
+ f.write(content)
682
+
683
+
684
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
685
+ """
686
+ 1. Put the provided model to cpu
687
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
688
+ 3. Load it into the provided model
689
+
690
+ Args:
691
+ - ``model``: the model object to update
692
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
693
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
694
+
695
+ Returns:
696
+ - ``model`: modified model
697
+
698
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
699
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
700
+ conveniently placed for you in the checkpoint folder.
701
+
702
+ A typical usage might be ::
703
+
704
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
705
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
706
+ # submit to model hub or save the model to share with others
707
+
708
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
709
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
710
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
711
+
712
+ """
713
+ logger.info(f"Extracting fp32 weights")
714
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
715
+
716
+ logger.info(f"Overwriting model with fp32 weights")
717
+ model = model.cpu()
718
+ model.load_state_dict(state_dict, strict=False)
719
+
720
+ return model
721
+
722
+
723
+ if __name__ == "__main__":
724
+ parser = argparse.ArgumentParser()
725
+ parser.add_argument("checkpoint_dir",
726
+ type=str,
727
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
728
+ parser.add_argument("output_dir",
729
+ type=str,
730
+ help="directory to the pytorch fp32 state_dict output files"
731
+ "(e.g. path/checkpoint-12-output/)")
732
+ parser.add_argument(
733
+ "--max_shard_size",
734
+ type=str,
735
+ default="5GB",
736
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
737
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
738
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
739
+ "without CPU OOM issues.")
740
+ parser.add_argument(
741
+ "--safe_serialization",
742
+ default=False,
743
+ action='store_true',
744
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
745
+ parser.add_argument("-t",
746
+ "--tag",
747
+ type=str,
748
+ default=None,
749
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
750
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
751
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
752
+ args = parser.parse_args()
753
+
754
+ debug = args.debug
755
+
756
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
757
+ args.output_dir,
758
+ max_shard_size=args.max_shard_size,
759
+ safe_serialization=args.safe_serialization,
760
+ tag=args.tag,
761
+ exclude_frozen_parameters=args.exclude_frozen_parameters)