SS3M commited on
Commit
1299746
·
verified ·
1 Parent(s): 599a668

Upload 4_doc_level_entities_5's state dict

Browse files
.gitattributes CHANGED
@@ -46,3 +46,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
46
  1_doc_level_issues_3/results/1_doc_level_issues_3_pred_test.json filter=lfs diff=lfs merge=lfs -text
47
  1_pointer_base_entities_4/logs/1_pointer_base_entities_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
48
  1_pointer_base_issues_4/logs/1_pointer_base_issues_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
46
  1_doc_level_issues_3/results/1_doc_level_issues_3_pred_test.json filter=lfs diff=lfs merge=lfs -text
47
  1_pointer_base_entities_4/logs/1_pointer_base_entities_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
48
  1_pointer_base_issues_4/logs/1_pointer_base_issues_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
49
+ 4_doc_level_entities_5/logs/4_doc_level_entities_5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
4_doc_level_entities_5/4_doc_level_entities_5.py ADDED
@@ -0,0 +1,1630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'kltn/raw' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 16
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "4_doc_level_entities_5"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 3
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ }
145
+ val_memory_params = {
146
+ 'max_len': max_len,
147
+ 'max_n_parts': max_n_parts,
148
+ }
149
+
150
+ # Data Loader
151
+ def seed_worker(worker_id):
152
+ worker_seed = torch.initial_seed() % 2**32
153
+ np.random.seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ train_loader_params = {
157
+ 'batch_size': batch_size,
158
+ 'shuffle': True,
159
+ 'pin_memory':True,
160
+ 'num_workers': 2,
161
+ 'drop_last': False,
162
+ 'worker_init_fn': seed_worker,
163
+ 'persistent_workers': False,
164
+ }
165
+ val_loader_params = {
166
+ 'batch_size': batch_size,
167
+ 'shuffle': False,
168
+ 'pin_memory':True,
169
+ 'num_workers': 1,
170
+ 'drop_last': False,
171
+ 'worker_init_fn': seed_worker,
172
+ 'persistent_workers': False,
173
+ }
174
+
175
+ # Model
176
+ model_params = {
177
+ 'backbone_model_name': backbone_model_name,
178
+ }
179
+
180
+ # Loss Func
181
+ loss_func_params = {
182
+ 'lambda_ce': 1.0,
183
+ }
184
+ eval_func_params = {}
185
+
186
+ # Optim
187
+ optim_params = {
188
+ 'name': 'AdamW',
189
+ 'lr': 1e-4,
190
+ 'weight_decay': 1e-4,
191
+ }
192
+ scheduler_params = {
193
+ 'name': 'CosineAnnealingLR',
194
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
195
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
196
+ }
197
+
198
+ # %% [code]
199
+ def set_seed(seed=42):
200
+ random.seed(seed)
201
+ np.random.seed(seed)
202
+ torch.manual_seed(seed)
203
+ torch.cuda.manual_seed(seed)
204
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
205
+ torch.use_deterministic_algorithms(False)
206
+ torch.backends.cudnn.deterministic = True
207
+ torch.backends.cudnn.benchmark = False
208
+ os.environ['PYTHONHASHSEED'] = str(seed)
209
+
210
+ # %% [code]
211
+ class CustomLoss(nn.Module):
212
+ def __init__(self, lambda_ce=1.0):
213
+ super().__init__()
214
+ self.lambda_ce = lambda_ce
215
+ self.ce = nn.CrossEntropyLoss(ignore_index=-100)
216
+
217
+ def forward(
218
+ self,
219
+ start_logits, start_labels,
220
+ end_logits, end_labels,
221
+ ):
222
+ device = start_logits.device
223
+
224
+ # ===== TRG CE =====
225
+ B, L, C = start_logits.shape
226
+
227
+ start_logits_flat = start_logits.view(B * L, C)
228
+ start_labels_flat = start_labels.view(-1)
229
+ start_loss = self.ce(start_logits_flat, start_labels_flat) # (B*N,)
230
+
231
+ end_logits_flat = end_logits.view(B * L, C)
232
+ end_labels_flat = end_labels.view(-1)
233
+ end_loss = self.ce(end_logits_flat, end_labels_flat) # (B*N,)
234
+
235
+ return {
236
+ "total": start_loss + end_loss,
237
+ "start_loss": start_loss,
238
+ "end_loss": end_loss,
239
+ }
240
+
241
+ # %% [code]
242
+ ## Viết eval_fn vào đây
243
+
244
+ # Bỏ hết eval_fn và trọng số vào đây
245
+ class CustomEvalFn(nn.Module):
246
+ def __init__(self):
247
+ super().__init__()
248
+
249
+ def compute_f1(self, tp, fp, fn):
250
+ precision = tp / (tp + fp + 1e-8)
251
+ recall = tp / (tp + fn + 1e-8)
252
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
253
+ return precision, recall, f1
254
+
255
+ def forward(self, pred, gold):
256
+ pred_set = set(pred)
257
+ gold_set = set(gold)
258
+
259
+ tp = len(pred_set & gold_set)
260
+ fp = len(pred_set - gold_set)
261
+ fn = len(gold_set - pred_set)
262
+
263
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
264
+
265
+ return {
266
+ f"precision": precision,
267
+ f"recall": recall,
268
+ f"f1": f1,
269
+ }
270
+
271
+ class SpanErrorAnalyzer:
272
+ def __init__(self, pad_token_id=0):
273
+ self.pad_token_id = pad_token_id
274
+
275
+ # ===== helper =====
276
+ def _to_set(self, data):
277
+ """
278
+ data: list of (b, tuple(ids))
279
+ -> dict[b] = set(tuple(ids))
280
+ """
281
+ res = defaultdict(set)
282
+ for b, ids in data:
283
+ ids = tuple([i for i in ids if i != self.pad_token_id])
284
+ if len(ids) > 0:
285
+ res[b].add(ids)
286
+ return res
287
+
288
+ def _iou(self, a, b):
289
+ """
290
+ a, b: tuple(ids)
291
+ """
292
+ set_a, set_b = set(a), set(b)
293
+ inter = len(set_a & set_b)
294
+ union = len(set_a | set_b)
295
+ if union == 0:
296
+ return 0.0
297
+ return inter / union
298
+
299
+ def _boundary_error(self, pred, gold):
300
+ """
301
+ đo lệch boundary dựa trên overlap prefix/suffix
302
+ """
303
+ # left match
304
+ left = 0
305
+ for i in range(min(len(pred), len(gold))):
306
+ if pred[i] == gold[i]:
307
+ left += 1
308
+ else:
309
+ break
310
+
311
+ # right match
312
+ right = 0
313
+ for i in range(1, min(len(pred), len(gold)) + 1):
314
+ if pred[-i] == gold[-i]:
315
+ right += 1
316
+ else:
317
+ break
318
+
319
+ return {
320
+ "left_match": left,
321
+ "right_match": right,
322
+ "pred_len": len(pred),
323
+ "gold_len": len(gold),
324
+ }
325
+
326
+ # ===== main =====
327
+ def analyze(self, preds, golds):
328
+ pred_map = self._to_set(preds)
329
+ gold_map = self._to_set(golds)
330
+
331
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
332
+
333
+ stats = Counter()
334
+
335
+ detailed_errors = []
336
+
337
+ for b in all_batches:
338
+ pset = pred_map.get(b, set())
339
+ gset = gold_map.get(b, set())
340
+
341
+ matched_gold = set()
342
+
343
+ # ===== check predictions =====
344
+ for p in pset:
345
+ if p in gset:
346
+ stats["exact_match"] += 1
347
+ matched_gold.add(p)
348
+ else:
349
+ # tìm gold gần nhất
350
+ best_iou = 0
351
+ best_g = None
352
+
353
+ for g in gset:
354
+ iou = self._iou(p, g)
355
+ if iou > best_iou:
356
+ best_iou = iou
357
+ best_g = g
358
+
359
+ if best_iou > 0:
360
+ stats["partial_match"] += 1
361
+
362
+ boundary = self._boundary_error(p, best_g)
363
+
364
+ detailed_errors.append({
365
+ "type": "boundary_error",
366
+ "batch": b,
367
+ "pred": p,
368
+ "gold": best_g,
369
+ "iou": best_iou,
370
+ **boundary
371
+ })
372
+ else:
373
+ if b not in gold_map:
374
+ stats["no_event_sample"] += 1
375
+ err_type = "no_event_sample"
376
+ else:
377
+ stats["completely_wrong"] += 1
378
+ err_type = "completely_wrong"
379
+
380
+ detailed_errors.append({
381
+ "type": err_type,
382
+ "batch": b,
383
+ "pred": p
384
+ })
385
+
386
+ # ===== check missing =====
387
+ for g in gset:
388
+ if g not in matched_gold:
389
+ # check if any pred overlaps
390
+ overlap = any(self._iou(p, g) > 0 for p in pset)
391
+
392
+ if overlap:
393
+ stats["miss_with_overlap"] += 1
394
+ else:
395
+ stats["miss"] += 1
396
+
397
+ detailed_errors.append({
398
+ "type": "miss",
399
+ "batch": b,
400
+ "gold": g
401
+ })
402
+
403
+ return {
404
+ "summary": {
405
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
406
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
407
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
408
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
409
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
410
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
411
+ },
412
+ "details": detailed_errors
413
+ }
414
+
415
+ # %% [code]
416
+ ## Viết cấu trúc model vào đây
417
+ class MLP(nn.Module):
418
+ def __init__(self, in_size, hid_size, out_size):
419
+ super().__init__()
420
+ self.mlp = nn.Sequential(
421
+ nn.Linear(in_size, hid_size),
422
+ nn.ReLU(),
423
+ nn.Linear(hid_size, out_size)
424
+ )
425
+
426
+ def forward(self, x):
427
+ return self.mlp(x)
428
+
429
+ class IEModel(nn.Module):
430
+ def __init__(self, backbone_model_name, num_labels):
431
+ super().__init__()
432
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
433
+ hidden_size = self.encoder.config.hidden_size
434
+
435
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
436
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
437
+
438
+ def encode(self, input_ids, attention_mask):
439
+ B, n_parts, L = input_ids.shape
440
+ input_ids = input_ids.view(-1, L)
441
+ attention_mask = attention_mask.view(-1, L)
442
+
443
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
444
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
445
+
446
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
447
+ return hidden_states
448
+
449
+ def get_logits(self, hidden_states):
450
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
451
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
452
+ return start_logits, end_logits
453
+
454
+ def forward(self, input_ids, attention_mask, labels=None):
455
+ hidden_states = self.encode(input_ids, attention_mask)
456
+ start_logits, end_logits = self.get_logits(hidden_states)
457
+ return start_logits, end_logits
458
+
459
+ def test():
460
+ model = nn.DataParallel(IEModel(backbone_model_name, 7)).to(device)
461
+ model.eval()
462
+ total_params = sum(p.numel() for p in model.parameters())
463
+ print(f"Total params: {total_params:,}")
464
+
465
+ vocab_size = model.module.encoder.config.vocab_size
466
+ max_len = model.module.encoder.config.max_position_embeddings
467
+
468
+ bz = 32
469
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
470
+ a = torch.ones(bz, 5, 10).to(device)
471
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
472
+
473
+ with torch.no_grad():
474
+ r = model(i, a)
475
+
476
+ if type(r) == tuple:
477
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
478
+ else:
479
+ print(r.shape)
480
+
481
+ test()
482
+
483
+ # %% [code]
484
+ def configure_optimizers(network, optim_params, scheduler_params):
485
+ try:
486
+ optim_params = copy.copy(optim_params)
487
+ scheduler_params = copy.copy(scheduler_params)
488
+
489
+ optim_name = optim_params.pop('name')
490
+ scheduler_name = scheduler_params.pop('name')
491
+
492
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
493
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
494
+
495
+ if optimizer_cls is None:
496
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
497
+
498
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
499
+
500
+ scheduler = None
501
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
502
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
503
+
504
+ return optimizer, scheduler
505
+
506
+ except KeyError as e:
507
+ raise ValueError(f"Missing {e} in config!!")
508
+
509
+ def freeze(self, model):
510
+ model.eval()
511
+ for param in model.parameters():
512
+ param.requires_grad = False
513
+
514
+ def unfreeze(self, model):
515
+ model.train()
516
+ for param in model.parameters():
517
+ param.requires_grad = True
518
+
519
+ def reduce_batch_size(loader, ratio=0.5):
520
+ new_bs = max(1, int(loader.batch_size * ratio))
521
+
522
+ shuffle = isinstance(loader.sampler, RandomSampler)
523
+
524
+ new_loader = DataLoader(
525
+ dataset=loader.dataset,
526
+ batch_size=new_bs,
527
+ shuffle=shuffle,
528
+ sampler=None if shuffle else loader.sampler,
529
+ num_workers=loader.num_workers,
530
+ collate_fn=loader.collate_fn,
531
+ pin_memory=loader.pin_memory,
532
+ drop_last=loader.drop_last,
533
+ timeout=loader.timeout,
534
+ worker_init_fn=loader.worker_init_fn,
535
+ multiprocessing_context=loader.multiprocessing_context,
536
+ generator=loader.generator,
537
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
538
+ persistent_workers=loader.persistent_workers,
539
+ pin_memory_device=loader.pin_memory_device
540
+ )
541
+
542
+ return new_loader
543
+
544
+ def list_to_tuple(x):
545
+ if isinstance(x, (list, tuple)):
546
+ return tuple(list_to_tuple(i) for i in x)
547
+ return x
548
+
549
+ def fmt(x):
550
+ if isinstance(x, float):
551
+ return round(x, 5)
552
+ if isinstance(x, dict):
553
+ return {k: fmt(v) for k, v in x.items()}
554
+ if isinstance(x, list):
555
+ return [fmt(v) for v in x]
556
+ return x
557
+
558
+ class ModelEmaV3Proxy(ModelEmaV3):
559
+ def __getattr__(self, name):
560
+ try:
561
+ return super().__getattr__(name)
562
+ except AttributeError:
563
+ return getattr(self.module, name)
564
+
565
+ class DataParallelProxy(nn.DataParallel):
566
+ def __getattr__(self, name):
567
+ try:
568
+ return super().__getattr__(name)
569
+ except AttributeError:
570
+ attr = getattr(self.module, name)
571
+
572
+ if callable(attr):
573
+ def wrapper(*args, **kwargs):
574
+ return self._parallel_apply_method(name, *args, **kwargs)
575
+ return wrapper
576
+
577
+ return attr
578
+
579
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
580
+ if not self.device_ids:
581
+ return getattr(self.module, method_name)(*inputs, **kwargs)
582
+
583
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
584
+
585
+ replicas = self.replicate(self.module, self.device_ids)
586
+
587
+ outputs = self.parallel_apply(
588
+ [getattr(replica, method_name) for replica in replicas],
589
+ inputs_scattered,
590
+ kwargs_scattered
591
+ )
592
+
593
+ return self.gather(outputs, self.output_device)
594
+
595
+ def fix_bio(tags):
596
+ fixed = []
597
+
598
+ for i, tag in enumerate(tags):
599
+ if tag.startswith('I-'):
600
+ if i == 0 or fixed[i-1] == 'O':
601
+ tag = 'B-' + tag[2:]
602
+ else:
603
+ prev_type = fixed[i-1][2:]
604
+ curr_type = tag[2:]
605
+ if prev_type != curr_type:
606
+ tag = 'B-' + curr_type
607
+ fixed.append(tag)
608
+
609
+ return fixed
610
+
611
+ def extract_entities(input_ids, start_logits, end_logits, id2label):
612
+ """
613
+ Args:
614
+ input_ids: Tensor (B, L)
615
+ start_logits: Tensor (B, L, C)
616
+ end_logits: Tensor (B, L, C)
617
+ id2label: dict {label_id: label_name}
618
+
619
+ Returns:
620
+ List[(bidx, (input_ids[bidx, s:e+1], id2label[label_id]))]
621
+ """
622
+
623
+ start_labels = start_logits.argmax(dim=-1) # (B, L)
624
+ end_labels = end_logits.argmax(dim=-1) # (B, L)
625
+
626
+ B, L = start_labels.shape
627
+
628
+ results = []
629
+
630
+ for bidx in range(B):
631
+
632
+ used_start = set()
633
+ used_end = set()
634
+
635
+ for s in range(L):
636
+
637
+ s_label = start_labels[bidx, s].item()
638
+
639
+ # bỏ qua nhãn O = 0
640
+ if s_label == 0:
641
+ continue
642
+
643
+ if s in used_start:
644
+ continue
645
+
646
+ nearest_e = None
647
+
648
+ # tìm end gần nhất có cùng label
649
+ for e in range(s, L):
650
+
651
+ if e in used_end:
652
+ continue
653
+
654
+ e_label = end_labels[bidx, e].item()
655
+
656
+ if e_label == s_label:
657
+ nearest_e = e
658
+ break
659
+
660
+ if nearest_e is None:
661
+ continue
662
+
663
+ used_start.add(s)
664
+ used_end.add(nearest_e)
665
+
666
+ entity_tokens = input_ids[bidx, s:nearest_e + 1].tolist()
667
+
668
+ results.append((bidx, (entity_tokens, id2label[s_label])))
669
+
670
+ return results
671
+
672
+ class Trainer:
673
+ def __init__(
674
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
675
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
676
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
677
+ ):
678
+ self.ema_net = None
679
+
680
+ self.training_time = self._time_str_to_seconds(training_time)
681
+ self.mode = eval_mode
682
+ self.topk = topk
683
+ self.device = device
684
+ self.logging = logging if logging < epochs else 1
685
+ self.logging_file = logging_file
686
+ self.checkpoints_dir = checkpoints_dir
687
+ self.early_stopping = early_stopping
688
+ self.eval_from_ratio = eval_from_ratio
689
+ self.eval_every = eval_every
690
+ self.save_name = save_name
691
+ self.save_best = save_best
692
+ self.save_last = save_last
693
+ self.return_best = return_best
694
+ self.return_last = return_last
695
+ self.max_grad_norm = max_grad_norm
696
+ self.schedule_in_step = schedule_in_step
697
+ self.use_ema = use_ema
698
+ self.ema_from_ratio = ema_from_ratio
699
+ self.ema_decay = ema_decay
700
+
701
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
702
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
703
+
704
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
705
+ if eval_fn is None:
706
+ if self.mode == "max":
707
+ eval_fn = lambda *x: -loss_fn(*x)
708
+ else:
709
+ eval_fn = lambda *x: loss_fn(*x)
710
+
711
+ if torch.cuda.device_count() > 1:
712
+ network = DataParallelProxy(network)
713
+ network = network.to(self.device)
714
+
715
+ if not start_training_time:
716
+ start_training_time = time.time()
717
+
718
+ start_ema = int(epochs * self.ema_from_ratio)
719
+ start_eval = int(epochs * self.eval_from_ratio)
720
+
721
+ if val_loader is None:
722
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
723
+ else:
724
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
725
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
726
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
727
+
728
+ training_log = {}
729
+ for epoch in range(start_epoch, epochs+start_epoch):
730
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
731
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
732
+
733
+ try:
734
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn)
735
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
736
+ logging_dict.update(train_loss_epoch_dict)
737
+
738
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
739
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
740
+
741
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
742
+ update = self._update_best_network(eval_net, val_score, epoch)
743
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
744
+ logging_dict.update(val_score_dict)
745
+ if not self.schedule_in_step and scheduler:
746
+ scheduler.step()
747
+
748
+ except RuntimeError as e:
749
+ if "out of memory" in str(e).lower():
750
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
751
+ torch.cuda.empty_cache()
752
+ gc.collect()
753
+ if torch.cuda.is_available():
754
+ torch.cuda.synchronize()
755
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
756
+
757
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
758
+ if val_loader is not None:
759
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
760
+
761
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
762
+ else:
763
+ raise
764
+
765
+ training_log[epoch] = logging_dict
766
+ if self.is_early_stopping(epoch):
767
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
768
+ break
769
+ if self.logging:
770
+ if epoch % self.logging == 0:
771
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
772
+ else:
773
+ print(f'{epoch}...', end=' ')
774
+
775
+ if self._at_time_limit(start_training_time):
776
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
777
+ break
778
+
779
+ if self.logging_file:
780
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
781
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
782
+ f.write(json.dumps(training_log))
783
+
784
+ if self.use_ema and self.ema_net is not None:
785
+ self._save_state_dict(self.ema_net.module)
786
+ else:
787
+ self._save_state_dict(network)
788
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
789
+
790
+ best_model, last_model = None, None
791
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
792
+ if self.return_best :
793
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
794
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
795
+ if self.return_last:
796
+ last_model = eval_net.state_dict()
797
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
798
+
799
+ del network
800
+ torch.cuda.empty_cache()
801
+ gc.collect()
802
+ return training_log, best_model, last_model
803
+
804
+ def _time_str_to_seconds(self, time_str):
805
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
806
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
807
+
808
+ def _update_best_network(self, network, val_score, epoch):
809
+ topk = max(1, self.topk)
810
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
811
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
812
+ if val_score in [x[0] for x in self.best_stage]:
813
+ return True
814
+ return False
815
+
816
+ def is_early_stopping(self, epoch):
817
+ if self.best_stage[0][1] is None:
818
+ return False
819
+ if not self.early_stopping:
820
+ return False
821
+ return epoch - self.best_stage[0][1] >= self.early_stopping
822
+
823
+ def _at_time_limit(self, start_training_time):
824
+ return time.time() - start_training_time >= self.training_time
825
+
826
+ def _save_state_dict(self, network):
827
+ if self.topk <= 0:
828
+ return
829
+
830
+ if self.save_best:
831
+ for r in range(self.topk):
832
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
833
+
834
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
835
+ if state_dict is None:
836
+ continue
837
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
838
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
839
+ if self.save_last:
840
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
841
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
842
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
843
+
844
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn):
845
+ network.train()
846
+ total_loss = 0
847
+ total_loss_dict = {}
848
+ for batch_idx, batch in enumerate(train_loader):
849
+ optimizer.zero_grad()
850
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
851
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn)
852
+
853
+ for k, v in loss_dict.items():
854
+ t = total_loss_dict.get(k, 0)
855
+ total_loss_dict[k] = t + v
856
+ self.grad_scaler.scale(loss).backward()
857
+ self.grad_scaler.unscale_(optimizer)
858
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
859
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
860
+ self.grad_scaler.step(optimizer)
861
+ self.grad_scaler.update()
862
+ if self.schedule_in_step and scheduler:
863
+ scheduler.step()
864
+ if self.use_ema and self.ema_net is not None:
865
+ self.ema_net.update(network)
866
+ total_loss += loss
867
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
868
+
869
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
870
+ network.eval()
871
+ total_score = 0.0
872
+ total_score_dict = {}
873
+ object_lists = None # sẽ init sau
874
+
875
+ with torch.no_grad():
876
+ for batch_idx, batch in enumerate(val_loader):
877
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
878
+ total_score += score
879
+
880
+ for k, v in score_dict.items():
881
+ t = total_score_dict.get(k, 0)
882
+ total_score_dict[k] = t + v
883
+
884
+ if objects:
885
+ if object_lists is None:
886
+ object_lists = [[] for _ in range(len(objects))]
887
+
888
+ for i, obj in enumerate(objects):
889
+ object_lists[i].append(obj.detach())
890
+
891
+ if object_lists is not None:
892
+ object_arrays = [
893
+ torch.concat(obj_list, dim=0).cpu().numpy()
894
+ for obj_list in object_lists
895
+ ]
896
+ else:
897
+ object_arrays = []
898
+
899
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
900
+
901
+ def _cal_loss(self, network, batch, batch_idx, loss_fn):
902
+ # Bạn cần override _cal_loss để tính loss
903
+ input_ids = batch['input_ids'].to(self.device)
904
+ attention_mask = batch['attention_mask'].to(self.device)
905
+ start_labels = batch['start_labels'].to(self.device)
906
+ end_labels = batch['end_labels'].to(self.device)
907
+
908
+ start_logits, end_logits = network(input_ids, attention_mask)
909
+
910
+ loss_dict = loss_fn(
911
+ start_logits, start_labels,
912
+ end_logits, end_labels,
913
+ )
914
+ return loss_dict['total'], loss_dict
915
+
916
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
917
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
918
+ input_ids = batch['input_ids'].to(self.device)
919
+ attention_mask = batch['attention_mask'].to(self.device)
920
+ gold_entities = batch['gold_entities']
921
+
922
+ B, _, _ = input_ids.shape
923
+
924
+ start_logits, end_logits = network(input_ids, attention_mask)
925
+
926
+ pred_ids = extract_entities(input_ids.reshape(B, -1), start_logits, end_logits, id2label)
927
+ pred_ids = list_to_tuple(pred_ids)
928
+
929
+ gold_ids = list_to_tuple(gold_entities)
930
+
931
+ score_dict = eval_fn(pred_ids, gold_ids)
932
+ return score_dict['f1'], score_dict, []
933
+
934
+ # %% [code]
935
+ class PhoBERTSpanAligner:
936
+ def __init__(self, tokenizer, max_len):
937
+ self.tokenizer = tokenizer
938
+ self.max_len = max_len
939
+
940
+ # ===== 1. Extract discontinuous spans =====
941
+ def extract_spans(self, sample):
942
+ entity_spans = []
943
+
944
+ for event in sample["entities"]:
945
+ entity_type = event["label"]
946
+ spans = [tuple(event["offset"])]
947
+ entity_spans.append({
948
+ "spans": spans,
949
+ "label": entity_type
950
+ })
951
+
952
+ return entity_spans
953
+
954
+ # ===== 2. Word offsets =====
955
+ def build_word_offsets(self, text, words):
956
+ offsets = []
957
+ pointer = 0
958
+
959
+ for word in words:
960
+ start = text.find(word, pointer)
961
+ end = start + len(word)
962
+ offsets.append((start, end))
963
+ pointer = end
964
+
965
+ return offsets
966
+
967
+ # ===== 3. Char → word =====
968
+ def char_span_to_word_span(self, word_offsets, start, end):
969
+ start_word = None
970
+ end_word = None
971
+
972
+ for i, (w_start, w_end) in enumerate(word_offsets):
973
+ if w_start <= start < w_end:
974
+ start_word = i
975
+ if w_start < end <= w_end:
976
+ end_word = i
977
+
978
+ return start_word, end_word
979
+
980
+ # ===== 4. Word → subword =====
981
+ def word_to_subword_map(self, words):
982
+ mapping = []
983
+ subword_index = 1 # <s>
984
+
985
+ for word in words:
986
+ sub_tokens = self.tokenizer.tokenize(word)
987
+ start = subword_index
988
+ end = subword_index + len(sub_tokens) - 1
989
+ mapping.append((start, end))
990
+ subword_index += len(sub_tokens)
991
+
992
+ return mapping
993
+
994
+ # ===== 5. Span → subword =====
995
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
996
+ sub_spans = []
997
+
998
+ for span_start, span_end in spans:
999
+ w_start, w_end = self.char_span_to_word_span(
1000
+ word_offsets, span_start, span_end
1001
+ )
1002
+ if w_start is None or w_end is None:
1003
+ continue
1004
+
1005
+ sub_start = word_subword_map[w_start][0]
1006
+ sub_end = word_subword_map[w_end][1]
1007
+ sub_spans.append((sub_start, sub_end))
1008
+
1009
+ return sub_spans
1010
+
1011
+ def extract_valid_spans(self, sub_spans):
1012
+ valid_spans = []
1013
+ for s, e in sub_spans:
1014
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1015
+ continue
1016
+ valid_spans.append((s, e))
1017
+ return valid_spans
1018
+
1019
+ def encode(self, sample):
1020
+ text = sample["text"]
1021
+ entities = self.extract_spans(sample)
1022
+
1023
+ # ===== 1. Word tokenize =====
1024
+ words = word_tokenize(text)
1025
+ sentence = " ".join(words)
1026
+
1027
+ # ===== 2. Mapping =====
1028
+ word_offsets = self.build_word_offsets(text, words)
1029
+ word_subword_map = self.word_to_subword_map(words)
1030
+
1031
+ # ===== 3. Tokenize FULL =====
1032
+ encoding = self.tokenizer(
1033
+ sentence,
1034
+ max_length=self.max_len,
1035
+ truncation=True,
1036
+ padding="max_length",
1037
+ return_tensors="pt"
1038
+ )
1039
+ input_ids = encoding["input_ids"][0]
1040
+ attention_mask = encoding["attention_mask"][0]
1041
+
1042
+ # ===== 5. Convert spans =====
1043
+ entities_gold_spans = []
1044
+
1045
+ for ent in entities:
1046
+ label = ent["label"]
1047
+
1048
+ sub_spans = self.span_to_subword(
1049
+ word_offsets,
1050
+ word_subword_map,
1051
+ ent["spans"]
1052
+ )
1053
+ valid_spans = self.extract_valid_spans(sub_spans)
1054
+ if len(valid_spans) == 0:
1055
+ continue
1056
+ entities_gold_spans.append((tuple(valid_spans), label))
1057
+
1058
+ return {
1059
+ "input_ids": input_ids,
1060
+ "attention_mask": attention_mask,
1061
+ "entities_gold_spans": entities_gold_spans,
1062
+ }
1063
+
1064
+ def generate_candidate_spans(seq_len, max_span_len):
1065
+ spans = []
1066
+ for i in range(1, seq_len+1):
1067
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1068
+ spans.append((i, j))
1069
+ return spans
1070
+
1071
+ class KLTNDataset(Dataset):
1072
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1073
+ super().__init__()
1074
+ self.tokenizer = tokenizer
1075
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1076
+ self.all_data = all_data
1077
+ self.using_idxes = using_idxes
1078
+ self.label2id = label2id
1079
+ self.max_len = max_len
1080
+ self.max_n_parts = max_n_parts
1081
+
1082
+ def __len__(self):
1083
+ return len(self.using_idxes)
1084
+
1085
+ def __getitem__(self, idx):
1086
+ ridx = self.using_idxes[idx]
1087
+ sample = self.all_data[ridx]
1088
+ result = self.aligner.encode(sample)
1089
+
1090
+ input_ids = result["input_ids"].squeeze(0)
1091
+ attention_mask = result["attention_mask"].squeeze(0)
1092
+ entities_gold_spans = result["entities_gold_spans"]
1093
+
1094
+ # Get label
1095
+ gold_entities = []
1096
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1097
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1098
+ for spans, label in entities_gold_spans:
1099
+ s, e = spans[0]
1100
+
1101
+ start_labels[s] = self.label2id[f'{label}']
1102
+ end_labels[e] = self.label2id[f'{label}']
1103
+
1104
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
1105
+
1106
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1107
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1108
+
1109
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1110
+ input_ids = input_ids[:n_valid_parts]
1111
+ attention_mask = attention_mask[:n_valid_parts]
1112
+ start_labels = start_labels[:n_valid_parts*self.max_len]
1113
+ end_labels = end_labels[:n_valid_parts*self.max_len]
1114
+
1115
+ return {
1116
+ "input_ids": input_ids,
1117
+ "attention_mask": attention_mask,
1118
+ "start_labels": start_labels,
1119
+ "end_labels": end_labels,
1120
+ "gold_entities": gold_entities,
1121
+ }
1122
+
1123
+ def _pad_batch(tensor_list, pad_value=0):
1124
+ """
1125
+ tensor_list: list of tensors
1126
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1127
+
1128
+ return:
1129
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
1130
+ """
1131
+
1132
+ # lấy max toàn batch
1133
+ max_Nk = max(t.size(0) for t in tensor_list)
1134
+ max_n_parts = max(t.size(1) for t in tensor_list)
1135
+ max_len = max(t.size(2) for t in tensor_list)
1136
+
1137
+ padded = []
1138
+
1139
+ for t in tensor_list:
1140
+ Nk, n_parts_i, max_len_i = t.shape
1141
+
1142
+ # pad chiều n_parts và max_len trước
1143
+ if n_parts_i < max_n_parts or max_len_i < max_len:
1144
+ new_t = t.new_full(
1145
+ (Nk, max_n_parts, max_len),
1146
+ pad_value
1147
+ )
1148
+ new_t[:, :n_parts_i, :max_len_i] = t
1149
+ t = new_t
1150
+
1151
+ # pad chiều Nk
1152
+ if Nk < max_Nk:
1153
+ pad_tensor = t.new_full(
1154
+ (max_Nk - Nk, max_n_parts, max_len),
1155
+ pad_value
1156
+ )
1157
+ t = torch.cat([t, pad_tensor], dim=0)
1158
+
1159
+ padded.append(t)
1160
+
1161
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1162
+
1163
+ def collate_fn(batch):
1164
+ gold_entities = []
1165
+ for bidx, b in enumerate(batch):
1166
+ for entity in b['gold_entities']:
1167
+ gold_entities.append([bidx, entity])
1168
+
1169
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
1170
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
1171
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1172
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1173
+
1174
+ # pad theo Nk
1175
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
1176
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
1177
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1178
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1179
+
1180
+ return {
1181
+ "input_ids": input_ids,
1182
+ "attention_mask": attention_mask,
1183
+ "start_labels": start_labels,
1184
+ "end_labels": end_labels,
1185
+ "gold_entities": gold_entities,
1186
+ }
1187
+
1188
+ # %% [code]
1189
+ def shift_bidx(spans, batch_idx):
1190
+ shifted = []
1191
+ for bidx, ent in spans:
1192
+ new_bidx = bidx + batch_idx * batch_size
1193
+ shifted.append((new_bidx, ent))
1194
+ return shifted
1195
+
1196
+ def refactor_entities(entities, save_dict):
1197
+ i, c = [], []
1198
+ for bidx, (ids, lb) in entities:
1199
+ if (bidx, ids) not in i:
1200
+ i.append((bidx, ids))
1201
+
1202
+ if (bidx, (ids, lb)) not in c:
1203
+ c.append((bidx, (ids, lb)))
1204
+
1205
+ save_dict['Ent-I'].extend(i)
1206
+ save_dict['Ent-C'].extend(c)
1207
+
1208
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
1209
+ if torch.cuda.device_count() > 1:
1210
+ network = DataParallelProxy(network)
1211
+ network = network.to(device)
1212
+ network.eval()
1213
+
1214
+ eval_types = ['Ent-I', 'Ent-C']
1215
+
1216
+ all_pred = {eval_type: [] for eval_type in eval_types}
1217
+ all_gold = {eval_type: [] for eval_type in eval_types}
1218
+
1219
+ list_input_ids = []
1220
+
1221
+ with torch.no_grad():
1222
+ for batch_idx, batch in enumerate(test_loader):
1223
+ input_ids = batch['input_ids'].to(device)
1224
+ attention_mask = batch['attention_mask'].to(device)
1225
+ gold_entities = batch['gold_entities']
1226
+
1227
+ B, _, _ = input_ids.shape
1228
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
1229
+
1230
+ list_start_logits = []
1231
+ list_end_logits = []
1232
+ for sd in state_dicts:
1233
+ if torch.cuda.device_count() > 1:
1234
+ network.module.load_state_dict(sd)
1235
+ else:
1236
+ network.load_state_dict(sd)
1237
+
1238
+ start_logits, end_logits = network(input_ids, attention_mask)
1239
+ list_start_logits.append(start_logits)
1240
+ list_end_logits.append(end_logits)
1241
+
1242
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
1243
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
1244
+
1245
+ pred_entities = extract_entities(input_ids.reshape(B, -1), ensemble_start_logits, ensemble_end_logits, id2label)
1246
+ pred_entities = shift_bidx(pred_entities, batch_idx)
1247
+ refactor_entities(pred_entities, all_pred)
1248
+
1249
+ gold_entities = shift_bidx(gold_entities, batch_idx)
1250
+ refactor_entities(gold_entities, all_gold)
1251
+
1252
+ # ===== GLOBAL EVAL =====
1253
+ final_score = {}
1254
+ for eval_type in eval_types:
1255
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
1256
+ final_score[eval_type] = score
1257
+
1258
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
1259
+
1260
+ # ===== PREDICT =====
1261
+ predictions = []
1262
+ for input_ids in list_input_ids:
1263
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
1264
+ for bidx, (ids, lb) in all_pred['Ent-C']:
1265
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
1266
+
1267
+ return final_score, analyze_result, predictions
1268
+
1269
+ # %% [code]
1270
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1271
+ data_train = json.load(f)
1272
+
1273
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1274
+ data_test = json.load(f)
1275
+
1276
+ print('Train:', len(data_train))
1277
+ print('Test:', len(data_test))
1278
+
1279
+ # %% [code]
1280
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
1281
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
1282
+ label2id = {l: i for i, l in enumerate(entity_types)}
1283
+ id2label = {i: l for l, i in label2id.items()}
1284
+
1285
+ # %% [code]
1286
+ zero_entities_idxes = []
1287
+ for idx, d in enumerate(data_train):
1288
+ if len(d['entities']) == 0:
1289
+ zero_entities_idxes.append(idx)
1290
+
1291
+ n_zero_entities_samples = len(zero_entities_idxes)
1292
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
1293
+
1294
+ random.seed(42)
1295
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
1296
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
1297
+
1298
+ new_data_train = []
1299
+ for idx, d in enumerate(data_train):
1300
+ if len(d['entities']) == 0:
1301
+ if idx in sampled_zero_entities_idxes:
1302
+ new_data_train.append(d)
1303
+ else:
1304
+ new_data_train.append(d)
1305
+ data_train = new_data_train
1306
+
1307
+ print('Train:', len(data_train))
1308
+
1309
+ # %% [code]
1310
+ if debug_only:
1311
+ data_train = data_train[:10]
1312
+ data_test = data_test[:10]
1313
+
1314
+ print('Train:', len(data_train))
1315
+ print('Test:', len(data_test))
1316
+
1317
+ # %% [code]
1318
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1319
+
1320
+ # %% [code]
1321
+ print('Experiment name:', state_dict_save_name)
1322
+
1323
+ # %% [code]
1324
+ if not test_only:
1325
+ full_idxes = np.array(range(len(data_train)))
1326
+ training_logs, best_models, last_models = [], [], []
1327
+ start_training_time = time.time()
1328
+ for seed in SEEDS:
1329
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1330
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1331
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1332
+ continue
1333
+ set_seed(seed)
1334
+
1335
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1336
+
1337
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
1338
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
1339
+
1340
+ generator = torch.Generator()
1341
+ generator.manual_seed(seed)
1342
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1343
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1344
+
1345
+ my_model = IEModel(
1346
+ num_labels=len(label2id),
1347
+ **model_params
1348
+ )
1349
+ total_params = sum(p.numel() for p in my_model.parameters())
1350
+ print(f"Total params: {total_params:,}")
1351
+
1352
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1353
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1354
+ other_params = [
1355
+ p for p in my_model.parameters()
1356
+ if id(p) not in encoder_params
1357
+ ]
1358
+ optimizer = optim.AdamW([
1359
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
1360
+ {"params": other_params}
1361
+ ], lr=5e-4)
1362
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1363
+
1364
+ loss_fn = CustomLoss(
1365
+ **loss_func_params
1366
+ )
1367
+ eval_fn = CustomEvalFn(**eval_func_params)
1368
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1369
+ trainer = Trainer(**trainer_params)
1370
+
1371
+ print(f'Start Training Fold {fold_idx}...')
1372
+ training_log, best_model, last_model = trainer.fit(
1373
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
1374
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
1375
+ )
1376
+
1377
+ training_logs.append(training_log)
1378
+ best_models.append(best_model)
1379
+ last_models.append(last_model)
1380
+
1381
+ # %% [code]
1382
+ def load_all_state_dicts(folder):
1383
+ files = []
1384
+
1385
+ for file in os.listdir(folder):
1386
+ if file.endswith(".pt") or file.endswith(".pth"):
1387
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1388
+ if m:
1389
+ fold = int(m.group(1))
1390
+ files.append((fold, file))
1391
+
1392
+ # sort theo fold
1393
+ files.sort(key=lambda x: x[0])
1394
+
1395
+ state_dicts = []
1396
+ for fold, file in files:
1397
+ path = os.path.join(folder, file)
1398
+ print(f"Loading fold {fold}: {file}")
1399
+ state_dict = torch.load(path, map_location="cpu")
1400
+ state_dicts.append(state_dict)
1401
+
1402
+ return state_dicts
1403
+
1404
+ if test_only:
1405
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1406
+ get_ipython().system('rm -rf .cache .gitattributes')
1407
+
1408
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1409
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1410
+
1411
+ # %% [code]
1412
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1413
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
1414
+ generator = torch.Generator()
1415
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1416
+ eval_fn = CustomEvalFn(**eval_func_params)
1417
+ analyzer = SpanErrorAnalyzer()
1418
+ my_model = IEModel(
1419
+ num_labels=len(label2id),
1420
+ **model_params
1421
+ )
1422
+ total_params = sum(p.numel() for p in my_model.parameters())
1423
+ print(f"Total params: {total_params:,}")
1424
+
1425
+ # %% [code]
1426
+ start_time = time.time()
1427
+ result_test = None
1428
+ analyze_result = None
1429
+
1430
+ best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1431
+ last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1432
+
1433
+ result_test = {"Best model": best_score, "Last model": last_score}
1434
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
1435
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
1436
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
1437
+
1438
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1439
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
1440
+
1441
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
1442
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
1443
+
1444
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
1445
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
1446
+
1447
+ print('Test:', time.time() - start_time, 's --> Done!')
1448
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
1449
+
1450
+ # %% [code]
1451
+ best_pred_test[:10]
1452
+
1453
+ # %% [code]
1454
+ last_pred_test[:10]
1455
+
1456
+ # %% [code]
1457
+ def dict_to_df(data):
1458
+ row_tuples = []
1459
+ row_values = []
1460
+
1461
+ metrics = ["precision", "recall", "f1"]
1462
+
1463
+ # Lấy model đầu tiên
1464
+ first_model = next(iter(data.values()))
1465
+
1466
+ # eval_keys
1467
+ eval_keys = list(first_model.keys())
1468
+
1469
+ for eval_key in eval_keys:
1470
+ row_tuples.append(eval_key)
1471
+ row = {}
1472
+
1473
+ for model_name, model_data in data.items():
1474
+ for metric in metrics:
1475
+ row[(model_name, metric)] = model_data[eval_key][metric]
1476
+
1477
+ row_values.append(row)
1478
+
1479
+ # ===== DataFrame =====
1480
+ df = pd.DataFrame(row_values)
1481
+
1482
+ # MultiIndex columns
1483
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
1484
+
1485
+ # Index
1486
+ df.index = pd.Index(row_tuples, name="evaluation")
1487
+
1488
+ # ===== Sort =====
1489
+ sort_keys = []
1490
+ if ("Best model", "f1") in df.columns:
1491
+ sort_keys.append(("Best model", "f1"))
1492
+ if ("Last model", "f1") in df.columns:
1493
+ sort_keys.append(("Last model", "f1"))
1494
+
1495
+ if sort_keys:
1496
+ df = df.sort_values(by=sort_keys, ascending=False)
1497
+
1498
+ return df
1499
+
1500
+ result_test_df = dict_to_df(result_test)
1501
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
1502
+ result_test_df
1503
+
1504
+ # %% [code]
1505
+ key = ("Best model", "f1")
1506
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
1507
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
1508
+ result_test_df_best
1509
+
1510
+ # %% [code]
1511
+ def get_avg_best_score(logs):
1512
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
1513
+
1514
+ def get_avg_log(logs, epochs):
1515
+ avg_log = {}
1516
+
1517
+ for epoch in range(1, epochs + 1):
1518
+ val_score = 0.0
1519
+ train_loss = 0.0
1520
+ n_eval = 0
1521
+
1522
+ for idx in range(len(logs)):
1523
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
1524
+ if log is None:
1525
+ continue
1526
+
1527
+ val_score += log.get('val_score', 0.0)
1528
+ train_loss += log.get('train_loss', 0.0)
1529
+ n_eval += 1
1530
+
1531
+ if n_eval == 0:
1532
+ continue
1533
+
1534
+ avg_log[epoch] = {
1535
+ 'train_loss': train_loss / n_eval,
1536
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
1537
+ }
1538
+
1539
+ return avg_log
1540
+
1541
+ def parse_label_key(label: str):
1542
+ try:
1543
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
1544
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
1545
+ return first, last
1546
+ except:
1547
+ return (0, 0)
1548
+
1549
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
1550
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1551
+
1552
+ # ===== Plot Train Loss =====
1553
+ for name, log in logs_dict.items():
1554
+ epochs = sorted(log.keys())
1555
+ train_loss = [log[e]['train_loss'] for e in epochs]
1556
+ axes[0].plot(epochs, train_loss, label=name)
1557
+
1558
+ axes[0].set_xlabel('Epoch')
1559
+ axes[0].set_ylabel('Train Loss')
1560
+ axes[0].set_title('Training Loss')
1561
+ axes[0].grid(True)
1562
+
1563
+ # ===== Plot Validation Score =====
1564
+ for name, log in logs_dict.items():
1565
+ epochs = sorted(log.keys())
1566
+ val_score = [log[e]['val_score'] for e in epochs]
1567
+ axes[1].plot(epochs, val_score, label=name)
1568
+
1569
+ axes[1].set_xlabel('Epoch')
1570
+ axes[1].set_ylabel('Validation Score')
1571
+ axes[1].set_title('Validation Score')
1572
+ axes[1].grid(True)
1573
+
1574
+ # ===== Shared Legend =====
1575
+ handles, labels = axes[0].get_legend_handles_labels()
1576
+ pairs = list(zip(handles, labels))
1577
+ pairs_sorted = sorted(
1578
+ pairs,
1579
+ key=lambda x: parse_label_key(x[1])
1580
+ )
1581
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
1582
+
1583
+ axes[0].legend(
1584
+ handles_sorted,
1585
+ labels_sorted,
1586
+ loc='center left',
1587
+ bbox_to_anchor=(1.01, 0.5),
1588
+ borderaxespad=0.
1589
+ )
1590
+
1591
+ plt.tight_layout(rect=[0, 0, 1, 1])
1592
+
1593
+ if save_path is not None:
1594
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
1595
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
1596
+
1597
+ plt.show()
1598
+
1599
+ # %% [code]
1600
+ if not test_only:
1601
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"])
1602
+ get_ipython().system('rm -rf .cache .gitattributes')
1603
+
1604
+ # %% [code]
1605
+ if not test_only:
1606
+ experiments = {}
1607
+ for experiment in os.listdir(pretrained_dir):
1608
+ if '.virtual_documents' in experiment:
1609
+ continue
1610
+ experiment_logs = []
1611
+ try:
1612
+ for seed in SEEDS:
1613
+ for fold_idx in range(nfolds):
1614
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
1615
+ experiment_log = json.load(f)
1616
+ experiment_logs.append(experiment_log)
1617
+ except:
1618
+ pass
1619
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
1620
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
1621
+
1622
+ # %% [code]
1623
+ if not test_only:
1624
+ score = get_avg_best_score(training_logs)
1625
+ state_dict_save_name, score
1626
+
1627
+ # %% [code]
1628
+ if not test_only:
1629
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
1630
+
4_doc_level_entities_5/lasts/4_doc_level_entities_5_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:989e461477030756eebc74b7390f7a06e4f25ca16c6562a295e536c959794b1e
3
+ size 544836586
4_doc_level_entities_5/logs/4_doc_level_entities_5_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 5b9ee3babb5ac258b6a693b0cbe9fe378336522c2e78f1ee4074177f9531bc3f
  • Pointer size: 131 Bytes
  • Size of remote file: 543 kB
4_doc_level_entities_5/logs/4_doc_level_entities_5_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 0.11218122392892838, "total": 0.112181226545353, "start_loss": 0.05426829794178838, "end_loss": 0.057912820158993523}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.051757387816905975, "total": 0.05175738829035424, "start_loss": 0.025174654447115384, "end_loss": 0.026582730653692647}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.050453245639801025, "total": 0.050453246636534214, "start_loss": 0.024131022169438493, "end_loss": 0.026322175029129487}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.04355930536985397, "total": 0.0435593040491825, "start_loss": 0.021132941230084985, "end_loss": 0.02242637238773614}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.04147800803184509, "total": 0.04147800713478522, "start_loss": 0.020000786286931374, "end_loss": 0.021477184168072448}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.038297928869724274, "total": 0.038297927499216136, "start_loss": 0.018655312898565694, "end_loss": 0.01964264649611253}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.03591204434633255, "total": 0.03591204486961747, "start_loss": 0.017833500801520203, "end_loss": 0.018078555231509003}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.03257141634821892, "total": 0.032571416236086434, "start_loss": 0.016170229002783528, "end_loss": 0.016401166501252548, "val_score": 0.6940790796844161, "best_score": 0.6940790796844161, "new_best_model": true, "precision": 0.7583838288443805, "recall": 0.6418669729906946, "f1": 0.6940790796844161}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.02933138608932495, "total": 0.029331385889978315, "start_loss": 0.014569072021688506, "end_loss": 0.014762310678743598, "val_score": 0.6921624291792371, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7511150914707015, "recall": 0.6437836669090866, "f1": 0.6921624291792371}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.026413651183247566, "total": 0.026413651214395477, "start_loss": 0.013195129541250376, "end_loss": 0.013218529647010625, "val_score": 0.6909192478816919, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7476423364255042, "recall": 0.6442246008215343, "f1": 0.6909192478816919}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.023491036146879196, "total": 0.023491036533113307, "start_loss": 0.011716610611880503, "end_loss": 0.011774446653283161, "val_score": 0.6898487763592059, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7446957700484234, "recall": 0.6445442603504218, "f1": 0.6898487763592059}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.021468782797455788, "total": 0.02146878290335869, "start_loss": 0.010837098029146227, "end_loss": 0.01063168008989315, "val_score": 0.6863720192343292, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7411644152446474, "recall": 0.64116415736283, "f1": 0.6863720192343292}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.01929703913629055, "total": 0.019297040027120838, "start_loss": 0.009791418860189891, "end_loss": 0.009505625951250261, "val_score": 0.6860198614724654, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7413199740858584, "recall": 0.6404082609825267, "f1": 0.6860198614724654}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.01754027046263218, "total": 0.017540269870821846, "start_loss": 0.008926345353142474, "end_loss": 0.008613921328133165, "val_score": 0.6850928313451818, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7409721990232857, "recall": 0.6392113813963958, "f1": 0.6850928313451818}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.015996301546692848, "total": 0.015996301453249113, "start_loss": 0.008183492864653419, "end_loss": 0.007812814170301559, "val_score": 0.6835492727374415, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7404580245033745, "recall": 0.6369536981327771, "f1": 0.6835492727374415}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.014749271795153618, "total": 0.014749271813842365, "start_loss": 0.007579577009016056, "end_loss": 0.007169686830960787, "val_score": 0.6834176192629552, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7402310798281497, "recall": 0.636857565242963, "f1": 0.6834176192629552}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.013936230912804604, "total": 0.013936230968870846, "start_loss": 0.007186701465211186, "end_loss": 0.00674954066707139, "val_score": 0.6826696675128174, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7403559238036548, "recall": 0.6355319830339294, "f1": 0.6826696675128174}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.013103893958032131, "total": 0.013103894166723144, "start_loss": 0.0067786979037383725, "end_loss": 0.006325190681278905, "val_score": 0.6824660475925308, "best_score": 0.6940790796844161, "new_best_model": false, "precision": 0.7407196580353178, "recall": 0.634898670043125, "f1": 0.6824660475925308}}
4_doc_level_entities_5/r1s/4_doc_level_entities_5_s26092004_f0_r1_vs0.69408_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77c70a0401c2ed198a5690b78d4e1b0534eadcb2adbd991790abc150f65362f8
3
+ size 544838290
4_doc_level_entities_5/results/4_doc_level_entities_5_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
4_doc_level_entities_5/results/4_doc_level_entities_5_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
4_doc_level_entities_5/results/4_doc_level_entities_5_test.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Ent-I": {
4
+ "precision": 0.8314164648902023,
5
+ "recall": 0.7037574722453426,
6
+ "f1": 0.7622791551485675
7
+ },
8
+ "Ent-C": {
9
+ "precision": 0.7554479418878577,
10
+ "recall": 0.6389078498288063,
11
+ "f1": 0.6923076873419802
12
+ }
13
+ },
14
+ "Last model": {
15
+ "Ent-I": {
16
+ "precision": 0.8109859435741092,
17
+ "recall": 0.6947053800164861,
18
+ "f1": 0.7483556364439702
19
+ },
20
+ "Ent-C": {
21
+ "precision": 0.7360446570965549,
22
+ "recall": 0.6300341296922951,
23
+ "f1": 0.6789260707926343
24
+ }
25
+ }
26
+ }
4_doc_level_entities_5/results/4_doc_level_entities_5_test_df.xlsx ADDED
Binary file (5.29 kB). View file
 
4_doc_level_entities_5/results/4_doc_level_entities_5_test_df_best.xlsx ADDED
Binary file (5.29 kB). View file